crypto: lib/mpi - Add error checks to extension
authorHerbert Xu <herbert@gondor.apana.org.au>
Sat, 10 Aug 2024 06:20:57 +0000 (14:20 +0800)
committerHerbert Xu <herbert@gondor.apana.org.au>
Sat, 17 Aug 2024 05:55:50 +0000 (13:55 +0800)
The remaining functions added by commit
a8ea8bdd9df92a0e5db5b43900abb7a288b8a53e did not check for memory
allocation errors.  Add the checks and change the API to allow errors
to be returned.

Fixes: a8ea8bdd9df9 ("lib/mpi: Extend the MPI library")
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
include/linux/mpi.h
lib/crypto/mpi/mpi-add.c
lib/crypto/mpi/mpi-bit.c
lib/crypto/mpi/mpi-div.c
lib/crypto/mpi/mpi-internal.h
lib/crypto/mpi/mpi-mod.c
lib/crypto/mpi/mpi-mul.c
lib/crypto/mpi/mpiutil.c

index e081428..47be46f 100644 (file)
@@ -59,7 +59,7 @@ int mpi_write_to_sgl(MPI a, struct scatterlist *sg, unsigned nbytes,
                     int *sign);
 
 /*-- mpi-mod.c --*/
-void mpi_mod(MPI rem, MPI dividend, MPI divisor);
+int mpi_mod(MPI rem, MPI dividend, MPI divisor);
 
 /*-- mpi-pow.c --*/
 int mpi_powm(MPI res, MPI base, MPI exp, MPI mod);
@@ -75,22 +75,22 @@ int mpi_sub_ui(MPI w, MPI u, unsigned long vval);
 void mpi_normalize(MPI a);
 unsigned mpi_get_nbits(MPI a);
 int mpi_test_bit(MPI a, unsigned int n);
-void mpi_set_bit(MPI a, unsigned int n);
-void mpi_rshift(MPI x, MPI a, unsigned int n);
+int mpi_set_bit(MPI a, unsigned int n);
+int mpi_rshift(MPI x, MPI a, unsigned int n);
 
 /*-- mpi-add.c --*/
-void mpi_add(MPI w, MPI u, MPI v);
-void mpi_sub(MPI w, MPI u, MPI v);
-void mpi_addm(MPI w, MPI u, MPI v, MPI m);
-void mpi_subm(MPI w, MPI u, MPI v, MPI m);
+int mpi_add(MPI w, MPI u, MPI v);
+int mpi_sub(MPI w, MPI u, MPI v);
+int mpi_addm(MPI w, MPI u, MPI v, MPI m);
+int mpi_subm(MPI w, MPI u, MPI v, MPI m);
 
 /*-- mpi-mul.c --*/
-void mpi_mul(MPI w, MPI u, MPI v);
-void mpi_mulm(MPI w, MPI u, MPI v, MPI m);
+int mpi_mul(MPI w, MPI u, MPI v);
+int mpi_mulm(MPI w, MPI u, MPI v, MPI m);
 
 /*-- mpi-div.c --*/
-void mpi_tdiv_r(MPI rem, MPI num, MPI den);
-void mpi_fdiv_r(MPI rem, MPI dividend, MPI divisor);
+int mpi_tdiv_r(MPI rem, MPI num, MPI den);
+int mpi_fdiv_r(MPI rem, MPI dividend, MPI divisor);
 
 /* inline functions */
 
index b47c8c3..3015140 100644 (file)
 
 #include "mpi-internal.h"
 
-void mpi_add(MPI w, MPI u, MPI v)
+int mpi_add(MPI w, MPI u, MPI v)
 {
        mpi_ptr_t wp, up, vp;
        mpi_size_t usize, vsize, wsize;
        int usign, vsign, wsign;
+       int err;
 
        if (u->nlimbs < v->nlimbs) { /* Swap U and V. */
                usize = v->nlimbs;
@@ -25,7 +26,9 @@ void mpi_add(MPI w, MPI u, MPI v)
                vsize = u->nlimbs;
                vsign = u->sign;
                wsize = usize + 1;
-               RESIZE_IF_NEEDED(w, wsize);
+               err = RESIZE_IF_NEEDED(w, wsize);
+               if (err)
+                       return err;
                /* These must be after realloc (u or v may be the same as w).  */
                up = v->d;
                vp = u->d;
@@ -35,7 +38,9 @@ void mpi_add(MPI w, MPI u, MPI v)
                vsize = v->nlimbs;
                vsign = v->sign;
                wsize = usize + 1;
-               RESIZE_IF_NEEDED(w, wsize);
+               err = RESIZE_IF_NEEDED(w, wsize);
+               if (err)
+                       return err;
                /* These must be after realloc (u or v may be the same as w).  */
                up = u->d;
                vp = v->d;
@@ -77,28 +82,37 @@ void mpi_add(MPI w, MPI u, MPI v)
 
        w->nlimbs = wsize;
        w->sign = wsign;
+       return 0;
 }
 EXPORT_SYMBOL_GPL(mpi_add);
 
-void mpi_sub(MPI w, MPI u, MPI v)
+int mpi_sub(MPI w, MPI u, MPI v)
 {
-       MPI vv = mpi_copy(v);
+       int err;
+       MPI vv;
+
+       vv = mpi_copy(v);
+       if (!vv)
+               return -ENOMEM;
+
        vv->sign = !vv->sign;
-       mpi_add(w, u, vv);
+       err = mpi_add(w, u, vv);
        mpi_free(vv);
+
+       return err;
 }
 EXPORT_SYMBOL_GPL(mpi_sub);
 
-void mpi_addm(MPI w, MPI u, MPI v, MPI m)
+int mpi_addm(MPI w, MPI u, MPI v, MPI m)
 {
-       mpi_add(w, u, v);
-       mpi_mod(w, w, m);
+       return mpi_add(w, u, v) ?:
+              mpi_mod(w, w, m);
 }
 EXPORT_SYMBOL_GPL(mpi_addm);
 
-void mpi_subm(MPI w, MPI u, MPI v, MPI m)
+int mpi_subm(MPI w, MPI u, MPI v, MPI m)
 {
-       mpi_sub(w, u, v);
-       mpi_mod(w, w, m);
+       return mpi_sub(w, u, v) ?:
+              mpi_mod(w, w, m);
 }
 EXPORT_SYMBOL_GPL(mpi_subm);
index c29b853..835a2f0 100644 (file)
@@ -76,9 +76,10 @@ EXPORT_SYMBOL_GPL(mpi_test_bit);
 /****************
  * Set bit N of A.
  */
-void mpi_set_bit(MPI a, unsigned int n)
+int mpi_set_bit(MPI a, unsigned int n)
 {
        unsigned int i, limbno, bitno;
+       int err;
 
        limbno = n / BITS_PER_MPI_LIMB;
        bitno  = n % BITS_PER_MPI_LIMB;
@@ -86,27 +87,31 @@ void mpi_set_bit(MPI a, unsigned int n)
        if (limbno >= a->nlimbs) {
                for (i = a->nlimbs; i < a->alloced; i++)
                        a->d[i] = 0;
-               mpi_resize(a, limbno+1);
+               err = mpi_resize(a, limbno+1);
+               if (err)
+                       return err;
                a->nlimbs = limbno+1;
        }
        a->d[limbno] |= (A_LIMB_1<<bitno);
+       return 0;
 }
 
 /*
  * Shift A by N bits to the right.
  */
-void mpi_rshift(MPI x, MPI a, unsigned int n)
+int mpi_rshift(MPI x, MPI a, unsigned int n)
 {
        mpi_size_t xsize;
        unsigned int i;
        unsigned int nlimbs = (n/BITS_PER_MPI_LIMB);
        unsigned int nbits = (n%BITS_PER_MPI_LIMB);
+       int err;
 
        if (x == a) {
                /* In-place operation.  */
                if (nlimbs >= x->nlimbs) {
                        x->nlimbs = 0;
-                       return;
+                       return 0;
                }
 
                if (nlimbs) {
@@ -121,7 +126,9 @@ void mpi_rshift(MPI x, MPI a, unsigned int n)
                /* Copy and shift by more or equal bits than in a limb. */
                xsize = a->nlimbs;
                x->sign = a->sign;
-               RESIZE_IF_NEEDED(x, xsize);
+               err = RESIZE_IF_NEEDED(x, xsize);
+               if (err)
+                       return err;
                x->nlimbs = xsize;
                for (i = 0; i < a->nlimbs; i++)
                        x->d[i] = a->d[i];
@@ -129,7 +136,7 @@ void mpi_rshift(MPI x, MPI a, unsigned int n)
 
                if (nlimbs >= x->nlimbs) {
                        x->nlimbs = 0;
-                       return;
+                       return 0;
                }
 
                for (i = 0; i < x->nlimbs - nlimbs; i++)
@@ -143,7 +150,9 @@ void mpi_rshift(MPI x, MPI a, unsigned int n)
                /* Copy and shift by less than bits in a limb.  */
                xsize = a->nlimbs;
                x->sign = a->sign;
-               RESIZE_IF_NEEDED(x, xsize);
+               err = RESIZE_IF_NEEDED(x, xsize);
+               if (err)
+                       return err;
                x->nlimbs = xsize;
 
                if (xsize) {
@@ -159,5 +168,7 @@ void mpi_rshift(MPI x, MPI a, unsigned int n)
                }
        }
        MPN_NORMALIZE(x->d, x->nlimbs);
+
+       return 0;
 }
 EXPORT_SYMBOL_GPL(mpi_rshift);
index 2ff0ebd..6e5044e 100644 (file)
 #include "mpi-internal.h"
 #include "longlong.h"
 
-void mpi_tdiv_qr(MPI quot, MPI rem, MPI num, MPI den);
+int mpi_tdiv_qr(MPI quot, MPI rem, MPI num, MPI den);
 
-void mpi_fdiv_r(MPI rem, MPI dividend, MPI divisor)
+int mpi_fdiv_r(MPI rem, MPI dividend, MPI divisor)
 {
        int divisor_sign = divisor->sign;
        MPI temp_divisor = NULL;
+       int err;
 
        /* We need the original value of the divisor after the remainder has been
         * preliminary calculated.      We have to copy it to temporary space if it's
@@ -27,16 +28,22 @@ void mpi_fdiv_r(MPI rem, MPI dividend, MPI divisor)
         */
        if (rem == divisor) {
                temp_divisor = mpi_copy(divisor);
+               if (!temp_divisor)
+                       return -ENOMEM;
                divisor = temp_divisor;
        }
 
-       mpi_tdiv_r(rem, dividend, divisor);
+       err = mpi_tdiv_r(rem, dividend, divisor);
+       if (err)
+               goto free_temp_divisor;
 
        if (((divisor_sign?1:0) ^ (dividend->sign?1:0)) && rem->nlimbs)
-               mpi_add(rem, rem, divisor);
+               err = mpi_add(rem, rem, divisor);
 
-       if (temp_divisor)
-               mpi_free(temp_divisor);
+free_temp_divisor:
+       mpi_free(temp_divisor);
+
+       return err;
 }
 
 /* If den == quot, den needs temporary storage.
@@ -46,12 +53,12 @@ void mpi_fdiv_r(MPI rem, MPI dividend, MPI divisor)
  *   i.e no extra storage should be allocated.
  */
 
-void mpi_tdiv_r(MPI rem, MPI num, MPI den)
+int mpi_tdiv_r(MPI rem, MPI num, MPI den)
 {
-       mpi_tdiv_qr(NULL, rem, num, den);
+       return mpi_tdiv_qr(NULL, rem, num, den);
 }
 
-void mpi_tdiv_qr(MPI quot, MPI rem, MPI num, MPI den)
+int mpi_tdiv_qr(MPI quot, MPI rem, MPI num, MPI den)
 {
        mpi_ptr_t np, dp;
        mpi_ptr_t qp, rp;
@@ -64,13 +71,16 @@ void mpi_tdiv_qr(MPI quot, MPI rem, MPI num, MPI den)
        mpi_limb_t q_limb;
        mpi_ptr_t marker[5];
        int markidx = 0;
+       int err;
 
        /* Ensure space is enough for quotient and remainder.
         * We need space for an extra limb in the remainder, because it's
         * up-shifted (normalized) below.
         */
        rsize = nsize + 1;
-       mpi_resize(rem, rsize);
+       err = mpi_resize(rem, rsize);
+       if (err)
+               return err;
 
        qsize = rsize - dsize;    /* qsize cannot be bigger than this.  */
        if (qsize <= 0) {
@@ -86,11 +96,14 @@ void mpi_tdiv_qr(MPI quot, MPI rem, MPI num, MPI den)
                        quot->nlimbs = 0;
                        quot->sign = 0;
                }
-               return;
+               return 0;
        }
 
-       if (quot)
-               mpi_resize(quot, qsize);
+       if (quot) {
+               err = mpi_resize(quot, qsize);
+               if (err)
+                       return err;
+       }
 
        /* Read pointers here, when reallocation is finished.  */
        np = num->d;
@@ -112,10 +125,10 @@ void mpi_tdiv_qr(MPI quot, MPI rem, MPI num, MPI den)
                rsize = rlimb != 0?1:0;
                rem->nlimbs = rsize;
                rem->sign = sign_remainder;
-               return;
+               return 0;
        }
 
-
+       err = -ENOMEM;
        if (quot) {
                qp = quot->d;
                /* Make sure QP and NP point to different objects.  Otherwise the
@@ -123,6 +136,8 @@ void mpi_tdiv_qr(MPI quot, MPI rem, MPI num, MPI den)
                 */
                if (qp == np) { /* Copy NP object to temporary space.  */
                        np = marker[markidx++] = mpi_alloc_limb_space(nsize);
+                       if (!np)
+                               goto out_free_marker;
                        MPN_COPY(np, qp, nsize);
                }
        } else /* Put quotient at top of remainder. */
@@ -143,6 +158,8 @@ void mpi_tdiv_qr(MPI quot, MPI rem, MPI num, MPI den)
                 * the original contents of the denominator.
                 */
                tp = marker[markidx++] = mpi_alloc_limb_space(dsize);
+               if (!tp)
+                       goto out_free_marker;
                mpihelp_lshift(tp, dp, dsize, normalization_steps);
                dp = tp;
 
@@ -164,6 +181,8 @@ void mpi_tdiv_qr(MPI quot, MPI rem, MPI num, MPI den)
                        mpi_ptr_t tp;
 
                        tp = marker[markidx++] = mpi_alloc_limb_space(dsize);
+                       if (!tp)
+                               goto out_free_marker;
                        MPN_COPY(tp, dp, dsize);
                        dp = tp;
                }
@@ -198,8 +217,14 @@ void mpi_tdiv_qr(MPI quot, MPI rem, MPI num, MPI den)
 
        rem->nlimbs = rsize;
        rem->sign       = sign_remainder;
+
+       err = 0;
+
+out_free_marker:
        while (markidx) {
                markidx--;
                mpi_free_limb_space(marker[markidx]);
        }
+
+       return err;
 }
index b6fbb43..8a4f49e 100644 (file)
 typedef mpi_limb_t *mpi_ptr_t; /* pointer to a limb */
 typedef int mpi_size_t;                /* (must be a signed type) */
 
-#define RESIZE_IF_NEEDED(a, b)                 \
-       do {                                    \
-               if ((a)->alloced < (b))         \
-                       mpi_resize((a), (b));   \
-       } while (0)
+static inline int RESIZE_IF_NEEDED(MPI a, unsigned b)
+{
+       if (a->alloced < b)
+               return mpi_resize(a, b);
+       return 0;
+}
 
 /* Copy N limbs from S to D.  */
 #define MPN_COPY(d, s, n) \
index 691bbdc..d5fdaec 100644 (file)
@@ -7,7 +7,7 @@
 
 #include "mpi-internal.h"
 
-void mpi_mod(MPI rem, MPI dividend, MPI divisor)
+int mpi_mod(MPI rem, MPI dividend, MPI divisor)
 {
-       mpi_fdiv_r(rem, dividend, divisor);
+       return mpi_fdiv_r(rem, dividend, divisor);
 }
index 7f4eda8..892a246 100644 (file)
@@ -13,7 +13,7 @@
 
 #include "mpi-internal.h"
 
-void mpi_mul(MPI w, MPI u, MPI v)
+int mpi_mul(MPI w, MPI u, MPI v)
 {
        mpi_size_t usize, vsize, wsize;
        mpi_ptr_t up, vp, wp;
@@ -21,6 +21,7 @@ void mpi_mul(MPI w, MPI u, MPI v)
        int usign, vsign, sign_product;
        int assign_wp = 0;
        mpi_ptr_t tmp_limb = NULL;
+       int err;
 
        if (u->nlimbs < v->nlimbs) {
                /* Swap U and V. */
@@ -46,15 +47,21 @@ void mpi_mul(MPI w, MPI u, MPI v)
        if (w->alloced < wsize) {
                if (wp == up || wp == vp) {
                        wp = mpi_alloc_limb_space(wsize);
+                       if (!wp)
+                               return -ENOMEM;
                        assign_wp = 1;
                } else {
-                       mpi_resize(w, wsize);
+                       err = mpi_resize(w, wsize);
+                       if (err)
+                               return err;
                        wp = w->d;
                }
        } else { /* Make U and V not overlap with W.    */
                if (wp == up) {
                        /* W and U are identical.  Allocate temporary space for U. */
                        up = tmp_limb = mpi_alloc_limb_space(usize);
+                       if (!up)
+                               return -ENOMEM;
                        /* Is V identical too?  Keep it identical with U.  */
                        if (wp == vp)
                                vp = up;
@@ -63,6 +70,8 @@ void mpi_mul(MPI w, MPI u, MPI v)
                } else if (wp == vp) {
                        /* W and V are identical.  Allocate temporary space for V. */
                        vp = tmp_limb = mpi_alloc_limb_space(vsize);
+                       if (!vp)
+                               return -ENOMEM;
                        /* Copy to the temporary space.  */
                        MPN_COPY(vp, wp, vsize);
                }
@@ -71,7 +80,12 @@ void mpi_mul(MPI w, MPI u, MPI v)
        if (!vsize)
                wsize = 0;
        else {
-               mpihelp_mul(wp, up, usize, vp, vsize, &cy);
+               err = mpihelp_mul(wp, up, usize, vp, vsize, &cy);
+               if (err) {
+                       if (assign_wp)
+                               mpi_free_limb_space(wp);
+                       goto free_tmp_limb;
+               }
                wsize -= cy ? 0:1;
        }
 
@@ -79,14 +93,17 @@ void mpi_mul(MPI w, MPI u, MPI v)
                mpi_assign_limb_space(w, wp, wsize);
        w->nlimbs = wsize;
        w->sign = sign_product;
+
+free_tmp_limb:
        if (tmp_limb)
                mpi_free_limb_space(tmp_limb);
+       return err;
 }
 EXPORT_SYMBOL_GPL(mpi_mul);
 
-void mpi_mulm(MPI w, MPI u, MPI v, MPI m)
+int mpi_mulm(MPI w, MPI u, MPI v, MPI m)
 {
-       mpi_mul(w, u, v);
-       mpi_tdiv_r(w, w, m);
+       return mpi_mul(w, u, v) ?:
+              mpi_tdiv_r(w, w, m);
 }
 EXPORT_SYMBOL_GPL(mpi_mulm);
index d57fd8a..979ece5 100644 (file)
@@ -133,6 +133,8 @@ MPI mpi_copy(MPI a)
 
        if (a) {
                b = mpi_alloc(a->nlimbs);
+               if (!b)
+                       return NULL;
                b->nlimbs = a->nlimbs;
                b->sign = a->sign;
                b->flags = a->flags;