Merge tag '6.2-rc-smb3-client-fixes-part1' of git://git.samba.org/sfrench/cifs-2.6
[linux-2.6-microblaze.git] / fs / cifs / smb2ops.c
index 32b3877..d33b00a 100644 (file)
@@ -4204,69 +4204,82 @@ fill_transform_hdr(struct smb2_transform_hdr *tr_hdr, unsigned int orig_len,
        memcpy(&tr_hdr->SessionId, &shdr->SessionId, 8);
 }
 
-/* We can not use the normal sg_set_buf() as we will sometimes pass a
- * stack object as buf.
- */
-static inline void smb2_sg_set_buf(struct scatterlist *sg, const void *buf,
-                                  unsigned int buflen)
+static void *smb2_aead_req_alloc(struct crypto_aead *tfm, const struct smb_rqst *rqst,
+                                int num_rqst, const u8 *sig, u8 **iv,
+                                struct aead_request **req, struct scatterlist **sgl,
+                                unsigned int *num_sgs)
 {
-       void *addr;
-       /*
-        * VMAP_STACK (at least) puts stack into the vmalloc address space
-        */
-       if (is_vmalloc_addr(buf))
-               addr = vmalloc_to_page(buf);
-       else
-               addr = virt_to_page(buf);
-       sg_set_page(sg, addr, buflen, offset_in_page(buf));
+       unsigned int req_size = sizeof(**req) + crypto_aead_reqsize(tfm);
+       unsigned int iv_size = crypto_aead_ivsize(tfm);
+       unsigned int len;
+       u8 *p;
+
+       *num_sgs = cifs_get_num_sgs(rqst, num_rqst, sig);
+
+       len = iv_size;
+       len += crypto_aead_alignmask(tfm) & ~(crypto_tfm_ctx_alignment() - 1);
+       len = ALIGN(len, crypto_tfm_ctx_alignment());
+       len += req_size;
+       len = ALIGN(len, __alignof__(struct scatterlist));
+       len += *num_sgs * sizeof(**sgl);
+
+       p = kmalloc(len, GFP_ATOMIC);
+       if (!p)
+               return NULL;
+
+       *iv = (u8 *)PTR_ALIGN(p, crypto_aead_alignmask(tfm) + 1);
+       *req = (struct aead_request *)PTR_ALIGN(*iv + iv_size,
+                                               crypto_tfm_ctx_alignment());
+       *sgl = (struct scatterlist *)PTR_ALIGN((u8 *)*req + req_size,
+                                              __alignof__(struct scatterlist));
+       return p;
 }
 
-/* Assumes the first rqst has a transform header as the first iov.
- * I.e.
- * rqst[0].rq_iov[0]  is transform header
- * rqst[0].rq_iov[1+] data to be encrypted/decrypted
- * rqst[1+].rq_iov[0+] data to be encrypted/decrypted
- */
-static struct scatterlist *
-init_sg(int num_rqst, struct smb_rqst *rqst, u8 *sign)
+static void *smb2_get_aead_req(struct crypto_aead *tfm, const struct smb_rqst *rqst,
+                              int num_rqst, const u8 *sig, u8 **iv,
+                              struct aead_request **req, struct scatterlist **sgl)
 {
-       unsigned int sg_len;
+       unsigned int off, len, skip;
        struct scatterlist *sg;
-       unsigned int i;
-       unsigned int j;
-       unsigned int idx = 0;
-       int skip;
-
-       sg_len = 1;
-       for (i = 0; i < num_rqst; i++)
-               sg_len += rqst[i].rq_nvec + rqst[i].rq_npages;
+       unsigned int num_sgs;
+       unsigned long addr;
+       int i, j;
+       void *p;
 
-       sg = kmalloc_array(sg_len, sizeof(struct scatterlist), GFP_KERNEL);
-       if (!sg)
+       p = smb2_aead_req_alloc(tfm, rqst, num_rqst, sig, iv, req, sgl, &num_sgs);
+       if (!p)
                return NULL;
 
-       sg_init_table(sg, sg_len);
+       sg_init_table(*sgl, num_sgs);
+       sg = *sgl;
+
+       /* Assumes the first rqst has a transform header as the first iov.
+        * I.e.
+        * rqst[0].rq_iov[0]  is transform header
+        * rqst[0].rq_iov[1+] data to be encrypted/decrypted
+        * rqst[1+].rq_iov[0+] data to be encrypted/decrypted
+        */
        for (i = 0; i < num_rqst; i++) {
+               /*
+                * The first rqst has a transform header where the
+                * first 20 bytes are not part of the encrypted blob.
+                */
                for (j = 0; j < rqst[i].rq_nvec; j++) {
-                       /*
-                        * The first rqst has a transform header where the
-                        * first 20 bytes are not part of the encrypted blob
-                        */
-                       skip = (i == 0) && (j == 0) ? 20 : 0;
-                       smb2_sg_set_buf(&sg[idx++],
-                                       rqst[i].rq_iov[j].iov_base + skip,
-                                       rqst[i].rq_iov[j].iov_len - skip);
-                       }
+                       struct kvec *iov = &rqst[i].rq_iov[j];
 
+                       skip = (i == 0) && (j == 0) ? 20 : 0;
+                       addr = (unsigned long)iov->iov_base + skip;
+                       len = iov->iov_len - skip;
+                       sg = cifs_sg_set_buf(sg, (void *)addr, len);
+               }
                for (j = 0; j < rqst[i].rq_npages; j++) {
-                       unsigned int len, offset;
-
-                       rqst_page_get_length(&rqst[i], j, &len, &offset);
-                       sg_set_page(&sg[idx++], rqst[i].rq_pages[j], len, offset);
+                       rqst_page_get_length(&rqst[i], j, &len, &off);
+                       sg_set_page(sg++, rqst[i].rq_pages[j], len, off);
                }
        }
-       smb2_sg_set_buf(&sg[idx], sign, SMB2_SIGNATURE_SIZE);
-       return sg;
+       cifs_sg_set_buf(sg, sig, SMB2_SIGNATURE_SIZE);
+
+       return p;
 }
 
 static int
@@ -4314,11 +4327,11 @@ crypt_message(struct TCP_Server_Info *server, int num_rqst,
        u8 sign[SMB2_SIGNATURE_SIZE] = {};
        u8 key[SMB3_ENC_DEC_KEY_SIZE];
        struct aead_request *req;
-       char *iv;
-       unsigned int iv_len;
+       u8 *iv;
        DECLARE_CRYPTO_WAIT(wait);
        struct crypto_aead *tfm;
        unsigned int crypt_len = le32_to_cpu(tr_hdr->OriginalMessageSize);
+       void *creq;
 
        rc = smb2_get_enc_key(server, le64_to_cpu(tr_hdr->SessionId), enc, key);
        if (rc) {
@@ -4352,32 +4365,15 @@ crypt_message(struct TCP_Server_Info *server, int num_rqst,
                return rc;
        }
 
-       req = aead_request_alloc(tfm, GFP_KERNEL);
-       if (!req) {
-               cifs_server_dbg(VFS, "%s: Failed to alloc aead request\n", __func__);
+       creq = smb2_get_aead_req(tfm, rqst, num_rqst, sign, &iv, &req, &sg);
+       if (unlikely(!creq))
                return -ENOMEM;
-       }
 
        if (!enc) {
                memcpy(sign, &tr_hdr->Signature, SMB2_SIGNATURE_SIZE);
                crypt_len += SMB2_SIGNATURE_SIZE;
        }
 
-       sg = init_sg(num_rqst, rqst, sign);
-       if (!sg) {
-               cifs_server_dbg(VFS, "%s: Failed to init sg\n", __func__);
-               rc = -ENOMEM;
-               goto free_req;
-       }
-
-       iv_len = crypto_aead_ivsize(tfm);
-       iv = kzalloc(iv_len, GFP_KERNEL);
-       if (!iv) {
-               cifs_server_dbg(VFS, "%s: Failed to alloc iv\n", __func__);
-               rc = -ENOMEM;
-               goto free_sg;
-       }
-
        if ((server->cipher_type == SMB2_ENCRYPTION_AES128_GCM) ||
            (server->cipher_type == SMB2_ENCRYPTION_AES256_GCM))
                memcpy(iv, (char *)tr_hdr->Nonce, SMB3_AES_GCM_NONCE);
@@ -4386,6 +4382,7 @@ crypt_message(struct TCP_Server_Info *server, int num_rqst,
                memcpy(iv + 1, (char *)tr_hdr->Nonce, SMB3_AES_CCM_NONCE);
        }
 
+       aead_request_set_tfm(req, tfm);
        aead_request_set_crypt(req, sg, sg, crypt_len, iv);
        aead_request_set_ad(req, assoc_data_len);
 
@@ -4398,11 +4395,7 @@ crypt_message(struct TCP_Server_Info *server, int num_rqst,
        if (!rc && enc)
                memcpy(&tr_hdr->Signature, sign, SMB2_SIGNATURE_SIZE);
 
-       kfree_sensitive(iv);
-free_sg:
-       kfree_sensitive(sg);
-free_req:
-       kfree_sensitive(req);
+       kfree_sensitive(creq);
        return rc;
 }
 
@@ -4445,21 +4438,27 @@ smb3_init_transform_rq(struct TCP_Server_Info *server, int num_rqst,
        int rc = -ENOMEM;
 
        for (i = 1; i < num_rqst; i++) {
-               npages = old_rq[i - 1].rq_npages;
+               struct smb_rqst *old = &old_rq[i - 1];
+               struct smb_rqst *new = &new_rq[i];
+
+               orig_len += smb_rqst_len(server, old);
+               new->rq_iov = old->rq_iov;
+               new->rq_nvec = old->rq_nvec;
+
+               npages = old->rq_npages;
+               if (!npages)
+                       continue;
+
                pages = kmalloc_array(npages, sizeof(struct page *),
                                      GFP_KERNEL);
                if (!pages)
                        goto err_free;
 
-               new_rq[i].rq_pages = pages;
-               new_rq[i].rq_npages = npages;
-               new_rq[i].rq_offset = old_rq[i - 1].rq_offset;
-               new_rq[i].rq_pagesz = old_rq[i - 1].rq_pagesz;
-               new_rq[i].rq_tailsz = old_rq[i - 1].rq_tailsz;
-               new_rq[i].rq_iov = old_rq[i - 1].rq_iov;
-               new_rq[i].rq_nvec = old_rq[i - 1].rq_nvec;
-
-               orig_len += smb_rqst_len(server, &old_rq[i - 1]);
+               new->rq_pages = pages;
+               new->rq_npages = npages;
+               new->rq_offset = old->rq_offset;
+               new->rq_pagesz = old->rq_pagesz;
+               new->rq_tailsz = old->rq_tailsz;
 
                for (j = 0; j < npages; j++) {
                        pages[j] = alloc_page(GFP_KERNEL|__GFP_HIGHMEM);
@@ -4472,14 +4471,14 @@ smb3_init_transform_rq(struct TCP_Server_Info *server, int num_rqst,
                        char *dst, *src;
                        unsigned int offset, len;
 
-                       rqst_page_get_length(&new_rq[i], j, &len, &offset);
+                       rqst_page_get_length(new, j, &len, &offset);
 
-                       dst = (char *) kmap(new_rq[i].rq_pages[j]) + offset;
-                       src = (char *) kmap(old_rq[i - 1].rq_pages[j]) + offset;
+                       dst = kmap_local_page(new->rq_pages[j]) + offset;
+                       src = kmap_local_page(old->rq_pages[j]) + offset;
 
                        memcpy(dst, src, len);
-                       kunmap(new_rq[i].rq_pages[j]);
-                       kunmap(old_rq[i - 1].rq_pages[j]);
+                       kunmap(new->rq_pages[j]);
+                       kunmap(old->rq_pages[j]);
                }
        }