Merge tag 'mips-fixes_5.16_3' of git://git.kernel.org/pub/scm/linux/kernel/git/mips...
[linux-2.6-microblaze.git] / fs / ksmbd / connection.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2016 Namjae Jeon <namjae.jeon@protocolfreedom.org>
4  *   Copyright (C) 2018 Samsung Electronics Co., Ltd.
5  */
6
7 #include <linux/mutex.h>
8 #include <linux/freezer.h>
9 #include <linux/module.h>
10
11 #include "server.h"
12 #include "smb_common.h"
13 #include "mgmt/ksmbd_ida.h"
14 #include "connection.h"
15 #include "transport_tcp.h"
16 #include "transport_rdma.h"
17
18 static DEFINE_MUTEX(init_lock);
19
20 static struct ksmbd_conn_ops default_conn_ops;
21
22 LIST_HEAD(conn_list);
23 DEFINE_RWLOCK(conn_list_lock);
24
25 /**
26  * ksmbd_conn_free() - free resources of the connection instance
27  *
28  * @conn:       connection instance to be cleand up
29  *
30  * During the thread termination, the corresponding conn instance
31  * resources(sock/memory) are released and finally the conn object is freed.
32  */
33 void ksmbd_conn_free(struct ksmbd_conn *conn)
34 {
35         write_lock(&conn_list_lock);
36         list_del(&conn->conns_list);
37         write_unlock(&conn_list_lock);
38
39         kvfree(conn->request_buf);
40         kfree(conn->preauth_info);
41         kfree(conn);
42 }
43
44 /**
45  * ksmbd_conn_alloc() - initialize a new connection instance
46  *
47  * Return:      ksmbd_conn struct on success, otherwise NULL
48  */
49 struct ksmbd_conn *ksmbd_conn_alloc(void)
50 {
51         struct ksmbd_conn *conn;
52
53         conn = kzalloc(sizeof(struct ksmbd_conn), GFP_KERNEL);
54         if (!conn)
55                 return NULL;
56
57         conn->need_neg = true;
58         conn->status = KSMBD_SESS_NEW;
59         conn->local_nls = load_nls("utf8");
60         if (!conn->local_nls)
61                 conn->local_nls = load_nls_default();
62         atomic_set(&conn->req_running, 0);
63         atomic_set(&conn->r_count, 0);
64         conn->total_credits = 1;
65
66         init_waitqueue_head(&conn->req_running_q);
67         INIT_LIST_HEAD(&conn->conns_list);
68         INIT_LIST_HEAD(&conn->sessions);
69         INIT_LIST_HEAD(&conn->requests);
70         INIT_LIST_HEAD(&conn->async_requests);
71         spin_lock_init(&conn->request_lock);
72         spin_lock_init(&conn->credits_lock);
73         ida_init(&conn->async_ida);
74
75         spin_lock_init(&conn->llist_lock);
76         INIT_LIST_HEAD(&conn->lock_list);
77
78         write_lock(&conn_list_lock);
79         list_add(&conn->conns_list, &conn_list);
80         write_unlock(&conn_list_lock);
81         return conn;
82 }
83
84 bool ksmbd_conn_lookup_dialect(struct ksmbd_conn *c)
85 {
86         struct ksmbd_conn *t;
87         bool ret = false;
88
89         read_lock(&conn_list_lock);
90         list_for_each_entry(t, &conn_list, conns_list) {
91                 if (memcmp(t->ClientGUID, c->ClientGUID, SMB2_CLIENT_GUID_SIZE))
92                         continue;
93
94                 ret = true;
95                 break;
96         }
97         read_unlock(&conn_list_lock);
98         return ret;
99 }
100
101 void ksmbd_conn_enqueue_request(struct ksmbd_work *work)
102 {
103         struct ksmbd_conn *conn = work->conn;
104         struct list_head *requests_queue = NULL;
105
106         if (conn->ops->get_cmd_val(work) != SMB2_CANCEL_HE) {
107                 requests_queue = &conn->requests;
108                 work->syncronous = true;
109         }
110
111         if (requests_queue) {
112                 atomic_inc(&conn->req_running);
113                 spin_lock(&conn->request_lock);
114                 list_add_tail(&work->request_entry, requests_queue);
115                 spin_unlock(&conn->request_lock);
116         }
117 }
118
119 int ksmbd_conn_try_dequeue_request(struct ksmbd_work *work)
120 {
121         struct ksmbd_conn *conn = work->conn;
122         int ret = 1;
123
124         if (list_empty(&work->request_entry) &&
125             list_empty(&work->async_request_entry))
126                 return 0;
127
128         if (!work->multiRsp)
129                 atomic_dec(&conn->req_running);
130         spin_lock(&conn->request_lock);
131         if (!work->multiRsp) {
132                 list_del_init(&work->request_entry);
133                 if (work->syncronous == false)
134                         list_del_init(&work->async_request_entry);
135                 ret = 0;
136         }
137         spin_unlock(&conn->request_lock);
138
139         wake_up_all(&conn->req_running_q);
140         return ret;
141 }
142
143 static void ksmbd_conn_lock(struct ksmbd_conn *conn)
144 {
145         mutex_lock(&conn->srv_mutex);
146 }
147
148 static void ksmbd_conn_unlock(struct ksmbd_conn *conn)
149 {
150         mutex_unlock(&conn->srv_mutex);
151 }
152
153 void ksmbd_conn_wait_idle(struct ksmbd_conn *conn)
154 {
155         wait_event(conn->req_running_q, atomic_read(&conn->req_running) < 2);
156 }
157
158 int ksmbd_conn_write(struct ksmbd_work *work)
159 {
160         struct ksmbd_conn *conn = work->conn;
161         size_t len = 0;
162         int sent;
163         struct kvec iov[3];
164         int iov_idx = 0;
165
166         ksmbd_conn_try_dequeue_request(work);
167         if (!work->response_buf) {
168                 pr_err("NULL response header\n");
169                 return -EINVAL;
170         }
171
172         if (work->tr_buf) {
173                 iov[iov_idx] = (struct kvec) { work->tr_buf,
174                                 sizeof(struct smb2_transform_hdr) + 4 };
175                 len += iov[iov_idx++].iov_len;
176         }
177
178         if (work->aux_payload_sz) {
179                 iov[iov_idx] = (struct kvec) { work->response_buf, work->resp_hdr_sz };
180                 len += iov[iov_idx++].iov_len;
181                 iov[iov_idx] = (struct kvec) { work->aux_payload_buf, work->aux_payload_sz };
182                 len += iov[iov_idx++].iov_len;
183         } else {
184                 if (work->tr_buf)
185                         iov[iov_idx].iov_len = work->resp_hdr_sz;
186                 else
187                         iov[iov_idx].iov_len = get_rfc1002_len(work->response_buf) + 4;
188                 iov[iov_idx].iov_base = work->response_buf;
189                 len += iov[iov_idx++].iov_len;
190         }
191
192         ksmbd_conn_lock(conn);
193         sent = conn->transport->ops->writev(conn->transport, &iov[0],
194                                         iov_idx, len,
195                                         work->need_invalidate_rkey,
196                                         work->remote_key);
197         ksmbd_conn_unlock(conn);
198
199         if (sent < 0) {
200                 pr_err("Failed to send message: %d\n", sent);
201                 return sent;
202         }
203
204         return 0;
205 }
206
207 int ksmbd_conn_rdma_read(struct ksmbd_conn *conn, void *buf,
208                          unsigned int buflen, u32 remote_key, u64 remote_offset,
209                          u32 remote_len)
210 {
211         int ret = -EINVAL;
212
213         if (conn->transport->ops->rdma_read)
214                 ret = conn->transport->ops->rdma_read(conn->transport,
215                                                       buf, buflen,
216                                                       remote_key, remote_offset,
217                                                       remote_len);
218         return ret;
219 }
220
221 int ksmbd_conn_rdma_write(struct ksmbd_conn *conn, void *buf,
222                           unsigned int buflen, u32 remote_key,
223                           u64 remote_offset, u32 remote_len)
224 {
225         int ret = -EINVAL;
226
227         if (conn->transport->ops->rdma_write)
228                 ret = conn->transport->ops->rdma_write(conn->transport,
229                                                        buf, buflen,
230                                                        remote_key, remote_offset,
231                                                        remote_len);
232         return ret;
233 }
234
235 bool ksmbd_conn_alive(struct ksmbd_conn *conn)
236 {
237         if (!ksmbd_server_running())
238                 return false;
239
240         if (conn->status == KSMBD_SESS_EXITING)
241                 return false;
242
243         if (kthread_should_stop())
244                 return false;
245
246         if (atomic_read(&conn->stats.open_files_count) > 0)
247                 return true;
248
249         /*
250          * Stop current session if the time that get last request from client
251          * is bigger than deadtime user configured and opening file count is
252          * zero.
253          */
254         if (server_conf.deadtime > 0 &&
255             time_after(jiffies, conn->last_active + server_conf.deadtime)) {
256                 ksmbd_debug(CONN, "No response from client in %lu minutes\n",
257                             server_conf.deadtime / SMB_ECHO_INTERVAL);
258                 return false;
259         }
260         return true;
261 }
262
263 /**
264  * ksmbd_conn_handler_loop() - session thread to listen on new smb requests
265  * @p:          connection instance
266  *
267  * One thread each per connection
268  *
269  * Return:      0 on success
270  */
271 int ksmbd_conn_handler_loop(void *p)
272 {
273         struct ksmbd_conn *conn = (struct ksmbd_conn *)p;
274         struct ksmbd_transport *t = conn->transport;
275         unsigned int pdu_size;
276         char hdr_buf[4] = {0,};
277         int size;
278
279         mutex_init(&conn->srv_mutex);
280         __module_get(THIS_MODULE);
281
282         if (t->ops->prepare && t->ops->prepare(t))
283                 goto out;
284
285         conn->last_active = jiffies;
286         while (ksmbd_conn_alive(conn)) {
287                 if (try_to_freeze())
288                         continue;
289
290                 kvfree(conn->request_buf);
291                 conn->request_buf = NULL;
292
293                 size = t->ops->read(t, hdr_buf, sizeof(hdr_buf));
294                 if (size != sizeof(hdr_buf))
295                         break;
296
297                 pdu_size = get_rfc1002_len(hdr_buf);
298                 ksmbd_debug(CONN, "RFC1002 header %u bytes\n", pdu_size);
299
300                 /*
301                  * Check if pdu size is valid (min : smb header size,
302                  * max : 0x00FFFFFF).
303                  */
304                 if (pdu_size < __SMB2_HEADER_STRUCTURE_SIZE ||
305                     pdu_size > MAX_STREAM_PROT_LEN) {
306                         continue;
307                 }
308
309                 /* 4 for rfc1002 length field */
310                 size = pdu_size + 4;
311                 conn->request_buf = kvmalloc(size, GFP_KERNEL);
312                 if (!conn->request_buf)
313                         continue;
314
315                 memcpy(conn->request_buf, hdr_buf, sizeof(hdr_buf));
316                 if (!ksmbd_smb_request(conn))
317                         break;
318
319                 /*
320                  * We already read 4 bytes to find out PDU size, now
321                  * read in PDU
322                  */
323                 size = t->ops->read(t, conn->request_buf + 4, pdu_size);
324                 if (size < 0) {
325                         pr_err("sock_read failed: %d\n", size);
326                         break;
327                 }
328
329                 if (size != pdu_size) {
330                         pr_err("PDU error. Read: %d, Expected: %d\n",
331                                size, pdu_size);
332                         continue;
333                 }
334
335                 if (!default_conn_ops.process_fn) {
336                         pr_err("No connection request callback\n");
337                         break;
338                 }
339
340                 if (default_conn_ops.process_fn(conn)) {
341                         pr_err("Cannot handle request\n");
342                         break;
343                 }
344         }
345
346 out:
347         /* Wait till all reference dropped to the Server object*/
348         while (atomic_read(&conn->r_count) > 0)
349                 schedule_timeout(HZ);
350
351         unload_nls(conn->local_nls);
352         if (default_conn_ops.terminate_fn)
353                 default_conn_ops.terminate_fn(conn);
354         t->ops->disconnect(t);
355         module_put(THIS_MODULE);
356         return 0;
357 }
358
359 void ksmbd_conn_init_server_callbacks(struct ksmbd_conn_ops *ops)
360 {
361         default_conn_ops.process_fn = ops->process_fn;
362         default_conn_ops.terminate_fn = ops->terminate_fn;
363 }
364
365 int ksmbd_conn_transport_init(void)
366 {
367         int ret;
368
369         mutex_lock(&init_lock);
370         ret = ksmbd_tcp_init();
371         if (ret) {
372                 pr_err("Failed to init TCP subsystem: %d\n", ret);
373                 goto out;
374         }
375
376         ret = ksmbd_rdma_init();
377         if (ret) {
378                 pr_err("Failed to init RDMA subsystem: %d\n", ret);
379                 goto out;
380         }
381 out:
382         mutex_unlock(&init_lock);
383         return ret;
384 }
385
386 static void stop_sessions(void)
387 {
388         struct ksmbd_conn *conn;
389
390 again:
391         read_lock(&conn_list_lock);
392         list_for_each_entry(conn, &conn_list, conns_list) {
393                 struct task_struct *task;
394
395                 task = conn->transport->handler;
396                 if (task)
397                         ksmbd_debug(CONN, "Stop session handler %s/%d\n",
398                                     task->comm, task_pid_nr(task));
399                 conn->status = KSMBD_SESS_EXITING;
400         }
401         read_unlock(&conn_list_lock);
402
403         if (!list_empty(&conn_list)) {
404                 schedule_timeout_interruptible(HZ / 10); /* 100ms */
405                 goto again;
406         }
407 }
408
409 void ksmbd_conn_transport_destroy(void)
410 {
411         mutex_lock(&init_lock);
412         ksmbd_tcp_destroy();
413         ksmbd_rdma_destroy();
414         stop_sessions();
415         mutex_unlock(&init_lock);
416 }