Merge tag 'for-5.16-deadlock-fix-tag' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-microblaze.git] / drivers / virtio / virtio_ring.c
index 3035bb6..00f64f2 100644 (file)
@@ -14,6 +14,9 @@
 #include <linux/spinlock.h>
 #include <xen/xen.h>
 
+static bool force_used_validation = false;
+module_param(force_used_validation, bool, 0444);
+
 #ifdef DEBUG
 /* For development, we want to crash whenever the ring is screwed. */
 #define BAD_RING(_vq, fmt, args...)                            \
@@ -79,8 +82,8 @@ struct vring_desc_state_packed {
 };
 
 struct vring_desc_extra {
-       dma_addr_t addr;                /* Buffer DMA addr. */
-       u32 len;                        /* Buffer length. */
+       dma_addr_t addr;                /* Descriptor DMA addr. */
+       u32 len;                        /* Descriptor length. */
        u16 flags;                      /* Descriptor flags. */
        u16 next;                       /* The next desc state in a list. */
 };
@@ -182,6 +185,9 @@ struct vring_virtqueue {
                } packed;
        };
 
+       /* Per-descriptor in buffer length */
+       u32 *buflen;
+
        /* How to notify other side. FIXME: commonalize hcalls! */
        bool (*notify)(struct virtqueue *vq);
 
@@ -490,6 +496,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
        unsigned int i, n, avail, descs_used, prev, err_idx;
        int head;
        bool indirect;
+       u32 buflen = 0;
 
        START_USE(vq);
 
@@ -571,6 +578,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
                                                     VRING_DESC_F_NEXT |
                                                     VRING_DESC_F_WRITE,
                                                     indirect);
+                       buflen += sg->length;
                }
        }
        /* Last one doesn't continue. */
@@ -610,6 +618,10 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
        else
                vq->split.desc_state[head].indir_desc = ctx;
 
+       /* Store in buffer length if necessary */
+       if (vq->buflen)
+               vq->buflen[head] = buflen;
+
        /* Put entry in available array (but don't update avail->idx until they
         * do sync). */
        avail = vq->split.avail_idx_shadow & (vq->split.vring.num - 1);
@@ -784,6 +796,11 @@ static void *virtqueue_get_buf_ctx_split(struct virtqueue *_vq,
                BAD_RING(vq, "id %u is not a head!\n", i);
                return NULL;
        }
+       if (vq->buflen && unlikely(*len > vq->buflen[i])) {
+               BAD_RING(vq, "used len %d is larger than in buflen %u\n",
+                       *len, vq->buflen[i]);
+               return NULL;
+       }
 
        /* detach_buf_split clears data, so grab it now. */
        ret = vq->split.desc_state[i].data;
@@ -1050,21 +1067,24 @@ static struct vring_packed_desc *alloc_indirect_packed(unsigned int total_sg,
 }
 
 static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq,
-                                      struct scatterlist *sgs[],
-                                      unsigned int total_sg,
-                                      unsigned int out_sgs,
-                                      unsigned int in_sgs,
-                                      void *data,
-                                      gfp_t gfp)
+                                        struct scatterlist *sgs[],
+                                        unsigned int total_sg,
+                                        unsigned int out_sgs,
+                                        unsigned int in_sgs,
+                                        void *data,
+                                        gfp_t gfp)
 {
        struct vring_packed_desc *desc;
        struct scatterlist *sg;
        unsigned int i, n, err_idx;
        u16 head, id;
        dma_addr_t addr;
+       u32 buflen = 0;
 
        head = vq->packed.next_avail_idx;
        desc = alloc_indirect_packed(total_sg, gfp);
+       if (!desc)
+               return -ENOMEM;
 
        if (unlikely(vq->vq.num_free < 1)) {
                pr_debug("Can't add buf len 1 - avail = 0\n");
@@ -1089,6 +1109,8 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq,
                        desc[i].addr = cpu_to_le64(addr);
                        desc[i].len = cpu_to_le32(sg->length);
                        i++;
+                       if (n >= out_sgs)
+                               buflen += sg->length;
                }
        }
 
@@ -1142,6 +1164,10 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq,
        vq->packed.desc_state[id].indir_desc = desc;
        vq->packed.desc_state[id].last = id;
 
+       /* Store in buffer length if necessary */
+       if (vq->buflen)
+               vq->buflen[id] = buflen;
+
        vq->num_added += 1;
 
        pr_debug("Added buffer head %i to %p\n", head, vq);
@@ -1176,6 +1202,8 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq,
        unsigned int i, n, c, descs_used, err_idx;
        __le16 head_flags, flags;
        u16 head, id, prev, curr, avail_used_flags;
+       int err;
+       u32 buflen = 0;
 
        START_USE(vq);
 
@@ -1191,9 +1219,14 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq,
 
        BUG_ON(total_sg == 0);
 
-       if (virtqueue_use_indirect(_vq, total_sg))
-               return virtqueue_add_indirect_packed(vq, sgs, total_sg,
-                               out_sgs, in_sgs, data, gfp);
+       if (virtqueue_use_indirect(_vq, total_sg)) {
+               err = virtqueue_add_indirect_packed(vq, sgs, total_sg, out_sgs,
+                                                   in_sgs, data, gfp);
+               if (err != -ENOMEM)
+                       return err;
+
+               /* fall back on direct */
+       }
 
        head = vq->packed.next_avail_idx;
        avail_used_flags = vq->packed.avail_used_flags;
@@ -1250,6 +1283,8 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq,
                                        1 << VRING_PACKED_DESC_F_AVAIL |
                                        1 << VRING_PACKED_DESC_F_USED;
                        }
+                       if (n >= out_sgs)
+                               buflen += sg->length;
                }
        }
 
@@ -1269,6 +1304,10 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq,
        vq->packed.desc_state[id].indir_desc = ctx;
        vq->packed.desc_state[id].last = prev;
 
+       /* Store in buffer length if necessary */
+       if (vq->buflen)
+               vq->buflen[id] = buflen;
+
        /*
         * A driver MUST NOT make the first descriptor in the list
         * available before all subsequent descriptors comprising
@@ -1455,6 +1494,11 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
                BAD_RING(vq, "id %u is not a head!\n", id);
                return NULL;
        }
+       if (vq->buflen && unlikely(*len > vq->buflen[id])) {
+               BAD_RING(vq, "used len %d is larger than in buflen %u\n",
+                       *len, vq->buflen[id]);
+               return NULL;
+       }
 
        /* detach_buf_packed clears data, so grab it now. */
        ret = vq->packed.desc_state[id].data;
@@ -1660,6 +1704,7 @@ static struct virtqueue *vring_create_virtqueue_packed(
        struct vring_virtqueue *vq;
        struct vring_packed_desc *ring;
        struct vring_packed_desc_event *driver, *device;
+       struct virtio_driver *drv = drv_to_virtio(vdev->dev.driver);
        dma_addr_t ring_dma_addr, driver_event_dma_addr, device_event_dma_addr;
        size_t ring_size_in_bytes, event_size_in_bytes;
 
@@ -1749,6 +1794,15 @@ static struct virtqueue *vring_create_virtqueue_packed(
        if (!vq->packed.desc_extra)
                goto err_desc_extra;
 
+       if (!drv->suppress_used_validation || force_used_validation) {
+               vq->buflen = kmalloc_array(num, sizeof(*vq->buflen),
+                                          GFP_KERNEL);
+               if (!vq->buflen)
+                       goto err_buflen;
+       } else {
+               vq->buflen = NULL;
+       }
+
        /* No callback?  Tell other side not to bother us. */
        if (!callback) {
                vq->packed.event_flags_shadow = VRING_PACKED_EVENT_FLAG_DISABLE;
@@ -1761,6 +1815,8 @@ static struct virtqueue *vring_create_virtqueue_packed(
        spin_unlock(&vdev->vqs_list_lock);
        return &vq->vq;
 
+err_buflen:
+       kfree(vq->packed.desc_extra);
 err_desc_extra:
        kfree(vq->packed.desc_state);
 err_desc_state:
@@ -2168,6 +2224,7 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index,
                                        void (*callback)(struct virtqueue *),
                                        const char *name)
 {
+       struct virtio_driver *drv = drv_to_virtio(vdev->dev.driver);
        struct vring_virtqueue *vq;
 
        if (virtio_has_feature(vdev, VIRTIO_F_RING_PACKED))
@@ -2227,6 +2284,15 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index,
        if (!vq->split.desc_extra)
                goto err_extra;
 
+       if (!drv->suppress_used_validation || force_used_validation) {
+               vq->buflen = kmalloc_array(vring.num, sizeof(*vq->buflen),
+                                          GFP_KERNEL);
+               if (!vq->buflen)
+                       goto err_buflen;
+       } else {
+               vq->buflen = NULL;
+       }
+
        /* Put everything in free lists. */
        vq->free_head = 0;
        memset(vq->split.desc_state, 0, vring.num *
@@ -2237,6 +2303,8 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index,
        spin_unlock(&vdev->vqs_list_lock);
        return &vq->vq;
 
+err_buflen:
+       kfree(vq->split.desc_extra);
 err_extra:
        kfree(vq->split.desc_state);
 err_state: