Merge tag 'for-linus-iommufd' of git://git.kernel.org/pub/scm/linux/kernel/git/jgg...
[linux-2.6-microblaze.git] / drivers / vfio / pci / mlx5 / main.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved
4  */
5
6 #include <linux/device.h>
7 #include <linux/eventfd.h>
8 #include <linux/file.h>
9 #include <linux/interrupt.h>
10 #include <linux/iommu.h>
11 #include <linux/module.h>
12 #include <linux/mutex.h>
13 #include <linux/notifier.h>
14 #include <linux/pci.h>
15 #include <linux/pm_runtime.h>
16 #include <linux/types.h>
17 #include <linux/uaccess.h>
18 #include <linux/vfio.h>
19 #include <linux/sched/mm.h>
20 #include <linux/anon_inodes.h>
21
22 #include "cmd.h"
23
24 /* Device specification max LOAD size */
25 #define MAX_LOAD_SIZE (BIT_ULL(__mlx5_bit_sz(load_vhca_state_in, size)) - 1)
26
27 #define MAX_CHUNK_SIZE SZ_8M
28
29 static struct mlx5vf_pci_core_device *mlx5vf_drvdata(struct pci_dev *pdev)
30 {
31         struct vfio_pci_core_device *core_device = dev_get_drvdata(&pdev->dev);
32
33         return container_of(core_device, struct mlx5vf_pci_core_device,
34                             core_device);
35 }
36
37 struct page *
38 mlx5vf_get_migration_page(struct mlx5_vhca_data_buffer *buf,
39                           unsigned long offset)
40 {
41         unsigned long cur_offset = 0;
42         struct scatterlist *sg;
43         unsigned int i;
44
45         /* All accesses are sequential */
46         if (offset < buf->last_offset || !buf->last_offset_sg) {
47                 buf->last_offset = 0;
48                 buf->last_offset_sg = buf->table.sgt.sgl;
49                 buf->sg_last_entry = 0;
50         }
51
52         cur_offset = buf->last_offset;
53
54         for_each_sg(buf->last_offset_sg, sg,
55                         buf->table.sgt.orig_nents - buf->sg_last_entry, i) {
56                 if (offset < sg->length + cur_offset) {
57                         buf->last_offset_sg = sg;
58                         buf->sg_last_entry += i;
59                         buf->last_offset = cur_offset;
60                         return nth_page(sg_page(sg),
61                                         (offset - cur_offset) / PAGE_SIZE);
62                 }
63                 cur_offset += sg->length;
64         }
65         return NULL;
66 }
67
68 int mlx5vf_add_migration_pages(struct mlx5_vhca_data_buffer *buf,
69                                unsigned int npages)
70 {
71         unsigned int to_alloc = npages;
72         struct page **page_list;
73         unsigned long filled;
74         unsigned int to_fill;
75         int ret;
76
77         to_fill = min_t(unsigned int, npages, PAGE_SIZE / sizeof(*page_list));
78         page_list = kvzalloc(to_fill * sizeof(*page_list), GFP_KERNEL_ACCOUNT);
79         if (!page_list)
80                 return -ENOMEM;
81
82         do {
83                 filled = alloc_pages_bulk_array(GFP_KERNEL_ACCOUNT, to_fill,
84                                                 page_list);
85                 if (!filled) {
86                         ret = -ENOMEM;
87                         goto err;
88                 }
89                 to_alloc -= filled;
90                 ret = sg_alloc_append_table_from_pages(
91                         &buf->table, page_list, filled, 0,
92                         filled << PAGE_SHIFT, UINT_MAX, SG_MAX_SINGLE_ALLOC,
93                         GFP_KERNEL_ACCOUNT);
94
95                 if (ret)
96                         goto err;
97                 buf->allocated_length += filled * PAGE_SIZE;
98                 /* clean input for another bulk allocation */
99                 memset(page_list, 0, filled * sizeof(*page_list));
100                 to_fill = min_t(unsigned int, to_alloc,
101                                 PAGE_SIZE / sizeof(*page_list));
102         } while (to_alloc > 0);
103
104         kvfree(page_list);
105         return 0;
106
107 err:
108         kvfree(page_list);
109         return ret;
110 }
111
112 static void mlx5vf_disable_fd(struct mlx5_vf_migration_file *migf)
113 {
114         mutex_lock(&migf->lock);
115         migf->state = MLX5_MIGF_STATE_ERROR;
116         migf->filp->f_pos = 0;
117         mutex_unlock(&migf->lock);
118 }
119
120 static int mlx5vf_release_file(struct inode *inode, struct file *filp)
121 {
122         struct mlx5_vf_migration_file *migf = filp->private_data;
123
124         mlx5vf_disable_fd(migf);
125         mutex_destroy(&migf->lock);
126         kfree(migf);
127         return 0;
128 }
129
130 static struct mlx5_vhca_data_buffer *
131 mlx5vf_get_data_buff_from_pos(struct mlx5_vf_migration_file *migf, loff_t pos,
132                               bool *end_of_data)
133 {
134         struct mlx5_vhca_data_buffer *buf;
135         bool found = false;
136
137         *end_of_data = false;
138         spin_lock_irq(&migf->list_lock);
139         if (list_empty(&migf->buf_list)) {
140                 *end_of_data = true;
141                 goto end;
142         }
143
144         buf = list_first_entry(&migf->buf_list, struct mlx5_vhca_data_buffer,
145                                buf_elm);
146         if (pos >= buf->start_pos &&
147             pos < buf->start_pos + buf->length) {
148                 found = true;
149                 goto end;
150         }
151
152         /*
153          * As we use a stream based FD we may expect having the data always
154          * on first chunk
155          */
156         migf->state = MLX5_MIGF_STATE_ERROR;
157
158 end:
159         spin_unlock_irq(&migf->list_lock);
160         return found ? buf : NULL;
161 }
162
163 static void mlx5vf_buf_read_done(struct mlx5_vhca_data_buffer *vhca_buf)
164 {
165         struct mlx5_vf_migration_file *migf = vhca_buf->migf;
166
167         if (vhca_buf->stop_copy_chunk_num) {
168                 bool is_header = vhca_buf->dma_dir == DMA_NONE;
169                 u8 chunk_num = vhca_buf->stop_copy_chunk_num;
170                 size_t next_required_umem_size = 0;
171
172                 if (is_header)
173                         migf->buf_header[chunk_num - 1] = vhca_buf;
174                 else
175                         migf->buf[chunk_num - 1] = vhca_buf;
176
177                 spin_lock_irq(&migf->list_lock);
178                 list_del_init(&vhca_buf->buf_elm);
179                 if (!is_header) {
180                         next_required_umem_size =
181                                 migf->next_required_umem_size;
182                         migf->next_required_umem_size = 0;
183                         migf->num_ready_chunks--;
184                 }
185                 spin_unlock_irq(&migf->list_lock);
186                 if (next_required_umem_size)
187                         mlx5vf_mig_file_set_save_work(migf, chunk_num,
188                                                       next_required_umem_size);
189                 return;
190         }
191
192         spin_lock_irq(&migf->list_lock);
193         list_del_init(&vhca_buf->buf_elm);
194         list_add_tail(&vhca_buf->buf_elm, &vhca_buf->migf->avail_list);
195         spin_unlock_irq(&migf->list_lock);
196 }
197
198 static ssize_t mlx5vf_buf_read(struct mlx5_vhca_data_buffer *vhca_buf,
199                                char __user **buf, size_t *len, loff_t *pos)
200 {
201         unsigned long offset;
202         ssize_t done = 0;
203         size_t copy_len;
204
205         copy_len = min_t(size_t,
206                          vhca_buf->start_pos + vhca_buf->length - *pos, *len);
207         while (copy_len) {
208                 size_t page_offset;
209                 struct page *page;
210                 size_t page_len;
211                 u8 *from_buff;
212                 int ret;
213
214                 offset = *pos - vhca_buf->start_pos;
215                 page_offset = offset % PAGE_SIZE;
216                 offset -= page_offset;
217                 page = mlx5vf_get_migration_page(vhca_buf, offset);
218                 if (!page)
219                         return -EINVAL;
220                 page_len = min_t(size_t, copy_len, PAGE_SIZE - page_offset);
221                 from_buff = kmap_local_page(page);
222                 ret = copy_to_user(*buf, from_buff + page_offset, page_len);
223                 kunmap_local(from_buff);
224                 if (ret)
225                         return -EFAULT;
226                 *pos += page_len;
227                 *len -= page_len;
228                 *buf += page_len;
229                 done += page_len;
230                 copy_len -= page_len;
231         }
232
233         if (*pos >= vhca_buf->start_pos + vhca_buf->length)
234                 mlx5vf_buf_read_done(vhca_buf);
235
236         return done;
237 }
238
239 static ssize_t mlx5vf_save_read(struct file *filp, char __user *buf, size_t len,
240                                loff_t *pos)
241 {
242         struct mlx5_vf_migration_file *migf = filp->private_data;
243         struct mlx5_vhca_data_buffer *vhca_buf;
244         bool first_loop_call = true;
245         bool end_of_data;
246         ssize_t done = 0;
247
248         if (pos)
249                 return -ESPIPE;
250         pos = &filp->f_pos;
251
252         if (!(filp->f_flags & O_NONBLOCK)) {
253                 if (wait_event_interruptible(migf->poll_wait,
254                                 !list_empty(&migf->buf_list) ||
255                                 migf->state == MLX5_MIGF_STATE_ERROR ||
256                                 migf->state == MLX5_MIGF_STATE_PRE_COPY_ERROR ||
257                                 migf->state == MLX5_MIGF_STATE_PRE_COPY ||
258                                 migf->state == MLX5_MIGF_STATE_COMPLETE))
259                         return -ERESTARTSYS;
260         }
261
262         mutex_lock(&migf->lock);
263         if (migf->state == MLX5_MIGF_STATE_ERROR) {
264                 done = -ENODEV;
265                 goto out_unlock;
266         }
267
268         while (len) {
269                 ssize_t count;
270
271                 vhca_buf = mlx5vf_get_data_buff_from_pos(migf, *pos,
272                                                          &end_of_data);
273                 if (first_loop_call) {
274                         first_loop_call = false;
275                         /* Temporary end of file as part of PRE_COPY */
276                         if (end_of_data && (migf->state == MLX5_MIGF_STATE_PRE_COPY ||
277                                 migf->state == MLX5_MIGF_STATE_PRE_COPY_ERROR)) {
278                                 done = -ENOMSG;
279                                 goto out_unlock;
280                         }
281
282                         if (end_of_data && migf->state != MLX5_MIGF_STATE_COMPLETE) {
283                                 if (filp->f_flags & O_NONBLOCK) {
284                                         done = -EAGAIN;
285                                         goto out_unlock;
286                                 }
287                         }
288                 }
289
290                 if (end_of_data)
291                         goto out_unlock;
292
293                 if (!vhca_buf) {
294                         done = -EINVAL;
295                         goto out_unlock;
296                 }
297
298                 count = mlx5vf_buf_read(vhca_buf, &buf, &len, pos);
299                 if (count < 0) {
300                         done = count;
301                         goto out_unlock;
302                 }
303                 done += count;
304         }
305
306 out_unlock:
307         mutex_unlock(&migf->lock);
308         return done;
309 }
310
311 static __poll_t mlx5vf_save_poll(struct file *filp,
312                                  struct poll_table_struct *wait)
313 {
314         struct mlx5_vf_migration_file *migf = filp->private_data;
315         __poll_t pollflags = 0;
316
317         poll_wait(filp, &migf->poll_wait, wait);
318
319         mutex_lock(&migf->lock);
320         if (migf->state == MLX5_MIGF_STATE_ERROR)
321                 pollflags = EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
322         else if (!list_empty(&migf->buf_list) ||
323                  migf->state == MLX5_MIGF_STATE_COMPLETE)
324                 pollflags = EPOLLIN | EPOLLRDNORM;
325         mutex_unlock(&migf->lock);
326
327         return pollflags;
328 }
329
330 /*
331  * FD is exposed and user can use it after receiving an error.
332  * Mark migf in error, and wake the user.
333  */
334 static void mlx5vf_mark_err(struct mlx5_vf_migration_file *migf)
335 {
336         migf->state = MLX5_MIGF_STATE_ERROR;
337         wake_up_interruptible(&migf->poll_wait);
338 }
339
340 void mlx5vf_mig_file_set_save_work(struct mlx5_vf_migration_file *migf,
341                                    u8 chunk_num, size_t next_required_umem_size)
342 {
343         migf->save_data[chunk_num - 1].next_required_umem_size =
344                         next_required_umem_size;
345         migf->save_data[chunk_num - 1].migf = migf;
346         get_file(migf->filp);
347         queue_work(migf->mvdev->cb_wq,
348                    &migf->save_data[chunk_num - 1].work);
349 }
350
351 static struct mlx5_vhca_data_buffer *
352 mlx5vf_mig_file_get_stop_copy_buf(struct mlx5_vf_migration_file *migf,
353                                   u8 index, size_t required_length)
354 {
355         struct mlx5_vhca_data_buffer *buf = migf->buf[index];
356         u8 chunk_num;
357
358         WARN_ON(!buf);
359         chunk_num = buf->stop_copy_chunk_num;
360         buf->migf->buf[index] = NULL;
361         /* Checking whether the pre-allocated buffer can fit */
362         if (buf->allocated_length >= required_length)
363                 return buf;
364
365         mlx5vf_put_data_buffer(buf);
366         buf = mlx5vf_get_data_buffer(buf->migf, required_length,
367                                      DMA_FROM_DEVICE);
368         if (IS_ERR(buf))
369                 return buf;
370
371         buf->stop_copy_chunk_num = chunk_num;
372         return buf;
373 }
374
375 static void mlx5vf_mig_file_save_work(struct work_struct *_work)
376 {
377         struct mlx5vf_save_work_data *save_data = container_of(_work,
378                 struct mlx5vf_save_work_data, work);
379         struct mlx5_vf_migration_file *migf = save_data->migf;
380         struct mlx5vf_pci_core_device *mvdev = migf->mvdev;
381         struct mlx5_vhca_data_buffer *buf;
382
383         mutex_lock(&mvdev->state_mutex);
384         if (migf->state == MLX5_MIGF_STATE_ERROR)
385                 goto end;
386
387         buf = mlx5vf_mig_file_get_stop_copy_buf(migf,
388                                 save_data->chunk_num - 1,
389                                 save_data->next_required_umem_size);
390         if (IS_ERR(buf))
391                 goto err;
392
393         if (mlx5vf_cmd_save_vhca_state(mvdev, migf, buf, true, false))
394                 goto err_save;
395
396         goto end;
397
398 err_save:
399         mlx5vf_put_data_buffer(buf);
400 err:
401         mlx5vf_mark_err(migf);
402 end:
403         mlx5vf_state_mutex_unlock(mvdev);
404         fput(migf->filp);
405 }
406
407 static int mlx5vf_add_stop_copy_header(struct mlx5_vf_migration_file *migf,
408                                        bool track)
409 {
410         size_t size = sizeof(struct mlx5_vf_migration_header) +
411                 sizeof(struct mlx5_vf_migration_tag_stop_copy_data);
412         struct mlx5_vf_migration_tag_stop_copy_data data = {};
413         struct mlx5_vhca_data_buffer *header_buf = NULL;
414         struct mlx5_vf_migration_header header = {};
415         unsigned long flags;
416         struct page *page;
417         u8 *to_buff;
418         int ret;
419
420         header_buf = mlx5vf_get_data_buffer(migf, size, DMA_NONE);
421         if (IS_ERR(header_buf))
422                 return PTR_ERR(header_buf);
423
424         header.record_size = cpu_to_le64(sizeof(data));
425         header.flags = cpu_to_le32(MLX5_MIGF_HEADER_FLAGS_TAG_OPTIONAL);
426         header.tag = cpu_to_le32(MLX5_MIGF_HEADER_TAG_STOP_COPY_SIZE);
427         page = mlx5vf_get_migration_page(header_buf, 0);
428         if (!page) {
429                 ret = -EINVAL;
430                 goto err;
431         }
432         to_buff = kmap_local_page(page);
433         memcpy(to_buff, &header, sizeof(header));
434         header_buf->length = sizeof(header);
435         data.stop_copy_size = cpu_to_le64(migf->buf[0]->allocated_length);
436         memcpy(to_buff + sizeof(header), &data, sizeof(data));
437         header_buf->length += sizeof(data);
438         kunmap_local(to_buff);
439         header_buf->start_pos = header_buf->migf->max_pos;
440         migf->max_pos += header_buf->length;
441         spin_lock_irqsave(&migf->list_lock, flags);
442         list_add_tail(&header_buf->buf_elm, &migf->buf_list);
443         spin_unlock_irqrestore(&migf->list_lock, flags);
444         if (track)
445                 migf->pre_copy_initial_bytes = size;
446         return 0;
447 err:
448         mlx5vf_put_data_buffer(header_buf);
449         return ret;
450 }
451
452 static int mlx5vf_prep_stop_copy(struct mlx5vf_pci_core_device *mvdev,
453                                  struct mlx5_vf_migration_file *migf,
454                                  size_t state_size, u64 full_size,
455                                  bool track)
456 {
457         struct mlx5_vhca_data_buffer *buf;
458         size_t inc_state_size;
459         int num_chunks;
460         int ret;
461         int i;
462
463         if (mvdev->chunk_mode) {
464                 size_t chunk_size = min_t(size_t, MAX_CHUNK_SIZE, full_size);
465
466                 /* from firmware perspective at least 'state_size' buffer should be set */
467                 inc_state_size = max(state_size, chunk_size);
468         } else {
469                 if (track) {
470                         /* let's be ready for stop_copy size that might grow by 10 percents */
471                         if (check_add_overflow(state_size, state_size / 10, &inc_state_size))
472                                 inc_state_size = state_size;
473                 } else {
474                         inc_state_size = state_size;
475                 }
476         }
477
478         /* let's not overflow the device specification max SAVE size */
479         inc_state_size = min_t(size_t, inc_state_size,
480                 (BIT_ULL(__mlx5_bit_sz(save_vhca_state_in, size)) - PAGE_SIZE));
481
482         num_chunks = mvdev->chunk_mode ? MAX_NUM_CHUNKS : 1;
483         for (i = 0; i < num_chunks; i++) {
484                 buf = mlx5vf_get_data_buffer(migf, inc_state_size, DMA_FROM_DEVICE);
485                 if (IS_ERR(buf)) {
486                         ret = PTR_ERR(buf);
487                         goto err;
488                 }
489
490                 migf->buf[i] = buf;
491                 buf = mlx5vf_get_data_buffer(migf,
492                                 sizeof(struct mlx5_vf_migration_header), DMA_NONE);
493                 if (IS_ERR(buf)) {
494                         ret = PTR_ERR(buf);
495                         goto err;
496                 }
497                 migf->buf_header[i] = buf;
498                 if (mvdev->chunk_mode) {
499                         migf->buf[i]->stop_copy_chunk_num = i + 1;
500                         migf->buf_header[i]->stop_copy_chunk_num = i + 1;
501                         INIT_WORK(&migf->save_data[i].work,
502                                   mlx5vf_mig_file_save_work);
503                         migf->save_data[i].chunk_num = i + 1;
504                 }
505         }
506
507         ret = mlx5vf_add_stop_copy_header(migf, track);
508         if (ret)
509                 goto err;
510         return 0;
511
512 err:
513         for (i = 0; i < num_chunks; i++) {
514                 if (migf->buf[i]) {
515                         mlx5vf_put_data_buffer(migf->buf[i]);
516                         migf->buf[i] = NULL;
517                 }
518                 if (migf->buf_header[i]) {
519                         mlx5vf_put_data_buffer(migf->buf_header[i]);
520                         migf->buf_header[i] = NULL;
521                 }
522         }
523
524         return ret;
525 }
526
527 static long mlx5vf_precopy_ioctl(struct file *filp, unsigned int cmd,
528                                  unsigned long arg)
529 {
530         struct mlx5_vf_migration_file *migf = filp->private_data;
531         struct mlx5vf_pci_core_device *mvdev = migf->mvdev;
532         struct mlx5_vhca_data_buffer *buf;
533         struct vfio_precopy_info info = {};
534         loff_t *pos = &filp->f_pos;
535         unsigned long minsz;
536         size_t inc_length = 0;
537         bool end_of_data = false;
538         int ret;
539
540         if (cmd != VFIO_MIG_GET_PRECOPY_INFO)
541                 return -ENOTTY;
542
543         minsz = offsetofend(struct vfio_precopy_info, dirty_bytes);
544
545         if (copy_from_user(&info, (void __user *)arg, minsz))
546                 return -EFAULT;
547
548         if (info.argsz < minsz)
549                 return -EINVAL;
550
551         mutex_lock(&mvdev->state_mutex);
552         if (mvdev->mig_state != VFIO_DEVICE_STATE_PRE_COPY &&
553             mvdev->mig_state != VFIO_DEVICE_STATE_PRE_COPY_P2P) {
554                 ret = -EINVAL;
555                 goto err_state_unlock;
556         }
557
558         /*
559          * We can't issue a SAVE command when the device is suspended, so as
560          * part of VFIO_DEVICE_STATE_PRE_COPY_P2P no reason to query for extra
561          * bytes that can't be read.
562          */
563         if (mvdev->mig_state == VFIO_DEVICE_STATE_PRE_COPY) {
564                 /*
565                  * Once the query returns it's guaranteed that there is no
566                  * active SAVE command.
567                  * As so, the other code below is safe with the proper locks.
568                  */
569                 ret = mlx5vf_cmd_query_vhca_migration_state(mvdev, &inc_length,
570                                                             NULL, MLX5VF_QUERY_INC);
571                 if (ret)
572                         goto err_state_unlock;
573         }
574
575         mutex_lock(&migf->lock);
576         if (migf->state == MLX5_MIGF_STATE_ERROR) {
577                 ret = -ENODEV;
578                 goto err_migf_unlock;
579         }
580
581         if (migf->pre_copy_initial_bytes > *pos) {
582                 info.initial_bytes = migf->pre_copy_initial_bytes - *pos;
583         } else {
584                 info.dirty_bytes = migf->max_pos - *pos;
585                 if (!info.dirty_bytes)
586                         end_of_data = true;
587                 info.dirty_bytes += inc_length;
588         }
589
590         if (!end_of_data || !inc_length) {
591                 mutex_unlock(&migf->lock);
592                 goto done;
593         }
594
595         mutex_unlock(&migf->lock);
596         /*
597          * We finished transferring the current state and the device has a
598          * dirty state, save a new state to be ready for.
599          */
600         buf = mlx5vf_get_data_buffer(migf, inc_length, DMA_FROM_DEVICE);
601         if (IS_ERR(buf)) {
602                 ret = PTR_ERR(buf);
603                 mlx5vf_mark_err(migf);
604                 goto err_state_unlock;
605         }
606
607         ret = mlx5vf_cmd_save_vhca_state(mvdev, migf, buf, true, true);
608         if (ret) {
609                 mlx5vf_mark_err(migf);
610                 mlx5vf_put_data_buffer(buf);
611                 goto err_state_unlock;
612         }
613
614 done:
615         mlx5vf_state_mutex_unlock(mvdev);
616         if (copy_to_user((void __user *)arg, &info, minsz))
617                 return -EFAULT;
618         return 0;
619
620 err_migf_unlock:
621         mutex_unlock(&migf->lock);
622 err_state_unlock:
623         mlx5vf_state_mutex_unlock(mvdev);
624         return ret;
625 }
626
627 static const struct file_operations mlx5vf_save_fops = {
628         .owner = THIS_MODULE,
629         .read = mlx5vf_save_read,
630         .poll = mlx5vf_save_poll,
631         .unlocked_ioctl = mlx5vf_precopy_ioctl,
632         .compat_ioctl = compat_ptr_ioctl,
633         .release = mlx5vf_release_file,
634         .llseek = no_llseek,
635 };
636
637 static int mlx5vf_pci_save_device_inc_data(struct mlx5vf_pci_core_device *mvdev)
638 {
639         struct mlx5_vf_migration_file *migf = mvdev->saving_migf;
640         struct mlx5_vhca_data_buffer *buf;
641         size_t length;
642         int ret;
643
644         if (migf->state == MLX5_MIGF_STATE_ERROR)
645                 return -ENODEV;
646
647         ret = mlx5vf_cmd_query_vhca_migration_state(mvdev, &length, NULL,
648                                 MLX5VF_QUERY_INC | MLX5VF_QUERY_FINAL);
649         if (ret)
650                 goto err;
651
652         buf = mlx5vf_mig_file_get_stop_copy_buf(migf, 0, length);
653         if (IS_ERR(buf)) {
654                 ret = PTR_ERR(buf);
655                 goto err;
656         }
657
658         ret = mlx5vf_cmd_save_vhca_state(mvdev, migf, buf, true, false);
659         if (ret)
660                 goto err_save;
661
662         return 0;
663
664 err_save:
665         mlx5vf_put_data_buffer(buf);
666 err:
667         mlx5vf_mark_err(migf);
668         return ret;
669 }
670
671 static struct mlx5_vf_migration_file *
672 mlx5vf_pci_save_device_data(struct mlx5vf_pci_core_device *mvdev, bool track)
673 {
674         struct mlx5_vf_migration_file *migf;
675         struct mlx5_vhca_data_buffer *buf;
676         size_t length;
677         u64 full_size;
678         int ret;
679
680         migf = kzalloc(sizeof(*migf), GFP_KERNEL_ACCOUNT);
681         if (!migf)
682                 return ERR_PTR(-ENOMEM);
683
684         migf->filp = anon_inode_getfile("mlx5vf_mig", &mlx5vf_save_fops, migf,
685                                         O_RDONLY);
686         if (IS_ERR(migf->filp)) {
687                 ret = PTR_ERR(migf->filp);
688                 goto end;
689         }
690
691         migf->mvdev = mvdev;
692         ret = mlx5vf_cmd_alloc_pd(migf);
693         if (ret)
694                 goto out_free;
695
696         stream_open(migf->filp->f_inode, migf->filp);
697         mutex_init(&migf->lock);
698         init_waitqueue_head(&migf->poll_wait);
699         init_completion(&migf->save_comp);
700         /*
701          * save_comp is being used as a binary semaphore built from
702          * a completion. A normal mutex cannot be used because the lock is
703          * passed between kernel threads and lockdep can't model this.
704          */
705         complete(&migf->save_comp);
706         mlx5_cmd_init_async_ctx(mvdev->mdev, &migf->async_ctx);
707         INIT_WORK(&migf->async_data.work, mlx5vf_mig_file_cleanup_cb);
708         INIT_LIST_HEAD(&migf->buf_list);
709         INIT_LIST_HEAD(&migf->avail_list);
710         spin_lock_init(&migf->list_lock);
711         ret = mlx5vf_cmd_query_vhca_migration_state(mvdev, &length, &full_size, 0);
712         if (ret)
713                 goto out_pd;
714
715         ret = mlx5vf_prep_stop_copy(mvdev, migf, length, full_size, track);
716         if (ret)
717                 goto out_pd;
718
719         if (track) {
720                 /* leave the allocated buffer ready for the stop-copy phase */
721                 buf = mlx5vf_alloc_data_buffer(migf,
722                         migf->buf[0]->allocated_length, DMA_FROM_DEVICE);
723                 if (IS_ERR(buf)) {
724                         ret = PTR_ERR(buf);
725                         goto out_pd;
726                 }
727         } else {
728                 buf = migf->buf[0];
729                 migf->buf[0] = NULL;
730         }
731
732         ret = mlx5vf_cmd_save_vhca_state(mvdev, migf, buf, false, track);
733         if (ret)
734                 goto out_save;
735         return migf;
736 out_save:
737         mlx5vf_free_data_buffer(buf);
738 out_pd:
739         mlx5fv_cmd_clean_migf_resources(migf);
740 out_free:
741         fput(migf->filp);
742 end:
743         kfree(migf);
744         return ERR_PTR(ret);
745 }
746
747 static int
748 mlx5vf_append_page_to_mig_buf(struct mlx5_vhca_data_buffer *vhca_buf,
749                               const char __user **buf, size_t *len,
750                               loff_t *pos, ssize_t *done)
751 {
752         unsigned long offset;
753         size_t page_offset;
754         struct page *page;
755         size_t page_len;
756         u8 *to_buff;
757         int ret;
758
759         offset = *pos - vhca_buf->start_pos;
760         page_offset = offset % PAGE_SIZE;
761
762         page = mlx5vf_get_migration_page(vhca_buf, offset - page_offset);
763         if (!page)
764                 return -EINVAL;
765         page_len = min_t(size_t, *len, PAGE_SIZE - page_offset);
766         to_buff = kmap_local_page(page);
767         ret = copy_from_user(to_buff + page_offset, *buf, page_len);
768         kunmap_local(to_buff);
769         if (ret)
770                 return -EFAULT;
771
772         *pos += page_len;
773         *done += page_len;
774         *buf += page_len;
775         *len -= page_len;
776         vhca_buf->length += page_len;
777         return 0;
778 }
779
780 static int
781 mlx5vf_resume_read_image_no_header(struct mlx5_vhca_data_buffer *vhca_buf,
782                                    loff_t requested_length,
783                                    const char __user **buf, size_t *len,
784                                    loff_t *pos, ssize_t *done)
785 {
786         int ret;
787
788         if (requested_length > MAX_LOAD_SIZE)
789                 return -ENOMEM;
790
791         if (vhca_buf->allocated_length < requested_length) {
792                 ret = mlx5vf_add_migration_pages(
793                         vhca_buf,
794                         DIV_ROUND_UP(requested_length - vhca_buf->allocated_length,
795                                      PAGE_SIZE));
796                 if (ret)
797                         return ret;
798         }
799
800         while (*len) {
801                 ret = mlx5vf_append_page_to_mig_buf(vhca_buf, buf, len, pos,
802                                                     done);
803                 if (ret)
804                         return ret;
805         }
806
807         return 0;
808 }
809
810 static ssize_t
811 mlx5vf_resume_read_image(struct mlx5_vf_migration_file *migf,
812                          struct mlx5_vhca_data_buffer *vhca_buf,
813                          size_t image_size, const char __user **buf,
814                          size_t *len, loff_t *pos, ssize_t *done,
815                          bool *has_work)
816 {
817         size_t copy_len, to_copy;
818         int ret;
819
820         to_copy = min_t(size_t, *len, image_size - vhca_buf->length);
821         copy_len = to_copy;
822         while (to_copy) {
823                 ret = mlx5vf_append_page_to_mig_buf(vhca_buf, buf, &to_copy, pos,
824                                                     done);
825                 if (ret)
826                         return ret;
827         }
828
829         *len -= copy_len;
830         if (vhca_buf->length == image_size) {
831                 migf->load_state = MLX5_VF_LOAD_STATE_LOAD_IMAGE;
832                 migf->max_pos += image_size;
833                 *has_work = true;
834         }
835
836         return 0;
837 }
838
839 static int
840 mlx5vf_resume_read_header_data(struct mlx5_vf_migration_file *migf,
841                                struct mlx5_vhca_data_buffer *vhca_buf,
842                                const char __user **buf, size_t *len,
843                                loff_t *pos, ssize_t *done)
844 {
845         size_t copy_len, to_copy;
846         size_t required_data;
847         u8 *to_buff;
848         int ret;
849
850         required_data = migf->record_size - vhca_buf->length;
851         to_copy = min_t(size_t, *len, required_data);
852         copy_len = to_copy;
853         while (to_copy) {
854                 ret = mlx5vf_append_page_to_mig_buf(vhca_buf, buf, &to_copy, pos,
855                                                     done);
856                 if (ret)
857                         return ret;
858         }
859
860         *len -= copy_len;
861         if (vhca_buf->length == migf->record_size) {
862                 switch (migf->record_tag) {
863                 case MLX5_MIGF_HEADER_TAG_STOP_COPY_SIZE:
864                 {
865                         struct page *page;
866
867                         page = mlx5vf_get_migration_page(vhca_buf, 0);
868                         if (!page)
869                                 return -EINVAL;
870                         to_buff = kmap_local_page(page);
871                         migf->stop_copy_prep_size = min_t(u64,
872                                 le64_to_cpup((__le64 *)to_buff), MAX_LOAD_SIZE);
873                         kunmap_local(to_buff);
874                         break;
875                 }
876                 default:
877                         /* Optional tag */
878                         break;
879                 }
880
881                 migf->load_state = MLX5_VF_LOAD_STATE_READ_HEADER;
882                 migf->max_pos += migf->record_size;
883                 vhca_buf->length = 0;
884         }
885
886         return 0;
887 }
888
889 static int
890 mlx5vf_resume_read_header(struct mlx5_vf_migration_file *migf,
891                           struct mlx5_vhca_data_buffer *vhca_buf,
892                           const char __user **buf,
893                           size_t *len, loff_t *pos,
894                           ssize_t *done, bool *has_work)
895 {
896         struct page *page;
897         size_t copy_len;
898         u8 *to_buff;
899         int ret;
900
901         copy_len = min_t(size_t, *len,
902                 sizeof(struct mlx5_vf_migration_header) - vhca_buf->length);
903         page = mlx5vf_get_migration_page(vhca_buf, 0);
904         if (!page)
905                 return -EINVAL;
906         to_buff = kmap_local_page(page);
907         ret = copy_from_user(to_buff + vhca_buf->length, *buf, copy_len);
908         if (ret) {
909                 ret = -EFAULT;
910                 goto end;
911         }
912
913         *buf += copy_len;
914         *pos += copy_len;
915         *done += copy_len;
916         *len -= copy_len;
917         vhca_buf->length += copy_len;
918         if (vhca_buf->length == sizeof(struct mlx5_vf_migration_header)) {
919                 u64 record_size;
920                 u32 flags;
921
922                 record_size = le64_to_cpup((__le64 *)to_buff);
923                 if (record_size > MAX_LOAD_SIZE) {
924                         ret = -ENOMEM;
925                         goto end;
926                 }
927
928                 migf->record_size = record_size;
929                 flags = le32_to_cpup((__le32 *)(to_buff +
930                             offsetof(struct mlx5_vf_migration_header, flags)));
931                 migf->record_tag = le32_to_cpup((__le32 *)(to_buff +
932                             offsetof(struct mlx5_vf_migration_header, tag)));
933                 switch (migf->record_tag) {
934                 case MLX5_MIGF_HEADER_TAG_FW_DATA:
935                         migf->load_state = MLX5_VF_LOAD_STATE_PREP_IMAGE;
936                         break;
937                 case MLX5_MIGF_HEADER_TAG_STOP_COPY_SIZE:
938                         migf->load_state = MLX5_VF_LOAD_STATE_PREP_HEADER_DATA;
939                         break;
940                 default:
941                         if (!(flags & MLX5_MIGF_HEADER_FLAGS_TAG_OPTIONAL)) {
942                                 ret = -EOPNOTSUPP;
943                                 goto end;
944                         }
945                         /* We may read and skip this optional record data */
946                         migf->load_state = MLX5_VF_LOAD_STATE_PREP_HEADER_DATA;
947                 }
948
949                 migf->max_pos += vhca_buf->length;
950                 vhca_buf->length = 0;
951                 *has_work = true;
952         }
953 end:
954         kunmap_local(to_buff);
955         return ret;
956 }
957
958 static ssize_t mlx5vf_resume_write(struct file *filp, const char __user *buf,
959                                    size_t len, loff_t *pos)
960 {
961         struct mlx5_vf_migration_file *migf = filp->private_data;
962         struct mlx5_vhca_data_buffer *vhca_buf = migf->buf[0];
963         struct mlx5_vhca_data_buffer *vhca_buf_header = migf->buf_header[0];
964         loff_t requested_length;
965         bool has_work = false;
966         ssize_t done = 0;
967         int ret = 0;
968
969         if (pos)
970                 return -ESPIPE;
971         pos = &filp->f_pos;
972
973         if (*pos < 0 ||
974             check_add_overflow((loff_t)len, *pos, &requested_length))
975                 return -EINVAL;
976
977         mutex_lock(&migf->mvdev->state_mutex);
978         mutex_lock(&migf->lock);
979         if (migf->state == MLX5_MIGF_STATE_ERROR) {
980                 ret = -ENODEV;
981                 goto out_unlock;
982         }
983
984         while (len || has_work) {
985                 has_work = false;
986                 switch (migf->load_state) {
987                 case MLX5_VF_LOAD_STATE_READ_HEADER:
988                         ret = mlx5vf_resume_read_header(migf, vhca_buf_header,
989                                                         &buf, &len, pos,
990                                                         &done, &has_work);
991                         if (ret)
992                                 goto out_unlock;
993                         break;
994                 case MLX5_VF_LOAD_STATE_PREP_HEADER_DATA:
995                         if (vhca_buf_header->allocated_length < migf->record_size) {
996                                 mlx5vf_free_data_buffer(vhca_buf_header);
997
998                                 migf->buf_header[0] = mlx5vf_alloc_data_buffer(migf,
999                                                 migf->record_size, DMA_NONE);
1000                                 if (IS_ERR(migf->buf_header[0])) {
1001                                         ret = PTR_ERR(migf->buf_header[0]);
1002                                         migf->buf_header[0] = NULL;
1003                                         goto out_unlock;
1004                                 }
1005
1006                                 vhca_buf_header = migf->buf_header[0];
1007                         }
1008
1009                         vhca_buf_header->start_pos = migf->max_pos;
1010                         migf->load_state = MLX5_VF_LOAD_STATE_READ_HEADER_DATA;
1011                         break;
1012                 case MLX5_VF_LOAD_STATE_READ_HEADER_DATA:
1013                         ret = mlx5vf_resume_read_header_data(migf, vhca_buf_header,
1014                                                         &buf, &len, pos, &done);
1015                         if (ret)
1016                                 goto out_unlock;
1017                         break;
1018                 case MLX5_VF_LOAD_STATE_PREP_IMAGE:
1019                 {
1020                         u64 size = max(migf->record_size,
1021                                        migf->stop_copy_prep_size);
1022
1023                         if (vhca_buf->allocated_length < size) {
1024                                 mlx5vf_free_data_buffer(vhca_buf);
1025
1026                                 migf->buf[0] = mlx5vf_alloc_data_buffer(migf,
1027                                                         size, DMA_TO_DEVICE);
1028                                 if (IS_ERR(migf->buf[0])) {
1029                                         ret = PTR_ERR(migf->buf[0]);
1030                                         migf->buf[0] = NULL;
1031                                         goto out_unlock;
1032                                 }
1033
1034                                 vhca_buf = migf->buf[0];
1035                         }
1036
1037                         vhca_buf->start_pos = migf->max_pos;
1038                         migf->load_state = MLX5_VF_LOAD_STATE_READ_IMAGE;
1039                         break;
1040                 }
1041                 case MLX5_VF_LOAD_STATE_READ_IMAGE_NO_HEADER:
1042                         ret = mlx5vf_resume_read_image_no_header(vhca_buf,
1043                                                 requested_length,
1044                                                 &buf, &len, pos, &done);
1045                         if (ret)
1046                                 goto out_unlock;
1047                         break;
1048                 case MLX5_VF_LOAD_STATE_READ_IMAGE:
1049                         ret = mlx5vf_resume_read_image(migf, vhca_buf,
1050                                                 migf->record_size,
1051                                                 &buf, &len, pos, &done, &has_work);
1052                         if (ret)
1053                                 goto out_unlock;
1054                         break;
1055                 case MLX5_VF_LOAD_STATE_LOAD_IMAGE:
1056                         ret = mlx5vf_cmd_load_vhca_state(migf->mvdev, migf, vhca_buf);
1057                         if (ret)
1058                                 goto out_unlock;
1059                         migf->load_state = MLX5_VF_LOAD_STATE_READ_HEADER;
1060
1061                         /* prep header buf for next image */
1062                         vhca_buf_header->length = 0;
1063                         /* prep data buf for next image */
1064                         vhca_buf->length = 0;
1065
1066                         break;
1067                 default:
1068                         break;
1069                 }
1070         }
1071
1072 out_unlock:
1073         if (ret)
1074                 migf->state = MLX5_MIGF_STATE_ERROR;
1075         mutex_unlock(&migf->lock);
1076         mlx5vf_state_mutex_unlock(migf->mvdev);
1077         return ret ? ret : done;
1078 }
1079
1080 static const struct file_operations mlx5vf_resume_fops = {
1081         .owner = THIS_MODULE,
1082         .write = mlx5vf_resume_write,
1083         .release = mlx5vf_release_file,
1084         .llseek = no_llseek,
1085 };
1086
1087 static struct mlx5_vf_migration_file *
1088 mlx5vf_pci_resume_device_data(struct mlx5vf_pci_core_device *mvdev)
1089 {
1090         struct mlx5_vf_migration_file *migf;
1091         struct mlx5_vhca_data_buffer *buf;
1092         int ret;
1093
1094         migf = kzalloc(sizeof(*migf), GFP_KERNEL_ACCOUNT);
1095         if (!migf)
1096                 return ERR_PTR(-ENOMEM);
1097
1098         migf->filp = anon_inode_getfile("mlx5vf_mig", &mlx5vf_resume_fops, migf,
1099                                         O_WRONLY);
1100         if (IS_ERR(migf->filp)) {
1101                 ret = PTR_ERR(migf->filp);
1102                 goto end;
1103         }
1104
1105         migf->mvdev = mvdev;
1106         ret = mlx5vf_cmd_alloc_pd(migf);
1107         if (ret)
1108                 goto out_free;
1109
1110         buf = mlx5vf_alloc_data_buffer(migf, 0, DMA_TO_DEVICE);
1111         if (IS_ERR(buf)) {
1112                 ret = PTR_ERR(buf);
1113                 goto out_pd;
1114         }
1115
1116         migf->buf[0] = buf;
1117         if (MLX5VF_PRE_COPY_SUPP(mvdev)) {
1118                 buf = mlx5vf_alloc_data_buffer(migf,
1119                         sizeof(struct mlx5_vf_migration_header), DMA_NONE);
1120                 if (IS_ERR(buf)) {
1121                         ret = PTR_ERR(buf);
1122                         goto out_buf;
1123                 }
1124
1125                 migf->buf_header[0] = buf;
1126                 migf->load_state = MLX5_VF_LOAD_STATE_READ_HEADER;
1127         } else {
1128                 /* Initial state will be to read the image */
1129                 migf->load_state = MLX5_VF_LOAD_STATE_READ_IMAGE_NO_HEADER;
1130         }
1131
1132         stream_open(migf->filp->f_inode, migf->filp);
1133         mutex_init(&migf->lock);
1134         INIT_LIST_HEAD(&migf->buf_list);
1135         INIT_LIST_HEAD(&migf->avail_list);
1136         spin_lock_init(&migf->list_lock);
1137         return migf;
1138 out_buf:
1139         mlx5vf_free_data_buffer(migf->buf[0]);
1140 out_pd:
1141         mlx5vf_cmd_dealloc_pd(migf);
1142 out_free:
1143         fput(migf->filp);
1144 end:
1145         kfree(migf);
1146         return ERR_PTR(ret);
1147 }
1148
1149 void mlx5vf_disable_fds(struct mlx5vf_pci_core_device *mvdev)
1150 {
1151         if (mvdev->resuming_migf) {
1152                 mlx5vf_disable_fd(mvdev->resuming_migf);
1153                 mlx5fv_cmd_clean_migf_resources(mvdev->resuming_migf);
1154                 fput(mvdev->resuming_migf->filp);
1155                 mvdev->resuming_migf = NULL;
1156         }
1157         if (mvdev->saving_migf) {
1158                 mlx5_cmd_cleanup_async_ctx(&mvdev->saving_migf->async_ctx);
1159                 cancel_work_sync(&mvdev->saving_migf->async_data.work);
1160                 mlx5vf_disable_fd(mvdev->saving_migf);
1161                 wake_up_interruptible(&mvdev->saving_migf->poll_wait);
1162                 mlx5fv_cmd_clean_migf_resources(mvdev->saving_migf);
1163                 fput(mvdev->saving_migf->filp);
1164                 mvdev->saving_migf = NULL;
1165         }
1166 }
1167
1168 static struct file *
1169 mlx5vf_pci_step_device_state_locked(struct mlx5vf_pci_core_device *mvdev,
1170                                     u32 new)
1171 {
1172         u32 cur = mvdev->mig_state;
1173         int ret;
1174
1175         if (cur == VFIO_DEVICE_STATE_RUNNING_P2P && new == VFIO_DEVICE_STATE_STOP) {
1176                 ret = mlx5vf_cmd_suspend_vhca(mvdev,
1177                         MLX5_SUSPEND_VHCA_IN_OP_MOD_SUSPEND_RESPONDER);
1178                 if (ret)
1179                         return ERR_PTR(ret);
1180                 return NULL;
1181         }
1182
1183         if (cur == VFIO_DEVICE_STATE_STOP && new == VFIO_DEVICE_STATE_RUNNING_P2P) {
1184                 ret = mlx5vf_cmd_resume_vhca(mvdev,
1185                         MLX5_RESUME_VHCA_IN_OP_MOD_RESUME_RESPONDER);
1186                 if (ret)
1187                         return ERR_PTR(ret);
1188                 return NULL;
1189         }
1190
1191         if ((cur == VFIO_DEVICE_STATE_RUNNING && new == VFIO_DEVICE_STATE_RUNNING_P2P) ||
1192             (cur == VFIO_DEVICE_STATE_PRE_COPY && new == VFIO_DEVICE_STATE_PRE_COPY_P2P)) {
1193                 ret = mlx5vf_cmd_suspend_vhca(mvdev,
1194                         MLX5_SUSPEND_VHCA_IN_OP_MOD_SUSPEND_INITIATOR);
1195                 if (ret)
1196                         return ERR_PTR(ret);
1197                 return NULL;
1198         }
1199
1200         if ((cur == VFIO_DEVICE_STATE_RUNNING_P2P && new == VFIO_DEVICE_STATE_RUNNING) ||
1201             (cur == VFIO_DEVICE_STATE_PRE_COPY_P2P && new == VFIO_DEVICE_STATE_PRE_COPY)) {
1202                 ret = mlx5vf_cmd_resume_vhca(mvdev,
1203                         MLX5_RESUME_VHCA_IN_OP_MOD_RESUME_INITIATOR);
1204                 if (ret)
1205                         return ERR_PTR(ret);
1206                 return NULL;
1207         }
1208
1209         if (cur == VFIO_DEVICE_STATE_STOP && new == VFIO_DEVICE_STATE_STOP_COPY) {
1210                 struct mlx5_vf_migration_file *migf;
1211
1212                 migf = mlx5vf_pci_save_device_data(mvdev, false);
1213                 if (IS_ERR(migf))
1214                         return ERR_CAST(migf);
1215                 get_file(migf->filp);
1216                 mvdev->saving_migf = migf;
1217                 return migf->filp;
1218         }
1219
1220         if ((cur == VFIO_DEVICE_STATE_STOP_COPY && new == VFIO_DEVICE_STATE_STOP) ||
1221             (cur == VFIO_DEVICE_STATE_PRE_COPY && new == VFIO_DEVICE_STATE_RUNNING) ||
1222             (cur == VFIO_DEVICE_STATE_PRE_COPY_P2P &&
1223              new == VFIO_DEVICE_STATE_RUNNING_P2P)) {
1224                 mlx5vf_disable_fds(mvdev);
1225                 return NULL;
1226         }
1227
1228         if (cur == VFIO_DEVICE_STATE_STOP && new == VFIO_DEVICE_STATE_RESUMING) {
1229                 struct mlx5_vf_migration_file *migf;
1230
1231                 migf = mlx5vf_pci_resume_device_data(mvdev);
1232                 if (IS_ERR(migf))
1233                         return ERR_CAST(migf);
1234                 get_file(migf->filp);
1235                 mvdev->resuming_migf = migf;
1236                 return migf->filp;
1237         }
1238
1239         if (cur == VFIO_DEVICE_STATE_RESUMING && new == VFIO_DEVICE_STATE_STOP) {
1240                 if (!MLX5VF_PRE_COPY_SUPP(mvdev)) {
1241                         ret = mlx5vf_cmd_load_vhca_state(mvdev,
1242                                                          mvdev->resuming_migf,
1243                                                          mvdev->resuming_migf->buf[0]);
1244                         if (ret)
1245                                 return ERR_PTR(ret);
1246                 }
1247                 mlx5vf_disable_fds(mvdev);
1248                 return NULL;
1249         }
1250
1251         if ((cur == VFIO_DEVICE_STATE_RUNNING && new == VFIO_DEVICE_STATE_PRE_COPY) ||
1252             (cur == VFIO_DEVICE_STATE_RUNNING_P2P &&
1253              new == VFIO_DEVICE_STATE_PRE_COPY_P2P)) {
1254                 struct mlx5_vf_migration_file *migf;
1255
1256                 migf = mlx5vf_pci_save_device_data(mvdev, true);
1257                 if (IS_ERR(migf))
1258                         return ERR_CAST(migf);
1259                 get_file(migf->filp);
1260                 mvdev->saving_migf = migf;
1261                 return migf->filp;
1262         }
1263
1264         if (cur == VFIO_DEVICE_STATE_PRE_COPY_P2P && new == VFIO_DEVICE_STATE_STOP_COPY) {
1265                 ret = mlx5vf_cmd_suspend_vhca(mvdev,
1266                         MLX5_SUSPEND_VHCA_IN_OP_MOD_SUSPEND_RESPONDER);
1267                 if (ret)
1268                         return ERR_PTR(ret);
1269                 ret = mlx5vf_pci_save_device_inc_data(mvdev);
1270                 return ret ? ERR_PTR(ret) : NULL;
1271         }
1272
1273         /*
1274          * vfio_mig_get_next_state() does not use arcs other than the above
1275          */
1276         WARN_ON(true);
1277         return ERR_PTR(-EINVAL);
1278 }
1279
1280 /*
1281  * This function is called in all state_mutex unlock cases to
1282  * handle a 'deferred_reset' if exists.
1283  */
1284 void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev)
1285 {
1286 again:
1287         spin_lock(&mvdev->reset_lock);
1288         if (mvdev->deferred_reset) {
1289                 mvdev->deferred_reset = false;
1290                 spin_unlock(&mvdev->reset_lock);
1291                 mvdev->mig_state = VFIO_DEVICE_STATE_RUNNING;
1292                 mlx5vf_disable_fds(mvdev);
1293                 goto again;
1294         }
1295         mutex_unlock(&mvdev->state_mutex);
1296         spin_unlock(&mvdev->reset_lock);
1297 }
1298
1299 static struct file *
1300 mlx5vf_pci_set_device_state(struct vfio_device *vdev,
1301                             enum vfio_device_mig_state new_state)
1302 {
1303         struct mlx5vf_pci_core_device *mvdev = container_of(
1304                 vdev, struct mlx5vf_pci_core_device, core_device.vdev);
1305         enum vfio_device_mig_state next_state;
1306         struct file *res = NULL;
1307         int ret;
1308
1309         mutex_lock(&mvdev->state_mutex);
1310         while (new_state != mvdev->mig_state) {
1311                 ret = vfio_mig_get_next_state(vdev, mvdev->mig_state,
1312                                               new_state, &next_state);
1313                 if (ret) {
1314                         res = ERR_PTR(ret);
1315                         break;
1316                 }
1317                 res = mlx5vf_pci_step_device_state_locked(mvdev, next_state);
1318                 if (IS_ERR(res))
1319                         break;
1320                 mvdev->mig_state = next_state;
1321                 if (WARN_ON(res && new_state != mvdev->mig_state)) {
1322                         fput(res);
1323                         res = ERR_PTR(-EINVAL);
1324                         break;
1325                 }
1326         }
1327         mlx5vf_state_mutex_unlock(mvdev);
1328         return res;
1329 }
1330
1331 static int mlx5vf_pci_get_data_size(struct vfio_device *vdev,
1332                                     unsigned long *stop_copy_length)
1333 {
1334         struct mlx5vf_pci_core_device *mvdev = container_of(
1335                 vdev, struct mlx5vf_pci_core_device, core_device.vdev);
1336         size_t state_size;
1337         u64 total_size;
1338         int ret;
1339
1340         mutex_lock(&mvdev->state_mutex);
1341         ret = mlx5vf_cmd_query_vhca_migration_state(mvdev, &state_size,
1342                                                     &total_size, 0);
1343         if (!ret)
1344                 *stop_copy_length = total_size;
1345         mlx5vf_state_mutex_unlock(mvdev);
1346         return ret;
1347 }
1348
1349 static int mlx5vf_pci_get_device_state(struct vfio_device *vdev,
1350                                        enum vfio_device_mig_state *curr_state)
1351 {
1352         struct mlx5vf_pci_core_device *mvdev = container_of(
1353                 vdev, struct mlx5vf_pci_core_device, core_device.vdev);
1354
1355         mutex_lock(&mvdev->state_mutex);
1356         *curr_state = mvdev->mig_state;
1357         mlx5vf_state_mutex_unlock(mvdev);
1358         return 0;
1359 }
1360
1361 static void mlx5vf_pci_aer_reset_done(struct pci_dev *pdev)
1362 {
1363         struct mlx5vf_pci_core_device *mvdev = mlx5vf_drvdata(pdev);
1364
1365         if (!mvdev->migrate_cap)
1366                 return;
1367
1368         /*
1369          * As the higher VFIO layers are holding locks across reset and using
1370          * those same locks with the mm_lock we need to prevent ABBA deadlock
1371          * with the state_mutex and mm_lock.
1372          * In case the state_mutex was taken already we defer the cleanup work
1373          * to the unlock flow of the other running context.
1374          */
1375         spin_lock(&mvdev->reset_lock);
1376         mvdev->deferred_reset = true;
1377         if (!mutex_trylock(&mvdev->state_mutex)) {
1378                 spin_unlock(&mvdev->reset_lock);
1379                 return;
1380         }
1381         spin_unlock(&mvdev->reset_lock);
1382         mlx5vf_state_mutex_unlock(mvdev);
1383 }
1384
1385 static int mlx5vf_pci_open_device(struct vfio_device *core_vdev)
1386 {
1387         struct mlx5vf_pci_core_device *mvdev = container_of(
1388                 core_vdev, struct mlx5vf_pci_core_device, core_device.vdev);
1389         struct vfio_pci_core_device *vdev = &mvdev->core_device;
1390         int ret;
1391
1392         ret = vfio_pci_core_enable(vdev);
1393         if (ret)
1394                 return ret;
1395
1396         if (mvdev->migrate_cap)
1397                 mvdev->mig_state = VFIO_DEVICE_STATE_RUNNING;
1398         vfio_pci_core_finish_enable(vdev);
1399         return 0;
1400 }
1401
1402 static void mlx5vf_pci_close_device(struct vfio_device *core_vdev)
1403 {
1404         struct mlx5vf_pci_core_device *mvdev = container_of(
1405                 core_vdev, struct mlx5vf_pci_core_device, core_device.vdev);
1406
1407         mlx5vf_cmd_close_migratable(mvdev);
1408         vfio_pci_core_close_device(core_vdev);
1409 }
1410
1411 static const struct vfio_migration_ops mlx5vf_pci_mig_ops = {
1412         .migration_set_state = mlx5vf_pci_set_device_state,
1413         .migration_get_state = mlx5vf_pci_get_device_state,
1414         .migration_get_data_size = mlx5vf_pci_get_data_size,
1415 };
1416
1417 static const struct vfio_log_ops mlx5vf_pci_log_ops = {
1418         .log_start = mlx5vf_start_page_tracker,
1419         .log_stop = mlx5vf_stop_page_tracker,
1420         .log_read_and_clear = mlx5vf_tracker_read_and_clear,
1421 };
1422
1423 static int mlx5vf_pci_init_dev(struct vfio_device *core_vdev)
1424 {
1425         struct mlx5vf_pci_core_device *mvdev = container_of(core_vdev,
1426                         struct mlx5vf_pci_core_device, core_device.vdev);
1427         int ret;
1428
1429         ret = vfio_pci_core_init_dev(core_vdev);
1430         if (ret)
1431                 return ret;
1432
1433         mlx5vf_cmd_set_migratable(mvdev, &mlx5vf_pci_mig_ops,
1434                                   &mlx5vf_pci_log_ops);
1435
1436         return 0;
1437 }
1438
1439 static void mlx5vf_pci_release_dev(struct vfio_device *core_vdev)
1440 {
1441         struct mlx5vf_pci_core_device *mvdev = container_of(core_vdev,
1442                         struct mlx5vf_pci_core_device, core_device.vdev);
1443
1444         mlx5vf_cmd_remove_migratable(mvdev);
1445         vfio_pci_core_release_dev(core_vdev);
1446 }
1447
1448 static const struct vfio_device_ops mlx5vf_pci_ops = {
1449         .name = "mlx5-vfio-pci",
1450         .init = mlx5vf_pci_init_dev,
1451         .release = mlx5vf_pci_release_dev,
1452         .open_device = mlx5vf_pci_open_device,
1453         .close_device = mlx5vf_pci_close_device,
1454         .ioctl = vfio_pci_core_ioctl,
1455         .device_feature = vfio_pci_core_ioctl_feature,
1456         .read = vfio_pci_core_read,
1457         .write = vfio_pci_core_write,
1458         .mmap = vfio_pci_core_mmap,
1459         .request = vfio_pci_core_request,
1460         .match = vfio_pci_core_match,
1461         .bind_iommufd = vfio_iommufd_physical_bind,
1462         .unbind_iommufd = vfio_iommufd_physical_unbind,
1463         .attach_ioas = vfio_iommufd_physical_attach_ioas,
1464         .detach_ioas = vfio_iommufd_physical_detach_ioas,
1465 };
1466
1467 static int mlx5vf_pci_probe(struct pci_dev *pdev,
1468                             const struct pci_device_id *id)
1469 {
1470         struct mlx5vf_pci_core_device *mvdev;
1471         int ret;
1472
1473         mvdev = vfio_alloc_device(mlx5vf_pci_core_device, core_device.vdev,
1474                                   &pdev->dev, &mlx5vf_pci_ops);
1475         if (IS_ERR(mvdev))
1476                 return PTR_ERR(mvdev);
1477
1478         dev_set_drvdata(&pdev->dev, &mvdev->core_device);
1479         ret = vfio_pci_core_register_device(&mvdev->core_device);
1480         if (ret)
1481                 goto out_put_vdev;
1482         return 0;
1483
1484 out_put_vdev:
1485         vfio_put_device(&mvdev->core_device.vdev);
1486         return ret;
1487 }
1488
1489 static void mlx5vf_pci_remove(struct pci_dev *pdev)
1490 {
1491         struct mlx5vf_pci_core_device *mvdev = mlx5vf_drvdata(pdev);
1492
1493         vfio_pci_core_unregister_device(&mvdev->core_device);
1494         vfio_put_device(&mvdev->core_device.vdev);
1495 }
1496
1497 static const struct pci_device_id mlx5vf_pci_table[] = {
1498         { PCI_DRIVER_OVERRIDE_DEVICE_VFIO(PCI_VENDOR_ID_MELLANOX, 0x101e) }, /* ConnectX Family mlx5Gen Virtual Function */
1499         {}
1500 };
1501
1502 MODULE_DEVICE_TABLE(pci, mlx5vf_pci_table);
1503
1504 static const struct pci_error_handlers mlx5vf_err_handlers = {
1505         .reset_done = mlx5vf_pci_aer_reset_done,
1506         .error_detected = vfio_pci_core_aer_err_detected,
1507 };
1508
1509 static struct pci_driver mlx5vf_pci_driver = {
1510         .name = KBUILD_MODNAME,
1511         .id_table = mlx5vf_pci_table,
1512         .probe = mlx5vf_pci_probe,
1513         .remove = mlx5vf_pci_remove,
1514         .err_handler = &mlx5vf_err_handlers,
1515         .driver_managed_dma = true,
1516 };
1517
1518 module_pci_driver(mlx5vf_pci_driver);
1519
1520 MODULE_IMPORT_NS(IOMMUFD);
1521 MODULE_LICENSE("GPL");
1522 MODULE_AUTHOR("Max Gurtovoy <mgurtovoy@nvidia.com>");
1523 MODULE_AUTHOR("Yishai Hadas <yishaih@nvidia.com>");
1524 MODULE_DESCRIPTION(
1525         "MLX5 VFIO PCI - User Level meta-driver for MLX5 device family");