mm: simplify compat numa syscalls
[linux-2.6-microblaze.git] / mm / mempolicy.c
index 5e90b3f..eb95578 100644 (file)
@@ -1362,16 +1362,33 @@ mpol_out:
 /*
  * User space interface with variable sized bitmaps for nodelists.
  */
+static int get_bitmap(unsigned long *mask, const unsigned long __user *nmask,
+                     unsigned long maxnode)
+{
+       unsigned long nlongs = BITS_TO_LONGS(maxnode);
+       int ret;
+
+       if (in_compat_syscall())
+               ret = compat_get_bitmap(mask,
+                                       (const compat_ulong_t __user *)nmask,
+                                       maxnode);
+       else
+               ret = copy_from_user(mask, nmask,
+                                    nlongs * sizeof(unsigned long));
+
+       if (ret)
+               return -EFAULT;
+
+       if (maxnode % BITS_PER_LONG)
+               mask[nlongs - 1] &= (1UL << (maxnode % BITS_PER_LONG)) - 1;
+
+       return 0;
+}
 
 /* Copy a node mask from user space. */
 static int get_nodes(nodemask_t *nodes, const unsigned long __user *nmask,
                     unsigned long maxnode)
 {
-       unsigned long k;
-       unsigned long t;
-       unsigned long nlongs;
-       unsigned long endmask;
-
        --maxnode;
        nodes_clear(*nodes);
        if (maxnode == 0 || !nmask)
@@ -1379,49 +1396,29 @@ static int get_nodes(nodemask_t *nodes, const unsigned long __user *nmask,
        if (maxnode > PAGE_SIZE*BITS_PER_BYTE)
                return -EINVAL;
 
-       nlongs = BITS_TO_LONGS(maxnode);
-       if ((maxnode % BITS_PER_LONG) == 0)
-               endmask = ~0UL;
-       else
-               endmask = (1UL << (maxnode % BITS_PER_LONG)) - 1;
-
        /*
         * When the user specified more nodes than supported just check
-        * if the non supported part is all zero.
-        *
-        * If maxnode have more longs than MAX_NUMNODES, check
-        * the bits in that area first. And then go through to
-        * check the rest bits which equal or bigger than MAX_NUMNODES.
-        * Otherwise, just check bits [MAX_NUMNODES, maxnode).
+        * if the non supported part is all zero, one word at a time,
+        * starting at the end.
         */
-       if (nlongs > BITS_TO_LONGS(MAX_NUMNODES)) {
-               for (k = BITS_TO_LONGS(MAX_NUMNODES); k < nlongs; k++) {
-                       if (get_user(t, nmask + k))
-                               return -EFAULT;
-                       if (k == nlongs - 1) {
-                               if (t & endmask)
-                                       return -EINVAL;
-                       } else if (t)
-                               return -EINVAL;
-               }
-               nlongs = BITS_TO_LONGS(MAX_NUMNODES);
-               endmask = ~0UL;
-       }
+       while (maxnode > MAX_NUMNODES) {
+               unsigned long bits = min_t(unsigned long, maxnode, BITS_PER_LONG);
+               unsigned long t;
 
-       if (maxnode > MAX_NUMNODES && MAX_NUMNODES % BITS_PER_LONG != 0) {
-               unsigned long valid_mask = endmask;
-
-               valid_mask &= ~((1UL << (MAX_NUMNODES % BITS_PER_LONG)) - 1);
-               if (get_user(t, nmask + nlongs - 1))
+               if (get_bitmap(&t, &nmask[maxnode / BITS_PER_LONG], bits))
                        return -EFAULT;
-               if (t & valid_mask)
+
+               if (maxnode - bits >= MAX_NUMNODES) {
+                       maxnode -= bits;
+               } else {
+                       maxnode = MAX_NUMNODES;
+                       t &= ~((1UL << (MAX_NUMNODES % BITS_PER_LONG)) - 1);
+               }
+               if (t)
                        return -EINVAL;
        }
 
-       if (copy_from_user(nodes_addr(*nodes), nmask, nlongs*sizeof(unsigned long)))
-               return -EFAULT;
-       nodes_addr(*nodes)[nlongs-1] &= endmask;
-       return 0;
+       return get_bitmap(nodes_addr(*nodes), nmask, maxnode);
 }
 
 /* Copy a kernel node mask to user space */
@@ -1430,6 +1427,10 @@ static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode,
 {
        unsigned long copy = ALIGN(maxnode-1, 64) / 8;
        unsigned int nbytes = BITS_TO_LONGS(nr_node_ids) * sizeof(long);
+       bool compat = in_compat_syscall();
+
+       if (compat)
+               nbytes = BITS_TO_COMPAT_LONGS(nr_node_ids) * sizeof(compat_long_t);
 
        if (copy > nbytes) {
                if (copy > PAGE_SIZE)
@@ -1437,7 +1438,13 @@ static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode,
                if (clear_user((char __user *)mask + nbytes, copy - nbytes))
                        return -EFAULT;
                copy = nbytes;
+               maxnode = nr_node_ids;
        }
+
+       if (compat)
+               return compat_put_bitmap((compat_ulong_t __user *)mask,
+                                        nodes_addr(*nodes), maxnode);
+
        return copy_to_user(mask, nodes_addr(*nodes), copy) ? -EFAULT : 0;
 }
 
@@ -1649,72 +1656,22 @@ COMPAT_SYSCALL_DEFINE5(get_mempolicy, int __user *, policy,
                       compat_ulong_t, maxnode,
                       compat_ulong_t, addr, compat_ulong_t, flags)
 {
-       long err;
-       unsigned long __user *nm = NULL;
-       unsigned long nr_bits, alloc_size;
-       DECLARE_BITMAP(bm, MAX_NUMNODES);
-
-       nr_bits = min_t(unsigned long, maxnode-1, nr_node_ids);
-       alloc_size = ALIGN(nr_bits, BITS_PER_LONG) / 8;
-
-       if (nmask)
-               nm = compat_alloc_user_space(alloc_size);
-
-       err = kernel_get_mempolicy(policy, nm, nr_bits+1, addr, flags);
-
-       if (!err && nmask) {
-               unsigned long copy_size;
-               copy_size = min_t(unsigned long, sizeof(bm), alloc_size);
-               err = copy_from_user(bm, nm, copy_size);
-               /* ensure entire bitmap is zeroed */
-               err |= clear_user(nmask, ALIGN(maxnode-1, 8) / 8);
-               err |= compat_put_bitmap(nmask, bm, nr_bits);
-       }
-
-       return err;
+       return kernel_get_mempolicy(policy, (unsigned long __user *)nmask,
+                                   maxnode, addr, flags);
 }
 
 COMPAT_SYSCALL_DEFINE3(set_mempolicy, int, mode, compat_ulong_t __user *, nmask,
                       compat_ulong_t, maxnode)
 {
-       unsigned long __user *nm = NULL;
-       unsigned long nr_bits, alloc_size;
-       DECLARE_BITMAP(bm, MAX_NUMNODES);
-
-       nr_bits = min_t(unsigned long, maxnode-1, MAX_NUMNODES);
-       alloc_size = ALIGN(nr_bits, BITS_PER_LONG) / 8;
-
-       if (nmask) {
-               if (compat_get_bitmap(bm, nmask, nr_bits))
-                       return -EFAULT;
-               nm = compat_alloc_user_space(alloc_size);
-               if (copy_to_user(nm, bm, alloc_size))
-                       return -EFAULT;
-       }
-
-       return kernel_set_mempolicy(mode, nm, nr_bits+1);
+       return kernel_set_mempolicy(mode, (unsigned long __user *)nmask, maxnode);
 }
 
 COMPAT_SYSCALL_DEFINE6(mbind, compat_ulong_t, start, compat_ulong_t, len,
                       compat_ulong_t, mode, compat_ulong_t __user *, nmask,
                       compat_ulong_t, maxnode, compat_ulong_t, flags)
 {
-       unsigned long __user *nm = NULL;
-       unsigned long nr_bits, alloc_size;
-       nodemask_t bm;
-
-       nr_bits = min_t(unsigned long, maxnode-1, MAX_NUMNODES);
-       alloc_size = ALIGN(nr_bits, BITS_PER_LONG) / 8;
-
-       if (nmask) {
-               if (compat_get_bitmap(nodes_addr(bm), nmask, nr_bits))
-                       return -EFAULT;
-               nm = compat_alloc_user_space(alloc_size);
-               if (copy_to_user(nm, nodes_addr(bm), alloc_size))
-                       return -EFAULT;
-       }
-
-       return kernel_mbind(start, len, mode, nm, nr_bits+1, flags);
+       return kernel_mbind(start, len, mode, (unsigned long __user *)nmask,
+                           maxnode, flags);
 }
 
 COMPAT_SYSCALL_DEFINE4(migrate_pages, compat_pid_t, pid,
@@ -1722,32 +1679,9 @@ COMPAT_SYSCALL_DEFINE4(migrate_pages, compat_pid_t, pid,
                       const compat_ulong_t __user *, old_nodes,
                       const compat_ulong_t __user *, new_nodes)
 {
-       unsigned long __user *old = NULL;
-       unsigned long __user *new = NULL;
-       nodemask_t tmp_mask;
-       unsigned long nr_bits;
-       unsigned long size;
-
-       nr_bits = min_t(unsigned long, maxnode - 1, MAX_NUMNODES);
-       size = ALIGN(nr_bits, BITS_PER_LONG) / 8;
-       if (old_nodes) {
-               if (compat_get_bitmap(nodes_addr(tmp_mask), old_nodes, nr_bits))
-                       return -EFAULT;
-               old = compat_alloc_user_space(new_nodes ? size * 2 : size);
-               if (new_nodes)
-                       new = old + size / sizeof(unsigned long);
-               if (copy_to_user(old, nodes_addr(tmp_mask), size))
-                       return -EFAULT;
-       }
-       if (new_nodes) {
-               if (compat_get_bitmap(nodes_addr(tmp_mask), new_nodes, nr_bits))
-                       return -EFAULT;
-               if (new == NULL)
-                       new = compat_alloc_user_space(size);
-               if (copy_to_user(new, nodes_addr(tmp_mask), size))
-                       return -EFAULT;
-       }
-       return kernel_migrate_pages(pid, nr_bits + 1, old, new);
+       return kernel_migrate_pages(pid, maxnode,
+                                   (const unsigned long __user *)old_nodes,
+                                   (const unsigned long __user *)new_nodes);
 }
 
 #endif /* CONFIG_COMPAT */