fs: dlm: fix lowcomms_start error case
[linux-2.6-microblaze.git] / fs / dlm / lowcomms.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /******************************************************************************
3 *******************************************************************************
4 **
5 **  Copyright (C) Sistina Software, Inc.  1997-2003  All rights reserved.
6 **  Copyright (C) 2004-2009 Red Hat, Inc.  All rights reserved.
7 **
8 **
9 *******************************************************************************
10 ******************************************************************************/
11
12 /*
13  * lowcomms.c
14  *
15  * This is the "low-level" comms layer.
16  *
17  * It is responsible for sending/receiving messages
18  * from other nodes in the cluster.
19  *
20  * Cluster nodes are referred to by their nodeids. nodeids are
21  * simply 32 bit numbers to the locking module - if they need to
22  * be expanded for the cluster infrastructure then that is its
23  * responsibility. It is this layer's
24  * responsibility to resolve these into IP address or
25  * whatever it needs for inter-node communication.
26  *
27  * The comms level is two kernel threads that deal mainly with
28  * the receiving of messages from other nodes and passing them
29  * up to the mid-level comms layer (which understands the
30  * message format) for execution by the locking core, and
31  * a send thread which does all the setting up of connections
32  * to remote nodes and the sending of data. Threads are not allowed
33  * to send their own data because it may cause them to wait in times
34  * of high load. Also, this way, the sending thread can collect together
35  * messages bound for one node and send them in one block.
36  *
37  * lowcomms will choose to use either TCP or SCTP as its transport layer
38  * depending on the configuration variable 'protocol'. This should be set
39  * to 0 (default) for TCP or 1 for SCTP. It should be configured using a
40  * cluster-wide mechanism as it must be the same on all nodes of the cluster
41  * for the DLM to function.
42  *
43  */
44
45 #include <asm/ioctls.h>
46 #include <net/sock.h>
47 #include <net/tcp.h>
48 #include <linux/pagemap.h>
49 #include <linux/file.h>
50 #include <linux/mutex.h>
51 #include <linux/sctp.h>
52 #include <linux/slab.h>
53 #include <net/sctp/sctp.h>
54 #include <net/ipv6.h>
55
56 #include "dlm_internal.h"
57 #include "lowcomms.h"
58 #include "midcomms.h"
59 #include "config.h"
60
61 #define NEEDED_RMEM (4*1024*1024)
62
63 /* Number of messages to send before rescheduling */
64 #define MAX_SEND_MSG_COUNT 25
65 #define DLM_SHUTDOWN_WAIT_TIMEOUT msecs_to_jiffies(10000)
66
67 struct connection {
68         struct socket *sock;    /* NULL if not connected */
69         uint32_t nodeid;        /* So we know who we are in the list */
70         struct mutex sock_mutex;
71         unsigned long flags;
72 #define CF_READ_PENDING 1
73 #define CF_WRITE_PENDING 2
74 #define CF_INIT_PENDING 4
75 #define CF_IS_OTHERCON 5
76 #define CF_CLOSE 6
77 #define CF_APP_LIMITED 7
78 #define CF_CLOSING 8
79 #define CF_SHUTDOWN 9
80 #define CF_CONNECTED 10
81 #define CF_RECONNECT 11
82 #define CF_DELAY_CONNECT 12
83 #define CF_EOF 13
84         struct list_head writequeue;  /* List of outgoing writequeue_entries */
85         spinlock_t writequeue_lock;
86         atomic_t writequeue_cnt;
87         void (*connect_action) (struct connection *);   /* What to do to connect */
88         void (*shutdown_action)(struct connection *con); /* What to do to shutdown */
89         bool (*eof_condition)(struct connection *con); /* What to do to eof check */
90         int retries;
91 #define MAX_CONNECT_RETRIES 3
92         struct hlist_node list;
93         struct connection *othercon;
94         struct connection *sendcon;
95         struct work_struct rwork; /* Receive workqueue */
96         struct work_struct swork; /* Send workqueue */
97         wait_queue_head_t shutdown_wait; /* wait for graceful shutdown */
98         unsigned char *rx_buf;
99         int rx_buflen;
100         int rx_leftover;
101         struct rcu_head rcu;
102 };
103 #define sock2con(x) ((struct connection *)(x)->sk_user_data)
104
105 struct listen_connection {
106         struct socket *sock;
107         struct work_struct rwork;
108 };
109
110 #define DLM_WQ_REMAIN_BYTES(e) (PAGE_SIZE - e->end)
111 #define DLM_WQ_LENGTH_BYTES(e) (e->end - e->offset)
112
113 /* An entry waiting to be sent */
114 struct writequeue_entry {
115         struct list_head list;
116         struct page *page;
117         int offset;
118         int len;
119         int end;
120         int users;
121         bool dirty;
122         struct connection *con;
123         struct list_head msgs;
124         struct kref ref;
125 };
126
127 struct dlm_msg {
128         struct writequeue_entry *entry;
129         struct dlm_msg *orig_msg;
130         bool retransmit;
131         void *ppc;
132         int len;
133         int idx; /* new()/commit() idx exchange */
134
135         struct list_head list;
136         struct kref ref;
137 };
138
139 struct dlm_node_addr {
140         struct list_head list;
141         int nodeid;
142         int mark;
143         int addr_count;
144         int curr_addr_index;
145         struct sockaddr_storage *addr[DLM_MAX_ADDR_COUNT];
146 };
147
148 static struct listen_sock_callbacks {
149         void (*sk_error_report)(struct sock *);
150         void (*sk_data_ready)(struct sock *);
151         void (*sk_state_change)(struct sock *);
152         void (*sk_write_space)(struct sock *);
153 } listen_sock;
154
155 static LIST_HEAD(dlm_node_addrs);
156 static DEFINE_SPINLOCK(dlm_node_addrs_spin);
157
158 static struct listen_connection listen_con;
159 static struct sockaddr_storage *dlm_local_addr[DLM_MAX_ADDR_COUNT];
160 static int dlm_local_count;
161 int dlm_allow_conn;
162
163 /* Work queues */
164 static struct workqueue_struct *recv_workqueue;
165 static struct workqueue_struct *send_workqueue;
166
167 static struct hlist_head connection_hash[CONN_HASH_SIZE];
168 static DEFINE_SPINLOCK(connections_lock);
169 DEFINE_STATIC_SRCU(connections_srcu);
170
171 static void process_recv_sockets(struct work_struct *work);
172 static void process_send_sockets(struct work_struct *work);
173
174 static void sctp_connect_to_sock(struct connection *con);
175 static void tcp_connect_to_sock(struct connection *con);
176 static void dlm_tcp_shutdown(struct connection *con);
177
178 static struct connection *__find_con(int nodeid, int r)
179 {
180         struct connection *con;
181
182         hlist_for_each_entry_rcu(con, &connection_hash[r], list) {
183                 if (con->nodeid == nodeid)
184                         return con;
185         }
186
187         return NULL;
188 }
189
190 static bool tcp_eof_condition(struct connection *con)
191 {
192         return atomic_read(&con->writequeue_cnt);
193 }
194
195 static int dlm_con_init(struct connection *con, int nodeid)
196 {
197         con->rx_buflen = dlm_config.ci_buffer_size;
198         con->rx_buf = kmalloc(con->rx_buflen, GFP_NOFS);
199         if (!con->rx_buf)
200                 return -ENOMEM;
201
202         con->nodeid = nodeid;
203         mutex_init(&con->sock_mutex);
204         INIT_LIST_HEAD(&con->writequeue);
205         spin_lock_init(&con->writequeue_lock);
206         atomic_set(&con->writequeue_cnt, 0);
207         INIT_WORK(&con->swork, process_send_sockets);
208         INIT_WORK(&con->rwork, process_recv_sockets);
209         init_waitqueue_head(&con->shutdown_wait);
210
211         if (dlm_config.ci_protocol == 0) {
212                 con->connect_action = tcp_connect_to_sock;
213                 con->shutdown_action = dlm_tcp_shutdown;
214                 con->eof_condition = tcp_eof_condition;
215         } else {
216                 con->connect_action = sctp_connect_to_sock;
217         }
218
219         return 0;
220 }
221
222 /*
223  * If 'allocation' is zero then we don't attempt to create a new
224  * connection structure for this node.
225  */
226 static struct connection *nodeid2con(int nodeid, gfp_t alloc)
227 {
228         struct connection *con, *tmp;
229         int r, ret;
230
231         r = nodeid_hash(nodeid);
232         con = __find_con(nodeid, r);
233         if (con || !alloc)
234                 return con;
235
236         con = kzalloc(sizeof(*con), alloc);
237         if (!con)
238                 return NULL;
239
240         ret = dlm_con_init(con, nodeid);
241         if (ret) {
242                 kfree(con);
243                 return NULL;
244         }
245
246         spin_lock(&connections_lock);
247         /* Because multiple workqueues/threads calls this function it can
248          * race on multiple cpu's. Instead of locking hot path __find_con()
249          * we just check in rare cases of recently added nodes again
250          * under protection of connections_lock. If this is the case we
251          * abort our connection creation and return the existing connection.
252          */
253         tmp = __find_con(nodeid, r);
254         if (tmp) {
255                 spin_unlock(&connections_lock);
256                 kfree(con->rx_buf);
257                 kfree(con);
258                 return tmp;
259         }
260
261         hlist_add_head_rcu(&con->list, &connection_hash[r]);
262         spin_unlock(&connections_lock);
263
264         return con;
265 }
266
267 /* Loop round all connections */
268 static void foreach_conn(void (*conn_func)(struct connection *c))
269 {
270         int i;
271         struct connection *con;
272
273         for (i = 0; i < CONN_HASH_SIZE; i++) {
274                 hlist_for_each_entry_rcu(con, &connection_hash[i], list)
275                         conn_func(con);
276         }
277 }
278
279 static struct dlm_node_addr *find_node_addr(int nodeid)
280 {
281         struct dlm_node_addr *na;
282
283         list_for_each_entry(na, &dlm_node_addrs, list) {
284                 if (na->nodeid == nodeid)
285                         return na;
286         }
287         return NULL;
288 }
289
290 static int addr_compare(const struct sockaddr_storage *x,
291                         const struct sockaddr_storage *y)
292 {
293         switch (x->ss_family) {
294         case AF_INET: {
295                 struct sockaddr_in *sinx = (struct sockaddr_in *)x;
296                 struct sockaddr_in *siny = (struct sockaddr_in *)y;
297                 if (sinx->sin_addr.s_addr != siny->sin_addr.s_addr)
298                         return 0;
299                 if (sinx->sin_port != siny->sin_port)
300                         return 0;
301                 break;
302         }
303         case AF_INET6: {
304                 struct sockaddr_in6 *sinx = (struct sockaddr_in6 *)x;
305                 struct sockaddr_in6 *siny = (struct sockaddr_in6 *)y;
306                 if (!ipv6_addr_equal(&sinx->sin6_addr, &siny->sin6_addr))
307                         return 0;
308                 if (sinx->sin6_port != siny->sin6_port)
309                         return 0;
310                 break;
311         }
312         default:
313                 return 0;
314         }
315         return 1;
316 }
317
318 static int nodeid_to_addr(int nodeid, struct sockaddr_storage *sas_out,
319                           struct sockaddr *sa_out, bool try_new_addr,
320                           unsigned int *mark)
321 {
322         struct sockaddr_storage sas;
323         struct dlm_node_addr *na;
324
325         if (!dlm_local_count)
326                 return -1;
327
328         spin_lock(&dlm_node_addrs_spin);
329         na = find_node_addr(nodeid);
330         if (na && na->addr_count) {
331                 memcpy(&sas, na->addr[na->curr_addr_index],
332                        sizeof(struct sockaddr_storage));
333
334                 if (try_new_addr) {
335                         na->curr_addr_index++;
336                         if (na->curr_addr_index == na->addr_count)
337                                 na->curr_addr_index = 0;
338                 }
339         }
340         spin_unlock(&dlm_node_addrs_spin);
341
342         if (!na)
343                 return -EEXIST;
344
345         if (!na->addr_count)
346                 return -ENOENT;
347
348         *mark = na->mark;
349
350         if (sas_out)
351                 memcpy(sas_out, &sas, sizeof(struct sockaddr_storage));
352
353         if (!sa_out)
354                 return 0;
355
356         if (dlm_local_addr[0]->ss_family == AF_INET) {
357                 struct sockaddr_in *in4  = (struct sockaddr_in *) &sas;
358                 struct sockaddr_in *ret4 = (struct sockaddr_in *) sa_out;
359                 ret4->sin_addr.s_addr = in4->sin_addr.s_addr;
360         } else {
361                 struct sockaddr_in6 *in6  = (struct sockaddr_in6 *) &sas;
362                 struct sockaddr_in6 *ret6 = (struct sockaddr_in6 *) sa_out;
363                 ret6->sin6_addr = in6->sin6_addr;
364         }
365
366         return 0;
367 }
368
369 static int addr_to_nodeid(struct sockaddr_storage *addr, int *nodeid,
370                           unsigned int *mark)
371 {
372         struct dlm_node_addr *na;
373         int rv = -EEXIST;
374         int addr_i;
375
376         spin_lock(&dlm_node_addrs_spin);
377         list_for_each_entry(na, &dlm_node_addrs, list) {
378                 if (!na->addr_count)
379                         continue;
380
381                 for (addr_i = 0; addr_i < na->addr_count; addr_i++) {
382                         if (addr_compare(na->addr[addr_i], addr)) {
383                                 *nodeid = na->nodeid;
384                                 *mark = na->mark;
385                                 rv = 0;
386                                 goto unlock;
387                         }
388                 }
389         }
390 unlock:
391         spin_unlock(&dlm_node_addrs_spin);
392         return rv;
393 }
394
395 /* caller need to held dlm_node_addrs_spin lock */
396 static bool dlm_lowcomms_na_has_addr(const struct dlm_node_addr *na,
397                                      const struct sockaddr_storage *addr)
398 {
399         int i;
400
401         for (i = 0; i < na->addr_count; i++) {
402                 if (addr_compare(na->addr[i], addr))
403                         return true;
404         }
405
406         return false;
407 }
408
409 int dlm_lowcomms_addr(int nodeid, struct sockaddr_storage *addr, int len)
410 {
411         struct sockaddr_storage *new_addr;
412         struct dlm_node_addr *new_node, *na;
413         bool ret;
414
415         new_node = kzalloc(sizeof(struct dlm_node_addr), GFP_NOFS);
416         if (!new_node)
417                 return -ENOMEM;
418
419         new_addr = kzalloc(sizeof(struct sockaddr_storage), GFP_NOFS);
420         if (!new_addr) {
421                 kfree(new_node);
422                 return -ENOMEM;
423         }
424
425         memcpy(new_addr, addr, len);
426
427         spin_lock(&dlm_node_addrs_spin);
428         na = find_node_addr(nodeid);
429         if (!na) {
430                 new_node->nodeid = nodeid;
431                 new_node->addr[0] = new_addr;
432                 new_node->addr_count = 1;
433                 new_node->mark = dlm_config.ci_mark;
434                 list_add(&new_node->list, &dlm_node_addrs);
435                 spin_unlock(&dlm_node_addrs_spin);
436                 return 0;
437         }
438
439         ret = dlm_lowcomms_na_has_addr(na, addr);
440         if (ret) {
441                 spin_unlock(&dlm_node_addrs_spin);
442                 kfree(new_addr);
443                 kfree(new_node);
444                 return -EEXIST;
445         }
446
447         if (na->addr_count >= DLM_MAX_ADDR_COUNT) {
448                 spin_unlock(&dlm_node_addrs_spin);
449                 kfree(new_addr);
450                 kfree(new_node);
451                 return -ENOSPC;
452         }
453
454         na->addr[na->addr_count++] = new_addr;
455         spin_unlock(&dlm_node_addrs_spin);
456         kfree(new_node);
457         return 0;
458 }
459
460 /* Data available on socket or listen socket received a connect */
461 static void lowcomms_data_ready(struct sock *sk)
462 {
463         struct connection *con;
464
465         read_lock_bh(&sk->sk_callback_lock);
466         con = sock2con(sk);
467         if (con && !test_and_set_bit(CF_READ_PENDING, &con->flags))
468                 queue_work(recv_workqueue, &con->rwork);
469         read_unlock_bh(&sk->sk_callback_lock);
470 }
471
472 static void lowcomms_listen_data_ready(struct sock *sk)
473 {
474         queue_work(recv_workqueue, &listen_con.rwork);
475 }
476
477 static void lowcomms_write_space(struct sock *sk)
478 {
479         struct connection *con;
480
481         read_lock_bh(&sk->sk_callback_lock);
482         con = sock2con(sk);
483         if (!con)
484                 goto out;
485
486         if (!test_and_set_bit(CF_CONNECTED, &con->flags)) {
487                 log_print("successful connected to node %d", con->nodeid);
488                 queue_work(send_workqueue, &con->swork);
489                 goto out;
490         }
491
492         clear_bit(SOCK_NOSPACE, &con->sock->flags);
493
494         if (test_and_clear_bit(CF_APP_LIMITED, &con->flags)) {
495                 con->sock->sk->sk_write_pending--;
496                 clear_bit(SOCKWQ_ASYNC_NOSPACE, &con->sock->flags);
497         }
498
499         queue_work(send_workqueue, &con->swork);
500 out:
501         read_unlock_bh(&sk->sk_callback_lock);
502 }
503
504 static inline void lowcomms_connect_sock(struct connection *con)
505 {
506         if (test_bit(CF_CLOSE, &con->flags))
507                 return;
508         queue_work(send_workqueue, &con->swork);
509         cond_resched();
510 }
511
512 static void lowcomms_state_change(struct sock *sk)
513 {
514         /* SCTP layer is not calling sk_data_ready when the connection
515          * is done, so we catch the signal through here. Also, it
516          * doesn't switch socket state when entering shutdown, so we
517          * skip the write in that case.
518          */
519         if (sk->sk_shutdown) {
520                 if (sk->sk_shutdown == RCV_SHUTDOWN)
521                         lowcomms_data_ready(sk);
522         } else if (sk->sk_state == TCP_ESTABLISHED) {
523                 lowcomms_write_space(sk);
524         }
525 }
526
527 int dlm_lowcomms_connect_node(int nodeid)
528 {
529         struct connection *con;
530         int idx;
531
532         if (nodeid == dlm_our_nodeid())
533                 return 0;
534
535         idx = srcu_read_lock(&connections_srcu);
536         con = nodeid2con(nodeid, GFP_NOFS);
537         if (!con) {
538                 srcu_read_unlock(&connections_srcu, idx);
539                 return -ENOMEM;
540         }
541
542         lowcomms_connect_sock(con);
543         srcu_read_unlock(&connections_srcu, idx);
544
545         return 0;
546 }
547
548 int dlm_lowcomms_nodes_set_mark(int nodeid, unsigned int mark)
549 {
550         struct dlm_node_addr *na;
551
552         spin_lock(&dlm_node_addrs_spin);
553         na = find_node_addr(nodeid);
554         if (!na) {
555                 spin_unlock(&dlm_node_addrs_spin);
556                 return -ENOENT;
557         }
558
559         na->mark = mark;
560         spin_unlock(&dlm_node_addrs_spin);
561
562         return 0;
563 }
564
565 static void lowcomms_error_report(struct sock *sk)
566 {
567         struct connection *con;
568         struct sockaddr_storage saddr;
569         void (*orig_report)(struct sock *) = NULL;
570
571         read_lock_bh(&sk->sk_callback_lock);
572         con = sock2con(sk);
573         if (con == NULL)
574                 goto out;
575
576         orig_report = listen_sock.sk_error_report;
577         if (con->sock == NULL ||
578             kernel_getpeername(con->sock, (struct sockaddr *)&saddr) < 0) {
579                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
580                                    "sending to node %d, port %d, "
581                                    "sk_err=%d/%d\n", dlm_our_nodeid(),
582                                    con->nodeid, dlm_config.ci_tcp_port,
583                                    sk->sk_err, sk->sk_err_soft);
584         } else if (saddr.ss_family == AF_INET) {
585                 struct sockaddr_in *sin4 = (struct sockaddr_in *)&saddr;
586
587                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
588                                    "sending to node %d at %pI4, port %d, "
589                                    "sk_err=%d/%d\n", dlm_our_nodeid(),
590                                    con->nodeid, &sin4->sin_addr.s_addr,
591                                    dlm_config.ci_tcp_port, sk->sk_err,
592                                    sk->sk_err_soft);
593         } else {
594                 struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&saddr;
595
596                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
597                                    "sending to node %d at %u.%u.%u.%u, "
598                                    "port %d, sk_err=%d/%d\n", dlm_our_nodeid(),
599                                    con->nodeid, sin6->sin6_addr.s6_addr32[0],
600                                    sin6->sin6_addr.s6_addr32[1],
601                                    sin6->sin6_addr.s6_addr32[2],
602                                    sin6->sin6_addr.s6_addr32[3],
603                                    dlm_config.ci_tcp_port, sk->sk_err,
604                                    sk->sk_err_soft);
605         }
606
607         /* below sendcon only handling */
608         if (test_bit(CF_IS_OTHERCON, &con->flags))
609                 con = con->sendcon;
610
611         switch (sk->sk_err) {
612         case ECONNREFUSED:
613                 set_bit(CF_DELAY_CONNECT, &con->flags);
614                 break;
615         default:
616                 break;
617         }
618
619         if (!test_and_set_bit(CF_RECONNECT, &con->flags))
620                 queue_work(send_workqueue, &con->swork);
621
622 out:
623         read_unlock_bh(&sk->sk_callback_lock);
624         if (orig_report)
625                 orig_report(sk);
626 }
627
628 /* Note: sk_callback_lock must be locked before calling this function. */
629 static void save_listen_callbacks(struct socket *sock)
630 {
631         struct sock *sk = sock->sk;
632
633         listen_sock.sk_data_ready = sk->sk_data_ready;
634         listen_sock.sk_state_change = sk->sk_state_change;
635         listen_sock.sk_write_space = sk->sk_write_space;
636         listen_sock.sk_error_report = sk->sk_error_report;
637 }
638
639 static void restore_callbacks(struct socket *sock)
640 {
641         struct sock *sk = sock->sk;
642
643         write_lock_bh(&sk->sk_callback_lock);
644         sk->sk_user_data = NULL;
645         sk->sk_data_ready = listen_sock.sk_data_ready;
646         sk->sk_state_change = listen_sock.sk_state_change;
647         sk->sk_write_space = listen_sock.sk_write_space;
648         sk->sk_error_report = listen_sock.sk_error_report;
649         write_unlock_bh(&sk->sk_callback_lock);
650 }
651
652 static void add_listen_sock(struct socket *sock, struct listen_connection *con)
653 {
654         struct sock *sk = sock->sk;
655
656         write_lock_bh(&sk->sk_callback_lock);
657         save_listen_callbacks(sock);
658         con->sock = sock;
659
660         sk->sk_user_data = con;
661         sk->sk_allocation = GFP_NOFS;
662         /* Install a data_ready callback */
663         sk->sk_data_ready = lowcomms_listen_data_ready;
664         write_unlock_bh(&sk->sk_callback_lock);
665 }
666
667 /* Make a socket active */
668 static void add_sock(struct socket *sock, struct connection *con)
669 {
670         struct sock *sk = sock->sk;
671
672         write_lock_bh(&sk->sk_callback_lock);
673         con->sock = sock;
674
675         sk->sk_user_data = con;
676         /* Install a data_ready callback */
677         sk->sk_data_ready = lowcomms_data_ready;
678         sk->sk_write_space = lowcomms_write_space;
679         sk->sk_state_change = lowcomms_state_change;
680         sk->sk_allocation = GFP_NOFS;
681         sk->sk_error_report = lowcomms_error_report;
682         write_unlock_bh(&sk->sk_callback_lock);
683 }
684
685 /* Add the port number to an IPv6 or 4 sockaddr and return the address
686    length */
687 static void make_sockaddr(struct sockaddr_storage *saddr, uint16_t port,
688                           int *addr_len)
689 {
690         saddr->ss_family =  dlm_local_addr[0]->ss_family;
691         if (saddr->ss_family == AF_INET) {
692                 struct sockaddr_in *in4_addr = (struct sockaddr_in *)saddr;
693                 in4_addr->sin_port = cpu_to_be16(port);
694                 *addr_len = sizeof(struct sockaddr_in);
695                 memset(&in4_addr->sin_zero, 0, sizeof(in4_addr->sin_zero));
696         } else {
697                 struct sockaddr_in6 *in6_addr = (struct sockaddr_in6 *)saddr;
698                 in6_addr->sin6_port = cpu_to_be16(port);
699                 *addr_len = sizeof(struct sockaddr_in6);
700         }
701         memset((char *)saddr + *addr_len, 0, sizeof(struct sockaddr_storage) - *addr_len);
702 }
703
704 static void dlm_page_release(struct kref *kref)
705 {
706         struct writequeue_entry *e = container_of(kref, struct writequeue_entry,
707                                                   ref);
708
709         __free_page(e->page);
710         kfree(e);
711 }
712
713 static void dlm_msg_release(struct kref *kref)
714 {
715         struct dlm_msg *msg = container_of(kref, struct dlm_msg, ref);
716
717         kref_put(&msg->entry->ref, dlm_page_release);
718         kfree(msg);
719 }
720
721 static void free_entry(struct writequeue_entry *e)
722 {
723         struct dlm_msg *msg, *tmp;
724
725         list_for_each_entry_safe(msg, tmp, &e->msgs, list) {
726                 if (msg->orig_msg) {
727                         msg->orig_msg->retransmit = false;
728                         kref_put(&msg->orig_msg->ref, dlm_msg_release);
729                 }
730
731                 list_del(&msg->list);
732                 kref_put(&msg->ref, dlm_msg_release);
733         }
734
735         list_del(&e->list);
736         atomic_dec(&e->con->writequeue_cnt);
737         kref_put(&e->ref, dlm_page_release);
738 }
739
740 static void dlm_close_sock(struct socket **sock)
741 {
742         if (*sock) {
743                 restore_callbacks(*sock);
744                 sock_release(*sock);
745                 *sock = NULL;
746         }
747 }
748
749 /* Close a remote connection and tidy up */
750 static void close_connection(struct connection *con, bool and_other,
751                              bool tx, bool rx)
752 {
753         bool closing = test_and_set_bit(CF_CLOSING, &con->flags);
754         struct writequeue_entry *e;
755
756         if (tx && !closing && cancel_work_sync(&con->swork)) {
757                 log_print("canceled swork for node %d", con->nodeid);
758                 clear_bit(CF_WRITE_PENDING, &con->flags);
759         }
760         if (rx && !closing && cancel_work_sync(&con->rwork)) {
761                 log_print("canceled rwork for node %d", con->nodeid);
762                 clear_bit(CF_READ_PENDING, &con->flags);
763         }
764
765         mutex_lock(&con->sock_mutex);
766         dlm_close_sock(&con->sock);
767
768         if (con->othercon && and_other) {
769                 /* Will only re-enter once. */
770                 close_connection(con->othercon, false, tx, rx);
771         }
772
773         /* if we send a writequeue entry only a half way, we drop the
774          * whole entry because reconnection and that we not start of the
775          * middle of a msg which will confuse the other end.
776          *
777          * we can always drop messages because retransmits, but what we
778          * cannot allow is to transmit half messages which may be processed
779          * at the other side.
780          *
781          * our policy is to start on a clean state when disconnects, we don't
782          * know what's send/received on transport layer in this case.
783          */
784         spin_lock(&con->writequeue_lock);
785         if (!list_empty(&con->writequeue)) {
786                 e = list_first_entry(&con->writequeue, struct writequeue_entry,
787                                      list);
788                 if (e->dirty)
789                         free_entry(e);
790         }
791         spin_unlock(&con->writequeue_lock);
792
793         con->rx_leftover = 0;
794         con->retries = 0;
795         clear_bit(CF_CONNECTED, &con->flags);
796         clear_bit(CF_DELAY_CONNECT, &con->flags);
797         clear_bit(CF_RECONNECT, &con->flags);
798         clear_bit(CF_EOF, &con->flags);
799         mutex_unlock(&con->sock_mutex);
800         clear_bit(CF_CLOSING, &con->flags);
801 }
802
803 static void shutdown_connection(struct connection *con)
804 {
805         int ret;
806
807         flush_work(&con->swork);
808
809         mutex_lock(&con->sock_mutex);
810         /* nothing to shutdown */
811         if (!con->sock) {
812                 mutex_unlock(&con->sock_mutex);
813                 return;
814         }
815
816         set_bit(CF_SHUTDOWN, &con->flags);
817         ret = kernel_sock_shutdown(con->sock, SHUT_WR);
818         mutex_unlock(&con->sock_mutex);
819         if (ret) {
820                 log_print("Connection %p failed to shutdown: %d will force close",
821                           con, ret);
822                 goto force_close;
823         } else {
824                 ret = wait_event_timeout(con->shutdown_wait,
825                                          !test_bit(CF_SHUTDOWN, &con->flags),
826                                          DLM_SHUTDOWN_WAIT_TIMEOUT);
827                 if (ret == 0) {
828                         log_print("Connection %p shutdown timed out, will force close",
829                                   con);
830                         goto force_close;
831                 }
832         }
833
834         return;
835
836 force_close:
837         clear_bit(CF_SHUTDOWN, &con->flags);
838         close_connection(con, false, true, true);
839 }
840
841 static void dlm_tcp_shutdown(struct connection *con)
842 {
843         if (con->othercon)
844                 shutdown_connection(con->othercon);
845         shutdown_connection(con);
846 }
847
848 static int con_realloc_receive_buf(struct connection *con, int newlen)
849 {
850         unsigned char *newbuf;
851
852         newbuf = kmalloc(newlen, GFP_NOFS);
853         if (!newbuf)
854                 return -ENOMEM;
855
856         /* copy any leftover from last receive */
857         if (con->rx_leftover)
858                 memmove(newbuf, con->rx_buf, con->rx_leftover);
859
860         /* swap to new buffer space */
861         kfree(con->rx_buf);
862         con->rx_buflen = newlen;
863         con->rx_buf = newbuf;
864
865         return 0;
866 }
867
868 /* Data received from remote end */
869 static int receive_from_sock(struct connection *con)
870 {
871         int call_again_soon = 0;
872         struct msghdr msg;
873         struct kvec iov;
874         int ret, buflen;
875
876         mutex_lock(&con->sock_mutex);
877
878         if (con->sock == NULL) {
879                 ret = -EAGAIN;
880                 goto out_close;
881         }
882
883         /* realloc if we get new buffer size to read out */
884         buflen = dlm_config.ci_buffer_size;
885         if (con->rx_buflen != buflen && con->rx_leftover <= buflen) {
886                 ret = con_realloc_receive_buf(con, buflen);
887                 if (ret < 0)
888                         goto out_resched;
889         }
890
891         /* calculate new buffer parameter regarding last receive and
892          * possible leftover bytes
893          */
894         iov.iov_base = con->rx_buf + con->rx_leftover;
895         iov.iov_len = con->rx_buflen - con->rx_leftover;
896
897         memset(&msg, 0, sizeof(msg));
898         msg.msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL;
899         ret = kernel_recvmsg(con->sock, &msg, &iov, 1, iov.iov_len,
900                              msg.msg_flags);
901         if (ret <= 0)
902                 goto out_close;
903         else if (ret == iov.iov_len)
904                 call_again_soon = 1;
905
906         /* new buflen according readed bytes and leftover from last receive */
907         buflen = ret + con->rx_leftover;
908         ret = dlm_process_incoming_buffer(con->nodeid, con->rx_buf, buflen);
909         if (ret < 0)
910                 goto out_close;
911
912         /* calculate leftover bytes from process and put it into begin of
913          * the receive buffer, so next receive we have the full message
914          * at the start address of the receive buffer.
915          */
916         con->rx_leftover = buflen - ret;
917         if (con->rx_leftover) {
918                 memmove(con->rx_buf, con->rx_buf + ret,
919                         con->rx_leftover);
920                 call_again_soon = true;
921         }
922
923         if (call_again_soon)
924                 goto out_resched;
925
926         mutex_unlock(&con->sock_mutex);
927         return 0;
928
929 out_resched:
930         if (!test_and_set_bit(CF_READ_PENDING, &con->flags))
931                 queue_work(recv_workqueue, &con->rwork);
932         mutex_unlock(&con->sock_mutex);
933         return -EAGAIN;
934
935 out_close:
936         if (ret == 0) {
937                 log_print("connection %p got EOF from %d",
938                           con, con->nodeid);
939
940                 if (con->eof_condition && con->eof_condition(con)) {
941                         set_bit(CF_EOF, &con->flags);
942                         mutex_unlock(&con->sock_mutex);
943                 } else {
944                         mutex_unlock(&con->sock_mutex);
945                         close_connection(con, false, true, false);
946
947                         /* handling for tcp shutdown */
948                         clear_bit(CF_SHUTDOWN, &con->flags);
949                         wake_up(&con->shutdown_wait);
950                 }
951
952                 /* signal to breaking receive worker */
953                 ret = -1;
954         } else {
955                 mutex_unlock(&con->sock_mutex);
956         }
957         return ret;
958 }
959
960 /* Listening socket is busy, accept a connection */
961 static int accept_from_sock(struct listen_connection *con)
962 {
963         int result;
964         struct sockaddr_storage peeraddr;
965         struct socket *newsock;
966         int len, idx;
967         int nodeid;
968         struct connection *newcon;
969         struct connection *addcon;
970         unsigned int mark;
971
972         if (!dlm_allow_conn) {
973                 return -1;
974         }
975
976         if (!con->sock)
977                 return -ENOTCONN;
978
979         result = kernel_accept(con->sock, &newsock, O_NONBLOCK);
980         if (result < 0)
981                 goto accept_err;
982
983         /* Get the connected socket's peer */
984         memset(&peeraddr, 0, sizeof(peeraddr));
985         len = newsock->ops->getname(newsock, (struct sockaddr *)&peeraddr, 2);
986         if (len < 0) {
987                 result = -ECONNABORTED;
988                 goto accept_err;
989         }
990
991         /* Get the new node's NODEID */
992         make_sockaddr(&peeraddr, 0, &len);
993         if (addr_to_nodeid(&peeraddr, &nodeid, &mark)) {
994                 unsigned char *b=(unsigned char *)&peeraddr;
995                 log_print("connect from non cluster node");
996                 print_hex_dump_bytes("ss: ", DUMP_PREFIX_NONE, 
997                                      b, sizeof(struct sockaddr_storage));
998                 sock_release(newsock);
999                 return -1;
1000         }
1001
1002         log_print("got connection from %d", nodeid);
1003
1004         /*  Check to see if we already have a connection to this node. This
1005          *  could happen if the two nodes initiate a connection at roughly
1006          *  the same time and the connections cross on the wire.
1007          *  In this case we store the incoming one in "othercon"
1008          */
1009         idx = srcu_read_lock(&connections_srcu);
1010         newcon = nodeid2con(nodeid, GFP_NOFS);
1011         if (!newcon) {
1012                 srcu_read_unlock(&connections_srcu, idx);
1013                 result = -ENOMEM;
1014                 goto accept_err;
1015         }
1016
1017         sock_set_mark(newsock->sk, mark);
1018
1019         mutex_lock(&newcon->sock_mutex);
1020         if (newcon->sock) {
1021                 struct connection *othercon = newcon->othercon;
1022
1023                 if (!othercon) {
1024                         othercon = kzalloc(sizeof(*othercon), GFP_NOFS);
1025                         if (!othercon) {
1026                                 log_print("failed to allocate incoming socket");
1027                                 mutex_unlock(&newcon->sock_mutex);
1028                                 srcu_read_unlock(&connections_srcu, idx);
1029                                 result = -ENOMEM;
1030                                 goto accept_err;
1031                         }
1032
1033                         result = dlm_con_init(othercon, nodeid);
1034                         if (result < 0) {
1035                                 kfree(othercon);
1036                                 mutex_unlock(&newcon->sock_mutex);
1037                                 srcu_read_unlock(&connections_srcu, idx);
1038                                 goto accept_err;
1039                         }
1040
1041                         lockdep_set_subclass(&othercon->sock_mutex, 1);
1042                         set_bit(CF_IS_OTHERCON, &othercon->flags);
1043                         newcon->othercon = othercon;
1044                         othercon->sendcon = newcon;
1045                 } else {
1046                         /* close other sock con if we have something new */
1047                         close_connection(othercon, false, true, false);
1048                 }
1049
1050                 mutex_lock(&othercon->sock_mutex);
1051                 add_sock(newsock, othercon);
1052                 addcon = othercon;
1053                 mutex_unlock(&othercon->sock_mutex);
1054         }
1055         else {
1056                 /* accept copies the sk after we've saved the callbacks, so we
1057                    don't want to save them a second time or comm errors will
1058                    result in calling sk_error_report recursively. */
1059                 add_sock(newsock, newcon);
1060                 addcon = newcon;
1061         }
1062
1063         set_bit(CF_CONNECTED, &addcon->flags);
1064         mutex_unlock(&newcon->sock_mutex);
1065
1066         /*
1067          * Add it to the active queue in case we got data
1068          * between processing the accept adding the socket
1069          * to the read_sockets list
1070          */
1071         if (!test_and_set_bit(CF_READ_PENDING, &addcon->flags))
1072                 queue_work(recv_workqueue, &addcon->rwork);
1073
1074         srcu_read_unlock(&connections_srcu, idx);
1075
1076         return 0;
1077
1078 accept_err:
1079         if (newsock)
1080                 sock_release(newsock);
1081
1082         if (result != -EAGAIN)
1083                 log_print("error accepting connection from node: %d", result);
1084         return result;
1085 }
1086
1087 /*
1088  * writequeue_entry_complete - try to delete and free write queue entry
1089  * @e: write queue entry to try to delete
1090  * @completed: bytes completed
1091  *
1092  * writequeue_lock must be held.
1093  */
1094 static void writequeue_entry_complete(struct writequeue_entry *e, int completed)
1095 {
1096         e->offset += completed;
1097         e->len -= completed;
1098         /* signal that page was half way transmitted */
1099         e->dirty = true;
1100
1101         if (e->len == 0 && e->users == 0)
1102                 free_entry(e);
1103 }
1104
1105 /*
1106  * sctp_bind_addrs - bind a SCTP socket to all our addresses
1107  */
1108 static int sctp_bind_addrs(struct socket *sock, uint16_t port)
1109 {
1110         struct sockaddr_storage localaddr;
1111         struct sockaddr *addr = (struct sockaddr *)&localaddr;
1112         int i, addr_len, result = 0;
1113
1114         for (i = 0; i < dlm_local_count; i++) {
1115                 memcpy(&localaddr, dlm_local_addr[i], sizeof(localaddr));
1116                 make_sockaddr(&localaddr, port, &addr_len);
1117
1118                 if (!i)
1119                         result = kernel_bind(sock, addr, addr_len);
1120                 else
1121                         result = sock_bind_add(sock->sk, addr, addr_len);
1122
1123                 if (result < 0) {
1124                         log_print("Can't bind to %d addr number %d, %d.\n",
1125                                   port, i + 1, result);
1126                         break;
1127                 }
1128         }
1129         return result;
1130 }
1131
1132 /* Initiate an SCTP association.
1133    This is a special case of send_to_sock() in that we don't yet have a
1134    peeled-off socket for this association, so we use the listening socket
1135    and add the primary IP address of the remote node.
1136  */
1137 static void sctp_connect_to_sock(struct connection *con)
1138 {
1139         struct sockaddr_storage daddr;
1140         int result;
1141         int addr_len;
1142         struct socket *sock;
1143         unsigned int mark;
1144
1145         mutex_lock(&con->sock_mutex);
1146
1147         /* Some odd races can cause double-connects, ignore them */
1148         if (con->retries++ > MAX_CONNECT_RETRIES)
1149                 goto out;
1150
1151         if (con->sock) {
1152                 log_print("node %d already connected.", con->nodeid);
1153                 goto out;
1154         }
1155
1156         memset(&daddr, 0, sizeof(daddr));
1157         result = nodeid_to_addr(con->nodeid, &daddr, NULL, true, &mark);
1158         if (result < 0) {
1159                 log_print("no address for nodeid %d", con->nodeid);
1160                 goto out;
1161         }
1162
1163         /* Create a socket to communicate with */
1164         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1165                                   SOCK_STREAM, IPPROTO_SCTP, &sock);
1166         if (result < 0)
1167                 goto socket_err;
1168
1169         sock_set_mark(sock->sk, mark);
1170
1171         add_sock(sock, con);
1172
1173         /* Bind to all addresses. */
1174         if (sctp_bind_addrs(con->sock, 0))
1175                 goto bind_err;
1176
1177         make_sockaddr(&daddr, dlm_config.ci_tcp_port, &addr_len);
1178
1179         log_print_ratelimited("connecting to %d", con->nodeid);
1180
1181         /* Turn off Nagle's algorithm */
1182         sctp_sock_set_nodelay(sock->sk);
1183
1184         /*
1185          * Make sock->ops->connect() function return in specified time,
1186          * since O_NONBLOCK argument in connect() function does not work here,
1187          * then, we should restore the default value of this attribute.
1188          */
1189         sock_set_sndtimeo(sock->sk, 5);
1190         result = sock->ops->connect(sock, (struct sockaddr *)&daddr, addr_len,
1191                                    0);
1192         sock_set_sndtimeo(sock->sk, 0);
1193
1194         if (result == -EINPROGRESS)
1195                 result = 0;
1196         if (result == 0) {
1197                 if (!test_and_set_bit(CF_CONNECTED, &con->flags))
1198                         log_print("successful connected to node %d", con->nodeid);
1199                 goto out;
1200         }
1201
1202 bind_err:
1203         con->sock = NULL;
1204         sock_release(sock);
1205
1206 socket_err:
1207         /*
1208          * Some errors are fatal and this list might need adjusting. For other
1209          * errors we try again until the max number of retries is reached.
1210          */
1211         if (result != -EHOSTUNREACH &&
1212             result != -ENETUNREACH &&
1213             result != -ENETDOWN &&
1214             result != -EINVAL &&
1215             result != -EPROTONOSUPPORT) {
1216                 log_print("connect %d try %d error %d", con->nodeid,
1217                           con->retries, result);
1218                 mutex_unlock(&con->sock_mutex);
1219                 msleep(1000);
1220                 lowcomms_connect_sock(con);
1221                 return;
1222         }
1223
1224 out:
1225         mutex_unlock(&con->sock_mutex);
1226 }
1227
1228 /* Connect a new socket to its peer */
1229 static void tcp_connect_to_sock(struct connection *con)
1230 {
1231         struct sockaddr_storage saddr, src_addr;
1232         unsigned int mark;
1233         int addr_len;
1234         struct socket *sock = NULL;
1235         int result;
1236
1237         mutex_lock(&con->sock_mutex);
1238         if (con->retries++ > MAX_CONNECT_RETRIES)
1239                 goto out;
1240
1241         /* Some odd races can cause double-connects, ignore them */
1242         if (con->sock)
1243                 goto out;
1244
1245         /* Create a socket to communicate with */
1246         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1247                                   SOCK_STREAM, IPPROTO_TCP, &sock);
1248         if (result < 0)
1249                 goto out_err;
1250
1251         memset(&saddr, 0, sizeof(saddr));
1252         result = nodeid_to_addr(con->nodeid, &saddr, NULL, false, &mark);
1253         if (result < 0) {
1254                 log_print("no address for nodeid %d", con->nodeid);
1255                 goto out_err;
1256         }
1257
1258         sock_set_mark(sock->sk, mark);
1259
1260         add_sock(sock, con);
1261
1262         /* Bind to our cluster-known address connecting to avoid
1263            routing problems */
1264         memcpy(&src_addr, dlm_local_addr[0], sizeof(src_addr));
1265         make_sockaddr(&src_addr, 0, &addr_len);
1266         result = sock->ops->bind(sock, (struct sockaddr *) &src_addr,
1267                                  addr_len);
1268         if (result < 0) {
1269                 log_print("could not bind for connect: %d", result);
1270                 /* This *may* not indicate a critical error */
1271         }
1272
1273         make_sockaddr(&saddr, dlm_config.ci_tcp_port, &addr_len);
1274
1275         log_print_ratelimited("connecting to %d", con->nodeid);
1276
1277         /* Turn off Nagle's algorithm */
1278         tcp_sock_set_nodelay(sock->sk);
1279
1280         result = sock->ops->connect(sock, (struct sockaddr *)&saddr, addr_len,
1281                                    O_NONBLOCK);
1282         if (result == -EINPROGRESS)
1283                 result = 0;
1284         if (result == 0)
1285                 goto out;
1286
1287 out_err:
1288         if (con->sock) {
1289                 sock_release(con->sock);
1290                 con->sock = NULL;
1291         } else if (sock) {
1292                 sock_release(sock);
1293         }
1294         /*
1295          * Some errors are fatal and this list might need adjusting. For other
1296          * errors we try again until the max number of retries is reached.
1297          */
1298         if (result != -EHOSTUNREACH &&
1299             result != -ENETUNREACH &&
1300             result != -ENETDOWN && 
1301             result != -EINVAL &&
1302             result != -EPROTONOSUPPORT) {
1303                 log_print("connect %d try %d error %d", con->nodeid,
1304                           con->retries, result);
1305                 mutex_unlock(&con->sock_mutex);
1306                 msleep(1000);
1307                 lowcomms_connect_sock(con);
1308                 return;
1309         }
1310 out:
1311         mutex_unlock(&con->sock_mutex);
1312         return;
1313 }
1314
1315 /* On error caller must run dlm_close_sock() for the
1316  * listen connection socket.
1317  */
1318 static int tcp_create_listen_sock(struct listen_connection *con,
1319                                   struct sockaddr_storage *saddr)
1320 {
1321         struct socket *sock = NULL;
1322         int result = 0;
1323         int addr_len;
1324
1325         if (dlm_local_addr[0]->ss_family == AF_INET)
1326                 addr_len = sizeof(struct sockaddr_in);
1327         else
1328                 addr_len = sizeof(struct sockaddr_in6);
1329
1330         /* Create a socket to communicate with */
1331         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1332                                   SOCK_STREAM, IPPROTO_TCP, &sock);
1333         if (result < 0) {
1334                 log_print("Can't create listening comms socket");
1335                 goto create_out;
1336         }
1337
1338         sock_set_mark(sock->sk, dlm_config.ci_mark);
1339
1340         /* Turn off Nagle's algorithm */
1341         tcp_sock_set_nodelay(sock->sk);
1342
1343         sock_set_reuseaddr(sock->sk);
1344
1345         add_listen_sock(sock, con);
1346
1347         /* Bind to our port */
1348         make_sockaddr(saddr, dlm_config.ci_tcp_port, &addr_len);
1349         result = sock->ops->bind(sock, (struct sockaddr *) saddr, addr_len);
1350         if (result < 0) {
1351                 log_print("Can't bind to port %d", dlm_config.ci_tcp_port);
1352                 goto create_out;
1353         }
1354         sock_set_keepalive(sock->sk);
1355
1356         result = sock->ops->listen(sock, 5);
1357         if (result < 0) {
1358                 log_print("Can't listen on port %d", dlm_config.ci_tcp_port);
1359                 goto create_out;
1360         }
1361
1362         return 0;
1363
1364 create_out:
1365         return result;
1366 }
1367
1368 /* Get local addresses */
1369 static void init_local(void)
1370 {
1371         struct sockaddr_storage sas, *addr;
1372         int i;
1373
1374         dlm_local_count = 0;
1375         for (i = 0; i < DLM_MAX_ADDR_COUNT; i++) {
1376                 if (dlm_our_addr(&sas, i))
1377                         break;
1378
1379                 addr = kmemdup(&sas, sizeof(*addr), GFP_NOFS);
1380                 if (!addr)
1381                         break;
1382                 dlm_local_addr[dlm_local_count++] = addr;
1383         }
1384 }
1385
1386 static void deinit_local(void)
1387 {
1388         int i;
1389
1390         for (i = 0; i < dlm_local_count; i++)
1391                 kfree(dlm_local_addr[i]);
1392 }
1393
1394 /* Initialise SCTP socket and bind to all interfaces
1395  * On error caller must run dlm_close_sock() for the
1396  * listen connection socket.
1397  */
1398 static int sctp_listen_for_all(struct listen_connection *con)
1399 {
1400         struct socket *sock = NULL;
1401         int result = -EINVAL;
1402
1403         log_print("Using SCTP for communications");
1404
1405         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1406                                   SOCK_STREAM, IPPROTO_SCTP, &sock);
1407         if (result < 0) {
1408                 log_print("Can't create comms socket, check SCTP is loaded");
1409                 goto out;
1410         }
1411
1412         sock_set_rcvbuf(sock->sk, NEEDED_RMEM);
1413         sock_set_mark(sock->sk, dlm_config.ci_mark);
1414         sctp_sock_set_nodelay(sock->sk);
1415
1416         add_listen_sock(sock, con);
1417
1418         /* Bind to all addresses. */
1419         result = sctp_bind_addrs(con->sock, dlm_config.ci_tcp_port);
1420         if (result < 0)
1421                 goto out;
1422
1423         result = sock->ops->listen(sock, 5);
1424         if (result < 0) {
1425                 log_print("Can't set socket listening");
1426                 goto out;
1427         }
1428
1429         return 0;
1430
1431 out:
1432         return result;
1433 }
1434
1435 static int tcp_listen_for_all(void)
1436 {
1437         /* We don't support multi-homed hosts */
1438         if (dlm_local_count > 1) {
1439                 log_print("TCP protocol can't handle multi-homed hosts, "
1440                           "try SCTP");
1441                 return -EINVAL;
1442         }
1443
1444         log_print("Using TCP for communications");
1445
1446         return tcp_create_listen_sock(&listen_con, dlm_local_addr[0]);
1447 }
1448
1449
1450
1451 static struct writequeue_entry *new_writequeue_entry(struct connection *con,
1452                                                      gfp_t allocation)
1453 {
1454         struct writequeue_entry *entry;
1455
1456         entry = kzalloc(sizeof(*entry), allocation);
1457         if (!entry)
1458                 return NULL;
1459
1460         entry->page = alloc_page(allocation | __GFP_ZERO);
1461         if (!entry->page) {
1462                 kfree(entry);
1463                 return NULL;
1464         }
1465
1466         entry->con = con;
1467         entry->users = 1;
1468         kref_init(&entry->ref);
1469         INIT_LIST_HEAD(&entry->msgs);
1470
1471         return entry;
1472 }
1473
1474 static struct writequeue_entry *new_wq_entry(struct connection *con, int len,
1475                                              gfp_t allocation, char **ppc,
1476                                              void (*cb)(struct dlm_mhandle *mh),
1477                                              struct dlm_mhandle *mh)
1478 {
1479         struct writequeue_entry *e;
1480
1481         spin_lock(&con->writequeue_lock);
1482         if (!list_empty(&con->writequeue)) {
1483                 e = list_last_entry(&con->writequeue, struct writequeue_entry, list);
1484                 if (DLM_WQ_REMAIN_BYTES(e) >= len) {
1485                         kref_get(&e->ref);
1486
1487                         *ppc = page_address(e->page) + e->end;
1488                         if (cb)
1489                                 cb(mh);
1490
1491                         e->end += len;
1492                         e->users++;
1493                         spin_unlock(&con->writequeue_lock);
1494
1495                         return e;
1496                 }
1497         }
1498         spin_unlock(&con->writequeue_lock);
1499
1500         e = new_writequeue_entry(con, allocation);
1501         if (!e)
1502                 return NULL;
1503
1504         kref_get(&e->ref);
1505         *ppc = page_address(e->page);
1506         e->end += len;
1507         atomic_inc(&con->writequeue_cnt);
1508
1509         spin_lock(&con->writequeue_lock);
1510         if (cb)
1511                 cb(mh);
1512
1513         list_add_tail(&e->list, &con->writequeue);
1514         spin_unlock(&con->writequeue_lock);
1515
1516         return e;
1517 };
1518
1519 static struct dlm_msg *dlm_lowcomms_new_msg_con(struct connection *con, int len,
1520                                                 gfp_t allocation, char **ppc,
1521                                                 void (*cb)(struct dlm_mhandle *mh),
1522                                                 struct dlm_mhandle *mh)
1523 {
1524         struct writequeue_entry *e;
1525         struct dlm_msg *msg;
1526
1527         msg = kzalloc(sizeof(*msg), allocation);
1528         if (!msg)
1529                 return NULL;
1530
1531         kref_init(&msg->ref);
1532
1533         e = new_wq_entry(con, len, allocation, ppc, cb, mh);
1534         if (!e) {
1535                 kfree(msg);
1536                 return NULL;
1537         }
1538
1539         msg->ppc = *ppc;
1540         msg->len = len;
1541         msg->entry = e;
1542
1543         return msg;
1544 }
1545
1546 struct dlm_msg *dlm_lowcomms_new_msg(int nodeid, int len, gfp_t allocation,
1547                                      char **ppc, void (*cb)(struct dlm_mhandle *mh),
1548                                      struct dlm_mhandle *mh)
1549 {
1550         struct connection *con;
1551         struct dlm_msg *msg;
1552         int idx;
1553
1554         if (len > DEFAULT_BUFFER_SIZE ||
1555             len < sizeof(struct dlm_header)) {
1556                 BUILD_BUG_ON(PAGE_SIZE < DEFAULT_BUFFER_SIZE);
1557                 log_print("failed to allocate a buffer of size %d", len);
1558                 WARN_ON(1);
1559                 return NULL;
1560         }
1561
1562         idx = srcu_read_lock(&connections_srcu);
1563         con = nodeid2con(nodeid, allocation);
1564         if (!con) {
1565                 srcu_read_unlock(&connections_srcu, idx);
1566                 return NULL;
1567         }
1568
1569         msg = dlm_lowcomms_new_msg_con(con, len, allocation, ppc, cb, mh);
1570         if (!msg) {
1571                 srcu_read_unlock(&connections_srcu, idx);
1572                 return NULL;
1573         }
1574
1575         /* we assume if successful commit must called */
1576         msg->idx = idx;
1577         return msg;
1578 }
1579
1580 static void _dlm_lowcomms_commit_msg(struct dlm_msg *msg)
1581 {
1582         struct writequeue_entry *e = msg->entry;
1583         struct connection *con = e->con;
1584         int users;
1585
1586         spin_lock(&con->writequeue_lock);
1587         kref_get(&msg->ref);
1588         list_add(&msg->list, &e->msgs);
1589
1590         users = --e->users;
1591         if (users)
1592                 goto out;
1593
1594         e->len = DLM_WQ_LENGTH_BYTES(e);
1595         spin_unlock(&con->writequeue_lock);
1596
1597         queue_work(send_workqueue, &con->swork);
1598         return;
1599
1600 out:
1601         spin_unlock(&con->writequeue_lock);
1602         return;
1603 }
1604
1605 void dlm_lowcomms_commit_msg(struct dlm_msg *msg)
1606 {
1607         _dlm_lowcomms_commit_msg(msg);
1608         srcu_read_unlock(&connections_srcu, msg->idx);
1609 }
1610
1611 void dlm_lowcomms_put_msg(struct dlm_msg *msg)
1612 {
1613         kref_put(&msg->ref, dlm_msg_release);
1614 }
1615
1616 /* does not held connections_srcu, usage workqueue only */
1617 int dlm_lowcomms_resend_msg(struct dlm_msg *msg)
1618 {
1619         struct dlm_msg *msg_resend;
1620         char *ppc;
1621
1622         if (msg->retransmit)
1623                 return 1;
1624
1625         msg_resend = dlm_lowcomms_new_msg_con(msg->entry->con, msg->len,
1626                                               GFP_ATOMIC, &ppc, NULL, NULL);
1627         if (!msg_resend)
1628                 return -ENOMEM;
1629
1630         msg->retransmit = true;
1631         kref_get(&msg->ref);
1632         msg_resend->orig_msg = msg;
1633
1634         memcpy(ppc, msg->ppc, msg->len);
1635         _dlm_lowcomms_commit_msg(msg_resend);
1636         dlm_lowcomms_put_msg(msg_resend);
1637
1638         return 0;
1639 }
1640
1641 /* Send a message */
1642 static void send_to_sock(struct connection *con)
1643 {
1644         int ret = 0;
1645         const int msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL;
1646         struct writequeue_entry *e;
1647         int len, offset;
1648         int count = 0;
1649
1650         mutex_lock(&con->sock_mutex);
1651         if (con->sock == NULL)
1652                 goto out_connect;
1653
1654         spin_lock(&con->writequeue_lock);
1655         for (;;) {
1656                 if (list_empty(&con->writequeue))
1657                         break;
1658
1659                 e = list_first_entry(&con->writequeue, struct writequeue_entry, list);
1660                 len = e->len;
1661                 offset = e->offset;
1662                 BUG_ON(len == 0 && e->users == 0);
1663                 spin_unlock(&con->writequeue_lock);
1664
1665                 ret = 0;
1666                 if (len) {
1667                         ret = kernel_sendpage(con->sock, e->page, offset, len,
1668                                               msg_flags);
1669                         if (ret == -EAGAIN || ret == 0) {
1670                                 if (ret == -EAGAIN &&
1671                                     test_bit(SOCKWQ_ASYNC_NOSPACE, &con->sock->flags) &&
1672                                     !test_and_set_bit(CF_APP_LIMITED, &con->flags)) {
1673                                         /* Notify TCP that we're limited by the
1674                                          * application window size.
1675                                          */
1676                                         set_bit(SOCK_NOSPACE, &con->sock->flags);
1677                                         con->sock->sk->sk_write_pending++;
1678                                 }
1679                                 cond_resched();
1680                                 goto out;
1681                         } else if (ret < 0)
1682                                 goto out;
1683                 }
1684
1685                 /* Don't starve people filling buffers */
1686                 if (++count >= MAX_SEND_MSG_COUNT) {
1687                         cond_resched();
1688                         count = 0;
1689                 }
1690
1691                 spin_lock(&con->writequeue_lock);
1692                 writequeue_entry_complete(e, ret);
1693         }
1694         spin_unlock(&con->writequeue_lock);
1695
1696         /* close if we got EOF */
1697         if (test_and_clear_bit(CF_EOF, &con->flags)) {
1698                 mutex_unlock(&con->sock_mutex);
1699                 close_connection(con, false, false, true);
1700
1701                 /* handling for tcp shutdown */
1702                 clear_bit(CF_SHUTDOWN, &con->flags);
1703                 wake_up(&con->shutdown_wait);
1704         } else {
1705                 mutex_unlock(&con->sock_mutex);
1706         }
1707
1708         return;
1709
1710 out:
1711         mutex_unlock(&con->sock_mutex);
1712         return;
1713
1714 out_connect:
1715         mutex_unlock(&con->sock_mutex);
1716         queue_work(send_workqueue, &con->swork);
1717         cond_resched();
1718 }
1719
1720 static void clean_one_writequeue(struct connection *con)
1721 {
1722         struct writequeue_entry *e, *safe;
1723
1724         spin_lock(&con->writequeue_lock);
1725         list_for_each_entry_safe(e, safe, &con->writequeue, list) {
1726                 free_entry(e);
1727         }
1728         spin_unlock(&con->writequeue_lock);
1729 }
1730
1731 /* Called from recovery when it knows that a node has
1732    left the cluster */
1733 int dlm_lowcomms_close(int nodeid)
1734 {
1735         struct connection *con;
1736         struct dlm_node_addr *na;
1737         int idx;
1738
1739         log_print("closing connection to node %d", nodeid);
1740         idx = srcu_read_lock(&connections_srcu);
1741         con = nodeid2con(nodeid, 0);
1742         if (con) {
1743                 set_bit(CF_CLOSE, &con->flags);
1744                 close_connection(con, true, true, true);
1745                 clean_one_writequeue(con);
1746                 if (con->othercon)
1747                         clean_one_writequeue(con->othercon);
1748         }
1749         srcu_read_unlock(&connections_srcu, idx);
1750
1751         spin_lock(&dlm_node_addrs_spin);
1752         na = find_node_addr(nodeid);
1753         if (na) {
1754                 list_del(&na->list);
1755                 while (na->addr_count--)
1756                         kfree(na->addr[na->addr_count]);
1757                 kfree(na);
1758         }
1759         spin_unlock(&dlm_node_addrs_spin);
1760
1761         return 0;
1762 }
1763
1764 /* Receive workqueue function */
1765 static void process_recv_sockets(struct work_struct *work)
1766 {
1767         struct connection *con = container_of(work, struct connection, rwork);
1768         int err;
1769
1770         clear_bit(CF_READ_PENDING, &con->flags);
1771         do {
1772                 err = receive_from_sock(con);
1773         } while (!err);
1774 }
1775
1776 static void process_listen_recv_socket(struct work_struct *work)
1777 {
1778         accept_from_sock(&listen_con);
1779 }
1780
1781 /* Send workqueue function */
1782 static void process_send_sockets(struct work_struct *work)
1783 {
1784         struct connection *con = container_of(work, struct connection, swork);
1785
1786         WARN_ON(test_bit(CF_IS_OTHERCON, &con->flags));
1787
1788         clear_bit(CF_WRITE_PENDING, &con->flags);
1789
1790         if (test_and_clear_bit(CF_RECONNECT, &con->flags)) {
1791                 close_connection(con, false, false, true);
1792                 dlm_midcomms_unack_msg_resend(con->nodeid);
1793         }
1794
1795         if (con->sock == NULL) { /* not mutex protected so check it inside too */
1796                 if (test_and_clear_bit(CF_DELAY_CONNECT, &con->flags))
1797                         msleep(1000);
1798                 con->connect_action(con);
1799         }
1800         if (!list_empty(&con->writequeue))
1801                 send_to_sock(con);
1802 }
1803
1804 static void work_stop(void)
1805 {
1806         if (recv_workqueue) {
1807                 destroy_workqueue(recv_workqueue);
1808                 recv_workqueue = NULL;
1809         }
1810
1811         if (send_workqueue) {
1812                 destroy_workqueue(send_workqueue);
1813                 send_workqueue = NULL;
1814         }
1815 }
1816
1817 static int work_start(void)
1818 {
1819         recv_workqueue = alloc_workqueue("dlm_recv",
1820                                          WQ_UNBOUND | WQ_MEM_RECLAIM, 1);
1821         if (!recv_workqueue) {
1822                 log_print("can't start dlm_recv");
1823                 return -ENOMEM;
1824         }
1825
1826         send_workqueue = alloc_workqueue("dlm_send",
1827                                          WQ_UNBOUND | WQ_MEM_RECLAIM, 1);
1828         if (!send_workqueue) {
1829                 log_print("can't start dlm_send");
1830                 destroy_workqueue(recv_workqueue);
1831                 recv_workqueue = NULL;
1832                 return -ENOMEM;
1833         }
1834
1835         return 0;
1836 }
1837
1838 static void shutdown_conn(struct connection *con)
1839 {
1840         if (con->shutdown_action)
1841                 con->shutdown_action(con);
1842 }
1843
1844 void dlm_lowcomms_shutdown(void)
1845 {
1846         int idx;
1847
1848         /* Set all the flags to prevent any
1849          * socket activity.
1850          */
1851         dlm_allow_conn = 0;
1852
1853         if (recv_workqueue)
1854                 flush_workqueue(recv_workqueue);
1855         if (send_workqueue)
1856                 flush_workqueue(send_workqueue);
1857
1858         dlm_close_sock(&listen_con.sock);
1859
1860         idx = srcu_read_lock(&connections_srcu);
1861         foreach_conn(shutdown_conn);
1862         srcu_read_unlock(&connections_srcu, idx);
1863 }
1864
1865 static void _stop_conn(struct connection *con, bool and_other)
1866 {
1867         mutex_lock(&con->sock_mutex);
1868         set_bit(CF_CLOSE, &con->flags);
1869         set_bit(CF_READ_PENDING, &con->flags);
1870         set_bit(CF_WRITE_PENDING, &con->flags);
1871         if (con->sock && con->sock->sk) {
1872                 write_lock_bh(&con->sock->sk->sk_callback_lock);
1873                 con->sock->sk->sk_user_data = NULL;
1874                 write_unlock_bh(&con->sock->sk->sk_callback_lock);
1875         }
1876         if (con->othercon && and_other)
1877                 _stop_conn(con->othercon, false);
1878         mutex_unlock(&con->sock_mutex);
1879 }
1880
1881 static void stop_conn(struct connection *con)
1882 {
1883         _stop_conn(con, true);
1884 }
1885
1886 static void connection_release(struct rcu_head *rcu)
1887 {
1888         struct connection *con = container_of(rcu, struct connection, rcu);
1889
1890         kfree(con->rx_buf);
1891         kfree(con);
1892 }
1893
1894 static void free_conn(struct connection *con)
1895 {
1896         close_connection(con, true, true, true);
1897         spin_lock(&connections_lock);
1898         hlist_del_rcu(&con->list);
1899         spin_unlock(&connections_lock);
1900         if (con->othercon) {
1901                 clean_one_writequeue(con->othercon);
1902                 call_srcu(&connections_srcu, &con->othercon->rcu,
1903                           connection_release);
1904         }
1905         clean_one_writequeue(con);
1906         call_srcu(&connections_srcu, &con->rcu, connection_release);
1907 }
1908
1909 static void work_flush(void)
1910 {
1911         int ok;
1912         int i;
1913         struct connection *con;
1914
1915         do {
1916                 ok = 1;
1917                 foreach_conn(stop_conn);
1918                 if (recv_workqueue)
1919                         flush_workqueue(recv_workqueue);
1920                 if (send_workqueue)
1921                         flush_workqueue(send_workqueue);
1922                 for (i = 0; i < CONN_HASH_SIZE && ok; i++) {
1923                         hlist_for_each_entry_rcu(con, &connection_hash[i],
1924                                                  list) {
1925                                 ok &= test_bit(CF_READ_PENDING, &con->flags);
1926                                 ok &= test_bit(CF_WRITE_PENDING, &con->flags);
1927                                 if (con->othercon) {
1928                                         ok &= test_bit(CF_READ_PENDING,
1929                                                        &con->othercon->flags);
1930                                         ok &= test_bit(CF_WRITE_PENDING,
1931                                                        &con->othercon->flags);
1932                                 }
1933                         }
1934                 }
1935         } while (!ok);
1936 }
1937
1938 void dlm_lowcomms_stop(void)
1939 {
1940         int idx;
1941
1942         idx = srcu_read_lock(&connections_srcu);
1943         work_flush();
1944         foreach_conn(free_conn);
1945         srcu_read_unlock(&connections_srcu, idx);
1946         work_stop();
1947         deinit_local();
1948 }
1949
1950 int dlm_lowcomms_start(void)
1951 {
1952         int error = -EINVAL;
1953         int i;
1954
1955         for (i = 0; i < CONN_HASH_SIZE; i++)
1956                 INIT_HLIST_HEAD(&connection_hash[i]);
1957
1958         init_local();
1959         if (!dlm_local_count) {
1960                 error = -ENOTCONN;
1961                 log_print("no local IP address has been set");
1962                 goto fail;
1963         }
1964
1965         INIT_WORK(&listen_con.rwork, process_listen_recv_socket);
1966
1967         error = work_start();
1968         if (error)
1969                 goto fail_local;
1970
1971         dlm_allow_conn = 1;
1972
1973         /* Start listening */
1974         if (dlm_config.ci_protocol == 0)
1975                 error = tcp_listen_for_all();
1976         else
1977                 error = sctp_listen_for_all(&listen_con);
1978         if (error)
1979                 goto fail_unlisten;
1980
1981         return 0;
1982
1983 fail_unlisten:
1984         dlm_allow_conn = 0;
1985         dlm_close_sock(&listen_con.sock);
1986         work_stop();
1987 fail_local:
1988         deinit_local();
1989 fail:
1990         return error;
1991 }
1992
1993 void dlm_lowcomms_exit(void)
1994 {
1995         struct dlm_node_addr *na, *safe;
1996
1997         spin_lock(&dlm_node_addrs_spin);
1998         list_for_each_entry_safe(na, safe, &dlm_node_addrs, list) {
1999                 list_del(&na->list);
2000                 while (na->addr_count--)
2001                         kfree(na->addr[na->addr_count]);
2002                 kfree(na);
2003         }
2004         spin_unlock(&dlm_node_addrs_spin);
2005 }