Merge tag 'qcom-clk-for-5.20' of https://git.kernel.org/pub/scm/linux/kernel/git...
[linux-2.6-microblaze.git] / fs / ksmbd / transport_rdma.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2017, Microsoft Corporation.
4  *   Copyright (C) 2018, LG Electronics.
5  *
6  *   Author(s): Long Li <longli@microsoft.com>,
7  *              Hyunchul Lee <hyc.lee@gmail.com>
8  *
9  *   This program is free software;  you can redistribute it and/or modify
10  *   it under the terms of the GNU General Public License as published by
11  *   the Free Software Foundation; either version 2 of the License, or
12  *   (at your option) any later version.
13  *
14  *   This program is distributed in the hope that it will be useful,
15  *   but WITHOUT ANY WARRANTY;  without even the implied warranty of
16  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See
17  *   the GNU General Public License for more details.
18  */
19
20 #define SUBMOD_NAME     "smb_direct"
21
22 #include <linux/kthread.h>
23 #include <linux/list.h>
24 #include <linux/mempool.h>
25 #include <linux/highmem.h>
26 #include <linux/scatterlist.h>
27 #include <rdma/ib_verbs.h>
28 #include <rdma/rdma_cm.h>
29 #include <rdma/rw.h>
30
31 #include "glob.h"
32 #include "connection.h"
33 #include "smb_common.h"
34 #include "smbstatus.h"
35 #include "transport_rdma.h"
36
37 #define SMB_DIRECT_PORT_IWARP           5445
38 #define SMB_DIRECT_PORT_INFINIBAND      445
39
40 #define SMB_DIRECT_VERSION_LE           cpu_to_le16(0x0100)
41
42 /* SMB_DIRECT negotiation timeout in seconds */
43 #define SMB_DIRECT_NEGOTIATE_TIMEOUT            120
44
45 #define SMB_DIRECT_MAX_SEND_SGES                8
46 #define SMB_DIRECT_MAX_RECV_SGES                1
47
48 /*
49  * Default maximum number of RDMA read/write outstanding on this connection
50  * This value is possibly decreased during QP creation on hardware limit
51  */
52 #define SMB_DIRECT_CM_INITIATOR_DEPTH           8
53
54 /* Maximum number of retries on data transfer operations */
55 #define SMB_DIRECT_CM_RETRY                     6
56 /* No need to retry on Receiver Not Ready since SMB_DIRECT manages credits */
57 #define SMB_DIRECT_CM_RNR_RETRY         0
58
59 /*
60  * User configurable initial values per SMB_DIRECT transport connection
61  * as defined in [MS-SMBD] 3.1.1.1
62  * Those may change after a SMB_DIRECT negotiation
63  */
64
65 /* Set 445 port to SMB Direct port by default */
66 static int smb_direct_port = SMB_DIRECT_PORT_INFINIBAND;
67
68 /* The local peer's maximum number of credits to grant to the peer */
69 static int smb_direct_receive_credit_max = 255;
70
71 /* The remote peer's credit request of local peer */
72 static int smb_direct_send_credit_target = 255;
73
74 /* The maximum single message size can be sent to remote peer */
75 static int smb_direct_max_send_size = 8192;
76
77 /*  The maximum fragmented upper-layer payload receive size supported */
78 static int smb_direct_max_fragmented_recv_size = 1024 * 1024;
79
80 /*  The maximum single-message size which can be received */
81 static int smb_direct_max_receive_size = 8192;
82
83 static int smb_direct_max_read_write_size = SMBD_DEFAULT_IOSIZE;
84
85 static LIST_HEAD(smb_direct_device_list);
86 static DEFINE_RWLOCK(smb_direct_device_lock);
87
88 struct smb_direct_device {
89         struct ib_device        *ib_dev;
90         struct list_head        list;
91 };
92
93 static struct smb_direct_listener {
94         struct rdma_cm_id       *cm_id;
95 } smb_direct_listener;
96
97 static struct workqueue_struct *smb_direct_wq;
98
99 enum smb_direct_status {
100         SMB_DIRECT_CS_NEW = 0,
101         SMB_DIRECT_CS_CONNECTED,
102         SMB_DIRECT_CS_DISCONNECTING,
103         SMB_DIRECT_CS_DISCONNECTED,
104 };
105
106 struct smb_direct_transport {
107         struct ksmbd_transport  transport;
108
109         enum smb_direct_status  status;
110         bool                    full_packet_received;
111         wait_queue_head_t       wait_status;
112
113         struct rdma_cm_id       *cm_id;
114         struct ib_cq            *send_cq;
115         struct ib_cq            *recv_cq;
116         struct ib_pd            *pd;
117         struct ib_qp            *qp;
118
119         int                     max_send_size;
120         int                     max_recv_size;
121         int                     max_fragmented_send_size;
122         int                     max_fragmented_recv_size;
123         int                     max_rdma_rw_size;
124
125         spinlock_t              reassembly_queue_lock;
126         struct list_head        reassembly_queue;
127         int                     reassembly_data_length;
128         int                     reassembly_queue_length;
129         int                     first_entry_offset;
130         wait_queue_head_t       wait_reassembly_queue;
131
132         spinlock_t              receive_credit_lock;
133         int                     recv_credits;
134         int                     count_avail_recvmsg;
135         int                     recv_credit_max;
136         int                     recv_credit_target;
137
138         spinlock_t              recvmsg_queue_lock;
139         struct list_head        recvmsg_queue;
140
141         spinlock_t              empty_recvmsg_queue_lock;
142         struct list_head        empty_recvmsg_queue;
143
144         int                     send_credit_target;
145         atomic_t                send_credits;
146         spinlock_t              lock_new_recv_credits;
147         int                     new_recv_credits;
148         int                     max_rw_credits;
149         int                     pages_per_rw_credit;
150         atomic_t                rw_credits;
151
152         wait_queue_head_t       wait_send_credits;
153         wait_queue_head_t       wait_rw_credits;
154
155         mempool_t               *sendmsg_mempool;
156         struct kmem_cache       *sendmsg_cache;
157         mempool_t               *recvmsg_mempool;
158         struct kmem_cache       *recvmsg_cache;
159
160         wait_queue_head_t       wait_send_pending;
161         atomic_t                send_pending;
162
163         struct delayed_work     post_recv_credits_work;
164         struct work_struct      send_immediate_work;
165         struct work_struct      disconnect_work;
166
167         bool                    negotiation_requested;
168 };
169
170 #define KSMBD_TRANS(t) ((struct ksmbd_transport *)&((t)->transport))
171
172 enum {
173         SMB_DIRECT_MSG_NEGOTIATE_REQ = 0,
174         SMB_DIRECT_MSG_DATA_TRANSFER
175 };
176
177 static struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops;
178
179 struct smb_direct_send_ctx {
180         struct list_head        msg_list;
181         int                     wr_cnt;
182         bool                    need_invalidate_rkey;
183         unsigned int            remote_key;
184 };
185
186 struct smb_direct_sendmsg {
187         struct smb_direct_transport     *transport;
188         struct ib_send_wr       wr;
189         struct list_head        list;
190         int                     num_sge;
191         struct ib_sge           sge[SMB_DIRECT_MAX_SEND_SGES];
192         struct ib_cqe           cqe;
193         u8                      packet[];
194 };
195
196 struct smb_direct_recvmsg {
197         struct smb_direct_transport     *transport;
198         struct list_head        list;
199         int                     type;
200         struct ib_sge           sge;
201         struct ib_cqe           cqe;
202         bool                    first_segment;
203         u8                      packet[];
204 };
205
206 struct smb_direct_rdma_rw_msg {
207         struct smb_direct_transport     *t;
208         struct ib_cqe           cqe;
209         int                     status;
210         struct completion       *completion;
211         struct list_head        list;
212         struct rdma_rw_ctx      rw_ctx;
213         struct sg_table         sgt;
214         struct scatterlist      sg_list[];
215 };
216
217 void init_smbd_max_io_size(unsigned int sz)
218 {
219         sz = clamp_val(sz, SMBD_MIN_IOSIZE, SMBD_MAX_IOSIZE);
220         smb_direct_max_read_write_size = sz;
221 }
222
223 unsigned int get_smbd_max_read_write_size(void)
224 {
225         return smb_direct_max_read_write_size;
226 }
227
228 static inline int get_buf_page_count(void *buf, int size)
229 {
230         return DIV_ROUND_UP((uintptr_t)buf + size, PAGE_SIZE) -
231                 (uintptr_t)buf / PAGE_SIZE;
232 }
233
234 static void smb_direct_destroy_pools(struct smb_direct_transport *transport);
235 static void smb_direct_post_recv_credits(struct work_struct *work);
236 static int smb_direct_post_send_data(struct smb_direct_transport *t,
237                                      struct smb_direct_send_ctx *send_ctx,
238                                      struct kvec *iov, int niov,
239                                      int remaining_data_length);
240
241 static inline struct smb_direct_transport *
242 smb_trans_direct_transfort(struct ksmbd_transport *t)
243 {
244         return container_of(t, struct smb_direct_transport, transport);
245 }
246
247 static inline void
248 *smb_direct_recvmsg_payload(struct smb_direct_recvmsg *recvmsg)
249 {
250         return (void *)recvmsg->packet;
251 }
252
253 static inline bool is_receive_credit_post_required(int receive_credits,
254                                                    int avail_recvmsg_count)
255 {
256         return receive_credits <= (smb_direct_receive_credit_max >> 3) &&
257                 avail_recvmsg_count >= (receive_credits >> 2);
258 }
259
260 static struct
261 smb_direct_recvmsg *get_free_recvmsg(struct smb_direct_transport *t)
262 {
263         struct smb_direct_recvmsg *recvmsg = NULL;
264
265         spin_lock(&t->recvmsg_queue_lock);
266         if (!list_empty(&t->recvmsg_queue)) {
267                 recvmsg = list_first_entry(&t->recvmsg_queue,
268                                            struct smb_direct_recvmsg,
269                                            list);
270                 list_del(&recvmsg->list);
271         }
272         spin_unlock(&t->recvmsg_queue_lock);
273         return recvmsg;
274 }
275
276 static void put_recvmsg(struct smb_direct_transport *t,
277                         struct smb_direct_recvmsg *recvmsg)
278 {
279         ib_dma_unmap_single(t->cm_id->device, recvmsg->sge.addr,
280                             recvmsg->sge.length, DMA_FROM_DEVICE);
281
282         spin_lock(&t->recvmsg_queue_lock);
283         list_add(&recvmsg->list, &t->recvmsg_queue);
284         spin_unlock(&t->recvmsg_queue_lock);
285 }
286
287 static struct
288 smb_direct_recvmsg *get_empty_recvmsg(struct smb_direct_transport *t)
289 {
290         struct smb_direct_recvmsg *recvmsg = NULL;
291
292         spin_lock(&t->empty_recvmsg_queue_lock);
293         if (!list_empty(&t->empty_recvmsg_queue)) {
294                 recvmsg = list_first_entry(&t->empty_recvmsg_queue,
295                                            struct smb_direct_recvmsg, list);
296                 list_del(&recvmsg->list);
297         }
298         spin_unlock(&t->empty_recvmsg_queue_lock);
299         return recvmsg;
300 }
301
302 static void put_empty_recvmsg(struct smb_direct_transport *t,
303                               struct smb_direct_recvmsg *recvmsg)
304 {
305         ib_dma_unmap_single(t->cm_id->device, recvmsg->sge.addr,
306                             recvmsg->sge.length, DMA_FROM_DEVICE);
307
308         spin_lock(&t->empty_recvmsg_queue_lock);
309         list_add_tail(&recvmsg->list, &t->empty_recvmsg_queue);
310         spin_unlock(&t->empty_recvmsg_queue_lock);
311 }
312
313 static void enqueue_reassembly(struct smb_direct_transport *t,
314                                struct smb_direct_recvmsg *recvmsg,
315                                int data_length)
316 {
317         spin_lock(&t->reassembly_queue_lock);
318         list_add_tail(&recvmsg->list, &t->reassembly_queue);
319         t->reassembly_queue_length++;
320         /*
321          * Make sure reassembly_data_length is updated after list and
322          * reassembly_queue_length are updated. On the dequeue side
323          * reassembly_data_length is checked without a lock to determine
324          * if reassembly_queue_length and list is up to date
325          */
326         virt_wmb();
327         t->reassembly_data_length += data_length;
328         spin_unlock(&t->reassembly_queue_lock);
329 }
330
331 static struct smb_direct_recvmsg *get_first_reassembly(struct smb_direct_transport *t)
332 {
333         if (!list_empty(&t->reassembly_queue))
334                 return list_first_entry(&t->reassembly_queue,
335                                 struct smb_direct_recvmsg, list);
336         else
337                 return NULL;
338 }
339
340 static void smb_direct_disconnect_rdma_work(struct work_struct *work)
341 {
342         struct smb_direct_transport *t =
343                 container_of(work, struct smb_direct_transport,
344                              disconnect_work);
345
346         if (t->status == SMB_DIRECT_CS_CONNECTED) {
347                 t->status = SMB_DIRECT_CS_DISCONNECTING;
348                 rdma_disconnect(t->cm_id);
349         }
350 }
351
352 static void
353 smb_direct_disconnect_rdma_connection(struct smb_direct_transport *t)
354 {
355         if (t->status == SMB_DIRECT_CS_CONNECTED)
356                 queue_work(smb_direct_wq, &t->disconnect_work);
357 }
358
359 static void smb_direct_send_immediate_work(struct work_struct *work)
360 {
361         struct smb_direct_transport *t = container_of(work,
362                         struct smb_direct_transport, send_immediate_work);
363
364         if (t->status != SMB_DIRECT_CS_CONNECTED)
365                 return;
366
367         smb_direct_post_send_data(t, NULL, NULL, 0, 0);
368 }
369
370 static struct smb_direct_transport *alloc_transport(struct rdma_cm_id *cm_id)
371 {
372         struct smb_direct_transport *t;
373         struct ksmbd_conn *conn;
374
375         t = kzalloc(sizeof(*t), GFP_KERNEL);
376         if (!t)
377                 return NULL;
378
379         t->cm_id = cm_id;
380         cm_id->context = t;
381
382         t->status = SMB_DIRECT_CS_NEW;
383         init_waitqueue_head(&t->wait_status);
384
385         spin_lock_init(&t->reassembly_queue_lock);
386         INIT_LIST_HEAD(&t->reassembly_queue);
387         t->reassembly_data_length = 0;
388         t->reassembly_queue_length = 0;
389         init_waitqueue_head(&t->wait_reassembly_queue);
390         init_waitqueue_head(&t->wait_send_credits);
391         init_waitqueue_head(&t->wait_rw_credits);
392
393         spin_lock_init(&t->receive_credit_lock);
394         spin_lock_init(&t->recvmsg_queue_lock);
395         INIT_LIST_HEAD(&t->recvmsg_queue);
396
397         spin_lock_init(&t->empty_recvmsg_queue_lock);
398         INIT_LIST_HEAD(&t->empty_recvmsg_queue);
399
400         init_waitqueue_head(&t->wait_send_pending);
401         atomic_set(&t->send_pending, 0);
402
403         spin_lock_init(&t->lock_new_recv_credits);
404
405         INIT_DELAYED_WORK(&t->post_recv_credits_work,
406                           smb_direct_post_recv_credits);
407         INIT_WORK(&t->send_immediate_work, smb_direct_send_immediate_work);
408         INIT_WORK(&t->disconnect_work, smb_direct_disconnect_rdma_work);
409
410         conn = ksmbd_conn_alloc();
411         if (!conn)
412                 goto err;
413         conn->transport = KSMBD_TRANS(t);
414         KSMBD_TRANS(t)->conn = conn;
415         KSMBD_TRANS(t)->ops = &ksmbd_smb_direct_transport_ops;
416         return t;
417 err:
418         kfree(t);
419         return NULL;
420 }
421
422 static void free_transport(struct smb_direct_transport *t)
423 {
424         struct smb_direct_recvmsg *recvmsg;
425
426         wake_up_interruptible(&t->wait_send_credits);
427
428         ksmbd_debug(RDMA, "wait for all send posted to IB to finish\n");
429         wait_event(t->wait_send_pending,
430                    atomic_read(&t->send_pending) == 0);
431
432         cancel_work_sync(&t->disconnect_work);
433         cancel_delayed_work_sync(&t->post_recv_credits_work);
434         cancel_work_sync(&t->send_immediate_work);
435
436         if (t->qp) {
437                 ib_drain_qp(t->qp);
438                 ib_mr_pool_destroy(t->qp, &t->qp->rdma_mrs);
439                 ib_destroy_qp(t->qp);
440         }
441
442         ksmbd_debug(RDMA, "drain the reassembly queue\n");
443         do {
444                 spin_lock(&t->reassembly_queue_lock);
445                 recvmsg = get_first_reassembly(t);
446                 if (recvmsg) {
447                         list_del(&recvmsg->list);
448                         spin_unlock(&t->reassembly_queue_lock);
449                         put_recvmsg(t, recvmsg);
450                 } else {
451                         spin_unlock(&t->reassembly_queue_lock);
452                 }
453         } while (recvmsg);
454         t->reassembly_data_length = 0;
455
456         if (t->send_cq)
457                 ib_free_cq(t->send_cq);
458         if (t->recv_cq)
459                 ib_free_cq(t->recv_cq);
460         if (t->pd)
461                 ib_dealloc_pd(t->pd);
462         if (t->cm_id)
463                 rdma_destroy_id(t->cm_id);
464
465         smb_direct_destroy_pools(t);
466         ksmbd_conn_free(KSMBD_TRANS(t)->conn);
467         kfree(t);
468 }
469
470 static struct smb_direct_sendmsg
471 *smb_direct_alloc_sendmsg(struct smb_direct_transport *t)
472 {
473         struct smb_direct_sendmsg *msg;
474
475         msg = mempool_alloc(t->sendmsg_mempool, GFP_KERNEL);
476         if (!msg)
477                 return ERR_PTR(-ENOMEM);
478         msg->transport = t;
479         INIT_LIST_HEAD(&msg->list);
480         msg->num_sge = 0;
481         return msg;
482 }
483
484 static void smb_direct_free_sendmsg(struct smb_direct_transport *t,
485                                     struct smb_direct_sendmsg *msg)
486 {
487         int i;
488
489         if (msg->num_sge > 0) {
490                 ib_dma_unmap_single(t->cm_id->device,
491                                     msg->sge[0].addr, msg->sge[0].length,
492                                     DMA_TO_DEVICE);
493                 for (i = 1; i < msg->num_sge; i++)
494                         ib_dma_unmap_page(t->cm_id->device,
495                                           msg->sge[i].addr, msg->sge[i].length,
496                                           DMA_TO_DEVICE);
497         }
498         mempool_free(msg, t->sendmsg_mempool);
499 }
500
501 static int smb_direct_check_recvmsg(struct smb_direct_recvmsg *recvmsg)
502 {
503         switch (recvmsg->type) {
504         case SMB_DIRECT_MSG_DATA_TRANSFER: {
505                 struct smb_direct_data_transfer *req =
506                         (struct smb_direct_data_transfer *)recvmsg->packet;
507                 struct smb2_hdr *hdr = (struct smb2_hdr *)(recvmsg->packet
508                                 + le32_to_cpu(req->data_offset));
509                 ksmbd_debug(RDMA,
510                             "CreditGranted: %u, CreditRequested: %u, DataLength: %u, RemainingDataLength: %u, SMB: %x, Command: %u\n",
511                             le16_to_cpu(req->credits_granted),
512                             le16_to_cpu(req->credits_requested),
513                             req->data_length, req->remaining_data_length,
514                             hdr->ProtocolId, hdr->Command);
515                 break;
516         }
517         case SMB_DIRECT_MSG_NEGOTIATE_REQ: {
518                 struct smb_direct_negotiate_req *req =
519                         (struct smb_direct_negotiate_req *)recvmsg->packet;
520                 ksmbd_debug(RDMA,
521                             "MinVersion: %u, MaxVersion: %u, CreditRequested: %u, MaxSendSize: %u, MaxRecvSize: %u, MaxFragmentedSize: %u\n",
522                             le16_to_cpu(req->min_version),
523                             le16_to_cpu(req->max_version),
524                             le16_to_cpu(req->credits_requested),
525                             le32_to_cpu(req->preferred_send_size),
526                             le32_to_cpu(req->max_receive_size),
527                             le32_to_cpu(req->max_fragmented_size));
528                 if (le16_to_cpu(req->min_version) > 0x0100 ||
529                     le16_to_cpu(req->max_version) < 0x0100)
530                         return -EOPNOTSUPP;
531                 if (le16_to_cpu(req->credits_requested) <= 0 ||
532                     le32_to_cpu(req->max_receive_size) <= 128 ||
533                     le32_to_cpu(req->max_fragmented_size) <=
534                                         128 * 1024)
535                         return -ECONNABORTED;
536
537                 break;
538         }
539         default:
540                 return -EINVAL;
541         }
542         return 0;
543 }
544
545 static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
546 {
547         struct smb_direct_recvmsg *recvmsg;
548         struct smb_direct_transport *t;
549
550         recvmsg = container_of(wc->wr_cqe, struct smb_direct_recvmsg, cqe);
551         t = recvmsg->transport;
552
553         if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_RECV) {
554                 if (wc->status != IB_WC_WR_FLUSH_ERR) {
555                         pr_err("Recv error. status='%s (%d)' opcode=%d\n",
556                                ib_wc_status_msg(wc->status), wc->status,
557                                wc->opcode);
558                         smb_direct_disconnect_rdma_connection(t);
559                 }
560                 put_empty_recvmsg(t, recvmsg);
561                 return;
562         }
563
564         ksmbd_debug(RDMA, "Recv completed. status='%s (%d)', opcode=%d\n",
565                     ib_wc_status_msg(wc->status), wc->status,
566                     wc->opcode);
567
568         ib_dma_sync_single_for_cpu(wc->qp->device, recvmsg->sge.addr,
569                                    recvmsg->sge.length, DMA_FROM_DEVICE);
570
571         switch (recvmsg->type) {
572         case SMB_DIRECT_MSG_NEGOTIATE_REQ:
573                 if (wc->byte_len < sizeof(struct smb_direct_negotiate_req)) {
574                         put_empty_recvmsg(t, recvmsg);
575                         return;
576                 }
577                 t->negotiation_requested = true;
578                 t->full_packet_received = true;
579                 t->status = SMB_DIRECT_CS_CONNECTED;
580                 enqueue_reassembly(t, recvmsg, 0);
581                 wake_up_interruptible(&t->wait_status);
582                 break;
583         case SMB_DIRECT_MSG_DATA_TRANSFER: {
584                 struct smb_direct_data_transfer *data_transfer =
585                         (struct smb_direct_data_transfer *)recvmsg->packet;
586                 unsigned int data_length;
587                 int avail_recvmsg_count, receive_credits;
588
589                 if (wc->byte_len <
590                     offsetof(struct smb_direct_data_transfer, padding)) {
591                         put_empty_recvmsg(t, recvmsg);
592                         return;
593                 }
594
595                 data_length = le32_to_cpu(data_transfer->data_length);
596                 if (data_length) {
597                         if (wc->byte_len < sizeof(struct smb_direct_data_transfer) +
598                             (u64)data_length) {
599                                 put_empty_recvmsg(t, recvmsg);
600                                 return;
601                         }
602
603                         if (t->full_packet_received)
604                                 recvmsg->first_segment = true;
605
606                         if (le32_to_cpu(data_transfer->remaining_data_length))
607                                 t->full_packet_received = false;
608                         else
609                                 t->full_packet_received = true;
610
611                         enqueue_reassembly(t, recvmsg, (int)data_length);
612                         wake_up_interruptible(&t->wait_reassembly_queue);
613
614                         spin_lock(&t->receive_credit_lock);
615                         receive_credits = --(t->recv_credits);
616                         avail_recvmsg_count = t->count_avail_recvmsg;
617                         spin_unlock(&t->receive_credit_lock);
618                 } else {
619                         put_empty_recvmsg(t, recvmsg);
620
621                         spin_lock(&t->receive_credit_lock);
622                         receive_credits = --(t->recv_credits);
623                         avail_recvmsg_count = ++(t->count_avail_recvmsg);
624                         spin_unlock(&t->receive_credit_lock);
625                 }
626
627                 t->recv_credit_target =
628                                 le16_to_cpu(data_transfer->credits_requested);
629                 atomic_add(le16_to_cpu(data_transfer->credits_granted),
630                            &t->send_credits);
631
632                 if (le16_to_cpu(data_transfer->flags) &
633                     SMB_DIRECT_RESPONSE_REQUESTED)
634                         queue_work(smb_direct_wq, &t->send_immediate_work);
635
636                 if (atomic_read(&t->send_credits) > 0)
637                         wake_up_interruptible(&t->wait_send_credits);
638
639                 if (is_receive_credit_post_required(receive_credits, avail_recvmsg_count))
640                         mod_delayed_work(smb_direct_wq,
641                                          &t->post_recv_credits_work, 0);
642                 break;
643         }
644         default:
645                 break;
646         }
647 }
648
649 static int smb_direct_post_recv(struct smb_direct_transport *t,
650                                 struct smb_direct_recvmsg *recvmsg)
651 {
652         struct ib_recv_wr wr;
653         int ret;
654
655         recvmsg->sge.addr = ib_dma_map_single(t->cm_id->device,
656                                               recvmsg->packet, t->max_recv_size,
657                                               DMA_FROM_DEVICE);
658         ret = ib_dma_mapping_error(t->cm_id->device, recvmsg->sge.addr);
659         if (ret)
660                 return ret;
661         recvmsg->sge.length = t->max_recv_size;
662         recvmsg->sge.lkey = t->pd->local_dma_lkey;
663         recvmsg->cqe.done = recv_done;
664
665         wr.wr_cqe = &recvmsg->cqe;
666         wr.next = NULL;
667         wr.sg_list = &recvmsg->sge;
668         wr.num_sge = 1;
669
670         ret = ib_post_recv(t->qp, &wr, NULL);
671         if (ret) {
672                 pr_err("Can't post recv: %d\n", ret);
673                 ib_dma_unmap_single(t->cm_id->device,
674                                     recvmsg->sge.addr, recvmsg->sge.length,
675                                     DMA_FROM_DEVICE);
676                 smb_direct_disconnect_rdma_connection(t);
677                 return ret;
678         }
679         return ret;
680 }
681
682 static int smb_direct_read(struct ksmbd_transport *t, char *buf,
683                            unsigned int size)
684 {
685         struct smb_direct_recvmsg *recvmsg;
686         struct smb_direct_data_transfer *data_transfer;
687         int to_copy, to_read, data_read, offset;
688         u32 data_length, remaining_data_length, data_offset;
689         int rc;
690         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
691
692 again:
693         if (st->status != SMB_DIRECT_CS_CONNECTED) {
694                 pr_err("disconnected\n");
695                 return -ENOTCONN;
696         }
697
698         /*
699          * No need to hold the reassembly queue lock all the time as we are
700          * the only one reading from the front of the queue. The transport
701          * may add more entries to the back of the queue at the same time
702          */
703         if (st->reassembly_data_length >= size) {
704                 int queue_length;
705                 int queue_removed = 0;
706
707                 /*
708                  * Need to make sure reassembly_data_length is read before
709                  * reading reassembly_queue_length and calling
710                  * get_first_reassembly. This call is lock free
711                  * as we never read at the end of the queue which are being
712                  * updated in SOFTIRQ as more data is received
713                  */
714                 virt_rmb();
715                 queue_length = st->reassembly_queue_length;
716                 data_read = 0;
717                 to_read = size;
718                 offset = st->first_entry_offset;
719                 while (data_read < size) {
720                         recvmsg = get_first_reassembly(st);
721                         data_transfer = smb_direct_recvmsg_payload(recvmsg);
722                         data_length = le32_to_cpu(data_transfer->data_length);
723                         remaining_data_length =
724                                 le32_to_cpu(data_transfer->remaining_data_length);
725                         data_offset = le32_to_cpu(data_transfer->data_offset);
726
727                         /*
728                          * The upper layer expects RFC1002 length at the
729                          * beginning of the payload. Return it to indicate
730                          * the total length of the packet. This minimize the
731                          * change to upper layer packet processing logic. This
732                          * will be eventually remove when an intermediate
733                          * transport layer is added
734                          */
735                         if (recvmsg->first_segment && size == 4) {
736                                 unsigned int rfc1002_len =
737                                         data_length + remaining_data_length;
738                                 *((__be32 *)buf) = cpu_to_be32(rfc1002_len);
739                                 data_read = 4;
740                                 recvmsg->first_segment = false;
741                                 ksmbd_debug(RDMA,
742                                             "returning rfc1002 length %d\n",
743                                             rfc1002_len);
744                                 goto read_rfc1002_done;
745                         }
746
747                         to_copy = min_t(int, data_length - offset, to_read);
748                         memcpy(buf + data_read, (char *)data_transfer + data_offset + offset,
749                                to_copy);
750
751                         /* move on to the next buffer? */
752                         if (to_copy == data_length - offset) {
753                                 queue_length--;
754                                 /*
755                                  * No need to lock if we are not at the
756                                  * end of the queue
757                                  */
758                                 if (queue_length) {
759                                         list_del(&recvmsg->list);
760                                 } else {
761                                         spin_lock_irq(&st->reassembly_queue_lock);
762                                         list_del(&recvmsg->list);
763                                         spin_unlock_irq(&st->reassembly_queue_lock);
764                                 }
765                                 queue_removed++;
766                                 put_recvmsg(st, recvmsg);
767                                 offset = 0;
768                         } else {
769                                 offset += to_copy;
770                         }
771
772                         to_read -= to_copy;
773                         data_read += to_copy;
774                 }
775
776                 spin_lock_irq(&st->reassembly_queue_lock);
777                 st->reassembly_data_length -= data_read;
778                 st->reassembly_queue_length -= queue_removed;
779                 spin_unlock_irq(&st->reassembly_queue_lock);
780
781                 spin_lock(&st->receive_credit_lock);
782                 st->count_avail_recvmsg += queue_removed;
783                 if (is_receive_credit_post_required(st->recv_credits, st->count_avail_recvmsg)) {
784                         spin_unlock(&st->receive_credit_lock);
785                         mod_delayed_work(smb_direct_wq,
786                                          &st->post_recv_credits_work, 0);
787                 } else {
788                         spin_unlock(&st->receive_credit_lock);
789                 }
790
791                 st->first_entry_offset = offset;
792                 ksmbd_debug(RDMA,
793                             "returning to thread data_read=%d reassembly_data_length=%d first_entry_offset=%d\n",
794                             data_read, st->reassembly_data_length,
795                             st->first_entry_offset);
796 read_rfc1002_done:
797                 return data_read;
798         }
799
800         ksmbd_debug(RDMA, "wait_event on more data\n");
801         rc = wait_event_interruptible(st->wait_reassembly_queue,
802                                       st->reassembly_data_length >= size ||
803                                        st->status != SMB_DIRECT_CS_CONNECTED);
804         if (rc)
805                 return -EINTR;
806
807         goto again;
808 }
809
810 static void smb_direct_post_recv_credits(struct work_struct *work)
811 {
812         struct smb_direct_transport *t = container_of(work,
813                 struct smb_direct_transport, post_recv_credits_work.work);
814         struct smb_direct_recvmsg *recvmsg;
815         int receive_credits, credits = 0;
816         int ret;
817         int use_free = 1;
818
819         spin_lock(&t->receive_credit_lock);
820         receive_credits = t->recv_credits;
821         spin_unlock(&t->receive_credit_lock);
822
823         if (receive_credits < t->recv_credit_target) {
824                 while (true) {
825                         if (use_free)
826                                 recvmsg = get_free_recvmsg(t);
827                         else
828                                 recvmsg = get_empty_recvmsg(t);
829                         if (!recvmsg) {
830                                 if (use_free) {
831                                         use_free = 0;
832                                         continue;
833                                 } else {
834                                         break;
835                                 }
836                         }
837
838                         recvmsg->type = SMB_DIRECT_MSG_DATA_TRANSFER;
839                         recvmsg->first_segment = false;
840
841                         ret = smb_direct_post_recv(t, recvmsg);
842                         if (ret) {
843                                 pr_err("Can't post recv: %d\n", ret);
844                                 put_recvmsg(t, recvmsg);
845                                 break;
846                         }
847                         credits++;
848                 }
849         }
850
851         spin_lock(&t->receive_credit_lock);
852         t->recv_credits += credits;
853         t->count_avail_recvmsg -= credits;
854         spin_unlock(&t->receive_credit_lock);
855
856         spin_lock(&t->lock_new_recv_credits);
857         t->new_recv_credits += credits;
858         spin_unlock(&t->lock_new_recv_credits);
859
860         if (credits)
861                 queue_work(smb_direct_wq, &t->send_immediate_work);
862 }
863
864 static void send_done(struct ib_cq *cq, struct ib_wc *wc)
865 {
866         struct smb_direct_sendmsg *sendmsg, *sibling;
867         struct smb_direct_transport *t;
868         struct list_head *pos, *prev, *end;
869
870         sendmsg = container_of(wc->wr_cqe, struct smb_direct_sendmsg, cqe);
871         t = sendmsg->transport;
872
873         ksmbd_debug(RDMA, "Send completed. status='%s (%d)', opcode=%d\n",
874                     ib_wc_status_msg(wc->status), wc->status,
875                     wc->opcode);
876
877         if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_SEND) {
878                 pr_err("Send error. status='%s (%d)', opcode=%d\n",
879                        ib_wc_status_msg(wc->status), wc->status,
880                        wc->opcode);
881                 smb_direct_disconnect_rdma_connection(t);
882         }
883
884         if (atomic_dec_and_test(&t->send_pending))
885                 wake_up(&t->wait_send_pending);
886
887         /* iterate and free the list of messages in reverse. the list's head
888          * is invalid.
889          */
890         for (pos = &sendmsg->list, prev = pos->prev, end = sendmsg->list.next;
891              prev != end; pos = prev, prev = prev->prev) {
892                 sibling = container_of(pos, struct smb_direct_sendmsg, list);
893                 smb_direct_free_sendmsg(t, sibling);
894         }
895
896         sibling = container_of(pos, struct smb_direct_sendmsg, list);
897         smb_direct_free_sendmsg(t, sibling);
898 }
899
900 static int manage_credits_prior_sending(struct smb_direct_transport *t)
901 {
902         int new_credits;
903
904         spin_lock(&t->lock_new_recv_credits);
905         new_credits = t->new_recv_credits;
906         t->new_recv_credits = 0;
907         spin_unlock(&t->lock_new_recv_credits);
908
909         return new_credits;
910 }
911
912 static int smb_direct_post_send(struct smb_direct_transport *t,
913                                 struct ib_send_wr *wr)
914 {
915         int ret;
916
917         atomic_inc(&t->send_pending);
918         ret = ib_post_send(t->qp, wr, NULL);
919         if (ret) {
920                 pr_err("failed to post send: %d\n", ret);
921                 if (atomic_dec_and_test(&t->send_pending))
922                         wake_up(&t->wait_send_pending);
923                 smb_direct_disconnect_rdma_connection(t);
924         }
925         return ret;
926 }
927
928 static void smb_direct_send_ctx_init(struct smb_direct_transport *t,
929                                      struct smb_direct_send_ctx *send_ctx,
930                                      bool need_invalidate_rkey,
931                                      unsigned int remote_key)
932 {
933         INIT_LIST_HEAD(&send_ctx->msg_list);
934         send_ctx->wr_cnt = 0;
935         send_ctx->need_invalidate_rkey = need_invalidate_rkey;
936         send_ctx->remote_key = remote_key;
937 }
938
939 static int smb_direct_flush_send_list(struct smb_direct_transport *t,
940                                       struct smb_direct_send_ctx *send_ctx,
941                                       bool is_last)
942 {
943         struct smb_direct_sendmsg *first, *last;
944         int ret;
945
946         if (list_empty(&send_ctx->msg_list))
947                 return 0;
948
949         first = list_first_entry(&send_ctx->msg_list,
950                                  struct smb_direct_sendmsg,
951                                  list);
952         last = list_last_entry(&send_ctx->msg_list,
953                                struct smb_direct_sendmsg,
954                                list);
955
956         last->wr.send_flags = IB_SEND_SIGNALED;
957         last->wr.wr_cqe = &last->cqe;
958         if (is_last && send_ctx->need_invalidate_rkey) {
959                 last->wr.opcode = IB_WR_SEND_WITH_INV;
960                 last->wr.ex.invalidate_rkey = send_ctx->remote_key;
961         }
962
963         ret = smb_direct_post_send(t, &first->wr);
964         if (!ret) {
965                 smb_direct_send_ctx_init(t, send_ctx,
966                                          send_ctx->need_invalidate_rkey,
967                                          send_ctx->remote_key);
968         } else {
969                 atomic_add(send_ctx->wr_cnt, &t->send_credits);
970                 wake_up(&t->wait_send_credits);
971                 list_for_each_entry_safe(first, last, &send_ctx->msg_list,
972                                          list) {
973                         smb_direct_free_sendmsg(t, first);
974                 }
975         }
976         return ret;
977 }
978
979 static int wait_for_credits(struct smb_direct_transport *t,
980                             wait_queue_head_t *waitq, atomic_t *total_credits,
981                             int needed)
982 {
983         int ret;
984
985         do {
986                 if (atomic_sub_return(needed, total_credits) >= 0)
987                         return 0;
988
989                 atomic_add(needed, total_credits);
990                 ret = wait_event_interruptible(*waitq,
991                                                atomic_read(total_credits) >= needed ||
992                                                t->status != SMB_DIRECT_CS_CONNECTED);
993
994                 if (t->status != SMB_DIRECT_CS_CONNECTED)
995                         return -ENOTCONN;
996                 else if (ret < 0)
997                         return ret;
998         } while (true);
999 }
1000
1001 static int wait_for_send_credits(struct smb_direct_transport *t,
1002                                  struct smb_direct_send_ctx *send_ctx)
1003 {
1004         int ret;
1005
1006         if (send_ctx &&
1007             (send_ctx->wr_cnt >= 16 || atomic_read(&t->send_credits) <= 1)) {
1008                 ret = smb_direct_flush_send_list(t, send_ctx, false);
1009                 if (ret)
1010                         return ret;
1011         }
1012
1013         return wait_for_credits(t, &t->wait_send_credits, &t->send_credits, 1);
1014 }
1015
1016 static int wait_for_rw_credits(struct smb_direct_transport *t, int credits)
1017 {
1018         return wait_for_credits(t, &t->wait_rw_credits, &t->rw_credits, credits);
1019 }
1020
1021 static int calc_rw_credits(struct smb_direct_transport *t,
1022                            char *buf, unsigned int len)
1023 {
1024         return DIV_ROUND_UP(get_buf_page_count(buf, len),
1025                             t->pages_per_rw_credit);
1026 }
1027
1028 static int smb_direct_create_header(struct smb_direct_transport *t,
1029                                     int size, int remaining_data_length,
1030                                     struct smb_direct_sendmsg **sendmsg_out)
1031 {
1032         struct smb_direct_sendmsg *sendmsg;
1033         struct smb_direct_data_transfer *packet;
1034         int header_length;
1035         int ret;
1036
1037         sendmsg = smb_direct_alloc_sendmsg(t);
1038         if (IS_ERR(sendmsg))
1039                 return PTR_ERR(sendmsg);
1040
1041         /* Fill in the packet header */
1042         packet = (struct smb_direct_data_transfer *)sendmsg->packet;
1043         packet->credits_requested = cpu_to_le16(t->send_credit_target);
1044         packet->credits_granted = cpu_to_le16(manage_credits_prior_sending(t));
1045
1046         packet->flags = 0;
1047         packet->reserved = 0;
1048         if (!size)
1049                 packet->data_offset = 0;
1050         else
1051                 packet->data_offset = cpu_to_le32(24);
1052         packet->data_length = cpu_to_le32(size);
1053         packet->remaining_data_length = cpu_to_le32(remaining_data_length);
1054         packet->padding = 0;
1055
1056         ksmbd_debug(RDMA,
1057                     "credits_requested=%d credits_granted=%d data_offset=%d data_length=%d remaining_data_length=%d\n",
1058                     le16_to_cpu(packet->credits_requested),
1059                     le16_to_cpu(packet->credits_granted),
1060                     le32_to_cpu(packet->data_offset),
1061                     le32_to_cpu(packet->data_length),
1062                     le32_to_cpu(packet->remaining_data_length));
1063
1064         /* Map the packet to DMA */
1065         header_length = sizeof(struct smb_direct_data_transfer);
1066         /* If this is a packet without payload, don't send padding */
1067         if (!size)
1068                 header_length =
1069                         offsetof(struct smb_direct_data_transfer, padding);
1070
1071         sendmsg->sge[0].addr = ib_dma_map_single(t->cm_id->device,
1072                                                  (void *)packet,
1073                                                  header_length,
1074                                                  DMA_TO_DEVICE);
1075         ret = ib_dma_mapping_error(t->cm_id->device, sendmsg->sge[0].addr);
1076         if (ret) {
1077                 smb_direct_free_sendmsg(t, sendmsg);
1078                 return ret;
1079         }
1080
1081         sendmsg->num_sge = 1;
1082         sendmsg->sge[0].length = header_length;
1083         sendmsg->sge[0].lkey = t->pd->local_dma_lkey;
1084
1085         *sendmsg_out = sendmsg;
1086         return 0;
1087 }
1088
1089 static int get_sg_list(void *buf, int size, struct scatterlist *sg_list, int nentries)
1090 {
1091         bool high = is_vmalloc_addr(buf);
1092         struct page *page;
1093         int offset, len;
1094         int i = 0;
1095
1096         if (size <= 0 || nentries < get_buf_page_count(buf, size))
1097                 return -EINVAL;
1098
1099         offset = offset_in_page(buf);
1100         buf -= offset;
1101         while (size > 0) {
1102                 len = min_t(int, PAGE_SIZE - offset, size);
1103                 if (high)
1104                         page = vmalloc_to_page(buf);
1105                 else
1106                         page = kmap_to_page(buf);
1107
1108                 if (!sg_list)
1109                         return -EINVAL;
1110                 sg_set_page(sg_list, page, len, offset);
1111                 sg_list = sg_next(sg_list);
1112
1113                 buf += PAGE_SIZE;
1114                 size -= len;
1115                 offset = 0;
1116                 i++;
1117         }
1118         return i;
1119 }
1120
1121 static int get_mapped_sg_list(struct ib_device *device, void *buf, int size,
1122                               struct scatterlist *sg_list, int nentries,
1123                               enum dma_data_direction dir)
1124 {
1125         int npages;
1126
1127         npages = get_sg_list(buf, size, sg_list, nentries);
1128         if (npages < 0)
1129                 return -EINVAL;
1130         return ib_dma_map_sg(device, sg_list, npages, dir);
1131 }
1132
1133 static int post_sendmsg(struct smb_direct_transport *t,
1134                         struct smb_direct_send_ctx *send_ctx,
1135                         struct smb_direct_sendmsg *msg)
1136 {
1137         int i;
1138
1139         for (i = 0; i < msg->num_sge; i++)
1140                 ib_dma_sync_single_for_device(t->cm_id->device,
1141                                               msg->sge[i].addr, msg->sge[i].length,
1142                                               DMA_TO_DEVICE);
1143
1144         msg->cqe.done = send_done;
1145         msg->wr.opcode = IB_WR_SEND;
1146         msg->wr.sg_list = &msg->sge[0];
1147         msg->wr.num_sge = msg->num_sge;
1148         msg->wr.next = NULL;
1149
1150         if (send_ctx) {
1151                 msg->wr.wr_cqe = NULL;
1152                 msg->wr.send_flags = 0;
1153                 if (!list_empty(&send_ctx->msg_list)) {
1154                         struct smb_direct_sendmsg *last;
1155
1156                         last = list_last_entry(&send_ctx->msg_list,
1157                                                struct smb_direct_sendmsg,
1158                                                list);
1159                         last->wr.next = &msg->wr;
1160                 }
1161                 list_add_tail(&msg->list, &send_ctx->msg_list);
1162                 send_ctx->wr_cnt++;
1163                 return 0;
1164         }
1165
1166         msg->wr.wr_cqe = &msg->cqe;
1167         msg->wr.send_flags = IB_SEND_SIGNALED;
1168         return smb_direct_post_send(t, &msg->wr);
1169 }
1170
1171 static int smb_direct_post_send_data(struct smb_direct_transport *t,
1172                                      struct smb_direct_send_ctx *send_ctx,
1173                                      struct kvec *iov, int niov,
1174                                      int remaining_data_length)
1175 {
1176         int i, j, ret;
1177         struct smb_direct_sendmsg *msg;
1178         int data_length;
1179         struct scatterlist sg[SMB_DIRECT_MAX_SEND_SGES - 1];
1180
1181         ret = wait_for_send_credits(t, send_ctx);
1182         if (ret)
1183                 return ret;
1184
1185         data_length = 0;
1186         for (i = 0; i < niov; i++)
1187                 data_length += iov[i].iov_len;
1188
1189         ret = smb_direct_create_header(t, data_length, remaining_data_length,
1190                                        &msg);
1191         if (ret) {
1192                 atomic_inc(&t->send_credits);
1193                 return ret;
1194         }
1195
1196         for (i = 0; i < niov; i++) {
1197                 struct ib_sge *sge;
1198                 int sg_cnt;
1199
1200                 sg_init_table(sg, SMB_DIRECT_MAX_SEND_SGES - 1);
1201                 sg_cnt = get_mapped_sg_list(t->cm_id->device,
1202                                             iov[i].iov_base, iov[i].iov_len,
1203                                             sg, SMB_DIRECT_MAX_SEND_SGES - 1,
1204                                             DMA_TO_DEVICE);
1205                 if (sg_cnt <= 0) {
1206                         pr_err("failed to map buffer\n");
1207                         ret = -ENOMEM;
1208                         goto err;
1209                 } else if (sg_cnt + msg->num_sge > SMB_DIRECT_MAX_SEND_SGES) {
1210                         pr_err("buffer not fitted into sges\n");
1211                         ret = -E2BIG;
1212                         ib_dma_unmap_sg(t->cm_id->device, sg, sg_cnt,
1213                                         DMA_TO_DEVICE);
1214                         goto err;
1215                 }
1216
1217                 for (j = 0; j < sg_cnt; j++) {
1218                         sge = &msg->sge[msg->num_sge];
1219                         sge->addr = sg_dma_address(&sg[j]);
1220                         sge->length = sg_dma_len(&sg[j]);
1221                         sge->lkey  = t->pd->local_dma_lkey;
1222                         msg->num_sge++;
1223                 }
1224         }
1225
1226         ret = post_sendmsg(t, send_ctx, msg);
1227         if (ret)
1228                 goto err;
1229         return 0;
1230 err:
1231         smb_direct_free_sendmsg(t, msg);
1232         atomic_inc(&t->send_credits);
1233         return ret;
1234 }
1235
1236 static int smb_direct_writev(struct ksmbd_transport *t,
1237                              struct kvec *iov, int niovs, int buflen,
1238                              bool need_invalidate, unsigned int remote_key)
1239 {
1240         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1241         int remaining_data_length;
1242         int start, i, j;
1243         int max_iov_size = st->max_send_size -
1244                         sizeof(struct smb_direct_data_transfer);
1245         int ret;
1246         struct kvec vec;
1247         struct smb_direct_send_ctx send_ctx;
1248
1249         if (st->status != SMB_DIRECT_CS_CONNECTED)
1250                 return -ENOTCONN;
1251
1252         //FIXME: skip RFC1002 header..
1253         buflen -= 4;
1254         iov[0].iov_base += 4;
1255         iov[0].iov_len -= 4;
1256
1257         remaining_data_length = buflen;
1258         ksmbd_debug(RDMA, "Sending smb (RDMA): smb_len=%u\n", buflen);
1259
1260         smb_direct_send_ctx_init(st, &send_ctx, need_invalidate, remote_key);
1261         start = i = 0;
1262         buflen = 0;
1263         while (true) {
1264                 buflen += iov[i].iov_len;
1265                 if (buflen > max_iov_size) {
1266                         if (i > start) {
1267                                 remaining_data_length -=
1268                                         (buflen - iov[i].iov_len);
1269                                 ret = smb_direct_post_send_data(st, &send_ctx,
1270                                                                 &iov[start], i - start,
1271                                                                 remaining_data_length);
1272                                 if (ret)
1273                                         goto done;
1274                         } else {
1275                                 /* iov[start] is too big, break it */
1276                                 int nvec  = (buflen + max_iov_size - 1) /
1277                                                 max_iov_size;
1278
1279                                 for (j = 0; j < nvec; j++) {
1280                                         vec.iov_base =
1281                                                 (char *)iov[start].iov_base +
1282                                                 j * max_iov_size;
1283                                         vec.iov_len =
1284                                                 min_t(int, max_iov_size,
1285                                                       buflen - max_iov_size * j);
1286                                         remaining_data_length -= vec.iov_len;
1287                                         ret = smb_direct_post_send_data(st, &send_ctx, &vec, 1,
1288                                                                         remaining_data_length);
1289                                         if (ret)
1290                                                 goto done;
1291                                 }
1292                                 i++;
1293                                 if (i == niovs)
1294                                         break;
1295                         }
1296                         start = i;
1297                         buflen = 0;
1298                 } else {
1299                         i++;
1300                         if (i == niovs) {
1301                                 /* send out all remaining vecs */
1302                                 remaining_data_length -= buflen;
1303                                 ret = smb_direct_post_send_data(st, &send_ctx,
1304                                                                 &iov[start], i - start,
1305                                                                 remaining_data_length);
1306                                 if (ret)
1307                                         goto done;
1308                                 break;
1309                         }
1310                 }
1311         }
1312
1313 done:
1314         ret = smb_direct_flush_send_list(st, &send_ctx, true);
1315
1316         /*
1317          * As an optimization, we don't wait for individual I/O to finish
1318          * before sending the next one.
1319          * Send them all and wait for pending send count to get to 0
1320          * that means all the I/Os have been out and we are good to return
1321          */
1322
1323         wait_event(st->wait_send_pending,
1324                    atomic_read(&st->send_pending) == 0);
1325         return ret;
1326 }
1327
1328 static void smb_direct_free_rdma_rw_msg(struct smb_direct_transport *t,
1329                                         struct smb_direct_rdma_rw_msg *msg,
1330                                         enum dma_data_direction dir)
1331 {
1332         rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
1333                             msg->sgt.sgl, msg->sgt.nents, dir);
1334         sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
1335         kfree(msg);
1336 }
1337
1338 static void read_write_done(struct ib_cq *cq, struct ib_wc *wc,
1339                             enum dma_data_direction dir)
1340 {
1341         struct smb_direct_rdma_rw_msg *msg = container_of(wc->wr_cqe,
1342                                                           struct smb_direct_rdma_rw_msg, cqe);
1343         struct smb_direct_transport *t = msg->t;
1344
1345         if (wc->status != IB_WC_SUCCESS) {
1346                 msg->status = -EIO;
1347                 pr_err("read/write error. opcode = %d, status = %s(%d)\n",
1348                        wc->opcode, ib_wc_status_msg(wc->status), wc->status);
1349                 if (wc->status != IB_WC_WR_FLUSH_ERR)
1350                         smb_direct_disconnect_rdma_connection(t);
1351         }
1352
1353         complete(msg->completion);
1354 }
1355
1356 static void read_done(struct ib_cq *cq, struct ib_wc *wc)
1357 {
1358         read_write_done(cq, wc, DMA_FROM_DEVICE);
1359 }
1360
1361 static void write_done(struct ib_cq *cq, struct ib_wc *wc)
1362 {
1363         read_write_done(cq, wc, DMA_TO_DEVICE);
1364 }
1365
1366 static int smb_direct_rdma_xmit(struct smb_direct_transport *t,
1367                                 void *buf, int buf_len,
1368                                 struct smb2_buffer_desc_v1 *desc,
1369                                 unsigned int desc_len,
1370                                 bool is_read)
1371 {
1372         struct smb_direct_rdma_rw_msg *msg, *next_msg;
1373         int i, ret;
1374         DECLARE_COMPLETION_ONSTACK(completion);
1375         struct ib_send_wr *first_wr;
1376         LIST_HEAD(msg_list);
1377         char *desc_buf;
1378         int credits_needed;
1379         unsigned int desc_buf_len;
1380         size_t total_length = 0;
1381
1382         if (t->status != SMB_DIRECT_CS_CONNECTED)
1383                 return -ENOTCONN;
1384
1385         /* calculate needed credits */
1386         credits_needed = 0;
1387         desc_buf = buf;
1388         for (i = 0; i < desc_len / sizeof(*desc); i++) {
1389                 desc_buf_len = le32_to_cpu(desc[i].length);
1390
1391                 credits_needed += calc_rw_credits(t, desc_buf, desc_buf_len);
1392                 desc_buf += desc_buf_len;
1393                 total_length += desc_buf_len;
1394                 if (desc_buf_len == 0 || total_length > buf_len ||
1395                     total_length > t->max_rdma_rw_size)
1396                         return -EINVAL;
1397         }
1398
1399         ksmbd_debug(RDMA, "RDMA %s, len %#x, needed credits %#x\n",
1400                     is_read ? "read" : "write", buf_len, credits_needed);
1401
1402         ret = wait_for_rw_credits(t, credits_needed);
1403         if (ret < 0)
1404                 return ret;
1405
1406         /* build rdma_rw_ctx for each descriptor */
1407         desc_buf = buf;
1408         for (i = 0; i < desc_len / sizeof(*desc); i++) {
1409                 msg = kzalloc(offsetof(struct smb_direct_rdma_rw_msg, sg_list) +
1410                               sizeof(struct scatterlist) * SG_CHUNK_SIZE, GFP_KERNEL);
1411                 if (!msg) {
1412                         ret = -ENOMEM;
1413                         goto out;
1414                 }
1415
1416                 desc_buf_len = le32_to_cpu(desc[i].length);
1417
1418                 msg->t = t;
1419                 msg->cqe.done = is_read ? read_done : write_done;
1420                 msg->completion = &completion;
1421
1422                 msg->sgt.sgl = &msg->sg_list[0];
1423                 ret = sg_alloc_table_chained(&msg->sgt,
1424                                              get_buf_page_count(desc_buf, desc_buf_len),
1425                                              msg->sg_list, SG_CHUNK_SIZE);
1426                 if (ret) {
1427                         kfree(msg);
1428                         ret = -ENOMEM;
1429                         goto out;
1430                 }
1431
1432                 ret = get_sg_list(desc_buf, desc_buf_len,
1433                                   msg->sgt.sgl, msg->sgt.orig_nents);
1434                 if (ret < 0) {
1435                         sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
1436                         kfree(msg);
1437                         goto out;
1438                 }
1439
1440                 ret = rdma_rw_ctx_init(&msg->rw_ctx, t->qp, t->qp->port,
1441                                        msg->sgt.sgl,
1442                                        get_buf_page_count(desc_buf, desc_buf_len),
1443                                        0,
1444                                        le64_to_cpu(desc[i].offset),
1445                                        le32_to_cpu(desc[i].token),
1446                                        is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
1447                 if (ret < 0) {
1448                         pr_err("failed to init rdma_rw_ctx: %d\n", ret);
1449                         sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
1450                         kfree(msg);
1451                         goto out;
1452                 }
1453
1454                 list_add_tail(&msg->list, &msg_list);
1455                 desc_buf += desc_buf_len;
1456         }
1457
1458         /* concatenate work requests of rdma_rw_ctxs */
1459         first_wr = NULL;
1460         list_for_each_entry_reverse(msg, &msg_list, list) {
1461                 first_wr = rdma_rw_ctx_wrs(&msg->rw_ctx, t->qp, t->qp->port,
1462                                            &msg->cqe, first_wr);
1463         }
1464
1465         ret = ib_post_send(t->qp, first_wr, NULL);
1466         if (ret) {
1467                 pr_err("failed to post send wr for RDMA R/W: %d\n", ret);
1468                 goto out;
1469         }
1470
1471         msg = list_last_entry(&msg_list, struct smb_direct_rdma_rw_msg, list);
1472         wait_for_completion(&completion);
1473         ret = msg->status;
1474 out:
1475         list_for_each_entry_safe(msg, next_msg, &msg_list, list) {
1476                 list_del(&msg->list);
1477                 smb_direct_free_rdma_rw_msg(t, msg,
1478                                             is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
1479         }
1480         atomic_add(credits_needed, &t->rw_credits);
1481         wake_up(&t->wait_rw_credits);
1482         return ret;
1483 }
1484
1485 static int smb_direct_rdma_write(struct ksmbd_transport *t,
1486                                  void *buf, unsigned int buflen,
1487                                  struct smb2_buffer_desc_v1 *desc,
1488                                  unsigned int desc_len)
1489 {
1490         return smb_direct_rdma_xmit(smb_trans_direct_transfort(t), buf, buflen,
1491                                     desc, desc_len, false);
1492 }
1493
1494 static int smb_direct_rdma_read(struct ksmbd_transport *t,
1495                                 void *buf, unsigned int buflen,
1496                                 struct smb2_buffer_desc_v1 *desc,
1497                                 unsigned int desc_len)
1498 {
1499         return smb_direct_rdma_xmit(smb_trans_direct_transfort(t), buf, buflen,
1500                                     desc, desc_len, true);
1501 }
1502
1503 static void smb_direct_disconnect(struct ksmbd_transport *t)
1504 {
1505         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1506
1507         ksmbd_debug(RDMA, "Disconnecting cm_id=%p\n", st->cm_id);
1508
1509         smb_direct_disconnect_rdma_work(&st->disconnect_work);
1510         wait_event_interruptible(st->wait_status,
1511                                  st->status == SMB_DIRECT_CS_DISCONNECTED);
1512         free_transport(st);
1513 }
1514
1515 static void smb_direct_shutdown(struct ksmbd_transport *t)
1516 {
1517         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1518
1519         ksmbd_debug(RDMA, "smb-direct shutdown cm_id=%p\n", st->cm_id);
1520
1521         smb_direct_disconnect_rdma_work(&st->disconnect_work);
1522 }
1523
1524 static int smb_direct_cm_handler(struct rdma_cm_id *cm_id,
1525                                  struct rdma_cm_event *event)
1526 {
1527         struct smb_direct_transport *t = cm_id->context;
1528
1529         ksmbd_debug(RDMA, "RDMA CM event. cm_id=%p event=%s (%d)\n",
1530                     cm_id, rdma_event_msg(event->event), event->event);
1531
1532         switch (event->event) {
1533         case RDMA_CM_EVENT_ESTABLISHED: {
1534                 t->status = SMB_DIRECT_CS_CONNECTED;
1535                 wake_up_interruptible(&t->wait_status);
1536                 break;
1537         }
1538         case RDMA_CM_EVENT_DEVICE_REMOVAL:
1539         case RDMA_CM_EVENT_DISCONNECTED: {
1540                 t->status = SMB_DIRECT_CS_DISCONNECTED;
1541                 wake_up_interruptible(&t->wait_status);
1542                 wake_up_interruptible(&t->wait_reassembly_queue);
1543                 wake_up(&t->wait_send_credits);
1544                 break;
1545         }
1546         case RDMA_CM_EVENT_CONNECT_ERROR: {
1547                 t->status = SMB_DIRECT_CS_DISCONNECTED;
1548                 wake_up_interruptible(&t->wait_status);
1549                 break;
1550         }
1551         default:
1552                 pr_err("Unexpected RDMA CM event. cm_id=%p, event=%s (%d)\n",
1553                        cm_id, rdma_event_msg(event->event),
1554                        event->event);
1555                 break;
1556         }
1557         return 0;
1558 }
1559
1560 static void smb_direct_qpair_handler(struct ib_event *event, void *context)
1561 {
1562         struct smb_direct_transport *t = context;
1563
1564         ksmbd_debug(RDMA, "Received QP event. cm_id=%p, event=%s (%d)\n",
1565                     t->cm_id, ib_event_msg(event->event), event->event);
1566
1567         switch (event->event) {
1568         case IB_EVENT_CQ_ERR:
1569         case IB_EVENT_QP_FATAL:
1570                 smb_direct_disconnect_rdma_connection(t);
1571                 break;
1572         default:
1573                 break;
1574         }
1575 }
1576
1577 static int smb_direct_send_negotiate_response(struct smb_direct_transport *t,
1578                                               int failed)
1579 {
1580         struct smb_direct_sendmsg *sendmsg;
1581         struct smb_direct_negotiate_resp *resp;
1582         int ret;
1583
1584         sendmsg = smb_direct_alloc_sendmsg(t);
1585         if (IS_ERR(sendmsg))
1586                 return -ENOMEM;
1587
1588         resp = (struct smb_direct_negotiate_resp *)sendmsg->packet;
1589         if (failed) {
1590                 memset(resp, 0, sizeof(*resp));
1591                 resp->min_version = cpu_to_le16(0x0100);
1592                 resp->max_version = cpu_to_le16(0x0100);
1593                 resp->status = STATUS_NOT_SUPPORTED;
1594         } else {
1595                 resp->status = STATUS_SUCCESS;
1596                 resp->min_version = SMB_DIRECT_VERSION_LE;
1597                 resp->max_version = SMB_DIRECT_VERSION_LE;
1598                 resp->negotiated_version = SMB_DIRECT_VERSION_LE;
1599                 resp->reserved = 0;
1600                 resp->credits_requested =
1601                                 cpu_to_le16(t->send_credit_target);
1602                 resp->credits_granted = cpu_to_le16(manage_credits_prior_sending(t));
1603                 resp->max_readwrite_size = cpu_to_le32(t->max_rdma_rw_size);
1604                 resp->preferred_send_size = cpu_to_le32(t->max_send_size);
1605                 resp->max_receive_size = cpu_to_le32(t->max_recv_size);
1606                 resp->max_fragmented_size =
1607                                 cpu_to_le32(t->max_fragmented_recv_size);
1608         }
1609
1610         sendmsg->sge[0].addr = ib_dma_map_single(t->cm_id->device,
1611                                                  (void *)resp, sizeof(*resp),
1612                                                  DMA_TO_DEVICE);
1613         ret = ib_dma_mapping_error(t->cm_id->device, sendmsg->sge[0].addr);
1614         if (ret) {
1615                 smb_direct_free_sendmsg(t, sendmsg);
1616                 return ret;
1617         }
1618
1619         sendmsg->num_sge = 1;
1620         sendmsg->sge[0].length = sizeof(*resp);
1621         sendmsg->sge[0].lkey = t->pd->local_dma_lkey;
1622
1623         ret = post_sendmsg(t, NULL, sendmsg);
1624         if (ret) {
1625                 smb_direct_free_sendmsg(t, sendmsg);
1626                 return ret;
1627         }
1628
1629         wait_event(t->wait_send_pending,
1630                    atomic_read(&t->send_pending) == 0);
1631         return 0;
1632 }
1633
1634 static int smb_direct_accept_client(struct smb_direct_transport *t)
1635 {
1636         struct rdma_conn_param conn_param;
1637         struct ib_port_immutable port_immutable;
1638         u32 ird_ord_hdr[2];
1639         int ret;
1640
1641         memset(&conn_param, 0, sizeof(conn_param));
1642         conn_param.initiator_depth = min_t(u8, t->cm_id->device->attrs.max_qp_rd_atom,
1643                                            SMB_DIRECT_CM_INITIATOR_DEPTH);
1644         conn_param.responder_resources = 0;
1645
1646         t->cm_id->device->ops.get_port_immutable(t->cm_id->device,
1647                                                  t->cm_id->port_num,
1648                                                  &port_immutable);
1649         if (port_immutable.core_cap_flags & RDMA_CORE_PORT_IWARP) {
1650                 ird_ord_hdr[0] = conn_param.responder_resources;
1651                 ird_ord_hdr[1] = 1;
1652                 conn_param.private_data = ird_ord_hdr;
1653                 conn_param.private_data_len = sizeof(ird_ord_hdr);
1654         } else {
1655                 conn_param.private_data = NULL;
1656                 conn_param.private_data_len = 0;
1657         }
1658         conn_param.retry_count = SMB_DIRECT_CM_RETRY;
1659         conn_param.rnr_retry_count = SMB_DIRECT_CM_RNR_RETRY;
1660         conn_param.flow_control = 0;
1661
1662         ret = rdma_accept(t->cm_id, &conn_param);
1663         if (ret) {
1664                 pr_err("error at rdma_accept: %d\n", ret);
1665                 return ret;
1666         }
1667         return 0;
1668 }
1669
1670 static int smb_direct_prepare_negotiation(struct smb_direct_transport *t)
1671 {
1672         int ret;
1673         struct smb_direct_recvmsg *recvmsg;
1674
1675         recvmsg = get_free_recvmsg(t);
1676         if (!recvmsg)
1677                 return -ENOMEM;
1678         recvmsg->type = SMB_DIRECT_MSG_NEGOTIATE_REQ;
1679
1680         ret = smb_direct_post_recv(t, recvmsg);
1681         if (ret) {
1682                 pr_err("Can't post recv: %d\n", ret);
1683                 goto out_err;
1684         }
1685
1686         t->negotiation_requested = false;
1687         ret = smb_direct_accept_client(t);
1688         if (ret) {
1689                 pr_err("Can't accept client\n");
1690                 goto out_err;
1691         }
1692
1693         smb_direct_post_recv_credits(&t->post_recv_credits_work.work);
1694         return 0;
1695 out_err:
1696         put_recvmsg(t, recvmsg);
1697         return ret;
1698 }
1699
1700 static unsigned int smb_direct_get_max_fr_pages(struct smb_direct_transport *t)
1701 {
1702         return min_t(unsigned int,
1703                      t->cm_id->device->attrs.max_fast_reg_page_list_len,
1704                      256);
1705 }
1706
1707 static int smb_direct_init_params(struct smb_direct_transport *t,
1708                                   struct ib_qp_cap *cap)
1709 {
1710         struct ib_device *device = t->cm_id->device;
1711         int max_send_sges, max_rw_wrs, max_send_wrs;
1712         unsigned int max_sge_per_wr, wrs_per_credit;
1713
1714         /* need 3 more sge. because a SMB_DIRECT header, SMB2 header,
1715          * SMB2 response could be mapped.
1716          */
1717         t->max_send_size = smb_direct_max_send_size;
1718         max_send_sges = DIV_ROUND_UP(t->max_send_size, PAGE_SIZE) + 3;
1719         if (max_send_sges > SMB_DIRECT_MAX_SEND_SGES) {
1720                 pr_err("max_send_size %d is too large\n", t->max_send_size);
1721                 return -EINVAL;
1722         }
1723
1724         /* Calculate the number of work requests for RDMA R/W.
1725          * The maximum number of pages which can be registered
1726          * with one Memory region can be transferred with one
1727          * R/W credit. And at least 4 work requests for each credit
1728          * are needed for MR registration, RDMA R/W, local & remote
1729          * MR invalidation.
1730          */
1731         t->max_rdma_rw_size = smb_direct_max_read_write_size;
1732         t->pages_per_rw_credit = smb_direct_get_max_fr_pages(t);
1733         t->max_rw_credits = DIV_ROUND_UP(t->max_rdma_rw_size,
1734                                          (t->pages_per_rw_credit - 1) *
1735                                          PAGE_SIZE);
1736
1737         max_sge_per_wr = min_t(unsigned int, device->attrs.max_send_sge,
1738                                device->attrs.max_sge_rd);
1739         max_sge_per_wr = max_t(unsigned int, max_sge_per_wr,
1740                                max_send_sges);
1741         wrs_per_credit = max_t(unsigned int, 4,
1742                                DIV_ROUND_UP(t->pages_per_rw_credit,
1743                                             max_sge_per_wr) + 1);
1744         max_rw_wrs = t->max_rw_credits * wrs_per_credit;
1745
1746         max_send_wrs = smb_direct_send_credit_target + max_rw_wrs;
1747         if (max_send_wrs > device->attrs.max_cqe ||
1748             max_send_wrs > device->attrs.max_qp_wr) {
1749                 pr_err("consider lowering send_credit_target = %d\n",
1750                        smb_direct_send_credit_target);
1751                 pr_err("Possible CQE overrun, device reporting max_cqe %d max_qp_wr %d\n",
1752                        device->attrs.max_cqe, device->attrs.max_qp_wr);
1753                 return -EINVAL;
1754         }
1755
1756         if (smb_direct_receive_credit_max > device->attrs.max_cqe ||
1757             smb_direct_receive_credit_max > device->attrs.max_qp_wr) {
1758                 pr_err("consider lowering receive_credit_max = %d\n",
1759                        smb_direct_receive_credit_max);
1760                 pr_err("Possible CQE overrun, device reporting max_cpe %d max_qp_wr %d\n",
1761                        device->attrs.max_cqe, device->attrs.max_qp_wr);
1762                 return -EINVAL;
1763         }
1764
1765         if (device->attrs.max_recv_sge < SMB_DIRECT_MAX_RECV_SGES) {
1766                 pr_err("warning: device max_recv_sge = %d too small\n",
1767                        device->attrs.max_recv_sge);
1768                 return -EINVAL;
1769         }
1770
1771         t->recv_credits = 0;
1772         t->count_avail_recvmsg = 0;
1773
1774         t->recv_credit_max = smb_direct_receive_credit_max;
1775         t->recv_credit_target = 10;
1776         t->new_recv_credits = 0;
1777
1778         t->send_credit_target = smb_direct_send_credit_target;
1779         atomic_set(&t->send_credits, 0);
1780         atomic_set(&t->rw_credits, t->max_rw_credits);
1781
1782         t->max_send_size = smb_direct_max_send_size;
1783         t->max_recv_size = smb_direct_max_receive_size;
1784         t->max_fragmented_recv_size = smb_direct_max_fragmented_recv_size;
1785
1786         cap->max_send_wr = max_send_wrs;
1787         cap->max_recv_wr = t->recv_credit_max;
1788         cap->max_send_sge = max_sge_per_wr;
1789         cap->max_recv_sge = SMB_DIRECT_MAX_RECV_SGES;
1790         cap->max_inline_data = 0;
1791         cap->max_rdma_ctxs = t->max_rw_credits;
1792         return 0;
1793 }
1794
1795 static void smb_direct_destroy_pools(struct smb_direct_transport *t)
1796 {
1797         struct smb_direct_recvmsg *recvmsg;
1798
1799         while ((recvmsg = get_free_recvmsg(t)))
1800                 mempool_free(recvmsg, t->recvmsg_mempool);
1801         while ((recvmsg = get_empty_recvmsg(t)))
1802                 mempool_free(recvmsg, t->recvmsg_mempool);
1803
1804         mempool_destroy(t->recvmsg_mempool);
1805         t->recvmsg_mempool = NULL;
1806
1807         kmem_cache_destroy(t->recvmsg_cache);
1808         t->recvmsg_cache = NULL;
1809
1810         mempool_destroy(t->sendmsg_mempool);
1811         t->sendmsg_mempool = NULL;
1812
1813         kmem_cache_destroy(t->sendmsg_cache);
1814         t->sendmsg_cache = NULL;
1815 }
1816
1817 static int smb_direct_create_pools(struct smb_direct_transport *t)
1818 {
1819         char name[80];
1820         int i;
1821         struct smb_direct_recvmsg *recvmsg;
1822
1823         snprintf(name, sizeof(name), "smb_direct_rqst_pool_%p", t);
1824         t->sendmsg_cache = kmem_cache_create(name,
1825                                              sizeof(struct smb_direct_sendmsg) +
1826                                               sizeof(struct smb_direct_negotiate_resp),
1827                                              0, SLAB_HWCACHE_ALIGN, NULL);
1828         if (!t->sendmsg_cache)
1829                 return -ENOMEM;
1830
1831         t->sendmsg_mempool = mempool_create(t->send_credit_target,
1832                                             mempool_alloc_slab, mempool_free_slab,
1833                                             t->sendmsg_cache);
1834         if (!t->sendmsg_mempool)
1835                 goto err;
1836
1837         snprintf(name, sizeof(name), "smb_direct_resp_%p", t);
1838         t->recvmsg_cache = kmem_cache_create(name,
1839                                              sizeof(struct smb_direct_recvmsg) +
1840                                               t->max_recv_size,
1841                                              0, SLAB_HWCACHE_ALIGN, NULL);
1842         if (!t->recvmsg_cache)
1843                 goto err;
1844
1845         t->recvmsg_mempool =
1846                 mempool_create(t->recv_credit_max, mempool_alloc_slab,
1847                                mempool_free_slab, t->recvmsg_cache);
1848         if (!t->recvmsg_mempool)
1849                 goto err;
1850
1851         INIT_LIST_HEAD(&t->recvmsg_queue);
1852
1853         for (i = 0; i < t->recv_credit_max; i++) {
1854                 recvmsg = mempool_alloc(t->recvmsg_mempool, GFP_KERNEL);
1855                 if (!recvmsg)
1856                         goto err;
1857                 recvmsg->transport = t;
1858                 list_add(&recvmsg->list, &t->recvmsg_queue);
1859         }
1860         t->count_avail_recvmsg = t->recv_credit_max;
1861
1862         return 0;
1863 err:
1864         smb_direct_destroy_pools(t);
1865         return -ENOMEM;
1866 }
1867
1868 static int smb_direct_create_qpair(struct smb_direct_transport *t,
1869                                    struct ib_qp_cap *cap)
1870 {
1871         int ret;
1872         struct ib_qp_init_attr qp_attr;
1873         int pages_per_rw;
1874
1875         t->pd = ib_alloc_pd(t->cm_id->device, 0);
1876         if (IS_ERR(t->pd)) {
1877                 pr_err("Can't create RDMA PD\n");
1878                 ret = PTR_ERR(t->pd);
1879                 t->pd = NULL;
1880                 return ret;
1881         }
1882
1883         t->send_cq = ib_alloc_cq(t->cm_id->device, t,
1884                                  smb_direct_send_credit_target + cap->max_rdma_ctxs,
1885                                  0, IB_POLL_WORKQUEUE);
1886         if (IS_ERR(t->send_cq)) {
1887                 pr_err("Can't create RDMA send CQ\n");
1888                 ret = PTR_ERR(t->send_cq);
1889                 t->send_cq = NULL;
1890                 goto err;
1891         }
1892
1893         t->recv_cq = ib_alloc_cq(t->cm_id->device, t,
1894                                  t->recv_credit_max, 0, IB_POLL_WORKQUEUE);
1895         if (IS_ERR(t->recv_cq)) {
1896                 pr_err("Can't create RDMA recv CQ\n");
1897                 ret = PTR_ERR(t->recv_cq);
1898                 t->recv_cq = NULL;
1899                 goto err;
1900         }
1901
1902         memset(&qp_attr, 0, sizeof(qp_attr));
1903         qp_attr.event_handler = smb_direct_qpair_handler;
1904         qp_attr.qp_context = t;
1905         qp_attr.cap = *cap;
1906         qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR;
1907         qp_attr.qp_type = IB_QPT_RC;
1908         qp_attr.send_cq = t->send_cq;
1909         qp_attr.recv_cq = t->recv_cq;
1910         qp_attr.port_num = ~0;
1911
1912         ret = rdma_create_qp(t->cm_id, t->pd, &qp_attr);
1913         if (ret) {
1914                 pr_err("Can't create RDMA QP: %d\n", ret);
1915                 goto err;
1916         }
1917
1918         t->qp = t->cm_id->qp;
1919         t->cm_id->event_handler = smb_direct_cm_handler;
1920
1921         pages_per_rw = DIV_ROUND_UP(t->max_rdma_rw_size, PAGE_SIZE) + 1;
1922         if (pages_per_rw > t->cm_id->device->attrs.max_sgl_rd) {
1923                 ret = ib_mr_pool_init(t->qp, &t->qp->rdma_mrs,
1924                                       t->max_rw_credits, IB_MR_TYPE_MEM_REG,
1925                                       t->pages_per_rw_credit, 0);
1926                 if (ret) {
1927                         pr_err("failed to init mr pool count %d pages %d\n",
1928                                t->max_rw_credits, t->pages_per_rw_credit);
1929                         goto err;
1930                 }
1931         }
1932
1933         return 0;
1934 err:
1935         if (t->qp) {
1936                 ib_destroy_qp(t->qp);
1937                 t->qp = NULL;
1938         }
1939         if (t->recv_cq) {
1940                 ib_destroy_cq(t->recv_cq);
1941                 t->recv_cq = NULL;
1942         }
1943         if (t->send_cq) {
1944                 ib_destroy_cq(t->send_cq);
1945                 t->send_cq = NULL;
1946         }
1947         if (t->pd) {
1948                 ib_dealloc_pd(t->pd);
1949                 t->pd = NULL;
1950         }
1951         return ret;
1952 }
1953
1954 static int smb_direct_prepare(struct ksmbd_transport *t)
1955 {
1956         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1957         struct smb_direct_recvmsg *recvmsg;
1958         struct smb_direct_negotiate_req *req;
1959         int ret;
1960
1961         ksmbd_debug(RDMA, "Waiting for SMB_DIRECT negotiate request\n");
1962         ret = wait_event_interruptible_timeout(st->wait_status,
1963                                                st->negotiation_requested ||
1964                                                st->status == SMB_DIRECT_CS_DISCONNECTED,
1965                                                SMB_DIRECT_NEGOTIATE_TIMEOUT * HZ);
1966         if (ret <= 0 || st->status == SMB_DIRECT_CS_DISCONNECTED)
1967                 return ret < 0 ? ret : -ETIMEDOUT;
1968
1969         recvmsg = get_first_reassembly(st);
1970         if (!recvmsg)
1971                 return -ECONNABORTED;
1972
1973         ret = smb_direct_check_recvmsg(recvmsg);
1974         if (ret == -ECONNABORTED)
1975                 goto out;
1976
1977         req = (struct smb_direct_negotiate_req *)recvmsg->packet;
1978         st->max_recv_size = min_t(int, st->max_recv_size,
1979                                   le32_to_cpu(req->preferred_send_size));
1980         st->max_send_size = min_t(int, st->max_send_size,
1981                                   le32_to_cpu(req->max_receive_size));
1982         st->max_fragmented_send_size =
1983                 le32_to_cpu(req->max_fragmented_size);
1984         st->max_fragmented_recv_size =
1985                 (st->recv_credit_max * st->max_recv_size) / 2;
1986
1987         ret = smb_direct_send_negotiate_response(st, ret);
1988 out:
1989         spin_lock_irq(&st->reassembly_queue_lock);
1990         st->reassembly_queue_length--;
1991         list_del(&recvmsg->list);
1992         spin_unlock_irq(&st->reassembly_queue_lock);
1993         put_recvmsg(st, recvmsg);
1994
1995         return ret;
1996 }
1997
1998 static int smb_direct_connect(struct smb_direct_transport *st)
1999 {
2000         int ret;
2001         struct ib_qp_cap qp_cap;
2002
2003         ret = smb_direct_init_params(st, &qp_cap);
2004         if (ret) {
2005                 pr_err("Can't configure RDMA parameters\n");
2006                 return ret;
2007         }
2008
2009         ret = smb_direct_create_pools(st);
2010         if (ret) {
2011                 pr_err("Can't init RDMA pool: %d\n", ret);
2012                 return ret;
2013         }
2014
2015         ret = smb_direct_create_qpair(st, &qp_cap);
2016         if (ret) {
2017                 pr_err("Can't accept RDMA client: %d\n", ret);
2018                 return ret;
2019         }
2020
2021         ret = smb_direct_prepare_negotiation(st);
2022         if (ret) {
2023                 pr_err("Can't negotiate: %d\n", ret);
2024                 return ret;
2025         }
2026         return 0;
2027 }
2028
2029 static bool rdma_frwr_is_supported(struct ib_device_attr *attrs)
2030 {
2031         if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS))
2032                 return false;
2033         if (attrs->max_fast_reg_page_list_len == 0)
2034                 return false;
2035         return true;
2036 }
2037
2038 static int smb_direct_handle_connect_request(struct rdma_cm_id *new_cm_id)
2039 {
2040         struct smb_direct_transport *t;
2041         int ret;
2042
2043         if (!rdma_frwr_is_supported(&new_cm_id->device->attrs)) {
2044                 ksmbd_debug(RDMA,
2045                             "Fast Registration Work Requests is not supported. device capabilities=%llx\n",
2046                             new_cm_id->device->attrs.device_cap_flags);
2047                 return -EPROTONOSUPPORT;
2048         }
2049
2050         t = alloc_transport(new_cm_id);
2051         if (!t)
2052                 return -ENOMEM;
2053
2054         ret = smb_direct_connect(t);
2055         if (ret)
2056                 goto out_err;
2057
2058         KSMBD_TRANS(t)->handler = kthread_run(ksmbd_conn_handler_loop,
2059                                               KSMBD_TRANS(t)->conn, "ksmbd:r%u",
2060                                               smb_direct_port);
2061         if (IS_ERR(KSMBD_TRANS(t)->handler)) {
2062                 ret = PTR_ERR(KSMBD_TRANS(t)->handler);
2063                 pr_err("Can't start thread\n");
2064                 goto out_err;
2065         }
2066
2067         return 0;
2068 out_err:
2069         free_transport(t);
2070         return ret;
2071 }
2072
2073 static int smb_direct_listen_handler(struct rdma_cm_id *cm_id,
2074                                      struct rdma_cm_event *event)
2075 {
2076         switch (event->event) {
2077         case RDMA_CM_EVENT_CONNECT_REQUEST: {
2078                 int ret = smb_direct_handle_connect_request(cm_id);
2079
2080                 if (ret) {
2081                         pr_err("Can't create transport: %d\n", ret);
2082                         return ret;
2083                 }
2084
2085                 ksmbd_debug(RDMA, "Received connection request. cm_id=%p\n",
2086                             cm_id);
2087                 break;
2088         }
2089         default:
2090                 pr_err("Unexpected listen event. cm_id=%p, event=%s (%d)\n",
2091                        cm_id, rdma_event_msg(event->event), event->event);
2092                 break;
2093         }
2094         return 0;
2095 }
2096
2097 static int smb_direct_listen(int port)
2098 {
2099         int ret;
2100         struct rdma_cm_id *cm_id;
2101         struct sockaddr_in sin = {
2102                 .sin_family             = AF_INET,
2103                 .sin_addr.s_addr        = htonl(INADDR_ANY),
2104                 .sin_port               = htons(port),
2105         };
2106
2107         cm_id = rdma_create_id(&init_net, smb_direct_listen_handler,
2108                                &smb_direct_listener, RDMA_PS_TCP, IB_QPT_RC);
2109         if (IS_ERR(cm_id)) {
2110                 pr_err("Can't create cm id: %ld\n", PTR_ERR(cm_id));
2111                 return PTR_ERR(cm_id);
2112         }
2113
2114         ret = rdma_bind_addr(cm_id, (struct sockaddr *)&sin);
2115         if (ret) {
2116                 pr_err("Can't bind: %d\n", ret);
2117                 goto err;
2118         }
2119
2120         smb_direct_listener.cm_id = cm_id;
2121
2122         ret = rdma_listen(cm_id, 10);
2123         if (ret) {
2124                 pr_err("Can't listen: %d\n", ret);
2125                 goto err;
2126         }
2127         return 0;
2128 err:
2129         smb_direct_listener.cm_id = NULL;
2130         rdma_destroy_id(cm_id);
2131         return ret;
2132 }
2133
2134 static int smb_direct_ib_client_add(struct ib_device *ib_dev)
2135 {
2136         struct smb_direct_device *smb_dev;
2137
2138         /* Set 5445 port if device type is iWARP(No IB) */
2139         if (ib_dev->node_type != RDMA_NODE_IB_CA)
2140                 smb_direct_port = SMB_DIRECT_PORT_IWARP;
2141
2142         if (!ib_dev->ops.get_netdev ||
2143             !rdma_frwr_is_supported(&ib_dev->attrs))
2144                 return 0;
2145
2146         smb_dev = kzalloc(sizeof(*smb_dev), GFP_KERNEL);
2147         if (!smb_dev)
2148                 return -ENOMEM;
2149         smb_dev->ib_dev = ib_dev;
2150
2151         write_lock(&smb_direct_device_lock);
2152         list_add(&smb_dev->list, &smb_direct_device_list);
2153         write_unlock(&smb_direct_device_lock);
2154
2155         ksmbd_debug(RDMA, "ib device added: name %s\n", ib_dev->name);
2156         return 0;
2157 }
2158
2159 static void smb_direct_ib_client_remove(struct ib_device *ib_dev,
2160                                         void *client_data)
2161 {
2162         struct smb_direct_device *smb_dev, *tmp;
2163
2164         write_lock(&smb_direct_device_lock);
2165         list_for_each_entry_safe(smb_dev, tmp, &smb_direct_device_list, list) {
2166                 if (smb_dev->ib_dev == ib_dev) {
2167                         list_del(&smb_dev->list);
2168                         kfree(smb_dev);
2169                         break;
2170                 }
2171         }
2172         write_unlock(&smb_direct_device_lock);
2173 }
2174
2175 static struct ib_client smb_direct_ib_client = {
2176         .name   = "ksmbd_smb_direct_ib",
2177         .add    = smb_direct_ib_client_add,
2178         .remove = smb_direct_ib_client_remove,
2179 };
2180
2181 int ksmbd_rdma_init(void)
2182 {
2183         int ret;
2184
2185         smb_direct_listener.cm_id = NULL;
2186
2187         ret = ib_register_client(&smb_direct_ib_client);
2188         if (ret) {
2189                 pr_err("failed to ib_register_client\n");
2190                 return ret;
2191         }
2192
2193         /* When a client is running out of send credits, the credits are
2194          * granted by the server's sending a packet using this queue.
2195          * This avoids the situation that a clients cannot send packets
2196          * for lack of credits
2197          */
2198         smb_direct_wq = alloc_workqueue("ksmbd-smb_direct-wq",
2199                                         WQ_HIGHPRI | WQ_MEM_RECLAIM, 0);
2200         if (!smb_direct_wq)
2201                 return -ENOMEM;
2202
2203         ret = smb_direct_listen(smb_direct_port);
2204         if (ret) {
2205                 destroy_workqueue(smb_direct_wq);
2206                 smb_direct_wq = NULL;
2207                 pr_err("Can't listen: %d\n", ret);
2208                 return ret;
2209         }
2210
2211         ksmbd_debug(RDMA, "init RDMA listener. cm_id=%p\n",
2212                     smb_direct_listener.cm_id);
2213         return 0;
2214 }
2215
2216 void ksmbd_rdma_destroy(void)
2217 {
2218         if (!smb_direct_listener.cm_id)
2219                 return;
2220
2221         ib_unregister_client(&smb_direct_ib_client);
2222         rdma_destroy_id(smb_direct_listener.cm_id);
2223
2224         smb_direct_listener.cm_id = NULL;
2225
2226         if (smb_direct_wq) {
2227                 destroy_workqueue(smb_direct_wq);
2228                 smb_direct_wq = NULL;
2229         }
2230 }
2231
2232 bool ksmbd_rdma_capable_netdev(struct net_device *netdev)
2233 {
2234         struct smb_direct_device *smb_dev;
2235         int i;
2236         bool rdma_capable = false;
2237
2238         read_lock(&smb_direct_device_lock);
2239         list_for_each_entry(smb_dev, &smb_direct_device_list, list) {
2240                 for (i = 0; i < smb_dev->ib_dev->phys_port_cnt; i++) {
2241                         struct net_device *ndev;
2242
2243                         ndev = smb_dev->ib_dev->ops.get_netdev(smb_dev->ib_dev,
2244                                                                i + 1);
2245                         if (!ndev)
2246                                 continue;
2247
2248                         if (ndev == netdev) {
2249                                 dev_put(ndev);
2250                                 rdma_capable = true;
2251                                 goto out;
2252                         }
2253                         dev_put(ndev);
2254                 }
2255         }
2256 out:
2257         read_unlock(&smb_direct_device_lock);
2258
2259         if (rdma_capable == false) {
2260                 struct ib_device *ibdev;
2261
2262                 ibdev = ib_device_get_by_netdev(netdev, RDMA_DRIVER_UNKNOWN);
2263                 if (ibdev) {
2264                         if (rdma_frwr_is_supported(&ibdev->attrs))
2265                                 rdma_capable = true;
2266                         ib_device_put(ibdev);
2267                 }
2268         }
2269
2270         return rdma_capable;
2271 }
2272
2273 static struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops = {
2274         .prepare        = smb_direct_prepare,
2275         .disconnect     = smb_direct_disconnect,
2276         .shutdown       = smb_direct_shutdown,
2277         .writev         = smb_direct_writev,
2278         .read           = smb_direct_read,
2279         .rdma_read      = smb_direct_rdma_read,
2280         .rdma_write     = smb_direct_rdma_write,
2281 };