DLM: fix race condition between dlm_send and dlm_recv
[linux-2.6-microblaze.git] / fs / dlm / lowcomms.c
index 4813d0e..420946d 100644 (file)
@@ -106,12 +106,11 @@ struct connection {
        struct mutex sock_mutex;
        unsigned long flags;
 #define CF_READ_PENDING 1
-#define CF_WRITE_PENDING 2
-#define CF_CONNECT_PENDING 3
 #define CF_INIT_PENDING 4
 #define CF_IS_OTHERCON 5
 #define CF_CLOSE 6
 #define CF_APP_LIMITED 7
+#define CF_CLOSING 8
        struct list_head writequeue;  /* List of outgoing writequeue_entries */
        spinlock_t writequeue_lock;
        int (*rx_action) (struct connection *); /* What to do when active */
@@ -124,10 +123,6 @@ struct connection {
        struct connection *othercon;
        struct work_struct rwork; /* Receive workqueue */
        struct work_struct swork; /* Send workqueue */
-       void (*orig_error_report)(struct sock *);
-       void (*orig_data_ready)(struct sock *);
-       void (*orig_state_change)(struct sock *);
-       void (*orig_write_space)(struct sock *);
 };
 #define sock2con(x) ((struct connection *)(x)->sk_user_data)
 
@@ -150,6 +145,13 @@ struct dlm_node_addr {
        struct sockaddr_storage *addr[DLM_MAX_ADDR_COUNT];
 };
 
+static struct listen_sock_callbacks {
+       void (*sk_error_report)(struct sock *);
+       void (*sk_data_ready)(struct sock *);
+       void (*sk_state_change)(struct sock *);
+       void (*sk_write_space)(struct sock *);
+} listen_sock;
+
 static LIST_HEAD(dlm_node_addrs);
 static DEFINE_SPINLOCK(dlm_node_addrs_spin);
 
@@ -427,16 +429,15 @@ static void lowcomms_write_space(struct sock *sk)
                clear_bit(SOCKWQ_ASYNC_NOSPACE, &con->sock->flags);
        }
 
-       if (!test_and_set_bit(CF_WRITE_PENDING, &con->flags))
-               queue_work(send_workqueue, &con->swork);
+       queue_work(send_workqueue, &con->swork);
 }
 
 static inline void lowcomms_connect_sock(struct connection *con)
 {
        if (test_bit(CF_CLOSE, &con->flags))
                return;
-       if (!test_and_set_bit(CF_CONNECT_PENDING, &con->flags))
-               queue_work(send_workqueue, &con->swork);
+       queue_work(send_workqueue, &con->swork);
+       cond_resched();
 }
 
 static void lowcomms_state_change(struct sock *sk)
@@ -480,7 +481,7 @@ static void lowcomms_error_report(struct sock *sk)
        if (con == NULL)
                goto out;
 
-       orig_report = con->orig_error_report;
+       orig_report = listen_sock.sk_error_report;
        if (con->sock == NULL ||
            kernel_getpeername(con->sock, (struct sockaddr *)&saddr, &buflen)) {
                printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
@@ -517,27 +518,31 @@ out:
 }
 
 /* Note: sk_callback_lock must be locked before calling this function. */
-static void save_callbacks(struct connection *con, struct sock *sk)
+static void save_listen_callbacks(struct socket *sock)
 {
-       con->orig_data_ready = sk->sk_data_ready;
-       con->orig_state_change = sk->sk_state_change;
-       con->orig_write_space = sk->sk_write_space;
-       con->orig_error_report = sk->sk_error_report;
+       struct sock *sk = sock->sk;
+
+       listen_sock.sk_data_ready = sk->sk_data_ready;
+       listen_sock.sk_state_change = sk->sk_state_change;
+       listen_sock.sk_write_space = sk->sk_write_space;
+       listen_sock.sk_error_report = sk->sk_error_report;
 }
 
-static void restore_callbacks(struct connection *con, struct sock *sk)
+static void restore_callbacks(struct socket *sock)
 {
+       struct sock *sk = sock->sk;
+
        write_lock_bh(&sk->sk_callback_lock);
        sk->sk_user_data = NULL;
-       sk->sk_data_ready = con->orig_data_ready;
-       sk->sk_state_change = con->orig_state_change;
-       sk->sk_write_space = con->orig_write_space;
-       sk->sk_error_report = con->orig_error_report;
+       sk->sk_data_ready = listen_sock.sk_data_ready;
+       sk->sk_state_change = listen_sock.sk_state_change;
+       sk->sk_write_space = listen_sock.sk_write_space;
+       sk->sk_error_report = listen_sock.sk_error_report;
        write_unlock_bh(&sk->sk_callback_lock);
 }
 
 /* Make a socket active */
-static void add_sock(struct socket *sock, struct connection *con, bool save_cb)
+static void add_sock(struct socket *sock, struct connection *con)
 {
        struct sock *sk = sock->sk;
 
@@ -545,8 +550,6 @@ static void add_sock(struct socket *sock, struct connection *con, bool save_cb)
        con->sock = sock;
 
        sk->sk_user_data = con;
-       if (save_cb)
-               save_callbacks(con, sk);
        /* Install a data_ready callback */
        sk->sk_data_ready = lowcomms_data_ready;
        sk->sk_write_space = lowcomms_write_space;
@@ -579,17 +582,16 @@ static void make_sockaddr(struct sockaddr_storage *saddr, uint16_t port,
 static void close_connection(struct connection *con, bool and_other,
                             bool tx, bool rx)
 {
-       clear_bit(CF_CONNECT_PENDING, &con->flags);
-       clear_bit(CF_WRITE_PENDING, &con->flags);
-       if (tx && cancel_work_sync(&con->swork))
+       bool closing = test_and_set_bit(CF_CLOSING, &con->flags);
+
+       if (tx && !closing && cancel_work_sync(&con->swork))
                log_print("canceled swork for node %d", con->nodeid);
-       if (rx && cancel_work_sync(&con->rwork))
+       if (rx && !closing && cancel_work_sync(&con->rwork))
                log_print("canceled rwork for node %d", con->nodeid);
 
        mutex_lock(&con->sock_mutex);
        if (con->sock) {
-               if (!test_bit(CF_IS_OTHERCON, &con->flags))
-                       restore_callbacks(con, con->sock->sk);
+               restore_callbacks(con->sock);
                sock_release(con->sock);
                con->sock = NULL;
        }
@@ -604,6 +606,7 @@ static void close_connection(struct connection *con, bool and_other,
 
        con->retries = 0;
        mutex_unlock(&con->sock_mutex);
+       clear_bit(CF_CLOSING, &con->flags);
 }
 
 /* Data received from remote end */
@@ -802,7 +805,7 @@ static int tcp_accept_from_sock(struct connection *con)
                        newcon->othercon = othercon;
                        othercon->sock = newsock;
                        newsock->sk->sk_user_data = othercon;
-                       add_sock(newsock, othercon, false);
+                       add_sock(newsock, othercon);
                        addcon = othercon;
                }
                else {
@@ -818,7 +821,7 @@ static int tcp_accept_from_sock(struct connection *con)
                /* accept copies the sk after we've saved the callbacks, so we
                   don't want to save them a second time or comm errors will
                   result in calling sk_error_report recursively. */
-               add_sock(newsock, newcon, false);
+               add_sock(newsock, newcon);
                addcon = newcon;
        }
 
@@ -919,7 +922,7 @@ static int sctp_accept_from_sock(struct connection *con)
                        newcon->othercon = othercon;
                        othercon->sock = newsock;
                        newsock->sk->sk_user_data = othercon;
-                       add_sock(newsock, othercon, false);
+                       add_sock(newsock, othercon);
                        addcon = othercon;
                } else {
                        printk("Extra connection from node %d attempted\n", nodeid);
@@ -930,7 +933,7 @@ static int sctp_accept_from_sock(struct connection *con)
        } else {
                newsock->sk->sk_user_data = newcon;
                newcon->rx_action = receive_from_sock;
-               add_sock(newsock, newcon, false);
+               add_sock(newsock, newcon);
                addcon = newcon;
        }
 
@@ -1058,7 +1061,7 @@ static void sctp_connect_to_sock(struct connection *con)
        sock->sk->sk_user_data = con;
        con->rx_action = receive_from_sock;
        con->connect_action = sctp_connect_to_sock;
-       add_sock(sock, con, true);
+       add_sock(sock, con);
 
        /* Bind to all addresses. */
        if (sctp_bind_addrs(con, 0))
@@ -1079,7 +1082,6 @@ static void sctp_connect_to_sock(struct connection *con)
        if (result == 0)
                goto out;
 
-
 bind_err:
        con->sock = NULL;
        sock_release(sock);
@@ -1098,14 +1100,12 @@ socket_err:
                          con->retries, result);
                mutex_unlock(&con->sock_mutex);
                msleep(1000);
-               clear_bit(CF_CONNECT_PENDING, &con->flags);
                lowcomms_connect_sock(con);
                return;
        }
 
 out:
        mutex_unlock(&con->sock_mutex);
-       set_bit(CF_WRITE_PENDING, &con->flags);
 }
 
 /* Connect a new socket to its peer */
@@ -1146,7 +1146,7 @@ static void tcp_connect_to_sock(struct connection *con)
        sock->sk->sk_user_data = con;
        con->rx_action = receive_from_sock;
        con->connect_action = tcp_connect_to_sock;
-       add_sock(sock, con, true);
+       add_sock(sock, con);
 
        /* Bind to our cluster-known address connecting to avoid
           routing problems */
@@ -1194,13 +1194,11 @@ out_err:
                          con->retries, result);
                mutex_unlock(&con->sock_mutex);
                msleep(1000);
-               clear_bit(CF_CONNECT_PENDING, &con->flags);
                lowcomms_connect_sock(con);
                return;
        }
 out:
        mutex_unlock(&con->sock_mutex);
-       set_bit(CF_WRITE_PENDING, &con->flags);
        return;
 }
 
@@ -1236,7 +1234,7 @@ static struct socket *tcp_create_listen_sock(struct connection *con,
                log_print("Failed to set SO_REUSEADDR on socket: %d", result);
        }
        sock->sk->sk_user_data = con;
-
+       save_listen_callbacks(sock);
        con->rx_action = tcp_accept_from_sock;
        con->connect_action = tcp_connect_to_sock;
 
@@ -1320,6 +1318,7 @@ static int sctp_listen_for_all(void)
        write_lock_bh(&sock->sk->sk_callback_lock);
        /* Init con struct */
        sock->sk->sk_user_data = con;
+       save_listen_callbacks(sock);
        con->sock = sock;
        con->sock->sk->sk_data_ready = lowcomms_data_ready;
        con->rx_action = sctp_accept_from_sock;
@@ -1366,7 +1365,7 @@ static int tcp_listen_for_all(void)
 
        sock = tcp_create_listen_sock(con, dlm_local_addr[0]);
        if (sock) {
-               add_sock(sock, con, true);
+               add_sock(sock, con);
                result = 0;
        }
        else {
@@ -1456,9 +1455,7 @@ void dlm_lowcomms_commit_buffer(void *mh)
        e->len = e->end - e->offset;
        spin_unlock(&con->writequeue_lock);
 
-       if (!test_and_set_bit(CF_WRITE_PENDING, &con->flags)) {
-               queue_work(send_workqueue, &con->swork);
-       }
+       queue_work(send_workqueue, &con->swork);
        return;
 
 out:
@@ -1528,12 +1525,15 @@ out:
 send_error:
        mutex_unlock(&con->sock_mutex);
        close_connection(con, false, false, true);
-       lowcomms_connect_sock(con);
+       /* Requeue the send work. When the work daemon runs again, it will try
+          a new connection, then call this function again. */
+       queue_work(send_workqueue, &con->swork);
        return;
 
 out_connect:
        mutex_unlock(&con->sock_mutex);
-       lowcomms_connect_sock(con);
+       queue_work(send_workqueue, &con->swork);
+       cond_resched();
 }
 
 static void clean_one_writequeue(struct connection *con)
@@ -1593,9 +1593,9 @@ static void process_send_sockets(struct work_struct *work)
 {
        struct connection *con = container_of(work, struct connection, swork);
 
-       if (test_and_clear_bit(CF_CONNECT_PENDING, &con->flags))
+       if (con->sock == NULL) /* not mutex protected so check it inside too */
                con->connect_action(con);
-       if (test_and_clear_bit(CF_WRITE_PENDING, &con->flags))
+       if (!list_empty(&con->writequeue))
                send_to_sock(con);
 }
 
@@ -1632,11 +1632,20 @@ static int work_start(void)
        return 0;
 }
 
-static void stop_conn(struct connection *con)
+static void _stop_conn(struct connection *con, bool and_other)
 {
-       con->flags |= 0x0F;
+       mutex_lock(&con->sock_mutex);
+       set_bit(CF_READ_PENDING, &con->flags);
        if (con->sock && con->sock->sk)
                con->sock->sk->sk_user_data = NULL;
+       if (con->othercon && and_other)
+               _stop_conn(con->othercon, false);
+       mutex_unlock(&con->sock_mutex);
+}
+
+static void stop_conn(struct connection *con)
+{
+       _stop_conn(con, true);
 }
 
 static void free_conn(struct connection *con)
@@ -1648,6 +1657,32 @@ static void free_conn(struct connection *con)
        kmem_cache_free(con_cache, con);
 }
 
+static void work_flush(void)
+{
+       int ok;
+       int i;
+       struct hlist_node *n;
+       struct connection *con;
+
+       flush_workqueue(recv_workqueue);
+       flush_workqueue(send_workqueue);
+       do {
+               ok = 1;
+               foreach_conn(stop_conn);
+               flush_workqueue(recv_workqueue);
+               flush_workqueue(send_workqueue);
+               for (i = 0; i < CONN_HASH_SIZE && ok; i++) {
+                       hlist_for_each_entry_safe(con, n,
+                                                 &connection_hash[i], list) {
+                               ok &= test_bit(CF_READ_PENDING, &con->flags);
+                               if (con->othercon)
+                                       ok &= test_bit(CF_READ_PENDING,
+                                                      &con->othercon->flags);
+                       }
+               }
+       } while (!ok);
+}
+
 void dlm_lowcomms_stop(void)
 {
        /* Set all the flags to prevent any
@@ -1655,11 +1690,10 @@ void dlm_lowcomms_stop(void)
        */
        mutex_lock(&connections_lock);
        dlm_allow_conn = 0;
-       foreach_conn(stop_conn);
+       mutex_unlock(&connections_lock);
+       work_flush();
        clean_writequeues();
        foreach_conn(free_conn);
-       mutex_unlock(&connections_lock);
-
        work_stop();
 
        kmem_cache_destroy(con_cache);