Merge tag 'tag-chrome-platform-for-v5.16' of git://git.kernel.org/pub/scm/linux/kerne...
[linux-2.6-microblaze.git] / net / mctp / af_mctp.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP)
4  *
5  * Copyright (c) 2021 Code Construct
6  * Copyright (c) 2021 Google
7  */
8
9 #include <linux/if_arp.h>
10 #include <linux/net.h>
11 #include <linux/mctp.h>
12 #include <linux/module.h>
13 #include <linux/socket.h>
14
15 #include <net/mctp.h>
16 #include <net/mctpdevice.h>
17 #include <net/sock.h>
18
19 #define CREATE_TRACE_POINTS
20 #include <trace/events/mctp.h>
21
22 /* socket implementation */
23
24 static int mctp_release(struct socket *sock)
25 {
26         struct sock *sk = sock->sk;
27
28         if (sk) {
29                 sock->sk = NULL;
30                 sk->sk_prot->close(sk, 0);
31         }
32
33         return 0;
34 }
35
36 static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
37 {
38         struct sock *sk = sock->sk;
39         struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
40         struct sockaddr_mctp *smctp;
41         int rc;
42
43         if (addrlen < sizeof(*smctp))
44                 return -EINVAL;
45
46         if (addr->sa_family != AF_MCTP)
47                 return -EAFNOSUPPORT;
48
49         if (!capable(CAP_NET_BIND_SERVICE))
50                 return -EACCES;
51
52         /* it's a valid sockaddr for MCTP, cast and do protocol checks */
53         smctp = (struct sockaddr_mctp *)addr;
54
55         lock_sock(sk);
56
57         /* TODO: allow rebind */
58         if (sk_hashed(sk)) {
59                 rc = -EADDRINUSE;
60                 goto out_release;
61         }
62         msk->bind_net = smctp->smctp_network;
63         msk->bind_addr = smctp->smctp_addr.s_addr;
64         msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */
65
66         rc = sk->sk_prot->hash(sk);
67
68 out_release:
69         release_sock(sk);
70
71         return rc;
72 }
73
74 static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
75 {
76         DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
77         const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
78         int rc, addrlen = msg->msg_namelen;
79         struct sock *sk = sock->sk;
80         struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
81         struct mctp_skb_cb *cb;
82         struct mctp_route *rt;
83         struct sk_buff *skb;
84
85         if (addr) {
86                 if (addrlen < sizeof(struct sockaddr_mctp))
87                         return -EINVAL;
88                 if (addr->smctp_family != AF_MCTP)
89                         return -EINVAL;
90                 if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER))
91                         return -EINVAL;
92
93         } else {
94                 /* TODO: connect()ed sockets */
95                 return -EDESTADDRREQ;
96         }
97
98         if (!capable(CAP_NET_RAW))
99                 return -EACCES;
100
101         if (addr->smctp_network == MCTP_NET_ANY)
102                 addr->smctp_network = mctp_default_net(sock_net(sk));
103
104         skb = sock_alloc_send_skb(sk, hlen + 1 + len,
105                                   msg->msg_flags & MSG_DONTWAIT, &rc);
106         if (!skb)
107                 return rc;
108
109         skb_reserve(skb, hlen);
110
111         /* set type as fist byte in payload */
112         *(u8 *)skb_put(skb, 1) = addr->smctp_type;
113
114         rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
115         if (rc < 0)
116                 goto err_free;
117
118         /* set up cb */
119         cb = __mctp_cb(skb);
120         cb->net = addr->smctp_network;
121
122         /* direct addressing */
123         if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
124                 DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
125                                  extaddr, msg->msg_name);
126
127                 if (extaddr->smctp_halen > sizeof(cb->haddr)) {
128                         rc = -EINVAL;
129                         goto err_free;
130                 }
131
132                 cb->ifindex = extaddr->smctp_ifindex;
133                 cb->halen = extaddr->smctp_halen;
134                 memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen);
135
136                 rt = NULL;
137         } else {
138                 rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
139                                        addr->smctp_addr.s_addr);
140                 if (!rt) {
141                         rc = -EHOSTUNREACH;
142                         goto err_free;
143                 }
144         }
145
146         rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
147                                addr->smctp_tag);
148
149         return rc ? : len;
150
151 err_free:
152         kfree_skb(skb);
153         return rc;
154 }
155
156 static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
157                         int flags)
158 {
159         DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
160         struct sock *sk = sock->sk;
161         struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
162         struct sk_buff *skb;
163         size_t msglen;
164         u8 type;
165         int rc;
166
167         if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK))
168                 return -EOPNOTSUPP;
169
170         skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc);
171         if (!skb)
172                 return rc;
173
174         if (!skb->len) {
175                 rc = 0;
176                 goto out_free;
177         }
178
179         /* extract message type, remove from data */
180         type = *((u8 *)skb->data);
181         msglen = skb->len - 1;
182
183         if (len < msglen)
184                 msg->msg_flags |= MSG_TRUNC;
185         else
186                 len = msglen;
187
188         rc = skb_copy_datagram_msg(skb, 1, msg, len);
189         if (rc < 0)
190                 goto out_free;
191
192         sock_recv_ts_and_drops(msg, sk, skb);
193
194         if (addr) {
195                 struct mctp_skb_cb *cb = mctp_cb(skb);
196                 /* TODO: expand mctp_skb_cb for header fields? */
197                 struct mctp_hdr *hdr = mctp_hdr(skb);
198
199                 addr = msg->msg_name;
200                 addr->smctp_family = AF_MCTP;
201                 addr->smctp_network = cb->net;
202                 addr->smctp_addr.s_addr = hdr->src;
203                 addr->smctp_type = type;
204                 addr->smctp_tag = hdr->flags_seq_tag &
205                                         (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
206                 msg->msg_namelen = sizeof(*addr);
207
208                 if (msk->addr_ext) {
209                         DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae,
210                                          msg->msg_name);
211                         msg->msg_namelen = sizeof(*ae);
212                         ae->smctp_ifindex = cb->ifindex;
213                         ae->smctp_halen = cb->halen;
214                         memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr));
215                         memcpy(ae->smctp_haddr, cb->haddr, cb->halen);
216                 }
217         }
218
219         rc = len;
220
221         if (flags & MSG_TRUNC)
222                 rc = msglen;
223
224 out_free:
225         skb_free_datagram(sk, skb);
226         return rc;
227 }
228
229 static int mctp_setsockopt(struct socket *sock, int level, int optname,
230                            sockptr_t optval, unsigned int optlen)
231 {
232         struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
233         int val;
234
235         if (level != SOL_MCTP)
236                 return -EINVAL;
237
238         if (optname == MCTP_OPT_ADDR_EXT) {
239                 if (optlen != sizeof(int))
240                         return -EINVAL;
241                 if (copy_from_sockptr(&val, optval, sizeof(int)))
242                         return -EFAULT;
243                 msk->addr_ext = val;
244                 return 0;
245         }
246
247         return -ENOPROTOOPT;
248 }
249
250 static int mctp_getsockopt(struct socket *sock, int level, int optname,
251                            char __user *optval, int __user *optlen)
252 {
253         struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
254         int len, val;
255
256         if (level != SOL_MCTP)
257                 return -EINVAL;
258
259         if (get_user(len, optlen))
260                 return -EFAULT;
261
262         if (optname == MCTP_OPT_ADDR_EXT) {
263                 if (len != sizeof(int))
264                         return -EINVAL;
265                 val = !!msk->addr_ext;
266                 if (copy_to_user(optval, &val, len))
267                         return -EFAULT;
268                 return 0;
269         }
270
271         return -EINVAL;
272 }
273
274 static const struct proto_ops mctp_dgram_ops = {
275         .family         = PF_MCTP,
276         .release        = mctp_release,
277         .bind           = mctp_bind,
278         .connect        = sock_no_connect,
279         .socketpair     = sock_no_socketpair,
280         .accept         = sock_no_accept,
281         .getname        = sock_no_getname,
282         .poll           = datagram_poll,
283         .ioctl          = sock_no_ioctl,
284         .gettstamp      = sock_gettstamp,
285         .listen         = sock_no_listen,
286         .shutdown       = sock_no_shutdown,
287         .setsockopt     = mctp_setsockopt,
288         .getsockopt     = mctp_getsockopt,
289         .sendmsg        = mctp_sendmsg,
290         .recvmsg        = mctp_recvmsg,
291         .mmap           = sock_no_mmap,
292         .sendpage       = sock_no_sendpage,
293 };
294
295 static void mctp_sk_expire_keys(struct timer_list *timer)
296 {
297         struct mctp_sock *msk = container_of(timer, struct mctp_sock,
298                                              key_expiry);
299         struct net *net = sock_net(&msk->sk);
300         unsigned long next_expiry, flags;
301         struct mctp_sk_key *key;
302         struct hlist_node *tmp;
303         bool next_expiry_valid = false;
304
305         spin_lock_irqsave(&net->mctp.keys_lock, flags);
306
307         hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
308                 spin_lock(&key->lock);
309
310                 if (!time_after_eq(key->expiry, jiffies)) {
311                         trace_mctp_key_release(key, MCTP_TRACE_KEY_TIMEOUT);
312                         key->valid = false;
313                         hlist_del_rcu(&key->hlist);
314                         hlist_del_rcu(&key->sklist);
315                         spin_unlock(&key->lock);
316                         mctp_key_unref(key);
317                         continue;
318                 }
319
320                 if (next_expiry_valid) {
321                         if (time_before(key->expiry, next_expiry))
322                                 next_expiry = key->expiry;
323                 } else {
324                         next_expiry = key->expiry;
325                         next_expiry_valid = true;
326                 }
327                 spin_unlock(&key->lock);
328         }
329
330         spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
331
332         if (next_expiry_valid)
333                 mod_timer(timer, next_expiry);
334 }
335
336 static int mctp_sk_init(struct sock *sk)
337 {
338         struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
339
340         INIT_HLIST_HEAD(&msk->keys);
341         timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0);
342         return 0;
343 }
344
345 static void mctp_sk_close(struct sock *sk, long timeout)
346 {
347         struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
348
349         del_timer_sync(&msk->key_expiry);
350         sk_common_release(sk);
351 }
352
353 static int mctp_sk_hash(struct sock *sk)
354 {
355         struct net *net = sock_net(sk);
356
357         mutex_lock(&net->mctp.bind_lock);
358         sk_add_node_rcu(sk, &net->mctp.binds);
359         mutex_unlock(&net->mctp.bind_lock);
360
361         return 0;
362 }
363
364 static void mctp_sk_unhash(struct sock *sk)
365 {
366         struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
367         struct net *net = sock_net(sk);
368         struct mctp_sk_key *key;
369         struct hlist_node *tmp;
370         unsigned long flags;
371
372         /* remove from any type-based binds */
373         mutex_lock(&net->mctp.bind_lock);
374         sk_del_node_init_rcu(sk);
375         mutex_unlock(&net->mctp.bind_lock);
376
377         /* remove tag allocations */
378         spin_lock_irqsave(&net->mctp.keys_lock, flags);
379         hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
380                 hlist_del(&key->sklist);
381                 hlist_del(&key->hlist);
382
383                 trace_mctp_key_release(key, MCTP_TRACE_KEY_CLOSED);
384
385                 spin_lock(&key->lock);
386                 if (key->reasm_head)
387                         kfree_skb(key->reasm_head);
388                 key->reasm_head = NULL;
389                 key->reasm_dead = true;
390                 key->valid = false;
391                 spin_unlock(&key->lock);
392
393                 /* key is no longer on the lookup lists, unref */
394                 mctp_key_unref(key);
395         }
396         spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
397 }
398
399 static struct proto mctp_proto = {
400         .name           = "MCTP",
401         .owner          = THIS_MODULE,
402         .obj_size       = sizeof(struct mctp_sock),
403         .init           = mctp_sk_init,
404         .close          = mctp_sk_close,
405         .hash           = mctp_sk_hash,
406         .unhash         = mctp_sk_unhash,
407 };
408
409 static int mctp_pf_create(struct net *net, struct socket *sock,
410                           int protocol, int kern)
411 {
412         const struct proto_ops *ops;
413         struct proto *proto;
414         struct sock *sk;
415         int rc;
416
417         if (protocol)
418                 return -EPROTONOSUPPORT;
419
420         /* only datagram sockets are supported */
421         if (sock->type != SOCK_DGRAM)
422                 return -ESOCKTNOSUPPORT;
423
424         proto = &mctp_proto;
425         ops = &mctp_dgram_ops;
426
427         sock->state = SS_UNCONNECTED;
428         sock->ops = ops;
429
430         sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern);
431         if (!sk)
432                 return -ENOMEM;
433
434         sock_init_data(sock, sk);
435
436         rc = 0;
437         if (sk->sk_prot->init)
438                 rc = sk->sk_prot->init(sk);
439
440         if (rc)
441                 goto err_sk_put;
442
443         return 0;
444
445 err_sk_put:
446         sock_orphan(sk);
447         sock_put(sk);
448         return rc;
449 }
450
451 static struct net_proto_family mctp_pf = {
452         .family = PF_MCTP,
453         .create = mctp_pf_create,
454         .owner = THIS_MODULE,
455 };
456
457 static __init int mctp_init(void)
458 {
459         int rc;
460
461         /* ensure our uapi tag definitions match the header format */
462         BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO);
463         BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK);
464
465         pr_info("mctp: management component transport protocol core\n");
466
467         rc = sock_register(&mctp_pf);
468         if (rc)
469                 return rc;
470
471         rc = proto_register(&mctp_proto, 0);
472         if (rc)
473                 goto err_unreg_sock;
474
475         rc = mctp_routes_init();
476         if (rc)
477                 goto err_unreg_proto;
478
479         rc = mctp_neigh_init();
480         if (rc)
481                 goto err_unreg_proto;
482
483         mctp_device_init();
484
485         return 0;
486
487 err_unreg_proto:
488         proto_unregister(&mctp_proto);
489 err_unreg_sock:
490         sock_unregister(PF_MCTP);
491
492         return rc;
493 }
494
495 static __exit void mctp_exit(void)
496 {
497         mctp_device_exit();
498         mctp_neigh_exit();
499         mctp_routes_exit();
500         proto_unregister(&mctp_proto);
501         sock_unregister(PF_MCTP);
502 }
503
504 subsys_initcall(mctp_init);
505 module_exit(mctp_exit);
506
507 MODULE_DESCRIPTION("MCTP core");
508 MODULE_LICENSE("GPL v2");
509 MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>");
510
511 MODULE_ALIAS_NETPROTO(PF_MCTP);