1 // SPDX-License-Identifier: GPL-2.0
3 * Management Component Transport Protocol (MCTP)
5 * Copyright (c) 2021 Code Construct
6 * Copyright (c) 2021 Google
9 #include <linux/if_arp.h>
10 #include <linux/net.h>
11 #include <linux/mctp.h>
12 #include <linux/module.h>
13 #include <linux/socket.h>
16 #include <net/mctpdevice.h>
19 #define CREATE_TRACE_POINTS
20 #include <trace/events/mctp.h>
22 /* socket implementation */
24 static int mctp_release(struct socket *sock)
26 struct sock *sk = sock->sk;
30 sk->sk_prot->close(sk, 0);
36 static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
38 struct sock *sk = sock->sk;
39 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
40 struct sockaddr_mctp *smctp;
43 if (addrlen < sizeof(*smctp))
46 if (addr->sa_family != AF_MCTP)
49 if (!capable(CAP_NET_BIND_SERVICE))
52 /* it's a valid sockaddr for MCTP, cast and do protocol checks */
53 smctp = (struct sockaddr_mctp *)addr;
57 /* TODO: allow rebind */
62 msk->bind_net = smctp->smctp_network;
63 msk->bind_addr = smctp->smctp_addr.s_addr;
64 msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */
66 rc = sk->sk_prot->hash(sk);
74 static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
76 DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
77 const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
78 int rc, addrlen = msg->msg_namelen;
79 struct sock *sk = sock->sk;
80 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
81 struct mctp_skb_cb *cb;
82 struct mctp_route *rt;
86 if (addrlen < sizeof(struct sockaddr_mctp))
88 if (addr->smctp_family != AF_MCTP)
90 if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER))
94 /* TODO: connect()ed sockets */
98 if (!capable(CAP_NET_RAW))
101 if (addr->smctp_network == MCTP_NET_ANY)
102 addr->smctp_network = mctp_default_net(sock_net(sk));
104 skb = sock_alloc_send_skb(sk, hlen + 1 + len,
105 msg->msg_flags & MSG_DONTWAIT, &rc);
109 skb_reserve(skb, hlen);
111 /* set type as fist byte in payload */
112 *(u8 *)skb_put(skb, 1) = addr->smctp_type;
114 rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
120 cb->net = addr->smctp_network;
122 /* direct addressing */
123 if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
124 DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
125 extaddr, msg->msg_name);
127 if (extaddr->smctp_halen > sizeof(cb->haddr)) {
132 cb->ifindex = extaddr->smctp_ifindex;
133 cb->halen = extaddr->smctp_halen;
134 memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen);
138 rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
139 addr->smctp_addr.s_addr);
146 rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
156 static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
159 DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
160 struct sock *sk = sock->sk;
161 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
167 if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK))
170 skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc);
179 /* extract message type, remove from data */
180 type = *((u8 *)skb->data);
181 msglen = skb->len - 1;
184 msg->msg_flags |= MSG_TRUNC;
188 rc = skb_copy_datagram_msg(skb, 1, msg, len);
192 sock_recv_ts_and_drops(msg, sk, skb);
195 struct mctp_skb_cb *cb = mctp_cb(skb);
196 /* TODO: expand mctp_skb_cb for header fields? */
197 struct mctp_hdr *hdr = mctp_hdr(skb);
199 addr = msg->msg_name;
200 addr->smctp_family = AF_MCTP;
201 addr->smctp_network = cb->net;
202 addr->smctp_addr.s_addr = hdr->src;
203 addr->smctp_type = type;
204 addr->smctp_tag = hdr->flags_seq_tag &
205 (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
206 msg->msg_namelen = sizeof(*addr);
209 DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae,
211 msg->msg_namelen = sizeof(*ae);
212 ae->smctp_ifindex = cb->ifindex;
213 ae->smctp_halen = cb->halen;
214 memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr));
215 memcpy(ae->smctp_haddr, cb->haddr, cb->halen);
221 if (flags & MSG_TRUNC)
225 skb_free_datagram(sk, skb);
229 static int mctp_setsockopt(struct socket *sock, int level, int optname,
230 sockptr_t optval, unsigned int optlen)
232 struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
235 if (level != SOL_MCTP)
238 if (optname == MCTP_OPT_ADDR_EXT) {
239 if (optlen != sizeof(int))
241 if (copy_from_sockptr(&val, optval, sizeof(int)))
250 static int mctp_getsockopt(struct socket *sock, int level, int optname,
251 char __user *optval, int __user *optlen)
253 struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
256 if (level != SOL_MCTP)
259 if (get_user(len, optlen))
262 if (optname == MCTP_OPT_ADDR_EXT) {
263 if (len != sizeof(int))
265 val = !!msk->addr_ext;
266 if (copy_to_user(optval, &val, len))
274 static const struct proto_ops mctp_dgram_ops = {
276 .release = mctp_release,
278 .connect = sock_no_connect,
279 .socketpair = sock_no_socketpair,
280 .accept = sock_no_accept,
281 .getname = sock_no_getname,
282 .poll = datagram_poll,
283 .ioctl = sock_no_ioctl,
284 .gettstamp = sock_gettstamp,
285 .listen = sock_no_listen,
286 .shutdown = sock_no_shutdown,
287 .setsockopt = mctp_setsockopt,
288 .getsockopt = mctp_getsockopt,
289 .sendmsg = mctp_sendmsg,
290 .recvmsg = mctp_recvmsg,
291 .mmap = sock_no_mmap,
292 .sendpage = sock_no_sendpage,
295 static void mctp_sk_expire_keys(struct timer_list *timer)
297 struct mctp_sock *msk = container_of(timer, struct mctp_sock,
299 struct net *net = sock_net(&msk->sk);
300 unsigned long next_expiry, flags;
301 struct mctp_sk_key *key;
302 struct hlist_node *tmp;
303 bool next_expiry_valid = false;
305 spin_lock_irqsave(&net->mctp.keys_lock, flags);
307 hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
308 spin_lock(&key->lock);
310 if (!time_after_eq(key->expiry, jiffies)) {
311 trace_mctp_key_release(key, MCTP_TRACE_KEY_TIMEOUT);
313 hlist_del_rcu(&key->hlist);
314 hlist_del_rcu(&key->sklist);
315 spin_unlock(&key->lock);
320 if (next_expiry_valid) {
321 if (time_before(key->expiry, next_expiry))
322 next_expiry = key->expiry;
324 next_expiry = key->expiry;
325 next_expiry_valid = true;
327 spin_unlock(&key->lock);
330 spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
332 if (next_expiry_valid)
333 mod_timer(timer, next_expiry);
336 static int mctp_sk_init(struct sock *sk)
338 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
340 INIT_HLIST_HEAD(&msk->keys);
341 timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0);
345 static void mctp_sk_close(struct sock *sk, long timeout)
347 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
349 del_timer_sync(&msk->key_expiry);
350 sk_common_release(sk);
353 static int mctp_sk_hash(struct sock *sk)
355 struct net *net = sock_net(sk);
357 mutex_lock(&net->mctp.bind_lock);
358 sk_add_node_rcu(sk, &net->mctp.binds);
359 mutex_unlock(&net->mctp.bind_lock);
364 static void mctp_sk_unhash(struct sock *sk)
366 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
367 struct net *net = sock_net(sk);
368 struct mctp_sk_key *key;
369 struct hlist_node *tmp;
372 /* remove from any type-based binds */
373 mutex_lock(&net->mctp.bind_lock);
374 sk_del_node_init_rcu(sk);
375 mutex_unlock(&net->mctp.bind_lock);
377 /* remove tag allocations */
378 spin_lock_irqsave(&net->mctp.keys_lock, flags);
379 hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
380 hlist_del(&key->sklist);
381 hlist_del(&key->hlist);
383 trace_mctp_key_release(key, MCTP_TRACE_KEY_CLOSED);
385 spin_lock(&key->lock);
387 kfree_skb(key->reasm_head);
388 key->reasm_head = NULL;
389 key->reasm_dead = true;
391 spin_unlock(&key->lock);
393 /* key is no longer on the lookup lists, unref */
396 spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
399 static struct proto mctp_proto = {
401 .owner = THIS_MODULE,
402 .obj_size = sizeof(struct mctp_sock),
403 .init = mctp_sk_init,
404 .close = mctp_sk_close,
405 .hash = mctp_sk_hash,
406 .unhash = mctp_sk_unhash,
409 static int mctp_pf_create(struct net *net, struct socket *sock,
410 int protocol, int kern)
412 const struct proto_ops *ops;
418 return -EPROTONOSUPPORT;
420 /* only datagram sockets are supported */
421 if (sock->type != SOCK_DGRAM)
422 return -ESOCKTNOSUPPORT;
425 ops = &mctp_dgram_ops;
427 sock->state = SS_UNCONNECTED;
430 sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern);
434 sock_init_data(sock, sk);
437 if (sk->sk_prot->init)
438 rc = sk->sk_prot->init(sk);
451 static struct net_proto_family mctp_pf = {
453 .create = mctp_pf_create,
454 .owner = THIS_MODULE,
457 static __init int mctp_init(void)
461 /* ensure our uapi tag definitions match the header format */
462 BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO);
463 BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK);
465 pr_info("mctp: management component transport protocol core\n");
467 rc = sock_register(&mctp_pf);
471 rc = proto_register(&mctp_proto, 0);
475 rc = mctp_routes_init();
477 goto err_unreg_proto;
479 rc = mctp_neigh_init();
481 goto err_unreg_proto;
488 proto_unregister(&mctp_proto);
490 sock_unregister(PF_MCTP);
495 static __exit void mctp_exit(void)
500 proto_unregister(&mctp_proto);
501 sock_unregister(PF_MCTP);
504 subsys_initcall(mctp_init);
505 module_exit(mctp_exit);
507 MODULE_DESCRIPTION("MCTP core");
508 MODULE_LICENSE("GPL v2");
509 MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>");
511 MODULE_ALIAS_NETPROTO(PF_MCTP);