soundwire: sysfs: add slave status and device number before probe
[linux-2.6-microblaze.git] / drivers / infiniband / ulp / rtrs / rtrs-srv.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * RDMA Transport Layer
4  *
5  * Copyright (c) 2014 - 2018 ProfitBricks GmbH. All rights reserved.
6  * Copyright (c) 2018 - 2019 1&1 IONOS Cloud GmbH. All rights reserved.
7  * Copyright (c) 2019 - 2020 1&1 IONOS SE. All rights reserved.
8  */
9
10 #undef pr_fmt
11 #define pr_fmt(fmt) KBUILD_MODNAME " L" __stringify(__LINE__) ": " fmt
12
13 #include <linux/module.h>
14 #include <linux/mempool.h>
15
16 #include "rtrs-srv.h"
17 #include "rtrs-log.h"
18 #include <rdma/ib_cm.h>
19
20 MODULE_DESCRIPTION("RDMA Transport Server");
21 MODULE_LICENSE("GPL");
22
23 /* Must be power of 2, see mask from mr->page_size in ib_sg_to_pages() */
24 #define DEFAULT_MAX_CHUNK_SIZE (128 << 10)
25 #define DEFAULT_SESS_QUEUE_DEPTH 512
26 #define MAX_HDR_SIZE PAGE_SIZE
27
28 /* We guarantee to serve 10 paths at least */
29 #define CHUNK_POOL_SZ 10
30
31 static struct rtrs_rdma_dev_pd dev_pd;
32 static mempool_t *chunk_pool;
33 struct class *rtrs_dev_class;
34
35 static int __read_mostly max_chunk_size = DEFAULT_MAX_CHUNK_SIZE;
36 static int __read_mostly sess_queue_depth = DEFAULT_SESS_QUEUE_DEPTH;
37
38 static bool always_invalidate = true;
39 module_param(always_invalidate, bool, 0444);
40 MODULE_PARM_DESC(always_invalidate,
41                  "Invalidate memory registration for contiguous memory regions before accessing.");
42
43 module_param_named(max_chunk_size, max_chunk_size, int, 0444);
44 MODULE_PARM_DESC(max_chunk_size,
45                  "Max size for each IO request, when change the unit is in byte (default: "
46                  __stringify(DEFAULT_MAX_CHUNK_SIZE) "KB)");
47
48 module_param_named(sess_queue_depth, sess_queue_depth, int, 0444);
49 MODULE_PARM_DESC(sess_queue_depth,
50                  "Number of buffers for pending I/O requests to allocate per session. Maximum: "
51                  __stringify(MAX_SESS_QUEUE_DEPTH) " (default: "
52                  __stringify(DEFAULT_SESS_QUEUE_DEPTH) ")");
53
54 static cpumask_t cq_affinity_mask = { CPU_BITS_ALL };
55
56 static struct workqueue_struct *rtrs_wq;
57
58 static inline struct rtrs_srv_con *to_srv_con(struct rtrs_con *c)
59 {
60         return container_of(c, struct rtrs_srv_con, c);
61 }
62
63 static inline struct rtrs_srv_sess *to_srv_sess(struct rtrs_sess *s)
64 {
65         return container_of(s, struct rtrs_srv_sess, s);
66 }
67
68 static bool __rtrs_srv_change_state(struct rtrs_srv_sess *sess,
69                                      enum rtrs_srv_state new_state)
70 {
71         enum rtrs_srv_state old_state;
72         bool changed = false;
73
74         lockdep_assert_held(&sess->state_lock);
75         old_state = sess->state;
76         switch (new_state) {
77         case RTRS_SRV_CONNECTED:
78                 switch (old_state) {
79                 case RTRS_SRV_CONNECTING:
80                         changed = true;
81                         fallthrough;
82                 default:
83                         break;
84                 }
85                 break;
86         case RTRS_SRV_CLOSING:
87                 switch (old_state) {
88                 case RTRS_SRV_CONNECTING:
89                 case RTRS_SRV_CONNECTED:
90                         changed = true;
91                         fallthrough;
92                 default:
93                         break;
94                 }
95                 break;
96         case RTRS_SRV_CLOSED:
97                 switch (old_state) {
98                 case RTRS_SRV_CLOSING:
99                         changed = true;
100                         fallthrough;
101                 default:
102                         break;
103                 }
104                 break;
105         default:
106                 break;
107         }
108         if (changed)
109                 sess->state = new_state;
110
111         return changed;
112 }
113
114 static bool rtrs_srv_change_state_get_old(struct rtrs_srv_sess *sess,
115                                            enum rtrs_srv_state new_state,
116                                            enum rtrs_srv_state *old_state)
117 {
118         bool changed;
119
120         spin_lock_irq(&sess->state_lock);
121         *old_state = sess->state;
122         changed = __rtrs_srv_change_state(sess, new_state);
123         spin_unlock_irq(&sess->state_lock);
124
125         return changed;
126 }
127
128 static bool rtrs_srv_change_state(struct rtrs_srv_sess *sess,
129                                    enum rtrs_srv_state new_state)
130 {
131         enum rtrs_srv_state old_state;
132
133         return rtrs_srv_change_state_get_old(sess, new_state, &old_state);
134 }
135
136 static void free_id(struct rtrs_srv_op *id)
137 {
138         if (!id)
139                 return;
140         kfree(id);
141 }
142
143 static void rtrs_srv_free_ops_ids(struct rtrs_srv_sess *sess)
144 {
145         struct rtrs_srv *srv = sess->srv;
146         int i;
147
148         WARN_ON(atomic_read(&sess->ids_inflight));
149         if (sess->ops_ids) {
150                 for (i = 0; i < srv->queue_depth; i++)
151                         free_id(sess->ops_ids[i]);
152                 kfree(sess->ops_ids);
153                 sess->ops_ids = NULL;
154         }
155 }
156
157 static void rtrs_srv_rdma_done(struct ib_cq *cq, struct ib_wc *wc);
158
159 static struct ib_cqe io_comp_cqe = {
160         .done = rtrs_srv_rdma_done
161 };
162
163 static int rtrs_srv_alloc_ops_ids(struct rtrs_srv_sess *sess)
164 {
165         struct rtrs_srv *srv = sess->srv;
166         struct rtrs_srv_op *id;
167         int i;
168
169         sess->ops_ids = kcalloc(srv->queue_depth, sizeof(*sess->ops_ids),
170                                 GFP_KERNEL);
171         if (!sess->ops_ids)
172                 goto err;
173
174         for (i = 0; i < srv->queue_depth; ++i) {
175                 id = kzalloc(sizeof(*id), GFP_KERNEL);
176                 if (!id)
177                         goto err;
178
179                 sess->ops_ids[i] = id;
180         }
181         init_waitqueue_head(&sess->ids_waitq);
182         atomic_set(&sess->ids_inflight, 0);
183
184         return 0;
185
186 err:
187         rtrs_srv_free_ops_ids(sess);
188         return -ENOMEM;
189 }
190
191 static inline void rtrs_srv_get_ops_ids(struct rtrs_srv_sess *sess)
192 {
193         atomic_inc(&sess->ids_inflight);
194 }
195
196 static inline void rtrs_srv_put_ops_ids(struct rtrs_srv_sess *sess)
197 {
198         if (atomic_dec_and_test(&sess->ids_inflight))
199                 wake_up(&sess->ids_waitq);
200 }
201
202 static void rtrs_srv_wait_ops_ids(struct rtrs_srv_sess *sess)
203 {
204         wait_event(sess->ids_waitq, !atomic_read(&sess->ids_inflight));
205 }
206
207
208 static void rtrs_srv_reg_mr_done(struct ib_cq *cq, struct ib_wc *wc)
209 {
210         struct rtrs_srv_con *con = cq->cq_context;
211         struct rtrs_sess *s = con->c.sess;
212         struct rtrs_srv_sess *sess = to_srv_sess(s);
213
214         if (unlikely(wc->status != IB_WC_SUCCESS)) {
215                 rtrs_err(s, "REG MR failed: %s\n",
216                           ib_wc_status_msg(wc->status));
217                 close_sess(sess);
218                 return;
219         }
220 }
221
222 static struct ib_cqe local_reg_cqe = {
223         .done = rtrs_srv_reg_mr_done
224 };
225
226 static int rdma_write_sg(struct rtrs_srv_op *id)
227 {
228         struct rtrs_sess *s = id->con->c.sess;
229         struct rtrs_srv_sess *sess = to_srv_sess(s);
230         dma_addr_t dma_addr = sess->dma_addr[id->msg_id];
231         struct rtrs_srv_mr *srv_mr;
232         struct rtrs_srv *srv = sess->srv;
233         struct ib_send_wr inv_wr, imm_wr;
234         struct ib_rdma_wr *wr = NULL;
235         enum ib_send_flags flags;
236         size_t sg_cnt;
237         int err, offset;
238         bool need_inval;
239         u32 rkey = 0;
240         struct ib_reg_wr rwr;
241         struct ib_sge *plist;
242         struct ib_sge list;
243
244         sg_cnt = le16_to_cpu(id->rd_msg->sg_cnt);
245         need_inval = le16_to_cpu(id->rd_msg->flags) & RTRS_MSG_NEED_INVAL_F;
246         if (unlikely(sg_cnt != 1))
247                 return -EINVAL;
248
249         offset = 0;
250
251         wr              = &id->tx_wr;
252         plist           = &id->tx_sg;
253         plist->addr     = dma_addr + offset;
254         plist->length   = le32_to_cpu(id->rd_msg->desc[0].len);
255
256         /* WR will fail with length error
257          * if this is 0
258          */
259         if (unlikely(plist->length == 0)) {
260                 rtrs_err(s, "Invalid RDMA-Write sg list length 0\n");
261                 return -EINVAL;
262         }
263
264         plist->lkey = sess->s.dev->ib_pd->local_dma_lkey;
265         offset += plist->length;
266
267         wr->wr.sg_list  = plist;
268         wr->wr.num_sge  = 1;
269         wr->remote_addr = le64_to_cpu(id->rd_msg->desc[0].addr);
270         wr->rkey        = le32_to_cpu(id->rd_msg->desc[0].key);
271         if (rkey == 0)
272                 rkey = wr->rkey;
273         else
274                 /* Only one key is actually used */
275                 WARN_ON_ONCE(rkey != wr->rkey);
276
277         wr->wr.opcode = IB_WR_RDMA_WRITE;
278         wr->wr.ex.imm_data = 0;
279         wr->wr.send_flags  = 0;
280
281         if (need_inval && always_invalidate) {
282                 wr->wr.next = &rwr.wr;
283                 rwr.wr.next = &inv_wr;
284                 inv_wr.next = &imm_wr;
285         } else if (always_invalidate) {
286                 wr->wr.next = &rwr.wr;
287                 rwr.wr.next = &imm_wr;
288         } else if (need_inval) {
289                 wr->wr.next = &inv_wr;
290                 inv_wr.next = &imm_wr;
291         } else {
292                 wr->wr.next = &imm_wr;
293         }
294         /*
295          * From time to time we have to post signaled sends,
296          * or send queue will fill up and only QP reset can help.
297          */
298         flags = (atomic_inc_return(&id->con->wr_cnt) % srv->queue_depth) ?
299                 0 : IB_SEND_SIGNALED;
300
301         if (need_inval) {
302                 inv_wr.sg_list = NULL;
303                 inv_wr.num_sge = 0;
304                 inv_wr.opcode = IB_WR_SEND_WITH_INV;
305                 inv_wr.send_flags = 0;
306                 inv_wr.ex.invalidate_rkey = rkey;
307         }
308
309         imm_wr.next = NULL;
310         if (always_invalidate) {
311                 struct rtrs_msg_rkey_rsp *msg;
312
313                 srv_mr = &sess->mrs[id->msg_id];
314                 rwr.wr.opcode = IB_WR_REG_MR;
315                 rwr.wr.num_sge = 0;
316                 rwr.mr = srv_mr->mr;
317                 rwr.wr.send_flags = 0;
318                 rwr.key = srv_mr->mr->rkey;
319                 rwr.access = (IB_ACCESS_LOCAL_WRITE |
320                               IB_ACCESS_REMOTE_WRITE);
321                 msg = srv_mr->iu->buf;
322                 msg->buf_id = cpu_to_le16(id->msg_id);
323                 msg->type = cpu_to_le16(RTRS_MSG_RKEY_RSP);
324                 msg->rkey = cpu_to_le32(srv_mr->mr->rkey);
325
326                 list.addr   = srv_mr->iu->dma_addr;
327                 list.length = sizeof(*msg);
328                 list.lkey   = sess->s.dev->ib_pd->local_dma_lkey;
329                 imm_wr.sg_list = &list;
330                 imm_wr.num_sge = 1;
331                 imm_wr.opcode = IB_WR_SEND_WITH_IMM;
332                 ib_dma_sync_single_for_device(sess->s.dev->ib_dev,
333                                               srv_mr->iu->dma_addr,
334                                               srv_mr->iu->size, DMA_TO_DEVICE);
335         } else {
336                 imm_wr.sg_list = NULL;
337                 imm_wr.num_sge = 0;
338                 imm_wr.opcode = IB_WR_RDMA_WRITE_WITH_IMM;
339         }
340         imm_wr.send_flags = flags;
341         imm_wr.ex.imm_data = cpu_to_be32(rtrs_to_io_rsp_imm(id->msg_id,
342                                                              0, need_inval));
343
344         imm_wr.wr_cqe   = &io_comp_cqe;
345         ib_dma_sync_single_for_device(sess->s.dev->ib_dev, dma_addr,
346                                       offset, DMA_BIDIRECTIONAL);
347
348         err = ib_post_send(id->con->c.qp, &id->tx_wr.wr, NULL);
349         if (unlikely(err))
350                 rtrs_err(s,
351                           "Posting RDMA-Write-Request to QP failed, err: %d\n",
352                           err);
353
354         return err;
355 }
356
357 /**
358  * send_io_resp_imm() - respond to client with empty IMM on failed READ/WRITE
359  *                      requests or on successful WRITE request.
360  * @con:        the connection to send back result
361  * @id:         the id associated with the IO
362  * @errno:      the error number of the IO.
363  *
364  * Return 0 on success, errno otherwise.
365  */
366 static int send_io_resp_imm(struct rtrs_srv_con *con, struct rtrs_srv_op *id,
367                             int errno)
368 {
369         struct rtrs_sess *s = con->c.sess;
370         struct rtrs_srv_sess *sess = to_srv_sess(s);
371         struct ib_send_wr inv_wr, imm_wr, *wr = NULL;
372         struct ib_reg_wr rwr;
373         struct rtrs_srv *srv = sess->srv;
374         struct rtrs_srv_mr *srv_mr;
375         bool need_inval = false;
376         enum ib_send_flags flags;
377         u32 imm;
378         int err;
379
380         if (id->dir == READ) {
381                 struct rtrs_msg_rdma_read *rd_msg = id->rd_msg;
382                 size_t sg_cnt;
383
384                 need_inval = le16_to_cpu(rd_msg->flags) &
385                                 RTRS_MSG_NEED_INVAL_F;
386                 sg_cnt = le16_to_cpu(rd_msg->sg_cnt);
387
388                 if (need_inval) {
389                         if (likely(sg_cnt)) {
390                                 inv_wr.sg_list = NULL;
391                                 inv_wr.num_sge = 0;
392                                 inv_wr.opcode = IB_WR_SEND_WITH_INV;
393                                 inv_wr.send_flags = 0;
394                                 /* Only one key is actually used */
395                                 inv_wr.ex.invalidate_rkey =
396                                         le32_to_cpu(rd_msg->desc[0].key);
397                         } else {
398                                 WARN_ON_ONCE(1);
399                                 need_inval = false;
400                         }
401                 }
402         }
403
404         if (need_inval && always_invalidate) {
405                 wr = &inv_wr;
406                 inv_wr.next = &rwr.wr;
407                 rwr.wr.next = &imm_wr;
408         } else if (always_invalidate) {
409                 wr = &rwr.wr;
410                 rwr.wr.next = &imm_wr;
411         } else if (need_inval) {
412                 wr = &inv_wr;
413                 inv_wr.next = &imm_wr;
414         } else {
415                 wr = &imm_wr;
416         }
417         /*
418          * From time to time we have to post signalled sends,
419          * or send queue will fill up and only QP reset can help.
420          */
421         flags = (atomic_inc_return(&con->wr_cnt) % srv->queue_depth) ?
422                 0 : IB_SEND_SIGNALED;
423         imm = rtrs_to_io_rsp_imm(id->msg_id, errno, need_inval);
424         imm_wr.next = NULL;
425         if (always_invalidate) {
426                 struct ib_sge list;
427                 struct rtrs_msg_rkey_rsp *msg;
428
429                 srv_mr = &sess->mrs[id->msg_id];
430                 rwr.wr.next = &imm_wr;
431                 rwr.wr.opcode = IB_WR_REG_MR;
432                 rwr.wr.num_sge = 0;
433                 rwr.wr.send_flags = 0;
434                 rwr.mr = srv_mr->mr;
435                 rwr.key = srv_mr->mr->rkey;
436                 rwr.access = (IB_ACCESS_LOCAL_WRITE |
437                               IB_ACCESS_REMOTE_WRITE);
438                 msg = srv_mr->iu->buf;
439                 msg->buf_id = cpu_to_le16(id->msg_id);
440                 msg->type = cpu_to_le16(RTRS_MSG_RKEY_RSP);
441                 msg->rkey = cpu_to_le32(srv_mr->mr->rkey);
442
443                 list.addr   = srv_mr->iu->dma_addr;
444                 list.length = sizeof(*msg);
445                 list.lkey   = sess->s.dev->ib_pd->local_dma_lkey;
446                 imm_wr.sg_list = &list;
447                 imm_wr.num_sge = 1;
448                 imm_wr.opcode = IB_WR_SEND_WITH_IMM;
449                 ib_dma_sync_single_for_device(sess->s.dev->ib_dev,
450                                               srv_mr->iu->dma_addr,
451                                               srv_mr->iu->size, DMA_TO_DEVICE);
452         } else {
453                 imm_wr.sg_list = NULL;
454                 imm_wr.num_sge = 0;
455                 imm_wr.opcode = IB_WR_RDMA_WRITE_WITH_IMM;
456         }
457         imm_wr.send_flags = flags;
458         imm_wr.wr_cqe   = &io_comp_cqe;
459
460         imm_wr.ex.imm_data = cpu_to_be32(imm);
461
462         err = ib_post_send(id->con->c.qp, wr, NULL);
463         if (unlikely(err))
464                 rtrs_err_rl(s, "Posting RDMA-Reply to QP failed, err: %d\n",
465                              err);
466
467         return err;
468 }
469
470 void close_sess(struct rtrs_srv_sess *sess)
471 {
472         enum rtrs_srv_state old_state;
473
474         if (rtrs_srv_change_state_get_old(sess, RTRS_SRV_CLOSING,
475                                            &old_state))
476                 queue_work(rtrs_wq, &sess->close_work);
477         WARN_ON(sess->state != RTRS_SRV_CLOSING);
478 }
479
480 static inline const char *rtrs_srv_state_str(enum rtrs_srv_state state)
481 {
482         switch (state) {
483         case RTRS_SRV_CONNECTING:
484                 return "RTRS_SRV_CONNECTING";
485         case RTRS_SRV_CONNECTED:
486                 return "RTRS_SRV_CONNECTED";
487         case RTRS_SRV_CLOSING:
488                 return "RTRS_SRV_CLOSING";
489         case RTRS_SRV_CLOSED:
490                 return "RTRS_SRV_CLOSED";
491         default:
492                 return "UNKNOWN";
493         }
494 }
495
496 /**
497  * rtrs_srv_resp_rdma() - Finish an RDMA request
498  *
499  * @id:         Internal RTRS operation identifier
500  * @status:     Response Code sent to the other side for this operation.
501  *              0 = success, <=0 error
502  * Context: any
503  *
504  * Finish a RDMA operation. A message is sent to the client and the
505  * corresponding memory areas will be released.
506  */
507 bool rtrs_srv_resp_rdma(struct rtrs_srv_op *id, int status)
508 {
509         struct rtrs_srv_sess *sess;
510         struct rtrs_srv_con *con;
511         struct rtrs_sess *s;
512         int err;
513
514         if (WARN_ON(!id))
515                 return true;
516
517         con = id->con;
518         s = con->c.sess;
519         sess = to_srv_sess(s);
520
521         id->status = status;
522
523         if (unlikely(sess->state != RTRS_SRV_CONNECTED)) {
524                 rtrs_err_rl(s,
525                              "Sending I/O response failed,  session is disconnected, sess state %s\n",
526                              rtrs_srv_state_str(sess->state));
527                 goto out;
528         }
529         if (always_invalidate) {
530                 struct rtrs_srv_mr *mr = &sess->mrs[id->msg_id];
531
532                 ib_update_fast_reg_key(mr->mr, ib_inc_rkey(mr->mr->rkey));
533         }
534         if (unlikely(atomic_sub_return(1,
535                                        &con->sq_wr_avail) < 0)) {
536                 pr_err("IB send queue full\n");
537                 atomic_add(1, &con->sq_wr_avail);
538                 spin_lock(&con->rsp_wr_wait_lock);
539                 list_add_tail(&id->wait_list, &con->rsp_wr_wait_list);
540                 spin_unlock(&con->rsp_wr_wait_lock);
541                 return false;
542         }
543
544         if (status || id->dir == WRITE || !id->rd_msg->sg_cnt)
545                 err = send_io_resp_imm(con, id, status);
546         else
547                 err = rdma_write_sg(id);
548
549         if (unlikely(err)) {
550                 rtrs_err_rl(s, "IO response failed: %d\n", err);
551                 close_sess(sess);
552         }
553 out:
554         rtrs_srv_put_ops_ids(sess);
555         return true;
556 }
557 EXPORT_SYMBOL(rtrs_srv_resp_rdma);
558
559 /**
560  * rtrs_srv_set_sess_priv() - Set private pointer in rtrs_srv.
561  * @srv:        Session pointer
562  * @priv:       The private pointer that is associated with the session.
563  */
564 void rtrs_srv_set_sess_priv(struct rtrs_srv *srv, void *priv)
565 {
566         srv->priv = priv;
567 }
568 EXPORT_SYMBOL(rtrs_srv_set_sess_priv);
569
570 static void unmap_cont_bufs(struct rtrs_srv_sess *sess)
571 {
572         int i;
573
574         for (i = 0; i < sess->mrs_num; i++) {
575                 struct rtrs_srv_mr *srv_mr;
576
577                 srv_mr = &sess->mrs[i];
578                 rtrs_iu_free(srv_mr->iu, DMA_TO_DEVICE,
579                               sess->s.dev->ib_dev, 1);
580                 ib_dereg_mr(srv_mr->mr);
581                 ib_dma_unmap_sg(sess->s.dev->ib_dev, srv_mr->sgt.sgl,
582                                 srv_mr->sgt.nents, DMA_BIDIRECTIONAL);
583                 sg_free_table(&srv_mr->sgt);
584         }
585         kfree(sess->mrs);
586 }
587
588 static int map_cont_bufs(struct rtrs_srv_sess *sess)
589 {
590         struct rtrs_srv *srv = sess->srv;
591         struct rtrs_sess *ss = &sess->s;
592         int i, mri, err, mrs_num;
593         unsigned int chunk_bits;
594         int chunks_per_mr = 1;
595
596         /*
597          * Here we map queue_depth chunks to MR.  Firstly we have to
598          * figure out how many chunks can we map per MR.
599          */
600         if (always_invalidate) {
601                 /*
602                  * in order to do invalidate for each chunks of memory, we needs
603                  * more memory regions.
604                  */
605                 mrs_num = srv->queue_depth;
606         } else {
607                 chunks_per_mr =
608                         sess->s.dev->ib_dev->attrs.max_fast_reg_page_list_len;
609                 mrs_num = DIV_ROUND_UP(srv->queue_depth, chunks_per_mr);
610                 chunks_per_mr = DIV_ROUND_UP(srv->queue_depth, mrs_num);
611         }
612
613         sess->mrs = kcalloc(mrs_num, sizeof(*sess->mrs), GFP_KERNEL);
614         if (!sess->mrs)
615                 return -ENOMEM;
616
617         sess->mrs_num = mrs_num;
618
619         for (mri = 0; mri < mrs_num; mri++) {
620                 struct rtrs_srv_mr *srv_mr = &sess->mrs[mri];
621                 struct sg_table *sgt = &srv_mr->sgt;
622                 struct scatterlist *s;
623                 struct ib_mr *mr;
624                 int nr, chunks;
625
626                 chunks = chunks_per_mr * mri;
627                 if (!always_invalidate)
628                         chunks_per_mr = min_t(int, chunks_per_mr,
629                                               srv->queue_depth - chunks);
630
631                 err = sg_alloc_table(sgt, chunks_per_mr, GFP_KERNEL);
632                 if (err)
633                         goto err;
634
635                 for_each_sg(sgt->sgl, s, chunks_per_mr, i)
636                         sg_set_page(s, srv->chunks[chunks + i],
637                                     max_chunk_size, 0);
638
639                 nr = ib_dma_map_sg(sess->s.dev->ib_dev, sgt->sgl,
640                                    sgt->nents, DMA_BIDIRECTIONAL);
641                 if (nr < sgt->nents) {
642                         err = nr < 0 ? nr : -EINVAL;
643                         goto free_sg;
644                 }
645                 mr = ib_alloc_mr(sess->s.dev->ib_pd, IB_MR_TYPE_MEM_REG,
646                                  sgt->nents);
647                 if (IS_ERR(mr)) {
648                         err = PTR_ERR(mr);
649                         goto unmap_sg;
650                 }
651                 nr = ib_map_mr_sg(mr, sgt->sgl, sgt->nents,
652                                   NULL, max_chunk_size);
653                 if (nr < 0 || nr < sgt->nents) {
654                         err = nr < 0 ? nr : -EINVAL;
655                         goto dereg_mr;
656                 }
657
658                 if (always_invalidate) {
659                         srv_mr->iu = rtrs_iu_alloc(1,
660                                         sizeof(struct rtrs_msg_rkey_rsp),
661                                         GFP_KERNEL, sess->s.dev->ib_dev,
662                                         DMA_TO_DEVICE, rtrs_srv_rdma_done);
663                         if (!srv_mr->iu) {
664                                 err = -ENOMEM;
665                                 rtrs_err(ss, "rtrs_iu_alloc(), err: %d\n", err);
666                                 goto free_iu;
667                         }
668                 }
669                 /* Eventually dma addr for each chunk can be cached */
670                 for_each_sg(sgt->sgl, s, sgt->orig_nents, i)
671                         sess->dma_addr[chunks + i] = sg_dma_address(s);
672
673                 ib_update_fast_reg_key(mr, ib_inc_rkey(mr->rkey));
674                 srv_mr->mr = mr;
675
676                 continue;
677 err:
678                 while (mri--) {
679                         srv_mr = &sess->mrs[mri];
680                         sgt = &srv_mr->sgt;
681                         mr = srv_mr->mr;
682 free_iu:
683                         rtrs_iu_free(srv_mr->iu, DMA_TO_DEVICE,
684                                       sess->s.dev->ib_dev, 1);
685 dereg_mr:
686                         ib_dereg_mr(mr);
687 unmap_sg:
688                         ib_dma_unmap_sg(sess->s.dev->ib_dev, sgt->sgl,
689                                         sgt->nents, DMA_BIDIRECTIONAL);
690 free_sg:
691                         sg_free_table(sgt);
692                 }
693                 kfree(sess->mrs);
694
695                 return err;
696         }
697
698         chunk_bits = ilog2(srv->queue_depth - 1) + 1;
699         sess->mem_bits = (MAX_IMM_PAYL_BITS - chunk_bits);
700
701         return 0;
702 }
703
704 static void rtrs_srv_hb_err_handler(struct rtrs_con *c)
705 {
706         close_sess(to_srv_sess(c->sess));
707 }
708
709 static void rtrs_srv_init_hb(struct rtrs_srv_sess *sess)
710 {
711         rtrs_init_hb(&sess->s, &io_comp_cqe,
712                       RTRS_HB_INTERVAL_MS,
713                       RTRS_HB_MISSED_MAX,
714                       rtrs_srv_hb_err_handler,
715                       rtrs_wq);
716 }
717
718 static void rtrs_srv_start_hb(struct rtrs_srv_sess *sess)
719 {
720         rtrs_start_hb(&sess->s);
721 }
722
723 static void rtrs_srv_stop_hb(struct rtrs_srv_sess *sess)
724 {
725         rtrs_stop_hb(&sess->s);
726 }
727
728 static void rtrs_srv_info_rsp_done(struct ib_cq *cq, struct ib_wc *wc)
729 {
730         struct rtrs_srv_con *con = cq->cq_context;
731         struct rtrs_sess *s = con->c.sess;
732         struct rtrs_srv_sess *sess = to_srv_sess(s);
733         struct rtrs_iu *iu;
734
735         iu = container_of(wc->wr_cqe, struct rtrs_iu, cqe);
736         rtrs_iu_free(iu, DMA_TO_DEVICE, sess->s.dev->ib_dev, 1);
737
738         if (unlikely(wc->status != IB_WC_SUCCESS)) {
739                 rtrs_err(s, "Sess info response send failed: %s\n",
740                           ib_wc_status_msg(wc->status));
741                 close_sess(sess);
742                 return;
743         }
744         WARN_ON(wc->opcode != IB_WC_SEND);
745 }
746
747 static void rtrs_srv_sess_up(struct rtrs_srv_sess *sess)
748 {
749         struct rtrs_srv *srv = sess->srv;
750         struct rtrs_srv_ctx *ctx = srv->ctx;
751         int up;
752
753         mutex_lock(&srv->paths_ev_mutex);
754         up = ++srv->paths_up;
755         if (up == 1)
756                 ctx->ops.link_ev(srv, RTRS_SRV_LINK_EV_CONNECTED, NULL);
757         mutex_unlock(&srv->paths_ev_mutex);
758
759         /* Mark session as established */
760         sess->established = true;
761 }
762
763 static void rtrs_srv_sess_down(struct rtrs_srv_sess *sess)
764 {
765         struct rtrs_srv *srv = sess->srv;
766         struct rtrs_srv_ctx *ctx = srv->ctx;
767
768         if (!sess->established)
769                 return;
770
771         sess->established = false;
772         mutex_lock(&srv->paths_ev_mutex);
773         WARN_ON(!srv->paths_up);
774         if (--srv->paths_up == 0)
775                 ctx->ops.link_ev(srv, RTRS_SRV_LINK_EV_DISCONNECTED, srv->priv);
776         mutex_unlock(&srv->paths_ev_mutex);
777 }
778
779 static int post_recv_sess(struct rtrs_srv_sess *sess);
780
781 static int process_info_req(struct rtrs_srv_con *con,
782                             struct rtrs_msg_info_req *msg)
783 {
784         struct rtrs_sess *s = con->c.sess;
785         struct rtrs_srv_sess *sess = to_srv_sess(s);
786         struct ib_send_wr *reg_wr = NULL;
787         struct rtrs_msg_info_rsp *rsp;
788         struct rtrs_iu *tx_iu;
789         struct ib_reg_wr *rwr;
790         int mri, err;
791         size_t tx_sz;
792
793         err = post_recv_sess(sess);
794         if (unlikely(err)) {
795                 rtrs_err(s, "post_recv_sess(), err: %d\n", err);
796                 return err;
797         }
798         rwr = kcalloc(sess->mrs_num, sizeof(*rwr), GFP_KERNEL);
799         if (unlikely(!rwr))
800                 return -ENOMEM;
801         strlcpy(sess->s.sessname, msg->sessname, sizeof(sess->s.sessname));
802
803         tx_sz  = sizeof(*rsp);
804         tx_sz += sizeof(rsp->desc[0]) * sess->mrs_num;
805         tx_iu = rtrs_iu_alloc(1, tx_sz, GFP_KERNEL, sess->s.dev->ib_dev,
806                                DMA_TO_DEVICE, rtrs_srv_info_rsp_done);
807         if (unlikely(!tx_iu)) {
808                 err = -ENOMEM;
809                 goto rwr_free;
810         }
811
812         rsp = tx_iu->buf;
813         rsp->type = cpu_to_le16(RTRS_MSG_INFO_RSP);
814         rsp->sg_cnt = cpu_to_le16(sess->mrs_num);
815
816         for (mri = 0; mri < sess->mrs_num; mri++) {
817                 struct ib_mr *mr = sess->mrs[mri].mr;
818
819                 rsp->desc[mri].addr = cpu_to_le64(mr->iova);
820                 rsp->desc[mri].key  = cpu_to_le32(mr->rkey);
821                 rsp->desc[mri].len  = cpu_to_le32(mr->length);
822
823                 /*
824                  * Fill in reg MR request and chain them *backwards*
825                  */
826                 rwr[mri].wr.next = mri ? &rwr[mri - 1].wr : NULL;
827                 rwr[mri].wr.opcode = IB_WR_REG_MR;
828                 rwr[mri].wr.wr_cqe = &local_reg_cqe;
829                 rwr[mri].wr.num_sge = 0;
830                 rwr[mri].wr.send_flags = mri ? 0 : IB_SEND_SIGNALED;
831                 rwr[mri].mr = mr;
832                 rwr[mri].key = mr->rkey;
833                 rwr[mri].access = (IB_ACCESS_LOCAL_WRITE |
834                                    IB_ACCESS_REMOTE_WRITE);
835                 reg_wr = &rwr[mri].wr;
836         }
837
838         err = rtrs_srv_create_sess_files(sess);
839         if (unlikely(err))
840                 goto iu_free;
841         kobject_get(&sess->kobj);
842         get_device(&sess->srv->dev);
843         rtrs_srv_change_state(sess, RTRS_SRV_CONNECTED);
844         rtrs_srv_start_hb(sess);
845
846         /*
847          * We do not account number of established connections at the current
848          * moment, we rely on the client, which should send info request when
849          * all connections are successfully established.  Thus, simply notify
850          * listener with a proper event if we are the first path.
851          */
852         rtrs_srv_sess_up(sess);
853
854         ib_dma_sync_single_for_device(sess->s.dev->ib_dev, tx_iu->dma_addr,
855                                       tx_iu->size, DMA_TO_DEVICE);
856
857         /* Send info response */
858         err = rtrs_iu_post_send(&con->c, tx_iu, tx_sz, reg_wr);
859         if (unlikely(err)) {
860                 rtrs_err(s, "rtrs_iu_post_send(), err: %d\n", err);
861 iu_free:
862                 rtrs_iu_free(tx_iu, DMA_TO_DEVICE, sess->s.dev->ib_dev, 1);
863         }
864 rwr_free:
865         kfree(rwr);
866
867         return err;
868 }
869
870 static void rtrs_srv_info_req_done(struct ib_cq *cq, struct ib_wc *wc)
871 {
872         struct rtrs_srv_con *con = cq->cq_context;
873         struct rtrs_sess *s = con->c.sess;
874         struct rtrs_srv_sess *sess = to_srv_sess(s);
875         struct rtrs_msg_info_req *msg;
876         struct rtrs_iu *iu;
877         int err;
878
879         WARN_ON(con->c.cid);
880
881         iu = container_of(wc->wr_cqe, struct rtrs_iu, cqe);
882         if (unlikely(wc->status != IB_WC_SUCCESS)) {
883                 rtrs_err(s, "Sess info request receive failed: %s\n",
884                           ib_wc_status_msg(wc->status));
885                 goto close;
886         }
887         WARN_ON(wc->opcode != IB_WC_RECV);
888
889         if (unlikely(wc->byte_len < sizeof(*msg))) {
890                 rtrs_err(s, "Sess info request is malformed: size %d\n",
891                           wc->byte_len);
892                 goto close;
893         }
894         ib_dma_sync_single_for_cpu(sess->s.dev->ib_dev, iu->dma_addr,
895                                    iu->size, DMA_FROM_DEVICE);
896         msg = iu->buf;
897         if (unlikely(le16_to_cpu(msg->type) != RTRS_MSG_INFO_REQ)) {
898                 rtrs_err(s, "Sess info request is malformed: type %d\n",
899                           le16_to_cpu(msg->type));
900                 goto close;
901         }
902         err = process_info_req(con, msg);
903         if (unlikely(err))
904                 goto close;
905
906 out:
907         rtrs_iu_free(iu, DMA_FROM_DEVICE, sess->s.dev->ib_dev, 1);
908         return;
909 close:
910         close_sess(sess);
911         goto out;
912 }
913
914 static int post_recv_info_req(struct rtrs_srv_con *con)
915 {
916         struct rtrs_sess *s = con->c.sess;
917         struct rtrs_srv_sess *sess = to_srv_sess(s);
918         struct rtrs_iu *rx_iu;
919         int err;
920
921         rx_iu = rtrs_iu_alloc(1, sizeof(struct rtrs_msg_info_req),
922                                GFP_KERNEL, sess->s.dev->ib_dev,
923                                DMA_FROM_DEVICE, rtrs_srv_info_req_done);
924         if (unlikely(!rx_iu))
925                 return -ENOMEM;
926         /* Prepare for getting info response */
927         err = rtrs_iu_post_recv(&con->c, rx_iu);
928         if (unlikely(err)) {
929                 rtrs_err(s, "rtrs_iu_post_recv(), err: %d\n", err);
930                 rtrs_iu_free(rx_iu, DMA_FROM_DEVICE, sess->s.dev->ib_dev, 1);
931                 return err;
932         }
933
934         return 0;
935 }
936
937 static int post_recv_io(struct rtrs_srv_con *con, size_t q_size)
938 {
939         int i, err;
940
941         for (i = 0; i < q_size; i++) {
942                 err = rtrs_post_recv_empty(&con->c, &io_comp_cqe);
943                 if (unlikely(err))
944                         return err;
945         }
946
947         return 0;
948 }
949
950 static int post_recv_sess(struct rtrs_srv_sess *sess)
951 {
952         struct rtrs_srv *srv = sess->srv;
953         struct rtrs_sess *s = &sess->s;
954         size_t q_size;
955         int err, cid;
956
957         for (cid = 0; cid < sess->s.con_num; cid++) {
958                 if (cid == 0)
959                         q_size = SERVICE_CON_QUEUE_DEPTH;
960                 else
961                         q_size = srv->queue_depth;
962
963                 err = post_recv_io(to_srv_con(sess->s.con[cid]), q_size);
964                 if (unlikely(err)) {
965                         rtrs_err(s, "post_recv_io(), err: %d\n", err);
966                         return err;
967                 }
968         }
969
970         return 0;
971 }
972
973 static void process_read(struct rtrs_srv_con *con,
974                          struct rtrs_msg_rdma_read *msg,
975                          u32 buf_id, u32 off)
976 {
977         struct rtrs_sess *s = con->c.sess;
978         struct rtrs_srv_sess *sess = to_srv_sess(s);
979         struct rtrs_srv *srv = sess->srv;
980         struct rtrs_srv_ctx *ctx = srv->ctx;
981         struct rtrs_srv_op *id;
982
983         size_t usr_len, data_len;
984         void *data;
985         int ret;
986
987         if (unlikely(sess->state != RTRS_SRV_CONNECTED)) {
988                 rtrs_err_rl(s,
989                              "Processing read request failed,  session is disconnected, sess state %s\n",
990                              rtrs_srv_state_str(sess->state));
991                 return;
992         }
993         if (unlikely(msg->sg_cnt != 1 && msg->sg_cnt != 0)) {
994                 rtrs_err_rl(s,
995                             "Processing read request failed, invalid message\n");
996                 return;
997         }
998         rtrs_srv_get_ops_ids(sess);
999         rtrs_srv_update_rdma_stats(sess->stats, off, READ);
1000         id = sess->ops_ids[buf_id];
1001         id->con         = con;
1002         id->dir         = READ;
1003         id->msg_id      = buf_id;
1004         id->rd_msg      = msg;
1005         usr_len = le16_to_cpu(msg->usr_len);
1006         data_len = off - usr_len;
1007         data = page_address(srv->chunks[buf_id]);
1008         ret = ctx->ops.rdma_ev(srv, srv->priv, id, READ, data, data_len,
1009                            data + data_len, usr_len);
1010
1011         if (unlikely(ret)) {
1012                 rtrs_err_rl(s,
1013                              "Processing read request failed, user module cb reported for msg_id %d, err: %d\n",
1014                              buf_id, ret);
1015                 goto send_err_msg;
1016         }
1017
1018         return;
1019
1020 send_err_msg:
1021         ret = send_io_resp_imm(con, id, ret);
1022         if (ret < 0) {
1023                 rtrs_err_rl(s,
1024                              "Sending err msg for failed RDMA-Write-Req failed, msg_id %d, err: %d\n",
1025                              buf_id, ret);
1026                 close_sess(sess);
1027         }
1028         rtrs_srv_put_ops_ids(sess);
1029 }
1030
1031 static void process_write(struct rtrs_srv_con *con,
1032                           struct rtrs_msg_rdma_write *req,
1033                           u32 buf_id, u32 off)
1034 {
1035         struct rtrs_sess *s = con->c.sess;
1036         struct rtrs_srv_sess *sess = to_srv_sess(s);
1037         struct rtrs_srv *srv = sess->srv;
1038         struct rtrs_srv_ctx *ctx = srv->ctx;
1039         struct rtrs_srv_op *id;
1040
1041         size_t data_len, usr_len;
1042         void *data;
1043         int ret;
1044
1045         if (unlikely(sess->state != RTRS_SRV_CONNECTED)) {
1046                 rtrs_err_rl(s,
1047                              "Processing write request failed,  session is disconnected, sess state %s\n",
1048                              rtrs_srv_state_str(sess->state));
1049                 return;
1050         }
1051         rtrs_srv_get_ops_ids(sess);
1052         rtrs_srv_update_rdma_stats(sess->stats, off, WRITE);
1053         id = sess->ops_ids[buf_id];
1054         id->con    = con;
1055         id->dir    = WRITE;
1056         id->msg_id = buf_id;
1057
1058         usr_len = le16_to_cpu(req->usr_len);
1059         data_len = off - usr_len;
1060         data = page_address(srv->chunks[buf_id]);
1061         ret = ctx->ops.rdma_ev(srv, srv->priv, id, WRITE, data, data_len,
1062                            data + data_len, usr_len);
1063         if (unlikely(ret)) {
1064                 rtrs_err_rl(s,
1065                              "Processing write request failed, user module callback reports err: %d\n",
1066                              ret);
1067                 goto send_err_msg;
1068         }
1069
1070         return;
1071
1072 send_err_msg:
1073         ret = send_io_resp_imm(con, id, ret);
1074         if (ret < 0) {
1075                 rtrs_err_rl(s,
1076                              "Processing write request failed, sending I/O response failed, msg_id %d, err: %d\n",
1077                              buf_id, ret);
1078                 close_sess(sess);
1079         }
1080         rtrs_srv_put_ops_ids(sess);
1081 }
1082
1083 static void process_io_req(struct rtrs_srv_con *con, void *msg,
1084                            u32 id, u32 off)
1085 {
1086         struct rtrs_sess *s = con->c.sess;
1087         struct rtrs_srv_sess *sess = to_srv_sess(s);
1088         struct rtrs_msg_rdma_hdr *hdr;
1089         unsigned int type;
1090
1091         ib_dma_sync_single_for_cpu(sess->s.dev->ib_dev, sess->dma_addr[id],
1092                                    max_chunk_size, DMA_BIDIRECTIONAL);
1093         hdr = msg;
1094         type = le16_to_cpu(hdr->type);
1095
1096         switch (type) {
1097         case RTRS_MSG_WRITE:
1098                 process_write(con, msg, id, off);
1099                 break;
1100         case RTRS_MSG_READ:
1101                 process_read(con, msg, id, off);
1102                 break;
1103         default:
1104                 rtrs_err(s,
1105                           "Processing I/O request failed, unknown message type received: 0x%02x\n",
1106                           type);
1107                 goto err;
1108         }
1109
1110         return;
1111
1112 err:
1113         close_sess(sess);
1114 }
1115
1116 static void rtrs_srv_inv_rkey_done(struct ib_cq *cq, struct ib_wc *wc)
1117 {
1118         struct rtrs_srv_mr *mr =
1119                 container_of(wc->wr_cqe, typeof(*mr), inv_cqe);
1120         struct rtrs_srv_con *con = cq->cq_context;
1121         struct rtrs_sess *s = con->c.sess;
1122         struct rtrs_srv_sess *sess = to_srv_sess(s);
1123         struct rtrs_srv *srv = sess->srv;
1124         u32 msg_id, off;
1125         void *data;
1126
1127         if (unlikely(wc->status != IB_WC_SUCCESS)) {
1128                 rtrs_err(s, "Failed IB_WR_LOCAL_INV: %s\n",
1129                           ib_wc_status_msg(wc->status));
1130                 close_sess(sess);
1131         }
1132         msg_id = mr->msg_id;
1133         off = mr->msg_off;
1134         data = page_address(srv->chunks[msg_id]) + off;
1135         process_io_req(con, data, msg_id, off);
1136 }
1137
1138 static int rtrs_srv_inv_rkey(struct rtrs_srv_con *con,
1139                               struct rtrs_srv_mr *mr)
1140 {
1141         struct ib_send_wr wr = {
1142                 .opcode             = IB_WR_LOCAL_INV,
1143                 .wr_cqe             = &mr->inv_cqe,
1144                 .send_flags         = IB_SEND_SIGNALED,
1145                 .ex.invalidate_rkey = mr->mr->rkey,
1146         };
1147         mr->inv_cqe.done = rtrs_srv_inv_rkey_done;
1148
1149         return ib_post_send(con->c.qp, &wr, NULL);
1150 }
1151
1152 static void rtrs_rdma_process_wr_wait_list(struct rtrs_srv_con *con)
1153 {
1154         spin_lock(&con->rsp_wr_wait_lock);
1155         while (!list_empty(&con->rsp_wr_wait_list)) {
1156                 struct rtrs_srv_op *id;
1157                 int ret;
1158
1159                 id = list_entry(con->rsp_wr_wait_list.next,
1160                                 struct rtrs_srv_op, wait_list);
1161                 list_del(&id->wait_list);
1162
1163                 spin_unlock(&con->rsp_wr_wait_lock);
1164                 ret = rtrs_srv_resp_rdma(id, id->status);
1165                 spin_lock(&con->rsp_wr_wait_lock);
1166
1167                 if (!ret) {
1168                         list_add(&id->wait_list, &con->rsp_wr_wait_list);
1169                         break;
1170                 }
1171         }
1172         spin_unlock(&con->rsp_wr_wait_lock);
1173 }
1174
1175 static void rtrs_srv_rdma_done(struct ib_cq *cq, struct ib_wc *wc)
1176 {
1177         struct rtrs_srv_con *con = cq->cq_context;
1178         struct rtrs_sess *s = con->c.sess;
1179         struct rtrs_srv_sess *sess = to_srv_sess(s);
1180         struct rtrs_srv *srv = sess->srv;
1181         u32 imm_type, imm_payload;
1182         int err;
1183
1184         if (unlikely(wc->status != IB_WC_SUCCESS)) {
1185                 if (wc->status != IB_WC_WR_FLUSH_ERR) {
1186                         rtrs_err(s,
1187                                   "%s (wr_cqe: %p, type: %d, vendor_err: 0x%x, len: %u)\n",
1188                                   ib_wc_status_msg(wc->status), wc->wr_cqe,
1189                                   wc->opcode, wc->vendor_err, wc->byte_len);
1190                         close_sess(sess);
1191                 }
1192                 return;
1193         }
1194
1195         switch (wc->opcode) {
1196         case IB_WC_RECV_RDMA_WITH_IMM:
1197                 /*
1198                  * post_recv() RDMA write completions of IO reqs (read/write)
1199                  * and hb
1200                  */
1201                 if (WARN_ON(wc->wr_cqe != &io_comp_cqe))
1202                         return;
1203                 err = rtrs_post_recv_empty(&con->c, &io_comp_cqe);
1204                 if (unlikely(err)) {
1205                         rtrs_err(s, "rtrs_post_recv(), err: %d\n", err);
1206                         close_sess(sess);
1207                         break;
1208                 }
1209                 rtrs_from_imm(be32_to_cpu(wc->ex.imm_data),
1210                                &imm_type, &imm_payload);
1211                 if (likely(imm_type == RTRS_IO_REQ_IMM)) {
1212                         u32 msg_id, off;
1213                         void *data;
1214
1215                         msg_id = imm_payload >> sess->mem_bits;
1216                         off = imm_payload & ((1 << sess->mem_bits) - 1);
1217                         if (unlikely(msg_id >= srv->queue_depth ||
1218                                      off >= max_chunk_size)) {
1219                                 rtrs_err(s, "Wrong msg_id %u, off %u\n",
1220                                           msg_id, off);
1221                                 close_sess(sess);
1222                                 return;
1223                         }
1224                         if (always_invalidate) {
1225                                 struct rtrs_srv_mr *mr = &sess->mrs[msg_id];
1226
1227                                 mr->msg_off = off;
1228                                 mr->msg_id = msg_id;
1229                                 err = rtrs_srv_inv_rkey(con, mr);
1230                                 if (unlikely(err)) {
1231                                         rtrs_err(s, "rtrs_post_recv(), err: %d\n",
1232                                                   err);
1233                                         close_sess(sess);
1234                                         break;
1235                                 }
1236                         } else {
1237                                 data = page_address(srv->chunks[msg_id]) + off;
1238                                 process_io_req(con, data, msg_id, off);
1239                         }
1240                 } else if (imm_type == RTRS_HB_MSG_IMM) {
1241                         WARN_ON(con->c.cid);
1242                         rtrs_send_hb_ack(&sess->s);
1243                 } else if (imm_type == RTRS_HB_ACK_IMM) {
1244                         WARN_ON(con->c.cid);
1245                         sess->s.hb_missed_cnt = 0;
1246                 } else {
1247                         rtrs_wrn(s, "Unknown IMM type %u\n", imm_type);
1248                 }
1249                 break;
1250         case IB_WC_RDMA_WRITE:
1251         case IB_WC_SEND:
1252                 /*
1253                  * post_send() RDMA write completions of IO reqs (read/write)
1254                  * and hb
1255                  */
1256                 atomic_add(srv->queue_depth, &con->sq_wr_avail);
1257
1258                 if (unlikely(!list_empty_careful(&con->rsp_wr_wait_list)))
1259                         rtrs_rdma_process_wr_wait_list(con);
1260
1261                 break;
1262         default:
1263                 rtrs_wrn(s, "Unexpected WC type: %d\n", wc->opcode);
1264                 return;
1265         }
1266 }
1267
1268 /**
1269  * rtrs_srv_get_sess_name() - Get rtrs_srv peer hostname.
1270  * @srv:        Session
1271  * @sessname:   Sessname buffer
1272  * @len:        Length of sessname buffer
1273  */
1274 int rtrs_srv_get_sess_name(struct rtrs_srv *srv, char *sessname, size_t len)
1275 {
1276         struct rtrs_srv_sess *sess;
1277         int err = -ENOTCONN;
1278
1279         mutex_lock(&srv->paths_mutex);
1280         list_for_each_entry(sess, &srv->paths_list, s.entry) {
1281                 if (sess->state != RTRS_SRV_CONNECTED)
1282                         continue;
1283                 strlcpy(sessname, sess->s.sessname,
1284                        min_t(size_t, sizeof(sess->s.sessname), len));
1285                 err = 0;
1286                 break;
1287         }
1288         mutex_unlock(&srv->paths_mutex);
1289
1290         return err;
1291 }
1292 EXPORT_SYMBOL(rtrs_srv_get_sess_name);
1293
1294 /**
1295  * rtrs_srv_get_sess_qdepth() - Get rtrs_srv qdepth.
1296  * @srv:        Session
1297  */
1298 int rtrs_srv_get_queue_depth(struct rtrs_srv *srv)
1299 {
1300         return srv->queue_depth;
1301 }
1302 EXPORT_SYMBOL(rtrs_srv_get_queue_depth);
1303
1304 static int find_next_bit_ring(struct rtrs_srv_sess *sess)
1305 {
1306         struct ib_device *ib_dev = sess->s.dev->ib_dev;
1307         int v;
1308
1309         v = cpumask_next(sess->cur_cq_vector, &cq_affinity_mask);
1310         if (v >= nr_cpu_ids || v >= ib_dev->num_comp_vectors)
1311                 v = cpumask_first(&cq_affinity_mask);
1312         return v;
1313 }
1314
1315 static int rtrs_srv_get_next_cq_vector(struct rtrs_srv_sess *sess)
1316 {
1317         sess->cur_cq_vector = find_next_bit_ring(sess);
1318
1319         return sess->cur_cq_vector;
1320 }
1321
1322 static struct rtrs_srv *__alloc_srv(struct rtrs_srv_ctx *ctx,
1323                                      const uuid_t *paths_uuid)
1324 {
1325         struct rtrs_srv *srv;
1326         int i;
1327
1328         srv = kzalloc(sizeof(*srv), GFP_KERNEL);
1329         if  (!srv)
1330                 return NULL;
1331
1332         refcount_set(&srv->refcount, 1);
1333         INIT_LIST_HEAD(&srv->paths_list);
1334         mutex_init(&srv->paths_mutex);
1335         mutex_init(&srv->paths_ev_mutex);
1336         uuid_copy(&srv->paths_uuid, paths_uuid);
1337         srv->queue_depth = sess_queue_depth;
1338         srv->ctx = ctx;
1339
1340         srv->chunks = kcalloc(srv->queue_depth, sizeof(*srv->chunks),
1341                               GFP_KERNEL);
1342         if (!srv->chunks)
1343                 goto err_free_srv;
1344
1345         for (i = 0; i < srv->queue_depth; i++) {
1346                 srv->chunks[i] = mempool_alloc(chunk_pool, GFP_KERNEL);
1347                 if (!srv->chunks[i])
1348                         goto err_free_chunks;
1349         }
1350         list_add(&srv->ctx_list, &ctx->srv_list);
1351
1352         return srv;
1353
1354 err_free_chunks:
1355         while (i--)
1356                 mempool_free(srv->chunks[i], chunk_pool);
1357         kfree(srv->chunks);
1358
1359 err_free_srv:
1360         kfree(srv);
1361
1362         return NULL;
1363 }
1364
1365 static void free_srv(struct rtrs_srv *srv)
1366 {
1367         int i;
1368
1369         WARN_ON(refcount_read(&srv->refcount));
1370         for (i = 0; i < srv->queue_depth; i++)
1371                 mempool_free(srv->chunks[i], chunk_pool);
1372         kfree(srv->chunks);
1373         mutex_destroy(&srv->paths_mutex);
1374         mutex_destroy(&srv->paths_ev_mutex);
1375         /* last put to release the srv structure */
1376         put_device(&srv->dev);
1377 }
1378
1379 static inline struct rtrs_srv *__find_srv_and_get(struct rtrs_srv_ctx *ctx,
1380                                                    const uuid_t *paths_uuid)
1381 {
1382         struct rtrs_srv *srv;
1383
1384         list_for_each_entry(srv, &ctx->srv_list, ctx_list) {
1385                 if (uuid_equal(&srv->paths_uuid, paths_uuid) &&
1386                     refcount_inc_not_zero(&srv->refcount))
1387                         return srv;
1388         }
1389
1390         return NULL;
1391 }
1392
1393 static struct rtrs_srv *get_or_create_srv(struct rtrs_srv_ctx *ctx,
1394                                            const uuid_t *paths_uuid)
1395 {
1396         struct rtrs_srv *srv;
1397
1398         mutex_lock(&ctx->srv_mutex);
1399         srv = __find_srv_and_get(ctx, paths_uuid);
1400         if (!srv)
1401                 srv = __alloc_srv(ctx, paths_uuid);
1402         mutex_unlock(&ctx->srv_mutex);
1403
1404         return srv;
1405 }
1406
1407 static void put_srv(struct rtrs_srv *srv)
1408 {
1409         if (refcount_dec_and_test(&srv->refcount)) {
1410                 struct rtrs_srv_ctx *ctx = srv->ctx;
1411
1412                 WARN_ON(srv->dev.kobj.state_in_sysfs);
1413
1414                 mutex_lock(&ctx->srv_mutex);
1415                 list_del(&srv->ctx_list);
1416                 mutex_unlock(&ctx->srv_mutex);
1417                 free_srv(srv);
1418         }
1419 }
1420
1421 static void __add_path_to_srv(struct rtrs_srv *srv,
1422                               struct rtrs_srv_sess *sess)
1423 {
1424         list_add_tail(&sess->s.entry, &srv->paths_list);
1425         srv->paths_num++;
1426         WARN_ON(srv->paths_num >= MAX_PATHS_NUM);
1427 }
1428
1429 static void del_path_from_srv(struct rtrs_srv_sess *sess)
1430 {
1431         struct rtrs_srv *srv = sess->srv;
1432
1433         if (WARN_ON(!srv))
1434                 return;
1435
1436         mutex_lock(&srv->paths_mutex);
1437         list_del(&sess->s.entry);
1438         WARN_ON(!srv->paths_num);
1439         srv->paths_num--;
1440         mutex_unlock(&srv->paths_mutex);
1441 }
1442
1443 /* return true if addresses are the same, error other wise */
1444 static int sockaddr_cmp(const struct sockaddr *a, const struct sockaddr *b)
1445 {
1446         switch (a->sa_family) {
1447         case AF_IB:
1448                 return memcmp(&((struct sockaddr_ib *)a)->sib_addr,
1449                               &((struct sockaddr_ib *)b)->sib_addr,
1450                               sizeof(struct ib_addr)) &&
1451                         (b->sa_family == AF_IB);
1452         case AF_INET:
1453                 return memcmp(&((struct sockaddr_in *)a)->sin_addr,
1454                               &((struct sockaddr_in *)b)->sin_addr,
1455                               sizeof(struct in_addr)) &&
1456                         (b->sa_family == AF_INET);
1457         case AF_INET6:
1458                 return memcmp(&((struct sockaddr_in6 *)a)->sin6_addr,
1459                               &((struct sockaddr_in6 *)b)->sin6_addr,
1460                               sizeof(struct in6_addr)) &&
1461                         (b->sa_family == AF_INET6);
1462         default:
1463                 return -ENOENT;
1464         }
1465 }
1466
1467 static bool __is_path_w_addr_exists(struct rtrs_srv *srv,
1468                                     struct rdma_addr *addr)
1469 {
1470         struct rtrs_srv_sess *sess;
1471
1472         list_for_each_entry(sess, &srv->paths_list, s.entry)
1473                 if (!sockaddr_cmp((struct sockaddr *)&sess->s.dst_addr,
1474                                   (struct sockaddr *)&addr->dst_addr) &&
1475                     !sockaddr_cmp((struct sockaddr *)&sess->s.src_addr,
1476                                   (struct sockaddr *)&addr->src_addr))
1477                         return true;
1478
1479         return false;
1480 }
1481
1482 static void free_sess(struct rtrs_srv_sess *sess)
1483 {
1484         if (sess->kobj.state_in_sysfs)
1485                 kobject_put(&sess->kobj);
1486         else
1487                 kfree(sess);
1488 }
1489
1490 static void rtrs_srv_close_work(struct work_struct *work)
1491 {
1492         struct rtrs_srv_sess *sess;
1493         struct rtrs_srv_con *con;
1494         int i;
1495
1496         sess = container_of(work, typeof(*sess), close_work);
1497
1498         rtrs_srv_destroy_sess_files(sess);
1499         rtrs_srv_stop_hb(sess);
1500
1501         for (i = 0; i < sess->s.con_num; i++) {
1502                 if (!sess->s.con[i])
1503                         continue;
1504                 con = to_srv_con(sess->s.con[i]);
1505                 rdma_disconnect(con->c.cm_id);
1506                 ib_drain_qp(con->c.qp);
1507         }
1508         /* Wait for all inflights */
1509         rtrs_srv_wait_ops_ids(sess);
1510
1511         /* Notify upper layer if we are the last path */
1512         rtrs_srv_sess_down(sess);
1513
1514         unmap_cont_bufs(sess);
1515         rtrs_srv_free_ops_ids(sess);
1516
1517         for (i = 0; i < sess->s.con_num; i++) {
1518                 if (!sess->s.con[i])
1519                         continue;
1520                 con = to_srv_con(sess->s.con[i]);
1521                 rtrs_cq_qp_destroy(&con->c);
1522                 rdma_destroy_id(con->c.cm_id);
1523                 kfree(con);
1524         }
1525         rtrs_ib_dev_put(sess->s.dev);
1526
1527         del_path_from_srv(sess);
1528         put_srv(sess->srv);
1529         sess->srv = NULL;
1530         rtrs_srv_change_state(sess, RTRS_SRV_CLOSED);
1531
1532         kfree(sess->dma_addr);
1533         kfree(sess->s.con);
1534         free_sess(sess);
1535 }
1536
1537 static int rtrs_rdma_do_accept(struct rtrs_srv_sess *sess,
1538                                struct rdma_cm_id *cm_id)
1539 {
1540         struct rtrs_srv *srv = sess->srv;
1541         struct rtrs_msg_conn_rsp msg;
1542         struct rdma_conn_param param;
1543         int err;
1544
1545         param = (struct rdma_conn_param) {
1546                 .rnr_retry_count = 7,
1547                 .private_data = &msg,
1548                 .private_data_len = sizeof(msg),
1549         };
1550
1551         msg = (struct rtrs_msg_conn_rsp) {
1552                 .magic = cpu_to_le16(RTRS_MAGIC),
1553                 .version = cpu_to_le16(RTRS_PROTO_VER),
1554                 .queue_depth = cpu_to_le16(srv->queue_depth),
1555                 .max_io_size = cpu_to_le32(max_chunk_size - MAX_HDR_SIZE),
1556                 .max_hdr_size = cpu_to_le32(MAX_HDR_SIZE),
1557         };
1558
1559         if (always_invalidate)
1560                 msg.flags = cpu_to_le32(RTRS_MSG_NEW_RKEY_F);
1561
1562         err = rdma_accept(cm_id, &param);
1563         if (err)
1564                 pr_err("rdma_accept(), err: %d\n", err);
1565
1566         return err;
1567 }
1568
1569 static int rtrs_rdma_do_reject(struct rdma_cm_id *cm_id, int errno)
1570 {
1571         struct rtrs_msg_conn_rsp msg;
1572         int err;
1573
1574         msg = (struct rtrs_msg_conn_rsp) {
1575                 .magic = cpu_to_le16(RTRS_MAGIC),
1576                 .version = cpu_to_le16(RTRS_PROTO_VER),
1577                 .errno = cpu_to_le16(errno),
1578         };
1579
1580         err = rdma_reject(cm_id, &msg, sizeof(msg), IB_CM_REJ_CONSUMER_DEFINED);
1581         if (err)
1582                 pr_err("rdma_reject(), err: %d\n", err);
1583
1584         /* Bounce errno back */
1585         return errno;
1586 }
1587
1588 static struct rtrs_srv_sess *
1589 __find_sess(struct rtrs_srv *srv, const uuid_t *sess_uuid)
1590 {
1591         struct rtrs_srv_sess *sess;
1592
1593         list_for_each_entry(sess, &srv->paths_list, s.entry) {
1594                 if (uuid_equal(&sess->s.uuid, sess_uuid))
1595                         return sess;
1596         }
1597
1598         return NULL;
1599 }
1600
1601 static int create_con(struct rtrs_srv_sess *sess,
1602                       struct rdma_cm_id *cm_id,
1603                       unsigned int cid)
1604 {
1605         struct rtrs_srv *srv = sess->srv;
1606         struct rtrs_sess *s = &sess->s;
1607         struct rtrs_srv_con *con;
1608
1609         u16 cq_size, wr_queue_size;
1610         int err, cq_vector;
1611
1612         con = kzalloc(sizeof(*con), GFP_KERNEL);
1613         if (!con) {
1614                 err = -ENOMEM;
1615                 goto err;
1616         }
1617
1618         spin_lock_init(&con->rsp_wr_wait_lock);
1619         INIT_LIST_HEAD(&con->rsp_wr_wait_list);
1620         con->c.cm_id = cm_id;
1621         con->c.sess = &sess->s;
1622         con->c.cid = cid;
1623         atomic_set(&con->wr_cnt, 0);
1624
1625         if (con->c.cid == 0) {
1626                 /*
1627                  * All receive and all send (each requiring invalidate)
1628                  * + 2 for drain and heartbeat
1629                  */
1630                 wr_queue_size = SERVICE_CON_QUEUE_DEPTH * 3 + 2;
1631                 cq_size = wr_queue_size;
1632         } else {
1633                 /*
1634                  * If we have all receive requests posted and
1635                  * all write requests posted and each read request
1636                  * requires an invalidate request + drain
1637                  * and qp gets into error state.
1638                  */
1639                 cq_size = srv->queue_depth * 3 + 1;
1640                 /*
1641                  * In theory we might have queue_depth * 32
1642                  * outstanding requests if an unsafe global key is used
1643                  * and we have queue_depth read requests each consisting
1644                  * of 32 different addresses. div 3 for mlx5.
1645                  */
1646                 wr_queue_size = sess->s.dev->ib_dev->attrs.max_qp_wr / 3;
1647         }
1648         atomic_set(&con->sq_wr_avail, wr_queue_size);
1649         cq_vector = rtrs_srv_get_next_cq_vector(sess);
1650
1651         /* TODO: SOFTIRQ can be faster, but be careful with softirq context */
1652         err = rtrs_cq_qp_create(&sess->s, &con->c, 1, cq_vector, cq_size,
1653                                  wr_queue_size, IB_POLL_WORKQUEUE);
1654         if (err) {
1655                 rtrs_err(s, "rtrs_cq_qp_create(), err: %d\n", err);
1656                 goto free_con;
1657         }
1658         if (con->c.cid == 0) {
1659                 err = post_recv_info_req(con);
1660                 if (err)
1661                         goto free_cqqp;
1662         }
1663         WARN_ON(sess->s.con[cid]);
1664         sess->s.con[cid] = &con->c;
1665
1666         /*
1667          * Change context from server to current connection.  The other
1668          * way is to use cm_id->qp->qp_context, which does not work on OFED.
1669          */
1670         cm_id->context = &con->c;
1671
1672         return 0;
1673
1674 free_cqqp:
1675         rtrs_cq_qp_destroy(&con->c);
1676 free_con:
1677         kfree(con);
1678
1679 err:
1680         return err;
1681 }
1682
1683 static struct rtrs_srv_sess *__alloc_sess(struct rtrs_srv *srv,
1684                                            struct rdma_cm_id *cm_id,
1685                                            unsigned int con_num,
1686                                            unsigned int recon_cnt,
1687                                            const uuid_t *uuid)
1688 {
1689         struct rtrs_srv_sess *sess;
1690         int err = -ENOMEM;
1691
1692         if (srv->paths_num >= MAX_PATHS_NUM) {
1693                 err = -ECONNRESET;
1694                 goto err;
1695         }
1696         if (__is_path_w_addr_exists(srv, &cm_id->route.addr)) {
1697                 err = -EEXIST;
1698                 pr_err("Path with same addr exists\n");
1699                 goto err;
1700         }
1701         sess = kzalloc(sizeof(*sess), GFP_KERNEL);
1702         if (!sess)
1703                 goto err;
1704
1705         sess->stats = kzalloc(sizeof(*sess->stats), GFP_KERNEL);
1706         if (!sess->stats)
1707                 goto err_free_sess;
1708
1709         sess->stats->sess = sess;
1710
1711         sess->dma_addr = kcalloc(srv->queue_depth, sizeof(*sess->dma_addr),
1712                                  GFP_KERNEL);
1713         if (!sess->dma_addr)
1714                 goto err_free_stats;
1715
1716         sess->s.con = kcalloc(con_num, sizeof(*sess->s.con), GFP_KERNEL);
1717         if (!sess->s.con)
1718                 goto err_free_dma_addr;
1719
1720         sess->state = RTRS_SRV_CONNECTING;
1721         sess->srv = srv;
1722         sess->cur_cq_vector = -1;
1723         sess->s.dst_addr = cm_id->route.addr.dst_addr;
1724         sess->s.src_addr = cm_id->route.addr.src_addr;
1725         sess->s.con_num = con_num;
1726         sess->s.recon_cnt = recon_cnt;
1727         uuid_copy(&sess->s.uuid, uuid);
1728         spin_lock_init(&sess->state_lock);
1729         INIT_WORK(&sess->close_work, rtrs_srv_close_work);
1730         rtrs_srv_init_hb(sess);
1731
1732         sess->s.dev = rtrs_ib_dev_find_or_add(cm_id->device, &dev_pd);
1733         if (!sess->s.dev) {
1734                 err = -ENOMEM;
1735                 goto err_free_con;
1736         }
1737         err = map_cont_bufs(sess);
1738         if (err)
1739                 goto err_put_dev;
1740
1741         err = rtrs_srv_alloc_ops_ids(sess);
1742         if (err)
1743                 goto err_unmap_bufs;
1744
1745         __add_path_to_srv(srv, sess);
1746
1747         return sess;
1748
1749 err_unmap_bufs:
1750         unmap_cont_bufs(sess);
1751 err_put_dev:
1752         rtrs_ib_dev_put(sess->s.dev);
1753 err_free_con:
1754         kfree(sess->s.con);
1755 err_free_dma_addr:
1756         kfree(sess->dma_addr);
1757 err_free_stats:
1758         kfree(sess->stats);
1759 err_free_sess:
1760         kfree(sess);
1761 err:
1762         return ERR_PTR(err);
1763 }
1764
1765 static int rtrs_rdma_connect(struct rdma_cm_id *cm_id,
1766                               const struct rtrs_msg_conn_req *msg,
1767                               size_t len)
1768 {
1769         struct rtrs_srv_ctx *ctx = cm_id->context;
1770         struct rtrs_srv_sess *sess;
1771         struct rtrs_srv *srv;
1772
1773         u16 version, con_num, cid;
1774         u16 recon_cnt;
1775         int err;
1776
1777         if (len < sizeof(*msg)) {
1778                 pr_err("Invalid RTRS connection request\n");
1779                 goto reject_w_econnreset;
1780         }
1781         if (le16_to_cpu(msg->magic) != RTRS_MAGIC) {
1782                 pr_err("Invalid RTRS magic\n");
1783                 goto reject_w_econnreset;
1784         }
1785         version = le16_to_cpu(msg->version);
1786         if (version >> 8 != RTRS_PROTO_VER_MAJOR) {
1787                 pr_err("Unsupported major RTRS version: %d, expected %d\n",
1788                        version >> 8, RTRS_PROTO_VER_MAJOR);
1789                 goto reject_w_econnreset;
1790         }
1791         con_num = le16_to_cpu(msg->cid_num);
1792         if (con_num > 4096) {
1793                 /* Sanity check */
1794                 pr_err("Too many connections requested: %d\n", con_num);
1795                 goto reject_w_econnreset;
1796         }
1797         cid = le16_to_cpu(msg->cid);
1798         if (cid >= con_num) {
1799                 /* Sanity check */
1800                 pr_err("Incorrect cid: %d >= %d\n", cid, con_num);
1801                 goto reject_w_econnreset;
1802         }
1803         recon_cnt = le16_to_cpu(msg->recon_cnt);
1804         srv = get_or_create_srv(ctx, &msg->paths_uuid);
1805         if (!srv) {
1806                 err = -ENOMEM;
1807                 goto reject_w_err;
1808         }
1809         mutex_lock(&srv->paths_mutex);
1810         sess = __find_sess(srv, &msg->sess_uuid);
1811         if (sess) {
1812                 struct rtrs_sess *s = &sess->s;
1813
1814                 /* Session already holds a reference */
1815                 put_srv(srv);
1816
1817                 if (sess->state != RTRS_SRV_CONNECTING) {
1818                         rtrs_err(s, "Session in wrong state: %s\n",
1819                                   rtrs_srv_state_str(sess->state));
1820                         mutex_unlock(&srv->paths_mutex);
1821                         goto reject_w_econnreset;
1822                 }
1823                 /*
1824                  * Sanity checks
1825                  */
1826                 if (con_num != s->con_num || cid >= s->con_num) {
1827                         rtrs_err(s, "Incorrect request: %d, %d\n",
1828                                   cid, con_num);
1829                         mutex_unlock(&srv->paths_mutex);
1830                         goto reject_w_econnreset;
1831                 }
1832                 if (s->con[cid]) {
1833                         rtrs_err(s, "Connection already exists: %d\n",
1834                                   cid);
1835                         mutex_unlock(&srv->paths_mutex);
1836                         goto reject_w_econnreset;
1837                 }
1838         } else {
1839                 sess = __alloc_sess(srv, cm_id, con_num, recon_cnt,
1840                                     &msg->sess_uuid);
1841                 if (IS_ERR(sess)) {
1842                         mutex_unlock(&srv->paths_mutex);
1843                         put_srv(srv);
1844                         err = PTR_ERR(sess);
1845                         goto reject_w_err;
1846                 }
1847         }
1848         err = create_con(sess, cm_id, cid);
1849         if (err) {
1850                 (void)rtrs_rdma_do_reject(cm_id, err);
1851                 /*
1852                  * Since session has other connections we follow normal way
1853                  * through workqueue, but still return an error to tell cma.c
1854                  * to call rdma_destroy_id() for current connection.
1855                  */
1856                 goto close_and_return_err;
1857         }
1858         err = rtrs_rdma_do_accept(sess, cm_id);
1859         if (err) {
1860                 (void)rtrs_rdma_do_reject(cm_id, err);
1861                 /*
1862                  * Since current connection was successfully added to the
1863                  * session we follow normal way through workqueue to close the
1864                  * session, thus return 0 to tell cma.c we call
1865                  * rdma_destroy_id() ourselves.
1866                  */
1867                 err = 0;
1868                 goto close_and_return_err;
1869         }
1870         mutex_unlock(&srv->paths_mutex);
1871
1872         return 0;
1873
1874 reject_w_err:
1875         return rtrs_rdma_do_reject(cm_id, err);
1876
1877 reject_w_econnreset:
1878         return rtrs_rdma_do_reject(cm_id, -ECONNRESET);
1879
1880 close_and_return_err:
1881         close_sess(sess);
1882         mutex_unlock(&srv->paths_mutex);
1883
1884         return err;
1885 }
1886
1887 static int rtrs_srv_rdma_cm_handler(struct rdma_cm_id *cm_id,
1888                                      struct rdma_cm_event *ev)
1889 {
1890         struct rtrs_srv_sess *sess = NULL;
1891         struct rtrs_sess *s = NULL;
1892
1893         if (ev->event != RDMA_CM_EVENT_CONNECT_REQUEST) {
1894                 struct rtrs_con *c = cm_id->context;
1895
1896                 s = c->sess;
1897                 sess = to_srv_sess(s);
1898         }
1899
1900         switch (ev->event) {
1901         case RDMA_CM_EVENT_CONNECT_REQUEST:
1902                 /*
1903                  * In case of error cma.c will destroy cm_id,
1904                  * see cma_process_remove()
1905                  */
1906                 return rtrs_rdma_connect(cm_id, ev->param.conn.private_data,
1907                                           ev->param.conn.private_data_len);
1908         case RDMA_CM_EVENT_ESTABLISHED:
1909                 /* Nothing here */
1910                 break;
1911         case RDMA_CM_EVENT_REJECTED:
1912         case RDMA_CM_EVENT_CONNECT_ERROR:
1913         case RDMA_CM_EVENT_UNREACHABLE:
1914                 rtrs_err(s, "CM error (CM event: %s, err: %d)\n",
1915                           rdma_event_msg(ev->event), ev->status);
1916                 close_sess(sess);
1917                 break;
1918         case RDMA_CM_EVENT_DISCONNECTED:
1919         case RDMA_CM_EVENT_ADDR_CHANGE:
1920         case RDMA_CM_EVENT_TIMEWAIT_EXIT:
1921                 close_sess(sess);
1922                 break;
1923         case RDMA_CM_EVENT_DEVICE_REMOVAL:
1924                 close_sess(sess);
1925                 break;
1926         default:
1927                 pr_err("Ignoring unexpected CM event %s, err %d\n",
1928                        rdma_event_msg(ev->event), ev->status);
1929                 break;
1930         }
1931
1932         return 0;
1933 }
1934
1935 static struct rdma_cm_id *rtrs_srv_cm_init(struct rtrs_srv_ctx *ctx,
1936                                             struct sockaddr *addr,
1937                                             enum rdma_ucm_port_space ps)
1938 {
1939         struct rdma_cm_id *cm_id;
1940         int ret;
1941
1942         cm_id = rdma_create_id(&init_net, rtrs_srv_rdma_cm_handler,
1943                                ctx, ps, IB_QPT_RC);
1944         if (IS_ERR(cm_id)) {
1945                 ret = PTR_ERR(cm_id);
1946                 pr_err("Creating id for RDMA connection failed, err: %d\n",
1947                        ret);
1948                 goto err_out;
1949         }
1950         ret = rdma_bind_addr(cm_id, addr);
1951         if (ret) {
1952                 pr_err("Binding RDMA address failed, err: %d\n", ret);
1953                 goto err_cm;
1954         }
1955         ret = rdma_listen(cm_id, 64);
1956         if (ret) {
1957                 pr_err("Listening on RDMA connection failed, err: %d\n",
1958                        ret);
1959                 goto err_cm;
1960         }
1961
1962         return cm_id;
1963
1964 err_cm:
1965         rdma_destroy_id(cm_id);
1966 err_out:
1967
1968         return ERR_PTR(ret);
1969 }
1970
1971 static int rtrs_srv_rdma_init(struct rtrs_srv_ctx *ctx, u16 port)
1972 {
1973         struct sockaddr_in6 sin = {
1974                 .sin6_family    = AF_INET6,
1975                 .sin6_addr      = IN6ADDR_ANY_INIT,
1976                 .sin6_port      = htons(port),
1977         };
1978         struct sockaddr_ib sib = {
1979                 .sib_family                     = AF_IB,
1980                 .sib_sid        = cpu_to_be64(RDMA_IB_IP_PS_IB | port),
1981                 .sib_sid_mask   = cpu_to_be64(0xffffffffffffffffULL),
1982                 .sib_pkey       = cpu_to_be16(0xffff),
1983         };
1984         struct rdma_cm_id *cm_ip, *cm_ib;
1985         int ret;
1986
1987         /*
1988          * We accept both IPoIB and IB connections, so we need to keep
1989          * two cm id's, one for each socket type and port space.
1990          * If the cm initialization of one of the id's fails, we abort
1991          * everything.
1992          */
1993         cm_ip = rtrs_srv_cm_init(ctx, (struct sockaddr *)&sin, RDMA_PS_TCP);
1994         if (IS_ERR(cm_ip))
1995                 return PTR_ERR(cm_ip);
1996
1997         cm_ib = rtrs_srv_cm_init(ctx, (struct sockaddr *)&sib, RDMA_PS_IB);
1998         if (IS_ERR(cm_ib)) {
1999                 ret = PTR_ERR(cm_ib);
2000                 goto free_cm_ip;
2001         }
2002
2003         ctx->cm_id_ip = cm_ip;
2004         ctx->cm_id_ib = cm_ib;
2005
2006         return 0;
2007
2008 free_cm_ip:
2009         rdma_destroy_id(cm_ip);
2010
2011         return ret;
2012 }
2013
2014 static struct rtrs_srv_ctx *alloc_srv_ctx(struct rtrs_srv_ops *ops)
2015 {
2016         struct rtrs_srv_ctx *ctx;
2017
2018         ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
2019         if (!ctx)
2020                 return NULL;
2021
2022         ctx->ops = *ops;
2023         mutex_init(&ctx->srv_mutex);
2024         INIT_LIST_HEAD(&ctx->srv_list);
2025
2026         return ctx;
2027 }
2028
2029 static void free_srv_ctx(struct rtrs_srv_ctx *ctx)
2030 {
2031         WARN_ON(!list_empty(&ctx->srv_list));
2032         mutex_destroy(&ctx->srv_mutex);
2033         kfree(ctx);
2034 }
2035
2036 /**
2037  * rtrs_srv_open() - open RTRS server context
2038  * @ops:                callback functions
2039  * @port:               port to listen on
2040  *
2041  * Creates server context with specified callbacks.
2042  *
2043  * Return a valid pointer on success otherwise PTR_ERR.
2044  */
2045 struct rtrs_srv_ctx *rtrs_srv_open(struct rtrs_srv_ops *ops, u16 port)
2046 {
2047         struct rtrs_srv_ctx *ctx;
2048         int err;
2049
2050         ctx = alloc_srv_ctx(ops);
2051         if (!ctx)
2052                 return ERR_PTR(-ENOMEM);
2053
2054         err = rtrs_srv_rdma_init(ctx, port);
2055         if (err) {
2056                 free_srv_ctx(ctx);
2057                 return ERR_PTR(err);
2058         }
2059
2060         return ctx;
2061 }
2062 EXPORT_SYMBOL(rtrs_srv_open);
2063
2064 static void close_sessions(struct rtrs_srv *srv)
2065 {
2066         struct rtrs_srv_sess *sess;
2067
2068         mutex_lock(&srv->paths_mutex);
2069         list_for_each_entry(sess, &srv->paths_list, s.entry)
2070                 close_sess(sess);
2071         mutex_unlock(&srv->paths_mutex);
2072 }
2073
2074 static void close_ctx(struct rtrs_srv_ctx *ctx)
2075 {
2076         struct rtrs_srv *srv;
2077
2078         mutex_lock(&ctx->srv_mutex);
2079         list_for_each_entry(srv, &ctx->srv_list, ctx_list)
2080                 close_sessions(srv);
2081         mutex_unlock(&ctx->srv_mutex);
2082         flush_workqueue(rtrs_wq);
2083 }
2084
2085 /**
2086  * rtrs_srv_close() - close RTRS server context
2087  * @ctx: pointer to server context
2088  *
2089  * Closes RTRS server context with all client sessions.
2090  */
2091 void rtrs_srv_close(struct rtrs_srv_ctx *ctx)
2092 {
2093         rdma_destroy_id(ctx->cm_id_ip);
2094         rdma_destroy_id(ctx->cm_id_ib);
2095         close_ctx(ctx);
2096         free_srv_ctx(ctx);
2097 }
2098 EXPORT_SYMBOL(rtrs_srv_close);
2099
2100 static int check_module_params(void)
2101 {
2102         if (sess_queue_depth < 1 || sess_queue_depth > MAX_SESS_QUEUE_DEPTH) {
2103                 pr_err("Invalid sess_queue_depth value %d, has to be >= %d, <= %d.\n",
2104                        sess_queue_depth, 1, MAX_SESS_QUEUE_DEPTH);
2105                 return -EINVAL;
2106         }
2107         if (max_chunk_size < 4096 || !is_power_of_2(max_chunk_size)) {
2108                 pr_err("Invalid max_chunk_size value %d, has to be >= %d and should be power of two.\n",
2109                        max_chunk_size, 4096);
2110                 return -EINVAL;
2111         }
2112
2113         /*
2114          * Check if IB immediate data size is enough to hold the mem_id and the
2115          * offset inside the memory chunk
2116          */
2117         if ((ilog2(sess_queue_depth - 1) + 1) +
2118             (ilog2(max_chunk_size - 1) + 1) > MAX_IMM_PAYL_BITS) {
2119                 pr_err("RDMA immediate size (%db) not enough to encode %d buffers of size %dB. Reduce 'sess_queue_depth' or 'max_chunk_size' parameters.\n",
2120                        MAX_IMM_PAYL_BITS, sess_queue_depth, max_chunk_size);
2121                 return -EINVAL;
2122         }
2123
2124         return 0;
2125 }
2126
2127 static int __init rtrs_server_init(void)
2128 {
2129         int err;
2130
2131         pr_info("Loading module %s, proto %s: (max_chunk_size: %d (pure IO %ld, headers %ld) , sess_queue_depth: %d, always_invalidate: %d)\n",
2132                 KBUILD_MODNAME, RTRS_PROTO_VER_STRING,
2133                 max_chunk_size, max_chunk_size - MAX_HDR_SIZE, MAX_HDR_SIZE,
2134                 sess_queue_depth, always_invalidate);
2135
2136         rtrs_rdma_dev_pd_init(0, &dev_pd);
2137
2138         err = check_module_params();
2139         if (err) {
2140                 pr_err("Failed to load module, invalid module parameters, err: %d\n",
2141                        err);
2142                 return err;
2143         }
2144         chunk_pool = mempool_create_page_pool(sess_queue_depth * CHUNK_POOL_SZ,
2145                                               get_order(max_chunk_size));
2146         if (!chunk_pool)
2147                 return -ENOMEM;
2148         rtrs_dev_class = class_create(THIS_MODULE, "rtrs-server");
2149         if (IS_ERR(rtrs_dev_class)) {
2150                 err = PTR_ERR(rtrs_dev_class);
2151                 goto out_chunk_pool;
2152         }
2153         rtrs_wq = alloc_workqueue("rtrs_server_wq", 0, 0);
2154         if (!rtrs_wq) {
2155                 err = -ENOMEM;
2156                 goto out_dev_class;
2157         }
2158
2159         return 0;
2160
2161 out_dev_class:
2162         class_destroy(rtrs_dev_class);
2163 out_chunk_pool:
2164         mempool_destroy(chunk_pool);
2165
2166         return err;
2167 }
2168
2169 static void __exit rtrs_server_exit(void)
2170 {
2171         destroy_workqueue(rtrs_wq);
2172         class_destroy(rtrs_dev_class);
2173         mempool_destroy(chunk_pool);
2174         rtrs_rdma_dev_pd_deinit(&dev_pd);
2175 }
2176
2177 module_init(rtrs_server_init);
2178 module_exit(rtrs_server_exit);