virtio-blk: Avoid use-after-free on suspend/resume
[linux-2.6-microblaze.git] / drivers / block / virtio_blk.c
index 6fc7850..d756423 100644 (file)
@@ -101,6 +101,14 @@ static inline blk_status_t virtblk_result(struct virtblk_req *vbr)
        }
 }
 
+static inline struct virtio_blk_vq *get_virtio_blk_vq(struct blk_mq_hw_ctx *hctx)
+{
+       struct virtio_blk *vblk = hctx->queue->queuedata;
+       struct virtio_blk_vq *vq = &vblk->vqs[hctx->queue_num];
+
+       return vq;
+}
+
 static int virtblk_add_req(struct virtqueue *vq, struct virtblk_req *vbr)
 {
        struct scatterlist hdr, status, *sgs[3];
@@ -416,7 +424,7 @@ static void virtio_queue_rqs(struct request **rqlist)
        struct request *requeue_list = NULL;
 
        rq_list_for_each_safe(rqlist, req, next) {
-               struct virtio_blk_vq *vq = req->mq_hctx->driver_data;
+               struct virtio_blk_vq *vq = get_virtio_blk_vq(req->mq_hctx);
                bool kick;
 
                if (!virtblk_prep_rq_batch(req)) {
@@ -837,7 +845,7 @@ static void virtblk_complete_batch(struct io_comp_batch *iob)
 static int virtblk_poll(struct blk_mq_hw_ctx *hctx, struct io_comp_batch *iob)
 {
        struct virtio_blk *vblk = hctx->queue->queuedata;
-       struct virtio_blk_vq *vq = hctx->driver_data;
+       struct virtio_blk_vq *vq = get_virtio_blk_vq(hctx);
        struct virtblk_req *vbr;
        unsigned long flags;
        unsigned int len;
@@ -862,22 +870,10 @@ static int virtblk_poll(struct blk_mq_hw_ctx *hctx, struct io_comp_batch *iob)
        return found;
 }
 
-static int virtblk_init_hctx(struct blk_mq_hw_ctx *hctx, void *data,
-                         unsigned int hctx_idx)
-{
-       struct virtio_blk *vblk = data;
-       struct virtio_blk_vq *vq = &vblk->vqs[hctx_idx];
-
-       WARN_ON(vblk->tag_set.tags[hctx_idx] != hctx->tags);
-       hctx->driver_data = vq;
-       return 0;
-}
-
 static const struct blk_mq_ops virtio_mq_ops = {
        .queue_rq       = virtio_queue_rq,
        .queue_rqs      = virtio_queue_rqs,
        .commit_rqs     = virtio_commit_rqs,
-       .init_hctx      = virtblk_init_hctx,
        .complete       = virtblk_request_done,
        .map_queues     = virtblk_map_queues,
        .poll           = virtblk_poll,