tls: Stricter error checking in zerocopy sendmsg path
[linux-2.6-microblaze.git] / net / tls / tls_sw.c
index f127fac..4618f1c 100644 (file)
@@ -440,7 +440,7 @@ alloc_encrypted:
                        ret = tls_push_record(sk, msg->msg_flags, record_type);
                        if (!ret)
                                continue;
-                       if (ret == -EAGAIN)
+                       if (ret < 0)
                                goto send_end;
 
                        copied -= try_to_copy;
@@ -701,6 +701,10 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
        nsg = skb_to_sgvec(skb, &sgin[1],
                           rxm->offset + tls_ctx->rx.prepend_size,
                           rxm->full_len - tls_ctx->rx.prepend_size);
+       if (nsg < 0) {
+               ret = nsg;
+               goto out;
+       }
 
        tls_make_aad(ctx->rx_aad_ciphertext,
                     rxm->full_len - tls_ctx->rx.overhead_size,
@@ -712,6 +716,7 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
                                rxm->full_len - tls_ctx->rx.overhead_size,
                                skb, sk->sk_allocation);
 
+out:
        if (sgin != &sgin_arr[0])
                kfree(sgin);
 
@@ -919,22 +924,23 @@ splice_read_end:
        return copied ? : err;
 }
 
-__poll_t tls_sw_poll_mask(struct socket *sock, __poll_t events)
+unsigned int tls_sw_poll(struct file *file, struct socket *sock,
+                        struct poll_table_struct *wait)
 {
+       unsigned int ret;
        struct sock *sk = sock->sk;
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-       __poll_t mask;
 
-       /* Grab EPOLLOUT and EPOLLHUP from the underlying socket */
-       mask = ctx->sk_poll_mask(sock, events);
+       /* Grab POLLOUT and POLLHUP from the underlying socket */
+       ret = ctx->sk_poll(file, sock, wait);
 
-       /* Clear EPOLLIN bits, and set based on recv_pkt */
-       mask &= ~(EPOLLIN | EPOLLRDNORM);
+       /* Clear POLLIN bits, and set based on recv_pkt */
+       ret &= ~(POLLIN | POLLRDNORM);
        if (ctx->recv_pkt)
-               mask |= EPOLLIN | EPOLLRDNORM;
+               ret |= POLLIN | POLLRDNORM;
 
-       return mask;
+       return ret;
 }
 
 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
@@ -1191,7 +1197,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
                sk->sk_data_ready = tls_data_ready;
                write_unlock_bh(&sk->sk_callback_lock);
 
-               sw_ctx_rx->sk_poll_mask = sk->sk_socket->ops->poll_mask;
+               sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
 
                strp_check_rcv(&sw_ctx_rx->strp);
        }