b3c3dbc89b3f9068b81722278bf9f2ce3ec06859
[linux-2.6-microblaze.git] / net / mptcp / protocol.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2017 - 2019, Intel Corporation.
5  */
6
7 #define pr_fmt(fmt) "MPTCP: " fmt
8
9 #include <linux/kernel.h>
10 #include <linux/module.h>
11 #include <linux/netdevice.h>
12 #include <linux/sched/signal.h>
13 #include <linux/atomic.h>
14 #include <net/sock.h>
15 #include <net/inet_common.h>
16 #include <net/inet_hashtables.h>
17 #include <net/protocol.h>
18 #include <net/tcp.h>
19 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
20 #include <net/transp_v6.h>
21 #endif
22 #include <net/mptcp.h>
23 #include "protocol.h"
24 #include "mib.h"
25
26 #define MPTCP_SAME_STATE TCP_MAX_STATES
27
28 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
29 struct mptcp6_sock {
30         struct mptcp_sock msk;
31         struct ipv6_pinfo np;
32 };
33 #endif
34
35 struct mptcp_skb_cb {
36         u32 offset;
37 };
38
39 #define MPTCP_SKB_CB(__skb)     ((struct mptcp_skb_cb *)&((__skb)->cb[0]))
40
41 static struct percpu_counter mptcp_sockets_allocated;
42
43 /* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not
44  * completed yet or has failed, return the subflow socket.
45  * Otherwise return NULL.
46  */
47 static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
48 {
49         if (!msk->subflow || READ_ONCE(msk->can_ack))
50                 return NULL;
51
52         return msk->subflow;
53 }
54
55 static bool mptcp_is_tcpsk(struct sock *sk)
56 {
57         struct socket *sock = sk->sk_socket;
58
59         if (unlikely(sk->sk_prot == &tcp_prot)) {
60                 /* we are being invoked after mptcp_accept() has
61                  * accepted a non-mp-capable flow: sk is a tcp_sk,
62                  * not an mptcp one.
63                  *
64                  * Hand the socket over to tcp so all further socket ops
65                  * bypass mptcp.
66                  */
67                 sock->ops = &inet_stream_ops;
68                 return true;
69 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
70         } else if (unlikely(sk->sk_prot == &tcpv6_prot)) {
71                 sock->ops = &inet6_stream_ops;
72                 return true;
73 #endif
74         }
75
76         return false;
77 }
78
79 static struct sock *__mptcp_tcp_fallback(struct mptcp_sock *msk)
80 {
81         sock_owned_by_me((const struct sock *)msk);
82
83         if (likely(!__mptcp_check_fallback(msk)))
84                 return NULL;
85
86         return msk->first;
87 }
88
89 static int __mptcp_socket_create(struct mptcp_sock *msk)
90 {
91         struct mptcp_subflow_context *subflow;
92         struct sock *sk = (struct sock *)msk;
93         struct socket *ssock;
94         int err;
95
96         err = mptcp_subflow_create_socket(sk, &ssock);
97         if (err)
98                 return err;
99
100         msk->first = ssock->sk;
101         msk->subflow = ssock;
102         subflow = mptcp_subflow_ctx(ssock->sk);
103         list_add(&subflow->node, &msk->conn_list);
104         subflow->request_mptcp = 1;
105
106         /* accept() will wait on first subflow sk_wq, and we always wakes up
107          * via msk->sk_socket
108          */
109         RCU_INIT_POINTER(msk->first->sk_wq, &sk->sk_socket->wq);
110
111         return 0;
112 }
113
114 static void __mptcp_move_skb(struct mptcp_sock *msk, struct sock *ssk,
115                              struct sk_buff *skb,
116                              unsigned int offset, size_t copy_len)
117 {
118         struct sock *sk = (struct sock *)msk;
119         struct sk_buff *tail;
120
121         __skb_unlink(skb, &ssk->sk_receive_queue);
122
123         skb_ext_reset(skb);
124         skb_orphan(skb);
125         msk->ack_seq += copy_len;
126
127         tail = skb_peek_tail(&sk->sk_receive_queue);
128         if (offset == 0 && tail) {
129                 bool fragstolen;
130                 int delta;
131
132                 if (skb_try_coalesce(tail, skb, &fragstolen, &delta)) {
133                         kfree_skb_partial(skb, fragstolen);
134                         atomic_add(delta, &sk->sk_rmem_alloc);
135                         sk_mem_charge(sk, delta);
136                         return;
137                 }
138         }
139
140         skb_set_owner_r(skb, sk);
141         __skb_queue_tail(&sk->sk_receive_queue, skb);
142         MPTCP_SKB_CB(skb)->offset = offset;
143 }
144
145 /* both sockets must be locked */
146 static bool mptcp_subflow_dsn_valid(const struct mptcp_sock *msk,
147                                     struct sock *ssk)
148 {
149         struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
150         u64 dsn = mptcp_subflow_get_mapped_dsn(subflow);
151
152         /* revalidate data sequence number.
153          *
154          * mptcp_subflow_data_available() is usually called
155          * without msk lock.  Its unlikely (but possible)
156          * that msk->ack_seq has been advanced since the last
157          * call found in-sequence data.
158          */
159         if (likely(dsn == msk->ack_seq))
160                 return true;
161
162         subflow->data_avail = 0;
163         return mptcp_subflow_data_available(ssk);
164 }
165
166 static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,
167                                            struct sock *ssk,
168                                            unsigned int *bytes)
169 {
170         struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
171         struct sock *sk = (struct sock *)msk;
172         unsigned int moved = 0;
173         bool more_data_avail;
174         struct tcp_sock *tp;
175         bool done = false;
176
177         if (!mptcp_subflow_dsn_valid(msk, ssk)) {
178                 *bytes = 0;
179                 return false;
180         }
181
182         tp = tcp_sk(ssk);
183         do {
184                 u32 map_remaining, offset;
185                 u32 seq = tp->copied_seq;
186                 struct sk_buff *skb;
187                 bool fin;
188
189                 /* try to move as much data as available */
190                 map_remaining = subflow->map_data_len -
191                                 mptcp_subflow_get_map_offset(subflow);
192
193                 skb = skb_peek(&ssk->sk_receive_queue);
194                 if (!skb)
195                         break;
196
197                 if (__mptcp_check_fallback(msk)) {
198                         /* if we are running under the workqueue, TCP could have
199                          * collapsed skbs between dummy map creation and now
200                          * be sure to adjust the size
201                          */
202                         map_remaining = skb->len;
203                         subflow->map_data_len = skb->len;
204                 }
205
206                 offset = seq - TCP_SKB_CB(skb)->seq;
207                 fin = TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN;
208                 if (fin) {
209                         done = true;
210                         seq++;
211                 }
212
213                 if (offset < skb->len) {
214                         size_t len = skb->len - offset;
215
216                         if (tp->urg_data)
217                                 done = true;
218
219                         __mptcp_move_skb(msk, ssk, skb, offset, len);
220                         seq += len;
221                         moved += len;
222
223                         if (WARN_ON_ONCE(map_remaining < len))
224                                 break;
225                 } else {
226                         WARN_ON_ONCE(!fin);
227                         sk_eat_skb(ssk, skb);
228                         done = true;
229                 }
230
231                 WRITE_ONCE(tp->copied_seq, seq);
232                 more_data_avail = mptcp_subflow_data_available(ssk);
233
234                 if (atomic_read(&sk->sk_rmem_alloc) > READ_ONCE(sk->sk_rcvbuf)) {
235                         done = true;
236                         break;
237                 }
238         } while (more_data_avail);
239
240         *bytes = moved;
241
242         return done;
243 }
244
245 /* In most cases we will be able to lock the mptcp socket.  If its already
246  * owned, we need to defer to the work queue to avoid ABBA deadlock.
247  */
248 static bool move_skbs_to_msk(struct mptcp_sock *msk, struct sock *ssk)
249 {
250         struct sock *sk = (struct sock *)msk;
251         unsigned int moved = 0;
252
253         if (READ_ONCE(sk->sk_lock.owned))
254                 return false;
255
256         if (unlikely(!spin_trylock_bh(&sk->sk_lock.slock)))
257                 return false;
258
259         /* must re-check after taking the lock */
260         if (!READ_ONCE(sk->sk_lock.owned))
261                 __mptcp_move_skbs_from_subflow(msk, ssk, &moved);
262
263         spin_unlock_bh(&sk->sk_lock.slock);
264
265         return moved > 0;
266 }
267
268 void mptcp_data_ready(struct sock *sk, struct sock *ssk)
269 {
270         struct mptcp_sock *msk = mptcp_sk(sk);
271
272         set_bit(MPTCP_DATA_READY, &msk->flags);
273
274         if (atomic_read(&sk->sk_rmem_alloc) < READ_ONCE(sk->sk_rcvbuf) &&
275             move_skbs_to_msk(msk, ssk))
276                 goto wake;
277
278         /* don't schedule if mptcp sk is (still) over limit */
279         if (atomic_read(&sk->sk_rmem_alloc) > READ_ONCE(sk->sk_rcvbuf))
280                 goto wake;
281
282         /* mptcp socket is owned, release_cb should retry */
283         if (!test_and_set_bit(TCP_DELACK_TIMER_DEFERRED,
284                               &sk->sk_tsq_flags)) {
285                 sock_hold(sk);
286
287                 /* need to try again, its possible release_cb() has already
288                  * been called after the test_and_set_bit() above.
289                  */
290                 move_skbs_to_msk(msk, ssk);
291         }
292 wake:
293         sk->sk_data_ready(sk);
294 }
295
296 static void __mptcp_flush_join_list(struct mptcp_sock *msk)
297 {
298         if (likely(list_empty(&msk->join_list)))
299                 return;
300
301         spin_lock_bh(&msk->join_list_lock);
302         list_splice_tail_init(&msk->join_list, &msk->conn_list);
303         spin_unlock_bh(&msk->join_list_lock);
304 }
305
306 static void mptcp_set_timeout(const struct sock *sk, const struct sock *ssk)
307 {
308         long tout = ssk && inet_csk(ssk)->icsk_pending ?
309                                       inet_csk(ssk)->icsk_timeout - jiffies : 0;
310
311         if (tout <= 0)
312                 tout = mptcp_sk(sk)->timer_ival;
313         mptcp_sk(sk)->timer_ival = tout > 0 ? tout : TCP_RTO_MIN;
314 }
315
316 static bool mptcp_timer_pending(struct sock *sk)
317 {
318         return timer_pending(&inet_csk(sk)->icsk_retransmit_timer);
319 }
320
321 static void mptcp_reset_timer(struct sock *sk)
322 {
323         struct inet_connection_sock *icsk = inet_csk(sk);
324         unsigned long tout;
325
326         /* should never be called with mptcp level timer cleared */
327         tout = READ_ONCE(mptcp_sk(sk)->timer_ival);
328         if (WARN_ON_ONCE(!tout))
329                 tout = TCP_RTO_MIN;
330         sk_reset_timer(sk, &icsk->icsk_retransmit_timer, jiffies + tout);
331 }
332
333 void mptcp_data_acked(struct sock *sk)
334 {
335         mptcp_reset_timer(sk);
336
337         if (!sk_stream_is_writeable(sk) &&
338             schedule_work(&mptcp_sk(sk)->work))
339                 sock_hold(sk);
340 }
341
342 void mptcp_subflow_eof(struct sock *sk)
343 {
344         struct mptcp_sock *msk = mptcp_sk(sk);
345
346         if (!test_and_set_bit(MPTCP_WORK_EOF, &msk->flags) &&
347             schedule_work(&msk->work))
348                 sock_hold(sk);
349 }
350
351 static void mptcp_check_for_eof(struct mptcp_sock *msk)
352 {
353         struct mptcp_subflow_context *subflow;
354         struct sock *sk = (struct sock *)msk;
355         int receivers = 0;
356
357         mptcp_for_each_subflow(msk, subflow)
358                 receivers += !subflow->rx_eof;
359
360         if (!receivers && !(sk->sk_shutdown & RCV_SHUTDOWN)) {
361                 /* hopefully temporary hack: propagate shutdown status
362                  * to msk, when all subflows agree on it
363                  */
364                 sk->sk_shutdown |= RCV_SHUTDOWN;
365
366                 smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
367                 set_bit(MPTCP_DATA_READY, &msk->flags);
368                 sk->sk_data_ready(sk);
369         }
370 }
371
372 static void mptcp_stop_timer(struct sock *sk)
373 {
374         struct inet_connection_sock *icsk = inet_csk(sk);
375
376         sk_stop_timer(sk, &icsk->icsk_retransmit_timer);
377         mptcp_sk(sk)->timer_ival = 0;
378 }
379
380 static bool mptcp_ext_cache_refill(struct mptcp_sock *msk)
381 {
382         const struct sock *sk = (const struct sock *)msk;
383
384         if (!msk->cached_ext)
385                 msk->cached_ext = __skb_ext_alloc(sk->sk_allocation);
386
387         return !!msk->cached_ext;
388 }
389
390 static struct sock *mptcp_subflow_recv_lookup(const struct mptcp_sock *msk)
391 {
392         struct mptcp_subflow_context *subflow;
393         struct sock *sk = (struct sock *)msk;
394
395         sock_owned_by_me(sk);
396
397         mptcp_for_each_subflow(msk, subflow) {
398                 if (subflow->data_avail)
399                         return mptcp_subflow_tcp_sock(subflow);
400         }
401
402         return NULL;
403 }
404
405 static bool mptcp_skb_can_collapse_to(u64 write_seq,
406                                       const struct sk_buff *skb,
407                                       const struct mptcp_ext *mpext)
408 {
409         if (!tcp_skb_can_collapse_to(skb))
410                 return false;
411
412         /* can collapse only if MPTCP level sequence is in order */
413         return mpext && mpext->data_seq + mpext->data_len == write_seq;
414 }
415
416 static bool mptcp_frag_can_collapse_to(const struct mptcp_sock *msk,
417                                        const struct page_frag *pfrag,
418                                        const struct mptcp_data_frag *df)
419 {
420         return df && pfrag->page == df->page &&
421                 df->data_seq + df->data_len == msk->write_seq;
422 }
423
424 static void dfrag_uncharge(struct sock *sk, int len)
425 {
426         sk_mem_uncharge(sk, len);
427         sk_wmem_queued_add(sk, -len);
428 }
429
430 static void dfrag_clear(struct sock *sk, struct mptcp_data_frag *dfrag)
431 {
432         int len = dfrag->data_len + dfrag->overhead;
433
434         list_del(&dfrag->list);
435         dfrag_uncharge(sk, len);
436         put_page(dfrag->page);
437 }
438
439 static void mptcp_clean_una(struct sock *sk)
440 {
441         struct mptcp_sock *msk = mptcp_sk(sk);
442         struct mptcp_data_frag *dtmp, *dfrag;
443         bool cleaned = false;
444         u64 snd_una;
445
446         /* on fallback we just need to ignore snd_una, as this is really
447          * plain TCP
448          */
449         if (__mptcp_check_fallback(msk))
450                 atomic64_set(&msk->snd_una, msk->write_seq);
451         snd_una = atomic64_read(&msk->snd_una);
452
453         list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list) {
454                 if (after64(dfrag->data_seq + dfrag->data_len, snd_una))
455                         break;
456
457                 dfrag_clear(sk, dfrag);
458                 cleaned = true;
459         }
460
461         dfrag = mptcp_rtx_head(sk);
462         if (dfrag && after64(snd_una, dfrag->data_seq)) {
463                 u64 delta = snd_una - dfrag->data_seq;
464
465                 if (WARN_ON_ONCE(delta > dfrag->data_len))
466                         goto out;
467
468                 dfrag->data_seq += delta;
469                 dfrag->offset += delta;
470                 dfrag->data_len -= delta;
471
472                 dfrag_uncharge(sk, delta);
473                 cleaned = true;
474         }
475
476 out:
477         if (cleaned) {
478                 sk_mem_reclaim_partial(sk);
479
480                 /* Only wake up writers if a subflow is ready */
481                 if (test_bit(MPTCP_SEND_SPACE, &msk->flags))
482                         sk_stream_write_space(sk);
483         }
484 }
485
486 /* ensure we get enough memory for the frag hdr, beyond some minimal amount of
487  * data
488  */
489 static bool mptcp_page_frag_refill(struct sock *sk, struct page_frag *pfrag)
490 {
491         if (likely(skb_page_frag_refill(32U + sizeof(struct mptcp_data_frag),
492                                         pfrag, sk->sk_allocation)))
493                 return true;
494
495         sk->sk_prot->enter_memory_pressure(sk);
496         sk_stream_moderate_sndbuf(sk);
497         return false;
498 }
499
500 static struct mptcp_data_frag *
501 mptcp_carve_data_frag(const struct mptcp_sock *msk, struct page_frag *pfrag,
502                       int orig_offset)
503 {
504         int offset = ALIGN(orig_offset, sizeof(long));
505         struct mptcp_data_frag *dfrag;
506
507         dfrag = (struct mptcp_data_frag *)(page_to_virt(pfrag->page) + offset);
508         dfrag->data_len = 0;
509         dfrag->data_seq = msk->write_seq;
510         dfrag->overhead = offset - orig_offset + sizeof(struct mptcp_data_frag);
511         dfrag->offset = offset + sizeof(struct mptcp_data_frag);
512         dfrag->page = pfrag->page;
513
514         return dfrag;
515 }
516
517 static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk,
518                               struct msghdr *msg, struct mptcp_data_frag *dfrag,
519                               long *timeo, int *pmss_now,
520                               int *ps_goal)
521 {
522         int mss_now, avail_size, size_goal, offset, ret, frag_truesize = 0;
523         bool dfrag_collapsed, can_collapse = false;
524         struct mptcp_sock *msk = mptcp_sk(sk);
525         struct mptcp_ext *mpext = NULL;
526         bool retransmission = !!dfrag;
527         struct sk_buff *skb, *tail;
528         struct page_frag *pfrag;
529         struct page *page;
530         u64 *write_seq;
531         size_t psize;
532
533         /* use the mptcp page cache so that we can easily move the data
534          * from one substream to another, but do per subflow memory accounting
535          * Note: pfrag is used only !retransmission, but the compiler if
536          * fooled into a warning if we don't init here
537          */
538         pfrag = sk_page_frag(sk);
539         if (!retransmission) {
540                 write_seq = &msk->write_seq;
541                 page = pfrag->page;
542         } else {
543                 write_seq = &dfrag->data_seq;
544                 page = dfrag->page;
545         }
546
547         /* compute copy limit */
548         mss_now = tcp_send_mss(ssk, &size_goal, msg->msg_flags);
549         *pmss_now = mss_now;
550         *ps_goal = size_goal;
551         avail_size = size_goal;
552         skb = tcp_write_queue_tail(ssk);
553         if (skb) {
554                 mpext = skb_ext_find(skb, SKB_EXT_MPTCP);
555
556                 /* Limit the write to the size available in the
557                  * current skb, if any, so that we create at most a new skb.
558                  * Explicitly tells TCP internals to avoid collapsing on later
559                  * queue management operation, to avoid breaking the ext <->
560                  * SSN association set here
561                  */
562                 can_collapse = (size_goal - skb->len > 0) &&
563                               mptcp_skb_can_collapse_to(*write_seq, skb, mpext);
564                 if (!can_collapse)
565                         TCP_SKB_CB(skb)->eor = 1;
566                 else
567                         avail_size = size_goal - skb->len;
568         }
569
570         if (!retransmission) {
571                 /* reuse tail pfrag, if possible, or carve a new one from the
572                  * page allocator
573                  */
574                 dfrag = mptcp_rtx_tail(sk);
575                 offset = pfrag->offset;
576                 dfrag_collapsed = mptcp_frag_can_collapse_to(msk, pfrag, dfrag);
577                 if (!dfrag_collapsed) {
578                         dfrag = mptcp_carve_data_frag(msk, pfrag, offset);
579                         offset = dfrag->offset;
580                         frag_truesize = dfrag->overhead;
581                 }
582                 psize = min_t(size_t, pfrag->size - offset, avail_size);
583
584                 /* Copy to page */
585                 pr_debug("left=%zu", msg_data_left(msg));
586                 psize = copy_page_from_iter(pfrag->page, offset,
587                                             min_t(size_t, msg_data_left(msg),
588                                                   psize),
589                                             &msg->msg_iter);
590                 pr_debug("left=%zu", msg_data_left(msg));
591                 if (!psize)
592                         return -EINVAL;
593
594                 if (!sk_wmem_schedule(sk, psize + dfrag->overhead))
595                         return -ENOMEM;
596         } else {
597                 offset = dfrag->offset;
598                 psize = min_t(size_t, dfrag->data_len, avail_size);
599         }
600
601         /* tell the TCP stack to delay the push so that we can safely
602          * access the skb after the sendpages call
603          */
604         ret = do_tcp_sendpages(ssk, page, offset, psize,
605                                msg->msg_flags | MSG_SENDPAGE_NOTLAST | MSG_DONTWAIT);
606         if (ret <= 0)
607                 return ret;
608
609         frag_truesize += ret;
610         if (!retransmission) {
611                 if (unlikely(ret < psize))
612                         iov_iter_revert(&msg->msg_iter, psize - ret);
613
614                 /* send successful, keep track of sent data for mptcp-level
615                  * retransmission
616                  */
617                 dfrag->data_len += ret;
618                 if (!dfrag_collapsed) {
619                         get_page(dfrag->page);
620                         list_add_tail(&dfrag->list, &msk->rtx_queue);
621                         sk_wmem_queued_add(sk, frag_truesize);
622                 } else {
623                         sk_wmem_queued_add(sk, ret);
624                 }
625
626                 /* charge data on mptcp rtx queue to the master socket
627                  * Note: we charge such data both to sk and ssk
628                  */
629                 sk->sk_forward_alloc -= frag_truesize;
630         }
631
632         /* if the tail skb extension is still the cached one, collapsing
633          * really happened. Note: we can't check for 'same skb' as the sk_buff
634          * hdr on tail can be transmitted, freed and re-allocated by the
635          * do_tcp_sendpages() call
636          */
637         tail = tcp_write_queue_tail(ssk);
638         if (mpext && tail && mpext == skb_ext_find(tail, SKB_EXT_MPTCP)) {
639                 WARN_ON_ONCE(!can_collapse);
640                 mpext->data_len += ret;
641                 goto out;
642         }
643
644         skb = tcp_write_queue_tail(ssk);
645         mpext = __skb_ext_set(skb, SKB_EXT_MPTCP, msk->cached_ext);
646         msk->cached_ext = NULL;
647
648         memset(mpext, 0, sizeof(*mpext));
649         mpext->data_seq = *write_seq;
650         mpext->subflow_seq = mptcp_subflow_ctx(ssk)->rel_write_seq;
651         mpext->data_len = ret;
652         mpext->use_map = 1;
653         mpext->dsn64 = 1;
654
655         pr_debug("data_seq=%llu subflow_seq=%u data_len=%u dsn64=%d",
656                  mpext->data_seq, mpext->subflow_seq, mpext->data_len,
657                  mpext->dsn64);
658
659 out:
660         if (!retransmission)
661                 pfrag->offset += frag_truesize;
662         *write_seq += ret;
663         mptcp_subflow_ctx(ssk)->rel_write_seq += ret;
664
665         return ret;
666 }
667
668 static void mptcp_nospace(struct mptcp_sock *msk, struct socket *sock)
669 {
670         clear_bit(MPTCP_SEND_SPACE, &msk->flags);
671         smp_mb__after_atomic(); /* msk->flags is changed by write_space cb */
672
673         /* enables sk->write_space() callbacks */
674         set_bit(SOCK_NOSPACE, &sock->flags);
675 }
676
677 static struct sock *mptcp_subflow_get_send(struct mptcp_sock *msk)
678 {
679         struct mptcp_subflow_context *subflow;
680         struct sock *backup = NULL;
681
682         sock_owned_by_me((const struct sock *)msk);
683
684         if (!mptcp_ext_cache_refill(msk))
685                 return NULL;
686
687         mptcp_for_each_subflow(msk, subflow) {
688                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
689
690                 if (!sk_stream_memory_free(ssk)) {
691                         struct socket *sock = ssk->sk_socket;
692
693                         if (sock)
694                                 mptcp_nospace(msk, sock);
695
696                         return NULL;
697                 }
698
699                 if (subflow->backup) {
700                         if (!backup)
701                                 backup = ssk;
702
703                         continue;
704                 }
705
706                 return ssk;
707         }
708
709         return backup;
710 }
711
712 static void ssk_check_wmem(struct mptcp_sock *msk, struct sock *ssk)
713 {
714         struct socket *sock;
715
716         if (likely(sk_stream_is_writeable(ssk)))
717                 return;
718
719         sock = READ_ONCE(ssk->sk_socket);
720         if (sock)
721                 mptcp_nospace(msk, sock);
722 }
723
724 static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
725 {
726         int mss_now = 0, size_goal = 0, ret = 0;
727         struct mptcp_sock *msk = mptcp_sk(sk);
728         struct page_frag *pfrag;
729         size_t copied = 0;
730         struct sock *ssk;
731         bool tx_ok;
732         long timeo;
733
734         if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
735                 return -EOPNOTSUPP;
736
737         lock_sock(sk);
738
739         timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
740
741         if ((1 << sk->sk_state) & ~(TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)) {
742                 ret = sk_stream_wait_connect(sk, &timeo);
743                 if (ret)
744                         goto out;
745         }
746
747         pfrag = sk_page_frag(sk);
748 restart:
749         mptcp_clean_una(sk);
750
751         if (sk->sk_err || (sk->sk_shutdown & SEND_SHUTDOWN)) {
752                 ret = -EPIPE;
753                 goto out;
754         }
755
756 wait_for_sndbuf:
757         __mptcp_flush_join_list(msk);
758         ssk = mptcp_subflow_get_send(msk);
759         while (!sk_stream_memory_free(sk) ||
760                !ssk ||
761                !mptcp_page_frag_refill(ssk, pfrag)) {
762                 if (ssk) {
763                         /* make sure retransmit timer is
764                          * running before we wait for memory.
765                          *
766                          * The retransmit timer might be needed
767                          * to make the peer send an up-to-date
768                          * MPTCP Ack.
769                          */
770                         mptcp_set_timeout(sk, ssk);
771                         if (!mptcp_timer_pending(sk))
772                                 mptcp_reset_timer(sk);
773                 }
774
775                 ret = sk_stream_wait_memory(sk, &timeo);
776                 if (ret)
777                         goto out;
778
779                 mptcp_clean_una(sk);
780
781                 ssk = mptcp_subflow_get_send(msk);
782                 if (list_empty(&msk->conn_list)) {
783                         ret = -ENOTCONN;
784                         goto out;
785                 }
786         }
787
788         pr_debug("conn_list->subflow=%p", ssk);
789
790         lock_sock(ssk);
791         tx_ok = msg_data_left(msg);
792         while (tx_ok) {
793                 ret = mptcp_sendmsg_frag(sk, ssk, msg, NULL, &timeo, &mss_now,
794                                          &size_goal);
795                 if (ret < 0) {
796                         if (ret == -EAGAIN && timeo > 0) {
797                                 mptcp_set_timeout(sk, ssk);
798                                 release_sock(ssk);
799                                 goto restart;
800                         }
801                         break;
802                 }
803
804                 copied += ret;
805
806                 tx_ok = msg_data_left(msg);
807                 if (!tx_ok)
808                         break;
809
810                 if (!sk_stream_memory_free(ssk) ||
811                     !mptcp_page_frag_refill(ssk, pfrag) ||
812                     !mptcp_ext_cache_refill(msk)) {
813                         set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
814                         tcp_push(ssk, msg->msg_flags, mss_now,
815                                  tcp_sk(ssk)->nonagle, size_goal);
816                         mptcp_set_timeout(sk, ssk);
817                         release_sock(ssk);
818                         goto restart;
819                 }
820
821                 /* memory is charged to mptcp level socket as well, i.e.
822                  * if msg is very large, mptcp socket may run out of buffer
823                  * space.  mptcp_clean_una() will release data that has
824                  * been acked at mptcp level in the mean time, so there is
825                  * a good chance we can continue sending data right away.
826                  *
827                  * Normally, when the tcp subflow can accept more data, then
828                  * so can the MPTCP socket.  However, we need to cope with
829                  * peers that might lag behind in their MPTCP-level
830                  * acknowledgements, i.e.  data might have been acked at
831                  * tcp level only.  So, we must also check the MPTCP socket
832                  * limits before we send more data.
833                  */
834                 if (unlikely(!sk_stream_memory_free(sk))) {
835                         tcp_push(ssk, msg->msg_flags, mss_now,
836                                  tcp_sk(ssk)->nonagle, size_goal);
837                         mptcp_clean_una(sk);
838                         if (!sk_stream_memory_free(sk)) {
839                                 /* can't send more for now, need to wait for
840                                  * MPTCP-level ACKs from peer.
841                                  *
842                                  * Wakeup will happen via mptcp_clean_una().
843                                  */
844                                 mptcp_set_timeout(sk, ssk);
845                                 release_sock(ssk);
846                                 goto wait_for_sndbuf;
847                         }
848                 }
849         }
850
851         mptcp_set_timeout(sk, ssk);
852         if (copied) {
853                 ret = copied;
854                 tcp_push(ssk, msg->msg_flags, mss_now, tcp_sk(ssk)->nonagle,
855                          size_goal);
856
857                 /* start the timer, if it's not pending */
858                 if (!mptcp_timer_pending(sk))
859                         mptcp_reset_timer(sk);
860         }
861
862         ssk_check_wmem(msk, ssk);
863         release_sock(ssk);
864 out:
865         release_sock(sk);
866         return ret;
867 }
868
869 static void mptcp_wait_data(struct sock *sk, long *timeo)
870 {
871         DEFINE_WAIT_FUNC(wait, woken_wake_function);
872         struct mptcp_sock *msk = mptcp_sk(sk);
873
874         add_wait_queue(sk_sleep(sk), &wait);
875         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
876
877         sk_wait_event(sk, timeo,
878                       test_and_clear_bit(MPTCP_DATA_READY, &msk->flags), &wait);
879
880         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
881         remove_wait_queue(sk_sleep(sk), &wait);
882 }
883
884 static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk,
885                                 struct msghdr *msg,
886                                 size_t len)
887 {
888         struct sock *sk = (struct sock *)msk;
889         struct sk_buff *skb;
890         int copied = 0;
891
892         while ((skb = skb_peek(&sk->sk_receive_queue)) != NULL) {
893                 u32 offset = MPTCP_SKB_CB(skb)->offset;
894                 u32 data_len = skb->len - offset;
895                 u32 count = min_t(size_t, len - copied, data_len);
896                 int err;
897
898                 err = skb_copy_datagram_msg(skb, offset, msg, count);
899                 if (unlikely(err < 0)) {
900                         if (!copied)
901                                 return err;
902                         break;
903                 }
904
905                 copied += count;
906
907                 if (count < data_len) {
908                         MPTCP_SKB_CB(skb)->offset += count;
909                         break;
910                 }
911
912                 __skb_unlink(skb, &sk->sk_receive_queue);
913                 __kfree_skb(skb);
914
915                 if (copied >= len)
916                         break;
917         }
918
919         return copied;
920 }
921
922 /* receive buffer autotuning.  See tcp_rcv_space_adjust for more information.
923  *
924  * Only difference: Use highest rtt estimate of the subflows in use.
925  */
926 static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied)
927 {
928         struct mptcp_subflow_context *subflow;
929         struct sock *sk = (struct sock *)msk;
930         u32 time, advmss = 1;
931         u64 rtt_us, mstamp;
932
933         sock_owned_by_me(sk);
934
935         if (copied <= 0)
936                 return;
937
938         msk->rcvq_space.copied += copied;
939
940         mstamp = div_u64(tcp_clock_ns(), NSEC_PER_USEC);
941         time = tcp_stamp_us_delta(mstamp, msk->rcvq_space.time);
942
943         rtt_us = msk->rcvq_space.rtt_us;
944         if (rtt_us && time < (rtt_us >> 3))
945                 return;
946
947         rtt_us = 0;
948         mptcp_for_each_subflow(msk, subflow) {
949                 const struct tcp_sock *tp;
950                 u64 sf_rtt_us;
951                 u32 sf_advmss;
952
953                 tp = tcp_sk(mptcp_subflow_tcp_sock(subflow));
954
955                 sf_rtt_us = READ_ONCE(tp->rcv_rtt_est.rtt_us);
956                 sf_advmss = READ_ONCE(tp->advmss);
957
958                 rtt_us = max(sf_rtt_us, rtt_us);
959                 advmss = max(sf_advmss, advmss);
960         }
961
962         msk->rcvq_space.rtt_us = rtt_us;
963         if (time < (rtt_us >> 3) || rtt_us == 0)
964                 return;
965
966         if (msk->rcvq_space.copied <= msk->rcvq_space.space)
967                 goto new_measure;
968
969         if (sock_net(sk)->ipv4.sysctl_tcp_moderate_rcvbuf &&
970             !(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) {
971                 int rcvmem, rcvbuf;
972                 u64 rcvwin, grow;
973
974                 rcvwin = ((u64)msk->rcvq_space.copied << 1) + 16 * advmss;
975
976                 grow = rcvwin * (msk->rcvq_space.copied - msk->rcvq_space.space);
977
978                 do_div(grow, msk->rcvq_space.space);
979                 rcvwin += (grow << 1);
980
981                 rcvmem = SKB_TRUESIZE(advmss + MAX_TCP_HEADER);
982                 while (tcp_win_from_space(sk, rcvmem) < advmss)
983                         rcvmem += 128;
984
985                 do_div(rcvwin, advmss);
986                 rcvbuf = min_t(u64, rcvwin * rcvmem,
987                                sock_net(sk)->ipv4.sysctl_tcp_rmem[2]);
988
989                 if (rcvbuf > sk->sk_rcvbuf) {
990                         u32 window_clamp;
991
992                         window_clamp = tcp_win_from_space(sk, rcvbuf);
993                         WRITE_ONCE(sk->sk_rcvbuf, rcvbuf);
994
995                         /* Make subflows follow along.  If we do not do this, we
996                          * get drops at subflow level if skbs can't be moved to
997                          * the mptcp rx queue fast enough (announced rcv_win can
998                          * exceed ssk->sk_rcvbuf).
999                          */
1000                         mptcp_for_each_subflow(msk, subflow) {
1001                                 struct sock *ssk;
1002
1003                                 ssk = mptcp_subflow_tcp_sock(subflow);
1004                                 WRITE_ONCE(ssk->sk_rcvbuf, rcvbuf);
1005                                 tcp_sk(ssk)->window_clamp = window_clamp;
1006                         }
1007                 }
1008         }
1009
1010         msk->rcvq_space.space = msk->rcvq_space.copied;
1011 new_measure:
1012         msk->rcvq_space.copied = 0;
1013         msk->rcvq_space.time = mstamp;
1014 }
1015
1016 static bool __mptcp_move_skbs(struct mptcp_sock *msk)
1017 {
1018         unsigned int moved = 0;
1019         bool done;
1020
1021         do {
1022                 struct sock *ssk = mptcp_subflow_recv_lookup(msk);
1023
1024                 if (!ssk)
1025                         break;
1026
1027                 lock_sock(ssk);
1028                 done = __mptcp_move_skbs_from_subflow(msk, ssk, &moved);
1029                 release_sock(ssk);
1030         } while (!done);
1031
1032         return moved > 0;
1033 }
1034
1035 static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
1036                          int nonblock, int flags, int *addr_len)
1037 {
1038         struct mptcp_sock *msk = mptcp_sk(sk);
1039         int copied = 0;
1040         int target;
1041         long timeo;
1042
1043         if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT))
1044                 return -EOPNOTSUPP;
1045
1046         lock_sock(sk);
1047         timeo = sock_rcvtimeo(sk, nonblock);
1048
1049         len = min_t(size_t, len, INT_MAX);
1050         target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1051         __mptcp_flush_join_list(msk);
1052
1053         while (len > (size_t)copied) {
1054                 int bytes_read;
1055
1056                 bytes_read = __mptcp_recvmsg_mskq(msk, msg, len - copied);
1057                 if (unlikely(bytes_read < 0)) {
1058                         if (!copied)
1059                                 copied = bytes_read;
1060                         goto out_err;
1061                 }
1062
1063                 copied += bytes_read;
1064
1065                 if (skb_queue_empty(&sk->sk_receive_queue) &&
1066                     __mptcp_move_skbs(msk))
1067                         continue;
1068
1069                 /* only the master socket status is relevant here. The exit
1070                  * conditions mirror closely tcp_recvmsg()
1071                  */
1072                 if (copied >= target)
1073                         break;
1074
1075                 if (copied) {
1076                         if (sk->sk_err ||
1077                             sk->sk_state == TCP_CLOSE ||
1078                             (sk->sk_shutdown & RCV_SHUTDOWN) ||
1079                             !timeo ||
1080                             signal_pending(current))
1081                                 break;
1082                 } else {
1083                         if (sk->sk_err) {
1084                                 copied = sock_error(sk);
1085                                 break;
1086                         }
1087
1088                         if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags))
1089                                 mptcp_check_for_eof(msk);
1090
1091                         if (sk->sk_shutdown & RCV_SHUTDOWN)
1092                                 break;
1093
1094                         if (sk->sk_state == TCP_CLOSE) {
1095                                 copied = -ENOTCONN;
1096                                 break;
1097                         }
1098
1099                         if (!timeo) {
1100                                 copied = -EAGAIN;
1101                                 break;
1102                         }
1103
1104                         if (signal_pending(current)) {
1105                                 copied = sock_intr_errno(timeo);
1106                                 break;
1107                         }
1108                 }
1109
1110                 pr_debug("block timeout %ld", timeo);
1111                 mptcp_wait_data(sk, &timeo);
1112         }
1113
1114         if (skb_queue_empty(&sk->sk_receive_queue)) {
1115                 /* entire backlog drained, clear DATA_READY. */
1116                 clear_bit(MPTCP_DATA_READY, &msk->flags);
1117
1118                 /* .. race-breaker: ssk might have gotten new data
1119                  * after last __mptcp_move_skbs() returned false.
1120                  */
1121                 if (unlikely(__mptcp_move_skbs(msk)))
1122                         set_bit(MPTCP_DATA_READY, &msk->flags);
1123         } else if (unlikely(!test_bit(MPTCP_DATA_READY, &msk->flags))) {
1124                 /* data to read but mptcp_wait_data() cleared DATA_READY */
1125                 set_bit(MPTCP_DATA_READY, &msk->flags);
1126         }
1127 out_err:
1128         mptcp_rcv_space_adjust(msk, copied);
1129
1130         release_sock(sk);
1131         return copied;
1132 }
1133
1134 static void mptcp_retransmit_handler(struct sock *sk)
1135 {
1136         struct mptcp_sock *msk = mptcp_sk(sk);
1137
1138         if (atomic64_read(&msk->snd_una) == msk->write_seq) {
1139                 mptcp_stop_timer(sk);
1140         } else {
1141                 set_bit(MPTCP_WORK_RTX, &msk->flags);
1142                 if (schedule_work(&msk->work))
1143                         sock_hold(sk);
1144         }
1145 }
1146
1147 static void mptcp_retransmit_timer(struct timer_list *t)
1148 {
1149         struct inet_connection_sock *icsk = from_timer(icsk, t,
1150                                                        icsk_retransmit_timer);
1151         struct sock *sk = &icsk->icsk_inet.sk;
1152
1153         bh_lock_sock(sk);
1154         if (!sock_owned_by_user(sk)) {
1155                 mptcp_retransmit_handler(sk);
1156         } else {
1157                 /* delegate our work to tcp_release_cb() */
1158                 if (!test_and_set_bit(TCP_WRITE_TIMER_DEFERRED,
1159                                       &sk->sk_tsq_flags))
1160                         sock_hold(sk);
1161         }
1162         bh_unlock_sock(sk);
1163         sock_put(sk);
1164 }
1165
1166 /* Find an idle subflow.  Return NULL if there is unacked data at tcp
1167  * level.
1168  *
1169  * A backup subflow is returned only if that is the only kind available.
1170  */
1171 static struct sock *mptcp_subflow_get_retrans(const struct mptcp_sock *msk)
1172 {
1173         struct mptcp_subflow_context *subflow;
1174         struct sock *backup = NULL;
1175
1176         sock_owned_by_me((const struct sock *)msk);
1177
1178         mptcp_for_each_subflow(msk, subflow) {
1179                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
1180
1181                 /* still data outstanding at TCP level?  Don't retransmit. */
1182                 if (!tcp_write_queue_empty(ssk))
1183                         return NULL;
1184
1185                 if (subflow->backup) {
1186                         if (!backup)
1187                                 backup = ssk;
1188                         continue;
1189                 }
1190
1191                 return ssk;
1192         }
1193
1194         return backup;
1195 }
1196
1197 /* subflow sockets can be either outgoing (connect) or incoming
1198  * (accept).
1199  *
1200  * Outgoing subflows use in-kernel sockets.
1201  * Incoming subflows do not have their own 'struct socket' allocated,
1202  * so we need to use tcp_close() after detaching them from the mptcp
1203  * parent socket.
1204  */
1205 static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
1206                               struct mptcp_subflow_context *subflow,
1207                               long timeout)
1208 {
1209         struct socket *sock = READ_ONCE(ssk->sk_socket);
1210
1211         list_del(&subflow->node);
1212
1213         if (sock && sock != sk->sk_socket) {
1214                 /* outgoing subflow */
1215                 sock_release(sock);
1216         } else {
1217                 /* incoming subflow */
1218                 tcp_close(ssk, timeout);
1219         }
1220 }
1221
1222 static unsigned int mptcp_sync_mss(struct sock *sk, u32 pmtu)
1223 {
1224         return 0;
1225 }
1226
1227 static void pm_work(struct mptcp_sock *msk)
1228 {
1229         struct mptcp_pm_data *pm = &msk->pm;
1230
1231         spin_lock_bh(&msk->pm.lock);
1232
1233         pr_debug("msk=%p status=%x", msk, pm->status);
1234         if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) {
1235                 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED);
1236                 mptcp_pm_nl_add_addr_received(msk);
1237         }
1238         if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) {
1239                 pm->status &= ~BIT(MPTCP_PM_ESTABLISHED);
1240                 mptcp_pm_nl_fully_established(msk);
1241         }
1242         if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) {
1243                 pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED);
1244                 mptcp_pm_nl_subflow_established(msk);
1245         }
1246
1247         spin_unlock_bh(&msk->pm.lock);
1248 }
1249
1250 static void mptcp_worker(struct work_struct *work)
1251 {
1252         struct mptcp_sock *msk = container_of(work, struct mptcp_sock, work);
1253         struct sock *ssk, *sk = &msk->sk.icsk_inet.sk;
1254         int orig_len, orig_offset, mss_now = 0, size_goal = 0;
1255         struct mptcp_data_frag *dfrag;
1256         u64 orig_write_seq;
1257         size_t copied = 0;
1258         struct msghdr msg;
1259         long timeo = 0;
1260
1261         lock_sock(sk);
1262         mptcp_clean_una(sk);
1263         __mptcp_flush_join_list(msk);
1264         __mptcp_move_skbs(msk);
1265
1266         if (msk->pm.status)
1267                 pm_work(msk);
1268
1269         if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags))
1270                 mptcp_check_for_eof(msk);
1271
1272         if (!test_and_clear_bit(MPTCP_WORK_RTX, &msk->flags))
1273                 goto unlock;
1274
1275         dfrag = mptcp_rtx_head(sk);
1276         if (!dfrag)
1277                 goto unlock;
1278
1279         if (!mptcp_ext_cache_refill(msk))
1280                 goto reset_unlock;
1281
1282         ssk = mptcp_subflow_get_retrans(msk);
1283         if (!ssk)
1284                 goto reset_unlock;
1285
1286         lock_sock(ssk);
1287
1288         msg.msg_flags = MSG_DONTWAIT;
1289         orig_len = dfrag->data_len;
1290         orig_offset = dfrag->offset;
1291         orig_write_seq = dfrag->data_seq;
1292         while (dfrag->data_len > 0) {
1293                 int ret = mptcp_sendmsg_frag(sk, ssk, &msg, dfrag, &timeo,
1294                                              &mss_now, &size_goal);
1295                 if (ret < 0)
1296                         break;
1297
1298                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RETRANSSEGS);
1299                 copied += ret;
1300                 dfrag->data_len -= ret;
1301                 dfrag->offset += ret;
1302
1303                 if (!mptcp_ext_cache_refill(msk))
1304                         break;
1305         }
1306         if (copied)
1307                 tcp_push(ssk, msg.msg_flags, mss_now, tcp_sk(ssk)->nonagle,
1308                          size_goal);
1309
1310         dfrag->data_seq = orig_write_seq;
1311         dfrag->offset = orig_offset;
1312         dfrag->data_len = orig_len;
1313
1314         mptcp_set_timeout(sk, ssk);
1315         release_sock(ssk);
1316
1317 reset_unlock:
1318         if (!mptcp_timer_pending(sk))
1319                 mptcp_reset_timer(sk);
1320
1321 unlock:
1322         release_sock(sk);
1323         sock_put(sk);
1324 }
1325
1326 static int __mptcp_init_sock(struct sock *sk)
1327 {
1328         struct mptcp_sock *msk = mptcp_sk(sk);
1329
1330         spin_lock_init(&msk->join_list_lock);
1331
1332         INIT_LIST_HEAD(&msk->conn_list);
1333         INIT_LIST_HEAD(&msk->join_list);
1334         INIT_LIST_HEAD(&msk->rtx_queue);
1335         __set_bit(MPTCP_SEND_SPACE, &msk->flags);
1336         INIT_WORK(&msk->work, mptcp_worker);
1337
1338         msk->first = NULL;
1339         inet_csk(sk)->icsk_sync_mss = mptcp_sync_mss;
1340
1341         mptcp_pm_data_init(msk);
1342
1343         /* re-use the csk retrans timer for MPTCP-level retrans */
1344         timer_setup(&msk->sk.icsk_retransmit_timer, mptcp_retransmit_timer, 0);
1345
1346         return 0;
1347 }
1348
1349 static int mptcp_init_sock(struct sock *sk)
1350 {
1351         struct net *net = sock_net(sk);
1352         int ret;
1353
1354         if (!mptcp_is_enabled(net))
1355                 return -ENOPROTOOPT;
1356
1357         if (unlikely(!net->mib.mptcp_statistics) && !mptcp_mib_alloc(net))
1358                 return -ENOMEM;
1359
1360         ret = __mptcp_init_sock(sk);
1361         if (ret)
1362                 return ret;
1363
1364         ret = __mptcp_socket_create(mptcp_sk(sk));
1365         if (ret)
1366                 return ret;
1367
1368         sk_sockets_allocated_inc(sk);
1369         sk->sk_rcvbuf = sock_net(sk)->ipv4.sysctl_tcp_rmem[1];
1370         sk->sk_sndbuf = sock_net(sk)->ipv4.sysctl_tcp_wmem[2];
1371
1372         return 0;
1373 }
1374
1375 static void __mptcp_clear_xmit(struct sock *sk)
1376 {
1377         struct mptcp_sock *msk = mptcp_sk(sk);
1378         struct mptcp_data_frag *dtmp, *dfrag;
1379
1380         sk_stop_timer(sk, &msk->sk.icsk_retransmit_timer);
1381
1382         list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list)
1383                 dfrag_clear(sk, dfrag);
1384 }
1385
1386 static void mptcp_cancel_work(struct sock *sk)
1387 {
1388         struct mptcp_sock *msk = mptcp_sk(sk);
1389
1390         if (cancel_work_sync(&msk->work))
1391                 sock_put(sk);
1392 }
1393
1394 static void mptcp_subflow_shutdown(struct sock *ssk, int how,
1395                                    bool data_fin_tx_enable, u64 data_fin_tx_seq)
1396 {
1397         lock_sock(ssk);
1398
1399         switch (ssk->sk_state) {
1400         case TCP_LISTEN:
1401                 if (!(how & RCV_SHUTDOWN))
1402                         break;
1403                 /* fall through */
1404         case TCP_SYN_SENT:
1405                 tcp_disconnect(ssk, O_NONBLOCK);
1406                 break;
1407         default:
1408                 if (data_fin_tx_enable) {
1409                         struct mptcp_subflow_context *subflow;
1410
1411                         subflow = mptcp_subflow_ctx(ssk);
1412                         subflow->data_fin_tx_seq = data_fin_tx_seq;
1413                         subflow->data_fin_tx_enable = 1;
1414                 }
1415
1416                 ssk->sk_shutdown |= how;
1417                 tcp_shutdown(ssk, how);
1418                 break;
1419         }
1420
1421         release_sock(ssk);
1422 }
1423
1424 /* Called with msk lock held, releases such lock before returning */
1425 static void mptcp_close(struct sock *sk, long timeout)
1426 {
1427         struct mptcp_subflow_context *subflow, *tmp;
1428         struct mptcp_sock *msk = mptcp_sk(sk);
1429         LIST_HEAD(conn_list);
1430         u64 data_fin_tx_seq;
1431
1432         lock_sock(sk);
1433
1434         inet_sk_state_store(sk, TCP_CLOSE);
1435
1436         /* be sure to always acquire the join list lock, to sync vs
1437          * mptcp_finish_join().
1438          */
1439         spin_lock_bh(&msk->join_list_lock);
1440         list_splice_tail_init(&msk->join_list, &msk->conn_list);
1441         spin_unlock_bh(&msk->join_list_lock);
1442         list_splice_init(&msk->conn_list, &conn_list);
1443
1444         data_fin_tx_seq = msk->write_seq;
1445
1446         __mptcp_clear_xmit(sk);
1447
1448         release_sock(sk);
1449
1450         list_for_each_entry_safe(subflow, tmp, &conn_list, node) {
1451                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
1452
1453                 subflow->data_fin_tx_seq = data_fin_tx_seq;
1454                 subflow->data_fin_tx_enable = 1;
1455                 __mptcp_close_ssk(sk, ssk, subflow, timeout);
1456         }
1457
1458         mptcp_cancel_work(sk);
1459
1460         __skb_queue_purge(&sk->sk_receive_queue);
1461
1462         sk_common_release(sk);
1463 }
1464
1465 static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
1466 {
1467 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1468         const struct ipv6_pinfo *ssk6 = inet6_sk(ssk);
1469         struct ipv6_pinfo *msk6 = inet6_sk(msk);
1470
1471         msk->sk_v6_daddr = ssk->sk_v6_daddr;
1472         msk->sk_v6_rcv_saddr = ssk->sk_v6_rcv_saddr;
1473
1474         if (msk6 && ssk6) {
1475                 msk6->saddr = ssk6->saddr;
1476                 msk6->flow_label = ssk6->flow_label;
1477         }
1478 #endif
1479
1480         inet_sk(msk)->inet_num = inet_sk(ssk)->inet_num;
1481         inet_sk(msk)->inet_dport = inet_sk(ssk)->inet_dport;
1482         inet_sk(msk)->inet_sport = inet_sk(ssk)->inet_sport;
1483         inet_sk(msk)->inet_daddr = inet_sk(ssk)->inet_daddr;
1484         inet_sk(msk)->inet_saddr = inet_sk(ssk)->inet_saddr;
1485         inet_sk(msk)->inet_rcv_saddr = inet_sk(ssk)->inet_rcv_saddr;
1486 }
1487
1488 static int mptcp_disconnect(struct sock *sk, int flags)
1489 {
1490         /* Should never be called.
1491          * inet_stream_connect() calls ->disconnect, but that
1492          * refers to the subflow socket, not the mptcp one.
1493          */
1494         WARN_ON_ONCE(1);
1495         return 0;
1496 }
1497
1498 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1499 static struct ipv6_pinfo *mptcp_inet6_sk(const struct sock *sk)
1500 {
1501         unsigned int offset = sizeof(struct mptcp6_sock) - sizeof(struct ipv6_pinfo);
1502
1503         return (struct ipv6_pinfo *)(((u8 *)sk) + offset);
1504 }
1505 #endif
1506
1507 struct sock *mptcp_sk_clone(const struct sock *sk,
1508                             const struct mptcp_options_received *mp_opt,
1509                             struct request_sock *req)
1510 {
1511         struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
1512         struct sock *nsk = sk_clone_lock(sk, GFP_ATOMIC);
1513         struct mptcp_sock *msk;
1514         u64 ack_seq;
1515
1516         if (!nsk)
1517                 return NULL;
1518
1519 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1520         if (nsk->sk_family == AF_INET6)
1521                 inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk);
1522 #endif
1523
1524         __mptcp_init_sock(nsk);
1525
1526         msk = mptcp_sk(nsk);
1527         msk->local_key = subflow_req->local_key;
1528         msk->token = subflow_req->token;
1529         msk->subflow = NULL;
1530         WRITE_ONCE(msk->fully_established, false);
1531
1532         msk->write_seq = subflow_req->idsn + 1;
1533         atomic64_set(&msk->snd_una, msk->write_seq);
1534         if (mp_opt->mp_capable) {
1535                 msk->can_ack = true;
1536                 msk->remote_key = mp_opt->sndr_key;
1537                 mptcp_crypto_key_sha(msk->remote_key, NULL, &ack_seq);
1538                 ack_seq++;
1539                 msk->ack_seq = ack_seq;
1540         }
1541
1542         sock_reset_flag(nsk, SOCK_RCU_FREE);
1543         /* will be fully established after successful MPC subflow creation */
1544         inet_sk_state_store(nsk, TCP_SYN_RECV);
1545         bh_unlock_sock(nsk);
1546
1547         /* keep a single reference */
1548         __sock_put(nsk);
1549         return nsk;
1550 }
1551
1552 void mptcp_rcv_space_init(struct mptcp_sock *msk, const struct sock *ssk)
1553 {
1554         const struct tcp_sock *tp = tcp_sk(ssk);
1555
1556         msk->rcvq_space.copied = 0;
1557         msk->rcvq_space.rtt_us = 0;
1558
1559         msk->rcvq_space.time = tp->tcp_mstamp;
1560
1561         /* initial rcv_space offering made to peer */
1562         msk->rcvq_space.space = min_t(u32, tp->rcv_wnd,
1563                                       TCP_INIT_CWND * tp->advmss);
1564         if (msk->rcvq_space.space == 0)
1565                 msk->rcvq_space.space = TCP_INIT_CWND * TCP_MSS_DEFAULT;
1566 }
1567
1568 static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
1569                                  bool kern)
1570 {
1571         struct mptcp_sock *msk = mptcp_sk(sk);
1572         struct socket *listener;
1573         struct sock *newsk;
1574
1575         listener = __mptcp_nmpc_socket(msk);
1576         if (WARN_ON_ONCE(!listener)) {
1577                 *err = -EINVAL;
1578                 return NULL;
1579         }
1580
1581         pr_debug("msk=%p, listener=%p", msk, mptcp_subflow_ctx(listener->sk));
1582         newsk = inet_csk_accept(listener->sk, flags, err, kern);
1583         if (!newsk)
1584                 return NULL;
1585
1586         pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
1587         if (sk_is_mptcp(newsk)) {
1588                 struct mptcp_subflow_context *subflow;
1589                 struct sock *new_mptcp_sock;
1590                 struct sock *ssk = newsk;
1591
1592                 subflow = mptcp_subflow_ctx(newsk);
1593                 new_mptcp_sock = subflow->conn;
1594
1595                 /* is_mptcp should be false if subflow->conn is missing, see
1596                  * subflow_syn_recv_sock()
1597                  */
1598                 if (WARN_ON_ONCE(!new_mptcp_sock)) {
1599                         tcp_sk(newsk)->is_mptcp = 0;
1600                         return newsk;
1601                 }
1602
1603                 /* acquire the 2nd reference for the owning socket */
1604                 sock_hold(new_mptcp_sock);
1605
1606                 local_bh_disable();
1607                 bh_lock_sock(new_mptcp_sock);
1608                 msk = mptcp_sk(new_mptcp_sock);
1609                 msk->first = newsk;
1610
1611                 newsk = new_mptcp_sock;
1612                 mptcp_copy_inaddrs(newsk, ssk);
1613                 list_add(&subflow->node, &msk->conn_list);
1614
1615                 mptcp_rcv_space_init(msk, ssk);
1616                 bh_unlock_sock(new_mptcp_sock);
1617
1618                 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEPASSIVEACK);
1619                 local_bh_enable();
1620         } else {
1621                 MPTCP_INC_STATS(sock_net(sk),
1622                                 MPTCP_MIB_MPCAPABLEPASSIVEFALLBACK);
1623         }
1624
1625         return newsk;
1626 }
1627
1628 static void mptcp_destroy(struct sock *sk)
1629 {
1630         struct mptcp_sock *msk = mptcp_sk(sk);
1631
1632         mptcp_token_destroy(msk);
1633         if (msk->cached_ext)
1634                 __skb_ext_put(msk->cached_ext);
1635
1636         sk_sockets_allocated_dec(sk);
1637 }
1638
1639 static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname,
1640                                        sockptr_t optval, unsigned int optlen)
1641 {
1642         struct sock *sk = (struct sock *)msk;
1643         struct socket *ssock;
1644         int ret;
1645
1646         switch (optname) {
1647         case SO_REUSEPORT:
1648         case SO_REUSEADDR:
1649                 lock_sock(sk);
1650                 ssock = __mptcp_nmpc_socket(msk);
1651                 if (!ssock) {
1652                         release_sock(sk);
1653                         return -EINVAL;
1654                 }
1655
1656                 ret = sock_setsockopt(ssock, SOL_SOCKET, optname, optval, optlen);
1657                 if (ret == 0) {
1658                         if (optname == SO_REUSEPORT)
1659                                 sk->sk_reuseport = ssock->sk->sk_reuseport;
1660                         else if (optname == SO_REUSEADDR)
1661                                 sk->sk_reuse = ssock->sk->sk_reuse;
1662                 }
1663                 release_sock(sk);
1664                 return ret;
1665         }
1666
1667         return sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, optval, optlen);
1668 }
1669
1670 static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname,
1671                                sockptr_t optval, unsigned int optlen)
1672 {
1673         struct sock *sk = (struct sock *)msk;
1674         int ret = -EOPNOTSUPP;
1675         struct socket *ssock;
1676
1677         switch (optname) {
1678         case IPV6_V6ONLY:
1679                 lock_sock(sk);
1680                 ssock = __mptcp_nmpc_socket(msk);
1681                 if (!ssock) {
1682                         release_sock(sk);
1683                         return -EINVAL;
1684                 }
1685
1686                 ret = tcp_setsockopt(ssock->sk, SOL_IPV6, optname, optval, optlen);
1687                 if (ret == 0)
1688                         sk->sk_ipv6only = ssock->sk->sk_ipv6only;
1689
1690                 release_sock(sk);
1691                 break;
1692         }
1693
1694         return ret;
1695 }
1696
1697 static int mptcp_setsockopt(struct sock *sk, int level, int optname,
1698                             sockptr_t optval, unsigned int optlen)
1699 {
1700         struct mptcp_sock *msk = mptcp_sk(sk);
1701         struct sock *ssk;
1702
1703         pr_debug("msk=%p", msk);
1704
1705         if (level == SOL_SOCKET)
1706                 return mptcp_setsockopt_sol_socket(msk, optname, optval, optlen);
1707
1708         /* @@ the meaning of setsockopt() when the socket is connected and
1709          * there are multiple subflows is not yet defined. It is up to the
1710          * MPTCP-level socket to configure the subflows until the subflow
1711          * is in TCP fallback, when TCP socket options are passed through
1712          * to the one remaining subflow.
1713          */
1714         lock_sock(sk);
1715         ssk = __mptcp_tcp_fallback(msk);
1716         release_sock(sk);
1717         if (ssk)
1718                 return tcp_setsockopt(ssk, level, optname, optval, optlen);
1719
1720         if (level == SOL_IPV6)
1721                 return mptcp_setsockopt_v6(msk, optname, optval, optlen);
1722
1723         return -EOPNOTSUPP;
1724 }
1725
1726 static int mptcp_getsockopt(struct sock *sk, int level, int optname,
1727                             char __user *optval, int __user *option)
1728 {
1729         struct mptcp_sock *msk = mptcp_sk(sk);
1730         struct sock *ssk;
1731
1732         pr_debug("msk=%p", msk);
1733
1734         /* @@ the meaning of setsockopt() when the socket is connected and
1735          * there are multiple subflows is not yet defined. It is up to the
1736          * MPTCP-level socket to configure the subflows until the subflow
1737          * is in TCP fallback, when socket options are passed through
1738          * to the one remaining subflow.
1739          */
1740         lock_sock(sk);
1741         ssk = __mptcp_tcp_fallback(msk);
1742         release_sock(sk);
1743         if (ssk)
1744                 return tcp_getsockopt(ssk, level, optname, optval, option);
1745
1746         return -EOPNOTSUPP;
1747 }
1748
1749 #define MPTCP_DEFERRED_ALL (TCPF_DELACK_TIMER_DEFERRED | \
1750                             TCPF_WRITE_TIMER_DEFERRED)
1751
1752 /* this is very alike tcp_release_cb() but we must handle differently a
1753  * different set of events
1754  */
1755 static void mptcp_release_cb(struct sock *sk)
1756 {
1757         unsigned long flags, nflags;
1758
1759         do {
1760                 flags = sk->sk_tsq_flags;
1761                 if (!(flags & MPTCP_DEFERRED_ALL))
1762                         return;
1763                 nflags = flags & ~MPTCP_DEFERRED_ALL;
1764         } while (cmpxchg(&sk->sk_tsq_flags, flags, nflags) != flags);
1765
1766         sock_release_ownership(sk);
1767
1768         if (flags & TCPF_DELACK_TIMER_DEFERRED) {
1769                 struct mptcp_sock *msk = mptcp_sk(sk);
1770                 struct sock *ssk;
1771
1772                 ssk = mptcp_subflow_recv_lookup(msk);
1773                 if (!ssk || !schedule_work(&msk->work))
1774                         __sock_put(sk);
1775         }
1776
1777         if (flags & TCPF_WRITE_TIMER_DEFERRED) {
1778                 mptcp_retransmit_handler(sk);
1779                 __sock_put(sk);
1780         }
1781 }
1782
1783 static int mptcp_hash(struct sock *sk)
1784 {
1785         /* should never be called,
1786          * we hash the TCP subflows not the master socket
1787          */
1788         WARN_ON_ONCE(1);
1789         return 0;
1790 }
1791
1792 static void mptcp_unhash(struct sock *sk)
1793 {
1794         /* called from sk_common_release(), but nothing to do here */
1795 }
1796
1797 static int mptcp_get_port(struct sock *sk, unsigned short snum)
1798 {
1799         struct mptcp_sock *msk = mptcp_sk(sk);
1800         struct socket *ssock;
1801
1802         ssock = __mptcp_nmpc_socket(msk);
1803         pr_debug("msk=%p, subflow=%p", msk, ssock);
1804         if (WARN_ON_ONCE(!ssock))
1805                 return -EINVAL;
1806
1807         return inet_csk_get_port(ssock->sk, snum);
1808 }
1809
1810 void mptcp_finish_connect(struct sock *ssk)
1811 {
1812         struct mptcp_subflow_context *subflow;
1813         struct mptcp_sock *msk;
1814         struct sock *sk;
1815         u64 ack_seq;
1816
1817         subflow = mptcp_subflow_ctx(ssk);
1818         sk = subflow->conn;
1819         msk = mptcp_sk(sk);
1820
1821         pr_debug("msk=%p, token=%u", sk, subflow->token);
1822
1823         mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq);
1824         ack_seq++;
1825         subflow->map_seq = ack_seq;
1826         subflow->map_subflow_seq = 1;
1827
1828         /* the socket is not connected yet, no msk/subflow ops can access/race
1829          * accessing the field below
1830          */
1831         WRITE_ONCE(msk->remote_key, subflow->remote_key);
1832         WRITE_ONCE(msk->local_key, subflow->local_key);
1833         WRITE_ONCE(msk->write_seq, subflow->idsn + 1);
1834         WRITE_ONCE(msk->ack_seq, ack_seq);
1835         WRITE_ONCE(msk->can_ack, 1);
1836         atomic64_set(&msk->snd_una, msk->write_seq);
1837
1838         mptcp_pm_new_connection(msk, 0);
1839
1840         mptcp_rcv_space_init(msk, ssk);
1841 }
1842
1843 static void mptcp_sock_graft(struct sock *sk, struct socket *parent)
1844 {
1845         write_lock_bh(&sk->sk_callback_lock);
1846         rcu_assign_pointer(sk->sk_wq, &parent->wq);
1847         sk_set_socket(sk, parent);
1848         sk->sk_uid = SOCK_INODE(parent)->i_uid;
1849         write_unlock_bh(&sk->sk_callback_lock);
1850 }
1851
1852 bool mptcp_finish_join(struct sock *sk)
1853 {
1854         struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
1855         struct mptcp_sock *msk = mptcp_sk(subflow->conn);
1856         struct sock *parent = (void *)msk;
1857         struct socket *parent_sock;
1858         bool ret;
1859
1860         pr_debug("msk=%p, subflow=%p", msk, subflow);
1861
1862         /* mptcp socket already closing? */
1863         if (!mptcp_is_fully_established(parent))
1864                 return false;
1865
1866         if (!msk->pm.server_side)
1867                 return true;
1868
1869         if (!mptcp_pm_allow_new_subflow(msk))
1870                 return false;
1871
1872         /* active connections are already on conn_list, and we can't acquire
1873          * msk lock here.
1874          * use the join list lock as synchronization point and double-check
1875          * msk status to avoid racing with mptcp_close()
1876          */
1877         spin_lock_bh(&msk->join_list_lock);
1878         ret = inet_sk_state_load(parent) == TCP_ESTABLISHED;
1879         if (ret && !WARN_ON_ONCE(!list_empty(&subflow->node)))
1880                 list_add_tail(&subflow->node, &msk->join_list);
1881         spin_unlock_bh(&msk->join_list_lock);
1882         if (!ret)
1883                 return false;
1884
1885         /* attach to msk socket only after we are sure he will deal with us
1886          * at close time
1887          */
1888         parent_sock = READ_ONCE(parent->sk_socket);
1889         if (parent_sock && !sk->sk_socket)
1890                 mptcp_sock_graft(sk, parent_sock);
1891         subflow->map_seq = msk->ack_seq;
1892         return true;
1893 }
1894
1895 static bool mptcp_memory_free(const struct sock *sk, int wake)
1896 {
1897         struct mptcp_sock *msk = mptcp_sk(sk);
1898
1899         return wake ? test_bit(MPTCP_SEND_SPACE, &msk->flags) : true;
1900 }
1901
1902 static struct proto mptcp_prot = {
1903         .name           = "MPTCP",
1904         .owner          = THIS_MODULE,
1905         .init           = mptcp_init_sock,
1906         .disconnect     = mptcp_disconnect,
1907         .close          = mptcp_close,
1908         .accept         = mptcp_accept,
1909         .setsockopt     = mptcp_setsockopt,
1910         .getsockopt     = mptcp_getsockopt,
1911         .shutdown       = tcp_shutdown,
1912         .destroy        = mptcp_destroy,
1913         .sendmsg        = mptcp_sendmsg,
1914         .recvmsg        = mptcp_recvmsg,
1915         .release_cb     = mptcp_release_cb,
1916         .hash           = mptcp_hash,
1917         .unhash         = mptcp_unhash,
1918         .get_port       = mptcp_get_port,
1919         .sockets_allocated      = &mptcp_sockets_allocated,
1920         .memory_allocated       = &tcp_memory_allocated,
1921         .memory_pressure        = &tcp_memory_pressure,
1922         .stream_memory_free     = mptcp_memory_free,
1923         .sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_tcp_wmem),
1924         .sysctl_mem     = sysctl_tcp_mem,
1925         .obj_size       = sizeof(struct mptcp_sock),
1926         .slab_flags     = SLAB_TYPESAFE_BY_RCU,
1927         .no_autobind    = true,
1928 };
1929
1930 static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
1931 {
1932         struct mptcp_sock *msk = mptcp_sk(sock->sk);
1933         struct socket *ssock;
1934         int err;
1935
1936         lock_sock(sock->sk);
1937         ssock = __mptcp_nmpc_socket(msk);
1938         if (!ssock) {
1939                 err = -EINVAL;
1940                 goto unlock;
1941         }
1942
1943         err = ssock->ops->bind(ssock, uaddr, addr_len);
1944         if (!err)
1945                 mptcp_copy_inaddrs(sock->sk, ssock->sk);
1946
1947 unlock:
1948         release_sock(sock->sk);
1949         return err;
1950 }
1951
1952 static void mptcp_subflow_early_fallback(struct mptcp_sock *msk,
1953                                          struct mptcp_subflow_context *subflow)
1954 {
1955         subflow->request_mptcp = 0;
1956         __mptcp_do_fallback(msk);
1957 }
1958
1959 static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
1960                                 int addr_len, int flags)
1961 {
1962         struct mptcp_sock *msk = mptcp_sk(sock->sk);
1963         struct mptcp_subflow_context *subflow;
1964         struct socket *ssock;
1965         int err;
1966
1967         lock_sock(sock->sk);
1968         if (sock->state != SS_UNCONNECTED && msk->subflow) {
1969                 /* pending connection or invalid state, let existing subflow
1970                  * cope with that
1971                  */
1972                 ssock = msk->subflow;
1973                 goto do_connect;
1974         }
1975
1976         ssock = __mptcp_nmpc_socket(msk);
1977         if (!ssock) {
1978                 err = -EINVAL;
1979                 goto unlock;
1980         }
1981
1982         mptcp_token_destroy(msk);
1983         inet_sk_state_store(sock->sk, TCP_SYN_SENT);
1984         subflow = mptcp_subflow_ctx(ssock->sk);
1985 #ifdef CONFIG_TCP_MD5SIG
1986         /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
1987          * TCP option space.
1988          */
1989         if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info))
1990                 mptcp_subflow_early_fallback(msk, subflow);
1991 #endif
1992         if (subflow->request_mptcp && mptcp_token_new_connect(ssock->sk))
1993                 mptcp_subflow_early_fallback(msk, subflow);
1994
1995 do_connect:
1996         err = ssock->ops->connect(ssock, uaddr, addr_len, flags);
1997         sock->state = ssock->state;
1998
1999         /* on successful connect, the msk state will be moved to established by
2000          * subflow_finish_connect()
2001          */
2002         if (!err || err == EINPROGRESS)
2003                 mptcp_copy_inaddrs(sock->sk, ssock->sk);
2004         else
2005                 inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
2006
2007 unlock:
2008         release_sock(sock->sk);
2009         return err;
2010 }
2011
2012 static int mptcp_listen(struct socket *sock, int backlog)
2013 {
2014         struct mptcp_sock *msk = mptcp_sk(sock->sk);
2015         struct socket *ssock;
2016         int err;
2017
2018         pr_debug("msk=%p", msk);
2019
2020         lock_sock(sock->sk);
2021         ssock = __mptcp_nmpc_socket(msk);
2022         if (!ssock) {
2023                 err = -EINVAL;
2024                 goto unlock;
2025         }
2026
2027         mptcp_token_destroy(msk);
2028         inet_sk_state_store(sock->sk, TCP_LISTEN);
2029         sock_set_flag(sock->sk, SOCK_RCU_FREE);
2030
2031         err = ssock->ops->listen(ssock, backlog);
2032         inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
2033         if (!err)
2034                 mptcp_copy_inaddrs(sock->sk, ssock->sk);
2035
2036 unlock:
2037         release_sock(sock->sk);
2038         return err;
2039 }
2040
2041 static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
2042                                int flags, bool kern)
2043 {
2044         struct mptcp_sock *msk = mptcp_sk(sock->sk);
2045         struct socket *ssock;
2046         int err;
2047
2048         pr_debug("msk=%p", msk);
2049
2050         lock_sock(sock->sk);
2051         if (sock->sk->sk_state != TCP_LISTEN)
2052                 goto unlock_fail;
2053
2054         ssock = __mptcp_nmpc_socket(msk);
2055         if (!ssock)
2056                 goto unlock_fail;
2057
2058         clear_bit(MPTCP_DATA_READY, &msk->flags);
2059         sock_hold(ssock->sk);
2060         release_sock(sock->sk);
2061
2062         err = ssock->ops->accept(sock, newsock, flags, kern);
2063         if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) {
2064                 struct mptcp_sock *msk = mptcp_sk(newsock->sk);
2065                 struct mptcp_subflow_context *subflow;
2066
2067                 /* set ssk->sk_socket of accept()ed flows to mptcp socket.
2068                  * This is needed so NOSPACE flag can be set from tcp stack.
2069                  */
2070                 __mptcp_flush_join_list(msk);
2071                 list_for_each_entry(subflow, &msk->conn_list, node) {
2072                         struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
2073
2074                         if (!ssk->sk_socket)
2075                                 mptcp_sock_graft(ssk, newsock);
2076                 }
2077         }
2078
2079         if (inet_csk_listen_poll(ssock->sk))
2080                 set_bit(MPTCP_DATA_READY, &msk->flags);
2081         sock_put(ssock->sk);
2082         return err;
2083
2084 unlock_fail:
2085         release_sock(sock->sk);
2086         return -EINVAL;
2087 }
2088
2089 static __poll_t mptcp_check_readable(struct mptcp_sock *msk)
2090 {
2091         return test_bit(MPTCP_DATA_READY, &msk->flags) ? EPOLLIN | EPOLLRDNORM :
2092                0;
2093 }
2094
2095 static __poll_t mptcp_poll(struct file *file, struct socket *sock,
2096                            struct poll_table_struct *wait)
2097 {
2098         struct sock *sk = sock->sk;
2099         struct mptcp_sock *msk;
2100         __poll_t mask = 0;
2101         int state;
2102
2103         msk = mptcp_sk(sk);
2104         sock_poll_wait(file, sock, wait);
2105
2106         state = inet_sk_state_load(sk);
2107         if (state == TCP_LISTEN)
2108                 return mptcp_check_readable(msk);
2109
2110         if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) {
2111                 mask |= mptcp_check_readable(msk);
2112                 if (sk_stream_is_writeable(sk) &&
2113                     test_bit(MPTCP_SEND_SPACE, &msk->flags))
2114                         mask |= EPOLLOUT | EPOLLWRNORM;
2115         }
2116         if (sk->sk_shutdown & RCV_SHUTDOWN)
2117                 mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
2118
2119         return mask;
2120 }
2121
2122 static int mptcp_shutdown(struct socket *sock, int how)
2123 {
2124         struct mptcp_sock *msk = mptcp_sk(sock->sk);
2125         struct mptcp_subflow_context *subflow;
2126         int ret = 0;
2127
2128         pr_debug("sk=%p, how=%d", msk, how);
2129
2130         lock_sock(sock->sk);
2131         if (how == SHUT_WR || how == SHUT_RDWR)
2132                 inet_sk_state_store(sock->sk, TCP_FIN_WAIT1);
2133
2134         how++;
2135
2136         if ((how & ~SHUTDOWN_MASK) || !how) {
2137                 ret = -EINVAL;
2138                 goto out_unlock;
2139         }
2140
2141         if (sock->state == SS_CONNECTING) {
2142                 if ((1 << sock->sk->sk_state) &
2143                     (TCPF_SYN_SENT | TCPF_SYN_RECV | TCPF_CLOSE))
2144                         sock->state = SS_DISCONNECTING;
2145                 else
2146                         sock->state = SS_CONNECTED;
2147         }
2148
2149         __mptcp_flush_join_list(msk);
2150         mptcp_for_each_subflow(msk, subflow) {
2151                 struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
2152
2153                 mptcp_subflow_shutdown(tcp_sk, how, 1, msk->write_seq);
2154         }
2155
2156         /* Wake up anyone sleeping in poll. */
2157         sock->sk->sk_state_change(sock->sk);
2158
2159 out_unlock:
2160         release_sock(sock->sk);
2161
2162         return ret;
2163 }
2164
2165 static const struct proto_ops mptcp_stream_ops = {
2166         .family            = PF_INET,
2167         .owner             = THIS_MODULE,
2168         .release           = inet_release,
2169         .bind              = mptcp_bind,
2170         .connect           = mptcp_stream_connect,
2171         .socketpair        = sock_no_socketpair,
2172         .accept            = mptcp_stream_accept,
2173         .getname           = inet_getname,
2174         .poll              = mptcp_poll,
2175         .ioctl             = inet_ioctl,
2176         .gettstamp         = sock_gettstamp,
2177         .listen            = mptcp_listen,
2178         .shutdown          = mptcp_shutdown,
2179         .setsockopt        = sock_common_setsockopt,
2180         .getsockopt        = sock_common_getsockopt,
2181         .sendmsg           = inet_sendmsg,
2182         .recvmsg           = inet_recvmsg,
2183         .mmap              = sock_no_mmap,
2184         .sendpage          = inet_sendpage,
2185 };
2186
2187 static struct inet_protosw mptcp_protosw = {
2188         .type           = SOCK_STREAM,
2189         .protocol       = IPPROTO_MPTCP,
2190         .prot           = &mptcp_prot,
2191         .ops            = &mptcp_stream_ops,
2192         .flags          = INET_PROTOSW_ICSK,
2193 };
2194
2195 void __init mptcp_proto_init(void)
2196 {
2197         mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo;
2198
2199         if (percpu_counter_init(&mptcp_sockets_allocated, 0, GFP_KERNEL))
2200                 panic("Failed to allocate MPTCP pcpu counter\n");
2201
2202         mptcp_subflow_init();
2203         mptcp_pm_init();
2204         mptcp_token_init();
2205
2206         if (proto_register(&mptcp_prot, 1) != 0)
2207                 panic("Failed to register MPTCP proto.\n");
2208
2209         inet_register_protosw(&mptcp_protosw);
2210
2211         BUILD_BUG_ON(sizeof(struct mptcp_skb_cb) > sizeof_field(struct sk_buff, cb));
2212 }
2213
2214 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
2215 static const struct proto_ops mptcp_v6_stream_ops = {
2216         .family            = PF_INET6,
2217         .owner             = THIS_MODULE,
2218         .release           = inet6_release,
2219         .bind              = mptcp_bind,
2220         .connect           = mptcp_stream_connect,
2221         .socketpair        = sock_no_socketpair,
2222         .accept            = mptcp_stream_accept,
2223         .getname           = inet6_getname,
2224         .poll              = mptcp_poll,
2225         .ioctl             = inet6_ioctl,
2226         .gettstamp         = sock_gettstamp,
2227         .listen            = mptcp_listen,
2228         .shutdown          = mptcp_shutdown,
2229         .setsockopt        = sock_common_setsockopt,
2230         .getsockopt        = sock_common_getsockopt,
2231         .sendmsg           = inet6_sendmsg,
2232         .recvmsg           = inet6_recvmsg,
2233         .mmap              = sock_no_mmap,
2234         .sendpage          = inet_sendpage,
2235 #ifdef CONFIG_COMPAT
2236         .compat_ioctl      = inet6_compat_ioctl,
2237 #endif
2238 };
2239
2240 static struct proto mptcp_v6_prot;
2241
2242 static void mptcp_v6_destroy(struct sock *sk)
2243 {
2244         mptcp_destroy(sk);
2245         inet6_destroy_sock(sk);
2246 }
2247
2248 static struct inet_protosw mptcp_v6_protosw = {
2249         .type           = SOCK_STREAM,
2250         .protocol       = IPPROTO_MPTCP,
2251         .prot           = &mptcp_v6_prot,
2252         .ops            = &mptcp_v6_stream_ops,
2253         .flags          = INET_PROTOSW_ICSK,
2254 };
2255
2256 int __init mptcp_proto_v6_init(void)
2257 {
2258         int err;
2259
2260         mptcp_v6_prot = mptcp_prot;
2261         strcpy(mptcp_v6_prot.name, "MPTCPv6");
2262         mptcp_v6_prot.slab = NULL;
2263         mptcp_v6_prot.destroy = mptcp_v6_destroy;
2264         mptcp_v6_prot.obj_size = sizeof(struct mptcp6_sock);
2265
2266         err = proto_register(&mptcp_v6_prot, 1);
2267         if (err)
2268                 return err;
2269
2270         err = inet6_register_protosw(&mptcp_v6_protosw);
2271         if (err)
2272                 proto_unregister(&mptcp_v6_prot);
2273
2274         return err;
2275 }
2276 #endif