Merge branch 'for-next' into for-linus
[linux-2.6-microblaze.git] / mm / maccess.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Access kernel or user memory without faulting.
4  */
5 #include <linux/export.h>
6 #include <linux/mm.h>
7 #include <linux/uaccess.h>
8
9 bool __weak probe_kernel_read_allowed(const void *unsafe_src, size_t size)
10 {
11         return true;
12 }
13
14 #ifdef HAVE_GET_KERNEL_NOFAULT
15
16 #define probe_kernel_read_loop(dst, src, len, type, err_label)          \
17         while (len >= sizeof(type)) {                                   \
18                 __get_kernel_nofault(dst, src, type, err_label);                \
19                 dst += sizeof(type);                                    \
20                 src += sizeof(type);                                    \
21                 len -= sizeof(type);                                    \
22         }
23
24 long probe_kernel_read(void *dst, const void *src, size_t size)
25 {
26         if (!probe_kernel_read_allowed(src, size))
27                 return -ERANGE;
28
29         pagefault_disable();
30         probe_kernel_read_loop(dst, src, size, u64, Efault);
31         probe_kernel_read_loop(dst, src, size, u32, Efault);
32         probe_kernel_read_loop(dst, src, size, u16, Efault);
33         probe_kernel_read_loop(dst, src, size, u8, Efault);
34         pagefault_enable();
35         return 0;
36 Efault:
37         pagefault_enable();
38         return -EFAULT;
39 }
40 EXPORT_SYMBOL_GPL(probe_kernel_read);
41
42 #define probe_kernel_write_loop(dst, src, len, type, err_label)         \
43         while (len >= sizeof(type)) {                                   \
44                 __put_kernel_nofault(dst, src, type, err_label);                \
45                 dst += sizeof(type);                                    \
46                 src += sizeof(type);                                    \
47                 len -= sizeof(type);                                    \
48         }
49
50 long probe_kernel_write(void *dst, const void *src, size_t size)
51 {
52         pagefault_disable();
53         probe_kernel_write_loop(dst, src, size, u64, Efault);
54         probe_kernel_write_loop(dst, src, size, u32, Efault);
55         probe_kernel_write_loop(dst, src, size, u16, Efault);
56         probe_kernel_write_loop(dst, src, size, u8, Efault);
57         pagefault_enable();
58         return 0;
59 Efault:
60         pagefault_enable();
61         return -EFAULT;
62 }
63
64 long strncpy_from_kernel_nofault(char *dst, const void *unsafe_addr, long count)
65 {
66         const void *src = unsafe_addr;
67
68         if (unlikely(count <= 0))
69                 return 0;
70         if (!probe_kernel_read_allowed(unsafe_addr, count))
71                 return -ERANGE;
72
73         pagefault_disable();
74         do {
75                 __get_kernel_nofault(dst, src, u8, Efault);
76                 dst++;
77                 src++;
78         } while (dst[-1] && src - unsafe_addr < count);
79         pagefault_enable();
80
81         dst[-1] = '\0';
82         return src - unsafe_addr;
83 Efault:
84         pagefault_enable();
85         dst[-1] = '\0';
86         return -EFAULT;
87 }
88 #else /* HAVE_GET_KERNEL_NOFAULT */
89 /**
90  * probe_kernel_read(): safely attempt to read from kernel-space
91  * @dst: pointer to the buffer that shall take the data
92  * @src: address to read from
93  * @size: size of the data chunk
94  *
95  * Safely read from kernel address @src to the buffer at @dst.  If a kernel
96  * fault happens, handle that and return -EFAULT.  If @src is not a valid kernel
97  * address, return -ERANGE.
98  *
99  * We ensure that the copy_from_user is executed in atomic context so that
100  * do_page_fault() doesn't attempt to take mmap_lock.  This makes
101  * probe_kernel_read() suitable for use within regions where the caller
102  * already holds mmap_lock, or other locks which nest inside mmap_lock.
103  */
104 long probe_kernel_read(void *dst, const void *src, size_t size)
105 {
106         long ret;
107         mm_segment_t old_fs = get_fs();
108
109         if (!probe_kernel_read_allowed(src, size))
110                 return -ERANGE;
111
112         set_fs(KERNEL_DS);
113         pagefault_disable();
114         ret = __copy_from_user_inatomic(dst, (__force const void __user *)src,
115                         size);
116         pagefault_enable();
117         set_fs(old_fs);
118
119         if (ret)
120                 return -EFAULT;
121         return 0;
122 }
123 EXPORT_SYMBOL_GPL(probe_kernel_read);
124
125 /**
126  * probe_kernel_write(): safely attempt to write to a location
127  * @dst: address to write to
128  * @src: pointer to the data that shall be written
129  * @size: size of the data chunk
130  *
131  * Safely write to address @dst from the buffer at @src.  If a kernel fault
132  * happens, handle that and return -EFAULT.
133  */
134 long probe_kernel_write(void *dst, const void *src, size_t size)
135 {
136         long ret;
137         mm_segment_t old_fs = get_fs();
138
139         set_fs(KERNEL_DS);
140         pagefault_disable();
141         ret = __copy_to_user_inatomic((__force void __user *)dst, src, size);
142         pagefault_enable();
143         set_fs(old_fs);
144
145         if (ret)
146                 return -EFAULT;
147         return 0;
148 }
149
150 /**
151  * strncpy_from_kernel_nofault: - Copy a NUL terminated string from unsafe
152  *                               address.
153  * @dst:   Destination address, in kernel space.  This buffer must be at
154  *         least @count bytes long.
155  * @unsafe_addr: Unsafe address.
156  * @count: Maximum number of bytes to copy, including the trailing NUL.
157  *
158  * Copies a NUL-terminated string from unsafe address to kernel buffer.
159  *
160  * On success, returns the length of the string INCLUDING the trailing NUL.
161  *
162  * If access fails, returns -EFAULT (some data may have been copied and the
163  * trailing NUL added).  If @unsafe_addr is not a valid kernel address, return
164  * -ERANGE.
165  *
166  * If @count is smaller than the length of the string, copies @count-1 bytes,
167  * sets the last byte of @dst buffer to NUL and returns @count.
168  */
169 long strncpy_from_kernel_nofault(char *dst, const void *unsafe_addr, long count)
170 {
171         mm_segment_t old_fs = get_fs();
172         const void *src = unsafe_addr;
173         long ret;
174
175         if (unlikely(count <= 0))
176                 return 0;
177         if (!probe_kernel_read_allowed(unsafe_addr, count))
178                 return -ERANGE;
179
180         set_fs(KERNEL_DS);
181         pagefault_disable();
182
183         do {
184                 ret = __get_user(*dst++, (const char __user __force *)src++);
185         } while (dst[-1] && ret == 0 && src - unsafe_addr < count);
186
187         dst[-1] = '\0';
188         pagefault_enable();
189         set_fs(old_fs);
190
191         return ret ? -EFAULT : src - unsafe_addr;
192 }
193 #endif /* HAVE_GET_KERNEL_NOFAULT */
194
195 /**
196  * probe_user_read(): safely attempt to read from a user-space location
197  * @dst: pointer to the buffer that shall take the data
198  * @src: address to read from. This must be a user address.
199  * @size: size of the data chunk
200  *
201  * Safely read from user address @src to the buffer at @dst. If a kernel fault
202  * happens, handle that and return -EFAULT.
203  */
204 long probe_user_read(void *dst, const void __user *src, size_t size)
205 {
206         long ret = -EFAULT;
207         mm_segment_t old_fs = get_fs();
208
209         set_fs(USER_DS);
210         if (access_ok(src, size)) {
211                 pagefault_disable();
212                 ret = __copy_from_user_inatomic(dst, src, size);
213                 pagefault_enable();
214         }
215         set_fs(old_fs);
216
217         if (ret)
218                 return -EFAULT;
219         return 0;
220 }
221 EXPORT_SYMBOL_GPL(probe_user_read);
222
223 /**
224  * probe_user_write(): safely attempt to write to a user-space location
225  * @dst: address to write to
226  * @src: pointer to the data that shall be written
227  * @size: size of the data chunk
228  *
229  * Safely write to address @dst from the buffer at @src.  If a kernel fault
230  * happens, handle that and return -EFAULT.
231  */
232 long probe_user_write(void __user *dst, const void *src, size_t size)
233 {
234         long ret = -EFAULT;
235         mm_segment_t old_fs = get_fs();
236
237         set_fs(USER_DS);
238         if (access_ok(dst, size)) {
239                 pagefault_disable();
240                 ret = __copy_to_user_inatomic(dst, src, size);
241                 pagefault_enable();
242         }
243         set_fs(old_fs);
244
245         if (ret)
246                 return -EFAULT;
247         return 0;
248 }
249 EXPORT_SYMBOL_GPL(probe_user_write);
250
251 /**
252  * strncpy_from_user_nofault: - Copy a NUL terminated string from unsafe user
253  *                              address.
254  * @dst:   Destination address, in kernel space.  This buffer must be at
255  *         least @count bytes long.
256  * @unsafe_addr: Unsafe user address.
257  * @count: Maximum number of bytes to copy, including the trailing NUL.
258  *
259  * Copies a NUL-terminated string from unsafe user address to kernel buffer.
260  *
261  * On success, returns the length of the string INCLUDING the trailing NUL.
262  *
263  * If access fails, returns -EFAULT (some data may have been copied
264  * and the trailing NUL added).
265  *
266  * If @count is smaller than the length of the string, copies @count-1 bytes,
267  * sets the last byte of @dst buffer to NUL and returns @count.
268  */
269 long strncpy_from_user_nofault(char *dst, const void __user *unsafe_addr,
270                               long count)
271 {
272         mm_segment_t old_fs = get_fs();
273         long ret;
274
275         if (unlikely(count <= 0))
276                 return 0;
277
278         set_fs(USER_DS);
279         pagefault_disable();
280         ret = strncpy_from_user(dst, unsafe_addr, count);
281         pagefault_enable();
282         set_fs(old_fs);
283
284         if (ret >= count) {
285                 ret = count;
286                 dst[ret - 1] = '\0';
287         } else if (ret > 0) {
288                 ret++;
289         }
290
291         return ret;
292 }
293
294 /**
295  * strnlen_user_nofault: - Get the size of a user string INCLUDING final NUL.
296  * @unsafe_addr: The string to measure.
297  * @count: Maximum count (including NUL)
298  *
299  * Get the size of a NUL-terminated string in user space without pagefault.
300  *
301  * Returns the size of the string INCLUDING the terminating NUL.
302  *
303  * If the string is too long, returns a number larger than @count. User
304  * has to check the return value against "> count".
305  * On exception (or invalid count), returns 0.
306  *
307  * Unlike strnlen_user, this can be used from IRQ handler etc. because
308  * it disables pagefaults.
309  */
310 long strnlen_user_nofault(const void __user *unsafe_addr, long count)
311 {
312         mm_segment_t old_fs = get_fs();
313         int ret;
314
315         set_fs(USER_DS);
316         pagefault_disable();
317         ret = strnlen_user(unsafe_addr, count);
318         pagefault_enable();
319         set_fs(old_fs);
320
321         return ret;
322 }