Merge tag 'libata-5.15-rc6' of git://git.kernel.org/pub/scm/linux/kernel/git/dlemoal...
[linux-2.6-microblaze.git] / fs / dlm / lowcomms.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /******************************************************************************
3 *******************************************************************************
4 **
5 **  Copyright (C) Sistina Software, Inc.  1997-2003  All rights reserved.
6 **  Copyright (C) 2004-2009 Red Hat, Inc.  All rights reserved.
7 **
8 **
9 *******************************************************************************
10 ******************************************************************************/
11
12 /*
13  * lowcomms.c
14  *
15  * This is the "low-level" comms layer.
16  *
17  * It is responsible for sending/receiving messages
18  * from other nodes in the cluster.
19  *
20  * Cluster nodes are referred to by their nodeids. nodeids are
21  * simply 32 bit numbers to the locking module - if they need to
22  * be expanded for the cluster infrastructure then that is its
23  * responsibility. It is this layer's
24  * responsibility to resolve these into IP address or
25  * whatever it needs for inter-node communication.
26  *
27  * The comms level is two kernel threads that deal mainly with
28  * the receiving of messages from other nodes and passing them
29  * up to the mid-level comms layer (which understands the
30  * message format) for execution by the locking core, and
31  * a send thread which does all the setting up of connections
32  * to remote nodes and the sending of data. Threads are not allowed
33  * to send their own data because it may cause them to wait in times
34  * of high load. Also, this way, the sending thread can collect together
35  * messages bound for one node and send them in one block.
36  *
37  * lowcomms will choose to use either TCP or SCTP as its transport layer
38  * depending on the configuration variable 'protocol'. This should be set
39  * to 0 (default) for TCP or 1 for SCTP. It should be configured using a
40  * cluster-wide mechanism as it must be the same on all nodes of the cluster
41  * for the DLM to function.
42  *
43  */
44
45 #include <asm/ioctls.h>
46 #include <net/sock.h>
47 #include <net/tcp.h>
48 #include <linux/pagemap.h>
49 #include <linux/file.h>
50 #include <linux/mutex.h>
51 #include <linux/sctp.h>
52 #include <linux/slab.h>
53 #include <net/sctp/sctp.h>
54 #include <net/ipv6.h>
55
56 #include "dlm_internal.h"
57 #include "lowcomms.h"
58 #include "midcomms.h"
59 #include "config.h"
60
61 #define NEEDED_RMEM (4*1024*1024)
62
63 /* Number of messages to send before rescheduling */
64 #define MAX_SEND_MSG_COUNT 25
65 #define DLM_SHUTDOWN_WAIT_TIMEOUT msecs_to_jiffies(10000)
66
67 struct connection {
68         struct socket *sock;    /* NULL if not connected */
69         uint32_t nodeid;        /* So we know who we are in the list */
70         struct mutex sock_mutex;
71         unsigned long flags;
72 #define CF_READ_PENDING 1
73 #define CF_WRITE_PENDING 2
74 #define CF_INIT_PENDING 4
75 #define CF_IS_OTHERCON 5
76 #define CF_CLOSE 6
77 #define CF_APP_LIMITED 7
78 #define CF_CLOSING 8
79 #define CF_SHUTDOWN 9
80 #define CF_CONNECTED 10
81 #define CF_RECONNECT 11
82 #define CF_DELAY_CONNECT 12
83 #define CF_EOF 13
84         struct list_head writequeue;  /* List of outgoing writequeue_entries */
85         spinlock_t writequeue_lock;
86         atomic_t writequeue_cnt;
87         struct mutex wq_alloc;
88         int retries;
89 #define MAX_CONNECT_RETRIES 3
90         struct hlist_node list;
91         struct connection *othercon;
92         struct connection *sendcon;
93         struct work_struct rwork; /* Receive workqueue */
94         struct work_struct swork; /* Send workqueue */
95         wait_queue_head_t shutdown_wait; /* wait for graceful shutdown */
96         unsigned char *rx_buf;
97         int rx_buflen;
98         int rx_leftover;
99         struct rcu_head rcu;
100 };
101 #define sock2con(x) ((struct connection *)(x)->sk_user_data)
102
103 struct listen_connection {
104         struct socket *sock;
105         struct work_struct rwork;
106 };
107
108 #define DLM_WQ_REMAIN_BYTES(e) (PAGE_SIZE - e->end)
109 #define DLM_WQ_LENGTH_BYTES(e) (e->end - e->offset)
110
111 /* An entry waiting to be sent */
112 struct writequeue_entry {
113         struct list_head list;
114         struct page *page;
115         int offset;
116         int len;
117         int end;
118         int users;
119         bool dirty;
120         struct connection *con;
121         struct list_head msgs;
122         struct kref ref;
123 };
124
125 struct dlm_msg {
126         struct writequeue_entry *entry;
127         struct dlm_msg *orig_msg;
128         bool retransmit;
129         void *ppc;
130         int len;
131         int idx; /* new()/commit() idx exchange */
132
133         struct list_head list;
134         struct kref ref;
135 };
136
137 struct dlm_node_addr {
138         struct list_head list;
139         int nodeid;
140         int mark;
141         int addr_count;
142         int curr_addr_index;
143         struct sockaddr_storage *addr[DLM_MAX_ADDR_COUNT];
144 };
145
146 struct dlm_proto_ops {
147         bool try_new_addr;
148         const char *name;
149         int proto;
150
151         int (*connect)(struct connection *con, struct socket *sock,
152                        struct sockaddr *addr, int addr_len);
153         void (*sockopts)(struct socket *sock);
154         int (*bind)(struct socket *sock);
155         int (*listen_validate)(void);
156         void (*listen_sockopts)(struct socket *sock);
157         int (*listen_bind)(struct socket *sock);
158         /* What to do to shutdown */
159         void (*shutdown_action)(struct connection *con);
160         /* What to do to eof check */
161         bool (*eof_condition)(struct connection *con);
162 };
163
164 static struct listen_sock_callbacks {
165         void (*sk_error_report)(struct sock *);
166         void (*sk_data_ready)(struct sock *);
167         void (*sk_state_change)(struct sock *);
168         void (*sk_write_space)(struct sock *);
169 } listen_sock;
170
171 static LIST_HEAD(dlm_node_addrs);
172 static DEFINE_SPINLOCK(dlm_node_addrs_spin);
173
174 static struct listen_connection listen_con;
175 static struct sockaddr_storage *dlm_local_addr[DLM_MAX_ADDR_COUNT];
176 static int dlm_local_count;
177 int dlm_allow_conn;
178
179 /* Work queues */
180 static struct workqueue_struct *recv_workqueue;
181 static struct workqueue_struct *send_workqueue;
182
183 static struct hlist_head connection_hash[CONN_HASH_SIZE];
184 static DEFINE_SPINLOCK(connections_lock);
185 DEFINE_STATIC_SRCU(connections_srcu);
186
187 static const struct dlm_proto_ops *dlm_proto_ops;
188
189 static void process_recv_sockets(struct work_struct *work);
190 static void process_send_sockets(struct work_struct *work);
191
192 /* need to held writequeue_lock */
193 static struct writequeue_entry *con_next_wq(struct connection *con)
194 {
195         struct writequeue_entry *e;
196
197         if (list_empty(&con->writequeue))
198                 return NULL;
199
200         e = list_first_entry(&con->writequeue, struct writequeue_entry,
201                              list);
202         if (e->len == 0)
203                 return NULL;
204
205         return e;
206 }
207
208 static struct connection *__find_con(int nodeid, int r)
209 {
210         struct connection *con;
211
212         hlist_for_each_entry_rcu(con, &connection_hash[r], list) {
213                 if (con->nodeid == nodeid)
214                         return con;
215         }
216
217         return NULL;
218 }
219
220 static bool tcp_eof_condition(struct connection *con)
221 {
222         return atomic_read(&con->writequeue_cnt);
223 }
224
225 static int dlm_con_init(struct connection *con, int nodeid)
226 {
227         con->rx_buflen = dlm_config.ci_buffer_size;
228         con->rx_buf = kmalloc(con->rx_buflen, GFP_NOFS);
229         if (!con->rx_buf)
230                 return -ENOMEM;
231
232         con->nodeid = nodeid;
233         mutex_init(&con->sock_mutex);
234         INIT_LIST_HEAD(&con->writequeue);
235         spin_lock_init(&con->writequeue_lock);
236         atomic_set(&con->writequeue_cnt, 0);
237         INIT_WORK(&con->swork, process_send_sockets);
238         INIT_WORK(&con->rwork, process_recv_sockets);
239         init_waitqueue_head(&con->shutdown_wait);
240
241         return 0;
242 }
243
244 /*
245  * If 'allocation' is zero then we don't attempt to create a new
246  * connection structure for this node.
247  */
248 static struct connection *nodeid2con(int nodeid, gfp_t alloc)
249 {
250         struct connection *con, *tmp;
251         int r, ret;
252
253         r = nodeid_hash(nodeid);
254         con = __find_con(nodeid, r);
255         if (con || !alloc)
256                 return con;
257
258         con = kzalloc(sizeof(*con), alloc);
259         if (!con)
260                 return NULL;
261
262         ret = dlm_con_init(con, nodeid);
263         if (ret) {
264                 kfree(con);
265                 return NULL;
266         }
267
268         mutex_init(&con->wq_alloc);
269
270         spin_lock(&connections_lock);
271         /* Because multiple workqueues/threads calls this function it can
272          * race on multiple cpu's. Instead of locking hot path __find_con()
273          * we just check in rare cases of recently added nodes again
274          * under protection of connections_lock. If this is the case we
275          * abort our connection creation and return the existing connection.
276          */
277         tmp = __find_con(nodeid, r);
278         if (tmp) {
279                 spin_unlock(&connections_lock);
280                 kfree(con->rx_buf);
281                 kfree(con);
282                 return tmp;
283         }
284
285         hlist_add_head_rcu(&con->list, &connection_hash[r]);
286         spin_unlock(&connections_lock);
287
288         return con;
289 }
290
291 /* Loop round all connections */
292 static void foreach_conn(void (*conn_func)(struct connection *c))
293 {
294         int i;
295         struct connection *con;
296
297         for (i = 0; i < CONN_HASH_SIZE; i++) {
298                 hlist_for_each_entry_rcu(con, &connection_hash[i], list)
299                         conn_func(con);
300         }
301 }
302
303 static struct dlm_node_addr *find_node_addr(int nodeid)
304 {
305         struct dlm_node_addr *na;
306
307         list_for_each_entry(na, &dlm_node_addrs, list) {
308                 if (na->nodeid == nodeid)
309                         return na;
310         }
311         return NULL;
312 }
313
314 static int addr_compare(const struct sockaddr_storage *x,
315                         const struct sockaddr_storage *y)
316 {
317         switch (x->ss_family) {
318         case AF_INET: {
319                 struct sockaddr_in *sinx = (struct sockaddr_in *)x;
320                 struct sockaddr_in *siny = (struct sockaddr_in *)y;
321                 if (sinx->sin_addr.s_addr != siny->sin_addr.s_addr)
322                         return 0;
323                 if (sinx->sin_port != siny->sin_port)
324                         return 0;
325                 break;
326         }
327         case AF_INET6: {
328                 struct sockaddr_in6 *sinx = (struct sockaddr_in6 *)x;
329                 struct sockaddr_in6 *siny = (struct sockaddr_in6 *)y;
330                 if (!ipv6_addr_equal(&sinx->sin6_addr, &siny->sin6_addr))
331                         return 0;
332                 if (sinx->sin6_port != siny->sin6_port)
333                         return 0;
334                 break;
335         }
336         default:
337                 return 0;
338         }
339         return 1;
340 }
341
342 static int nodeid_to_addr(int nodeid, struct sockaddr_storage *sas_out,
343                           struct sockaddr *sa_out, bool try_new_addr,
344                           unsigned int *mark)
345 {
346         struct sockaddr_storage sas;
347         struct dlm_node_addr *na;
348
349         if (!dlm_local_count)
350                 return -1;
351
352         spin_lock(&dlm_node_addrs_spin);
353         na = find_node_addr(nodeid);
354         if (na && na->addr_count) {
355                 memcpy(&sas, na->addr[na->curr_addr_index],
356                        sizeof(struct sockaddr_storage));
357
358                 if (try_new_addr) {
359                         na->curr_addr_index++;
360                         if (na->curr_addr_index == na->addr_count)
361                                 na->curr_addr_index = 0;
362                 }
363         }
364         spin_unlock(&dlm_node_addrs_spin);
365
366         if (!na)
367                 return -EEXIST;
368
369         if (!na->addr_count)
370                 return -ENOENT;
371
372         *mark = na->mark;
373
374         if (sas_out)
375                 memcpy(sas_out, &sas, sizeof(struct sockaddr_storage));
376
377         if (!sa_out)
378                 return 0;
379
380         if (dlm_local_addr[0]->ss_family == AF_INET) {
381                 struct sockaddr_in *in4  = (struct sockaddr_in *) &sas;
382                 struct sockaddr_in *ret4 = (struct sockaddr_in *) sa_out;
383                 ret4->sin_addr.s_addr = in4->sin_addr.s_addr;
384         } else {
385                 struct sockaddr_in6 *in6  = (struct sockaddr_in6 *) &sas;
386                 struct sockaddr_in6 *ret6 = (struct sockaddr_in6 *) sa_out;
387                 ret6->sin6_addr = in6->sin6_addr;
388         }
389
390         return 0;
391 }
392
393 static int addr_to_nodeid(struct sockaddr_storage *addr, int *nodeid,
394                           unsigned int *mark)
395 {
396         struct dlm_node_addr *na;
397         int rv = -EEXIST;
398         int addr_i;
399
400         spin_lock(&dlm_node_addrs_spin);
401         list_for_each_entry(na, &dlm_node_addrs, list) {
402                 if (!na->addr_count)
403                         continue;
404
405                 for (addr_i = 0; addr_i < na->addr_count; addr_i++) {
406                         if (addr_compare(na->addr[addr_i], addr)) {
407                                 *nodeid = na->nodeid;
408                                 *mark = na->mark;
409                                 rv = 0;
410                                 goto unlock;
411                         }
412                 }
413         }
414 unlock:
415         spin_unlock(&dlm_node_addrs_spin);
416         return rv;
417 }
418
419 /* caller need to held dlm_node_addrs_spin lock */
420 static bool dlm_lowcomms_na_has_addr(const struct dlm_node_addr *na,
421                                      const struct sockaddr_storage *addr)
422 {
423         int i;
424
425         for (i = 0; i < na->addr_count; i++) {
426                 if (addr_compare(na->addr[i], addr))
427                         return true;
428         }
429
430         return false;
431 }
432
433 int dlm_lowcomms_addr(int nodeid, struct sockaddr_storage *addr, int len)
434 {
435         struct sockaddr_storage *new_addr;
436         struct dlm_node_addr *new_node, *na;
437         bool ret;
438
439         new_node = kzalloc(sizeof(struct dlm_node_addr), GFP_NOFS);
440         if (!new_node)
441                 return -ENOMEM;
442
443         new_addr = kzalloc(sizeof(struct sockaddr_storage), GFP_NOFS);
444         if (!new_addr) {
445                 kfree(new_node);
446                 return -ENOMEM;
447         }
448
449         memcpy(new_addr, addr, len);
450
451         spin_lock(&dlm_node_addrs_spin);
452         na = find_node_addr(nodeid);
453         if (!na) {
454                 new_node->nodeid = nodeid;
455                 new_node->addr[0] = new_addr;
456                 new_node->addr_count = 1;
457                 new_node->mark = dlm_config.ci_mark;
458                 list_add(&new_node->list, &dlm_node_addrs);
459                 spin_unlock(&dlm_node_addrs_spin);
460                 return 0;
461         }
462
463         ret = dlm_lowcomms_na_has_addr(na, addr);
464         if (ret) {
465                 spin_unlock(&dlm_node_addrs_spin);
466                 kfree(new_addr);
467                 kfree(new_node);
468                 return -EEXIST;
469         }
470
471         if (na->addr_count >= DLM_MAX_ADDR_COUNT) {
472                 spin_unlock(&dlm_node_addrs_spin);
473                 kfree(new_addr);
474                 kfree(new_node);
475                 return -ENOSPC;
476         }
477
478         na->addr[na->addr_count++] = new_addr;
479         spin_unlock(&dlm_node_addrs_spin);
480         kfree(new_node);
481         return 0;
482 }
483
484 /* Data available on socket or listen socket received a connect */
485 static void lowcomms_data_ready(struct sock *sk)
486 {
487         struct connection *con;
488
489         read_lock_bh(&sk->sk_callback_lock);
490         con = sock2con(sk);
491         if (con && !test_and_set_bit(CF_READ_PENDING, &con->flags))
492                 queue_work(recv_workqueue, &con->rwork);
493         read_unlock_bh(&sk->sk_callback_lock);
494 }
495
496 static void lowcomms_listen_data_ready(struct sock *sk)
497 {
498         if (!dlm_allow_conn)
499                 return;
500
501         queue_work(recv_workqueue, &listen_con.rwork);
502 }
503
504 static void lowcomms_write_space(struct sock *sk)
505 {
506         struct connection *con;
507
508         read_lock_bh(&sk->sk_callback_lock);
509         con = sock2con(sk);
510         if (!con)
511                 goto out;
512
513         if (!test_and_set_bit(CF_CONNECTED, &con->flags)) {
514                 log_print("successful connected to node %d", con->nodeid);
515                 queue_work(send_workqueue, &con->swork);
516                 goto out;
517         }
518
519         clear_bit(SOCK_NOSPACE, &con->sock->flags);
520
521         if (test_and_clear_bit(CF_APP_LIMITED, &con->flags)) {
522                 con->sock->sk->sk_write_pending--;
523                 clear_bit(SOCKWQ_ASYNC_NOSPACE, &con->sock->flags);
524         }
525
526         queue_work(send_workqueue, &con->swork);
527 out:
528         read_unlock_bh(&sk->sk_callback_lock);
529 }
530
531 static inline void lowcomms_connect_sock(struct connection *con)
532 {
533         if (test_bit(CF_CLOSE, &con->flags))
534                 return;
535         queue_work(send_workqueue, &con->swork);
536         cond_resched();
537 }
538
539 static void lowcomms_state_change(struct sock *sk)
540 {
541         /* SCTP layer is not calling sk_data_ready when the connection
542          * is done, so we catch the signal through here. Also, it
543          * doesn't switch socket state when entering shutdown, so we
544          * skip the write in that case.
545          */
546         if (sk->sk_shutdown) {
547                 if (sk->sk_shutdown == RCV_SHUTDOWN)
548                         lowcomms_data_ready(sk);
549         } else if (sk->sk_state == TCP_ESTABLISHED) {
550                 lowcomms_write_space(sk);
551         }
552 }
553
554 int dlm_lowcomms_connect_node(int nodeid)
555 {
556         struct connection *con;
557         int idx;
558
559         if (nodeid == dlm_our_nodeid())
560                 return 0;
561
562         idx = srcu_read_lock(&connections_srcu);
563         con = nodeid2con(nodeid, GFP_NOFS);
564         if (!con) {
565                 srcu_read_unlock(&connections_srcu, idx);
566                 return -ENOMEM;
567         }
568
569         lowcomms_connect_sock(con);
570         srcu_read_unlock(&connections_srcu, idx);
571
572         return 0;
573 }
574
575 int dlm_lowcomms_nodes_set_mark(int nodeid, unsigned int mark)
576 {
577         struct dlm_node_addr *na;
578
579         spin_lock(&dlm_node_addrs_spin);
580         na = find_node_addr(nodeid);
581         if (!na) {
582                 spin_unlock(&dlm_node_addrs_spin);
583                 return -ENOENT;
584         }
585
586         na->mark = mark;
587         spin_unlock(&dlm_node_addrs_spin);
588
589         return 0;
590 }
591
592 static void lowcomms_error_report(struct sock *sk)
593 {
594         struct connection *con;
595         struct sockaddr_storage saddr;
596         void (*orig_report)(struct sock *) = NULL;
597
598         read_lock_bh(&sk->sk_callback_lock);
599         con = sock2con(sk);
600         if (con == NULL)
601                 goto out;
602
603         orig_report = listen_sock.sk_error_report;
604         if (kernel_getpeername(sk->sk_socket, (struct sockaddr *)&saddr) < 0) {
605                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
606                                    "sending to node %d, port %d, "
607                                    "sk_err=%d/%d\n", dlm_our_nodeid(),
608                                    con->nodeid, dlm_config.ci_tcp_port,
609                                    sk->sk_err, sk->sk_err_soft);
610         } else if (saddr.ss_family == AF_INET) {
611                 struct sockaddr_in *sin4 = (struct sockaddr_in *)&saddr;
612
613                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
614                                    "sending to node %d at %pI4, port %d, "
615                                    "sk_err=%d/%d\n", dlm_our_nodeid(),
616                                    con->nodeid, &sin4->sin_addr.s_addr,
617                                    dlm_config.ci_tcp_port, sk->sk_err,
618                                    sk->sk_err_soft);
619         } else {
620                 struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&saddr;
621
622                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
623                                    "sending to node %d at %u.%u.%u.%u, "
624                                    "port %d, sk_err=%d/%d\n", dlm_our_nodeid(),
625                                    con->nodeid, sin6->sin6_addr.s6_addr32[0],
626                                    sin6->sin6_addr.s6_addr32[1],
627                                    sin6->sin6_addr.s6_addr32[2],
628                                    sin6->sin6_addr.s6_addr32[3],
629                                    dlm_config.ci_tcp_port, sk->sk_err,
630                                    sk->sk_err_soft);
631         }
632
633         /* below sendcon only handling */
634         if (test_bit(CF_IS_OTHERCON, &con->flags))
635                 con = con->sendcon;
636
637         switch (sk->sk_err) {
638         case ECONNREFUSED:
639                 set_bit(CF_DELAY_CONNECT, &con->flags);
640                 break;
641         default:
642                 break;
643         }
644
645         if (!test_and_set_bit(CF_RECONNECT, &con->flags))
646                 queue_work(send_workqueue, &con->swork);
647
648 out:
649         read_unlock_bh(&sk->sk_callback_lock);
650         if (orig_report)
651                 orig_report(sk);
652 }
653
654 /* Note: sk_callback_lock must be locked before calling this function. */
655 static void save_listen_callbacks(struct socket *sock)
656 {
657         struct sock *sk = sock->sk;
658
659         listen_sock.sk_data_ready = sk->sk_data_ready;
660         listen_sock.sk_state_change = sk->sk_state_change;
661         listen_sock.sk_write_space = sk->sk_write_space;
662         listen_sock.sk_error_report = sk->sk_error_report;
663 }
664
665 static void restore_callbacks(struct socket *sock)
666 {
667         struct sock *sk = sock->sk;
668
669         write_lock_bh(&sk->sk_callback_lock);
670         sk->sk_user_data = NULL;
671         sk->sk_data_ready = listen_sock.sk_data_ready;
672         sk->sk_state_change = listen_sock.sk_state_change;
673         sk->sk_write_space = listen_sock.sk_write_space;
674         sk->sk_error_report = listen_sock.sk_error_report;
675         write_unlock_bh(&sk->sk_callback_lock);
676 }
677
678 static void add_listen_sock(struct socket *sock, struct listen_connection *con)
679 {
680         struct sock *sk = sock->sk;
681
682         write_lock_bh(&sk->sk_callback_lock);
683         save_listen_callbacks(sock);
684         con->sock = sock;
685
686         sk->sk_user_data = con;
687         sk->sk_allocation = GFP_NOFS;
688         /* Install a data_ready callback */
689         sk->sk_data_ready = lowcomms_listen_data_ready;
690         write_unlock_bh(&sk->sk_callback_lock);
691 }
692
693 /* Make a socket active */
694 static void add_sock(struct socket *sock, struct connection *con)
695 {
696         struct sock *sk = sock->sk;
697
698         write_lock_bh(&sk->sk_callback_lock);
699         con->sock = sock;
700
701         sk->sk_user_data = con;
702         /* Install a data_ready callback */
703         sk->sk_data_ready = lowcomms_data_ready;
704         sk->sk_write_space = lowcomms_write_space;
705         sk->sk_state_change = lowcomms_state_change;
706         sk->sk_allocation = GFP_NOFS;
707         sk->sk_error_report = lowcomms_error_report;
708         write_unlock_bh(&sk->sk_callback_lock);
709 }
710
711 /* Add the port number to an IPv6 or 4 sockaddr and return the address
712    length */
713 static void make_sockaddr(struct sockaddr_storage *saddr, uint16_t port,
714                           int *addr_len)
715 {
716         saddr->ss_family =  dlm_local_addr[0]->ss_family;
717         if (saddr->ss_family == AF_INET) {
718                 struct sockaddr_in *in4_addr = (struct sockaddr_in *)saddr;
719                 in4_addr->sin_port = cpu_to_be16(port);
720                 *addr_len = sizeof(struct sockaddr_in);
721                 memset(&in4_addr->sin_zero, 0, sizeof(in4_addr->sin_zero));
722         } else {
723                 struct sockaddr_in6 *in6_addr = (struct sockaddr_in6 *)saddr;
724                 in6_addr->sin6_port = cpu_to_be16(port);
725                 *addr_len = sizeof(struct sockaddr_in6);
726         }
727         memset((char *)saddr + *addr_len, 0, sizeof(struct sockaddr_storage) - *addr_len);
728 }
729
730 static void dlm_page_release(struct kref *kref)
731 {
732         struct writequeue_entry *e = container_of(kref, struct writequeue_entry,
733                                                   ref);
734
735         __free_page(e->page);
736         kfree(e);
737 }
738
739 static void dlm_msg_release(struct kref *kref)
740 {
741         struct dlm_msg *msg = container_of(kref, struct dlm_msg, ref);
742
743         kref_put(&msg->entry->ref, dlm_page_release);
744         kfree(msg);
745 }
746
747 static void free_entry(struct writequeue_entry *e)
748 {
749         struct dlm_msg *msg, *tmp;
750
751         list_for_each_entry_safe(msg, tmp, &e->msgs, list) {
752                 if (msg->orig_msg) {
753                         msg->orig_msg->retransmit = false;
754                         kref_put(&msg->orig_msg->ref, dlm_msg_release);
755                 }
756
757                 list_del(&msg->list);
758                 kref_put(&msg->ref, dlm_msg_release);
759         }
760
761         list_del(&e->list);
762         atomic_dec(&e->con->writequeue_cnt);
763         kref_put(&e->ref, dlm_page_release);
764 }
765
766 static void dlm_close_sock(struct socket **sock)
767 {
768         if (*sock) {
769                 restore_callbacks(*sock);
770                 sock_release(*sock);
771                 *sock = NULL;
772         }
773 }
774
775 /* Close a remote connection and tidy up */
776 static void close_connection(struct connection *con, bool and_other,
777                              bool tx, bool rx)
778 {
779         bool closing = test_and_set_bit(CF_CLOSING, &con->flags);
780         struct writequeue_entry *e;
781
782         if (tx && !closing && cancel_work_sync(&con->swork)) {
783                 log_print("canceled swork for node %d", con->nodeid);
784                 clear_bit(CF_WRITE_PENDING, &con->flags);
785         }
786         if (rx && !closing && cancel_work_sync(&con->rwork)) {
787                 log_print("canceled rwork for node %d", con->nodeid);
788                 clear_bit(CF_READ_PENDING, &con->flags);
789         }
790
791         mutex_lock(&con->sock_mutex);
792         dlm_close_sock(&con->sock);
793
794         if (con->othercon && and_other) {
795                 /* Will only re-enter once. */
796                 close_connection(con->othercon, false, tx, rx);
797         }
798
799         /* if we send a writequeue entry only a half way, we drop the
800          * whole entry because reconnection and that we not start of the
801          * middle of a msg which will confuse the other end.
802          *
803          * we can always drop messages because retransmits, but what we
804          * cannot allow is to transmit half messages which may be processed
805          * at the other side.
806          *
807          * our policy is to start on a clean state when disconnects, we don't
808          * know what's send/received on transport layer in this case.
809          */
810         spin_lock(&con->writequeue_lock);
811         if (!list_empty(&con->writequeue)) {
812                 e = list_first_entry(&con->writequeue, struct writequeue_entry,
813                                      list);
814                 if (e->dirty)
815                         free_entry(e);
816         }
817         spin_unlock(&con->writequeue_lock);
818
819         con->rx_leftover = 0;
820         con->retries = 0;
821         clear_bit(CF_APP_LIMITED, &con->flags);
822         clear_bit(CF_CONNECTED, &con->flags);
823         clear_bit(CF_DELAY_CONNECT, &con->flags);
824         clear_bit(CF_RECONNECT, &con->flags);
825         clear_bit(CF_EOF, &con->flags);
826         mutex_unlock(&con->sock_mutex);
827         clear_bit(CF_CLOSING, &con->flags);
828 }
829
830 static void shutdown_connection(struct connection *con)
831 {
832         int ret;
833
834         flush_work(&con->swork);
835
836         mutex_lock(&con->sock_mutex);
837         /* nothing to shutdown */
838         if (!con->sock) {
839                 mutex_unlock(&con->sock_mutex);
840                 return;
841         }
842
843         set_bit(CF_SHUTDOWN, &con->flags);
844         ret = kernel_sock_shutdown(con->sock, SHUT_WR);
845         mutex_unlock(&con->sock_mutex);
846         if (ret) {
847                 log_print("Connection %p failed to shutdown: %d will force close",
848                           con, ret);
849                 goto force_close;
850         } else {
851                 ret = wait_event_timeout(con->shutdown_wait,
852                                          !test_bit(CF_SHUTDOWN, &con->flags),
853                                          DLM_SHUTDOWN_WAIT_TIMEOUT);
854                 if (ret == 0) {
855                         log_print("Connection %p shutdown timed out, will force close",
856                                   con);
857                         goto force_close;
858                 }
859         }
860
861         return;
862
863 force_close:
864         clear_bit(CF_SHUTDOWN, &con->flags);
865         close_connection(con, false, true, true);
866 }
867
868 static void dlm_tcp_shutdown(struct connection *con)
869 {
870         if (con->othercon)
871                 shutdown_connection(con->othercon);
872         shutdown_connection(con);
873 }
874
875 static int con_realloc_receive_buf(struct connection *con, int newlen)
876 {
877         unsigned char *newbuf;
878
879         newbuf = kmalloc(newlen, GFP_NOFS);
880         if (!newbuf)
881                 return -ENOMEM;
882
883         /* copy any leftover from last receive */
884         if (con->rx_leftover)
885                 memmove(newbuf, con->rx_buf, con->rx_leftover);
886
887         /* swap to new buffer space */
888         kfree(con->rx_buf);
889         con->rx_buflen = newlen;
890         con->rx_buf = newbuf;
891
892         return 0;
893 }
894
895 /* Data received from remote end */
896 static int receive_from_sock(struct connection *con)
897 {
898         struct msghdr msg;
899         struct kvec iov;
900         int ret, buflen;
901
902         mutex_lock(&con->sock_mutex);
903
904         if (con->sock == NULL) {
905                 ret = -EAGAIN;
906                 goto out_close;
907         }
908
909         /* realloc if we get new buffer size to read out */
910         buflen = dlm_config.ci_buffer_size;
911         if (con->rx_buflen != buflen && con->rx_leftover <= buflen) {
912                 ret = con_realloc_receive_buf(con, buflen);
913                 if (ret < 0)
914                         goto out_resched;
915         }
916
917         for (;;) {
918                 /* calculate new buffer parameter regarding last receive and
919                  * possible leftover bytes
920                  */
921                 iov.iov_base = con->rx_buf + con->rx_leftover;
922                 iov.iov_len = con->rx_buflen - con->rx_leftover;
923
924                 memset(&msg, 0, sizeof(msg));
925                 msg.msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL;
926                 ret = kernel_recvmsg(con->sock, &msg, &iov, 1, iov.iov_len,
927                                      msg.msg_flags);
928                 if (ret == -EAGAIN)
929                         break;
930                 else if (ret <= 0)
931                         goto out_close;
932
933                 /* new buflen according readed bytes and leftover from last receive */
934                 buflen = ret + con->rx_leftover;
935                 ret = dlm_process_incoming_buffer(con->nodeid, con->rx_buf, buflen);
936                 if (ret < 0)
937                         goto out_close;
938
939                 /* calculate leftover bytes from process and put it into begin of
940                  * the receive buffer, so next receive we have the full message
941                  * at the start address of the receive buffer.
942                  */
943                 con->rx_leftover = buflen - ret;
944                 if (con->rx_leftover) {
945                         memmove(con->rx_buf, con->rx_buf + ret,
946                                 con->rx_leftover);
947                 }
948         }
949
950         dlm_midcomms_receive_done(con->nodeid);
951         mutex_unlock(&con->sock_mutex);
952         return 0;
953
954 out_resched:
955         if (!test_and_set_bit(CF_READ_PENDING, &con->flags))
956                 queue_work(recv_workqueue, &con->rwork);
957         mutex_unlock(&con->sock_mutex);
958         return -EAGAIN;
959
960 out_close:
961         if (ret == 0) {
962                 log_print("connection %p got EOF from %d",
963                           con, con->nodeid);
964
965                 if (dlm_proto_ops->eof_condition &&
966                     dlm_proto_ops->eof_condition(con)) {
967                         set_bit(CF_EOF, &con->flags);
968                         mutex_unlock(&con->sock_mutex);
969                 } else {
970                         mutex_unlock(&con->sock_mutex);
971                         close_connection(con, false, true, false);
972
973                         /* handling for tcp shutdown */
974                         clear_bit(CF_SHUTDOWN, &con->flags);
975                         wake_up(&con->shutdown_wait);
976                 }
977
978                 /* signal to breaking receive worker */
979                 ret = -1;
980         } else {
981                 mutex_unlock(&con->sock_mutex);
982         }
983         return ret;
984 }
985
986 /* Listening socket is busy, accept a connection */
987 static int accept_from_sock(struct listen_connection *con)
988 {
989         int result;
990         struct sockaddr_storage peeraddr;
991         struct socket *newsock;
992         int len, idx;
993         int nodeid;
994         struct connection *newcon;
995         struct connection *addcon;
996         unsigned int mark;
997
998         if (!con->sock)
999                 return -ENOTCONN;
1000
1001         result = kernel_accept(con->sock, &newsock, O_NONBLOCK);
1002         if (result < 0)
1003                 goto accept_err;
1004
1005         /* Get the connected socket's peer */
1006         memset(&peeraddr, 0, sizeof(peeraddr));
1007         len = newsock->ops->getname(newsock, (struct sockaddr *)&peeraddr, 2);
1008         if (len < 0) {
1009                 result = -ECONNABORTED;
1010                 goto accept_err;
1011         }
1012
1013         /* Get the new node's NODEID */
1014         make_sockaddr(&peeraddr, 0, &len);
1015         if (addr_to_nodeid(&peeraddr, &nodeid, &mark)) {
1016                 unsigned char *b=(unsigned char *)&peeraddr;
1017                 log_print("connect from non cluster node");
1018                 print_hex_dump_bytes("ss: ", DUMP_PREFIX_NONE, 
1019                                      b, sizeof(struct sockaddr_storage));
1020                 sock_release(newsock);
1021                 return -1;
1022         }
1023
1024         log_print("got connection from %d", nodeid);
1025
1026         /*  Check to see if we already have a connection to this node. This
1027          *  could happen if the two nodes initiate a connection at roughly
1028          *  the same time and the connections cross on the wire.
1029          *  In this case we store the incoming one in "othercon"
1030          */
1031         idx = srcu_read_lock(&connections_srcu);
1032         newcon = nodeid2con(nodeid, GFP_NOFS);
1033         if (!newcon) {
1034                 srcu_read_unlock(&connections_srcu, idx);
1035                 result = -ENOMEM;
1036                 goto accept_err;
1037         }
1038
1039         sock_set_mark(newsock->sk, mark);
1040
1041         mutex_lock(&newcon->sock_mutex);
1042         if (newcon->sock) {
1043                 struct connection *othercon = newcon->othercon;
1044
1045                 if (!othercon) {
1046                         othercon = kzalloc(sizeof(*othercon), GFP_NOFS);
1047                         if (!othercon) {
1048                                 log_print("failed to allocate incoming socket");
1049                                 mutex_unlock(&newcon->sock_mutex);
1050                                 srcu_read_unlock(&connections_srcu, idx);
1051                                 result = -ENOMEM;
1052                                 goto accept_err;
1053                         }
1054
1055                         result = dlm_con_init(othercon, nodeid);
1056                         if (result < 0) {
1057                                 kfree(othercon);
1058                                 mutex_unlock(&newcon->sock_mutex);
1059                                 srcu_read_unlock(&connections_srcu, idx);
1060                                 goto accept_err;
1061                         }
1062
1063                         lockdep_set_subclass(&othercon->sock_mutex, 1);
1064                         set_bit(CF_IS_OTHERCON, &othercon->flags);
1065                         newcon->othercon = othercon;
1066                         othercon->sendcon = newcon;
1067                 } else {
1068                         /* close other sock con if we have something new */
1069                         close_connection(othercon, false, true, false);
1070                 }
1071
1072                 mutex_lock(&othercon->sock_mutex);
1073                 add_sock(newsock, othercon);
1074                 addcon = othercon;
1075                 mutex_unlock(&othercon->sock_mutex);
1076         }
1077         else {
1078                 /* accept copies the sk after we've saved the callbacks, so we
1079                    don't want to save them a second time or comm errors will
1080                    result in calling sk_error_report recursively. */
1081                 add_sock(newsock, newcon);
1082                 addcon = newcon;
1083         }
1084
1085         set_bit(CF_CONNECTED, &addcon->flags);
1086         mutex_unlock(&newcon->sock_mutex);
1087
1088         /*
1089          * Add it to the active queue in case we got data
1090          * between processing the accept adding the socket
1091          * to the read_sockets list
1092          */
1093         if (!test_and_set_bit(CF_READ_PENDING, &addcon->flags))
1094                 queue_work(recv_workqueue, &addcon->rwork);
1095
1096         srcu_read_unlock(&connections_srcu, idx);
1097
1098         return 0;
1099
1100 accept_err:
1101         if (newsock)
1102                 sock_release(newsock);
1103
1104         if (result != -EAGAIN)
1105                 log_print("error accepting connection from node: %d", result);
1106         return result;
1107 }
1108
1109 /*
1110  * writequeue_entry_complete - try to delete and free write queue entry
1111  * @e: write queue entry to try to delete
1112  * @completed: bytes completed
1113  *
1114  * writequeue_lock must be held.
1115  */
1116 static void writequeue_entry_complete(struct writequeue_entry *e, int completed)
1117 {
1118         e->offset += completed;
1119         e->len -= completed;
1120         /* signal that page was half way transmitted */
1121         e->dirty = true;
1122
1123         if (e->len == 0 && e->users == 0)
1124                 free_entry(e);
1125 }
1126
1127 /*
1128  * sctp_bind_addrs - bind a SCTP socket to all our addresses
1129  */
1130 static int sctp_bind_addrs(struct socket *sock, uint16_t port)
1131 {
1132         struct sockaddr_storage localaddr;
1133         struct sockaddr *addr = (struct sockaddr *)&localaddr;
1134         int i, addr_len, result = 0;
1135
1136         for (i = 0; i < dlm_local_count; i++) {
1137                 memcpy(&localaddr, dlm_local_addr[i], sizeof(localaddr));
1138                 make_sockaddr(&localaddr, port, &addr_len);
1139
1140                 if (!i)
1141                         result = kernel_bind(sock, addr, addr_len);
1142                 else
1143                         result = sock_bind_add(sock->sk, addr, addr_len);
1144
1145                 if (result < 0) {
1146                         log_print("Can't bind to %d addr number %d, %d.\n",
1147                                   port, i + 1, result);
1148                         break;
1149                 }
1150         }
1151         return result;
1152 }
1153
1154 /* Get local addresses */
1155 static void init_local(void)
1156 {
1157         struct sockaddr_storage sas, *addr;
1158         int i;
1159
1160         dlm_local_count = 0;
1161         for (i = 0; i < DLM_MAX_ADDR_COUNT; i++) {
1162                 if (dlm_our_addr(&sas, i))
1163                         break;
1164
1165                 addr = kmemdup(&sas, sizeof(*addr), GFP_NOFS);
1166                 if (!addr)
1167                         break;
1168                 dlm_local_addr[dlm_local_count++] = addr;
1169         }
1170 }
1171
1172 static void deinit_local(void)
1173 {
1174         int i;
1175
1176         for (i = 0; i < dlm_local_count; i++)
1177                 kfree(dlm_local_addr[i]);
1178 }
1179
1180 static struct writequeue_entry *new_writequeue_entry(struct connection *con,
1181                                                      gfp_t allocation)
1182 {
1183         struct writequeue_entry *entry;
1184
1185         entry = kzalloc(sizeof(*entry), allocation);
1186         if (!entry)
1187                 return NULL;
1188
1189         entry->page = alloc_page(allocation | __GFP_ZERO);
1190         if (!entry->page) {
1191                 kfree(entry);
1192                 return NULL;
1193         }
1194
1195         entry->con = con;
1196         entry->users = 1;
1197         kref_init(&entry->ref);
1198         INIT_LIST_HEAD(&entry->msgs);
1199
1200         return entry;
1201 }
1202
1203 static struct writequeue_entry *new_wq_entry(struct connection *con, int len,
1204                                              gfp_t allocation, char **ppc,
1205                                              void (*cb)(struct dlm_mhandle *mh),
1206                                              struct dlm_mhandle *mh)
1207 {
1208         struct writequeue_entry *e;
1209
1210         spin_lock(&con->writequeue_lock);
1211         if (!list_empty(&con->writequeue)) {
1212                 e = list_last_entry(&con->writequeue, struct writequeue_entry, list);
1213                 if (DLM_WQ_REMAIN_BYTES(e) >= len) {
1214                         kref_get(&e->ref);
1215
1216                         *ppc = page_address(e->page) + e->end;
1217                         if (cb)
1218                                 cb(mh);
1219
1220                         e->end += len;
1221                         e->users++;
1222                         spin_unlock(&con->writequeue_lock);
1223
1224                         return e;
1225                 }
1226         }
1227         spin_unlock(&con->writequeue_lock);
1228
1229         e = new_writequeue_entry(con, allocation);
1230         if (!e)
1231                 return NULL;
1232
1233         kref_get(&e->ref);
1234         *ppc = page_address(e->page);
1235         e->end += len;
1236         atomic_inc(&con->writequeue_cnt);
1237
1238         spin_lock(&con->writequeue_lock);
1239         if (cb)
1240                 cb(mh);
1241
1242         list_add_tail(&e->list, &con->writequeue);
1243         spin_unlock(&con->writequeue_lock);
1244
1245         return e;
1246 };
1247
1248 static struct dlm_msg *dlm_lowcomms_new_msg_con(struct connection *con, int len,
1249                                                 gfp_t allocation, char **ppc,
1250                                                 void (*cb)(struct dlm_mhandle *mh),
1251                                                 struct dlm_mhandle *mh)
1252 {
1253         struct writequeue_entry *e;
1254         struct dlm_msg *msg;
1255         bool sleepable;
1256
1257         msg = kzalloc(sizeof(*msg), allocation);
1258         if (!msg)
1259                 return NULL;
1260
1261         /* this mutex is being used as a wait to avoid multiple "fast"
1262          * new writequeue page list entry allocs in new_wq_entry in
1263          * normal operation which is sleepable context. Without it
1264          * we could end in multiple writequeue entries with one
1265          * dlm message because multiple callers were waiting at
1266          * the writequeue_lock in new_wq_entry().
1267          */
1268         sleepable = gfpflags_normal_context(allocation);
1269         if (sleepable)
1270                 mutex_lock(&con->wq_alloc);
1271
1272         kref_init(&msg->ref);
1273
1274         e = new_wq_entry(con, len, allocation, ppc, cb, mh);
1275         if (!e) {
1276                 if (sleepable)
1277                         mutex_unlock(&con->wq_alloc);
1278
1279                 kfree(msg);
1280                 return NULL;
1281         }
1282
1283         if (sleepable)
1284                 mutex_unlock(&con->wq_alloc);
1285
1286         msg->ppc = *ppc;
1287         msg->len = len;
1288         msg->entry = e;
1289
1290         return msg;
1291 }
1292
1293 struct dlm_msg *dlm_lowcomms_new_msg(int nodeid, int len, gfp_t allocation,
1294                                      char **ppc, void (*cb)(struct dlm_mhandle *mh),
1295                                      struct dlm_mhandle *mh)
1296 {
1297         struct connection *con;
1298         struct dlm_msg *msg;
1299         int idx;
1300
1301         if (len > DLM_MAX_SOCKET_BUFSIZE ||
1302             len < sizeof(struct dlm_header)) {
1303                 BUILD_BUG_ON(PAGE_SIZE < DLM_MAX_SOCKET_BUFSIZE);
1304                 log_print("failed to allocate a buffer of size %d", len);
1305                 WARN_ON(1);
1306                 return NULL;
1307         }
1308
1309         idx = srcu_read_lock(&connections_srcu);
1310         con = nodeid2con(nodeid, allocation);
1311         if (!con) {
1312                 srcu_read_unlock(&connections_srcu, idx);
1313                 return NULL;
1314         }
1315
1316         msg = dlm_lowcomms_new_msg_con(con, len, allocation, ppc, cb, mh);
1317         if (!msg) {
1318                 srcu_read_unlock(&connections_srcu, idx);
1319                 return NULL;
1320         }
1321
1322         /* we assume if successful commit must called */
1323         msg->idx = idx;
1324         return msg;
1325 }
1326
1327 static void _dlm_lowcomms_commit_msg(struct dlm_msg *msg)
1328 {
1329         struct writequeue_entry *e = msg->entry;
1330         struct connection *con = e->con;
1331         int users;
1332
1333         spin_lock(&con->writequeue_lock);
1334         kref_get(&msg->ref);
1335         list_add(&msg->list, &e->msgs);
1336
1337         users = --e->users;
1338         if (users)
1339                 goto out;
1340
1341         e->len = DLM_WQ_LENGTH_BYTES(e);
1342         spin_unlock(&con->writequeue_lock);
1343
1344         queue_work(send_workqueue, &con->swork);
1345         return;
1346
1347 out:
1348         spin_unlock(&con->writequeue_lock);
1349         return;
1350 }
1351
1352 void dlm_lowcomms_commit_msg(struct dlm_msg *msg)
1353 {
1354         _dlm_lowcomms_commit_msg(msg);
1355         srcu_read_unlock(&connections_srcu, msg->idx);
1356 }
1357
1358 void dlm_lowcomms_put_msg(struct dlm_msg *msg)
1359 {
1360         kref_put(&msg->ref, dlm_msg_release);
1361 }
1362
1363 /* does not held connections_srcu, usage workqueue only */
1364 int dlm_lowcomms_resend_msg(struct dlm_msg *msg)
1365 {
1366         struct dlm_msg *msg_resend;
1367         char *ppc;
1368
1369         if (msg->retransmit)
1370                 return 1;
1371
1372         msg_resend = dlm_lowcomms_new_msg_con(msg->entry->con, msg->len,
1373                                               GFP_ATOMIC, &ppc, NULL, NULL);
1374         if (!msg_resend)
1375                 return -ENOMEM;
1376
1377         msg->retransmit = true;
1378         kref_get(&msg->ref);
1379         msg_resend->orig_msg = msg;
1380
1381         memcpy(ppc, msg->ppc, msg->len);
1382         _dlm_lowcomms_commit_msg(msg_resend);
1383         dlm_lowcomms_put_msg(msg_resend);
1384
1385         return 0;
1386 }
1387
1388 /* Send a message */
1389 static void send_to_sock(struct connection *con)
1390 {
1391         const int msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL;
1392         struct writequeue_entry *e;
1393         int len, offset, ret;
1394         int count = 0;
1395
1396         mutex_lock(&con->sock_mutex);
1397         if (con->sock == NULL)
1398                 goto out_connect;
1399
1400         spin_lock(&con->writequeue_lock);
1401         for (;;) {
1402                 e = con_next_wq(con);
1403                 if (!e)
1404                         break;
1405
1406                 e = list_first_entry(&con->writequeue, struct writequeue_entry, list);
1407                 len = e->len;
1408                 offset = e->offset;
1409                 BUG_ON(len == 0 && e->users == 0);
1410                 spin_unlock(&con->writequeue_lock);
1411
1412                 ret = kernel_sendpage(con->sock, e->page, offset, len,
1413                                       msg_flags);
1414                 if (ret == -EAGAIN || ret == 0) {
1415                         if (ret == -EAGAIN &&
1416                             test_bit(SOCKWQ_ASYNC_NOSPACE, &con->sock->flags) &&
1417                             !test_and_set_bit(CF_APP_LIMITED, &con->flags)) {
1418                                 /* Notify TCP that we're limited by the
1419                                  * application window size.
1420                                  */
1421                                 set_bit(SOCK_NOSPACE, &con->sock->flags);
1422                                 con->sock->sk->sk_write_pending++;
1423                         }
1424                         cond_resched();
1425                         goto out;
1426                 } else if (ret < 0)
1427                         goto out;
1428
1429                 /* Don't starve people filling buffers */
1430                 if (++count >= MAX_SEND_MSG_COUNT) {
1431                         cond_resched();
1432                         count = 0;
1433                 }
1434
1435                 spin_lock(&con->writequeue_lock);
1436                 writequeue_entry_complete(e, ret);
1437         }
1438         spin_unlock(&con->writequeue_lock);
1439
1440         /* close if we got EOF */
1441         if (test_and_clear_bit(CF_EOF, &con->flags)) {
1442                 mutex_unlock(&con->sock_mutex);
1443                 close_connection(con, false, false, true);
1444
1445                 /* handling for tcp shutdown */
1446                 clear_bit(CF_SHUTDOWN, &con->flags);
1447                 wake_up(&con->shutdown_wait);
1448         } else {
1449                 mutex_unlock(&con->sock_mutex);
1450         }
1451
1452         return;
1453
1454 out:
1455         mutex_unlock(&con->sock_mutex);
1456         return;
1457
1458 out_connect:
1459         mutex_unlock(&con->sock_mutex);
1460         queue_work(send_workqueue, &con->swork);
1461         cond_resched();
1462 }
1463
1464 static void clean_one_writequeue(struct connection *con)
1465 {
1466         struct writequeue_entry *e, *safe;
1467
1468         spin_lock(&con->writequeue_lock);
1469         list_for_each_entry_safe(e, safe, &con->writequeue, list) {
1470                 free_entry(e);
1471         }
1472         spin_unlock(&con->writequeue_lock);
1473 }
1474
1475 /* Called from recovery when it knows that a node has
1476    left the cluster */
1477 int dlm_lowcomms_close(int nodeid)
1478 {
1479         struct connection *con;
1480         struct dlm_node_addr *na;
1481         int idx;
1482
1483         log_print("closing connection to node %d", nodeid);
1484         idx = srcu_read_lock(&connections_srcu);
1485         con = nodeid2con(nodeid, 0);
1486         if (con) {
1487                 set_bit(CF_CLOSE, &con->flags);
1488                 close_connection(con, true, true, true);
1489                 clean_one_writequeue(con);
1490                 if (con->othercon)
1491                         clean_one_writequeue(con->othercon);
1492         }
1493         srcu_read_unlock(&connections_srcu, idx);
1494
1495         spin_lock(&dlm_node_addrs_spin);
1496         na = find_node_addr(nodeid);
1497         if (na) {
1498                 list_del(&na->list);
1499                 while (na->addr_count--)
1500                         kfree(na->addr[na->addr_count]);
1501                 kfree(na);
1502         }
1503         spin_unlock(&dlm_node_addrs_spin);
1504
1505         return 0;
1506 }
1507
1508 /* Receive workqueue function */
1509 static void process_recv_sockets(struct work_struct *work)
1510 {
1511         struct connection *con = container_of(work, struct connection, rwork);
1512
1513         clear_bit(CF_READ_PENDING, &con->flags);
1514         receive_from_sock(con);
1515 }
1516
1517 static void process_listen_recv_socket(struct work_struct *work)
1518 {
1519         accept_from_sock(&listen_con);
1520 }
1521
1522 static void dlm_connect(struct connection *con)
1523 {
1524         struct sockaddr_storage addr;
1525         int result, addr_len;
1526         struct socket *sock;
1527         unsigned int mark;
1528
1529         /* Some odd races can cause double-connects, ignore them */
1530         if (con->retries++ > MAX_CONNECT_RETRIES)
1531                 return;
1532
1533         if (con->sock) {
1534                 log_print("node %d already connected.", con->nodeid);
1535                 return;
1536         }
1537
1538         memset(&addr, 0, sizeof(addr));
1539         result = nodeid_to_addr(con->nodeid, &addr, NULL,
1540                                 dlm_proto_ops->try_new_addr, &mark);
1541         if (result < 0) {
1542                 log_print("no address for nodeid %d", con->nodeid);
1543                 return;
1544         }
1545
1546         /* Create a socket to communicate with */
1547         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1548                                   SOCK_STREAM, dlm_proto_ops->proto, &sock);
1549         if (result < 0)
1550                 goto socket_err;
1551
1552         sock_set_mark(sock->sk, mark);
1553         dlm_proto_ops->sockopts(sock);
1554
1555         add_sock(sock, con);
1556
1557         result = dlm_proto_ops->bind(sock);
1558         if (result < 0)
1559                 goto add_sock_err;
1560
1561         log_print_ratelimited("connecting to %d", con->nodeid);
1562         make_sockaddr(&addr, dlm_config.ci_tcp_port, &addr_len);
1563         result = dlm_proto_ops->connect(con, sock, (struct sockaddr *)&addr,
1564                                         addr_len);
1565         if (result < 0)
1566                 goto add_sock_err;
1567
1568         return;
1569
1570 add_sock_err:
1571         dlm_close_sock(&con->sock);
1572
1573 socket_err:
1574         /*
1575          * Some errors are fatal and this list might need adjusting. For other
1576          * errors we try again until the max number of retries is reached.
1577          */
1578         if (result != -EHOSTUNREACH &&
1579             result != -ENETUNREACH &&
1580             result != -ENETDOWN &&
1581             result != -EINVAL &&
1582             result != -EPROTONOSUPPORT) {
1583                 log_print("connect %d try %d error %d", con->nodeid,
1584                           con->retries, result);
1585                 msleep(1000);
1586                 lowcomms_connect_sock(con);
1587         }
1588 }
1589
1590 /* Send workqueue function */
1591 static void process_send_sockets(struct work_struct *work)
1592 {
1593         struct connection *con = container_of(work, struct connection, swork);
1594
1595         WARN_ON(test_bit(CF_IS_OTHERCON, &con->flags));
1596
1597         clear_bit(CF_WRITE_PENDING, &con->flags);
1598
1599         if (test_and_clear_bit(CF_RECONNECT, &con->flags)) {
1600                 close_connection(con, false, false, true);
1601                 dlm_midcomms_unack_msg_resend(con->nodeid);
1602         }
1603
1604         if (con->sock == NULL) {
1605                 if (test_and_clear_bit(CF_DELAY_CONNECT, &con->flags))
1606                         msleep(1000);
1607
1608                 mutex_lock(&con->sock_mutex);
1609                 dlm_connect(con);
1610                 mutex_unlock(&con->sock_mutex);
1611         }
1612
1613         if (!list_empty(&con->writequeue))
1614                 send_to_sock(con);
1615 }
1616
1617 static void work_stop(void)
1618 {
1619         if (recv_workqueue) {
1620                 destroy_workqueue(recv_workqueue);
1621                 recv_workqueue = NULL;
1622         }
1623
1624         if (send_workqueue) {
1625                 destroy_workqueue(send_workqueue);
1626                 send_workqueue = NULL;
1627         }
1628 }
1629
1630 static int work_start(void)
1631 {
1632         recv_workqueue = alloc_ordered_workqueue("dlm_recv", WQ_MEM_RECLAIM);
1633         if (!recv_workqueue) {
1634                 log_print("can't start dlm_recv");
1635                 return -ENOMEM;
1636         }
1637
1638         send_workqueue = alloc_ordered_workqueue("dlm_send", WQ_MEM_RECLAIM);
1639         if (!send_workqueue) {
1640                 log_print("can't start dlm_send");
1641                 destroy_workqueue(recv_workqueue);
1642                 recv_workqueue = NULL;
1643                 return -ENOMEM;
1644         }
1645
1646         return 0;
1647 }
1648
1649 static void shutdown_conn(struct connection *con)
1650 {
1651         if (dlm_proto_ops->shutdown_action)
1652                 dlm_proto_ops->shutdown_action(con);
1653 }
1654
1655 void dlm_lowcomms_shutdown(void)
1656 {
1657         int idx;
1658
1659         /* Set all the flags to prevent any
1660          * socket activity.
1661          */
1662         dlm_allow_conn = 0;
1663
1664         if (recv_workqueue)
1665                 flush_workqueue(recv_workqueue);
1666         if (send_workqueue)
1667                 flush_workqueue(send_workqueue);
1668
1669         dlm_close_sock(&listen_con.sock);
1670
1671         idx = srcu_read_lock(&connections_srcu);
1672         foreach_conn(shutdown_conn);
1673         srcu_read_unlock(&connections_srcu, idx);
1674 }
1675
1676 static void _stop_conn(struct connection *con, bool and_other)
1677 {
1678         mutex_lock(&con->sock_mutex);
1679         set_bit(CF_CLOSE, &con->flags);
1680         set_bit(CF_READ_PENDING, &con->flags);
1681         set_bit(CF_WRITE_PENDING, &con->flags);
1682         if (con->sock && con->sock->sk) {
1683                 write_lock_bh(&con->sock->sk->sk_callback_lock);
1684                 con->sock->sk->sk_user_data = NULL;
1685                 write_unlock_bh(&con->sock->sk->sk_callback_lock);
1686         }
1687         if (con->othercon && and_other)
1688                 _stop_conn(con->othercon, false);
1689         mutex_unlock(&con->sock_mutex);
1690 }
1691
1692 static void stop_conn(struct connection *con)
1693 {
1694         _stop_conn(con, true);
1695 }
1696
1697 static void connection_release(struct rcu_head *rcu)
1698 {
1699         struct connection *con = container_of(rcu, struct connection, rcu);
1700
1701         kfree(con->rx_buf);
1702         kfree(con);
1703 }
1704
1705 static void free_conn(struct connection *con)
1706 {
1707         close_connection(con, true, true, true);
1708         spin_lock(&connections_lock);
1709         hlist_del_rcu(&con->list);
1710         spin_unlock(&connections_lock);
1711         if (con->othercon) {
1712                 clean_one_writequeue(con->othercon);
1713                 call_srcu(&connections_srcu, &con->othercon->rcu,
1714                           connection_release);
1715         }
1716         clean_one_writequeue(con);
1717         call_srcu(&connections_srcu, &con->rcu, connection_release);
1718 }
1719
1720 static void work_flush(void)
1721 {
1722         int ok;
1723         int i;
1724         struct connection *con;
1725
1726         do {
1727                 ok = 1;
1728                 foreach_conn(stop_conn);
1729                 if (recv_workqueue)
1730                         flush_workqueue(recv_workqueue);
1731                 if (send_workqueue)
1732                         flush_workqueue(send_workqueue);
1733                 for (i = 0; i < CONN_HASH_SIZE && ok; i++) {
1734                         hlist_for_each_entry_rcu(con, &connection_hash[i],
1735                                                  list) {
1736                                 ok &= test_bit(CF_READ_PENDING, &con->flags);
1737                                 ok &= test_bit(CF_WRITE_PENDING, &con->flags);
1738                                 if (con->othercon) {
1739                                         ok &= test_bit(CF_READ_PENDING,
1740                                                        &con->othercon->flags);
1741                                         ok &= test_bit(CF_WRITE_PENDING,
1742                                                        &con->othercon->flags);
1743                                 }
1744                         }
1745                 }
1746         } while (!ok);
1747 }
1748
1749 void dlm_lowcomms_stop(void)
1750 {
1751         int idx;
1752
1753         idx = srcu_read_lock(&connections_srcu);
1754         work_flush();
1755         foreach_conn(free_conn);
1756         srcu_read_unlock(&connections_srcu, idx);
1757         work_stop();
1758         deinit_local();
1759
1760         dlm_proto_ops = NULL;
1761 }
1762
1763 static int dlm_listen_for_all(void)
1764 {
1765         struct socket *sock;
1766         int result;
1767
1768         log_print("Using %s for communications",
1769                   dlm_proto_ops->name);
1770
1771         result = dlm_proto_ops->listen_validate();
1772         if (result < 0)
1773                 return result;
1774
1775         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1776                                   SOCK_STREAM, dlm_proto_ops->proto, &sock);
1777         if (result < 0) {
1778                 log_print("Can't create comms socket, check SCTP is loaded");
1779                 goto out;
1780         }
1781
1782         sock_set_mark(sock->sk, dlm_config.ci_mark);
1783         dlm_proto_ops->listen_sockopts(sock);
1784
1785         result = dlm_proto_ops->listen_bind(sock);
1786         if (result < 0)
1787                 goto out;
1788
1789         save_listen_callbacks(sock);
1790         add_listen_sock(sock, &listen_con);
1791
1792         INIT_WORK(&listen_con.rwork, process_listen_recv_socket);
1793         result = sock->ops->listen(sock, 5);
1794         if (result < 0) {
1795                 dlm_close_sock(&listen_con.sock);
1796                 goto out;
1797         }
1798
1799         return 0;
1800
1801 out:
1802         sock_release(sock);
1803         return result;
1804 }
1805
1806 static int dlm_tcp_bind(struct socket *sock)
1807 {
1808         struct sockaddr_storage src_addr;
1809         int result, addr_len;
1810
1811         /* Bind to our cluster-known address connecting to avoid
1812          * routing problems.
1813          */
1814         memcpy(&src_addr, dlm_local_addr[0], sizeof(src_addr));
1815         make_sockaddr(&src_addr, 0, &addr_len);
1816
1817         result = sock->ops->bind(sock, (struct sockaddr *)&src_addr,
1818                                  addr_len);
1819         if (result < 0) {
1820                 /* This *may* not indicate a critical error */
1821                 log_print("could not bind for connect: %d", result);
1822         }
1823
1824         return 0;
1825 }
1826
1827 static int dlm_tcp_connect(struct connection *con, struct socket *sock,
1828                            struct sockaddr *addr, int addr_len)
1829 {
1830         int ret;
1831
1832         ret = sock->ops->connect(sock, addr, addr_len, O_NONBLOCK);
1833         switch (ret) {
1834         case -EINPROGRESS:
1835                 fallthrough;
1836         case 0:
1837                 return 0;
1838         }
1839
1840         return ret;
1841 }
1842
1843 static int dlm_tcp_listen_validate(void)
1844 {
1845         /* We don't support multi-homed hosts */
1846         if (dlm_local_count > 1) {
1847                 log_print("TCP protocol can't handle multi-homed hosts, try SCTP");
1848                 return -EINVAL;
1849         }
1850
1851         return 0;
1852 }
1853
1854 static void dlm_tcp_sockopts(struct socket *sock)
1855 {
1856         /* Turn off Nagle's algorithm */
1857         tcp_sock_set_nodelay(sock->sk);
1858 }
1859
1860 static void dlm_tcp_listen_sockopts(struct socket *sock)
1861 {
1862         dlm_tcp_sockopts(sock);
1863         sock_set_reuseaddr(sock->sk);
1864 }
1865
1866 static int dlm_tcp_listen_bind(struct socket *sock)
1867 {
1868         int addr_len;
1869
1870         /* Bind to our port */
1871         make_sockaddr(dlm_local_addr[0], dlm_config.ci_tcp_port, &addr_len);
1872         return sock->ops->bind(sock, (struct sockaddr *)dlm_local_addr[0],
1873                                addr_len);
1874 }
1875
1876 static const struct dlm_proto_ops dlm_tcp_ops = {
1877         .name = "TCP",
1878         .proto = IPPROTO_TCP,
1879         .connect = dlm_tcp_connect,
1880         .sockopts = dlm_tcp_sockopts,
1881         .bind = dlm_tcp_bind,
1882         .listen_validate = dlm_tcp_listen_validate,
1883         .listen_sockopts = dlm_tcp_listen_sockopts,
1884         .listen_bind = dlm_tcp_listen_bind,
1885         .shutdown_action = dlm_tcp_shutdown,
1886         .eof_condition = tcp_eof_condition,
1887 };
1888
1889 static int dlm_sctp_bind(struct socket *sock)
1890 {
1891         return sctp_bind_addrs(sock, 0);
1892 }
1893
1894 static int dlm_sctp_connect(struct connection *con, struct socket *sock,
1895                             struct sockaddr *addr, int addr_len)
1896 {
1897         int ret;
1898
1899         /*
1900          * Make sock->ops->connect() function return in specified time,
1901          * since O_NONBLOCK argument in connect() function does not work here,
1902          * then, we should restore the default value of this attribute.
1903          */
1904         sock_set_sndtimeo(sock->sk, 5);
1905         ret = sock->ops->connect(sock, addr, addr_len, 0);
1906         sock_set_sndtimeo(sock->sk, 0);
1907         if (ret < 0)
1908                 return ret;
1909
1910         if (!test_and_set_bit(CF_CONNECTED, &con->flags))
1911                 log_print("successful connected to node %d", con->nodeid);
1912
1913         return 0;
1914 }
1915
1916 static int dlm_sctp_listen_validate(void)
1917 {
1918         if (!IS_ENABLED(CONFIG_IP_SCTP)) {
1919                 log_print("SCTP is not enabled by this kernel");
1920                 return -EOPNOTSUPP;
1921         }
1922
1923         request_module("sctp");
1924         return 0;
1925 }
1926
1927 static int dlm_sctp_bind_listen(struct socket *sock)
1928 {
1929         return sctp_bind_addrs(sock, dlm_config.ci_tcp_port);
1930 }
1931
1932 static void dlm_sctp_sockopts(struct socket *sock)
1933 {
1934         /* Turn off Nagle's algorithm */
1935         sctp_sock_set_nodelay(sock->sk);
1936         sock_set_rcvbuf(sock->sk, NEEDED_RMEM);
1937 }
1938
1939 static const struct dlm_proto_ops dlm_sctp_ops = {
1940         .name = "SCTP",
1941         .proto = IPPROTO_SCTP,
1942         .try_new_addr = true,
1943         .connect = dlm_sctp_connect,
1944         .sockopts = dlm_sctp_sockopts,
1945         .bind = dlm_sctp_bind,
1946         .listen_validate = dlm_sctp_listen_validate,
1947         .listen_sockopts = dlm_sctp_sockopts,
1948         .listen_bind = dlm_sctp_bind_listen,
1949 };
1950
1951 int dlm_lowcomms_start(void)
1952 {
1953         int error = -EINVAL;
1954         int i;
1955
1956         for (i = 0; i < CONN_HASH_SIZE; i++)
1957                 INIT_HLIST_HEAD(&connection_hash[i]);
1958
1959         init_local();
1960         if (!dlm_local_count) {
1961                 error = -ENOTCONN;
1962                 log_print("no local IP address has been set");
1963                 goto fail;
1964         }
1965
1966         INIT_WORK(&listen_con.rwork, process_listen_recv_socket);
1967
1968         error = work_start();
1969         if (error)
1970                 goto fail_local;
1971
1972         dlm_allow_conn = 1;
1973
1974         /* Start listening */
1975         switch (dlm_config.ci_protocol) {
1976         case DLM_PROTO_TCP:
1977                 dlm_proto_ops = &dlm_tcp_ops;
1978                 break;
1979         case DLM_PROTO_SCTP:
1980                 dlm_proto_ops = &dlm_sctp_ops;
1981                 break;
1982         default:
1983                 log_print("Invalid protocol identifier %d set",
1984                           dlm_config.ci_protocol);
1985                 error = -EINVAL;
1986                 goto fail_proto_ops;
1987         }
1988
1989         error = dlm_listen_for_all();
1990         if (error)
1991                 goto fail_listen;
1992
1993         return 0;
1994
1995 fail_listen:
1996         dlm_proto_ops = NULL;
1997 fail_proto_ops:
1998         dlm_allow_conn = 0;
1999         dlm_close_sock(&listen_con.sock);
2000         work_stop();
2001 fail_local:
2002         deinit_local();
2003 fail:
2004         return error;
2005 }
2006
2007 void dlm_lowcomms_exit(void)
2008 {
2009         struct dlm_node_addr *na, *safe;
2010
2011         spin_lock(&dlm_node_addrs_spin);
2012         list_for_each_entry_safe(na, safe, &dlm_node_addrs, list) {
2013                 list_del(&na->list);
2014                 while (na->addr_count--)
2015                         kfree(na->addr[na->addr_count]);
2016                 kfree(na);
2017         }
2018         spin_unlock(&dlm_node_addrs_spin);
2019 }