fs: dlm: introduce con_next_wq helper
[linux-2.6-microblaze.git] / fs / dlm / lowcomms.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /******************************************************************************
3 *******************************************************************************
4 **
5 **  Copyright (C) Sistina Software, Inc.  1997-2003  All rights reserved.
6 **  Copyright (C) 2004-2009 Red Hat, Inc.  All rights reserved.
7 **
8 **
9 *******************************************************************************
10 ******************************************************************************/
11
12 /*
13  * lowcomms.c
14  *
15  * This is the "low-level" comms layer.
16  *
17  * It is responsible for sending/receiving messages
18  * from other nodes in the cluster.
19  *
20  * Cluster nodes are referred to by their nodeids. nodeids are
21  * simply 32 bit numbers to the locking module - if they need to
22  * be expanded for the cluster infrastructure then that is its
23  * responsibility. It is this layer's
24  * responsibility to resolve these into IP address or
25  * whatever it needs for inter-node communication.
26  *
27  * The comms level is two kernel threads that deal mainly with
28  * the receiving of messages from other nodes and passing them
29  * up to the mid-level comms layer (which understands the
30  * message format) for execution by the locking core, and
31  * a send thread which does all the setting up of connections
32  * to remote nodes and the sending of data. Threads are not allowed
33  * to send their own data because it may cause them to wait in times
34  * of high load. Also, this way, the sending thread can collect together
35  * messages bound for one node and send them in one block.
36  *
37  * lowcomms will choose to use either TCP or SCTP as its transport layer
38  * depending on the configuration variable 'protocol'. This should be set
39  * to 0 (default) for TCP or 1 for SCTP. It should be configured using a
40  * cluster-wide mechanism as it must be the same on all nodes of the cluster
41  * for the DLM to function.
42  *
43  */
44
45 #include <asm/ioctls.h>
46 #include <net/sock.h>
47 #include <net/tcp.h>
48 #include <linux/pagemap.h>
49 #include <linux/file.h>
50 #include <linux/mutex.h>
51 #include <linux/sctp.h>
52 #include <linux/slab.h>
53 #include <net/sctp/sctp.h>
54 #include <net/ipv6.h>
55
56 #include "dlm_internal.h"
57 #include "lowcomms.h"
58 #include "midcomms.h"
59 #include "config.h"
60
61 #define NEEDED_RMEM (4*1024*1024)
62
63 /* Number of messages to send before rescheduling */
64 #define MAX_SEND_MSG_COUNT 25
65 #define DLM_SHUTDOWN_WAIT_TIMEOUT msecs_to_jiffies(10000)
66
67 struct connection {
68         struct socket *sock;    /* NULL if not connected */
69         uint32_t nodeid;        /* So we know who we are in the list */
70         struct mutex sock_mutex;
71         unsigned long flags;
72 #define CF_READ_PENDING 1
73 #define CF_WRITE_PENDING 2
74 #define CF_INIT_PENDING 4
75 #define CF_IS_OTHERCON 5
76 #define CF_CLOSE 6
77 #define CF_APP_LIMITED 7
78 #define CF_CLOSING 8
79 #define CF_SHUTDOWN 9
80 #define CF_CONNECTED 10
81 #define CF_RECONNECT 11
82 #define CF_DELAY_CONNECT 12
83 #define CF_EOF 13
84         struct list_head writequeue;  /* List of outgoing writequeue_entries */
85         spinlock_t writequeue_lock;
86         atomic_t writequeue_cnt;
87         void (*connect_action) (struct connection *);   /* What to do to connect */
88         void (*shutdown_action)(struct connection *con); /* What to do to shutdown */
89         bool (*eof_condition)(struct connection *con); /* What to do to eof check */
90         int retries;
91 #define MAX_CONNECT_RETRIES 3
92         struct hlist_node list;
93         struct connection *othercon;
94         struct connection *sendcon;
95         struct work_struct rwork; /* Receive workqueue */
96         struct work_struct swork; /* Send workqueue */
97         wait_queue_head_t shutdown_wait; /* wait for graceful shutdown */
98         unsigned char *rx_buf;
99         int rx_buflen;
100         int rx_leftover;
101         struct rcu_head rcu;
102 };
103 #define sock2con(x) ((struct connection *)(x)->sk_user_data)
104
105 struct listen_connection {
106         struct socket *sock;
107         struct work_struct rwork;
108 };
109
110 #define DLM_WQ_REMAIN_BYTES(e) (PAGE_SIZE - e->end)
111 #define DLM_WQ_LENGTH_BYTES(e) (e->end - e->offset)
112
113 /* An entry waiting to be sent */
114 struct writequeue_entry {
115         struct list_head list;
116         struct page *page;
117         int offset;
118         int len;
119         int end;
120         int users;
121         bool dirty;
122         struct connection *con;
123         struct list_head msgs;
124         struct kref ref;
125 };
126
127 struct dlm_msg {
128         struct writequeue_entry *entry;
129         struct dlm_msg *orig_msg;
130         bool retransmit;
131         void *ppc;
132         int len;
133         int idx; /* new()/commit() idx exchange */
134
135         struct list_head list;
136         struct kref ref;
137 };
138
139 struct dlm_node_addr {
140         struct list_head list;
141         int nodeid;
142         int mark;
143         int addr_count;
144         int curr_addr_index;
145         struct sockaddr_storage *addr[DLM_MAX_ADDR_COUNT];
146 };
147
148 static struct listen_sock_callbacks {
149         void (*sk_error_report)(struct sock *);
150         void (*sk_data_ready)(struct sock *);
151         void (*sk_state_change)(struct sock *);
152         void (*sk_write_space)(struct sock *);
153 } listen_sock;
154
155 static LIST_HEAD(dlm_node_addrs);
156 static DEFINE_SPINLOCK(dlm_node_addrs_spin);
157
158 static struct listen_connection listen_con;
159 static struct sockaddr_storage *dlm_local_addr[DLM_MAX_ADDR_COUNT];
160 static int dlm_local_count;
161 int dlm_allow_conn;
162
163 /* Work queues */
164 static struct workqueue_struct *recv_workqueue;
165 static struct workqueue_struct *send_workqueue;
166
167 static struct hlist_head connection_hash[CONN_HASH_SIZE];
168 static DEFINE_SPINLOCK(connections_lock);
169 DEFINE_STATIC_SRCU(connections_srcu);
170
171 static void process_recv_sockets(struct work_struct *work);
172 static void process_send_sockets(struct work_struct *work);
173
174 static void sctp_connect_to_sock(struct connection *con);
175 static void tcp_connect_to_sock(struct connection *con);
176 static void dlm_tcp_shutdown(struct connection *con);
177
178 /* need to held writequeue_lock */
179 static struct writequeue_entry *con_next_wq(struct connection *con)
180 {
181         struct writequeue_entry *e;
182
183         if (list_empty(&con->writequeue))
184                 return NULL;
185
186         e = list_first_entry(&con->writequeue, struct writequeue_entry,
187                              list);
188         if (e->len == 0)
189                 return NULL;
190
191         return e;
192 }
193
194 static struct connection *__find_con(int nodeid, int r)
195 {
196         struct connection *con;
197
198         hlist_for_each_entry_rcu(con, &connection_hash[r], list) {
199                 if (con->nodeid == nodeid)
200                         return con;
201         }
202
203         return NULL;
204 }
205
206 static bool tcp_eof_condition(struct connection *con)
207 {
208         return atomic_read(&con->writequeue_cnt);
209 }
210
211 static int dlm_con_init(struct connection *con, int nodeid)
212 {
213         con->rx_buflen = dlm_config.ci_buffer_size;
214         con->rx_buf = kmalloc(con->rx_buflen, GFP_NOFS);
215         if (!con->rx_buf)
216                 return -ENOMEM;
217
218         con->nodeid = nodeid;
219         mutex_init(&con->sock_mutex);
220         INIT_LIST_HEAD(&con->writequeue);
221         spin_lock_init(&con->writequeue_lock);
222         atomic_set(&con->writequeue_cnt, 0);
223         INIT_WORK(&con->swork, process_send_sockets);
224         INIT_WORK(&con->rwork, process_recv_sockets);
225         init_waitqueue_head(&con->shutdown_wait);
226
227         switch (dlm_config.ci_protocol) {
228         case DLM_PROTO_TCP:
229                 con->connect_action = tcp_connect_to_sock;
230                 con->shutdown_action = dlm_tcp_shutdown;
231                 con->eof_condition = tcp_eof_condition;
232                 break;
233         case DLM_PROTO_SCTP:
234                 con->connect_action = sctp_connect_to_sock;
235                 break;
236         default:
237                 kfree(con->rx_buf);
238                 return -EINVAL;
239         }
240
241         return 0;
242 }
243
244 /*
245  * If 'allocation' is zero then we don't attempt to create a new
246  * connection structure for this node.
247  */
248 static struct connection *nodeid2con(int nodeid, gfp_t alloc)
249 {
250         struct connection *con, *tmp;
251         int r, ret;
252
253         r = nodeid_hash(nodeid);
254         con = __find_con(nodeid, r);
255         if (con || !alloc)
256                 return con;
257
258         con = kzalloc(sizeof(*con), alloc);
259         if (!con)
260                 return NULL;
261
262         ret = dlm_con_init(con, nodeid);
263         if (ret) {
264                 kfree(con);
265                 return NULL;
266         }
267
268         spin_lock(&connections_lock);
269         /* Because multiple workqueues/threads calls this function it can
270          * race on multiple cpu's. Instead of locking hot path __find_con()
271          * we just check in rare cases of recently added nodes again
272          * under protection of connections_lock. If this is the case we
273          * abort our connection creation and return the existing connection.
274          */
275         tmp = __find_con(nodeid, r);
276         if (tmp) {
277                 spin_unlock(&connections_lock);
278                 kfree(con->rx_buf);
279                 kfree(con);
280                 return tmp;
281         }
282
283         hlist_add_head_rcu(&con->list, &connection_hash[r]);
284         spin_unlock(&connections_lock);
285
286         return con;
287 }
288
289 /* Loop round all connections */
290 static void foreach_conn(void (*conn_func)(struct connection *c))
291 {
292         int i;
293         struct connection *con;
294
295         for (i = 0; i < CONN_HASH_SIZE; i++) {
296                 hlist_for_each_entry_rcu(con, &connection_hash[i], list)
297                         conn_func(con);
298         }
299 }
300
301 static struct dlm_node_addr *find_node_addr(int nodeid)
302 {
303         struct dlm_node_addr *na;
304
305         list_for_each_entry(na, &dlm_node_addrs, list) {
306                 if (na->nodeid == nodeid)
307                         return na;
308         }
309         return NULL;
310 }
311
312 static int addr_compare(const struct sockaddr_storage *x,
313                         const struct sockaddr_storage *y)
314 {
315         switch (x->ss_family) {
316         case AF_INET: {
317                 struct sockaddr_in *sinx = (struct sockaddr_in *)x;
318                 struct sockaddr_in *siny = (struct sockaddr_in *)y;
319                 if (sinx->sin_addr.s_addr != siny->sin_addr.s_addr)
320                         return 0;
321                 if (sinx->sin_port != siny->sin_port)
322                         return 0;
323                 break;
324         }
325         case AF_INET6: {
326                 struct sockaddr_in6 *sinx = (struct sockaddr_in6 *)x;
327                 struct sockaddr_in6 *siny = (struct sockaddr_in6 *)y;
328                 if (!ipv6_addr_equal(&sinx->sin6_addr, &siny->sin6_addr))
329                         return 0;
330                 if (sinx->sin6_port != siny->sin6_port)
331                         return 0;
332                 break;
333         }
334         default:
335                 return 0;
336         }
337         return 1;
338 }
339
340 static int nodeid_to_addr(int nodeid, struct sockaddr_storage *sas_out,
341                           struct sockaddr *sa_out, bool try_new_addr,
342                           unsigned int *mark)
343 {
344         struct sockaddr_storage sas;
345         struct dlm_node_addr *na;
346
347         if (!dlm_local_count)
348                 return -1;
349
350         spin_lock(&dlm_node_addrs_spin);
351         na = find_node_addr(nodeid);
352         if (na && na->addr_count) {
353                 memcpy(&sas, na->addr[na->curr_addr_index],
354                        sizeof(struct sockaddr_storage));
355
356                 if (try_new_addr) {
357                         na->curr_addr_index++;
358                         if (na->curr_addr_index == na->addr_count)
359                                 na->curr_addr_index = 0;
360                 }
361         }
362         spin_unlock(&dlm_node_addrs_spin);
363
364         if (!na)
365                 return -EEXIST;
366
367         if (!na->addr_count)
368                 return -ENOENT;
369
370         *mark = na->mark;
371
372         if (sas_out)
373                 memcpy(sas_out, &sas, sizeof(struct sockaddr_storage));
374
375         if (!sa_out)
376                 return 0;
377
378         if (dlm_local_addr[0]->ss_family == AF_INET) {
379                 struct sockaddr_in *in4  = (struct sockaddr_in *) &sas;
380                 struct sockaddr_in *ret4 = (struct sockaddr_in *) sa_out;
381                 ret4->sin_addr.s_addr = in4->sin_addr.s_addr;
382         } else {
383                 struct sockaddr_in6 *in6  = (struct sockaddr_in6 *) &sas;
384                 struct sockaddr_in6 *ret6 = (struct sockaddr_in6 *) sa_out;
385                 ret6->sin6_addr = in6->sin6_addr;
386         }
387
388         return 0;
389 }
390
391 static int addr_to_nodeid(struct sockaddr_storage *addr, int *nodeid,
392                           unsigned int *mark)
393 {
394         struct dlm_node_addr *na;
395         int rv = -EEXIST;
396         int addr_i;
397
398         spin_lock(&dlm_node_addrs_spin);
399         list_for_each_entry(na, &dlm_node_addrs, list) {
400                 if (!na->addr_count)
401                         continue;
402
403                 for (addr_i = 0; addr_i < na->addr_count; addr_i++) {
404                         if (addr_compare(na->addr[addr_i], addr)) {
405                                 *nodeid = na->nodeid;
406                                 *mark = na->mark;
407                                 rv = 0;
408                                 goto unlock;
409                         }
410                 }
411         }
412 unlock:
413         spin_unlock(&dlm_node_addrs_spin);
414         return rv;
415 }
416
417 /* caller need to held dlm_node_addrs_spin lock */
418 static bool dlm_lowcomms_na_has_addr(const struct dlm_node_addr *na,
419                                      const struct sockaddr_storage *addr)
420 {
421         int i;
422
423         for (i = 0; i < na->addr_count; i++) {
424                 if (addr_compare(na->addr[i], addr))
425                         return true;
426         }
427
428         return false;
429 }
430
431 int dlm_lowcomms_addr(int nodeid, struct sockaddr_storage *addr, int len)
432 {
433         struct sockaddr_storage *new_addr;
434         struct dlm_node_addr *new_node, *na;
435         bool ret;
436
437         new_node = kzalloc(sizeof(struct dlm_node_addr), GFP_NOFS);
438         if (!new_node)
439                 return -ENOMEM;
440
441         new_addr = kzalloc(sizeof(struct sockaddr_storage), GFP_NOFS);
442         if (!new_addr) {
443                 kfree(new_node);
444                 return -ENOMEM;
445         }
446
447         memcpy(new_addr, addr, len);
448
449         spin_lock(&dlm_node_addrs_spin);
450         na = find_node_addr(nodeid);
451         if (!na) {
452                 new_node->nodeid = nodeid;
453                 new_node->addr[0] = new_addr;
454                 new_node->addr_count = 1;
455                 new_node->mark = dlm_config.ci_mark;
456                 list_add(&new_node->list, &dlm_node_addrs);
457                 spin_unlock(&dlm_node_addrs_spin);
458                 return 0;
459         }
460
461         ret = dlm_lowcomms_na_has_addr(na, addr);
462         if (ret) {
463                 spin_unlock(&dlm_node_addrs_spin);
464                 kfree(new_addr);
465                 kfree(new_node);
466                 return -EEXIST;
467         }
468
469         if (na->addr_count >= DLM_MAX_ADDR_COUNT) {
470                 spin_unlock(&dlm_node_addrs_spin);
471                 kfree(new_addr);
472                 kfree(new_node);
473                 return -ENOSPC;
474         }
475
476         na->addr[na->addr_count++] = new_addr;
477         spin_unlock(&dlm_node_addrs_spin);
478         kfree(new_node);
479         return 0;
480 }
481
482 /* Data available on socket or listen socket received a connect */
483 static void lowcomms_data_ready(struct sock *sk)
484 {
485         struct connection *con;
486
487         read_lock_bh(&sk->sk_callback_lock);
488         con = sock2con(sk);
489         if (con && !test_and_set_bit(CF_READ_PENDING, &con->flags))
490                 queue_work(recv_workqueue, &con->rwork);
491         read_unlock_bh(&sk->sk_callback_lock);
492 }
493
494 static void lowcomms_listen_data_ready(struct sock *sk)
495 {
496         if (!dlm_allow_conn)
497                 return;
498
499         queue_work(recv_workqueue, &listen_con.rwork);
500 }
501
502 static void lowcomms_write_space(struct sock *sk)
503 {
504         struct connection *con;
505
506         read_lock_bh(&sk->sk_callback_lock);
507         con = sock2con(sk);
508         if (!con)
509                 goto out;
510
511         if (!test_and_set_bit(CF_CONNECTED, &con->flags)) {
512                 log_print("successful connected to node %d", con->nodeid);
513                 queue_work(send_workqueue, &con->swork);
514                 goto out;
515         }
516
517         clear_bit(SOCK_NOSPACE, &con->sock->flags);
518
519         if (test_and_clear_bit(CF_APP_LIMITED, &con->flags)) {
520                 con->sock->sk->sk_write_pending--;
521                 clear_bit(SOCKWQ_ASYNC_NOSPACE, &con->sock->flags);
522         }
523
524         queue_work(send_workqueue, &con->swork);
525 out:
526         read_unlock_bh(&sk->sk_callback_lock);
527 }
528
529 static inline void lowcomms_connect_sock(struct connection *con)
530 {
531         if (test_bit(CF_CLOSE, &con->flags))
532                 return;
533         queue_work(send_workqueue, &con->swork);
534         cond_resched();
535 }
536
537 static void lowcomms_state_change(struct sock *sk)
538 {
539         /* SCTP layer is not calling sk_data_ready when the connection
540          * is done, so we catch the signal through here. Also, it
541          * doesn't switch socket state when entering shutdown, so we
542          * skip the write in that case.
543          */
544         if (sk->sk_shutdown) {
545                 if (sk->sk_shutdown == RCV_SHUTDOWN)
546                         lowcomms_data_ready(sk);
547         } else if (sk->sk_state == TCP_ESTABLISHED) {
548                 lowcomms_write_space(sk);
549         }
550 }
551
552 int dlm_lowcomms_connect_node(int nodeid)
553 {
554         struct connection *con;
555         int idx;
556
557         if (nodeid == dlm_our_nodeid())
558                 return 0;
559
560         idx = srcu_read_lock(&connections_srcu);
561         con = nodeid2con(nodeid, GFP_NOFS);
562         if (!con) {
563                 srcu_read_unlock(&connections_srcu, idx);
564                 return -ENOMEM;
565         }
566
567         lowcomms_connect_sock(con);
568         srcu_read_unlock(&connections_srcu, idx);
569
570         return 0;
571 }
572
573 int dlm_lowcomms_nodes_set_mark(int nodeid, unsigned int mark)
574 {
575         struct dlm_node_addr *na;
576
577         spin_lock(&dlm_node_addrs_spin);
578         na = find_node_addr(nodeid);
579         if (!na) {
580                 spin_unlock(&dlm_node_addrs_spin);
581                 return -ENOENT;
582         }
583
584         na->mark = mark;
585         spin_unlock(&dlm_node_addrs_spin);
586
587         return 0;
588 }
589
590 static void lowcomms_error_report(struct sock *sk)
591 {
592         struct connection *con;
593         struct sockaddr_storage saddr;
594         void (*orig_report)(struct sock *) = NULL;
595
596         read_lock_bh(&sk->sk_callback_lock);
597         con = sock2con(sk);
598         if (con == NULL)
599                 goto out;
600
601         orig_report = listen_sock.sk_error_report;
602         if (kernel_getpeername(sk->sk_socket, (struct sockaddr *)&saddr) < 0) {
603                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
604                                    "sending to node %d, port %d, "
605                                    "sk_err=%d/%d\n", dlm_our_nodeid(),
606                                    con->nodeid, dlm_config.ci_tcp_port,
607                                    sk->sk_err, sk->sk_err_soft);
608         } else if (saddr.ss_family == AF_INET) {
609                 struct sockaddr_in *sin4 = (struct sockaddr_in *)&saddr;
610
611                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
612                                    "sending to node %d at %pI4, port %d, "
613                                    "sk_err=%d/%d\n", dlm_our_nodeid(),
614                                    con->nodeid, &sin4->sin_addr.s_addr,
615                                    dlm_config.ci_tcp_port, sk->sk_err,
616                                    sk->sk_err_soft);
617         } else {
618                 struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&saddr;
619
620                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
621                                    "sending to node %d at %u.%u.%u.%u, "
622                                    "port %d, sk_err=%d/%d\n", dlm_our_nodeid(),
623                                    con->nodeid, sin6->sin6_addr.s6_addr32[0],
624                                    sin6->sin6_addr.s6_addr32[1],
625                                    sin6->sin6_addr.s6_addr32[2],
626                                    sin6->sin6_addr.s6_addr32[3],
627                                    dlm_config.ci_tcp_port, sk->sk_err,
628                                    sk->sk_err_soft);
629         }
630
631         /* below sendcon only handling */
632         if (test_bit(CF_IS_OTHERCON, &con->flags))
633                 con = con->sendcon;
634
635         switch (sk->sk_err) {
636         case ECONNREFUSED:
637                 set_bit(CF_DELAY_CONNECT, &con->flags);
638                 break;
639         default:
640                 break;
641         }
642
643         if (!test_and_set_bit(CF_RECONNECT, &con->flags))
644                 queue_work(send_workqueue, &con->swork);
645
646 out:
647         read_unlock_bh(&sk->sk_callback_lock);
648         if (orig_report)
649                 orig_report(sk);
650 }
651
652 /* Note: sk_callback_lock must be locked before calling this function. */
653 static void save_listen_callbacks(struct socket *sock)
654 {
655         struct sock *sk = sock->sk;
656
657         listen_sock.sk_data_ready = sk->sk_data_ready;
658         listen_sock.sk_state_change = sk->sk_state_change;
659         listen_sock.sk_write_space = sk->sk_write_space;
660         listen_sock.sk_error_report = sk->sk_error_report;
661 }
662
663 static void restore_callbacks(struct socket *sock)
664 {
665         struct sock *sk = sock->sk;
666
667         write_lock_bh(&sk->sk_callback_lock);
668         sk->sk_user_data = NULL;
669         sk->sk_data_ready = listen_sock.sk_data_ready;
670         sk->sk_state_change = listen_sock.sk_state_change;
671         sk->sk_write_space = listen_sock.sk_write_space;
672         sk->sk_error_report = listen_sock.sk_error_report;
673         write_unlock_bh(&sk->sk_callback_lock);
674 }
675
676 static void add_listen_sock(struct socket *sock, struct listen_connection *con)
677 {
678         struct sock *sk = sock->sk;
679
680         write_lock_bh(&sk->sk_callback_lock);
681         save_listen_callbacks(sock);
682         con->sock = sock;
683
684         sk->sk_user_data = con;
685         sk->sk_allocation = GFP_NOFS;
686         /* Install a data_ready callback */
687         sk->sk_data_ready = lowcomms_listen_data_ready;
688         write_unlock_bh(&sk->sk_callback_lock);
689 }
690
691 /* Make a socket active */
692 static void add_sock(struct socket *sock, struct connection *con)
693 {
694         struct sock *sk = sock->sk;
695
696         write_lock_bh(&sk->sk_callback_lock);
697         con->sock = sock;
698
699         sk->sk_user_data = con;
700         /* Install a data_ready callback */
701         sk->sk_data_ready = lowcomms_data_ready;
702         sk->sk_write_space = lowcomms_write_space;
703         sk->sk_state_change = lowcomms_state_change;
704         sk->sk_allocation = GFP_NOFS;
705         sk->sk_error_report = lowcomms_error_report;
706         write_unlock_bh(&sk->sk_callback_lock);
707 }
708
709 /* Add the port number to an IPv6 or 4 sockaddr and return the address
710    length */
711 static void make_sockaddr(struct sockaddr_storage *saddr, uint16_t port,
712                           int *addr_len)
713 {
714         saddr->ss_family =  dlm_local_addr[0]->ss_family;
715         if (saddr->ss_family == AF_INET) {
716                 struct sockaddr_in *in4_addr = (struct sockaddr_in *)saddr;
717                 in4_addr->sin_port = cpu_to_be16(port);
718                 *addr_len = sizeof(struct sockaddr_in);
719                 memset(&in4_addr->sin_zero, 0, sizeof(in4_addr->sin_zero));
720         } else {
721                 struct sockaddr_in6 *in6_addr = (struct sockaddr_in6 *)saddr;
722                 in6_addr->sin6_port = cpu_to_be16(port);
723                 *addr_len = sizeof(struct sockaddr_in6);
724         }
725         memset((char *)saddr + *addr_len, 0, sizeof(struct sockaddr_storage) - *addr_len);
726 }
727
728 static void dlm_page_release(struct kref *kref)
729 {
730         struct writequeue_entry *e = container_of(kref, struct writequeue_entry,
731                                                   ref);
732
733         __free_page(e->page);
734         kfree(e);
735 }
736
737 static void dlm_msg_release(struct kref *kref)
738 {
739         struct dlm_msg *msg = container_of(kref, struct dlm_msg, ref);
740
741         kref_put(&msg->entry->ref, dlm_page_release);
742         kfree(msg);
743 }
744
745 static void free_entry(struct writequeue_entry *e)
746 {
747         struct dlm_msg *msg, *tmp;
748
749         list_for_each_entry_safe(msg, tmp, &e->msgs, list) {
750                 if (msg->orig_msg) {
751                         msg->orig_msg->retransmit = false;
752                         kref_put(&msg->orig_msg->ref, dlm_msg_release);
753                 }
754
755                 list_del(&msg->list);
756                 kref_put(&msg->ref, dlm_msg_release);
757         }
758
759         list_del(&e->list);
760         atomic_dec(&e->con->writequeue_cnt);
761         kref_put(&e->ref, dlm_page_release);
762 }
763
764 static void dlm_close_sock(struct socket **sock)
765 {
766         if (*sock) {
767                 restore_callbacks(*sock);
768                 sock_release(*sock);
769                 *sock = NULL;
770         }
771 }
772
773 /* Close a remote connection and tidy up */
774 static void close_connection(struct connection *con, bool and_other,
775                              bool tx, bool rx)
776 {
777         bool closing = test_and_set_bit(CF_CLOSING, &con->flags);
778         struct writequeue_entry *e;
779
780         if (tx && !closing && cancel_work_sync(&con->swork)) {
781                 log_print("canceled swork for node %d", con->nodeid);
782                 clear_bit(CF_WRITE_PENDING, &con->flags);
783         }
784         if (rx && !closing && cancel_work_sync(&con->rwork)) {
785                 log_print("canceled rwork for node %d", con->nodeid);
786                 clear_bit(CF_READ_PENDING, &con->flags);
787         }
788
789         mutex_lock(&con->sock_mutex);
790         dlm_close_sock(&con->sock);
791
792         if (con->othercon && and_other) {
793                 /* Will only re-enter once. */
794                 close_connection(con->othercon, false, tx, rx);
795         }
796
797         /* if we send a writequeue entry only a half way, we drop the
798          * whole entry because reconnection and that we not start of the
799          * middle of a msg which will confuse the other end.
800          *
801          * we can always drop messages because retransmits, but what we
802          * cannot allow is to transmit half messages which may be processed
803          * at the other side.
804          *
805          * our policy is to start on a clean state when disconnects, we don't
806          * know what's send/received on transport layer in this case.
807          */
808         spin_lock(&con->writequeue_lock);
809         if (!list_empty(&con->writequeue)) {
810                 e = list_first_entry(&con->writequeue, struct writequeue_entry,
811                                      list);
812                 if (e->dirty)
813                         free_entry(e);
814         }
815         spin_unlock(&con->writequeue_lock);
816
817         con->rx_leftover = 0;
818         con->retries = 0;
819         clear_bit(CF_APP_LIMITED, &con->flags);
820         clear_bit(CF_CONNECTED, &con->flags);
821         clear_bit(CF_DELAY_CONNECT, &con->flags);
822         clear_bit(CF_RECONNECT, &con->flags);
823         clear_bit(CF_EOF, &con->flags);
824         mutex_unlock(&con->sock_mutex);
825         clear_bit(CF_CLOSING, &con->flags);
826 }
827
828 static void shutdown_connection(struct connection *con)
829 {
830         int ret;
831
832         flush_work(&con->swork);
833
834         mutex_lock(&con->sock_mutex);
835         /* nothing to shutdown */
836         if (!con->sock) {
837                 mutex_unlock(&con->sock_mutex);
838                 return;
839         }
840
841         set_bit(CF_SHUTDOWN, &con->flags);
842         ret = kernel_sock_shutdown(con->sock, SHUT_WR);
843         mutex_unlock(&con->sock_mutex);
844         if (ret) {
845                 log_print("Connection %p failed to shutdown: %d will force close",
846                           con, ret);
847                 goto force_close;
848         } else {
849                 ret = wait_event_timeout(con->shutdown_wait,
850                                          !test_bit(CF_SHUTDOWN, &con->flags),
851                                          DLM_SHUTDOWN_WAIT_TIMEOUT);
852                 if (ret == 0) {
853                         log_print("Connection %p shutdown timed out, will force close",
854                                   con);
855                         goto force_close;
856                 }
857         }
858
859         return;
860
861 force_close:
862         clear_bit(CF_SHUTDOWN, &con->flags);
863         close_connection(con, false, true, true);
864 }
865
866 static void dlm_tcp_shutdown(struct connection *con)
867 {
868         if (con->othercon)
869                 shutdown_connection(con->othercon);
870         shutdown_connection(con);
871 }
872
873 static int con_realloc_receive_buf(struct connection *con, int newlen)
874 {
875         unsigned char *newbuf;
876
877         newbuf = kmalloc(newlen, GFP_NOFS);
878         if (!newbuf)
879                 return -ENOMEM;
880
881         /* copy any leftover from last receive */
882         if (con->rx_leftover)
883                 memmove(newbuf, con->rx_buf, con->rx_leftover);
884
885         /* swap to new buffer space */
886         kfree(con->rx_buf);
887         con->rx_buflen = newlen;
888         con->rx_buf = newbuf;
889
890         return 0;
891 }
892
893 /* Data received from remote end */
894 static int receive_from_sock(struct connection *con)
895 {
896         int call_again_soon = 0;
897         struct msghdr msg;
898         struct kvec iov;
899         int ret, buflen;
900
901         mutex_lock(&con->sock_mutex);
902
903         if (con->sock == NULL) {
904                 ret = -EAGAIN;
905                 goto out_close;
906         }
907
908         /* realloc if we get new buffer size to read out */
909         buflen = dlm_config.ci_buffer_size;
910         if (con->rx_buflen != buflen && con->rx_leftover <= buflen) {
911                 ret = con_realloc_receive_buf(con, buflen);
912                 if (ret < 0)
913                         goto out_resched;
914         }
915
916         /* calculate new buffer parameter regarding last receive and
917          * possible leftover bytes
918          */
919         iov.iov_base = con->rx_buf + con->rx_leftover;
920         iov.iov_len = con->rx_buflen - con->rx_leftover;
921
922         memset(&msg, 0, sizeof(msg));
923         msg.msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL;
924         ret = kernel_recvmsg(con->sock, &msg, &iov, 1, iov.iov_len,
925                              msg.msg_flags);
926         if (ret <= 0)
927                 goto out_close;
928         else if (ret == iov.iov_len)
929                 call_again_soon = 1;
930
931         /* new buflen according readed bytes and leftover from last receive */
932         buflen = ret + con->rx_leftover;
933         ret = dlm_process_incoming_buffer(con->nodeid, con->rx_buf, buflen);
934         if (ret < 0)
935                 goto out_close;
936
937         /* calculate leftover bytes from process and put it into begin of
938          * the receive buffer, so next receive we have the full message
939          * at the start address of the receive buffer.
940          */
941         con->rx_leftover = buflen - ret;
942         if (con->rx_leftover) {
943                 memmove(con->rx_buf, con->rx_buf + ret,
944                         con->rx_leftover);
945                 call_again_soon = true;
946         }
947
948         if (call_again_soon)
949                 goto out_resched;
950
951         mutex_unlock(&con->sock_mutex);
952         return 0;
953
954 out_resched:
955         if (!test_and_set_bit(CF_READ_PENDING, &con->flags))
956                 queue_work(recv_workqueue, &con->rwork);
957         mutex_unlock(&con->sock_mutex);
958         return -EAGAIN;
959
960 out_close:
961         if (ret == 0) {
962                 log_print("connection %p got EOF from %d",
963                           con, con->nodeid);
964
965                 if (con->eof_condition && con->eof_condition(con)) {
966                         set_bit(CF_EOF, &con->flags);
967                         mutex_unlock(&con->sock_mutex);
968                 } else {
969                         mutex_unlock(&con->sock_mutex);
970                         close_connection(con, false, true, false);
971
972                         /* handling for tcp shutdown */
973                         clear_bit(CF_SHUTDOWN, &con->flags);
974                         wake_up(&con->shutdown_wait);
975                 }
976
977                 /* signal to breaking receive worker */
978                 ret = -1;
979         } else {
980                 mutex_unlock(&con->sock_mutex);
981         }
982         return ret;
983 }
984
985 /* Listening socket is busy, accept a connection */
986 static int accept_from_sock(struct listen_connection *con)
987 {
988         int result;
989         struct sockaddr_storage peeraddr;
990         struct socket *newsock;
991         int len, idx;
992         int nodeid;
993         struct connection *newcon;
994         struct connection *addcon;
995         unsigned int mark;
996
997         if (!con->sock)
998                 return -ENOTCONN;
999
1000         result = kernel_accept(con->sock, &newsock, O_NONBLOCK);
1001         if (result < 0)
1002                 goto accept_err;
1003
1004         /* Get the connected socket's peer */
1005         memset(&peeraddr, 0, sizeof(peeraddr));
1006         len = newsock->ops->getname(newsock, (struct sockaddr *)&peeraddr, 2);
1007         if (len < 0) {
1008                 result = -ECONNABORTED;
1009                 goto accept_err;
1010         }
1011
1012         /* Get the new node's NODEID */
1013         make_sockaddr(&peeraddr, 0, &len);
1014         if (addr_to_nodeid(&peeraddr, &nodeid, &mark)) {
1015                 unsigned char *b=(unsigned char *)&peeraddr;
1016                 log_print("connect from non cluster node");
1017                 print_hex_dump_bytes("ss: ", DUMP_PREFIX_NONE, 
1018                                      b, sizeof(struct sockaddr_storage));
1019                 sock_release(newsock);
1020                 return -1;
1021         }
1022
1023         log_print("got connection from %d", nodeid);
1024
1025         /*  Check to see if we already have a connection to this node. This
1026          *  could happen if the two nodes initiate a connection at roughly
1027          *  the same time and the connections cross on the wire.
1028          *  In this case we store the incoming one in "othercon"
1029          */
1030         idx = srcu_read_lock(&connections_srcu);
1031         newcon = nodeid2con(nodeid, GFP_NOFS);
1032         if (!newcon) {
1033                 srcu_read_unlock(&connections_srcu, idx);
1034                 result = -ENOMEM;
1035                 goto accept_err;
1036         }
1037
1038         sock_set_mark(newsock->sk, mark);
1039
1040         mutex_lock(&newcon->sock_mutex);
1041         if (newcon->sock) {
1042                 struct connection *othercon = newcon->othercon;
1043
1044                 if (!othercon) {
1045                         othercon = kzalloc(sizeof(*othercon), GFP_NOFS);
1046                         if (!othercon) {
1047                                 log_print("failed to allocate incoming socket");
1048                                 mutex_unlock(&newcon->sock_mutex);
1049                                 srcu_read_unlock(&connections_srcu, idx);
1050                                 result = -ENOMEM;
1051                                 goto accept_err;
1052                         }
1053
1054                         result = dlm_con_init(othercon, nodeid);
1055                         if (result < 0) {
1056                                 kfree(othercon);
1057                                 mutex_unlock(&newcon->sock_mutex);
1058                                 srcu_read_unlock(&connections_srcu, idx);
1059                                 goto accept_err;
1060                         }
1061
1062                         lockdep_set_subclass(&othercon->sock_mutex, 1);
1063                         set_bit(CF_IS_OTHERCON, &othercon->flags);
1064                         newcon->othercon = othercon;
1065                         othercon->sendcon = newcon;
1066                 } else {
1067                         /* close other sock con if we have something new */
1068                         close_connection(othercon, false, true, false);
1069                 }
1070
1071                 mutex_lock(&othercon->sock_mutex);
1072                 add_sock(newsock, othercon);
1073                 addcon = othercon;
1074                 mutex_unlock(&othercon->sock_mutex);
1075         }
1076         else {
1077                 /* accept copies the sk after we've saved the callbacks, so we
1078                    don't want to save them a second time or comm errors will
1079                    result in calling sk_error_report recursively. */
1080                 add_sock(newsock, newcon);
1081                 addcon = newcon;
1082         }
1083
1084         set_bit(CF_CONNECTED, &addcon->flags);
1085         mutex_unlock(&newcon->sock_mutex);
1086
1087         /*
1088          * Add it to the active queue in case we got data
1089          * between processing the accept adding the socket
1090          * to the read_sockets list
1091          */
1092         if (!test_and_set_bit(CF_READ_PENDING, &addcon->flags))
1093                 queue_work(recv_workqueue, &addcon->rwork);
1094
1095         srcu_read_unlock(&connections_srcu, idx);
1096
1097         return 0;
1098
1099 accept_err:
1100         if (newsock)
1101                 sock_release(newsock);
1102
1103         if (result != -EAGAIN)
1104                 log_print("error accepting connection from node: %d", result);
1105         return result;
1106 }
1107
1108 /*
1109  * writequeue_entry_complete - try to delete and free write queue entry
1110  * @e: write queue entry to try to delete
1111  * @completed: bytes completed
1112  *
1113  * writequeue_lock must be held.
1114  */
1115 static void writequeue_entry_complete(struct writequeue_entry *e, int completed)
1116 {
1117         e->offset += completed;
1118         e->len -= completed;
1119         /* signal that page was half way transmitted */
1120         e->dirty = true;
1121
1122         if (e->len == 0 && e->users == 0)
1123                 free_entry(e);
1124 }
1125
1126 /*
1127  * sctp_bind_addrs - bind a SCTP socket to all our addresses
1128  */
1129 static int sctp_bind_addrs(struct socket *sock, uint16_t port)
1130 {
1131         struct sockaddr_storage localaddr;
1132         struct sockaddr *addr = (struct sockaddr *)&localaddr;
1133         int i, addr_len, result = 0;
1134
1135         for (i = 0; i < dlm_local_count; i++) {
1136                 memcpy(&localaddr, dlm_local_addr[i], sizeof(localaddr));
1137                 make_sockaddr(&localaddr, port, &addr_len);
1138
1139                 if (!i)
1140                         result = kernel_bind(sock, addr, addr_len);
1141                 else
1142                         result = sock_bind_add(sock->sk, addr, addr_len);
1143
1144                 if (result < 0) {
1145                         log_print("Can't bind to %d addr number %d, %d.\n",
1146                                   port, i + 1, result);
1147                         break;
1148                 }
1149         }
1150         return result;
1151 }
1152
1153 /* Initiate an SCTP association.
1154    This is a special case of send_to_sock() in that we don't yet have a
1155    peeled-off socket for this association, so we use the listening socket
1156    and add the primary IP address of the remote node.
1157  */
1158 static void sctp_connect_to_sock(struct connection *con)
1159 {
1160         struct sockaddr_storage daddr;
1161         int result;
1162         int addr_len;
1163         struct socket *sock;
1164         unsigned int mark;
1165
1166         mutex_lock(&con->sock_mutex);
1167
1168         /* Some odd races can cause double-connects, ignore them */
1169         if (con->retries++ > MAX_CONNECT_RETRIES)
1170                 goto out;
1171
1172         if (con->sock) {
1173                 log_print("node %d already connected.", con->nodeid);
1174                 goto out;
1175         }
1176
1177         memset(&daddr, 0, sizeof(daddr));
1178         result = nodeid_to_addr(con->nodeid, &daddr, NULL, true, &mark);
1179         if (result < 0) {
1180                 log_print("no address for nodeid %d", con->nodeid);
1181                 goto out;
1182         }
1183
1184         /* Create a socket to communicate with */
1185         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1186                                   SOCK_STREAM, IPPROTO_SCTP, &sock);
1187         if (result < 0)
1188                 goto socket_err;
1189
1190         sock_set_mark(sock->sk, mark);
1191
1192         add_sock(sock, con);
1193
1194         /* Bind to all addresses. */
1195         if (sctp_bind_addrs(con->sock, 0))
1196                 goto bind_err;
1197
1198         make_sockaddr(&daddr, dlm_config.ci_tcp_port, &addr_len);
1199
1200         log_print_ratelimited("connecting to %d", con->nodeid);
1201
1202         /* Turn off Nagle's algorithm */
1203         sctp_sock_set_nodelay(sock->sk);
1204
1205         /*
1206          * Make sock->ops->connect() function return in specified time,
1207          * since O_NONBLOCK argument in connect() function does not work here,
1208          * then, we should restore the default value of this attribute.
1209          */
1210         sock_set_sndtimeo(sock->sk, 5);
1211         result = sock->ops->connect(sock, (struct sockaddr *)&daddr, addr_len,
1212                                    0);
1213         sock_set_sndtimeo(sock->sk, 0);
1214
1215         if (result == -EINPROGRESS)
1216                 result = 0;
1217         if (result == 0) {
1218                 if (!test_and_set_bit(CF_CONNECTED, &con->flags))
1219                         log_print("successful connected to node %d", con->nodeid);
1220                 goto out;
1221         }
1222
1223 bind_err:
1224         con->sock = NULL;
1225         sock_release(sock);
1226
1227 socket_err:
1228         /*
1229          * Some errors are fatal and this list might need adjusting. For other
1230          * errors we try again until the max number of retries is reached.
1231          */
1232         if (result != -EHOSTUNREACH &&
1233             result != -ENETUNREACH &&
1234             result != -ENETDOWN &&
1235             result != -EINVAL &&
1236             result != -EPROTONOSUPPORT) {
1237                 log_print("connect %d try %d error %d", con->nodeid,
1238                           con->retries, result);
1239                 mutex_unlock(&con->sock_mutex);
1240                 msleep(1000);
1241                 lowcomms_connect_sock(con);
1242                 return;
1243         }
1244
1245 out:
1246         mutex_unlock(&con->sock_mutex);
1247 }
1248
1249 /* Connect a new socket to its peer */
1250 static void tcp_connect_to_sock(struct connection *con)
1251 {
1252         struct sockaddr_storage saddr, src_addr;
1253         unsigned int mark;
1254         int addr_len;
1255         struct socket *sock = NULL;
1256         int result;
1257
1258         mutex_lock(&con->sock_mutex);
1259         if (con->retries++ > MAX_CONNECT_RETRIES)
1260                 goto out;
1261
1262         /* Some odd races can cause double-connects, ignore them */
1263         if (con->sock)
1264                 goto out;
1265
1266         /* Create a socket to communicate with */
1267         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1268                                   SOCK_STREAM, IPPROTO_TCP, &sock);
1269         if (result < 0)
1270                 goto out_err;
1271
1272         memset(&saddr, 0, sizeof(saddr));
1273         result = nodeid_to_addr(con->nodeid, &saddr, NULL, false, &mark);
1274         if (result < 0) {
1275                 log_print("no address for nodeid %d", con->nodeid);
1276                 goto out_err;
1277         }
1278
1279         sock_set_mark(sock->sk, mark);
1280
1281         add_sock(sock, con);
1282
1283         /* Bind to our cluster-known address connecting to avoid
1284            routing problems */
1285         memcpy(&src_addr, dlm_local_addr[0], sizeof(src_addr));
1286         make_sockaddr(&src_addr, 0, &addr_len);
1287         result = sock->ops->bind(sock, (struct sockaddr *) &src_addr,
1288                                  addr_len);
1289         if (result < 0) {
1290                 log_print("could not bind for connect: %d", result);
1291                 /* This *may* not indicate a critical error */
1292         }
1293
1294         make_sockaddr(&saddr, dlm_config.ci_tcp_port, &addr_len);
1295
1296         log_print_ratelimited("connecting to %d", con->nodeid);
1297
1298         /* Turn off Nagle's algorithm */
1299         tcp_sock_set_nodelay(sock->sk);
1300
1301         result = sock->ops->connect(sock, (struct sockaddr *)&saddr, addr_len,
1302                                    O_NONBLOCK);
1303         if (result == -EINPROGRESS)
1304                 result = 0;
1305         if (result == 0)
1306                 goto out;
1307
1308 out_err:
1309         if (con->sock) {
1310                 sock_release(con->sock);
1311                 con->sock = NULL;
1312         } else if (sock) {
1313                 sock_release(sock);
1314         }
1315         /*
1316          * Some errors are fatal and this list might need adjusting. For other
1317          * errors we try again until the max number of retries is reached.
1318          */
1319         if (result != -EHOSTUNREACH &&
1320             result != -ENETUNREACH &&
1321             result != -ENETDOWN && 
1322             result != -EINVAL &&
1323             result != -EPROTONOSUPPORT) {
1324                 log_print("connect %d try %d error %d", con->nodeid,
1325                           con->retries, result);
1326                 mutex_unlock(&con->sock_mutex);
1327                 msleep(1000);
1328                 lowcomms_connect_sock(con);
1329                 return;
1330         }
1331 out:
1332         mutex_unlock(&con->sock_mutex);
1333         return;
1334 }
1335
1336 /* On error caller must run dlm_close_sock() for the
1337  * listen connection socket.
1338  */
1339 static int tcp_create_listen_sock(struct listen_connection *con,
1340                                   struct sockaddr_storage *saddr)
1341 {
1342         struct socket *sock = NULL;
1343         int result = 0;
1344         int addr_len;
1345
1346         if (dlm_local_addr[0]->ss_family == AF_INET)
1347                 addr_len = sizeof(struct sockaddr_in);
1348         else
1349                 addr_len = sizeof(struct sockaddr_in6);
1350
1351         /* Create a socket to communicate with */
1352         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1353                                   SOCK_STREAM, IPPROTO_TCP, &sock);
1354         if (result < 0) {
1355                 log_print("Can't create listening comms socket");
1356                 goto create_out;
1357         }
1358
1359         sock_set_mark(sock->sk, dlm_config.ci_mark);
1360
1361         /* Turn off Nagle's algorithm */
1362         tcp_sock_set_nodelay(sock->sk);
1363
1364         sock_set_reuseaddr(sock->sk);
1365
1366         add_listen_sock(sock, con);
1367
1368         /* Bind to our port */
1369         make_sockaddr(saddr, dlm_config.ci_tcp_port, &addr_len);
1370         result = sock->ops->bind(sock, (struct sockaddr *) saddr, addr_len);
1371         if (result < 0) {
1372                 log_print("Can't bind to port %d", dlm_config.ci_tcp_port);
1373                 goto create_out;
1374         }
1375         sock_set_keepalive(sock->sk);
1376
1377         result = sock->ops->listen(sock, 5);
1378         if (result < 0) {
1379                 log_print("Can't listen on port %d", dlm_config.ci_tcp_port);
1380                 goto create_out;
1381         }
1382
1383         return 0;
1384
1385 create_out:
1386         return result;
1387 }
1388
1389 /* Get local addresses */
1390 static void init_local(void)
1391 {
1392         struct sockaddr_storage sas, *addr;
1393         int i;
1394
1395         dlm_local_count = 0;
1396         for (i = 0; i < DLM_MAX_ADDR_COUNT; i++) {
1397                 if (dlm_our_addr(&sas, i))
1398                         break;
1399
1400                 addr = kmemdup(&sas, sizeof(*addr), GFP_NOFS);
1401                 if (!addr)
1402                         break;
1403                 dlm_local_addr[dlm_local_count++] = addr;
1404         }
1405 }
1406
1407 static void deinit_local(void)
1408 {
1409         int i;
1410
1411         for (i = 0; i < dlm_local_count; i++)
1412                 kfree(dlm_local_addr[i]);
1413 }
1414
1415 /* Initialise SCTP socket and bind to all interfaces
1416  * On error caller must run dlm_close_sock() for the
1417  * listen connection socket.
1418  */
1419 static int sctp_listen_for_all(struct listen_connection *con)
1420 {
1421         struct socket *sock = NULL;
1422         int result = -EINVAL;
1423
1424         log_print("Using SCTP for communications");
1425
1426         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1427                                   SOCK_STREAM, IPPROTO_SCTP, &sock);
1428         if (result < 0) {
1429                 log_print("Can't create comms socket, check SCTP is loaded");
1430                 goto out;
1431         }
1432
1433         sock_set_rcvbuf(sock->sk, NEEDED_RMEM);
1434         sock_set_mark(sock->sk, dlm_config.ci_mark);
1435         sctp_sock_set_nodelay(sock->sk);
1436
1437         add_listen_sock(sock, con);
1438
1439         /* Bind to all addresses. */
1440         result = sctp_bind_addrs(con->sock, dlm_config.ci_tcp_port);
1441         if (result < 0)
1442                 goto out;
1443
1444         result = sock->ops->listen(sock, 5);
1445         if (result < 0) {
1446                 log_print("Can't set socket listening");
1447                 goto out;
1448         }
1449
1450         return 0;
1451
1452 out:
1453         return result;
1454 }
1455
1456 static int tcp_listen_for_all(void)
1457 {
1458         /* We don't support multi-homed hosts */
1459         if (dlm_local_count > 1) {
1460                 log_print("TCP protocol can't handle multi-homed hosts, "
1461                           "try SCTP");
1462                 return -EINVAL;
1463         }
1464
1465         log_print("Using TCP for communications");
1466
1467         return tcp_create_listen_sock(&listen_con, dlm_local_addr[0]);
1468 }
1469
1470
1471
1472 static struct writequeue_entry *new_writequeue_entry(struct connection *con,
1473                                                      gfp_t allocation)
1474 {
1475         struct writequeue_entry *entry;
1476
1477         entry = kzalloc(sizeof(*entry), allocation);
1478         if (!entry)
1479                 return NULL;
1480
1481         entry->page = alloc_page(allocation | __GFP_ZERO);
1482         if (!entry->page) {
1483                 kfree(entry);
1484                 return NULL;
1485         }
1486
1487         entry->con = con;
1488         entry->users = 1;
1489         kref_init(&entry->ref);
1490         INIT_LIST_HEAD(&entry->msgs);
1491
1492         return entry;
1493 }
1494
1495 static struct writequeue_entry *new_wq_entry(struct connection *con, int len,
1496                                              gfp_t allocation, char **ppc,
1497                                              void (*cb)(struct dlm_mhandle *mh),
1498                                              struct dlm_mhandle *mh)
1499 {
1500         struct writequeue_entry *e;
1501
1502         spin_lock(&con->writequeue_lock);
1503         if (!list_empty(&con->writequeue)) {
1504                 e = list_last_entry(&con->writequeue, struct writequeue_entry, list);
1505                 if (DLM_WQ_REMAIN_BYTES(e) >= len) {
1506                         kref_get(&e->ref);
1507
1508                         *ppc = page_address(e->page) + e->end;
1509                         if (cb)
1510                                 cb(mh);
1511
1512                         e->end += len;
1513                         e->users++;
1514                         spin_unlock(&con->writequeue_lock);
1515
1516                         return e;
1517                 }
1518         }
1519         spin_unlock(&con->writequeue_lock);
1520
1521         e = new_writequeue_entry(con, allocation);
1522         if (!e)
1523                 return NULL;
1524
1525         kref_get(&e->ref);
1526         *ppc = page_address(e->page);
1527         e->end += len;
1528         atomic_inc(&con->writequeue_cnt);
1529
1530         spin_lock(&con->writequeue_lock);
1531         if (cb)
1532                 cb(mh);
1533
1534         list_add_tail(&e->list, &con->writequeue);
1535         spin_unlock(&con->writequeue_lock);
1536
1537         return e;
1538 };
1539
1540 static struct dlm_msg *dlm_lowcomms_new_msg_con(struct connection *con, int len,
1541                                                 gfp_t allocation, char **ppc,
1542                                                 void (*cb)(struct dlm_mhandle *mh),
1543                                                 struct dlm_mhandle *mh)
1544 {
1545         struct writequeue_entry *e;
1546         struct dlm_msg *msg;
1547
1548         msg = kzalloc(sizeof(*msg), allocation);
1549         if (!msg)
1550                 return NULL;
1551
1552         kref_init(&msg->ref);
1553
1554         e = new_wq_entry(con, len, allocation, ppc, cb, mh);
1555         if (!e) {
1556                 kfree(msg);
1557                 return NULL;
1558         }
1559
1560         msg->ppc = *ppc;
1561         msg->len = len;
1562         msg->entry = e;
1563
1564         return msg;
1565 }
1566
1567 struct dlm_msg *dlm_lowcomms_new_msg(int nodeid, int len, gfp_t allocation,
1568                                      char **ppc, void (*cb)(struct dlm_mhandle *mh),
1569                                      struct dlm_mhandle *mh)
1570 {
1571         struct connection *con;
1572         struct dlm_msg *msg;
1573         int idx;
1574
1575         if (len > DLM_MAX_SOCKET_BUFSIZE ||
1576             len < sizeof(struct dlm_header)) {
1577                 BUILD_BUG_ON(PAGE_SIZE < DLM_MAX_SOCKET_BUFSIZE);
1578                 log_print("failed to allocate a buffer of size %d", len);
1579                 WARN_ON(1);
1580                 return NULL;
1581         }
1582
1583         idx = srcu_read_lock(&connections_srcu);
1584         con = nodeid2con(nodeid, allocation);
1585         if (!con) {
1586                 srcu_read_unlock(&connections_srcu, idx);
1587                 return NULL;
1588         }
1589
1590         msg = dlm_lowcomms_new_msg_con(con, len, allocation, ppc, cb, mh);
1591         if (!msg) {
1592                 srcu_read_unlock(&connections_srcu, idx);
1593                 return NULL;
1594         }
1595
1596         /* we assume if successful commit must called */
1597         msg->idx = idx;
1598         return msg;
1599 }
1600
1601 static void _dlm_lowcomms_commit_msg(struct dlm_msg *msg)
1602 {
1603         struct writequeue_entry *e = msg->entry;
1604         struct connection *con = e->con;
1605         int users;
1606
1607         spin_lock(&con->writequeue_lock);
1608         kref_get(&msg->ref);
1609         list_add(&msg->list, &e->msgs);
1610
1611         users = --e->users;
1612         if (users)
1613                 goto out;
1614
1615         e->len = DLM_WQ_LENGTH_BYTES(e);
1616         spin_unlock(&con->writequeue_lock);
1617
1618         queue_work(send_workqueue, &con->swork);
1619         return;
1620
1621 out:
1622         spin_unlock(&con->writequeue_lock);
1623         return;
1624 }
1625
1626 void dlm_lowcomms_commit_msg(struct dlm_msg *msg)
1627 {
1628         _dlm_lowcomms_commit_msg(msg);
1629         srcu_read_unlock(&connections_srcu, msg->idx);
1630 }
1631
1632 void dlm_lowcomms_put_msg(struct dlm_msg *msg)
1633 {
1634         kref_put(&msg->ref, dlm_msg_release);
1635 }
1636
1637 /* does not held connections_srcu, usage workqueue only */
1638 int dlm_lowcomms_resend_msg(struct dlm_msg *msg)
1639 {
1640         struct dlm_msg *msg_resend;
1641         char *ppc;
1642
1643         if (msg->retransmit)
1644                 return 1;
1645
1646         msg_resend = dlm_lowcomms_new_msg_con(msg->entry->con, msg->len,
1647                                               GFP_ATOMIC, &ppc, NULL, NULL);
1648         if (!msg_resend)
1649                 return -ENOMEM;
1650
1651         msg->retransmit = true;
1652         kref_get(&msg->ref);
1653         msg_resend->orig_msg = msg;
1654
1655         memcpy(ppc, msg->ppc, msg->len);
1656         _dlm_lowcomms_commit_msg(msg_resend);
1657         dlm_lowcomms_put_msg(msg_resend);
1658
1659         return 0;
1660 }
1661
1662 /* Send a message */
1663 static void send_to_sock(struct connection *con)
1664 {
1665         const int msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL;
1666         struct writequeue_entry *e;
1667         int len, offset, ret;
1668         int count = 0;
1669
1670         mutex_lock(&con->sock_mutex);
1671         if (con->sock == NULL)
1672                 goto out_connect;
1673
1674         spin_lock(&con->writequeue_lock);
1675         for (;;) {
1676                 e = con_next_wq(con);
1677                 if (!e)
1678                         break;
1679
1680                 e = list_first_entry(&con->writequeue, struct writequeue_entry, list);
1681                 len = e->len;
1682                 offset = e->offset;
1683                 BUG_ON(len == 0 && e->users == 0);
1684                 spin_unlock(&con->writequeue_lock);
1685
1686                 ret = kernel_sendpage(con->sock, e->page, offset, len,
1687                                       msg_flags);
1688                 if (ret == -EAGAIN || ret == 0) {
1689                         if (ret == -EAGAIN &&
1690                             test_bit(SOCKWQ_ASYNC_NOSPACE, &con->sock->flags) &&
1691                             !test_and_set_bit(CF_APP_LIMITED, &con->flags)) {
1692                                 /* Notify TCP that we're limited by the
1693                                  * application window size.
1694                                  */
1695                                 set_bit(SOCK_NOSPACE, &con->sock->flags);
1696                                 con->sock->sk->sk_write_pending++;
1697                         }
1698                         cond_resched();
1699                         goto out;
1700                 } else if (ret < 0)
1701                         goto out;
1702
1703                 /* Don't starve people filling buffers */
1704                 if (++count >= MAX_SEND_MSG_COUNT) {
1705                         cond_resched();
1706                         count = 0;
1707                 }
1708
1709                 spin_lock(&con->writequeue_lock);
1710                 writequeue_entry_complete(e, ret);
1711         }
1712         spin_unlock(&con->writequeue_lock);
1713
1714         /* close if we got EOF */
1715         if (test_and_clear_bit(CF_EOF, &con->flags)) {
1716                 mutex_unlock(&con->sock_mutex);
1717                 close_connection(con, false, false, true);
1718
1719                 /* handling for tcp shutdown */
1720                 clear_bit(CF_SHUTDOWN, &con->flags);
1721                 wake_up(&con->shutdown_wait);
1722         } else {
1723                 mutex_unlock(&con->sock_mutex);
1724         }
1725
1726         return;
1727
1728 out:
1729         mutex_unlock(&con->sock_mutex);
1730         return;
1731
1732 out_connect:
1733         mutex_unlock(&con->sock_mutex);
1734         queue_work(send_workqueue, &con->swork);
1735         cond_resched();
1736 }
1737
1738 static void clean_one_writequeue(struct connection *con)
1739 {
1740         struct writequeue_entry *e, *safe;
1741
1742         spin_lock(&con->writequeue_lock);
1743         list_for_each_entry_safe(e, safe, &con->writequeue, list) {
1744                 free_entry(e);
1745         }
1746         spin_unlock(&con->writequeue_lock);
1747 }
1748
1749 /* Called from recovery when it knows that a node has
1750    left the cluster */
1751 int dlm_lowcomms_close(int nodeid)
1752 {
1753         struct connection *con;
1754         struct dlm_node_addr *na;
1755         int idx;
1756
1757         log_print("closing connection to node %d", nodeid);
1758         idx = srcu_read_lock(&connections_srcu);
1759         con = nodeid2con(nodeid, 0);
1760         if (con) {
1761                 set_bit(CF_CLOSE, &con->flags);
1762                 close_connection(con, true, true, true);
1763                 clean_one_writequeue(con);
1764                 if (con->othercon)
1765                         clean_one_writequeue(con->othercon);
1766         }
1767         srcu_read_unlock(&connections_srcu, idx);
1768
1769         spin_lock(&dlm_node_addrs_spin);
1770         na = find_node_addr(nodeid);
1771         if (na) {
1772                 list_del(&na->list);
1773                 while (na->addr_count--)
1774                         kfree(na->addr[na->addr_count]);
1775                 kfree(na);
1776         }
1777         spin_unlock(&dlm_node_addrs_spin);
1778
1779         return 0;
1780 }
1781
1782 /* Receive workqueue function */
1783 static void process_recv_sockets(struct work_struct *work)
1784 {
1785         struct connection *con = container_of(work, struct connection, rwork);
1786         int err;
1787
1788         clear_bit(CF_READ_PENDING, &con->flags);
1789         do {
1790                 err = receive_from_sock(con);
1791         } while (!err);
1792 }
1793
1794 static void process_listen_recv_socket(struct work_struct *work)
1795 {
1796         accept_from_sock(&listen_con);
1797 }
1798
1799 /* Send workqueue function */
1800 static void process_send_sockets(struct work_struct *work)
1801 {
1802         struct connection *con = container_of(work, struct connection, swork);
1803
1804         WARN_ON(test_bit(CF_IS_OTHERCON, &con->flags));
1805
1806         clear_bit(CF_WRITE_PENDING, &con->flags);
1807
1808         if (test_and_clear_bit(CF_RECONNECT, &con->flags)) {
1809                 close_connection(con, false, false, true);
1810                 dlm_midcomms_unack_msg_resend(con->nodeid);
1811         }
1812
1813         if (con->sock == NULL) { /* not mutex protected so check it inside too */
1814                 if (test_and_clear_bit(CF_DELAY_CONNECT, &con->flags))
1815                         msleep(1000);
1816                 con->connect_action(con);
1817         }
1818         if (!list_empty(&con->writequeue))
1819                 send_to_sock(con);
1820 }
1821
1822 static void work_stop(void)
1823 {
1824         if (recv_workqueue) {
1825                 destroy_workqueue(recv_workqueue);
1826                 recv_workqueue = NULL;
1827         }
1828
1829         if (send_workqueue) {
1830                 destroy_workqueue(send_workqueue);
1831                 send_workqueue = NULL;
1832         }
1833 }
1834
1835 static int work_start(void)
1836 {
1837         recv_workqueue = alloc_ordered_workqueue("dlm_recv", WQ_MEM_RECLAIM);
1838         if (!recv_workqueue) {
1839                 log_print("can't start dlm_recv");
1840                 return -ENOMEM;
1841         }
1842
1843         send_workqueue = alloc_ordered_workqueue("dlm_send", WQ_MEM_RECLAIM);
1844         if (!send_workqueue) {
1845                 log_print("can't start dlm_send");
1846                 destroy_workqueue(recv_workqueue);
1847                 recv_workqueue = NULL;
1848                 return -ENOMEM;
1849         }
1850
1851         return 0;
1852 }
1853
1854 static void shutdown_conn(struct connection *con)
1855 {
1856         if (con->shutdown_action)
1857                 con->shutdown_action(con);
1858 }
1859
1860 void dlm_lowcomms_shutdown(void)
1861 {
1862         int idx;
1863
1864         /* Set all the flags to prevent any
1865          * socket activity.
1866          */
1867         dlm_allow_conn = 0;
1868
1869         if (recv_workqueue)
1870                 flush_workqueue(recv_workqueue);
1871         if (send_workqueue)
1872                 flush_workqueue(send_workqueue);
1873
1874         dlm_close_sock(&listen_con.sock);
1875
1876         idx = srcu_read_lock(&connections_srcu);
1877         foreach_conn(shutdown_conn);
1878         srcu_read_unlock(&connections_srcu, idx);
1879 }
1880
1881 static void _stop_conn(struct connection *con, bool and_other)
1882 {
1883         mutex_lock(&con->sock_mutex);
1884         set_bit(CF_CLOSE, &con->flags);
1885         set_bit(CF_READ_PENDING, &con->flags);
1886         set_bit(CF_WRITE_PENDING, &con->flags);
1887         if (con->sock && con->sock->sk) {
1888                 write_lock_bh(&con->sock->sk->sk_callback_lock);
1889                 con->sock->sk->sk_user_data = NULL;
1890                 write_unlock_bh(&con->sock->sk->sk_callback_lock);
1891         }
1892         if (con->othercon && and_other)
1893                 _stop_conn(con->othercon, false);
1894         mutex_unlock(&con->sock_mutex);
1895 }
1896
1897 static void stop_conn(struct connection *con)
1898 {
1899         _stop_conn(con, true);
1900 }
1901
1902 static void connection_release(struct rcu_head *rcu)
1903 {
1904         struct connection *con = container_of(rcu, struct connection, rcu);
1905
1906         kfree(con->rx_buf);
1907         kfree(con);
1908 }
1909
1910 static void free_conn(struct connection *con)
1911 {
1912         close_connection(con, true, true, true);
1913         spin_lock(&connections_lock);
1914         hlist_del_rcu(&con->list);
1915         spin_unlock(&connections_lock);
1916         if (con->othercon) {
1917                 clean_one_writequeue(con->othercon);
1918                 call_srcu(&connections_srcu, &con->othercon->rcu,
1919                           connection_release);
1920         }
1921         clean_one_writequeue(con);
1922         call_srcu(&connections_srcu, &con->rcu, connection_release);
1923 }
1924
1925 static void work_flush(void)
1926 {
1927         int ok;
1928         int i;
1929         struct connection *con;
1930
1931         do {
1932                 ok = 1;
1933                 foreach_conn(stop_conn);
1934                 if (recv_workqueue)
1935                         flush_workqueue(recv_workqueue);
1936                 if (send_workqueue)
1937                         flush_workqueue(send_workqueue);
1938                 for (i = 0; i < CONN_HASH_SIZE && ok; i++) {
1939                         hlist_for_each_entry_rcu(con, &connection_hash[i],
1940                                                  list) {
1941                                 ok &= test_bit(CF_READ_PENDING, &con->flags);
1942                                 ok &= test_bit(CF_WRITE_PENDING, &con->flags);
1943                                 if (con->othercon) {
1944                                         ok &= test_bit(CF_READ_PENDING,
1945                                                        &con->othercon->flags);
1946                                         ok &= test_bit(CF_WRITE_PENDING,
1947                                                        &con->othercon->flags);
1948                                 }
1949                         }
1950                 }
1951         } while (!ok);
1952 }
1953
1954 void dlm_lowcomms_stop(void)
1955 {
1956         int idx;
1957
1958         idx = srcu_read_lock(&connections_srcu);
1959         work_flush();
1960         foreach_conn(free_conn);
1961         srcu_read_unlock(&connections_srcu, idx);
1962         work_stop();
1963         deinit_local();
1964 }
1965
1966 int dlm_lowcomms_start(void)
1967 {
1968         int error = -EINVAL;
1969         int i;
1970
1971         for (i = 0; i < CONN_HASH_SIZE; i++)
1972                 INIT_HLIST_HEAD(&connection_hash[i]);
1973
1974         init_local();
1975         if (!dlm_local_count) {
1976                 error = -ENOTCONN;
1977                 log_print("no local IP address has been set");
1978                 goto fail;
1979         }
1980
1981         INIT_WORK(&listen_con.rwork, process_listen_recv_socket);
1982
1983         error = work_start();
1984         if (error)
1985                 goto fail_local;
1986
1987         dlm_allow_conn = 1;
1988
1989         /* Start listening */
1990         switch (dlm_config.ci_protocol) {
1991         case DLM_PROTO_TCP:
1992                 error = tcp_listen_for_all();
1993                 break;
1994         case DLM_PROTO_SCTP:
1995                 error = sctp_listen_for_all(&listen_con);
1996                 break;
1997         default:
1998                 log_print("Invalid protocol identifier %d set",
1999                           dlm_config.ci_protocol);
2000                 error = -EINVAL;
2001                 break;
2002         }
2003         if (error)
2004                 goto fail_unlisten;
2005
2006         return 0;
2007
2008 fail_unlisten:
2009         dlm_allow_conn = 0;
2010         dlm_close_sock(&listen_con.sock);
2011         work_stop();
2012 fail_local:
2013         deinit_local();
2014 fail:
2015         return error;
2016 }
2017
2018 void dlm_lowcomms_exit(void)
2019 {
2020         struct dlm_node_addr *na, *safe;
2021
2022         spin_lock(&dlm_node_addrs_spin);
2023         list_for_each_entry_safe(na, safe, &dlm_node_addrs, list) {
2024                 list_del(&na->list);
2025                 while (na->addr_count--)
2026                         kfree(na->addr[na->addr_count]);
2027                 kfree(na);
2028         }
2029         spin_unlock(&dlm_node_addrs_spin);
2030 }