mptcp: fix use-after-free on tcp fallback
[linux-2.6-microblaze.git] / net / mptcp / protocol.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2017 - 2019, Intel Corporation.
5  */
6
7 #define pr_fmt(fmt) "MPTCP: " fmt
8
9 #include <linux/kernel.h>
10 #include <linux/module.h>
11 #include <linux/netdevice.h>
12 #include <linux/sched/signal.h>
13 #include <linux/atomic.h>
14 #include <net/sock.h>
15 #include <net/inet_common.h>
16 #include <net/inet_hashtables.h>
17 #include <net/protocol.h>
18 #include <net/tcp.h>
19 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
20 #include <net/transp_v6.h>
21 #endif
22 #include <net/mptcp.h>
23 #include "protocol.h"
24
25 #define MPTCP_SAME_STATE TCP_MAX_STATES
26
27 /* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not
28  * completed yet or has failed, return the subflow socket.
29  * Otherwise return NULL.
30  */
31 static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
32 {
33         if (!msk->subflow || READ_ONCE(msk->can_ack))
34                 return NULL;
35
36         return msk->subflow;
37 }
38
39 static bool __mptcp_needs_tcp_fallback(const struct mptcp_sock *msk)
40 {
41         return msk->first && !sk_is_mptcp(msk->first);
42 }
43
44 static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk)
45 {
46         sock_owned_by_me((const struct sock *)msk);
47
48         if (likely(!__mptcp_needs_tcp_fallback(msk)))
49                 return NULL;
50
51         if (msk->subflow) {
52                 release_sock((struct sock *)msk);
53                 return msk->subflow;
54         }
55
56         return NULL;
57 }
58
59 static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk)
60 {
61         return !msk->first;
62 }
63
64 static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state)
65 {
66         struct mptcp_subflow_context *subflow;
67         struct sock *sk = (struct sock *)msk;
68         struct socket *ssock;
69         int err;
70
71         ssock = __mptcp_nmpc_socket(msk);
72         if (ssock)
73                 goto set_state;
74
75         if (!__mptcp_can_create_subflow(msk))
76                 return ERR_PTR(-EINVAL);
77
78         err = mptcp_subflow_create_socket(sk, &ssock);
79         if (err)
80                 return ERR_PTR(err);
81
82         msk->first = ssock->sk;
83         msk->subflow = ssock;
84         subflow = mptcp_subflow_ctx(ssock->sk);
85         list_add(&subflow->node, &msk->conn_list);
86         subflow->request_mptcp = 1;
87
88 set_state:
89         if (state != MPTCP_SAME_STATE)
90                 inet_sk_state_store(sk, state);
91         return ssock;
92 }
93
94 static struct sock *mptcp_subflow_get(const struct mptcp_sock *msk)
95 {
96         struct mptcp_subflow_context *subflow;
97
98         sock_owned_by_me((const struct sock *)msk);
99
100         mptcp_for_each_subflow(msk, subflow) {
101                 return mptcp_subflow_tcp_sock(subflow);
102         }
103
104         return NULL;
105 }
106
107 static bool mptcp_ext_cache_refill(struct mptcp_sock *msk)
108 {
109         if (!msk->cached_ext)
110                 msk->cached_ext = __skb_ext_alloc();
111
112         return !!msk->cached_ext;
113 }
114
115 static struct sock *mptcp_subflow_recv_lookup(const struct mptcp_sock *msk)
116 {
117         struct mptcp_subflow_context *subflow;
118         struct sock *sk = (struct sock *)msk;
119
120         sock_owned_by_me(sk);
121
122         mptcp_for_each_subflow(msk, subflow) {
123                 if (subflow->data_avail)
124                         return mptcp_subflow_tcp_sock(subflow);
125         }
126
127         return NULL;
128 }
129
130 static inline bool mptcp_skb_can_collapse_to(const struct mptcp_sock *msk,
131                                              const struct sk_buff *skb,
132                                              const struct mptcp_ext *mpext)
133 {
134         if (!tcp_skb_can_collapse_to(skb))
135                 return false;
136
137         /* can collapse only if MPTCP level sequence is in order */
138         return mpext && mpext->data_seq + mpext->data_len == msk->write_seq;
139 }
140
141 static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk,
142                               struct msghdr *msg, long *timeo, int *pmss_now,
143                               int *ps_goal)
144 {
145         int mss_now, avail_size, size_goal, ret;
146         struct mptcp_sock *msk = mptcp_sk(sk);
147         struct mptcp_ext *mpext = NULL;
148         struct sk_buff *skb, *tail;
149         bool can_collapse = false;
150         struct page_frag *pfrag;
151         size_t psize;
152
153         /* use the mptcp page cache so that we can easily move the data
154          * from one substream to another, but do per subflow memory accounting
155          */
156         pfrag = sk_page_frag(sk);
157         while (!sk_page_frag_refill(ssk, pfrag) ||
158                !mptcp_ext_cache_refill(msk)) {
159                 ret = sk_stream_wait_memory(ssk, timeo);
160                 if (ret)
161                         return ret;
162                 if (unlikely(__mptcp_needs_tcp_fallback(msk)))
163                         return 0;
164         }
165
166         /* compute copy limit */
167         mss_now = tcp_send_mss(ssk, &size_goal, msg->msg_flags);
168         *pmss_now = mss_now;
169         *ps_goal = size_goal;
170         avail_size = size_goal;
171         skb = tcp_write_queue_tail(ssk);
172         if (skb) {
173                 mpext = skb_ext_find(skb, SKB_EXT_MPTCP);
174
175                 /* Limit the write to the size available in the
176                  * current skb, if any, so that we create at most a new skb.
177                  * Explicitly tells TCP internals to avoid collapsing on later
178                  * queue management operation, to avoid breaking the ext <->
179                  * SSN association set here
180                  */
181                 can_collapse = (size_goal - skb->len > 0) &&
182                               mptcp_skb_can_collapse_to(msk, skb, mpext);
183                 if (!can_collapse)
184                         TCP_SKB_CB(skb)->eor = 1;
185                 else
186                         avail_size = size_goal - skb->len;
187         }
188         psize = min_t(size_t, pfrag->size - pfrag->offset, avail_size);
189
190         /* Copy to page */
191         pr_debug("left=%zu", msg_data_left(msg));
192         psize = copy_page_from_iter(pfrag->page, pfrag->offset,
193                                     min_t(size_t, msg_data_left(msg), psize),
194                                     &msg->msg_iter);
195         pr_debug("left=%zu", msg_data_left(msg));
196         if (!psize)
197                 return -EINVAL;
198
199         /* tell the TCP stack to delay the push so that we can safely
200          * access the skb after the sendpages call
201          */
202         ret = do_tcp_sendpages(ssk, pfrag->page, pfrag->offset, psize,
203                                msg->msg_flags | MSG_SENDPAGE_NOTLAST);
204         if (ret <= 0)
205                 return ret;
206         if (unlikely(ret < psize))
207                 iov_iter_revert(&msg->msg_iter, psize - ret);
208
209         /* if the tail skb extension is still the cached one, collapsing
210          * really happened. Note: we can't check for 'same skb' as the sk_buff
211          * hdr on tail can be transmitted, freed and re-allocated by the
212          * do_tcp_sendpages() call
213          */
214         tail = tcp_write_queue_tail(ssk);
215         if (mpext && tail && mpext == skb_ext_find(tail, SKB_EXT_MPTCP)) {
216                 WARN_ON_ONCE(!can_collapse);
217                 mpext->data_len += ret;
218                 goto out;
219         }
220
221         skb = tcp_write_queue_tail(ssk);
222         mpext = __skb_ext_set(skb, SKB_EXT_MPTCP, msk->cached_ext);
223         msk->cached_ext = NULL;
224
225         memset(mpext, 0, sizeof(*mpext));
226         mpext->data_seq = msk->write_seq;
227         mpext->subflow_seq = mptcp_subflow_ctx(ssk)->rel_write_seq;
228         mpext->data_len = ret;
229         mpext->use_map = 1;
230         mpext->dsn64 = 1;
231
232         pr_debug("data_seq=%llu subflow_seq=%u data_len=%u dsn64=%d",
233                  mpext->data_seq, mpext->subflow_seq, mpext->data_len,
234                  mpext->dsn64);
235
236 out:
237         pfrag->offset += ret;
238         msk->write_seq += ret;
239         mptcp_subflow_ctx(ssk)->rel_write_seq += ret;
240
241         return ret;
242 }
243
244 static void ssk_check_wmem(struct mptcp_sock *msk, struct sock *ssk)
245 {
246         struct socket *sock;
247
248         if (likely(sk_stream_is_writeable(ssk)))
249                 return;
250
251         sock = READ_ONCE(ssk->sk_socket);
252
253         if (sock) {
254                 clear_bit(MPTCP_SEND_SPACE, &msk->flags);
255                 smp_mb__after_atomic();
256                 /* set NOSPACE only after clearing SEND_SPACE flag */
257                 set_bit(SOCK_NOSPACE, &sock->flags);
258         }
259 }
260
261 static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
262 {
263         int mss_now = 0, size_goal = 0, ret = 0;
264         struct mptcp_sock *msk = mptcp_sk(sk);
265         struct socket *ssock;
266         size_t copied = 0;
267         struct sock *ssk;
268         long timeo;
269
270         if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
271                 return -EOPNOTSUPP;
272
273         lock_sock(sk);
274         ssock = __mptcp_tcp_fallback(msk);
275         if (unlikely(ssock)) {
276 fallback:
277                 pr_debug("fallback passthrough");
278                 ret = sock_sendmsg(ssock, msg);
279                 return ret >= 0 ? ret + copied : (copied ? copied : ret);
280         }
281
282         timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
283
284         ssk = mptcp_subflow_get(msk);
285         if (!ssk) {
286                 release_sock(sk);
287                 return -ENOTCONN;
288         }
289
290         pr_debug("conn_list->subflow=%p", ssk);
291
292         lock_sock(ssk);
293         while (msg_data_left(msg)) {
294                 ret = mptcp_sendmsg_frag(sk, ssk, msg, &timeo, &mss_now,
295                                          &size_goal);
296                 if (ret < 0)
297                         break;
298                 if (ret == 0 && unlikely(__mptcp_needs_tcp_fallback(msk))) {
299                         release_sock(ssk);
300                         ssock = __mptcp_tcp_fallback(msk);
301                         goto fallback;
302                 }
303
304                 copied += ret;
305         }
306
307         if (copied) {
308                 ret = copied;
309                 tcp_push(ssk, msg->msg_flags, mss_now, tcp_sk(ssk)->nonagle,
310                          size_goal);
311         }
312
313         ssk_check_wmem(msk, ssk);
314         release_sock(ssk);
315         release_sock(sk);
316         return ret;
317 }
318
319 int mptcp_read_actor(read_descriptor_t *desc, struct sk_buff *skb,
320                      unsigned int offset, size_t len)
321 {
322         struct mptcp_read_arg *arg = desc->arg.data;
323         size_t copy_len;
324
325         copy_len = min(desc->count, len);
326
327         if (likely(arg->msg)) {
328                 int err;
329
330                 err = skb_copy_datagram_msg(skb, offset, arg->msg, copy_len);
331                 if (err) {
332                         pr_debug("error path");
333                         desc->error = err;
334                         return err;
335                 }
336         } else {
337                 pr_debug("Flushing skb payload");
338         }
339
340         desc->count -= copy_len;
341
342         pr_debug("consumed %zu bytes, %zu left", copy_len, desc->count);
343         return copy_len;
344 }
345
346 static void mptcp_wait_data(struct sock *sk, long *timeo)
347 {
348         DEFINE_WAIT_FUNC(wait, woken_wake_function);
349         struct mptcp_sock *msk = mptcp_sk(sk);
350
351         add_wait_queue(sk_sleep(sk), &wait);
352         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
353
354         sk_wait_event(sk, timeo,
355                       test_and_clear_bit(MPTCP_DATA_READY, &msk->flags), &wait);
356
357         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
358         remove_wait_queue(sk_sleep(sk), &wait);
359 }
360
361 static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
362                          int nonblock, int flags, int *addr_len)
363 {
364         struct mptcp_sock *msk = mptcp_sk(sk);
365         struct mptcp_subflow_context *subflow;
366         bool more_data_avail = false;
367         struct mptcp_read_arg arg;
368         read_descriptor_t desc;
369         bool wait_data = false;
370         struct socket *ssock;
371         struct tcp_sock *tp;
372         bool done = false;
373         struct sock *ssk;
374         int copied = 0;
375         int target;
376         long timeo;
377
378         if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT))
379                 return -EOPNOTSUPP;
380
381         lock_sock(sk);
382         ssock = __mptcp_tcp_fallback(msk);
383         if (unlikely(ssock)) {
384 fallback:
385                 pr_debug("fallback-read subflow=%p",
386                          mptcp_subflow_ctx(ssock->sk));
387                 copied = sock_recvmsg(ssock, msg, flags);
388                 return copied;
389         }
390
391         arg.msg = msg;
392         desc.arg.data = &arg;
393         desc.error = 0;
394
395         timeo = sock_rcvtimeo(sk, nonblock);
396
397         len = min_t(size_t, len, INT_MAX);
398         target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
399
400         while (!done) {
401                 u32 map_remaining;
402                 int bytes_read;
403
404                 ssk = mptcp_subflow_recv_lookup(msk);
405                 pr_debug("msk=%p ssk=%p", msk, ssk);
406                 if (!ssk)
407                         goto wait_for_data;
408
409                 subflow = mptcp_subflow_ctx(ssk);
410                 tp = tcp_sk(ssk);
411
412                 lock_sock(ssk);
413                 do {
414                         /* try to read as much data as available */
415                         map_remaining = subflow->map_data_len -
416                                         mptcp_subflow_get_map_offset(subflow);
417                         desc.count = min_t(size_t, len - copied, map_remaining);
418                         pr_debug("reading %zu bytes, copied %d", desc.count,
419                                  copied);
420                         bytes_read = tcp_read_sock(ssk, &desc,
421                                                    mptcp_read_actor);
422                         if (bytes_read < 0) {
423                                 if (!copied)
424                                         copied = bytes_read;
425                                 done = true;
426                                 goto next;
427                         }
428
429                         pr_debug("msk ack_seq=%llx -> %llx", msk->ack_seq,
430                                  msk->ack_seq + bytes_read);
431                         msk->ack_seq += bytes_read;
432                         copied += bytes_read;
433                         if (copied >= len) {
434                                 done = true;
435                                 goto next;
436                         }
437                         if (tp->urg_data && tp->urg_seq == tp->copied_seq) {
438                                 pr_err("Urgent data present, cannot proceed");
439                                 done = true;
440                                 goto next;
441                         }
442 next:
443                         more_data_avail = mptcp_subflow_data_available(ssk);
444                 } while (more_data_avail && !done);
445                 release_sock(ssk);
446                 continue;
447
448 wait_for_data:
449                 more_data_avail = false;
450
451                 /* only the master socket status is relevant here. The exit
452                  * conditions mirror closely tcp_recvmsg()
453                  */
454                 if (copied >= target)
455                         break;
456
457                 if (copied) {
458                         if (sk->sk_err ||
459                             sk->sk_state == TCP_CLOSE ||
460                             (sk->sk_shutdown & RCV_SHUTDOWN) ||
461                             !timeo ||
462                             signal_pending(current))
463                                 break;
464                 } else {
465                         if (sk->sk_err) {
466                                 copied = sock_error(sk);
467                                 break;
468                         }
469
470                         if (sk->sk_shutdown & RCV_SHUTDOWN)
471                                 break;
472
473                         if (sk->sk_state == TCP_CLOSE) {
474                                 copied = -ENOTCONN;
475                                 break;
476                         }
477
478                         if (!timeo) {
479                                 copied = -EAGAIN;
480                                 break;
481                         }
482
483                         if (signal_pending(current)) {
484                                 copied = sock_intr_errno(timeo);
485                                 break;
486                         }
487                 }
488
489                 pr_debug("block timeout %ld", timeo);
490                 wait_data = true;
491                 mptcp_wait_data(sk, &timeo);
492                 if (unlikely(__mptcp_tcp_fallback(msk)))
493                         goto fallback;
494         }
495
496         if (more_data_avail) {
497                 if (!test_bit(MPTCP_DATA_READY, &msk->flags))
498                         set_bit(MPTCP_DATA_READY, &msk->flags);
499         } else if (!wait_data) {
500                 clear_bit(MPTCP_DATA_READY, &msk->flags);
501
502                 /* .. race-breaker: ssk might get new data after last
503                  * data_available() returns false.
504                  */
505                 ssk = mptcp_subflow_recv_lookup(msk);
506                 if (unlikely(ssk))
507                         set_bit(MPTCP_DATA_READY, &msk->flags);
508         }
509
510         release_sock(sk);
511         return copied;
512 }
513
514 /* subflow sockets can be either outgoing (connect) or incoming
515  * (accept).
516  *
517  * Outgoing subflows use in-kernel sockets.
518  * Incoming subflows do not have their own 'struct socket' allocated,
519  * so we need to use tcp_close() after detaching them from the mptcp
520  * parent socket.
521  */
522 static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
523                               struct mptcp_subflow_context *subflow,
524                               long timeout)
525 {
526         struct socket *sock = READ_ONCE(ssk->sk_socket);
527
528         list_del(&subflow->node);
529
530         if (sock && sock != sk->sk_socket) {
531                 /* outgoing subflow */
532                 sock_release(sock);
533         } else {
534                 /* incoming subflow */
535                 tcp_close(ssk, timeout);
536         }
537 }
538
539 static int __mptcp_init_sock(struct sock *sk)
540 {
541         struct mptcp_sock *msk = mptcp_sk(sk);
542
543         INIT_LIST_HEAD(&msk->conn_list);
544         __set_bit(MPTCP_SEND_SPACE, &msk->flags);
545
546         msk->first = NULL;
547
548         return 0;
549 }
550
551 static int mptcp_init_sock(struct sock *sk)
552 {
553         if (!mptcp_is_enabled(sock_net(sk)))
554                 return -ENOPROTOOPT;
555
556         return __mptcp_init_sock(sk);
557 }
558
559 static void mptcp_subflow_shutdown(struct sock *ssk, int how)
560 {
561         lock_sock(ssk);
562
563         switch (ssk->sk_state) {
564         case TCP_LISTEN:
565                 if (!(how & RCV_SHUTDOWN))
566                         break;
567                 /* fall through */
568         case TCP_SYN_SENT:
569                 tcp_disconnect(ssk, O_NONBLOCK);
570                 break;
571         default:
572                 ssk->sk_shutdown |= how;
573                 tcp_shutdown(ssk, how);
574                 break;
575         }
576
577         /* Wake up anyone sleeping in poll. */
578         ssk->sk_state_change(ssk);
579         release_sock(ssk);
580 }
581
582 /* Called with msk lock held, releases such lock before returning */
583 static void mptcp_close(struct sock *sk, long timeout)
584 {
585         struct mptcp_subflow_context *subflow, *tmp;
586         struct mptcp_sock *msk = mptcp_sk(sk);
587         LIST_HEAD(conn_list);
588
589         lock_sock(sk);
590
591         mptcp_token_destroy(msk->token);
592         inet_sk_state_store(sk, TCP_CLOSE);
593
594         list_splice_init(&msk->conn_list, &conn_list);
595
596         release_sock(sk);
597
598         list_for_each_entry_safe(subflow, tmp, &conn_list, node) {
599                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
600
601                 __mptcp_close_ssk(sk, ssk, subflow, timeout);
602         }
603
604         sk_common_release(sk);
605 }
606
607 static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
608 {
609 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
610         const struct ipv6_pinfo *ssk6 = inet6_sk(ssk);
611         struct ipv6_pinfo *msk6 = inet6_sk(msk);
612
613         msk->sk_v6_daddr = ssk->sk_v6_daddr;
614         msk->sk_v6_rcv_saddr = ssk->sk_v6_rcv_saddr;
615
616         if (msk6 && ssk6) {
617                 msk6->saddr = ssk6->saddr;
618                 msk6->flow_label = ssk6->flow_label;
619         }
620 #endif
621
622         inet_sk(msk)->inet_num = inet_sk(ssk)->inet_num;
623         inet_sk(msk)->inet_dport = inet_sk(ssk)->inet_dport;
624         inet_sk(msk)->inet_sport = inet_sk(ssk)->inet_sport;
625         inet_sk(msk)->inet_daddr = inet_sk(ssk)->inet_daddr;
626         inet_sk(msk)->inet_saddr = inet_sk(ssk)->inet_saddr;
627         inet_sk(msk)->inet_rcv_saddr = inet_sk(ssk)->inet_rcv_saddr;
628 }
629
630 static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
631                                  bool kern)
632 {
633         struct mptcp_sock *msk = mptcp_sk(sk);
634         struct socket *listener;
635         struct sock *newsk;
636
637         listener = __mptcp_nmpc_socket(msk);
638         if (WARN_ON_ONCE(!listener)) {
639                 *err = -EINVAL;
640                 return NULL;
641         }
642
643         pr_debug("msk=%p, listener=%p", msk, mptcp_subflow_ctx(listener->sk));
644         newsk = inet_csk_accept(listener->sk, flags, err, kern);
645         if (!newsk)
646                 return NULL;
647
648         pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
649
650         if (sk_is_mptcp(newsk)) {
651                 struct mptcp_subflow_context *subflow;
652                 struct sock *new_mptcp_sock;
653                 struct sock *ssk = newsk;
654                 u64 ack_seq;
655
656                 subflow = mptcp_subflow_ctx(newsk);
657                 lock_sock(sk);
658
659                 local_bh_disable();
660                 new_mptcp_sock = sk_clone_lock(sk, GFP_ATOMIC);
661                 if (!new_mptcp_sock) {
662                         *err = -ENOBUFS;
663                         local_bh_enable();
664                         release_sock(sk);
665                         mptcp_subflow_shutdown(newsk, SHUT_RDWR + 1);
666                         tcp_close(newsk, 0);
667                         return NULL;
668                 }
669
670                 __mptcp_init_sock(new_mptcp_sock);
671
672                 msk = mptcp_sk(new_mptcp_sock);
673                 msk->local_key = subflow->local_key;
674                 msk->token = subflow->token;
675                 msk->subflow = NULL;
676                 msk->first = newsk;
677
678                 mptcp_token_update_accept(newsk, new_mptcp_sock);
679
680                 msk->write_seq = subflow->idsn + 1;
681                 if (subflow->can_ack) {
682                         msk->can_ack = true;
683                         msk->remote_key = subflow->remote_key;
684                         mptcp_crypto_key_sha(msk->remote_key, NULL, &ack_seq);
685                         ack_seq++;
686                         msk->ack_seq = ack_seq;
687                 }
688                 newsk = new_mptcp_sock;
689                 mptcp_copy_inaddrs(newsk, ssk);
690                 list_add(&subflow->node, &msk->conn_list);
691
692                 /* will be fully established at mptcp_stream_accept()
693                  * completion.
694                  */
695                 inet_sk_state_store(new_mptcp_sock, TCP_SYN_RECV);
696                 bh_unlock_sock(new_mptcp_sock);
697                 local_bh_enable();
698                 release_sock(sk);
699
700                 /* the subflow can already receive packet, avoid racing with
701                  * the receive path and process the pending ones
702                  */
703                 lock_sock(ssk);
704                 subflow->rel_write_seq = 1;
705                 subflow->tcp_sock = ssk;
706                 subflow->conn = new_mptcp_sock;
707                 if (unlikely(!skb_queue_empty(&ssk->sk_receive_queue)))
708                         mptcp_subflow_data_available(ssk);
709                 release_sock(ssk);
710         }
711
712         return newsk;
713 }
714
715 static void mptcp_destroy(struct sock *sk)
716 {
717         struct mptcp_sock *msk = mptcp_sk(sk);
718
719         if (msk->cached_ext)
720                 __skb_ext_put(msk->cached_ext);
721 }
722
723 static int mptcp_setsockopt(struct sock *sk, int level, int optname,
724                             char __user *optval, unsigned int optlen)
725 {
726         struct mptcp_sock *msk = mptcp_sk(sk);
727         int ret = -EOPNOTSUPP;
728         struct socket *ssock;
729         struct sock *ssk;
730
731         pr_debug("msk=%p", msk);
732
733         /* @@ the meaning of setsockopt() when the socket is connected and
734          * there are multiple subflows is not defined.
735          */
736         lock_sock(sk);
737         ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE);
738         if (IS_ERR(ssock)) {
739                 release_sock(sk);
740                 return ret;
741         }
742
743         ssk = ssock->sk;
744         sock_hold(ssk);
745         release_sock(sk);
746
747         ret = tcp_setsockopt(ssk, level, optname, optval, optlen);
748         sock_put(ssk);
749
750         return ret;
751 }
752
753 static int mptcp_getsockopt(struct sock *sk, int level, int optname,
754                             char __user *optval, int __user *option)
755 {
756         struct mptcp_sock *msk = mptcp_sk(sk);
757         int ret = -EOPNOTSUPP;
758         struct socket *ssock;
759         struct sock *ssk;
760
761         pr_debug("msk=%p", msk);
762
763         /* @@ the meaning of getsockopt() when the socket is connected and
764          * there are multiple subflows is not defined.
765          */
766         lock_sock(sk);
767         ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE);
768         if (IS_ERR(ssock)) {
769                 release_sock(sk);
770                 return ret;
771         }
772
773         ssk = ssock->sk;
774         sock_hold(ssk);
775         release_sock(sk);
776
777         ret = tcp_getsockopt(ssk, level, optname, optval, option);
778         sock_put(ssk);
779
780         return ret;
781 }
782
783 static int mptcp_get_port(struct sock *sk, unsigned short snum)
784 {
785         struct mptcp_sock *msk = mptcp_sk(sk);
786         struct socket *ssock;
787
788         ssock = __mptcp_nmpc_socket(msk);
789         pr_debug("msk=%p, subflow=%p", msk, ssock);
790         if (WARN_ON_ONCE(!ssock))
791                 return -EINVAL;
792
793         return inet_csk_get_port(ssock->sk, snum);
794 }
795
796 void mptcp_finish_connect(struct sock *ssk)
797 {
798         struct mptcp_subflow_context *subflow;
799         struct mptcp_sock *msk;
800         struct sock *sk;
801         u64 ack_seq;
802
803         subflow = mptcp_subflow_ctx(ssk);
804
805         if (!subflow->mp_capable)
806                 return;
807
808         sk = subflow->conn;
809         msk = mptcp_sk(sk);
810
811         pr_debug("msk=%p, token=%u", sk, subflow->token);
812
813         mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq);
814         ack_seq++;
815         subflow->map_seq = ack_seq;
816         subflow->map_subflow_seq = 1;
817         subflow->rel_write_seq = 1;
818
819         /* the socket is not connected yet, no msk/subflow ops can access/race
820          * accessing the field below
821          */
822         WRITE_ONCE(msk->remote_key, subflow->remote_key);
823         WRITE_ONCE(msk->local_key, subflow->local_key);
824         WRITE_ONCE(msk->token, subflow->token);
825         WRITE_ONCE(msk->write_seq, subflow->idsn + 1);
826         WRITE_ONCE(msk->ack_seq, ack_seq);
827         WRITE_ONCE(msk->can_ack, 1);
828 }
829
830 static void mptcp_sock_graft(struct sock *sk, struct socket *parent)
831 {
832         write_lock_bh(&sk->sk_callback_lock);
833         rcu_assign_pointer(sk->sk_wq, &parent->wq);
834         sk_set_socket(sk, parent);
835         sk->sk_uid = SOCK_INODE(parent)->i_uid;
836         write_unlock_bh(&sk->sk_callback_lock);
837 }
838
839 static bool mptcp_memory_free(const struct sock *sk, int wake)
840 {
841         struct mptcp_sock *msk = mptcp_sk(sk);
842
843         return wake ? test_bit(MPTCP_SEND_SPACE, &msk->flags) : true;
844 }
845
846 static struct proto mptcp_prot = {
847         .name           = "MPTCP",
848         .owner          = THIS_MODULE,
849         .init           = mptcp_init_sock,
850         .close          = mptcp_close,
851         .accept         = mptcp_accept,
852         .setsockopt     = mptcp_setsockopt,
853         .getsockopt     = mptcp_getsockopt,
854         .shutdown       = tcp_shutdown,
855         .destroy        = mptcp_destroy,
856         .sendmsg        = mptcp_sendmsg,
857         .recvmsg        = mptcp_recvmsg,
858         .hash           = inet_hash,
859         .unhash         = inet_unhash,
860         .get_port       = mptcp_get_port,
861         .stream_memory_free     = mptcp_memory_free,
862         .obj_size       = sizeof(struct mptcp_sock),
863         .no_autobind    = true,
864 };
865
866 static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
867 {
868         struct mptcp_sock *msk = mptcp_sk(sock->sk);
869         struct socket *ssock;
870         int err;
871
872         lock_sock(sock->sk);
873         ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE);
874         if (IS_ERR(ssock)) {
875                 err = PTR_ERR(ssock);
876                 goto unlock;
877         }
878
879         err = ssock->ops->bind(ssock, uaddr, addr_len);
880         if (!err)
881                 mptcp_copy_inaddrs(sock->sk, ssock->sk);
882
883 unlock:
884         release_sock(sock->sk);
885         return err;
886 }
887
888 static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
889                                 int addr_len, int flags)
890 {
891         struct mptcp_sock *msk = mptcp_sk(sock->sk);
892         struct socket *ssock;
893         int err;
894
895         lock_sock(sock->sk);
896         ssock = __mptcp_socket_create(msk, TCP_SYN_SENT);
897         if (IS_ERR(ssock)) {
898                 err = PTR_ERR(ssock);
899                 goto unlock;
900         }
901
902 #ifdef CONFIG_TCP_MD5SIG
903         /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
904          * TCP option space.
905          */
906         if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info))
907                 mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0;
908 #endif
909
910         err = ssock->ops->connect(ssock, uaddr, addr_len, flags);
911         inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
912         mptcp_copy_inaddrs(sock->sk, ssock->sk);
913
914 unlock:
915         release_sock(sock->sk);
916         return err;
917 }
918
919 static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
920                             int peer)
921 {
922         if (sock->sk->sk_prot == &tcp_prot) {
923                 /* we are being invoked from __sys_accept4, after
924                  * mptcp_accept() has just accepted a non-mp-capable
925                  * flow: sk is a tcp_sk, not an mptcp one.
926                  *
927                  * Hand the socket over to tcp so all further socket ops
928                  * bypass mptcp.
929                  */
930                 sock->ops = &inet_stream_ops;
931         }
932
933         return inet_getname(sock, uaddr, peer);
934 }
935
936 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
937 static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
938                             int peer)
939 {
940         if (sock->sk->sk_prot == &tcpv6_prot) {
941                 /* we are being invoked from __sys_accept4 after
942                  * mptcp_accept() has accepted a non-mp-capable
943                  * subflow: sk is a tcp_sk, not mptcp.
944                  *
945                  * Hand the socket over to tcp so all further
946                  * socket ops bypass mptcp.
947                  */
948                 sock->ops = &inet6_stream_ops;
949         }
950
951         return inet6_getname(sock, uaddr, peer);
952 }
953 #endif
954
955 static int mptcp_listen(struct socket *sock, int backlog)
956 {
957         struct mptcp_sock *msk = mptcp_sk(sock->sk);
958         struct socket *ssock;
959         int err;
960
961         pr_debug("msk=%p", msk);
962
963         lock_sock(sock->sk);
964         ssock = __mptcp_socket_create(msk, TCP_LISTEN);
965         if (IS_ERR(ssock)) {
966                 err = PTR_ERR(ssock);
967                 goto unlock;
968         }
969
970         err = ssock->ops->listen(ssock, backlog);
971         inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
972         if (!err)
973                 mptcp_copy_inaddrs(sock->sk, ssock->sk);
974
975 unlock:
976         release_sock(sock->sk);
977         return err;
978 }
979
980 static bool is_tcp_proto(const struct proto *p)
981 {
982 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
983         return p == &tcp_prot || p == &tcpv6_prot;
984 #else
985         return p == &tcp_prot;
986 #endif
987 }
988
989 static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
990                                int flags, bool kern)
991 {
992         struct mptcp_sock *msk = mptcp_sk(sock->sk);
993         struct socket *ssock;
994         int err;
995
996         pr_debug("msk=%p", msk);
997
998         lock_sock(sock->sk);
999         if (sock->sk->sk_state != TCP_LISTEN)
1000                 goto unlock_fail;
1001
1002         ssock = __mptcp_nmpc_socket(msk);
1003         if (!ssock)
1004                 goto unlock_fail;
1005
1006         sock_hold(ssock->sk);
1007         release_sock(sock->sk);
1008
1009         err = ssock->ops->accept(sock, newsock, flags, kern);
1010         if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) {
1011                 struct mptcp_sock *msk = mptcp_sk(newsock->sk);
1012                 struct mptcp_subflow_context *subflow;
1013
1014                 /* set ssk->sk_socket of accept()ed flows to mptcp socket.
1015                  * This is needed so NOSPACE flag can be set from tcp stack.
1016                  */
1017                 list_for_each_entry(subflow, &msk->conn_list, node) {
1018                         struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
1019
1020                         if (!ssk->sk_socket)
1021                                 mptcp_sock_graft(ssk, newsock);
1022                 }
1023
1024                 inet_sk_state_store(newsock->sk, TCP_ESTABLISHED);
1025         }
1026
1027         sock_put(ssock->sk);
1028         return err;
1029
1030 unlock_fail:
1031         release_sock(sock->sk);
1032         return -EINVAL;
1033 }
1034
1035 static __poll_t mptcp_poll(struct file *file, struct socket *sock,
1036                            struct poll_table_struct *wait)
1037 {
1038         struct sock *sk = sock->sk;
1039         struct mptcp_sock *msk;
1040         struct socket *ssock;
1041         __poll_t mask = 0;
1042
1043         msk = mptcp_sk(sk);
1044         lock_sock(sk);
1045         ssock = __mptcp_nmpc_socket(msk);
1046         if (ssock) {
1047                 mask = ssock->ops->poll(file, ssock, wait);
1048                 release_sock(sk);
1049                 return mask;
1050         }
1051
1052         release_sock(sk);
1053         sock_poll_wait(file, sock, wait);
1054         lock_sock(sk);
1055         ssock = __mptcp_tcp_fallback(msk);
1056         if (unlikely(ssock))
1057                 return ssock->ops->poll(file, ssock, NULL);
1058
1059         if (test_bit(MPTCP_DATA_READY, &msk->flags))
1060                 mask = EPOLLIN | EPOLLRDNORM;
1061         if (sk_stream_is_writeable(sk) &&
1062             test_bit(MPTCP_SEND_SPACE, &msk->flags))
1063                 mask |= EPOLLOUT | EPOLLWRNORM;
1064         if (sk->sk_shutdown & RCV_SHUTDOWN)
1065                 mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
1066
1067         release_sock(sk);
1068
1069         return mask;
1070 }
1071
1072 static int mptcp_shutdown(struct socket *sock, int how)
1073 {
1074         struct mptcp_sock *msk = mptcp_sk(sock->sk);
1075         struct mptcp_subflow_context *subflow;
1076         int ret = 0;
1077
1078         pr_debug("sk=%p, how=%d", msk, how);
1079
1080         lock_sock(sock->sk);
1081
1082         if (how == SHUT_WR || how == SHUT_RDWR)
1083                 inet_sk_state_store(sock->sk, TCP_FIN_WAIT1);
1084
1085         how++;
1086
1087         if ((how & ~SHUTDOWN_MASK) || !how) {
1088                 ret = -EINVAL;
1089                 goto out_unlock;
1090         }
1091
1092         if (sock->state == SS_CONNECTING) {
1093                 if ((1 << sock->sk->sk_state) &
1094                     (TCPF_SYN_SENT | TCPF_SYN_RECV | TCPF_CLOSE))
1095                         sock->state = SS_DISCONNECTING;
1096                 else
1097                         sock->state = SS_CONNECTED;
1098         }
1099
1100         mptcp_for_each_subflow(msk, subflow) {
1101                 struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
1102
1103                 mptcp_subflow_shutdown(tcp_sk, how);
1104         }
1105
1106 out_unlock:
1107         release_sock(sock->sk);
1108
1109         return ret;
1110 }
1111
1112 static const struct proto_ops mptcp_stream_ops = {
1113         .family            = PF_INET,
1114         .owner             = THIS_MODULE,
1115         .release           = inet_release,
1116         .bind              = mptcp_bind,
1117         .connect           = mptcp_stream_connect,
1118         .socketpair        = sock_no_socketpair,
1119         .accept            = mptcp_stream_accept,
1120         .getname           = mptcp_v4_getname,
1121         .poll              = mptcp_poll,
1122         .ioctl             = inet_ioctl,
1123         .gettstamp         = sock_gettstamp,
1124         .listen            = mptcp_listen,
1125         .shutdown          = mptcp_shutdown,
1126         .setsockopt        = sock_common_setsockopt,
1127         .getsockopt        = sock_common_getsockopt,
1128         .sendmsg           = inet_sendmsg,
1129         .recvmsg           = inet_recvmsg,
1130         .mmap              = sock_no_mmap,
1131         .sendpage          = inet_sendpage,
1132 #ifdef CONFIG_COMPAT
1133         .compat_setsockopt = compat_sock_common_setsockopt,
1134         .compat_getsockopt = compat_sock_common_getsockopt,
1135 #endif
1136 };
1137
1138 static struct inet_protosw mptcp_protosw = {
1139         .type           = SOCK_STREAM,
1140         .protocol       = IPPROTO_MPTCP,
1141         .prot           = &mptcp_prot,
1142         .ops            = &mptcp_stream_ops,
1143         .flags          = INET_PROTOSW_ICSK,
1144 };
1145
1146 void mptcp_proto_init(void)
1147 {
1148         mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo;
1149
1150         mptcp_subflow_init();
1151
1152         if (proto_register(&mptcp_prot, 1) != 0)
1153                 panic("Failed to register MPTCP proto.\n");
1154
1155         inet_register_protosw(&mptcp_protosw);
1156 }
1157
1158 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1159 static const struct proto_ops mptcp_v6_stream_ops = {
1160         .family            = PF_INET6,
1161         .owner             = THIS_MODULE,
1162         .release           = inet6_release,
1163         .bind              = mptcp_bind,
1164         .connect           = mptcp_stream_connect,
1165         .socketpair        = sock_no_socketpair,
1166         .accept            = mptcp_stream_accept,
1167         .getname           = mptcp_v6_getname,
1168         .poll              = mptcp_poll,
1169         .ioctl             = inet6_ioctl,
1170         .gettstamp         = sock_gettstamp,
1171         .listen            = mptcp_listen,
1172         .shutdown          = mptcp_shutdown,
1173         .setsockopt        = sock_common_setsockopt,
1174         .getsockopt        = sock_common_getsockopt,
1175         .sendmsg           = inet6_sendmsg,
1176         .recvmsg           = inet6_recvmsg,
1177         .mmap              = sock_no_mmap,
1178         .sendpage          = inet_sendpage,
1179 #ifdef CONFIG_COMPAT
1180         .compat_setsockopt = compat_sock_common_setsockopt,
1181         .compat_getsockopt = compat_sock_common_getsockopt,
1182 #endif
1183 };
1184
1185 static struct proto mptcp_v6_prot;
1186
1187 static void mptcp_v6_destroy(struct sock *sk)
1188 {
1189         mptcp_destroy(sk);
1190         inet6_destroy_sock(sk);
1191 }
1192
1193 static struct inet_protosw mptcp_v6_protosw = {
1194         .type           = SOCK_STREAM,
1195         .protocol       = IPPROTO_MPTCP,
1196         .prot           = &mptcp_v6_prot,
1197         .ops            = &mptcp_v6_stream_ops,
1198         .flags          = INET_PROTOSW_ICSK,
1199 };
1200
1201 int mptcp_proto_v6_init(void)
1202 {
1203         int err;
1204
1205         mptcp_v6_prot = mptcp_prot;
1206         strcpy(mptcp_v6_prot.name, "MPTCPv6");
1207         mptcp_v6_prot.slab = NULL;
1208         mptcp_v6_prot.destroy = mptcp_v6_destroy;
1209         mptcp_v6_prot.obj_size = sizeof(struct mptcp_sock) +
1210                                  sizeof(struct ipv6_pinfo);
1211
1212         err = proto_register(&mptcp_v6_prot, 1);
1213         if (err)
1214                 return err;
1215
1216         err = inet6_register_protosw(&mptcp_v6_protosw);
1217         if (err)
1218                 proto_unregister(&mptcp_v6_prot);
1219
1220         return err;
1221 }
1222 #endif