SMB3: Resolve data corruption of TCP server info fields
authorRohith Surabattula <rohiths@microsoft.com>
Thu, 8 Oct 2020 09:58:41 +0000 (09:58 +0000)
committerSteve French <stfrench@microsoft.com>
Wed, 21 Oct 2020 22:56:23 +0000 (17:56 -0500)
TCP server info field server->total_read is modified in parallel by
demultiplex thread and decrypt offload worker thread. server->total_read
is used in calculation to discard the remaining data of PDU which is
not read into memory.

Because of parallel modification, server->total_read can get corrupted
and can result in discarding the valid data of next PDU.

Signed-off-by: Rohith Surabattula <rohiths@microsoft.com>
Reviewed-by: Aurelien Aptel <aaptel@suse.com>
Reviewed-by: Pavel Shilovsky <pshilov@microsoft.com>
CC: Stable <stable@vger.kernel.org> #5.4+
Signed-off-by: Steve French <stfrench@microsoft.com>
fs/cifs/smb2ops.c

index f085fe3..2c3cfb2 100644 (file)
@@ -4131,7 +4131,8 @@ smb3_is_transform_hdr(void *buf)
 static int
 decrypt_raw_data(struct TCP_Server_Info *server, char *buf,
                 unsigned int buf_data_size, struct page **pages,
-                unsigned int npages, unsigned int page_data_size)
+                unsigned int npages, unsigned int page_data_size,
+                bool is_offloaded)
 {
        struct kvec iov[2];
        struct smb_rqst rqst = {NULL};
@@ -4157,7 +4158,8 @@ decrypt_raw_data(struct TCP_Server_Info *server, char *buf,
 
        memmove(buf, iov[1].iov_base, buf_data_size);
 
-       server->total_read = buf_data_size + page_data_size;
+       if (!is_offloaded)
+               server->total_read = buf_data_size + page_data_size;
 
        return rc;
 }
@@ -4370,7 +4372,7 @@ static void smb2_decrypt_offload(struct work_struct *work)
        struct mid_q_entry *mid;
 
        rc = decrypt_raw_data(dw->server, dw->buf, dw->server->vals->read_rsp_size,
-                             dw->ppages, dw->npages, dw->len);
+                             dw->ppages, dw->npages, dw->len, true);
        if (rc) {
                cifs_dbg(VFS, "error decrypting rc=%d\n", rc);
                goto free_pages;
@@ -4476,7 +4478,7 @@ receive_encrypted_read(struct TCP_Server_Info *server, struct mid_q_entry **mid,
 
 non_offloaded_decrypt:
        rc = decrypt_raw_data(server, buf, server->vals->read_rsp_size,
-                             pages, npages, len);
+                             pages, npages, len, false);
        if (rc)
                goto free_pages;
 
@@ -4532,7 +4534,7 @@ receive_encrypted_standard(struct TCP_Server_Info *server,
        server->total_read += length;
 
        buf_size = pdu_length - sizeof(struct smb2_transform_hdr);
-       length = decrypt_raw_data(server, buf, buf_size, NULL, 0, 0);
+       length = decrypt_raw_data(server, buf, buf_size, NULL, 0, 0, false);
        if (length)
                return length;