Merge tag 'block-5.15-2021-09-11' of git://git.kernel.dk/linux-block
[linux-2.6-microblaze.git] / mm / maccess.c
index f98ff91..d3f1a1f 100644 (file)
@@ -24,13 +24,21 @@ bool __weak copy_from_kernel_nofault_allowed(const void *unsafe_src,
 
 long copy_from_kernel_nofault(void *dst, const void *src, size_t size)
 {
+       unsigned long align = 0;
+
+       if (!IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS))
+               align = (unsigned long)dst | (unsigned long)src;
+
        if (!copy_from_kernel_nofault_allowed(src, size))
                return -ERANGE;
 
        pagefault_disable();
-       copy_from_kernel_nofault_loop(dst, src, size, u64, Efault);
-       copy_from_kernel_nofault_loop(dst, src, size, u32, Efault);
-       copy_from_kernel_nofault_loop(dst, src, size, u16, Efault);
+       if (!(align & 7))
+               copy_from_kernel_nofault_loop(dst, src, size, u64, Efault);
+       if (!(align & 3))
+               copy_from_kernel_nofault_loop(dst, src, size, u32, Efault);
+       if (!(align & 1))
+               copy_from_kernel_nofault_loop(dst, src, size, u16, Efault);
        copy_from_kernel_nofault_loop(dst, src, size, u8, Efault);
        pagefault_enable();
        return 0;
@@ -50,10 +58,18 @@ EXPORT_SYMBOL_GPL(copy_from_kernel_nofault);
 
 long copy_to_kernel_nofault(void *dst, const void *src, size_t size)
 {
+       unsigned long align = 0;
+
+       if (!IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS))
+               align = (unsigned long)dst | (unsigned long)src;
+
        pagefault_disable();
-       copy_to_kernel_nofault_loop(dst, src, size, u64, Efault);
-       copy_to_kernel_nofault_loop(dst, src, size, u32, Efault);
-       copy_to_kernel_nofault_loop(dst, src, size, u16, Efault);
+       if (!(align & 7))
+               copy_to_kernel_nofault_loop(dst, src, size, u64, Efault);
+       if (!(align & 3))
+               copy_to_kernel_nofault_loop(dst, src, size, u32, Efault);
+       if (!(align & 1))
+               copy_to_kernel_nofault_loop(dst, src, size, u16, Efault);
        copy_to_kernel_nofault_loop(dst, src, size, u8, Efault);
        pagefault_enable();
        return 0;
@@ -205,15 +221,14 @@ long strncpy_from_kernel_nofault(char *dst, const void *unsafe_addr, long count)
 long copy_from_user_nofault(void *dst, const void __user *src, size_t size)
 {
        long ret = -EFAULT;
-       mm_segment_t old_fs = get_fs();
+       mm_segment_t old_fs = force_uaccess_begin();
 
-       set_fs(USER_DS);
        if (access_ok(src, size)) {
                pagefault_disable();
                ret = __copy_from_user_inatomic(dst, src, size);
                pagefault_enable();
        }
-       set_fs(old_fs);
+       force_uaccess_end(old_fs);
 
        if (ret)
                return -EFAULT;
@@ -233,15 +248,14 @@ EXPORT_SYMBOL_GPL(copy_from_user_nofault);
 long copy_to_user_nofault(void __user *dst, const void *src, size_t size)
 {
        long ret = -EFAULT;
-       mm_segment_t old_fs = get_fs();
+       mm_segment_t old_fs = force_uaccess_begin();
 
-       set_fs(USER_DS);
        if (access_ok(dst, size)) {
                pagefault_disable();
                ret = __copy_to_user_inatomic(dst, src, size);
                pagefault_enable();
        }
-       set_fs(old_fs);
+       force_uaccess_end(old_fs);
 
        if (ret)
                return -EFAULT;
@@ -270,17 +284,17 @@ EXPORT_SYMBOL_GPL(copy_to_user_nofault);
 long strncpy_from_user_nofault(char *dst, const void __user *unsafe_addr,
                              long count)
 {
-       mm_segment_t old_fs = get_fs();
+       mm_segment_t old_fs;
        long ret;
 
        if (unlikely(count <= 0))
                return 0;
 
-       set_fs(USER_DS);
+       old_fs = force_uaccess_begin();
        pagefault_disable();
        ret = strncpy_from_user(dst, unsafe_addr, count);
        pagefault_enable();
-       set_fs(old_fs);
+       force_uaccess_end(old_fs);
 
        if (ret >= count) {
                ret = count;
@@ -310,14 +324,14 @@ long strncpy_from_user_nofault(char *dst, const void __user *unsafe_addr,
  */
 long strnlen_user_nofault(const void __user *unsafe_addr, long count)
 {
-       mm_segment_t old_fs = get_fs();
+       mm_segment_t old_fs;
        int ret;
 
-       set_fs(USER_DS);
+       old_fs = force_uaccess_begin();
        pagefault_disable();
        ret = strnlen_user(unsafe_addr, count);
        pagefault_enable();
-       set_fs(old_fs);
+       force_uaccess_end(old_fs);
 
        return ret;
 }