virtio_ring : keep used_wrap_counter in vq->last_used_idx
authorhuangjie.albert <huangjie.albert@bytedance.com>
Fri, 17 Jun 2022 02:04:11 +0000 (10:04 +0800)
committerMichael S. Tsirkin <mst@redhat.com>
Fri, 24 Jun 2022 06:49:48 +0000 (02:49 -0400)
the used_wrap_counter and the vq->last_used_idx may get
out of sync if they are separate assignment,and interrupt
might use an incorrect value to check for the used index.

for example:OOB access
ksoftirqd may consume the packet and it will call:
virtnet_poll
-->virtnet_receive
-->virtqueue_get_buf_ctx
-->virtqueue_get_buf_ctx_packed
and in virtqueue_get_buf_ctx_packed:

vq->last_used_idx += vq->packed.desc_state[id].num;
if (unlikely(vq->last_used_idx >= vq->packed.vring.num)) {
         vq->last_used_idx -= vq->packed.vring.num;
         vq->packed.used_wrap_counter ^= 1;
}

if at the same time, there comes a vring interrupt,in vring_interrupt:
we will call:
vring_interrupt
-->more_used
-->more_used_packed
-->is_used_desc_packed
in is_used_desc_packed, the last_used_idx maybe >= vq->packed.vring.num.
so this could case a memory out of bounds bug.

this patch is to keep the used_wrap_counter in vq->last_used_idx
so we can get the correct value to check for used index in interrupt.

v3->v4:
- use READ_ONCE/WRITE_ONCE to get/set vq->last_used_idx

v2->v3:
- add inline function to get used_wrap_counter and last_used
- when use vq->last_used_idx, only read once
  if vq->last_used_idx is read twice, the values can be inconsistent.
- use last_used_idx & ~(-(1 << VRING_PACKED_EVENT_F_WRAP_CTR))
  to get the all bits below VRING_PACKED_EVENT_F_WRAP_CTR

v1->v2:
- reuse the VRING_PACKED_EVENT_F_WRAP_CTR
- Remove parameter judgment in is_used_desc_packed,
because it can't be illegal

Signed-off-by: huangjie.albert <huangjie.albert@bytedance.com>
Message-Id: <20220617020411.80367-1-huangjie.albert@bytedance.com>
Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
drivers/virtio/virtio_ring.c

index 13a7348..ef30465 100644 (file)
@@ -111,7 +111,12 @@ struct vring_virtqueue {
        /* Number we've added since last sync. */
        unsigned int num_added;
 
-       /* Last used index we've seen. */
+       /* Last used index  we've seen.
+        * for split ring, it just contains last used index
+        * for packed ring:
+        * bits up to VRING_PACKED_EVENT_F_WRAP_CTR include the last used index.
+        * bits from VRING_PACKED_EVENT_F_WRAP_CTR include the used wrap counter.
+        */
        u16 last_used_idx;
 
        /* Hint for event idx: already triggered no need to disable. */
@@ -154,9 +159,6 @@ struct vring_virtqueue {
                        /* Driver ring wrap counter. */
                        bool avail_wrap_counter;
 
-                       /* Device ring wrap counter. */
-                       bool used_wrap_counter;
-
                        /* Avail used flags. */
                        u16 avail_used_flags;
 
@@ -973,6 +975,15 @@ static struct virtqueue *vring_create_virtqueue_split(
 /*
  * Packed ring specific functions - *_packed().
  */
+static inline bool packed_used_wrap_counter(u16 last_used_idx)
+{
+       return !!(last_used_idx & (1 << VRING_PACKED_EVENT_F_WRAP_CTR));
+}
+
+static inline u16 packed_last_used(u16 last_used_idx)
+{
+       return last_used_idx & ~(-(1 << VRING_PACKED_EVENT_F_WRAP_CTR));
+}
 
 static void vring_unmap_extra_packed(const struct vring_virtqueue *vq,
                                     struct vring_desc_extra *extra)
@@ -1406,8 +1417,14 @@ static inline bool is_used_desc_packed(const struct vring_virtqueue *vq,
 
 static inline bool more_used_packed(const struct vring_virtqueue *vq)
 {
-       return is_used_desc_packed(vq, vq->last_used_idx,
-                       vq->packed.used_wrap_counter);
+       u16 last_used;
+       u16 last_used_idx;
+       bool used_wrap_counter;
+
+       last_used_idx = READ_ONCE(vq->last_used_idx);
+       last_used = packed_last_used(last_used_idx);
+       used_wrap_counter = packed_used_wrap_counter(last_used_idx);
+       return is_used_desc_packed(vq, last_used, used_wrap_counter);
 }
 
 static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
@@ -1415,7 +1432,8 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
                                          void **ctx)
 {
        struct vring_virtqueue *vq = to_vvq(_vq);
-       u16 last_used, id;
+       u16 last_used, id, last_used_idx;
+       bool used_wrap_counter;
        void *ret;
 
        START_USE(vq);
@@ -1434,7 +1452,9 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
        /* Only get used elements after they have been exposed by host. */
        virtio_rmb(vq->weak_barriers);
 
-       last_used = vq->last_used_idx;
+       last_used_idx = READ_ONCE(vq->last_used_idx);
+       used_wrap_counter = packed_used_wrap_counter(last_used_idx);
+       last_used = packed_last_used(last_used_idx);
        id = le16_to_cpu(vq->packed.vring.desc[last_used].id);
        *len = le32_to_cpu(vq->packed.vring.desc[last_used].len);
 
@@ -1451,12 +1471,15 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
        ret = vq->packed.desc_state[id].data;
        detach_buf_packed(vq, id, ctx);
 
-       vq->last_used_idx += vq->packed.desc_state[id].num;
-       if (unlikely(vq->last_used_idx >= vq->packed.vring.num)) {
-               vq->last_used_idx -= vq->packed.vring.num;
-               vq->packed.used_wrap_counter ^= 1;
+       last_used += vq->packed.desc_state[id].num;
+       if (unlikely(last_used >= vq->packed.vring.num)) {
+               last_used -= vq->packed.vring.num;
+               used_wrap_counter ^= 1;
        }
 
+       last_used = (last_used | (used_wrap_counter << VRING_PACKED_EVENT_F_WRAP_CTR));
+       WRITE_ONCE(vq->last_used_idx, last_used);
+
        /*
         * If we expect an interrupt for the next entry, tell host
         * by writing event index and flush out the write before
@@ -1465,9 +1488,7 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
        if (vq->packed.event_flags_shadow == VRING_PACKED_EVENT_FLAG_DESC)
                virtio_store_mb(vq->weak_barriers,
                                &vq->packed.vring.driver->off_wrap,
-                               cpu_to_le16(vq->last_used_idx |
-                                       (vq->packed.used_wrap_counter <<
-                                        VRING_PACKED_EVENT_F_WRAP_CTR)));
+                               cpu_to_le16(vq->last_used_idx));
 
        LAST_ADD_TIME_INVALID(vq);
 
@@ -1499,9 +1520,7 @@ static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq)
 
        if (vq->event) {
                vq->packed.vring.driver->off_wrap =
-                       cpu_to_le16(vq->last_used_idx |
-                               (vq->packed.used_wrap_counter <<
-                                VRING_PACKED_EVENT_F_WRAP_CTR));
+                       cpu_to_le16(vq->last_used_idx);
                /*
                 * We need to update event offset and event wrap
                 * counter first before updating event flags.
@@ -1518,8 +1537,7 @@ static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq)
        }
 
        END_USE(vq);
-       return vq->last_used_idx | ((u16)vq->packed.used_wrap_counter <<
-                       VRING_PACKED_EVENT_F_WRAP_CTR);
+       return vq->last_used_idx;
 }
 
 static bool virtqueue_poll_packed(struct virtqueue *_vq, u16 off_wrap)
@@ -1537,7 +1555,7 @@ static bool virtqueue_poll_packed(struct virtqueue *_vq, u16 off_wrap)
 static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq)
 {
        struct vring_virtqueue *vq = to_vvq(_vq);
-       u16 used_idx, wrap_counter;
+       u16 used_idx, wrap_counter, last_used_idx;
        u16 bufs;
 
        START_USE(vq);
@@ -1550,9 +1568,10 @@ static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq)
        if (vq->event) {
                /* TODO: tune this threshold */
                bufs = (vq->packed.vring.num - vq->vq.num_free) * 3 / 4;
-               wrap_counter = vq->packed.used_wrap_counter;
+               last_used_idx = READ_ONCE(vq->last_used_idx);
+               wrap_counter = packed_used_wrap_counter(last_used_idx);
 
-               used_idx = vq->last_used_idx + bufs;
+               used_idx = packed_last_used(last_used_idx) + bufs;
                if (used_idx >= vq->packed.vring.num) {
                        used_idx -= vq->packed.vring.num;
                        wrap_counter ^= 1;
@@ -1582,9 +1601,10 @@ static bool virtqueue_enable_cb_delayed_packed(struct virtqueue *_vq)
         */
        virtio_mb(vq->weak_barriers);
 
-       if (is_used_desc_packed(vq,
-                               vq->last_used_idx,
-                               vq->packed.used_wrap_counter)) {
+       last_used_idx = READ_ONCE(vq->last_used_idx);
+       wrap_counter = packed_used_wrap_counter(last_used_idx);
+       used_idx = packed_last_used(last_used_idx);
+       if (is_used_desc_packed(vq, used_idx, wrap_counter)) {
                END_USE(vq);
                return false;
        }
@@ -1689,7 +1709,7 @@ static struct virtqueue *vring_create_virtqueue_packed(
        vq->notify = notify;
        vq->weak_barriers = weak_barriers;
        vq->broken = true;
-       vq->last_used_idx = 0;
+       vq->last_used_idx = 0 | (1 << VRING_PACKED_EVENT_F_WRAP_CTR);
        vq->event_triggered = false;
        vq->num_added = 0;
        vq->packed_ring = true;
@@ -1720,7 +1740,6 @@ static struct virtqueue *vring_create_virtqueue_packed(
 
        vq->packed.next_avail_idx = 0;
        vq->packed.avail_wrap_counter = 1;
-       vq->packed.used_wrap_counter = 1;
        vq->packed.event_flags_shadow = 0;
        vq->packed.avail_used_flags = 1 << VRING_PACKED_DESC_F_AVAIL;