Merge branch 'pm-opp'
[linux-2.6-microblaze.git] / fs / ksmbd / connection.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2016 Namjae Jeon <namjae.jeon@protocolfreedom.org>
4  *   Copyright (C) 2018 Samsung Electronics Co., Ltd.
5  */
6
7 #include <linux/mutex.h>
8 #include <linux/freezer.h>
9 #include <linux/module.h>
10
11 #include "server.h"
12 #include "smb_common.h"
13 #include "mgmt/ksmbd_ida.h"
14 #include "connection.h"
15 #include "transport_tcp.h"
16 #include "transport_rdma.h"
17
18 static DEFINE_MUTEX(init_lock);
19
20 static struct ksmbd_conn_ops default_conn_ops;
21
22 LIST_HEAD(conn_list);
23 DEFINE_RWLOCK(conn_list_lock);
24
25 /**
26  * ksmbd_conn_free() - free resources of the connection instance
27  *
28  * @conn:       connection instance to be cleand up
29  *
30  * During the thread termination, the corresponding conn instance
31  * resources(sock/memory) are released and finally the conn object is freed.
32  */
33 void ksmbd_conn_free(struct ksmbd_conn *conn)
34 {
35         write_lock(&conn_list_lock);
36         list_del(&conn->conns_list);
37         write_unlock(&conn_list_lock);
38
39         kvfree(conn->request_buf);
40         kfree(conn->preauth_info);
41         kfree(conn);
42 }
43
44 /**
45  * ksmbd_conn_alloc() - initialize a new connection instance
46  *
47  * Return:      ksmbd_conn struct on success, otherwise NULL
48  */
49 struct ksmbd_conn *ksmbd_conn_alloc(void)
50 {
51         struct ksmbd_conn *conn;
52
53         conn = kzalloc(sizeof(struct ksmbd_conn), GFP_KERNEL);
54         if (!conn)
55                 return NULL;
56
57         conn->need_neg = true;
58         conn->status = KSMBD_SESS_NEW;
59         conn->local_nls = load_nls("utf8");
60         if (!conn->local_nls)
61                 conn->local_nls = load_nls_default();
62         atomic_set(&conn->req_running, 0);
63         atomic_set(&conn->r_count, 0);
64         init_waitqueue_head(&conn->req_running_q);
65         INIT_LIST_HEAD(&conn->conns_list);
66         INIT_LIST_HEAD(&conn->sessions);
67         INIT_LIST_HEAD(&conn->requests);
68         INIT_LIST_HEAD(&conn->async_requests);
69         spin_lock_init(&conn->request_lock);
70         spin_lock_init(&conn->credits_lock);
71         ida_init(&conn->async_ida);
72
73         spin_lock_init(&conn->llist_lock);
74         INIT_LIST_HEAD(&conn->lock_list);
75
76         write_lock(&conn_list_lock);
77         list_add(&conn->conns_list, &conn_list);
78         write_unlock(&conn_list_lock);
79         return conn;
80 }
81
82 bool ksmbd_conn_lookup_dialect(struct ksmbd_conn *c)
83 {
84         struct ksmbd_conn *t;
85         bool ret = false;
86
87         read_lock(&conn_list_lock);
88         list_for_each_entry(t, &conn_list, conns_list) {
89                 if (memcmp(t->ClientGUID, c->ClientGUID, SMB2_CLIENT_GUID_SIZE))
90                         continue;
91
92                 ret = true;
93                 break;
94         }
95         read_unlock(&conn_list_lock);
96         return ret;
97 }
98
99 void ksmbd_conn_enqueue_request(struct ksmbd_work *work)
100 {
101         struct ksmbd_conn *conn = work->conn;
102         struct list_head *requests_queue = NULL;
103
104         if (conn->ops->get_cmd_val(work) != SMB2_CANCEL_HE) {
105                 requests_queue = &conn->requests;
106                 work->syncronous = true;
107         }
108
109         if (requests_queue) {
110                 atomic_inc(&conn->req_running);
111                 spin_lock(&conn->request_lock);
112                 list_add_tail(&work->request_entry, requests_queue);
113                 spin_unlock(&conn->request_lock);
114         }
115 }
116
117 int ksmbd_conn_try_dequeue_request(struct ksmbd_work *work)
118 {
119         struct ksmbd_conn *conn = work->conn;
120         int ret = 1;
121
122         if (list_empty(&work->request_entry) &&
123             list_empty(&work->async_request_entry))
124                 return 0;
125
126         if (!work->multiRsp)
127                 atomic_dec(&conn->req_running);
128         spin_lock(&conn->request_lock);
129         if (!work->multiRsp) {
130                 list_del_init(&work->request_entry);
131                 if (work->syncronous == false)
132                         list_del_init(&work->async_request_entry);
133                 ret = 0;
134         }
135         spin_unlock(&conn->request_lock);
136
137         wake_up_all(&conn->req_running_q);
138         return ret;
139 }
140
141 static void ksmbd_conn_lock(struct ksmbd_conn *conn)
142 {
143         mutex_lock(&conn->srv_mutex);
144 }
145
146 static void ksmbd_conn_unlock(struct ksmbd_conn *conn)
147 {
148         mutex_unlock(&conn->srv_mutex);
149 }
150
151 void ksmbd_conn_wait_idle(struct ksmbd_conn *conn)
152 {
153         wait_event(conn->req_running_q, atomic_read(&conn->req_running) < 2);
154 }
155
156 int ksmbd_conn_write(struct ksmbd_work *work)
157 {
158         struct ksmbd_conn *conn = work->conn;
159         struct smb_hdr *rsp_hdr = work->response_buf;
160         size_t len = 0;
161         int sent;
162         struct kvec iov[3];
163         int iov_idx = 0;
164
165         ksmbd_conn_try_dequeue_request(work);
166         if (!rsp_hdr) {
167                 pr_err("NULL response header\n");
168                 return -EINVAL;
169         }
170
171         if (work->tr_buf) {
172                 iov[iov_idx] = (struct kvec) { work->tr_buf,
173                                 sizeof(struct smb2_transform_hdr) };
174                 len += iov[iov_idx++].iov_len;
175         }
176
177         if (work->aux_payload_sz) {
178                 iov[iov_idx] = (struct kvec) { rsp_hdr, work->resp_hdr_sz };
179                 len += iov[iov_idx++].iov_len;
180                 iov[iov_idx] = (struct kvec) { work->aux_payload_buf, work->aux_payload_sz };
181                 len += iov[iov_idx++].iov_len;
182         } else {
183                 if (work->tr_buf)
184                         iov[iov_idx].iov_len = work->resp_hdr_sz;
185                 else
186                         iov[iov_idx].iov_len = get_rfc1002_len(rsp_hdr) + 4;
187                 iov[iov_idx].iov_base = rsp_hdr;
188                 len += iov[iov_idx++].iov_len;
189         }
190
191         ksmbd_conn_lock(conn);
192         sent = conn->transport->ops->writev(conn->transport, &iov[0],
193                                         iov_idx, len,
194                                         work->need_invalidate_rkey,
195                                         work->remote_key);
196         ksmbd_conn_unlock(conn);
197
198         if (sent < 0) {
199                 pr_err("Failed to send message: %d\n", sent);
200                 return sent;
201         }
202
203         return 0;
204 }
205
206 int ksmbd_conn_rdma_read(struct ksmbd_conn *conn, void *buf,
207                          unsigned int buflen, u32 remote_key, u64 remote_offset,
208                          u32 remote_len)
209 {
210         int ret = -EINVAL;
211
212         if (conn->transport->ops->rdma_read)
213                 ret = conn->transport->ops->rdma_read(conn->transport,
214                                                       buf, buflen,
215                                                       remote_key, remote_offset,
216                                                       remote_len);
217         return ret;
218 }
219
220 int ksmbd_conn_rdma_write(struct ksmbd_conn *conn, void *buf,
221                           unsigned int buflen, u32 remote_key,
222                           u64 remote_offset, u32 remote_len)
223 {
224         int ret = -EINVAL;
225
226         if (conn->transport->ops->rdma_write)
227                 ret = conn->transport->ops->rdma_write(conn->transport,
228                                                        buf, buflen,
229                                                        remote_key, remote_offset,
230                                                        remote_len);
231         return ret;
232 }
233
234 bool ksmbd_conn_alive(struct ksmbd_conn *conn)
235 {
236         if (!ksmbd_server_running())
237                 return false;
238
239         if (conn->status == KSMBD_SESS_EXITING)
240                 return false;
241
242         if (kthread_should_stop())
243                 return false;
244
245         if (atomic_read(&conn->stats.open_files_count) > 0)
246                 return true;
247
248         /*
249          * Stop current session if the time that get last request from client
250          * is bigger than deadtime user configured and opening file count is
251          * zero.
252          */
253         if (server_conf.deadtime > 0 &&
254             time_after(jiffies, conn->last_active + server_conf.deadtime)) {
255                 ksmbd_debug(CONN, "No response from client in %lu minutes\n",
256                             server_conf.deadtime / SMB_ECHO_INTERVAL);
257                 return false;
258         }
259         return true;
260 }
261
262 /**
263  * ksmbd_conn_handler_loop() - session thread to listen on new smb requests
264  * @p:          connection instance
265  *
266  * One thread each per connection
267  *
268  * Return:      0 on success
269  */
270 int ksmbd_conn_handler_loop(void *p)
271 {
272         struct ksmbd_conn *conn = (struct ksmbd_conn *)p;
273         struct ksmbd_transport *t = conn->transport;
274         unsigned int pdu_size;
275         char hdr_buf[4] = {0,};
276         int size;
277
278         mutex_init(&conn->srv_mutex);
279         __module_get(THIS_MODULE);
280
281         if (t->ops->prepare && t->ops->prepare(t))
282                 goto out;
283
284         conn->last_active = jiffies;
285         while (ksmbd_conn_alive(conn)) {
286                 if (try_to_freeze())
287                         continue;
288
289                 kvfree(conn->request_buf);
290                 conn->request_buf = NULL;
291
292                 size = t->ops->read(t, hdr_buf, sizeof(hdr_buf));
293                 if (size != sizeof(hdr_buf))
294                         break;
295
296                 pdu_size = get_rfc1002_len(hdr_buf);
297                 ksmbd_debug(CONN, "RFC1002 header %u bytes\n", pdu_size);
298
299                 /* make sure we have enough to get to SMB header end */
300                 if (!ksmbd_pdu_size_has_room(pdu_size)) {
301                         ksmbd_debug(CONN, "SMB request too short (%u bytes)\n",
302                                     pdu_size);
303                         continue;
304                 }
305
306                 /* 4 for rfc1002 length field */
307                 size = pdu_size + 4;
308                 conn->request_buf = kvmalloc(size, GFP_KERNEL);
309                 if (!conn->request_buf)
310                         continue;
311
312                 memcpy(conn->request_buf, hdr_buf, sizeof(hdr_buf));
313                 if (!ksmbd_smb_request(conn))
314                         break;
315
316                 /*
317                  * We already read 4 bytes to find out PDU size, now
318                  * read in PDU
319                  */
320                 size = t->ops->read(t, conn->request_buf + 4, pdu_size);
321                 if (size < 0) {
322                         pr_err("sock_read failed: %d\n", size);
323                         break;
324                 }
325
326                 if (size != pdu_size) {
327                         pr_err("PDU error. Read: %d, Expected: %d\n",
328                                size, pdu_size);
329                         continue;
330                 }
331
332                 if (!default_conn_ops.process_fn) {
333                         pr_err("No connection request callback\n");
334                         break;
335                 }
336
337                 if (default_conn_ops.process_fn(conn)) {
338                         pr_err("Cannot handle request\n");
339                         break;
340                 }
341         }
342
343 out:
344         /* Wait till all reference dropped to the Server object*/
345         while (atomic_read(&conn->r_count) > 0)
346                 schedule_timeout(HZ);
347
348         unload_nls(conn->local_nls);
349         if (default_conn_ops.terminate_fn)
350                 default_conn_ops.terminate_fn(conn);
351         t->ops->disconnect(t);
352         module_put(THIS_MODULE);
353         return 0;
354 }
355
356 void ksmbd_conn_init_server_callbacks(struct ksmbd_conn_ops *ops)
357 {
358         default_conn_ops.process_fn = ops->process_fn;
359         default_conn_ops.terminate_fn = ops->terminate_fn;
360 }
361
362 int ksmbd_conn_transport_init(void)
363 {
364         int ret;
365
366         mutex_lock(&init_lock);
367         ret = ksmbd_tcp_init();
368         if (ret) {
369                 pr_err("Failed to init TCP subsystem: %d\n", ret);
370                 goto out;
371         }
372
373         ret = ksmbd_rdma_init();
374         if (ret) {
375                 pr_err("Failed to init RDMA subsystem: %d\n", ret);
376                 goto out;
377         }
378 out:
379         mutex_unlock(&init_lock);
380         return ret;
381 }
382
383 static void stop_sessions(void)
384 {
385         struct ksmbd_conn *conn;
386
387 again:
388         read_lock(&conn_list_lock);
389         list_for_each_entry(conn, &conn_list, conns_list) {
390                 struct task_struct *task;
391
392                 task = conn->transport->handler;
393                 if (task)
394                         ksmbd_debug(CONN, "Stop session handler %s/%d\n",
395                                     task->comm, task_pid_nr(task));
396                 conn->status = KSMBD_SESS_EXITING;
397         }
398         read_unlock(&conn_list_lock);
399
400         if (!list_empty(&conn_list)) {
401                 schedule_timeout_interruptible(HZ / 10); /* 100ms */
402                 goto again;
403         }
404 }
405
406 void ksmbd_conn_transport_destroy(void)
407 {
408         mutex_lock(&init_lock);
409         ksmbd_tcp_destroy();
410         ksmbd_rdma_destroy();
411         stop_sessions();
412         mutex_unlock(&init_lock);
413 }