net/smc: Dynamic control handshake limitation by socket options
[linux-2.6-microblaze.git] / net / smc / af_smc.c
index d5ea62b..97dcdc0 100644 (file)
@@ -59,6 +59,7 @@ static DEFINE_MUTEX(smc_client_lgr_pending);  /* serialize link group
                                                 * creation on client
                                                 */
 
+static struct workqueue_struct *smc_tcp_ls_wq; /* wq for tcp listen work */
 struct workqueue_struct        *smc_hs_wq;     /* wq for handshake work */
 struct workqueue_struct        *smc_close_wq;  /* wq for close work */
 
@@ -72,6 +73,51 @@ static void smc_set_keepalive(struct sock *sk, int val)
        smc->clcsock->sk->sk_prot->keepalive(smc->clcsock->sk, val);
 }
 
+static struct sock *smc_tcp_syn_recv_sock(const struct sock *sk,
+                                         struct sk_buff *skb,
+                                         struct request_sock *req,
+                                         struct dst_entry *dst,
+                                         struct request_sock *req_unhash,
+                                         bool *own_req)
+{
+       struct smc_sock *smc;
+
+       smc = smc_clcsock_user_data(sk);
+
+       if (READ_ONCE(sk->sk_ack_backlog) + atomic_read(&smc->queued_smc_hs) >
+                               sk->sk_max_ack_backlog)
+               goto drop;
+
+       if (sk_acceptq_is_full(&smc->sk)) {
+               NET_INC_STATS(sock_net(sk), LINUX_MIB_LISTENOVERFLOWS);
+               goto drop;
+       }
+
+       /* passthrough to original syn recv sock fct */
+       return smc->ori_af_ops->syn_recv_sock(sk, skb, req, dst, req_unhash,
+                                             own_req);
+
+drop:
+       dst_release(dst);
+       tcp_listendrop(sk);
+       return NULL;
+}
+
+static bool smc_hs_congested(const struct sock *sk)
+{
+       const struct smc_sock *smc;
+
+       smc = smc_clcsock_user_data(sk);
+
+       if (!smc)
+               return true;
+
+       if (workqueue_congested(WORK_CPU_UNBOUND, smc_hs_wq))
+               return true;
+
+       return false;
+}
+
 static struct smc_hashinfo smc_v4_hashinfo = {
        .lock = __RW_LOCK_UNLOCKED(smc_v4_hashinfo.lock),
 };
@@ -566,17 +612,115 @@ static void smc_stat_fallback(struct smc_sock *smc)
        mutex_unlock(&net->smc.mutex_fback_rsn);
 }
 
+/* must be called under rcu read lock */
+static void smc_fback_wakeup_waitqueue(struct smc_sock *smc, void *key)
+{
+       struct socket_wq *wq;
+       __poll_t flags;
+
+       wq = rcu_dereference(smc->sk.sk_wq);
+       if (!skwq_has_sleeper(wq))
+               return;
+
+       /* wake up smc sk->sk_wq */
+       if (!key) {
+               /* sk_state_change */
+               wake_up_interruptible_all(&wq->wait);
+       } else {
+               flags = key_to_poll(key);
+               if (flags & (EPOLLIN | EPOLLOUT))
+                       /* sk_data_ready or sk_write_space */
+                       wake_up_interruptible_sync_poll(&wq->wait, flags);
+               else if (flags & EPOLLERR)
+                       /* sk_error_report */
+                       wake_up_interruptible_poll(&wq->wait, flags);
+       }
+}
+
+static int smc_fback_mark_woken(wait_queue_entry_t *wait,
+                               unsigned int mode, int sync, void *key)
+{
+       struct smc_mark_woken *mark =
+               container_of(wait, struct smc_mark_woken, wait_entry);
+
+       mark->woken = true;
+       mark->key = key;
+       return 0;
+}
+
+static void smc_fback_forward_wakeup(struct smc_sock *smc, struct sock *clcsk,
+                                    void (*clcsock_callback)(struct sock *sk))
+{
+       struct smc_mark_woken mark = { .woken = false };
+       struct socket_wq *wq;
+
+       init_waitqueue_func_entry(&mark.wait_entry,
+                                 smc_fback_mark_woken);
+       rcu_read_lock();
+       wq = rcu_dereference(clcsk->sk_wq);
+       if (!wq)
+               goto out;
+       add_wait_queue(sk_sleep(clcsk), &mark.wait_entry);
+       clcsock_callback(clcsk);
+       remove_wait_queue(sk_sleep(clcsk), &mark.wait_entry);
+
+       if (mark.woken)
+               smc_fback_wakeup_waitqueue(smc, mark.key);
+out:
+       rcu_read_unlock();
+}
+
+static void smc_fback_state_change(struct sock *clcsk)
+{
+       struct smc_sock *smc =
+               smc_clcsock_user_data(clcsk);
+
+       if (!smc)
+               return;
+       smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_state_change);
+}
+
+static void smc_fback_data_ready(struct sock *clcsk)
+{
+       struct smc_sock *smc =
+               smc_clcsock_user_data(clcsk);
+
+       if (!smc)
+               return;
+       smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_data_ready);
+}
+
+static void smc_fback_write_space(struct sock *clcsk)
+{
+       struct smc_sock *smc =
+               smc_clcsock_user_data(clcsk);
+
+       if (!smc)
+               return;
+       smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_write_space);
+}
+
+static void smc_fback_error_report(struct sock *clcsk)
+{
+       struct smc_sock *smc =
+               smc_clcsock_user_data(clcsk);
+
+       if (!smc)
+               return;
+       smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_error_report);
+}
+
 static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code)
 {
-       wait_queue_head_t *smc_wait = sk_sleep(&smc->sk);
-       wait_queue_head_t *clc_wait;
-       unsigned long flags;
+       struct sock *clcsk;
 
        mutex_lock(&smc->clcsock_release_lock);
        if (!smc->clcsock) {
                mutex_unlock(&smc->clcsock_release_lock);
                return -EBADF;
        }
+       clcsk = smc->clcsock->sk;
+
        smc->use_fallback = true;
        smc->fallback_rsn = reason_code;
        smc_stat_fallback(smc);
@@ -587,16 +731,22 @@ static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code)
                smc->clcsock->wq.fasync_list =
                        smc->sk.sk_socket->wq.fasync_list;
 
-               /* There may be some entries remaining in
-                * smc socket->wq, which should be removed
-                * to clcsocket->wq during the fallback.
+               /* There might be some wait entries remaining
+                * in smc sk->sk_wq and they should be woken up
+                * as clcsock's wait queue is woken up.
                 */
-               clc_wait = sk_sleep(smc->clcsock->sk);
-               spin_lock_irqsave(&smc_wait->lock, flags);
-               spin_lock_nested(&clc_wait->lock, SINGLE_DEPTH_NESTING);
-               list_splice_init(&smc_wait->head, &clc_wait->head);
-               spin_unlock(&clc_wait->lock);
-               spin_unlock_irqrestore(&smc_wait->lock, flags);
+               smc->clcsk_state_change = clcsk->sk_state_change;
+               smc->clcsk_data_ready = clcsk->sk_data_ready;
+               smc->clcsk_write_space = clcsk->sk_write_space;
+               smc->clcsk_error_report = clcsk->sk_error_report;
+
+               clcsk->sk_state_change = smc_fback_state_change;
+               clcsk->sk_data_ready = smc_fback_data_ready;
+               clcsk->sk_write_space = smc_fback_write_space;
+               clcsk->sk_error_report = smc_fback_error_report;
+
+               smc->clcsock->sk->sk_user_data =
+                       (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY);
        }
        mutex_unlock(&smc->clcsock_release_lock);
        return 0;
@@ -1490,6 +1640,9 @@ static void smc_listen_out(struct smc_sock *new_smc)
        struct smc_sock *lsmc = new_smc->listen_smc;
        struct sock *newsmcsk = &new_smc->sk;
 
+       if (tcp_sk(new_smc->clcsock->sk)->syn_smc)
+               atomic_dec(&lsmc->queued_smc_hs);
+
        if (lsmc->sk.sk_state == SMC_LISTEN) {
                lock_sock_nested(&lsmc->sk, SINGLE_DEPTH_NESTING);
                smc_accept_enqueue(&lsmc->sk, newsmcsk);
@@ -2095,6 +2248,9 @@ static void smc_tcp_listen_work(struct work_struct *work)
                if (!new_smc)
                        continue;
 
+               if (tcp_sk(new_smc->clcsock->sk)->syn_smc)
+                       atomic_inc(&lsmc->queued_smc_hs);
+
                new_smc->listen_smc = lsmc;
                new_smc->use_fallback = lsmc->use_fallback;
                new_smc->fallback_rsn = lsmc->fallback_rsn;
@@ -2115,16 +2271,15 @@ out:
 
 static void smc_clcsock_data_ready(struct sock *listen_clcsock)
 {
-       struct smc_sock *lsmc;
+       struct smc_sock *lsmc =
+               smc_clcsock_user_data(listen_clcsock);
 
-       lsmc = (struct smc_sock *)
-              ((uintptr_t)listen_clcsock->sk_user_data & ~SK_USER_DATA_NOCOPY);
        if (!lsmc)
                return;
        lsmc->clcsk_data_ready(listen_clcsock);
        if (lsmc->sk.sk_state == SMC_LISTEN) {
                sock_hold(&lsmc->sk); /* sock_put in smc_tcp_listen_work() */
-               if (!queue_work(smc_hs_wq, &lsmc->tcp_listen_work))
+               if (!queue_work(smc_tcp_ls_wq, &lsmc->tcp_listen_work))
                        sock_put(&lsmc->sk);
        }
 }
@@ -2162,6 +2317,18 @@ static int smc_listen(struct socket *sock, int backlog)
        smc->clcsock->sk->sk_data_ready = smc_clcsock_data_ready;
        smc->clcsock->sk->sk_user_data =
                (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY);
+
+       /* save original ops */
+       smc->ori_af_ops = inet_csk(smc->clcsock->sk)->icsk_af_ops;
+
+       smc->af_ops = *smc->ori_af_ops;
+       smc->af_ops.syn_recv_sock = smc_tcp_syn_recv_sock;
+
+       inet_csk(smc->clcsock->sk)->icsk_af_ops = &smc->af_ops;
+
+       if (smc->limit_smc_hs)
+               tcp_sk(smc->clcsock->sk)->smc_hs_congested = smc_hs_congested;
+
        rc = kernel_listen(smc->clcsock, backlog);
        if (rc) {
                smc->clcsock->sk->sk_data_ready = smc->clcsk_data_ready;
@@ -2455,6 +2622,67 @@ out:
        return rc ? rc : rc1;
 }
 
+static int __smc_getsockopt(struct socket *sock, int level, int optname,
+                           char __user *optval, int __user *optlen)
+{
+       struct smc_sock *smc;
+       int val, len;
+
+       smc = smc_sk(sock->sk);
+
+       if (get_user(len, optlen))
+               return -EFAULT;
+
+       len = min_t(int, len, sizeof(int));
+
+       if (len < 0)
+               return -EINVAL;
+
+       switch (optname) {
+       case SMC_LIMIT_HS:
+               val = smc->limit_smc_hs;
+               break;
+       default:
+               return -EOPNOTSUPP;
+       }
+
+       if (put_user(len, optlen))
+               return -EFAULT;
+       if (copy_to_user(optval, &val, len))
+               return -EFAULT;
+
+       return 0;
+}
+
+static int __smc_setsockopt(struct socket *sock, int level, int optname,
+                           sockptr_t optval, unsigned int optlen)
+{
+       struct sock *sk = sock->sk;
+       struct smc_sock *smc;
+       int val, rc;
+
+       smc = smc_sk(sk);
+
+       lock_sock(sk);
+       switch (optname) {
+       case SMC_LIMIT_HS:
+               if (optlen < sizeof(int))
+                       return -EINVAL;
+               if (copy_from_sockptr(&val, optval, sizeof(int)))
+                       return -EFAULT;
+
+               smc->limit_smc_hs = !!val;
+               rc = 0;
+               break;
+       default:
+               rc = -EOPNOTSUPP;
+               break;
+       }
+       release_sock(sk);
+
+       return rc;
+}
+
 static int smc_setsockopt(struct socket *sock, int level, int optname,
                          sockptr_t optval, unsigned int optlen)
 {
@@ -2464,6 +2692,8 @@ static int smc_setsockopt(struct socket *sock, int level, int optname,
 
        if (level == SOL_TCP && optname == TCP_ULP)
                return -EOPNOTSUPP;
+       else if (level == SOL_SMC)
+               return __smc_setsockopt(sock, level, optname, optval, optlen);
 
        smc = smc_sk(sk);
 
@@ -2523,8 +2753,8 @@ static int smc_setsockopt(struct socket *sock, int level, int optname,
                    sk->sk_state != SMC_CLOSED) {
                        if (!val) {
                                SMC_STAT_INC(smc, cork_cnt);
-                               mod_delayed_work(smc->conn.lgr->tx_wq,
-                                                &smc->conn.tx_work, 0);
+                               smc_tx_pending(&smc->conn);
+                               cancel_delayed_work(&smc->conn.tx_work);
                        }
                }
                break;
@@ -2546,6 +2776,9 @@ static int smc_getsockopt(struct socket *sock, int level, int optname,
        struct smc_sock *smc;
        int rc;
 
+       if (level == SOL_SMC)
+               return __smc_getsockopt(sock, level, optname, optval, optlen);
+
        smc = smc_sk(sock->sk);
        mutex_lock(&smc->clcsock_release_lock);
        if (!smc->clcsock) {
@@ -2662,8 +2895,10 @@ static ssize_t smc_sendpage(struct socket *sock, struct page *page,
                rc = kernel_sendpage(smc->clcsock, page, offset,
                                     size, flags);
        } else {
+               lock_sock(sk);
+               rc = smc_tx_sendpage(smc, page, offset, size, flags);
+               release_sock(sk);
                SMC_STAT_INC(smc, sendpage_cnt);
-               rc = sock_no_sendpage(sock, page, offset, size, flags);
        }
 
 out:
@@ -2919,9 +3154,14 @@ static int __init smc_init(void)
                goto out_nl;
 
        rc = -ENOMEM;
+
+       smc_tcp_ls_wq = alloc_workqueue("smc_tcp_ls_wq", 0, 0);
+       if (!smc_tcp_ls_wq)
+               goto out_pnet;
+
        smc_hs_wq = alloc_workqueue("smc_hs_wq", 0, 0);
        if (!smc_hs_wq)
-               goto out_pnet;
+               goto out_alloc_tcp_ls_wq;
 
        smc_close_wq = alloc_workqueue("smc_close_wq", 0, 0);
        if (!smc_close_wq)
@@ -2992,6 +3232,8 @@ out_alloc_wqs:
        destroy_workqueue(smc_close_wq);
 out_alloc_hs_wq:
        destroy_workqueue(smc_hs_wq);
+out_alloc_tcp_ls_wq:
+       destroy_workqueue(smc_tcp_ls_wq);
 out_pnet:
        smc_pnet_exit();
 out_nl:
@@ -3010,6 +3252,7 @@ static void __exit smc_exit(void)
        smc_core_exit();
        smc_ib_unregister_client();
        destroy_workqueue(smc_close_wq);
+       destroy_workqueue(smc_tcp_ls_wq);
        destroy_workqueue(smc_hs_wq);
        proto_unregister(&smc_proto6);
        proto_unregister(&smc_proto);