Merge tag 'pull-work.iov_iter-rebased' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-microblaze.git] / drivers / net / virtio_net.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* A network driver using virtio.
3  *
4  * Copyright 2007 Rusty Russell <rusty@rustcorp.com.au> IBM Corporation
5  */
6 //#define DEBUG
7 #include <linux/netdevice.h>
8 #include <linux/etherdevice.h>
9 #include <linux/ethtool.h>
10 #include <linux/module.h>
11 #include <linux/virtio.h>
12 #include <linux/virtio_net.h>
13 #include <linux/bpf.h>
14 #include <linux/bpf_trace.h>
15 #include <linux/scatterlist.h>
16 #include <linux/if_vlan.h>
17 #include <linux/slab.h>
18 #include <linux/cpu.h>
19 #include <linux/average.h>
20 #include <linux/filter.h>
21 #include <linux/kernel.h>
22 #include <net/route.h>
23 #include <net/xdp.h>
24 #include <net/net_failover.h>
25
26 static int napi_weight = NAPI_POLL_WEIGHT;
27 module_param(napi_weight, int, 0444);
28
29 static bool csum = true, gso = true, napi_tx = true;
30 module_param(csum, bool, 0444);
31 module_param(gso, bool, 0444);
32 module_param(napi_tx, bool, 0644);
33
34 /* FIXME: MTU in config. */
35 #define GOOD_PACKET_LEN (ETH_HLEN + VLAN_HLEN + ETH_DATA_LEN)
36 #define GOOD_COPY_LEN   128
37
38 #define VIRTNET_RX_PAD (NET_IP_ALIGN + NET_SKB_PAD)
39
40 /* Amount of XDP headroom to prepend to packets for use by xdp_adjust_head */
41 #define VIRTIO_XDP_HEADROOM 256
42
43 /* Separating two types of XDP xmit */
44 #define VIRTIO_XDP_TX           BIT(0)
45 #define VIRTIO_XDP_REDIR        BIT(1)
46
47 #define VIRTIO_XDP_FLAG BIT(0)
48
49 /* RX packet size EWMA. The average packet size is used to determine the packet
50  * buffer size when refilling RX rings. As the entire RX ring may be refilled
51  * at once, the weight is chosen so that the EWMA will be insensitive to short-
52  * term, transient changes in packet size.
53  */
54 DECLARE_EWMA(pkt_len, 0, 64)
55
56 #define VIRTNET_DRIVER_VERSION "1.0.0"
57
58 static const unsigned long guest_offloads[] = {
59         VIRTIO_NET_F_GUEST_TSO4,
60         VIRTIO_NET_F_GUEST_TSO6,
61         VIRTIO_NET_F_GUEST_ECN,
62         VIRTIO_NET_F_GUEST_UFO,
63         VIRTIO_NET_F_GUEST_CSUM
64 };
65
66 #define GUEST_OFFLOAD_GRO_HW_MASK ((1ULL << VIRTIO_NET_F_GUEST_TSO4) | \
67                                 (1ULL << VIRTIO_NET_F_GUEST_TSO6) | \
68                                 (1ULL << VIRTIO_NET_F_GUEST_ECN)  | \
69                                 (1ULL << VIRTIO_NET_F_GUEST_UFO))
70
71 struct virtnet_stat_desc {
72         char desc[ETH_GSTRING_LEN];
73         size_t offset;
74 };
75
76 struct virtnet_sq_stats {
77         struct u64_stats_sync syncp;
78         u64 packets;
79         u64 bytes;
80         u64 xdp_tx;
81         u64 xdp_tx_drops;
82         u64 kicks;
83         u64 tx_timeouts;
84 };
85
86 struct virtnet_rq_stats {
87         struct u64_stats_sync syncp;
88         u64 packets;
89         u64 bytes;
90         u64 drops;
91         u64 xdp_packets;
92         u64 xdp_tx;
93         u64 xdp_redirects;
94         u64 xdp_drops;
95         u64 kicks;
96 };
97
98 #define VIRTNET_SQ_STAT(m)      offsetof(struct virtnet_sq_stats, m)
99 #define VIRTNET_RQ_STAT(m)      offsetof(struct virtnet_rq_stats, m)
100
101 static const struct virtnet_stat_desc virtnet_sq_stats_desc[] = {
102         { "packets",            VIRTNET_SQ_STAT(packets) },
103         { "bytes",              VIRTNET_SQ_STAT(bytes) },
104         { "xdp_tx",             VIRTNET_SQ_STAT(xdp_tx) },
105         { "xdp_tx_drops",       VIRTNET_SQ_STAT(xdp_tx_drops) },
106         { "kicks",              VIRTNET_SQ_STAT(kicks) },
107         { "tx_timeouts",        VIRTNET_SQ_STAT(tx_timeouts) },
108 };
109
110 static const struct virtnet_stat_desc virtnet_rq_stats_desc[] = {
111         { "packets",            VIRTNET_RQ_STAT(packets) },
112         { "bytes",              VIRTNET_RQ_STAT(bytes) },
113         { "drops",              VIRTNET_RQ_STAT(drops) },
114         { "xdp_packets",        VIRTNET_RQ_STAT(xdp_packets) },
115         { "xdp_tx",             VIRTNET_RQ_STAT(xdp_tx) },
116         { "xdp_redirects",      VIRTNET_RQ_STAT(xdp_redirects) },
117         { "xdp_drops",          VIRTNET_RQ_STAT(xdp_drops) },
118         { "kicks",              VIRTNET_RQ_STAT(kicks) },
119 };
120
121 #define VIRTNET_SQ_STATS_LEN    ARRAY_SIZE(virtnet_sq_stats_desc)
122 #define VIRTNET_RQ_STATS_LEN    ARRAY_SIZE(virtnet_rq_stats_desc)
123
124 /* Internal representation of a send virtqueue */
125 struct send_queue {
126         /* Virtqueue associated with this send _queue */
127         struct virtqueue *vq;
128
129         /* TX: fragments + linear part + virtio header */
130         struct scatterlist sg[MAX_SKB_FRAGS + 2];
131
132         /* Name of the send queue: output.$index */
133         char name[40];
134
135         struct virtnet_sq_stats stats;
136
137         struct napi_struct napi;
138 };
139
140 /* Internal representation of a receive virtqueue */
141 struct receive_queue {
142         /* Virtqueue associated with this receive_queue */
143         struct virtqueue *vq;
144
145         struct napi_struct napi;
146
147         struct bpf_prog __rcu *xdp_prog;
148
149         struct virtnet_rq_stats stats;
150
151         /* Chain pages by the private ptr. */
152         struct page *pages;
153
154         /* Average packet length for mergeable receive buffers. */
155         struct ewma_pkt_len mrg_avg_pkt_len;
156
157         /* Page frag for packet buffer allocation. */
158         struct page_frag alloc_frag;
159
160         /* RX: fragments + linear part + virtio header */
161         struct scatterlist sg[MAX_SKB_FRAGS + 2];
162
163         /* Min single buffer size for mergeable buffers case. */
164         unsigned int min_buf_len;
165
166         /* Name of this receive queue: input.$index */
167         char name[40];
168
169         struct xdp_rxq_info xdp_rxq;
170 };
171
172 /* This structure can contain rss message with maximum settings for indirection table and keysize
173  * Note, that default structure that describes RSS configuration virtio_net_rss_config
174  * contains same info but can't handle table values.
175  * In any case, structure would be passed to virtio hw through sg_buf split by parts
176  * because table sizes may be differ according to the device configuration.
177  */
178 #define VIRTIO_NET_RSS_MAX_KEY_SIZE     40
179 #define VIRTIO_NET_RSS_MAX_TABLE_LEN    128
180 struct virtio_net_ctrl_rss {
181         u32 hash_types;
182         u16 indirection_table_mask;
183         u16 unclassified_queue;
184         u16 indirection_table[VIRTIO_NET_RSS_MAX_TABLE_LEN];
185         u16 max_tx_vq;
186         u8 hash_key_length;
187         u8 key[VIRTIO_NET_RSS_MAX_KEY_SIZE];
188 };
189
190 /* Control VQ buffers: protected by the rtnl lock */
191 struct control_buf {
192         struct virtio_net_ctrl_hdr hdr;
193         virtio_net_ctrl_ack status;
194         struct virtio_net_ctrl_mq mq;
195         u8 promisc;
196         u8 allmulti;
197         __virtio16 vid;
198         __virtio64 offloads;
199         struct virtio_net_ctrl_rss rss;
200 };
201
202 struct virtnet_info {
203         struct virtio_device *vdev;
204         struct virtqueue *cvq;
205         struct net_device *dev;
206         struct send_queue *sq;
207         struct receive_queue *rq;
208         unsigned int status;
209
210         /* Max # of queue pairs supported by the device */
211         u16 max_queue_pairs;
212
213         /* # of queue pairs currently used by the driver */
214         u16 curr_queue_pairs;
215
216         /* # of XDP queue pairs currently used by the driver */
217         u16 xdp_queue_pairs;
218
219         /* xdp_queue_pairs may be 0, when xdp is already loaded. So add this. */
220         bool xdp_enabled;
221
222         /* I like... big packets and I cannot lie! */
223         bool big_packets;
224
225         /* Host will merge rx buffers for big packets (shake it! shake it!) */
226         bool mergeable_rx_bufs;
227
228         /* Host supports rss and/or hash report */
229         bool has_rss;
230         bool has_rss_hash_report;
231         u8 rss_key_size;
232         u16 rss_indir_table_size;
233         u32 rss_hash_types_supported;
234         u32 rss_hash_types_saved;
235
236         /* Has control virtqueue */
237         bool has_cvq;
238
239         /* Host can handle any s/g split between our header and packet data */
240         bool any_header_sg;
241
242         /* Packet virtio header size */
243         u8 hdr_len;
244
245         /* Work struct for delayed refilling if we run low on memory. */
246         struct delayed_work refill;
247
248         /* Is delayed refill enabled? */
249         bool refill_enabled;
250
251         /* The lock to synchronize the access to refill_enabled */
252         spinlock_t refill_lock;
253
254         /* Work struct for config space updates */
255         struct work_struct config_work;
256
257         /* Does the affinity hint is set for virtqueues? */
258         bool affinity_hint_set;
259
260         /* CPU hotplug instances for online & dead */
261         struct hlist_node node;
262         struct hlist_node node_dead;
263
264         struct control_buf *ctrl;
265
266         /* Ethtool settings */
267         u8 duplex;
268         u32 speed;
269
270         unsigned long guest_offloads;
271         unsigned long guest_offloads_capable;
272
273         /* failover when STANDBY feature enabled */
274         struct failover *failover;
275 };
276
277 struct padded_vnet_hdr {
278         struct virtio_net_hdr_v1_hash hdr;
279         /*
280          * hdr is in a separate sg buffer, and data sg buffer shares same page
281          * with this header sg. This padding makes next sg 16 byte aligned
282          * after the header.
283          */
284         char padding[12];
285 };
286
287 static bool is_xdp_frame(void *ptr)
288 {
289         return (unsigned long)ptr & VIRTIO_XDP_FLAG;
290 }
291
292 static void *xdp_to_ptr(struct xdp_frame *ptr)
293 {
294         return (void *)((unsigned long)ptr | VIRTIO_XDP_FLAG);
295 }
296
297 static struct xdp_frame *ptr_to_xdp(void *ptr)
298 {
299         return (struct xdp_frame *)((unsigned long)ptr & ~VIRTIO_XDP_FLAG);
300 }
301
302 /* Converting between virtqueue no. and kernel tx/rx queue no.
303  * 0:rx0 1:tx0 2:rx1 3:tx1 ... 2N:rxN 2N+1:txN 2N+2:cvq
304  */
305 static int vq2txq(struct virtqueue *vq)
306 {
307         return (vq->index - 1) / 2;
308 }
309
310 static int txq2vq(int txq)
311 {
312         return txq * 2 + 1;
313 }
314
315 static int vq2rxq(struct virtqueue *vq)
316 {
317         return vq->index / 2;
318 }
319
320 static int rxq2vq(int rxq)
321 {
322         return rxq * 2;
323 }
324
325 static inline struct virtio_net_hdr_mrg_rxbuf *skb_vnet_hdr(struct sk_buff *skb)
326 {
327         return (struct virtio_net_hdr_mrg_rxbuf *)skb->cb;
328 }
329
330 /*
331  * private is used to chain pages for big packets, put the whole
332  * most recent used list in the beginning for reuse
333  */
334 static void give_pages(struct receive_queue *rq, struct page *page)
335 {
336         struct page *end;
337
338         /* Find end of list, sew whole thing into vi->rq.pages. */
339         for (end = page; end->private; end = (struct page *)end->private);
340         end->private = (unsigned long)rq->pages;
341         rq->pages = page;
342 }
343
344 static struct page *get_a_page(struct receive_queue *rq, gfp_t gfp_mask)
345 {
346         struct page *p = rq->pages;
347
348         if (p) {
349                 rq->pages = (struct page *)p->private;
350                 /* clear private here, it is used to chain pages */
351                 p->private = 0;
352         } else
353                 p = alloc_page(gfp_mask);
354         return p;
355 }
356
357 static void enable_delayed_refill(struct virtnet_info *vi)
358 {
359         spin_lock_bh(&vi->refill_lock);
360         vi->refill_enabled = true;
361         spin_unlock_bh(&vi->refill_lock);
362 }
363
364 static void disable_delayed_refill(struct virtnet_info *vi)
365 {
366         spin_lock_bh(&vi->refill_lock);
367         vi->refill_enabled = false;
368         spin_unlock_bh(&vi->refill_lock);
369 }
370
371 static void virtqueue_napi_schedule(struct napi_struct *napi,
372                                     struct virtqueue *vq)
373 {
374         if (napi_schedule_prep(napi)) {
375                 virtqueue_disable_cb(vq);
376                 __napi_schedule(napi);
377         }
378 }
379
380 static void virtqueue_napi_complete(struct napi_struct *napi,
381                                     struct virtqueue *vq, int processed)
382 {
383         int opaque;
384
385         opaque = virtqueue_enable_cb_prepare(vq);
386         if (napi_complete_done(napi, processed)) {
387                 if (unlikely(virtqueue_poll(vq, opaque)))
388                         virtqueue_napi_schedule(napi, vq);
389         } else {
390                 virtqueue_disable_cb(vq);
391         }
392 }
393
394 static void skb_xmit_done(struct virtqueue *vq)
395 {
396         struct virtnet_info *vi = vq->vdev->priv;
397         struct napi_struct *napi = &vi->sq[vq2txq(vq)].napi;
398
399         /* Suppress further interrupts. */
400         virtqueue_disable_cb(vq);
401
402         if (napi->weight)
403                 virtqueue_napi_schedule(napi, vq);
404         else
405                 /* We were probably waiting for more output buffers. */
406                 netif_wake_subqueue(vi->dev, vq2txq(vq));
407 }
408
409 #define MRG_CTX_HEADER_SHIFT 22
410 static void *mergeable_len_to_ctx(unsigned int truesize,
411                                   unsigned int headroom)
412 {
413         return (void *)(unsigned long)((headroom << MRG_CTX_HEADER_SHIFT) | truesize);
414 }
415
416 static unsigned int mergeable_ctx_to_headroom(void *mrg_ctx)
417 {
418         return (unsigned long)mrg_ctx >> MRG_CTX_HEADER_SHIFT;
419 }
420
421 static unsigned int mergeable_ctx_to_truesize(void *mrg_ctx)
422 {
423         return (unsigned long)mrg_ctx & ((1 << MRG_CTX_HEADER_SHIFT) - 1);
424 }
425
426 /* Called from bottom half context */
427 static struct sk_buff *page_to_skb(struct virtnet_info *vi,
428                                    struct receive_queue *rq,
429                                    struct page *page, unsigned int offset,
430                                    unsigned int len, unsigned int truesize,
431                                    bool hdr_valid, unsigned int metasize,
432                                    unsigned int headroom)
433 {
434         struct sk_buff *skb;
435         struct virtio_net_hdr_mrg_rxbuf *hdr;
436         unsigned int copy, hdr_len, hdr_padded_len;
437         struct page *page_to_free = NULL;
438         int tailroom, shinfo_size;
439         char *p, *hdr_p, *buf;
440
441         p = page_address(page) + offset;
442         hdr_p = p;
443
444         hdr_len = vi->hdr_len;
445         if (vi->mergeable_rx_bufs)
446                 hdr_padded_len = hdr_len;
447         else
448                 hdr_padded_len = sizeof(struct padded_vnet_hdr);
449
450         /* If headroom is not 0, there is an offset between the beginning of the
451          * data and the allocated space, otherwise the data and the allocated
452          * space are aligned.
453          *
454          * Buffers with headroom use PAGE_SIZE as alloc size, see
455          * add_recvbuf_mergeable() + get_mergeable_buf_len()
456          */
457         truesize = headroom ? PAGE_SIZE : truesize;
458         tailroom = truesize - headroom;
459         buf = p - headroom;
460
461         len -= hdr_len;
462         offset += hdr_padded_len;
463         p += hdr_padded_len;
464         tailroom -= hdr_padded_len + len;
465
466         shinfo_size = SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
467
468         /* copy small packet so we can reuse these pages */
469         if (!NET_IP_ALIGN && len > GOOD_COPY_LEN && tailroom >= shinfo_size) {
470                 skb = build_skb(buf, truesize);
471                 if (unlikely(!skb))
472                         return NULL;
473
474                 skb_reserve(skb, p - buf);
475                 skb_put(skb, len);
476
477                 page = (struct page *)page->private;
478                 if (page)
479                         give_pages(rq, page);
480                 goto ok;
481         }
482
483         /* copy small packet so we can reuse these pages for small data */
484         skb = napi_alloc_skb(&rq->napi, GOOD_COPY_LEN);
485         if (unlikely(!skb))
486                 return NULL;
487
488         /* Copy all frame if it fits skb->head, otherwise
489          * we let virtio_net_hdr_to_skb() and GRO pull headers as needed.
490          */
491         if (len <= skb_tailroom(skb))
492                 copy = len;
493         else
494                 copy = ETH_HLEN + metasize;
495         skb_put_data(skb, p, copy);
496
497         len -= copy;
498         offset += copy;
499
500         if (vi->mergeable_rx_bufs) {
501                 if (len)
502                         skb_add_rx_frag(skb, 0, page, offset, len, truesize);
503                 else
504                         page_to_free = page;
505                 goto ok;
506         }
507
508         /*
509          * Verify that we can indeed put this data into a skb.
510          * This is here to handle cases when the device erroneously
511          * tries to receive more than is possible. This is usually
512          * the case of a broken device.
513          */
514         if (unlikely(len > MAX_SKB_FRAGS * PAGE_SIZE)) {
515                 net_dbg_ratelimited("%s: too much data\n", skb->dev->name);
516                 dev_kfree_skb(skb);
517                 return NULL;
518         }
519         BUG_ON(offset >= PAGE_SIZE);
520         while (len) {
521                 unsigned int frag_size = min((unsigned)PAGE_SIZE - offset, len);
522                 skb_add_rx_frag(skb, skb_shinfo(skb)->nr_frags, page, offset,
523                                 frag_size, truesize);
524                 len -= frag_size;
525                 page = (struct page *)page->private;
526                 offset = 0;
527         }
528
529         if (page)
530                 give_pages(rq, page);
531
532 ok:
533         /* hdr_valid means no XDP, so we can copy the vnet header */
534         if (hdr_valid) {
535                 hdr = skb_vnet_hdr(skb);
536                 memcpy(hdr, hdr_p, hdr_len);
537         }
538         if (page_to_free)
539                 put_page(page_to_free);
540
541         if (metasize) {
542                 __skb_pull(skb, metasize);
543                 skb_metadata_set(skb, metasize);
544         }
545
546         return skb;
547 }
548
549 static int __virtnet_xdp_xmit_one(struct virtnet_info *vi,
550                                    struct send_queue *sq,
551                                    struct xdp_frame *xdpf)
552 {
553         struct virtio_net_hdr_mrg_rxbuf *hdr;
554         int err;
555
556         if (unlikely(xdpf->headroom < vi->hdr_len))
557                 return -EOVERFLOW;
558
559         /* Make room for virtqueue hdr (also change xdpf->headroom?) */
560         xdpf->data -= vi->hdr_len;
561         /* Zero header and leave csum up to XDP layers */
562         hdr = xdpf->data;
563         memset(hdr, 0, vi->hdr_len);
564         xdpf->len   += vi->hdr_len;
565
566         sg_init_one(sq->sg, xdpf->data, xdpf->len);
567
568         err = virtqueue_add_outbuf(sq->vq, sq->sg, 1, xdp_to_ptr(xdpf),
569                                    GFP_ATOMIC);
570         if (unlikely(err))
571                 return -ENOSPC; /* Caller handle free/refcnt */
572
573         return 0;
574 }
575
576 /* when vi->curr_queue_pairs > nr_cpu_ids, the txq/sq is only used for xdp tx on
577  * the current cpu, so it does not need to be locked.
578  *
579  * Here we use marco instead of inline functions because we have to deal with
580  * three issues at the same time: 1. the choice of sq. 2. judge and execute the
581  * lock/unlock of txq 3. make sparse happy. It is difficult for two inline
582  * functions to perfectly solve these three problems at the same time.
583  */
584 #define virtnet_xdp_get_sq(vi) ({                                       \
585         int cpu = smp_processor_id();                                   \
586         struct netdev_queue *txq;                                       \
587         typeof(vi) v = (vi);                                            \
588         unsigned int qp;                                                \
589                                                                         \
590         if (v->curr_queue_pairs > nr_cpu_ids) {                         \
591                 qp = v->curr_queue_pairs - v->xdp_queue_pairs;          \
592                 qp += cpu;                                              \
593                 txq = netdev_get_tx_queue(v->dev, qp);                  \
594                 __netif_tx_acquire(txq);                                \
595         } else {                                                        \
596                 qp = cpu % v->curr_queue_pairs;                         \
597                 txq = netdev_get_tx_queue(v->dev, qp);                  \
598                 __netif_tx_lock(txq, cpu);                              \
599         }                                                               \
600         v->sq + qp;                                                     \
601 })
602
603 #define virtnet_xdp_put_sq(vi, q) {                                     \
604         struct netdev_queue *txq;                                       \
605         typeof(vi) v = (vi);                                            \
606                                                                         \
607         txq = netdev_get_tx_queue(v->dev, (q) - v->sq);                 \
608         if (v->curr_queue_pairs > nr_cpu_ids)                           \
609                 __netif_tx_release(txq);                                \
610         else                                                            \
611                 __netif_tx_unlock(txq);                                 \
612 }
613
614 static int virtnet_xdp_xmit(struct net_device *dev,
615                             int n, struct xdp_frame **frames, u32 flags)
616 {
617         struct virtnet_info *vi = netdev_priv(dev);
618         struct receive_queue *rq = vi->rq;
619         struct bpf_prog *xdp_prog;
620         struct send_queue *sq;
621         unsigned int len;
622         int packets = 0;
623         int bytes = 0;
624         int nxmit = 0;
625         int kicks = 0;
626         void *ptr;
627         int ret;
628         int i;
629
630         /* Only allow ndo_xdp_xmit if XDP is loaded on dev, as this
631          * indicate XDP resources have been successfully allocated.
632          */
633         xdp_prog = rcu_access_pointer(rq->xdp_prog);
634         if (!xdp_prog)
635                 return -ENXIO;
636
637         sq = virtnet_xdp_get_sq(vi);
638
639         if (unlikely(flags & ~XDP_XMIT_FLAGS_MASK)) {
640                 ret = -EINVAL;
641                 goto out;
642         }
643
644         /* Free up any pending old buffers before queueing new ones. */
645         while ((ptr = virtqueue_get_buf(sq->vq, &len)) != NULL) {
646                 if (likely(is_xdp_frame(ptr))) {
647                         struct xdp_frame *frame = ptr_to_xdp(ptr);
648
649                         bytes += frame->len;
650                         xdp_return_frame(frame);
651                 } else {
652                         struct sk_buff *skb = ptr;
653
654                         bytes += skb->len;
655                         napi_consume_skb(skb, false);
656                 }
657                 packets++;
658         }
659
660         for (i = 0; i < n; i++) {
661                 struct xdp_frame *xdpf = frames[i];
662
663                 if (__virtnet_xdp_xmit_one(vi, sq, xdpf))
664                         break;
665                 nxmit++;
666         }
667         ret = nxmit;
668
669         if (flags & XDP_XMIT_FLUSH) {
670                 if (virtqueue_kick_prepare(sq->vq) && virtqueue_notify(sq->vq))
671                         kicks = 1;
672         }
673 out:
674         u64_stats_update_begin(&sq->stats.syncp);
675         sq->stats.bytes += bytes;
676         sq->stats.packets += packets;
677         sq->stats.xdp_tx += n;
678         sq->stats.xdp_tx_drops += n - nxmit;
679         sq->stats.kicks += kicks;
680         u64_stats_update_end(&sq->stats.syncp);
681
682         virtnet_xdp_put_sq(vi, sq);
683         return ret;
684 }
685
686 static unsigned int virtnet_get_headroom(struct virtnet_info *vi)
687 {
688         return vi->xdp_enabled ? VIRTIO_XDP_HEADROOM : 0;
689 }
690
691 /* We copy the packet for XDP in the following cases:
692  *
693  * 1) Packet is scattered across multiple rx buffers.
694  * 2) Headroom space is insufficient.
695  *
696  * This is inefficient but it's a temporary condition that
697  * we hit right after XDP is enabled and until queue is refilled
698  * with large buffers with sufficient headroom - so it should affect
699  * at most queue size packets.
700  * Afterwards, the conditions to enable
701  * XDP should preclude the underlying device from sending packets
702  * across multiple buffers (num_buf > 1), and we make sure buffers
703  * have enough headroom.
704  */
705 static struct page *xdp_linearize_page(struct receive_queue *rq,
706                                        u16 *num_buf,
707                                        struct page *p,
708                                        int offset,
709                                        int page_off,
710                                        unsigned int *len)
711 {
712         struct page *page = alloc_page(GFP_ATOMIC);
713
714         if (!page)
715                 return NULL;
716
717         memcpy(page_address(page) + page_off, page_address(p) + offset, *len);
718         page_off += *len;
719
720         while (--*num_buf) {
721                 int tailroom = SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
722                 unsigned int buflen;
723                 void *buf;
724                 int off;
725
726                 buf = virtqueue_get_buf(rq->vq, &buflen);
727                 if (unlikely(!buf))
728                         goto err_buf;
729
730                 p = virt_to_head_page(buf);
731                 off = buf - page_address(p);
732
733                 /* guard against a misconfigured or uncooperative backend that
734                  * is sending packet larger than the MTU.
735                  */
736                 if ((page_off + buflen + tailroom) > PAGE_SIZE) {
737                         put_page(p);
738                         goto err_buf;
739                 }
740
741                 memcpy(page_address(page) + page_off,
742                        page_address(p) + off, buflen);
743                 page_off += buflen;
744                 put_page(p);
745         }
746
747         /* Headroom does not contribute to packet length */
748         *len = page_off - VIRTIO_XDP_HEADROOM;
749         return page;
750 err_buf:
751         __free_pages(page, 0);
752         return NULL;
753 }
754
755 static struct sk_buff *receive_small(struct net_device *dev,
756                                      struct virtnet_info *vi,
757                                      struct receive_queue *rq,
758                                      void *buf, void *ctx,
759                                      unsigned int len,
760                                      unsigned int *xdp_xmit,
761                                      struct virtnet_rq_stats *stats)
762 {
763         struct sk_buff *skb;
764         struct bpf_prog *xdp_prog;
765         unsigned int xdp_headroom = (unsigned long)ctx;
766         unsigned int header_offset = VIRTNET_RX_PAD + xdp_headroom;
767         unsigned int headroom = vi->hdr_len + header_offset;
768         unsigned int buflen = SKB_DATA_ALIGN(GOOD_PACKET_LEN + headroom) +
769                               SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
770         struct page *page = virt_to_head_page(buf);
771         unsigned int delta = 0;
772         struct page *xdp_page;
773         int err;
774         unsigned int metasize = 0;
775
776         len -= vi->hdr_len;
777         stats->bytes += len;
778
779         if (unlikely(len > GOOD_PACKET_LEN)) {
780                 pr_debug("%s: rx error: len %u exceeds max size %d\n",
781                          dev->name, len, GOOD_PACKET_LEN);
782                 dev->stats.rx_length_errors++;
783                 goto err;
784         }
785
786         if (likely(!vi->xdp_enabled)) {
787                 xdp_prog = NULL;
788                 goto skip_xdp;
789         }
790
791         rcu_read_lock();
792         xdp_prog = rcu_dereference(rq->xdp_prog);
793         if (xdp_prog) {
794                 struct virtio_net_hdr_mrg_rxbuf *hdr = buf + header_offset;
795                 struct xdp_frame *xdpf;
796                 struct xdp_buff xdp;
797                 void *orig_data;
798                 u32 act;
799
800                 if (unlikely(hdr->hdr.gso_type))
801                         goto err_xdp;
802
803                 if (unlikely(xdp_headroom < virtnet_get_headroom(vi))) {
804                         int offset = buf - page_address(page) + header_offset;
805                         unsigned int tlen = len + vi->hdr_len;
806                         u16 num_buf = 1;
807
808                         xdp_headroom = virtnet_get_headroom(vi);
809                         header_offset = VIRTNET_RX_PAD + xdp_headroom;
810                         headroom = vi->hdr_len + header_offset;
811                         buflen = SKB_DATA_ALIGN(GOOD_PACKET_LEN + headroom) +
812                                  SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
813                         xdp_page = xdp_linearize_page(rq, &num_buf, page,
814                                                       offset, header_offset,
815                                                       &tlen);
816                         if (!xdp_page)
817                                 goto err_xdp;
818
819                         buf = page_address(xdp_page);
820                         put_page(page);
821                         page = xdp_page;
822                 }
823
824                 xdp_init_buff(&xdp, buflen, &rq->xdp_rxq);
825                 xdp_prepare_buff(&xdp, buf + VIRTNET_RX_PAD + vi->hdr_len,
826                                  xdp_headroom, len, true);
827                 orig_data = xdp.data;
828                 act = bpf_prog_run_xdp(xdp_prog, &xdp);
829                 stats->xdp_packets++;
830
831                 switch (act) {
832                 case XDP_PASS:
833                         /* Recalculate length in case bpf program changed it */
834                         delta = orig_data - xdp.data;
835                         len = xdp.data_end - xdp.data;
836                         metasize = xdp.data - xdp.data_meta;
837                         break;
838                 case XDP_TX:
839                         stats->xdp_tx++;
840                         xdpf = xdp_convert_buff_to_frame(&xdp);
841                         if (unlikely(!xdpf))
842                                 goto err_xdp;
843                         err = virtnet_xdp_xmit(dev, 1, &xdpf, 0);
844                         if (unlikely(!err)) {
845                                 xdp_return_frame_rx_napi(xdpf);
846                         } else if (unlikely(err < 0)) {
847                                 trace_xdp_exception(vi->dev, xdp_prog, act);
848                                 goto err_xdp;
849                         }
850                         *xdp_xmit |= VIRTIO_XDP_TX;
851                         rcu_read_unlock();
852                         goto xdp_xmit;
853                 case XDP_REDIRECT:
854                         stats->xdp_redirects++;
855                         err = xdp_do_redirect(dev, &xdp, xdp_prog);
856                         if (err)
857                                 goto err_xdp;
858                         *xdp_xmit |= VIRTIO_XDP_REDIR;
859                         rcu_read_unlock();
860                         goto xdp_xmit;
861                 default:
862                         bpf_warn_invalid_xdp_action(vi->dev, xdp_prog, act);
863                         fallthrough;
864                 case XDP_ABORTED:
865                         trace_xdp_exception(vi->dev, xdp_prog, act);
866                         goto err_xdp;
867                 case XDP_DROP:
868                         goto err_xdp;
869                 }
870         }
871         rcu_read_unlock();
872
873 skip_xdp:
874         skb = build_skb(buf, buflen);
875         if (!skb)
876                 goto err;
877         skb_reserve(skb, headroom - delta);
878         skb_put(skb, len);
879         if (!xdp_prog) {
880                 buf += header_offset;
881                 memcpy(skb_vnet_hdr(skb), buf, vi->hdr_len);
882         } /* keep zeroed vnet hdr since XDP is loaded */
883
884         if (metasize)
885                 skb_metadata_set(skb, metasize);
886
887         return skb;
888
889 err_xdp:
890         rcu_read_unlock();
891         stats->xdp_drops++;
892 err:
893         stats->drops++;
894         put_page(page);
895 xdp_xmit:
896         return NULL;
897 }
898
899 static struct sk_buff *receive_big(struct net_device *dev,
900                                    struct virtnet_info *vi,
901                                    struct receive_queue *rq,
902                                    void *buf,
903                                    unsigned int len,
904                                    struct virtnet_rq_stats *stats)
905 {
906         struct page *page = buf;
907         struct sk_buff *skb =
908                 page_to_skb(vi, rq, page, 0, len, PAGE_SIZE, true, 0, 0);
909
910         stats->bytes += len - vi->hdr_len;
911         if (unlikely(!skb))
912                 goto err;
913
914         return skb;
915
916 err:
917         stats->drops++;
918         give_pages(rq, page);
919         return NULL;
920 }
921
922 static struct sk_buff *receive_mergeable(struct net_device *dev,
923                                          struct virtnet_info *vi,
924                                          struct receive_queue *rq,
925                                          void *buf,
926                                          void *ctx,
927                                          unsigned int len,
928                                          unsigned int *xdp_xmit,
929                                          struct virtnet_rq_stats *stats)
930 {
931         struct virtio_net_hdr_mrg_rxbuf *hdr = buf;
932         u16 num_buf = virtio16_to_cpu(vi->vdev, hdr->num_buffers);
933         struct page *page = virt_to_head_page(buf);
934         int offset = buf - page_address(page);
935         struct sk_buff *head_skb, *curr_skb;
936         struct bpf_prog *xdp_prog;
937         unsigned int truesize = mergeable_ctx_to_truesize(ctx);
938         unsigned int headroom = mergeable_ctx_to_headroom(ctx);
939         unsigned int metasize = 0;
940         unsigned int frame_sz;
941         int err;
942
943         head_skb = NULL;
944         stats->bytes += len - vi->hdr_len;
945
946         if (unlikely(len > truesize)) {
947                 pr_debug("%s: rx error: len %u exceeds truesize %lu\n",
948                          dev->name, len, (unsigned long)ctx);
949                 dev->stats.rx_length_errors++;
950                 goto err_skb;
951         }
952
953         if (likely(!vi->xdp_enabled)) {
954                 xdp_prog = NULL;
955                 goto skip_xdp;
956         }
957
958         rcu_read_lock();
959         xdp_prog = rcu_dereference(rq->xdp_prog);
960         if (xdp_prog) {
961                 struct xdp_frame *xdpf;
962                 struct page *xdp_page;
963                 struct xdp_buff xdp;
964                 void *data;
965                 u32 act;
966
967                 /* Transient failure which in theory could occur if
968                  * in-flight packets from before XDP was enabled reach
969                  * the receive path after XDP is loaded.
970                  */
971                 if (unlikely(hdr->hdr.gso_type))
972                         goto err_xdp;
973
974                 /* Buffers with headroom use PAGE_SIZE as alloc size,
975                  * see add_recvbuf_mergeable() + get_mergeable_buf_len()
976                  */
977                 frame_sz = headroom ? PAGE_SIZE : truesize;
978
979                 /* This happens when rx buffer size is underestimated
980                  * or headroom is not enough because of the buffer
981                  * was refilled before XDP is set. This should only
982                  * happen for the first several packets, so we don't
983                  * care much about its performance.
984                  */
985                 if (unlikely(num_buf > 1 ||
986                              headroom < virtnet_get_headroom(vi))) {
987                         /* linearize data for XDP */
988                         xdp_page = xdp_linearize_page(rq, &num_buf,
989                                                       page, offset,
990                                                       VIRTIO_XDP_HEADROOM,
991                                                       &len);
992                         frame_sz = PAGE_SIZE;
993
994                         if (!xdp_page)
995                                 goto err_xdp;
996                         offset = VIRTIO_XDP_HEADROOM;
997                 } else {
998                         xdp_page = page;
999                 }
1000
1001                 /* Allow consuming headroom but reserve enough space to push
1002                  * the descriptor on if we get an XDP_TX return code.
1003                  */
1004                 data = page_address(xdp_page) + offset;
1005                 xdp_init_buff(&xdp, frame_sz - vi->hdr_len, &rq->xdp_rxq);
1006                 xdp_prepare_buff(&xdp, data - VIRTIO_XDP_HEADROOM + vi->hdr_len,
1007                                  VIRTIO_XDP_HEADROOM, len - vi->hdr_len, true);
1008
1009                 act = bpf_prog_run_xdp(xdp_prog, &xdp);
1010                 stats->xdp_packets++;
1011
1012                 switch (act) {
1013                 case XDP_PASS:
1014                         metasize = xdp.data - xdp.data_meta;
1015
1016                         /* recalculate offset to account for any header
1017                          * adjustments and minus the metasize to copy the
1018                          * metadata in page_to_skb(). Note other cases do not
1019                          * build an skb and avoid using offset
1020                          */
1021                         offset = xdp.data - page_address(xdp_page) -
1022                                  vi->hdr_len - metasize;
1023
1024                         /* recalculate len if xdp.data, xdp.data_end or
1025                          * xdp.data_meta were adjusted
1026                          */
1027                         len = xdp.data_end - xdp.data + vi->hdr_len + metasize;
1028
1029                         /* recalculate headroom if xdp.data or xdp_data_meta
1030                          * were adjusted, note that offset should always point
1031                          * to the start of the reserved bytes for virtio_net
1032                          * header which are followed by xdp.data, that means
1033                          * that offset is equal to the headroom (when buf is
1034                          * starting at the beginning of the page, otherwise
1035                          * there is a base offset inside the page) but it's used
1036                          * with a different starting point (buf start) than
1037                          * xdp.data (buf start + vnet hdr size). If xdp.data or
1038                          * data_meta were adjusted by the xdp prog then the
1039                          * headroom size has changed and so has the offset, we
1040                          * can use data_hard_start, which points at buf start +
1041                          * vnet hdr size, to calculate the new headroom and use
1042                          * it later to compute buf start in page_to_skb()
1043                          */
1044                         headroom = xdp.data - xdp.data_hard_start - metasize;
1045
1046                         /* We can only create skb based on xdp_page. */
1047                         if (unlikely(xdp_page != page)) {
1048                                 rcu_read_unlock();
1049                                 put_page(page);
1050                                 head_skb = page_to_skb(vi, rq, xdp_page, offset,
1051                                                        len, PAGE_SIZE, false,
1052                                                        metasize,
1053                                                        headroom);
1054                                 return head_skb;
1055                         }
1056                         break;
1057                 case XDP_TX:
1058                         stats->xdp_tx++;
1059                         xdpf = xdp_convert_buff_to_frame(&xdp);
1060                         if (unlikely(!xdpf))
1061                                 goto err_xdp;
1062                         err = virtnet_xdp_xmit(dev, 1, &xdpf, 0);
1063                         if (unlikely(!err)) {
1064                                 xdp_return_frame_rx_napi(xdpf);
1065                         } else if (unlikely(err < 0)) {
1066                                 trace_xdp_exception(vi->dev, xdp_prog, act);
1067                                 if (unlikely(xdp_page != page))
1068                                         put_page(xdp_page);
1069                                 goto err_xdp;
1070                         }
1071                         *xdp_xmit |= VIRTIO_XDP_TX;
1072                         if (unlikely(xdp_page != page))
1073                                 put_page(page);
1074                         rcu_read_unlock();
1075                         goto xdp_xmit;
1076                 case XDP_REDIRECT:
1077                         stats->xdp_redirects++;
1078                         err = xdp_do_redirect(dev, &xdp, xdp_prog);
1079                         if (err) {
1080                                 if (unlikely(xdp_page != page))
1081                                         put_page(xdp_page);
1082                                 goto err_xdp;
1083                         }
1084                         *xdp_xmit |= VIRTIO_XDP_REDIR;
1085                         if (unlikely(xdp_page != page))
1086                                 put_page(page);
1087                         rcu_read_unlock();
1088                         goto xdp_xmit;
1089                 default:
1090                         bpf_warn_invalid_xdp_action(vi->dev, xdp_prog, act);
1091                         fallthrough;
1092                 case XDP_ABORTED:
1093                         trace_xdp_exception(vi->dev, xdp_prog, act);
1094                         fallthrough;
1095                 case XDP_DROP:
1096                         if (unlikely(xdp_page != page))
1097                                 __free_pages(xdp_page, 0);
1098                         goto err_xdp;
1099                 }
1100         }
1101         rcu_read_unlock();
1102
1103 skip_xdp:
1104         head_skb = page_to_skb(vi, rq, page, offset, len, truesize, !xdp_prog,
1105                                metasize, headroom);
1106         curr_skb = head_skb;
1107
1108         if (unlikely(!curr_skb))
1109                 goto err_skb;
1110         while (--num_buf) {
1111                 int num_skb_frags;
1112
1113                 buf = virtqueue_get_buf_ctx(rq->vq, &len, &ctx);
1114                 if (unlikely(!buf)) {
1115                         pr_debug("%s: rx error: %d buffers out of %d missing\n",
1116                                  dev->name, num_buf,
1117                                  virtio16_to_cpu(vi->vdev,
1118                                                  hdr->num_buffers));
1119                         dev->stats.rx_length_errors++;
1120                         goto err_buf;
1121                 }
1122
1123                 stats->bytes += len;
1124                 page = virt_to_head_page(buf);
1125
1126                 truesize = mergeable_ctx_to_truesize(ctx);
1127                 if (unlikely(len > truesize)) {
1128                         pr_debug("%s: rx error: len %u exceeds truesize %lu\n",
1129                                  dev->name, len, (unsigned long)ctx);
1130                         dev->stats.rx_length_errors++;
1131                         goto err_skb;
1132                 }
1133
1134                 num_skb_frags = skb_shinfo(curr_skb)->nr_frags;
1135                 if (unlikely(num_skb_frags == MAX_SKB_FRAGS)) {
1136                         struct sk_buff *nskb = alloc_skb(0, GFP_ATOMIC);
1137
1138                         if (unlikely(!nskb))
1139                                 goto err_skb;
1140                         if (curr_skb == head_skb)
1141                                 skb_shinfo(curr_skb)->frag_list = nskb;
1142                         else
1143                                 curr_skb->next = nskb;
1144                         curr_skb = nskb;
1145                         head_skb->truesize += nskb->truesize;
1146                         num_skb_frags = 0;
1147                 }
1148                 if (curr_skb != head_skb) {
1149                         head_skb->data_len += len;
1150                         head_skb->len += len;
1151                         head_skb->truesize += truesize;
1152                 }
1153                 offset = buf - page_address(page);
1154                 if (skb_can_coalesce(curr_skb, num_skb_frags, page, offset)) {
1155                         put_page(page);
1156                         skb_coalesce_rx_frag(curr_skb, num_skb_frags - 1,
1157                                              len, truesize);
1158                 } else {
1159                         skb_add_rx_frag(curr_skb, num_skb_frags, page,
1160                                         offset, len, truesize);
1161                 }
1162         }
1163
1164         ewma_pkt_len_add(&rq->mrg_avg_pkt_len, head_skb->len);
1165         return head_skb;
1166
1167 err_xdp:
1168         rcu_read_unlock();
1169         stats->xdp_drops++;
1170 err_skb:
1171         put_page(page);
1172         while (num_buf-- > 1) {
1173                 buf = virtqueue_get_buf(rq->vq, &len);
1174                 if (unlikely(!buf)) {
1175                         pr_debug("%s: rx error: %d buffers missing\n",
1176                                  dev->name, num_buf);
1177                         dev->stats.rx_length_errors++;
1178                         break;
1179                 }
1180                 stats->bytes += len;
1181                 page = virt_to_head_page(buf);
1182                 put_page(page);
1183         }
1184 err_buf:
1185         stats->drops++;
1186         dev_kfree_skb(head_skb);
1187 xdp_xmit:
1188         return NULL;
1189 }
1190
1191 static void virtio_skb_set_hash(const struct virtio_net_hdr_v1_hash *hdr_hash,
1192                                 struct sk_buff *skb)
1193 {
1194         enum pkt_hash_types rss_hash_type;
1195
1196         if (!hdr_hash || !skb)
1197                 return;
1198
1199         switch ((int)hdr_hash->hash_report) {
1200         case VIRTIO_NET_HASH_REPORT_TCPv4:
1201         case VIRTIO_NET_HASH_REPORT_UDPv4:
1202         case VIRTIO_NET_HASH_REPORT_TCPv6:
1203         case VIRTIO_NET_HASH_REPORT_UDPv6:
1204         case VIRTIO_NET_HASH_REPORT_TCPv6_EX:
1205         case VIRTIO_NET_HASH_REPORT_UDPv6_EX:
1206                 rss_hash_type = PKT_HASH_TYPE_L4;
1207                 break;
1208         case VIRTIO_NET_HASH_REPORT_IPv4:
1209         case VIRTIO_NET_HASH_REPORT_IPv6:
1210         case VIRTIO_NET_HASH_REPORT_IPv6_EX:
1211                 rss_hash_type = PKT_HASH_TYPE_L3;
1212                 break;
1213         case VIRTIO_NET_HASH_REPORT_NONE:
1214         default:
1215                 rss_hash_type = PKT_HASH_TYPE_NONE;
1216         }
1217         skb_set_hash(skb, (unsigned int)hdr_hash->hash_value, rss_hash_type);
1218 }
1219
1220 static void receive_buf(struct virtnet_info *vi, struct receive_queue *rq,
1221                         void *buf, unsigned int len, void **ctx,
1222                         unsigned int *xdp_xmit,
1223                         struct virtnet_rq_stats *stats)
1224 {
1225         struct net_device *dev = vi->dev;
1226         struct sk_buff *skb;
1227         struct virtio_net_hdr_mrg_rxbuf *hdr;
1228
1229         if (unlikely(len < vi->hdr_len + ETH_HLEN)) {
1230                 pr_debug("%s: short packet %i\n", dev->name, len);
1231                 dev->stats.rx_length_errors++;
1232                 if (vi->mergeable_rx_bufs) {
1233                         put_page(virt_to_head_page(buf));
1234                 } else if (vi->big_packets) {
1235                         give_pages(rq, buf);
1236                 } else {
1237                         put_page(virt_to_head_page(buf));
1238                 }
1239                 return;
1240         }
1241
1242         if (vi->mergeable_rx_bufs)
1243                 skb = receive_mergeable(dev, vi, rq, buf, ctx, len, xdp_xmit,
1244                                         stats);
1245         else if (vi->big_packets)
1246                 skb = receive_big(dev, vi, rq, buf, len, stats);
1247         else
1248                 skb = receive_small(dev, vi, rq, buf, ctx, len, xdp_xmit, stats);
1249
1250         if (unlikely(!skb))
1251                 return;
1252
1253         hdr = skb_vnet_hdr(skb);
1254         if (dev->features & NETIF_F_RXHASH && vi->has_rss_hash_report)
1255                 virtio_skb_set_hash((const struct virtio_net_hdr_v1_hash *)hdr, skb);
1256
1257         if (hdr->hdr.flags & VIRTIO_NET_HDR_F_DATA_VALID)
1258                 skb->ip_summed = CHECKSUM_UNNECESSARY;
1259
1260         if (virtio_net_hdr_to_skb(skb, &hdr->hdr,
1261                                   virtio_is_little_endian(vi->vdev))) {
1262                 net_warn_ratelimited("%s: bad gso: type: %u, size: %u\n",
1263                                      dev->name, hdr->hdr.gso_type,
1264                                      hdr->hdr.gso_size);
1265                 goto frame_err;
1266         }
1267
1268         skb_record_rx_queue(skb, vq2rxq(rq->vq));
1269         skb->protocol = eth_type_trans(skb, dev);
1270         pr_debug("Receiving skb proto 0x%04x len %i type %i\n",
1271                  ntohs(skb->protocol), skb->len, skb->pkt_type);
1272
1273         napi_gro_receive(&rq->napi, skb);
1274         return;
1275
1276 frame_err:
1277         dev->stats.rx_frame_errors++;
1278         dev_kfree_skb(skb);
1279 }
1280
1281 /* Unlike mergeable buffers, all buffers are allocated to the
1282  * same size, except for the headroom. For this reason we do
1283  * not need to use  mergeable_len_to_ctx here - it is enough
1284  * to store the headroom as the context ignoring the truesize.
1285  */
1286 static int add_recvbuf_small(struct virtnet_info *vi, struct receive_queue *rq,
1287                              gfp_t gfp)
1288 {
1289         struct page_frag *alloc_frag = &rq->alloc_frag;
1290         char *buf;
1291         unsigned int xdp_headroom = virtnet_get_headroom(vi);
1292         void *ctx = (void *)(unsigned long)xdp_headroom;
1293         int len = vi->hdr_len + VIRTNET_RX_PAD + GOOD_PACKET_LEN + xdp_headroom;
1294         int err;
1295
1296         len = SKB_DATA_ALIGN(len) +
1297               SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
1298         if (unlikely(!skb_page_frag_refill(len, alloc_frag, gfp)))
1299                 return -ENOMEM;
1300
1301         buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
1302         get_page(alloc_frag->page);
1303         alloc_frag->offset += len;
1304         sg_init_one(rq->sg, buf + VIRTNET_RX_PAD + xdp_headroom,
1305                     vi->hdr_len + GOOD_PACKET_LEN);
1306         err = virtqueue_add_inbuf_ctx(rq->vq, rq->sg, 1, buf, ctx, gfp);
1307         if (err < 0)
1308                 put_page(virt_to_head_page(buf));
1309         return err;
1310 }
1311
1312 static int add_recvbuf_big(struct virtnet_info *vi, struct receive_queue *rq,
1313                            gfp_t gfp)
1314 {
1315         struct page *first, *list = NULL;
1316         char *p;
1317         int i, err, offset;
1318
1319         sg_init_table(rq->sg, MAX_SKB_FRAGS + 2);
1320
1321         /* page in rq->sg[MAX_SKB_FRAGS + 1] is list tail */
1322         for (i = MAX_SKB_FRAGS + 1; i > 1; --i) {
1323                 first = get_a_page(rq, gfp);
1324                 if (!first) {
1325                         if (list)
1326                                 give_pages(rq, list);
1327                         return -ENOMEM;
1328                 }
1329                 sg_set_buf(&rq->sg[i], page_address(first), PAGE_SIZE);
1330
1331                 /* chain new page in list head to match sg */
1332                 first->private = (unsigned long)list;
1333                 list = first;
1334         }
1335
1336         first = get_a_page(rq, gfp);
1337         if (!first) {
1338                 give_pages(rq, list);
1339                 return -ENOMEM;
1340         }
1341         p = page_address(first);
1342
1343         /* rq->sg[0], rq->sg[1] share the same page */
1344         /* a separated rq->sg[0] for header - required in case !any_header_sg */
1345         sg_set_buf(&rq->sg[0], p, vi->hdr_len);
1346
1347         /* rq->sg[1] for data packet, from offset */
1348         offset = sizeof(struct padded_vnet_hdr);
1349         sg_set_buf(&rq->sg[1], p + offset, PAGE_SIZE - offset);
1350
1351         /* chain first in list head */
1352         first->private = (unsigned long)list;
1353         err = virtqueue_add_inbuf(rq->vq, rq->sg, MAX_SKB_FRAGS + 2,
1354                                   first, gfp);
1355         if (err < 0)
1356                 give_pages(rq, first);
1357
1358         return err;
1359 }
1360
1361 static unsigned int get_mergeable_buf_len(struct receive_queue *rq,
1362                                           struct ewma_pkt_len *avg_pkt_len,
1363                                           unsigned int room)
1364 {
1365         struct virtnet_info *vi = rq->vq->vdev->priv;
1366         const size_t hdr_len = vi->hdr_len;
1367         unsigned int len;
1368
1369         if (room)
1370                 return PAGE_SIZE - room;
1371
1372         len = hdr_len + clamp_t(unsigned int, ewma_pkt_len_read(avg_pkt_len),
1373                                 rq->min_buf_len, PAGE_SIZE - hdr_len);
1374
1375         return ALIGN(len, L1_CACHE_BYTES);
1376 }
1377
1378 static int add_recvbuf_mergeable(struct virtnet_info *vi,
1379                                  struct receive_queue *rq, gfp_t gfp)
1380 {
1381         struct page_frag *alloc_frag = &rq->alloc_frag;
1382         unsigned int headroom = virtnet_get_headroom(vi);
1383         unsigned int tailroom = headroom ? sizeof(struct skb_shared_info) : 0;
1384         unsigned int room = SKB_DATA_ALIGN(headroom + tailroom);
1385         char *buf;
1386         void *ctx;
1387         int err;
1388         unsigned int len, hole;
1389
1390         /* Extra tailroom is needed to satisfy XDP's assumption. This
1391          * means rx frags coalescing won't work, but consider we've
1392          * disabled GSO for XDP, it won't be a big issue.
1393          */
1394         len = get_mergeable_buf_len(rq, &rq->mrg_avg_pkt_len, room);
1395         if (unlikely(!skb_page_frag_refill(len + room, alloc_frag, gfp)))
1396                 return -ENOMEM;
1397
1398         buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
1399         buf += headroom; /* advance address leaving hole at front of pkt */
1400         get_page(alloc_frag->page);
1401         alloc_frag->offset += len + room;
1402         hole = alloc_frag->size - alloc_frag->offset;
1403         if (hole < len + room) {
1404                 /* To avoid internal fragmentation, if there is very likely not
1405                  * enough space for another buffer, add the remaining space to
1406                  * the current buffer.
1407                  */
1408                 len += hole;
1409                 alloc_frag->offset += hole;
1410         }
1411
1412         sg_init_one(rq->sg, buf, len);
1413         ctx = mergeable_len_to_ctx(len, headroom);
1414         err = virtqueue_add_inbuf_ctx(rq->vq, rq->sg, 1, buf, ctx, gfp);
1415         if (err < 0)
1416                 put_page(virt_to_head_page(buf));
1417
1418         return err;
1419 }
1420
1421 /*
1422  * Returns false if we couldn't fill entirely (OOM).
1423  *
1424  * Normally run in the receive path, but can also be run from ndo_open
1425  * before we're receiving packets, or from refill_work which is
1426  * careful to disable receiving (using napi_disable).
1427  */
1428 static bool try_fill_recv(struct virtnet_info *vi, struct receive_queue *rq,
1429                           gfp_t gfp)
1430 {
1431         int err;
1432         bool oom;
1433
1434         do {
1435                 if (vi->mergeable_rx_bufs)
1436                         err = add_recvbuf_mergeable(vi, rq, gfp);
1437                 else if (vi->big_packets)
1438                         err = add_recvbuf_big(vi, rq, gfp);
1439                 else
1440                         err = add_recvbuf_small(vi, rq, gfp);
1441
1442                 oom = err == -ENOMEM;
1443                 if (err)
1444                         break;
1445         } while (rq->vq->num_free);
1446         if (virtqueue_kick_prepare(rq->vq) && virtqueue_notify(rq->vq)) {
1447                 unsigned long flags;
1448
1449                 flags = u64_stats_update_begin_irqsave(&rq->stats.syncp);
1450                 rq->stats.kicks++;
1451                 u64_stats_update_end_irqrestore(&rq->stats.syncp, flags);
1452         }
1453
1454         return !oom;
1455 }
1456
1457 static void skb_recv_done(struct virtqueue *rvq)
1458 {
1459         struct virtnet_info *vi = rvq->vdev->priv;
1460         struct receive_queue *rq = &vi->rq[vq2rxq(rvq)];
1461
1462         virtqueue_napi_schedule(&rq->napi, rvq);
1463 }
1464
1465 static void virtnet_napi_enable(struct virtqueue *vq, struct napi_struct *napi)
1466 {
1467         napi_enable(napi);
1468
1469         /* If all buffers were filled by other side before we napi_enabled, we
1470          * won't get another interrupt, so process any outstanding packets now.
1471          * Call local_bh_enable after to trigger softIRQ processing.
1472          */
1473         local_bh_disable();
1474         virtqueue_napi_schedule(napi, vq);
1475         local_bh_enable();
1476 }
1477
1478 static void virtnet_napi_tx_enable(struct virtnet_info *vi,
1479                                    struct virtqueue *vq,
1480                                    struct napi_struct *napi)
1481 {
1482         if (!napi->weight)
1483                 return;
1484
1485         /* Tx napi touches cachelines on the cpu handling tx interrupts. Only
1486          * enable the feature if this is likely affine with the transmit path.
1487          */
1488         if (!vi->affinity_hint_set) {
1489                 napi->weight = 0;
1490                 return;
1491         }
1492
1493         return virtnet_napi_enable(vq, napi);
1494 }
1495
1496 static void virtnet_napi_tx_disable(struct napi_struct *napi)
1497 {
1498         if (napi->weight)
1499                 napi_disable(napi);
1500 }
1501
1502 static void refill_work(struct work_struct *work)
1503 {
1504         struct virtnet_info *vi =
1505                 container_of(work, struct virtnet_info, refill.work);
1506         bool still_empty;
1507         int i;
1508
1509         for (i = 0; i < vi->curr_queue_pairs; i++) {
1510                 struct receive_queue *rq = &vi->rq[i];
1511
1512                 napi_disable(&rq->napi);
1513                 still_empty = !try_fill_recv(vi, rq, GFP_KERNEL);
1514                 virtnet_napi_enable(rq->vq, &rq->napi);
1515
1516                 /* In theory, this can happen: if we don't get any buffers in
1517                  * we will *never* try to fill again.
1518                  */
1519                 if (still_empty)
1520                         schedule_delayed_work(&vi->refill, HZ/2);
1521         }
1522 }
1523
1524 static int virtnet_receive(struct receive_queue *rq, int budget,
1525                            unsigned int *xdp_xmit)
1526 {
1527         struct virtnet_info *vi = rq->vq->vdev->priv;
1528         struct virtnet_rq_stats stats = {};
1529         unsigned int len;
1530         void *buf;
1531         int i;
1532
1533         if (!vi->big_packets || vi->mergeable_rx_bufs) {
1534                 void *ctx;
1535
1536                 while (stats.packets < budget &&
1537                        (buf = virtqueue_get_buf_ctx(rq->vq, &len, &ctx))) {
1538                         receive_buf(vi, rq, buf, len, ctx, xdp_xmit, &stats);
1539                         stats.packets++;
1540                 }
1541         } else {
1542                 while (stats.packets < budget &&
1543                        (buf = virtqueue_get_buf(rq->vq, &len)) != NULL) {
1544                         receive_buf(vi, rq, buf, len, NULL, xdp_xmit, &stats);
1545                         stats.packets++;
1546                 }
1547         }
1548
1549         if (rq->vq->num_free > min((unsigned int)budget, virtqueue_get_vring_size(rq->vq)) / 2) {
1550                 if (!try_fill_recv(vi, rq, GFP_ATOMIC)) {
1551                         spin_lock(&vi->refill_lock);
1552                         if (vi->refill_enabled)
1553                                 schedule_delayed_work(&vi->refill, 0);
1554                         spin_unlock(&vi->refill_lock);
1555                 }
1556         }
1557
1558         u64_stats_update_begin(&rq->stats.syncp);
1559         for (i = 0; i < VIRTNET_RQ_STATS_LEN; i++) {
1560                 size_t offset = virtnet_rq_stats_desc[i].offset;
1561                 u64 *item;
1562
1563                 item = (u64 *)((u8 *)&rq->stats + offset);
1564                 *item += *(u64 *)((u8 *)&stats + offset);
1565         }
1566         u64_stats_update_end(&rq->stats.syncp);
1567
1568         return stats.packets;
1569 }
1570
1571 static void free_old_xmit_skbs(struct send_queue *sq, bool in_napi)
1572 {
1573         unsigned int len;
1574         unsigned int packets = 0;
1575         unsigned int bytes = 0;
1576         void *ptr;
1577
1578         while ((ptr = virtqueue_get_buf(sq->vq, &len)) != NULL) {
1579                 if (likely(!is_xdp_frame(ptr))) {
1580                         struct sk_buff *skb = ptr;
1581
1582                         pr_debug("Sent skb %p\n", skb);
1583
1584                         bytes += skb->len;
1585                         napi_consume_skb(skb, in_napi);
1586                 } else {
1587                         struct xdp_frame *frame = ptr_to_xdp(ptr);
1588
1589                         bytes += frame->len;
1590                         xdp_return_frame(frame);
1591                 }
1592                 packets++;
1593         }
1594
1595         /* Avoid overhead when no packets have been processed
1596          * happens when called speculatively from start_xmit.
1597          */
1598         if (!packets)
1599                 return;
1600
1601         u64_stats_update_begin(&sq->stats.syncp);
1602         sq->stats.bytes += bytes;
1603         sq->stats.packets += packets;
1604         u64_stats_update_end(&sq->stats.syncp);
1605 }
1606
1607 static bool is_xdp_raw_buffer_queue(struct virtnet_info *vi, int q)
1608 {
1609         if (q < (vi->curr_queue_pairs - vi->xdp_queue_pairs))
1610                 return false;
1611         else if (q < vi->curr_queue_pairs)
1612                 return true;
1613         else
1614                 return false;
1615 }
1616
1617 static void virtnet_poll_cleantx(struct receive_queue *rq)
1618 {
1619         struct virtnet_info *vi = rq->vq->vdev->priv;
1620         unsigned int index = vq2rxq(rq->vq);
1621         struct send_queue *sq = &vi->sq[index];
1622         struct netdev_queue *txq = netdev_get_tx_queue(vi->dev, index);
1623
1624         if (!sq->napi.weight || is_xdp_raw_buffer_queue(vi, index))
1625                 return;
1626
1627         if (__netif_tx_trylock(txq)) {
1628                 do {
1629                         virtqueue_disable_cb(sq->vq);
1630                         free_old_xmit_skbs(sq, true);
1631                 } while (unlikely(!virtqueue_enable_cb_delayed(sq->vq)));
1632
1633                 if (sq->vq->num_free >= 2 + MAX_SKB_FRAGS)
1634                         netif_tx_wake_queue(txq);
1635
1636                 __netif_tx_unlock(txq);
1637         }
1638 }
1639
1640 static int virtnet_poll(struct napi_struct *napi, int budget)
1641 {
1642         struct receive_queue *rq =
1643                 container_of(napi, struct receive_queue, napi);
1644         struct virtnet_info *vi = rq->vq->vdev->priv;
1645         struct send_queue *sq;
1646         unsigned int received;
1647         unsigned int xdp_xmit = 0;
1648
1649         virtnet_poll_cleantx(rq);
1650
1651         received = virtnet_receive(rq, budget, &xdp_xmit);
1652
1653         /* Out of packets? */
1654         if (received < budget)
1655                 virtqueue_napi_complete(napi, rq->vq, received);
1656
1657         if (xdp_xmit & VIRTIO_XDP_REDIR)
1658                 xdp_do_flush();
1659
1660         if (xdp_xmit & VIRTIO_XDP_TX) {
1661                 sq = virtnet_xdp_get_sq(vi);
1662                 if (virtqueue_kick_prepare(sq->vq) && virtqueue_notify(sq->vq)) {
1663                         u64_stats_update_begin(&sq->stats.syncp);
1664                         sq->stats.kicks++;
1665                         u64_stats_update_end(&sq->stats.syncp);
1666                 }
1667                 virtnet_xdp_put_sq(vi, sq);
1668         }
1669
1670         return received;
1671 }
1672
1673 static int virtnet_open(struct net_device *dev)
1674 {
1675         struct virtnet_info *vi = netdev_priv(dev);
1676         int i, err;
1677
1678         enable_delayed_refill(vi);
1679
1680         for (i = 0; i < vi->max_queue_pairs; i++) {
1681                 if (i < vi->curr_queue_pairs)
1682                         /* Make sure we have some buffers: if oom use wq. */
1683                         if (!try_fill_recv(vi, &vi->rq[i], GFP_KERNEL))
1684                                 schedule_delayed_work(&vi->refill, 0);
1685
1686                 err = xdp_rxq_info_reg(&vi->rq[i].xdp_rxq, dev, i, vi->rq[i].napi.napi_id);
1687                 if (err < 0)
1688                         return err;
1689
1690                 err = xdp_rxq_info_reg_mem_model(&vi->rq[i].xdp_rxq,
1691                                                  MEM_TYPE_PAGE_SHARED, NULL);
1692                 if (err < 0) {
1693                         xdp_rxq_info_unreg(&vi->rq[i].xdp_rxq);
1694                         return err;
1695                 }
1696
1697                 virtnet_napi_enable(vi->rq[i].vq, &vi->rq[i].napi);
1698                 virtnet_napi_tx_enable(vi, vi->sq[i].vq, &vi->sq[i].napi);
1699         }
1700
1701         return 0;
1702 }
1703
1704 static int virtnet_poll_tx(struct napi_struct *napi, int budget)
1705 {
1706         struct send_queue *sq = container_of(napi, struct send_queue, napi);
1707         struct virtnet_info *vi = sq->vq->vdev->priv;
1708         unsigned int index = vq2txq(sq->vq);
1709         struct netdev_queue *txq;
1710         int opaque;
1711         bool done;
1712
1713         if (unlikely(is_xdp_raw_buffer_queue(vi, index))) {
1714                 /* We don't need to enable cb for XDP */
1715                 napi_complete_done(napi, 0);
1716                 return 0;
1717         }
1718
1719         txq = netdev_get_tx_queue(vi->dev, index);
1720         __netif_tx_lock(txq, raw_smp_processor_id());
1721         virtqueue_disable_cb(sq->vq);
1722         free_old_xmit_skbs(sq, true);
1723
1724         if (sq->vq->num_free >= 2 + MAX_SKB_FRAGS)
1725                 netif_tx_wake_queue(txq);
1726
1727         opaque = virtqueue_enable_cb_prepare(sq->vq);
1728
1729         done = napi_complete_done(napi, 0);
1730
1731         if (!done)
1732                 virtqueue_disable_cb(sq->vq);
1733
1734         __netif_tx_unlock(txq);
1735
1736         if (done) {
1737                 if (unlikely(virtqueue_poll(sq->vq, opaque))) {
1738                         if (napi_schedule_prep(napi)) {
1739                                 __netif_tx_lock(txq, raw_smp_processor_id());
1740                                 virtqueue_disable_cb(sq->vq);
1741                                 __netif_tx_unlock(txq);
1742                                 __napi_schedule(napi);
1743                         }
1744                 }
1745         }
1746
1747         return 0;
1748 }
1749
1750 static int xmit_skb(struct send_queue *sq, struct sk_buff *skb)
1751 {
1752         struct virtio_net_hdr_mrg_rxbuf *hdr;
1753         const unsigned char *dest = ((struct ethhdr *)skb->data)->h_dest;
1754         struct virtnet_info *vi = sq->vq->vdev->priv;
1755         int num_sg;
1756         unsigned hdr_len = vi->hdr_len;
1757         bool can_push;
1758
1759         pr_debug("%s: xmit %p %pM\n", vi->dev->name, skb, dest);
1760
1761         can_push = vi->any_header_sg &&
1762                 !((unsigned long)skb->data & (__alignof__(*hdr) - 1)) &&
1763                 !skb_header_cloned(skb) && skb_headroom(skb) >= hdr_len;
1764         /* Even if we can, don't push here yet as this would skew
1765          * csum_start offset below. */
1766         if (can_push)
1767                 hdr = (struct virtio_net_hdr_mrg_rxbuf *)(skb->data - hdr_len);
1768         else
1769                 hdr = skb_vnet_hdr(skb);
1770
1771         if (virtio_net_hdr_from_skb(skb, &hdr->hdr,
1772                                     virtio_is_little_endian(vi->vdev), false,
1773                                     0))
1774                 return -EPROTO;
1775
1776         if (vi->mergeable_rx_bufs)
1777                 hdr->num_buffers = 0;
1778
1779         sg_init_table(sq->sg, skb_shinfo(skb)->nr_frags + (can_push ? 1 : 2));
1780         if (can_push) {
1781                 __skb_push(skb, hdr_len);
1782                 num_sg = skb_to_sgvec(skb, sq->sg, 0, skb->len);
1783                 if (unlikely(num_sg < 0))
1784                         return num_sg;
1785                 /* Pull header back to avoid skew in tx bytes calculations. */
1786                 __skb_pull(skb, hdr_len);
1787         } else {
1788                 sg_set_buf(sq->sg, hdr, hdr_len);
1789                 num_sg = skb_to_sgvec(skb, sq->sg + 1, 0, skb->len);
1790                 if (unlikely(num_sg < 0))
1791                         return num_sg;
1792                 num_sg++;
1793         }
1794         return virtqueue_add_outbuf(sq->vq, sq->sg, num_sg, skb, GFP_ATOMIC);
1795 }
1796
1797 static netdev_tx_t start_xmit(struct sk_buff *skb, struct net_device *dev)
1798 {
1799         struct virtnet_info *vi = netdev_priv(dev);
1800         int qnum = skb_get_queue_mapping(skb);
1801         struct send_queue *sq = &vi->sq[qnum];
1802         int err;
1803         struct netdev_queue *txq = netdev_get_tx_queue(dev, qnum);
1804         bool kick = !netdev_xmit_more();
1805         bool use_napi = sq->napi.weight;
1806
1807         /* Free up any pending old buffers before queueing new ones. */
1808         do {
1809                 if (use_napi)
1810                         virtqueue_disable_cb(sq->vq);
1811
1812                 free_old_xmit_skbs(sq, false);
1813
1814         } while (use_napi && kick &&
1815                unlikely(!virtqueue_enable_cb_delayed(sq->vq)));
1816
1817         /* timestamp packet in software */
1818         skb_tx_timestamp(skb);
1819
1820         /* Try to transmit */
1821         err = xmit_skb(sq, skb);
1822
1823         /* This should not happen! */
1824         if (unlikely(err)) {
1825                 dev->stats.tx_fifo_errors++;
1826                 if (net_ratelimit())
1827                         dev_warn(&dev->dev,
1828                                  "Unexpected TXQ (%d) queue failure: %d\n",
1829                                  qnum, err);
1830                 dev->stats.tx_dropped++;
1831                 dev_kfree_skb_any(skb);
1832                 return NETDEV_TX_OK;
1833         }
1834
1835         /* Don't wait up for transmitted skbs to be freed. */
1836         if (!use_napi) {
1837                 skb_orphan(skb);
1838                 nf_reset_ct(skb);
1839         }
1840
1841         /* If running out of space, stop queue to avoid getting packets that we
1842          * are then unable to transmit.
1843          * An alternative would be to force queuing layer to requeue the skb by
1844          * returning NETDEV_TX_BUSY. However, NETDEV_TX_BUSY should not be
1845          * returned in a normal path of operation: it means that driver is not
1846          * maintaining the TX queue stop/start state properly, and causes
1847          * the stack to do a non-trivial amount of useless work.
1848          * Since most packets only take 1 or 2 ring slots, stopping the queue
1849          * early means 16 slots are typically wasted.
1850          */
1851         if (sq->vq->num_free < 2+MAX_SKB_FRAGS) {
1852                 netif_stop_subqueue(dev, qnum);
1853                 if (!use_napi &&
1854                     unlikely(!virtqueue_enable_cb_delayed(sq->vq))) {
1855                         /* More just got used, free them then recheck. */
1856                         free_old_xmit_skbs(sq, false);
1857                         if (sq->vq->num_free >= 2+MAX_SKB_FRAGS) {
1858                                 netif_start_subqueue(dev, qnum);
1859                                 virtqueue_disable_cb(sq->vq);
1860                         }
1861                 }
1862         }
1863
1864         if (kick || netif_xmit_stopped(txq)) {
1865                 if (virtqueue_kick_prepare(sq->vq) && virtqueue_notify(sq->vq)) {
1866                         u64_stats_update_begin(&sq->stats.syncp);
1867                         sq->stats.kicks++;
1868                         u64_stats_update_end(&sq->stats.syncp);
1869                 }
1870         }
1871
1872         return NETDEV_TX_OK;
1873 }
1874
1875 /*
1876  * Send command via the control virtqueue and check status.  Commands
1877  * supported by the hypervisor, as indicated by feature bits, should
1878  * never fail unless improperly formatted.
1879  */
1880 static bool virtnet_send_command(struct virtnet_info *vi, u8 class, u8 cmd,
1881                                  struct scatterlist *out)
1882 {
1883         struct scatterlist *sgs[4], hdr, stat;
1884         unsigned out_num = 0, tmp;
1885         int ret;
1886
1887         /* Caller should know better */
1888         BUG_ON(!virtio_has_feature(vi->vdev, VIRTIO_NET_F_CTRL_VQ));
1889
1890         vi->ctrl->status = ~0;
1891         vi->ctrl->hdr.class = class;
1892         vi->ctrl->hdr.cmd = cmd;
1893         /* Add header */
1894         sg_init_one(&hdr, &vi->ctrl->hdr, sizeof(vi->ctrl->hdr));
1895         sgs[out_num++] = &hdr;
1896
1897         if (out)
1898                 sgs[out_num++] = out;
1899
1900         /* Add return status. */
1901         sg_init_one(&stat, &vi->ctrl->status, sizeof(vi->ctrl->status));
1902         sgs[out_num] = &stat;
1903
1904         BUG_ON(out_num + 1 > ARRAY_SIZE(sgs));
1905         ret = virtqueue_add_sgs(vi->cvq, sgs, out_num, 1, vi, GFP_ATOMIC);
1906         if (ret < 0) {
1907                 dev_warn(&vi->vdev->dev,
1908                          "Failed to add sgs for command vq: %d\n.", ret);
1909                 return false;
1910         }
1911
1912         if (unlikely(!virtqueue_kick(vi->cvq)))
1913                 return vi->ctrl->status == VIRTIO_NET_OK;
1914
1915         /* Spin for a response, the kick causes an ioport write, trapping
1916          * into the hypervisor, so the request should be handled immediately.
1917          */
1918         while (!virtqueue_get_buf(vi->cvq, &tmp) &&
1919                !virtqueue_is_broken(vi->cvq))
1920                 cpu_relax();
1921
1922         return vi->ctrl->status == VIRTIO_NET_OK;
1923 }
1924
1925 static int virtnet_set_mac_address(struct net_device *dev, void *p)
1926 {
1927         struct virtnet_info *vi = netdev_priv(dev);
1928         struct virtio_device *vdev = vi->vdev;
1929         int ret;
1930         struct sockaddr *addr;
1931         struct scatterlist sg;
1932
1933         if (virtio_has_feature(vi->vdev, VIRTIO_NET_F_STANDBY))
1934                 return -EOPNOTSUPP;
1935
1936         addr = kmemdup(p, sizeof(*addr), GFP_KERNEL);
1937         if (!addr)
1938                 return -ENOMEM;
1939
1940         ret = eth_prepare_mac_addr_change(dev, addr);
1941         if (ret)
1942                 goto out;
1943
1944         if (virtio_has_feature(vdev, VIRTIO_NET_F_CTRL_MAC_ADDR)) {
1945                 sg_init_one(&sg, addr->sa_data, dev->addr_len);
1946                 if (!virtnet_send_command(vi, VIRTIO_NET_CTRL_MAC,
1947                                           VIRTIO_NET_CTRL_MAC_ADDR_SET, &sg)) {
1948                         dev_warn(&vdev->dev,
1949                                  "Failed to set mac address by vq command.\n");
1950                         ret = -EINVAL;
1951                         goto out;
1952                 }
1953         } else if (virtio_has_feature(vdev, VIRTIO_NET_F_MAC) &&
1954                    !virtio_has_feature(vdev, VIRTIO_F_VERSION_1)) {
1955                 unsigned int i;
1956
1957                 /* Naturally, this has an atomicity problem. */
1958                 for (i = 0; i < dev->addr_len; i++)
1959                         virtio_cwrite8(vdev,
1960                                        offsetof(struct virtio_net_config, mac) +
1961                                        i, addr->sa_data[i]);
1962         }
1963
1964         eth_commit_mac_addr_change(dev, p);
1965         ret = 0;
1966
1967 out:
1968         kfree(addr);
1969         return ret;
1970 }
1971
1972 static void virtnet_stats(struct net_device *dev,
1973                           struct rtnl_link_stats64 *tot)
1974 {
1975         struct virtnet_info *vi = netdev_priv(dev);
1976         unsigned int start;
1977         int i;
1978
1979         for (i = 0; i < vi->max_queue_pairs; i++) {
1980                 u64 tpackets, tbytes, terrors, rpackets, rbytes, rdrops;
1981                 struct receive_queue *rq = &vi->rq[i];
1982                 struct send_queue *sq = &vi->sq[i];
1983
1984                 do {
1985                         start = u64_stats_fetch_begin_irq(&sq->stats.syncp);
1986                         tpackets = sq->stats.packets;
1987                         tbytes   = sq->stats.bytes;
1988                         terrors  = sq->stats.tx_timeouts;
1989                 } while (u64_stats_fetch_retry_irq(&sq->stats.syncp, start));
1990
1991                 do {
1992                         start = u64_stats_fetch_begin_irq(&rq->stats.syncp);
1993                         rpackets = rq->stats.packets;
1994                         rbytes   = rq->stats.bytes;
1995                         rdrops   = rq->stats.drops;
1996                 } while (u64_stats_fetch_retry_irq(&rq->stats.syncp, start));
1997
1998                 tot->rx_packets += rpackets;
1999                 tot->tx_packets += tpackets;
2000                 tot->rx_bytes   += rbytes;
2001                 tot->tx_bytes   += tbytes;
2002                 tot->rx_dropped += rdrops;
2003                 tot->tx_errors  += terrors;
2004         }
2005
2006         tot->tx_dropped = dev->stats.tx_dropped;
2007         tot->tx_fifo_errors = dev->stats.tx_fifo_errors;
2008         tot->rx_length_errors = dev->stats.rx_length_errors;
2009         tot->rx_frame_errors = dev->stats.rx_frame_errors;
2010 }
2011
2012 static void virtnet_ack_link_announce(struct virtnet_info *vi)
2013 {
2014         rtnl_lock();
2015         if (!virtnet_send_command(vi, VIRTIO_NET_CTRL_ANNOUNCE,
2016                                   VIRTIO_NET_CTRL_ANNOUNCE_ACK, NULL))
2017                 dev_warn(&vi->dev->dev, "Failed to ack link announce.\n");
2018         rtnl_unlock();
2019 }
2020
2021 static int _virtnet_set_queues(struct virtnet_info *vi, u16 queue_pairs)
2022 {
2023         struct scatterlist sg;
2024         struct net_device *dev = vi->dev;
2025
2026         if (!vi->has_cvq || !virtio_has_feature(vi->vdev, VIRTIO_NET_F_MQ))
2027                 return 0;
2028
2029         vi->ctrl->mq.virtqueue_pairs = cpu_to_virtio16(vi->vdev, queue_pairs);
2030         sg_init_one(&sg, &vi->ctrl->mq, sizeof(vi->ctrl->mq));
2031
2032         if (!virtnet_send_command(vi, VIRTIO_NET_CTRL_MQ,
2033                                   VIRTIO_NET_CTRL_MQ_VQ_PAIRS_SET, &sg)) {
2034                 dev_warn(&dev->dev, "Fail to set num of queue pairs to %d\n",
2035                          queue_pairs);
2036                 return -EINVAL;
2037         } else {
2038                 vi->curr_queue_pairs = queue_pairs;
2039                 /* virtnet_open() will refill when device is going to up. */
2040                 if (dev->flags & IFF_UP)
2041                         schedule_delayed_work(&vi->refill, 0);
2042         }
2043
2044         return 0;
2045 }
2046
2047 static int virtnet_set_queues(struct virtnet_info *vi, u16 queue_pairs)
2048 {
2049         int err;
2050
2051         rtnl_lock();
2052         err = _virtnet_set_queues(vi, queue_pairs);
2053         rtnl_unlock();
2054         return err;
2055 }
2056
2057 static int virtnet_close(struct net_device *dev)
2058 {
2059         struct virtnet_info *vi = netdev_priv(dev);
2060         int i;
2061
2062         /* Make sure NAPI doesn't schedule refill work */
2063         disable_delayed_refill(vi);
2064         /* Make sure refill_work doesn't re-enable napi! */
2065         cancel_delayed_work_sync(&vi->refill);
2066
2067         for (i = 0; i < vi->max_queue_pairs; i++) {
2068                 xdp_rxq_info_unreg(&vi->rq[i].xdp_rxq);
2069                 napi_disable(&vi->rq[i].napi);
2070                 virtnet_napi_tx_disable(&vi->sq[i].napi);
2071         }
2072
2073         return 0;
2074 }
2075
2076 static void virtnet_set_rx_mode(struct net_device *dev)
2077 {
2078         struct virtnet_info *vi = netdev_priv(dev);
2079         struct scatterlist sg[2];
2080         struct virtio_net_ctrl_mac *mac_data;
2081         struct netdev_hw_addr *ha;
2082         int uc_count;
2083         int mc_count;
2084         void *buf;
2085         int i;
2086
2087         /* We can't dynamically set ndo_set_rx_mode, so return gracefully */
2088         if (!virtio_has_feature(vi->vdev, VIRTIO_NET_F_CTRL_RX))
2089                 return;
2090
2091         vi->ctrl->promisc = ((dev->flags & IFF_PROMISC) != 0);
2092         vi->ctrl->allmulti = ((dev->flags & IFF_ALLMULTI) != 0);
2093
2094         sg_init_one(sg, &vi->ctrl->promisc, sizeof(vi->ctrl->promisc));
2095
2096         if (!virtnet_send_command(vi, VIRTIO_NET_CTRL_RX,
2097                                   VIRTIO_NET_CTRL_RX_PROMISC, sg))
2098                 dev_warn(&dev->dev, "Failed to %sable promisc mode.\n",
2099                          vi->ctrl->promisc ? "en" : "dis");
2100
2101         sg_init_one(sg, &vi->ctrl->allmulti, sizeof(vi->ctrl->allmulti));
2102
2103         if (!virtnet_send_command(vi, VIRTIO_NET_CTRL_RX,
2104                                   VIRTIO_NET_CTRL_RX_ALLMULTI, sg))
2105                 dev_warn(&dev->dev, "Failed to %sable allmulti mode.\n",
2106                          vi->ctrl->allmulti ? "en" : "dis");
2107
2108         uc_count = netdev_uc_count(dev);
2109         mc_count = netdev_mc_count(dev);
2110         /* MAC filter - use one buffer for both lists */
2111         buf = kzalloc(((uc_count + mc_count) * ETH_ALEN) +
2112                       (2 * sizeof(mac_data->entries)), GFP_ATOMIC);
2113         mac_data = buf;
2114         if (!buf)
2115                 return;
2116
2117         sg_init_table(sg, 2);
2118
2119         /* Store the unicast list and count in the front of the buffer */
2120         mac_data->entries = cpu_to_virtio32(vi->vdev, uc_count);
2121         i = 0;
2122         netdev_for_each_uc_addr(ha, dev)
2123                 memcpy(&mac_data->macs[i++][0], ha->addr, ETH_ALEN);
2124
2125         sg_set_buf(&sg[0], mac_data,
2126                    sizeof(mac_data->entries) + (uc_count * ETH_ALEN));
2127
2128         /* multicast list and count fill the end */
2129         mac_data = (void *)&mac_data->macs[uc_count][0];
2130
2131         mac_data->entries = cpu_to_virtio32(vi->vdev, mc_count);
2132         i = 0;
2133         netdev_for_each_mc_addr(ha, dev)
2134                 memcpy(&mac_data->macs[i++][0], ha->addr, ETH_ALEN);
2135
2136         sg_set_buf(&sg[1], mac_data,
2137                    sizeof(mac_data->entries) + (mc_count * ETH_ALEN));
2138
2139         if (!virtnet_send_command(vi, VIRTIO_NET_CTRL_MAC,
2140                                   VIRTIO_NET_CTRL_MAC_TABLE_SET, sg))
2141                 dev_warn(&dev->dev, "Failed to set MAC filter table.\n");
2142
2143         kfree(buf);
2144 }
2145
2146 static int virtnet_vlan_rx_add_vid(struct net_device *dev,
2147                                    __be16 proto, u16 vid)
2148 {
2149         struct virtnet_info *vi = netdev_priv(dev);
2150         struct scatterlist sg;
2151
2152         vi->ctrl->vid = cpu_to_virtio16(vi->vdev, vid);
2153         sg_init_one(&sg, &vi->ctrl->vid, sizeof(vi->ctrl->vid));
2154
2155         if (!virtnet_send_command(vi, VIRTIO_NET_CTRL_VLAN,
2156                                   VIRTIO_NET_CTRL_VLAN_ADD, &sg))
2157                 dev_warn(&dev->dev, "Failed to add VLAN ID %d.\n", vid);
2158         return 0;
2159 }
2160
2161 static int virtnet_vlan_rx_kill_vid(struct net_device *dev,
2162                                     __be16 proto, u16 vid)
2163 {
2164         struct virtnet_info *vi = netdev_priv(dev);
2165         struct scatterlist sg;
2166
2167         vi->ctrl->vid = cpu_to_virtio16(vi->vdev, vid);
2168         sg_init_one(&sg, &vi->ctrl->vid, sizeof(vi->ctrl->vid));
2169
2170         if (!virtnet_send_command(vi, VIRTIO_NET_CTRL_VLAN,
2171                                   VIRTIO_NET_CTRL_VLAN_DEL, &sg))
2172                 dev_warn(&dev->dev, "Failed to kill VLAN ID %d.\n", vid);
2173         return 0;
2174 }
2175
2176 static void virtnet_clean_affinity(struct virtnet_info *vi)
2177 {
2178         int i;
2179
2180         if (vi->affinity_hint_set) {
2181                 for (i = 0; i < vi->max_queue_pairs; i++) {
2182                         virtqueue_set_affinity(vi->rq[i].vq, NULL);
2183                         virtqueue_set_affinity(vi->sq[i].vq, NULL);
2184                 }
2185
2186                 vi->affinity_hint_set = false;
2187         }
2188 }
2189
2190 static void virtnet_set_affinity(struct virtnet_info *vi)
2191 {
2192         cpumask_var_t mask;
2193         int stragglers;
2194         int group_size;
2195         int i, j, cpu;
2196         int num_cpu;
2197         int stride;
2198
2199         if (!zalloc_cpumask_var(&mask, GFP_KERNEL)) {
2200                 virtnet_clean_affinity(vi);
2201                 return;
2202         }
2203
2204         num_cpu = num_online_cpus();
2205         stride = max_t(int, num_cpu / vi->curr_queue_pairs, 1);
2206         stragglers = num_cpu >= vi->curr_queue_pairs ?
2207                         num_cpu % vi->curr_queue_pairs :
2208                         0;
2209         cpu = cpumask_first(cpu_online_mask);
2210
2211         for (i = 0; i < vi->curr_queue_pairs; i++) {
2212                 group_size = stride + (i < stragglers ? 1 : 0);
2213
2214                 for (j = 0; j < group_size; j++) {
2215                         cpumask_set_cpu(cpu, mask);
2216                         cpu = cpumask_next_wrap(cpu, cpu_online_mask,
2217                                                 nr_cpu_ids, false);
2218                 }
2219                 virtqueue_set_affinity(vi->rq[i].vq, mask);
2220                 virtqueue_set_affinity(vi->sq[i].vq, mask);
2221                 __netif_set_xps_queue(vi->dev, cpumask_bits(mask), i, XPS_CPUS);
2222                 cpumask_clear(mask);
2223         }
2224
2225         vi->affinity_hint_set = true;
2226         free_cpumask_var(mask);
2227 }
2228
2229 static int virtnet_cpu_online(unsigned int cpu, struct hlist_node *node)
2230 {
2231         struct virtnet_info *vi = hlist_entry_safe(node, struct virtnet_info,
2232                                                    node);
2233         virtnet_set_affinity(vi);
2234         return 0;
2235 }
2236
2237 static int virtnet_cpu_dead(unsigned int cpu, struct hlist_node *node)
2238 {
2239         struct virtnet_info *vi = hlist_entry_safe(node, struct virtnet_info,
2240                                                    node_dead);
2241         virtnet_set_affinity(vi);
2242         return 0;
2243 }
2244
2245 static int virtnet_cpu_down_prep(unsigned int cpu, struct hlist_node *node)
2246 {
2247         struct virtnet_info *vi = hlist_entry_safe(node, struct virtnet_info,
2248                                                    node);
2249
2250         virtnet_clean_affinity(vi);
2251         return 0;
2252 }
2253
2254 static enum cpuhp_state virtionet_online;
2255
2256 static int virtnet_cpu_notif_add(struct virtnet_info *vi)
2257 {
2258         int ret;
2259
2260         ret = cpuhp_state_add_instance_nocalls(virtionet_online, &vi->node);
2261         if (ret)
2262                 return ret;
2263         ret = cpuhp_state_add_instance_nocalls(CPUHP_VIRT_NET_DEAD,
2264                                                &vi->node_dead);
2265         if (!ret)
2266                 return ret;
2267         cpuhp_state_remove_instance_nocalls(virtionet_online, &vi->node);
2268         return ret;
2269 }
2270
2271 static void virtnet_cpu_notif_remove(struct virtnet_info *vi)
2272 {
2273         cpuhp_state_remove_instance_nocalls(virtionet_online, &vi->node);
2274         cpuhp_state_remove_instance_nocalls(CPUHP_VIRT_NET_DEAD,
2275                                             &vi->node_dead);
2276 }
2277
2278 static void virtnet_get_ringparam(struct net_device *dev,
2279                                   struct ethtool_ringparam *ring,
2280                                   struct kernel_ethtool_ringparam *kernel_ring,
2281                                   struct netlink_ext_ack *extack)
2282 {
2283         struct virtnet_info *vi = netdev_priv(dev);
2284
2285         ring->rx_max_pending = virtqueue_get_vring_size(vi->rq[0].vq);
2286         ring->tx_max_pending = virtqueue_get_vring_size(vi->sq[0].vq);
2287         ring->rx_pending = ring->rx_max_pending;
2288         ring->tx_pending = ring->tx_max_pending;
2289 }
2290
2291 static bool virtnet_commit_rss_command(struct virtnet_info *vi)
2292 {
2293         struct net_device *dev = vi->dev;
2294         struct scatterlist sgs[4];
2295         unsigned int sg_buf_size;
2296
2297         /* prepare sgs */
2298         sg_init_table(sgs, 4);
2299
2300         sg_buf_size = offsetof(struct virtio_net_ctrl_rss, indirection_table);
2301         sg_set_buf(&sgs[0], &vi->ctrl->rss, sg_buf_size);
2302
2303         sg_buf_size = sizeof(uint16_t) * (vi->ctrl->rss.indirection_table_mask + 1);
2304         sg_set_buf(&sgs[1], vi->ctrl->rss.indirection_table, sg_buf_size);
2305
2306         sg_buf_size = offsetof(struct virtio_net_ctrl_rss, key)
2307                         - offsetof(struct virtio_net_ctrl_rss, max_tx_vq);
2308         sg_set_buf(&sgs[2], &vi->ctrl->rss.max_tx_vq, sg_buf_size);
2309
2310         sg_buf_size = vi->rss_key_size;
2311         sg_set_buf(&sgs[3], vi->ctrl->rss.key, sg_buf_size);
2312
2313         if (!virtnet_send_command(vi, VIRTIO_NET_CTRL_MQ,
2314                                   vi->has_rss ? VIRTIO_NET_CTRL_MQ_RSS_CONFIG
2315                                   : VIRTIO_NET_CTRL_MQ_HASH_CONFIG, sgs)) {
2316                 dev_warn(&dev->dev, "VIRTIONET issue with committing RSS sgs\n");
2317                 return false;
2318         }
2319         return true;
2320 }
2321
2322 static void virtnet_init_default_rss(struct virtnet_info *vi)
2323 {
2324         u32 indir_val = 0;
2325         int i = 0;
2326
2327         vi->ctrl->rss.hash_types = vi->rss_hash_types_supported;
2328         vi->rss_hash_types_saved = vi->rss_hash_types_supported;
2329         vi->ctrl->rss.indirection_table_mask = vi->rss_indir_table_size
2330                                                 ? vi->rss_indir_table_size - 1 : 0;
2331         vi->ctrl->rss.unclassified_queue = 0;
2332
2333         for (; i < vi->rss_indir_table_size; ++i) {
2334                 indir_val = ethtool_rxfh_indir_default(i, vi->curr_queue_pairs);
2335                 vi->ctrl->rss.indirection_table[i] = indir_val;
2336         }
2337
2338         vi->ctrl->rss.max_tx_vq = vi->curr_queue_pairs;
2339         vi->ctrl->rss.hash_key_length = vi->rss_key_size;
2340
2341         netdev_rss_key_fill(vi->ctrl->rss.key, vi->rss_key_size);
2342 }
2343
2344 static void virtnet_get_hashflow(const struct virtnet_info *vi, struct ethtool_rxnfc *info)
2345 {
2346         info->data = 0;
2347         switch (info->flow_type) {
2348         case TCP_V4_FLOW:
2349                 if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_TCPv4) {
2350                         info->data = RXH_IP_SRC | RXH_IP_DST |
2351                                                  RXH_L4_B_0_1 | RXH_L4_B_2_3;
2352                 } else if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv4) {
2353                         info->data = RXH_IP_SRC | RXH_IP_DST;
2354                 }
2355                 break;
2356         case TCP_V6_FLOW:
2357                 if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_TCPv6) {
2358                         info->data = RXH_IP_SRC | RXH_IP_DST |
2359                                                  RXH_L4_B_0_1 | RXH_L4_B_2_3;
2360                 } else if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv6) {
2361                         info->data = RXH_IP_SRC | RXH_IP_DST;
2362                 }
2363                 break;
2364         case UDP_V4_FLOW:
2365                 if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_UDPv4) {
2366                         info->data = RXH_IP_SRC | RXH_IP_DST |
2367                                                  RXH_L4_B_0_1 | RXH_L4_B_2_3;
2368                 } else if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv4) {
2369                         info->data = RXH_IP_SRC | RXH_IP_DST;
2370                 }
2371                 break;
2372         case UDP_V6_FLOW:
2373                 if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_UDPv6) {
2374                         info->data = RXH_IP_SRC | RXH_IP_DST |
2375                                                  RXH_L4_B_0_1 | RXH_L4_B_2_3;
2376                 } else if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv6) {
2377                         info->data = RXH_IP_SRC | RXH_IP_DST;
2378                 }
2379                 break;
2380         case IPV4_FLOW:
2381                 if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv4)
2382                         info->data = RXH_IP_SRC | RXH_IP_DST;
2383
2384                 break;
2385         case IPV6_FLOW:
2386                 if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv6)
2387                         info->data = RXH_IP_SRC | RXH_IP_DST;
2388
2389                 break;
2390         default:
2391                 info->data = 0;
2392                 break;
2393         }
2394 }
2395
2396 static bool virtnet_set_hashflow(struct virtnet_info *vi, struct ethtool_rxnfc *info)
2397 {
2398         u32 new_hashtypes = vi->rss_hash_types_saved;
2399         bool is_disable = info->data & RXH_DISCARD;
2400         bool is_l4 = info->data == (RXH_IP_SRC | RXH_IP_DST | RXH_L4_B_0_1 | RXH_L4_B_2_3);
2401
2402         /* supports only 'sd', 'sdfn' and 'r' */
2403         if (!((info->data == (RXH_IP_SRC | RXH_IP_DST)) | is_l4 | is_disable))
2404                 return false;
2405
2406         switch (info->flow_type) {
2407         case TCP_V4_FLOW:
2408                 new_hashtypes &= ~(VIRTIO_NET_RSS_HASH_TYPE_IPv4 | VIRTIO_NET_RSS_HASH_TYPE_TCPv4);
2409                 if (!is_disable)
2410                         new_hashtypes |= VIRTIO_NET_RSS_HASH_TYPE_IPv4
2411                                 | (is_l4 ? VIRTIO_NET_RSS_HASH_TYPE_TCPv4 : 0);
2412                 break;
2413         case UDP_V4_FLOW:
2414                 new_hashtypes &= ~(VIRTIO_NET_RSS_HASH_TYPE_IPv4 | VIRTIO_NET_RSS_HASH_TYPE_UDPv4);
2415                 if (!is_disable)
2416                         new_hashtypes |= VIRTIO_NET_RSS_HASH_TYPE_IPv4
2417                                 | (is_l4 ? VIRTIO_NET_RSS_HASH_TYPE_UDPv4 : 0);
2418                 break;
2419         case IPV4_FLOW:
2420                 new_hashtypes &= ~VIRTIO_NET_RSS_HASH_TYPE_IPv4;
2421                 if (!is_disable)
2422                         new_hashtypes = VIRTIO_NET_RSS_HASH_TYPE_IPv4;
2423                 break;
2424         case TCP_V6_FLOW:
2425                 new_hashtypes &= ~(VIRTIO_NET_RSS_HASH_TYPE_IPv6 | VIRTIO_NET_RSS_HASH_TYPE_TCPv6);
2426                 if (!is_disable)
2427                         new_hashtypes |= VIRTIO_NET_RSS_HASH_TYPE_IPv6
2428                                 | (is_l4 ? VIRTIO_NET_RSS_HASH_TYPE_TCPv6 : 0);
2429                 break;
2430         case UDP_V6_FLOW:
2431                 new_hashtypes &= ~(VIRTIO_NET_RSS_HASH_TYPE_IPv6 | VIRTIO_NET_RSS_HASH_TYPE_UDPv6);
2432                 if (!is_disable)
2433                         new_hashtypes |= VIRTIO_NET_RSS_HASH_TYPE_IPv6
2434                                 | (is_l4 ? VIRTIO_NET_RSS_HASH_TYPE_UDPv6 : 0);
2435                 break;
2436         case IPV6_FLOW:
2437                 new_hashtypes &= ~VIRTIO_NET_RSS_HASH_TYPE_IPv6;
2438                 if (!is_disable)
2439                         new_hashtypes = VIRTIO_NET_RSS_HASH_TYPE_IPv6;
2440                 break;
2441         default:
2442                 /* unsupported flow */
2443                 return false;
2444         }
2445
2446         /* if unsupported hashtype was set */
2447         if (new_hashtypes != (new_hashtypes & vi->rss_hash_types_supported))
2448                 return false;
2449
2450         if (new_hashtypes != vi->rss_hash_types_saved) {
2451                 vi->rss_hash_types_saved = new_hashtypes;
2452                 vi->ctrl->rss.hash_types = vi->rss_hash_types_saved;
2453                 if (vi->dev->features & NETIF_F_RXHASH)
2454                         return virtnet_commit_rss_command(vi);
2455         }
2456
2457         return true;
2458 }
2459
2460 static void virtnet_get_drvinfo(struct net_device *dev,
2461                                 struct ethtool_drvinfo *info)
2462 {
2463         struct virtnet_info *vi = netdev_priv(dev);
2464         struct virtio_device *vdev = vi->vdev;
2465
2466         strlcpy(info->driver, KBUILD_MODNAME, sizeof(info->driver));
2467         strlcpy(info->version, VIRTNET_DRIVER_VERSION, sizeof(info->version));
2468         strlcpy(info->bus_info, virtio_bus_name(vdev), sizeof(info->bus_info));
2469
2470 }
2471
2472 /* TODO: Eliminate OOO packets during switching */
2473 static int virtnet_set_channels(struct net_device *dev,
2474                                 struct ethtool_channels *channels)
2475 {
2476         struct virtnet_info *vi = netdev_priv(dev);
2477         u16 queue_pairs = channels->combined_count;
2478         int err;
2479
2480         /* We don't support separate rx/tx channels.
2481          * We don't allow setting 'other' channels.
2482          */
2483         if (channels->rx_count || channels->tx_count || channels->other_count)
2484                 return -EINVAL;
2485
2486         if (queue_pairs > vi->max_queue_pairs || queue_pairs == 0)
2487                 return -EINVAL;
2488
2489         /* For now we don't support modifying channels while XDP is loaded
2490          * also when XDP is loaded all RX queues have XDP programs so we only
2491          * need to check a single RX queue.
2492          */
2493         if (vi->rq[0].xdp_prog)
2494                 return -EINVAL;
2495
2496         cpus_read_lock();
2497         err = _virtnet_set_queues(vi, queue_pairs);
2498         if (err) {
2499                 cpus_read_unlock();
2500                 goto err;
2501         }
2502         virtnet_set_affinity(vi);
2503         cpus_read_unlock();
2504
2505         netif_set_real_num_tx_queues(dev, queue_pairs);
2506         netif_set_real_num_rx_queues(dev, queue_pairs);
2507  err:
2508         return err;
2509 }
2510
2511 static void virtnet_get_strings(struct net_device *dev, u32 stringset, u8 *data)
2512 {
2513         struct virtnet_info *vi = netdev_priv(dev);
2514         unsigned int i, j;
2515         u8 *p = data;
2516
2517         switch (stringset) {
2518         case ETH_SS_STATS:
2519                 for (i = 0; i < vi->curr_queue_pairs; i++) {
2520                         for (j = 0; j < VIRTNET_RQ_STATS_LEN; j++)
2521                                 ethtool_sprintf(&p, "rx_queue_%u_%s", i,
2522                                                 virtnet_rq_stats_desc[j].desc);
2523                 }
2524
2525                 for (i = 0; i < vi->curr_queue_pairs; i++) {
2526                         for (j = 0; j < VIRTNET_SQ_STATS_LEN; j++)
2527                                 ethtool_sprintf(&p, "tx_queue_%u_%s", i,
2528                                                 virtnet_sq_stats_desc[j].desc);
2529                 }
2530                 break;
2531         }
2532 }
2533
2534 static int virtnet_get_sset_count(struct net_device *dev, int sset)
2535 {
2536         struct virtnet_info *vi = netdev_priv(dev);
2537
2538         switch (sset) {
2539         case ETH_SS_STATS:
2540                 return vi->curr_queue_pairs * (VIRTNET_RQ_STATS_LEN +
2541                                                VIRTNET_SQ_STATS_LEN);
2542         default:
2543                 return -EOPNOTSUPP;
2544         }
2545 }
2546
2547 static void virtnet_get_ethtool_stats(struct net_device *dev,
2548                                       struct ethtool_stats *stats, u64 *data)
2549 {
2550         struct virtnet_info *vi = netdev_priv(dev);
2551         unsigned int idx = 0, start, i, j;
2552         const u8 *stats_base;
2553         size_t offset;
2554
2555         for (i = 0; i < vi->curr_queue_pairs; i++) {
2556                 struct receive_queue *rq = &vi->rq[i];
2557
2558                 stats_base = (u8 *)&rq->stats;
2559                 do {
2560                         start = u64_stats_fetch_begin_irq(&rq->stats.syncp);
2561                         for (j = 0; j < VIRTNET_RQ_STATS_LEN; j++) {
2562                                 offset = virtnet_rq_stats_desc[j].offset;
2563                                 data[idx + j] = *(u64 *)(stats_base + offset);
2564                         }
2565                 } while (u64_stats_fetch_retry_irq(&rq->stats.syncp, start));
2566                 idx += VIRTNET_RQ_STATS_LEN;
2567         }
2568
2569         for (i = 0; i < vi->curr_queue_pairs; i++) {
2570                 struct send_queue *sq = &vi->sq[i];
2571
2572                 stats_base = (u8 *)&sq->stats;
2573                 do {
2574                         start = u64_stats_fetch_begin_irq(&sq->stats.syncp);
2575                         for (j = 0; j < VIRTNET_SQ_STATS_LEN; j++) {
2576                                 offset = virtnet_sq_stats_desc[j].offset;
2577                                 data[idx + j] = *(u64 *)(stats_base + offset);
2578                         }
2579                 } while (u64_stats_fetch_retry_irq(&sq->stats.syncp, start));
2580                 idx += VIRTNET_SQ_STATS_LEN;
2581         }
2582 }
2583
2584 static void virtnet_get_channels(struct net_device *dev,
2585                                  struct ethtool_channels *channels)
2586 {
2587         struct virtnet_info *vi = netdev_priv(dev);
2588
2589         channels->combined_count = vi->curr_queue_pairs;
2590         channels->max_combined = vi->max_queue_pairs;
2591         channels->max_other = 0;
2592         channels->rx_count = 0;
2593         channels->tx_count = 0;
2594         channels->other_count = 0;
2595 }
2596
2597 static int virtnet_set_link_ksettings(struct net_device *dev,
2598                                       const struct ethtool_link_ksettings *cmd)
2599 {
2600         struct virtnet_info *vi = netdev_priv(dev);
2601
2602         return ethtool_virtdev_set_link_ksettings(dev, cmd,
2603                                                   &vi->speed, &vi->duplex);
2604 }
2605
2606 static int virtnet_get_link_ksettings(struct net_device *dev,
2607                                       struct ethtool_link_ksettings *cmd)
2608 {
2609         struct virtnet_info *vi = netdev_priv(dev);
2610
2611         cmd->base.speed = vi->speed;
2612         cmd->base.duplex = vi->duplex;
2613         cmd->base.port = PORT_OTHER;
2614
2615         return 0;
2616 }
2617
2618 static int virtnet_set_coalesce(struct net_device *dev,
2619                                 struct ethtool_coalesce *ec,
2620                                 struct kernel_ethtool_coalesce *kernel_coal,
2621                                 struct netlink_ext_ack *extack)
2622 {
2623         struct virtnet_info *vi = netdev_priv(dev);
2624         int i, napi_weight;
2625
2626         if (ec->tx_max_coalesced_frames > 1 ||
2627             ec->rx_max_coalesced_frames != 1)
2628                 return -EINVAL;
2629
2630         napi_weight = ec->tx_max_coalesced_frames ? NAPI_POLL_WEIGHT : 0;
2631         if (napi_weight ^ vi->sq[0].napi.weight) {
2632                 if (dev->flags & IFF_UP)
2633                         return -EBUSY;
2634                 for (i = 0; i < vi->max_queue_pairs; i++)
2635                         vi->sq[i].napi.weight = napi_weight;
2636         }
2637
2638         return 0;
2639 }
2640
2641 static int virtnet_get_coalesce(struct net_device *dev,
2642                                 struct ethtool_coalesce *ec,
2643                                 struct kernel_ethtool_coalesce *kernel_coal,
2644                                 struct netlink_ext_ack *extack)
2645 {
2646         struct ethtool_coalesce ec_default = {
2647                 .cmd = ETHTOOL_GCOALESCE,
2648                 .rx_max_coalesced_frames = 1,
2649         };
2650         struct virtnet_info *vi = netdev_priv(dev);
2651
2652         memcpy(ec, &ec_default, sizeof(ec_default));
2653
2654         if (vi->sq[0].napi.weight)
2655                 ec->tx_max_coalesced_frames = 1;
2656
2657         return 0;
2658 }
2659
2660 static void virtnet_init_settings(struct net_device *dev)
2661 {
2662         struct virtnet_info *vi = netdev_priv(dev);
2663
2664         vi->speed = SPEED_UNKNOWN;
2665         vi->duplex = DUPLEX_UNKNOWN;
2666 }
2667
2668 static void virtnet_update_settings(struct virtnet_info *vi)
2669 {
2670         u32 speed;
2671         u8 duplex;
2672
2673         if (!virtio_has_feature(vi->vdev, VIRTIO_NET_F_SPEED_DUPLEX))
2674                 return;
2675
2676         virtio_cread_le(vi->vdev, struct virtio_net_config, speed, &speed);
2677
2678         if (ethtool_validate_speed(speed))
2679                 vi->speed = speed;
2680
2681         virtio_cread_le(vi->vdev, struct virtio_net_config, duplex, &duplex);
2682
2683         if (ethtool_validate_duplex(duplex))
2684                 vi->duplex = duplex;
2685 }
2686
2687 static u32 virtnet_get_rxfh_key_size(struct net_device *dev)
2688 {
2689         return ((struct virtnet_info *)netdev_priv(dev))->rss_key_size;
2690 }
2691
2692 static u32 virtnet_get_rxfh_indir_size(struct net_device *dev)
2693 {
2694         return ((struct virtnet_info *)netdev_priv(dev))->rss_indir_table_size;
2695 }
2696
2697 static int virtnet_get_rxfh(struct net_device *dev, u32 *indir, u8 *key, u8 *hfunc)
2698 {
2699         struct virtnet_info *vi = netdev_priv(dev);
2700         int i;
2701
2702         if (indir) {
2703                 for (i = 0; i < vi->rss_indir_table_size; ++i)
2704                         indir[i] = vi->ctrl->rss.indirection_table[i];
2705         }
2706
2707         if (key)
2708                 memcpy(key, vi->ctrl->rss.key, vi->rss_key_size);
2709
2710         if (hfunc)
2711                 *hfunc = ETH_RSS_HASH_TOP;
2712
2713         return 0;
2714 }
2715
2716 static int virtnet_set_rxfh(struct net_device *dev, const u32 *indir, const u8 *key, const u8 hfunc)
2717 {
2718         struct virtnet_info *vi = netdev_priv(dev);
2719         int i;
2720
2721         if (hfunc != ETH_RSS_HASH_NO_CHANGE && hfunc != ETH_RSS_HASH_TOP)
2722                 return -EOPNOTSUPP;
2723
2724         if (indir) {
2725                 for (i = 0; i < vi->rss_indir_table_size; ++i)
2726                         vi->ctrl->rss.indirection_table[i] = indir[i];
2727         }
2728         if (key)
2729                 memcpy(vi->ctrl->rss.key, key, vi->rss_key_size);
2730
2731         virtnet_commit_rss_command(vi);
2732
2733         return 0;
2734 }
2735
2736 static int virtnet_get_rxnfc(struct net_device *dev, struct ethtool_rxnfc *info, u32 *rule_locs)
2737 {
2738         struct virtnet_info *vi = netdev_priv(dev);
2739         int rc = 0;
2740
2741         switch (info->cmd) {
2742         case ETHTOOL_GRXRINGS:
2743                 info->data = vi->curr_queue_pairs;
2744                 break;
2745         case ETHTOOL_GRXFH:
2746                 virtnet_get_hashflow(vi, info);
2747                 break;
2748         default:
2749                 rc = -EOPNOTSUPP;
2750         }
2751
2752         return rc;
2753 }
2754
2755 static int virtnet_set_rxnfc(struct net_device *dev, struct ethtool_rxnfc *info)
2756 {
2757         struct virtnet_info *vi = netdev_priv(dev);
2758         int rc = 0;
2759
2760         switch (info->cmd) {
2761         case ETHTOOL_SRXFH:
2762                 if (!virtnet_set_hashflow(vi, info))
2763                         rc = -EINVAL;
2764
2765                 break;
2766         default:
2767                 rc = -EOPNOTSUPP;
2768         }
2769
2770         return rc;
2771 }
2772
2773 static const struct ethtool_ops virtnet_ethtool_ops = {
2774         .supported_coalesce_params = ETHTOOL_COALESCE_MAX_FRAMES,
2775         .get_drvinfo = virtnet_get_drvinfo,
2776         .get_link = ethtool_op_get_link,
2777         .get_ringparam = virtnet_get_ringparam,
2778         .get_strings = virtnet_get_strings,
2779         .get_sset_count = virtnet_get_sset_count,
2780         .get_ethtool_stats = virtnet_get_ethtool_stats,
2781         .set_channels = virtnet_set_channels,
2782         .get_channels = virtnet_get_channels,
2783         .get_ts_info = ethtool_op_get_ts_info,
2784         .get_link_ksettings = virtnet_get_link_ksettings,
2785         .set_link_ksettings = virtnet_set_link_ksettings,
2786         .set_coalesce = virtnet_set_coalesce,
2787         .get_coalesce = virtnet_get_coalesce,
2788         .get_rxfh_key_size = virtnet_get_rxfh_key_size,
2789         .get_rxfh_indir_size = virtnet_get_rxfh_indir_size,
2790         .get_rxfh = virtnet_get_rxfh,
2791         .set_rxfh = virtnet_set_rxfh,
2792         .get_rxnfc = virtnet_get_rxnfc,
2793         .set_rxnfc = virtnet_set_rxnfc,
2794 };
2795
2796 static void virtnet_freeze_down(struct virtio_device *vdev)
2797 {
2798         struct virtnet_info *vi = vdev->priv;
2799
2800         /* Make sure no work handler is accessing the device */
2801         flush_work(&vi->config_work);
2802
2803         netif_tx_lock_bh(vi->dev);
2804         netif_device_detach(vi->dev);
2805         netif_tx_unlock_bh(vi->dev);
2806         if (netif_running(vi->dev))
2807                 virtnet_close(vi->dev);
2808 }
2809
2810 static int init_vqs(struct virtnet_info *vi);
2811
2812 static int virtnet_restore_up(struct virtio_device *vdev)
2813 {
2814         struct virtnet_info *vi = vdev->priv;
2815         int err;
2816
2817         err = init_vqs(vi);
2818         if (err)
2819                 return err;
2820
2821         virtio_device_ready(vdev);
2822
2823         enable_delayed_refill(vi);
2824
2825         if (netif_running(vi->dev)) {
2826                 err = virtnet_open(vi->dev);
2827                 if (err)
2828                         return err;
2829         }
2830
2831         netif_tx_lock_bh(vi->dev);
2832         netif_device_attach(vi->dev);
2833         netif_tx_unlock_bh(vi->dev);
2834         return err;
2835 }
2836
2837 static int virtnet_set_guest_offloads(struct virtnet_info *vi, u64 offloads)
2838 {
2839         struct scatterlist sg;
2840         vi->ctrl->offloads = cpu_to_virtio64(vi->vdev, offloads);
2841
2842         sg_init_one(&sg, &vi->ctrl->offloads, sizeof(vi->ctrl->offloads));
2843
2844         if (!virtnet_send_command(vi, VIRTIO_NET_CTRL_GUEST_OFFLOADS,
2845                                   VIRTIO_NET_CTRL_GUEST_OFFLOADS_SET, &sg)) {
2846                 dev_warn(&vi->dev->dev, "Fail to set guest offload.\n");
2847                 return -EINVAL;
2848         }
2849
2850         return 0;
2851 }
2852
2853 static int virtnet_clear_guest_offloads(struct virtnet_info *vi)
2854 {
2855         u64 offloads = 0;
2856
2857         if (!vi->guest_offloads)
2858                 return 0;
2859
2860         return virtnet_set_guest_offloads(vi, offloads);
2861 }
2862
2863 static int virtnet_restore_guest_offloads(struct virtnet_info *vi)
2864 {
2865         u64 offloads = vi->guest_offloads;
2866
2867         if (!vi->guest_offloads)
2868                 return 0;
2869
2870         return virtnet_set_guest_offloads(vi, offloads);
2871 }
2872
2873 static int virtnet_xdp_set(struct net_device *dev, struct bpf_prog *prog,
2874                            struct netlink_ext_ack *extack)
2875 {
2876         unsigned long int max_sz = PAGE_SIZE - sizeof(struct padded_vnet_hdr);
2877         struct virtnet_info *vi = netdev_priv(dev);
2878         struct bpf_prog *old_prog;
2879         u16 xdp_qp = 0, curr_qp;
2880         int i, err;
2881
2882         if (!virtio_has_feature(vi->vdev, VIRTIO_NET_F_CTRL_GUEST_OFFLOADS)
2883             && (virtio_has_feature(vi->vdev, VIRTIO_NET_F_GUEST_TSO4) ||
2884                 virtio_has_feature(vi->vdev, VIRTIO_NET_F_GUEST_TSO6) ||
2885                 virtio_has_feature(vi->vdev, VIRTIO_NET_F_GUEST_ECN) ||
2886                 virtio_has_feature(vi->vdev, VIRTIO_NET_F_GUEST_UFO) ||
2887                 virtio_has_feature(vi->vdev, VIRTIO_NET_F_GUEST_CSUM))) {
2888                 NL_SET_ERR_MSG_MOD(extack, "Can't set XDP while host is implementing GRO_HW/CSUM, disable GRO_HW/CSUM first");
2889                 return -EOPNOTSUPP;
2890         }
2891
2892         if (vi->mergeable_rx_bufs && !vi->any_header_sg) {
2893                 NL_SET_ERR_MSG_MOD(extack, "XDP expects header/data in single page, any_header_sg required");
2894                 return -EINVAL;
2895         }
2896
2897         if (dev->mtu > max_sz) {
2898                 NL_SET_ERR_MSG_MOD(extack, "MTU too large to enable XDP");
2899                 netdev_warn(dev, "XDP requires MTU less than %lu\n", max_sz);
2900                 return -EINVAL;
2901         }
2902
2903         curr_qp = vi->curr_queue_pairs - vi->xdp_queue_pairs;
2904         if (prog)
2905                 xdp_qp = nr_cpu_ids;
2906
2907         /* XDP requires extra queues for XDP_TX */
2908         if (curr_qp + xdp_qp > vi->max_queue_pairs) {
2909                 netdev_warn_once(dev, "XDP request %i queues but max is %i. XDP_TX and XDP_REDIRECT will operate in a slower locked tx mode.\n",
2910                                  curr_qp + xdp_qp, vi->max_queue_pairs);
2911                 xdp_qp = 0;
2912         }
2913
2914         old_prog = rtnl_dereference(vi->rq[0].xdp_prog);
2915         if (!prog && !old_prog)
2916                 return 0;
2917
2918         if (prog)
2919                 bpf_prog_add(prog, vi->max_queue_pairs - 1);
2920
2921         /* Make sure NAPI is not using any XDP TX queues for RX. */
2922         if (netif_running(dev)) {
2923                 for (i = 0; i < vi->max_queue_pairs; i++) {
2924                         napi_disable(&vi->rq[i].napi);
2925                         virtnet_napi_tx_disable(&vi->sq[i].napi);
2926                 }
2927         }
2928
2929         if (!prog) {
2930                 for (i = 0; i < vi->max_queue_pairs; i++) {
2931                         rcu_assign_pointer(vi->rq[i].xdp_prog, prog);
2932                         if (i == 0)
2933                                 virtnet_restore_guest_offloads(vi);
2934                 }
2935                 synchronize_net();
2936         }
2937
2938         err = _virtnet_set_queues(vi, curr_qp + xdp_qp);
2939         if (err)
2940                 goto err;
2941         netif_set_real_num_rx_queues(dev, curr_qp + xdp_qp);
2942         vi->xdp_queue_pairs = xdp_qp;
2943
2944         if (prog) {
2945                 vi->xdp_enabled = true;
2946                 for (i = 0; i < vi->max_queue_pairs; i++) {
2947                         rcu_assign_pointer(vi->rq[i].xdp_prog, prog);
2948                         if (i == 0 && !old_prog)
2949                                 virtnet_clear_guest_offloads(vi);
2950                 }
2951         } else {
2952                 vi->xdp_enabled = false;
2953         }
2954
2955         for (i = 0; i < vi->max_queue_pairs; i++) {
2956                 if (old_prog)
2957                         bpf_prog_put(old_prog);
2958                 if (netif_running(dev)) {
2959                         virtnet_napi_enable(vi->rq[i].vq, &vi->rq[i].napi);
2960                         virtnet_napi_tx_enable(vi, vi->sq[i].vq,
2961                                                &vi->sq[i].napi);
2962                 }
2963         }
2964
2965         return 0;
2966
2967 err:
2968         if (!prog) {
2969                 virtnet_clear_guest_offloads(vi);
2970                 for (i = 0; i < vi->max_queue_pairs; i++)
2971                         rcu_assign_pointer(vi->rq[i].xdp_prog, old_prog);
2972         }
2973
2974         if (netif_running(dev)) {
2975                 for (i = 0; i < vi->max_queue_pairs; i++) {
2976                         virtnet_napi_enable(vi->rq[i].vq, &vi->rq[i].napi);
2977                         virtnet_napi_tx_enable(vi, vi->sq[i].vq,
2978                                                &vi->sq[i].napi);
2979                 }
2980         }
2981         if (prog)
2982                 bpf_prog_sub(prog, vi->max_queue_pairs - 1);
2983         return err;
2984 }
2985
2986 static int virtnet_xdp(struct net_device *dev, struct netdev_bpf *xdp)
2987 {
2988         switch (xdp->command) {
2989         case XDP_SETUP_PROG:
2990                 return virtnet_xdp_set(dev, xdp->prog, xdp->extack);
2991         default:
2992                 return -EINVAL;
2993         }
2994 }
2995
2996 static int virtnet_get_phys_port_name(struct net_device *dev, char *buf,
2997                                       size_t len)
2998 {
2999         struct virtnet_info *vi = netdev_priv(dev);
3000         int ret;
3001
3002         if (!virtio_has_feature(vi->vdev, VIRTIO_NET_F_STANDBY))
3003                 return -EOPNOTSUPP;
3004
3005         ret = snprintf(buf, len, "sby");
3006         if (ret >= len)
3007                 return -EOPNOTSUPP;
3008
3009         return 0;
3010 }
3011
3012 static int virtnet_set_features(struct net_device *dev,
3013                                 netdev_features_t features)
3014 {
3015         struct virtnet_info *vi = netdev_priv(dev);
3016         u64 offloads;
3017         int err;
3018
3019         if ((dev->features ^ features) & NETIF_F_GRO_HW) {
3020                 if (vi->xdp_enabled)
3021                         return -EBUSY;
3022
3023                 if (features & NETIF_F_GRO_HW)
3024                         offloads = vi->guest_offloads_capable;
3025                 else
3026                         offloads = vi->guest_offloads_capable &
3027                                    ~GUEST_OFFLOAD_GRO_HW_MASK;
3028
3029                 err = virtnet_set_guest_offloads(vi, offloads);
3030                 if (err)
3031                         return err;
3032                 vi->guest_offloads = offloads;
3033         }
3034
3035         if ((dev->features ^ features) & NETIF_F_RXHASH) {
3036                 if (features & NETIF_F_RXHASH)
3037                         vi->ctrl->rss.hash_types = vi->rss_hash_types_saved;
3038                 else
3039                         vi->ctrl->rss.hash_types = VIRTIO_NET_HASH_REPORT_NONE;
3040
3041                 if (!virtnet_commit_rss_command(vi))
3042                         return -EINVAL;
3043         }
3044
3045         return 0;
3046 }
3047
3048 static void virtnet_tx_timeout(struct net_device *dev, unsigned int txqueue)
3049 {
3050         struct virtnet_info *priv = netdev_priv(dev);
3051         struct send_queue *sq = &priv->sq[txqueue];
3052         struct netdev_queue *txq = netdev_get_tx_queue(dev, txqueue);
3053
3054         u64_stats_update_begin(&sq->stats.syncp);
3055         sq->stats.tx_timeouts++;
3056         u64_stats_update_end(&sq->stats.syncp);
3057
3058         netdev_err(dev, "TX timeout on queue: %u, sq: %s, vq: 0x%x, name: %s, %u usecs ago\n",
3059                    txqueue, sq->name, sq->vq->index, sq->vq->name,
3060                    jiffies_to_usecs(jiffies - READ_ONCE(txq->trans_start)));
3061 }
3062
3063 static const struct net_device_ops virtnet_netdev = {
3064         .ndo_open            = virtnet_open,
3065         .ndo_stop            = virtnet_close,
3066         .ndo_start_xmit      = start_xmit,
3067         .ndo_validate_addr   = eth_validate_addr,
3068         .ndo_set_mac_address = virtnet_set_mac_address,
3069         .ndo_set_rx_mode     = virtnet_set_rx_mode,
3070         .ndo_get_stats64     = virtnet_stats,
3071         .ndo_vlan_rx_add_vid = virtnet_vlan_rx_add_vid,
3072         .ndo_vlan_rx_kill_vid = virtnet_vlan_rx_kill_vid,
3073         .ndo_bpf                = virtnet_xdp,
3074         .ndo_xdp_xmit           = virtnet_xdp_xmit,
3075         .ndo_features_check     = passthru_features_check,
3076         .ndo_get_phys_port_name = virtnet_get_phys_port_name,
3077         .ndo_set_features       = virtnet_set_features,
3078         .ndo_tx_timeout         = virtnet_tx_timeout,
3079 };
3080
3081 static void virtnet_config_changed_work(struct work_struct *work)
3082 {
3083         struct virtnet_info *vi =
3084                 container_of(work, struct virtnet_info, config_work);
3085         u16 v;
3086
3087         if (virtio_cread_feature(vi->vdev, VIRTIO_NET_F_STATUS,
3088                                  struct virtio_net_config, status, &v) < 0)
3089                 return;
3090
3091         if (v & VIRTIO_NET_S_ANNOUNCE) {
3092                 netdev_notify_peers(vi->dev);
3093                 virtnet_ack_link_announce(vi);
3094         }
3095
3096         /* Ignore unknown (future) status bits */
3097         v &= VIRTIO_NET_S_LINK_UP;
3098
3099         if (vi->status == v)
3100                 return;
3101
3102         vi->status = v;
3103
3104         if (vi->status & VIRTIO_NET_S_LINK_UP) {
3105                 virtnet_update_settings(vi);
3106                 netif_carrier_on(vi->dev);
3107                 netif_tx_wake_all_queues(vi->dev);
3108         } else {
3109                 netif_carrier_off(vi->dev);
3110                 netif_tx_stop_all_queues(vi->dev);
3111         }
3112 }
3113
3114 static void virtnet_config_changed(struct virtio_device *vdev)
3115 {
3116         struct virtnet_info *vi = vdev->priv;
3117
3118         schedule_work(&vi->config_work);
3119 }
3120
3121 static void virtnet_free_queues(struct virtnet_info *vi)
3122 {
3123         int i;
3124
3125         for (i = 0; i < vi->max_queue_pairs; i++) {
3126                 __netif_napi_del(&vi->rq[i].napi);
3127                 __netif_napi_del(&vi->sq[i].napi);
3128         }
3129
3130         /* We called __netif_napi_del(),
3131          * we need to respect an RCU grace period before freeing vi->rq
3132          */
3133         synchronize_net();
3134
3135         kfree(vi->rq);
3136         kfree(vi->sq);
3137         kfree(vi->ctrl);
3138 }
3139
3140 static void _free_receive_bufs(struct virtnet_info *vi)
3141 {
3142         struct bpf_prog *old_prog;
3143         int i;
3144
3145         for (i = 0; i < vi->max_queue_pairs; i++) {
3146                 while (vi->rq[i].pages)
3147                         __free_pages(get_a_page(&vi->rq[i], GFP_KERNEL), 0);
3148
3149                 old_prog = rtnl_dereference(vi->rq[i].xdp_prog);
3150                 RCU_INIT_POINTER(vi->rq[i].xdp_prog, NULL);
3151                 if (old_prog)
3152                         bpf_prog_put(old_prog);
3153         }
3154 }
3155
3156 static void free_receive_bufs(struct virtnet_info *vi)
3157 {
3158         rtnl_lock();
3159         _free_receive_bufs(vi);
3160         rtnl_unlock();
3161 }
3162
3163 static void free_receive_page_frags(struct virtnet_info *vi)
3164 {
3165         int i;
3166         for (i = 0; i < vi->max_queue_pairs; i++)
3167                 if (vi->rq[i].alloc_frag.page)
3168                         put_page(vi->rq[i].alloc_frag.page);
3169 }
3170
3171 static void free_unused_bufs(struct virtnet_info *vi)
3172 {
3173         void *buf;
3174         int i;
3175
3176         for (i = 0; i < vi->max_queue_pairs; i++) {
3177                 struct virtqueue *vq = vi->sq[i].vq;
3178                 while ((buf = virtqueue_detach_unused_buf(vq)) != NULL) {
3179                         if (!is_xdp_frame(buf))
3180                                 dev_kfree_skb(buf);
3181                         else
3182                                 xdp_return_frame(ptr_to_xdp(buf));
3183                 }
3184         }
3185
3186         for (i = 0; i < vi->max_queue_pairs; i++) {
3187                 struct virtqueue *vq = vi->rq[i].vq;
3188
3189                 while ((buf = virtqueue_detach_unused_buf(vq)) != NULL) {
3190                         if (vi->mergeable_rx_bufs) {
3191                                 put_page(virt_to_head_page(buf));
3192                         } else if (vi->big_packets) {
3193                                 give_pages(&vi->rq[i], buf);
3194                         } else {
3195                                 put_page(virt_to_head_page(buf));
3196                         }
3197                 }
3198         }
3199 }
3200
3201 static void virtnet_del_vqs(struct virtnet_info *vi)
3202 {
3203         struct virtio_device *vdev = vi->vdev;
3204
3205         virtnet_clean_affinity(vi);
3206
3207         vdev->config->del_vqs(vdev);
3208
3209         virtnet_free_queues(vi);
3210 }
3211
3212 /* How large should a single buffer be so a queue full of these can fit at
3213  * least one full packet?
3214  * Logic below assumes the mergeable buffer header is used.
3215  */
3216 static unsigned int mergeable_min_buf_len(struct virtnet_info *vi, struct virtqueue *vq)
3217 {
3218         const unsigned int hdr_len = vi->hdr_len;
3219         unsigned int rq_size = virtqueue_get_vring_size(vq);
3220         unsigned int packet_len = vi->big_packets ? IP_MAX_MTU : vi->dev->max_mtu;
3221         unsigned int buf_len = hdr_len + ETH_HLEN + VLAN_HLEN + packet_len;
3222         unsigned int min_buf_len = DIV_ROUND_UP(buf_len, rq_size);
3223
3224         return max(max(min_buf_len, hdr_len) - hdr_len,
3225                    (unsigned int)GOOD_PACKET_LEN);
3226 }
3227
3228 static int virtnet_find_vqs(struct virtnet_info *vi)
3229 {
3230         vq_callback_t **callbacks;
3231         struct virtqueue **vqs;
3232         int ret = -ENOMEM;
3233         int i, total_vqs;
3234         const char **names;
3235         bool *ctx;
3236
3237         /* We expect 1 RX virtqueue followed by 1 TX virtqueue, followed by
3238          * possible N-1 RX/TX queue pairs used in multiqueue mode, followed by
3239          * possible control vq.
3240          */
3241         total_vqs = vi->max_queue_pairs * 2 +
3242                     virtio_has_feature(vi->vdev, VIRTIO_NET_F_CTRL_VQ);
3243
3244         /* Allocate space for find_vqs parameters */
3245         vqs = kcalloc(total_vqs, sizeof(*vqs), GFP_KERNEL);
3246         if (!vqs)
3247                 goto err_vq;
3248         callbacks = kmalloc_array(total_vqs, sizeof(*callbacks), GFP_KERNEL);
3249         if (!callbacks)
3250                 goto err_callback;
3251         names = kmalloc_array(total_vqs, sizeof(*names), GFP_KERNEL);
3252         if (!names)
3253                 goto err_names;
3254         if (!vi->big_packets || vi->mergeable_rx_bufs) {
3255                 ctx = kcalloc(total_vqs, sizeof(*ctx), GFP_KERNEL);
3256                 if (!ctx)
3257                         goto err_ctx;
3258         } else {
3259                 ctx = NULL;
3260         }
3261
3262         /* Parameters for control virtqueue, if any */
3263         if (vi->has_cvq) {
3264                 callbacks[total_vqs - 1] = NULL;
3265                 names[total_vqs - 1] = "control";
3266         }
3267
3268         /* Allocate/initialize parameters for send/receive virtqueues */
3269         for (i = 0; i < vi->max_queue_pairs; i++) {
3270                 callbacks[rxq2vq(i)] = skb_recv_done;
3271                 callbacks[txq2vq(i)] = skb_xmit_done;
3272                 sprintf(vi->rq[i].name, "input.%d", i);
3273                 sprintf(vi->sq[i].name, "output.%d", i);
3274                 names[rxq2vq(i)] = vi->rq[i].name;
3275                 names[txq2vq(i)] = vi->sq[i].name;
3276                 if (ctx)
3277                         ctx[rxq2vq(i)] = true;
3278         }
3279
3280         ret = virtio_find_vqs_ctx(vi->vdev, total_vqs, vqs, callbacks,
3281                                   names, ctx, NULL);
3282         if (ret)
3283                 goto err_find;
3284
3285         if (vi->has_cvq) {
3286                 vi->cvq = vqs[total_vqs - 1];
3287                 if (virtio_has_feature(vi->vdev, VIRTIO_NET_F_CTRL_VLAN))
3288                         vi->dev->features |= NETIF_F_HW_VLAN_CTAG_FILTER;
3289         }
3290
3291         for (i = 0; i < vi->max_queue_pairs; i++) {
3292                 vi->rq[i].vq = vqs[rxq2vq(i)];
3293                 vi->rq[i].min_buf_len = mergeable_min_buf_len(vi, vi->rq[i].vq);
3294                 vi->sq[i].vq = vqs[txq2vq(i)];
3295         }
3296
3297         /* run here: ret == 0. */
3298
3299
3300 err_find:
3301         kfree(ctx);
3302 err_ctx:
3303         kfree(names);
3304 err_names:
3305         kfree(callbacks);
3306 err_callback:
3307         kfree(vqs);
3308 err_vq:
3309         return ret;
3310 }
3311
3312 static int virtnet_alloc_queues(struct virtnet_info *vi)
3313 {
3314         int i;
3315
3316         if (vi->has_cvq) {
3317                 vi->ctrl = kzalloc(sizeof(*vi->ctrl), GFP_KERNEL);
3318                 if (!vi->ctrl)
3319                         goto err_ctrl;
3320         } else {
3321                 vi->ctrl = NULL;
3322         }
3323         vi->sq = kcalloc(vi->max_queue_pairs, sizeof(*vi->sq), GFP_KERNEL);
3324         if (!vi->sq)
3325                 goto err_sq;
3326         vi->rq = kcalloc(vi->max_queue_pairs, sizeof(*vi->rq), GFP_KERNEL);
3327         if (!vi->rq)
3328                 goto err_rq;
3329
3330         INIT_DELAYED_WORK(&vi->refill, refill_work);
3331         for (i = 0; i < vi->max_queue_pairs; i++) {
3332                 vi->rq[i].pages = NULL;
3333                 netif_napi_add_weight(vi->dev, &vi->rq[i].napi, virtnet_poll,
3334                                       napi_weight);
3335                 netif_napi_add_tx_weight(vi->dev, &vi->sq[i].napi,
3336                                          virtnet_poll_tx,
3337                                          napi_tx ? napi_weight : 0);
3338
3339                 sg_init_table(vi->rq[i].sg, ARRAY_SIZE(vi->rq[i].sg));
3340                 ewma_pkt_len_init(&vi->rq[i].mrg_avg_pkt_len);
3341                 sg_init_table(vi->sq[i].sg, ARRAY_SIZE(vi->sq[i].sg));
3342
3343                 u64_stats_init(&vi->rq[i].stats.syncp);
3344                 u64_stats_init(&vi->sq[i].stats.syncp);
3345         }
3346
3347         return 0;
3348
3349 err_rq:
3350         kfree(vi->sq);
3351 err_sq:
3352         kfree(vi->ctrl);
3353 err_ctrl:
3354         return -ENOMEM;
3355 }
3356
3357 static int init_vqs(struct virtnet_info *vi)
3358 {
3359         int ret;
3360
3361         /* Allocate send & receive queues */
3362         ret = virtnet_alloc_queues(vi);
3363         if (ret)
3364                 goto err;
3365
3366         ret = virtnet_find_vqs(vi);
3367         if (ret)
3368                 goto err_free;
3369
3370         cpus_read_lock();
3371         virtnet_set_affinity(vi);
3372         cpus_read_unlock();
3373
3374         return 0;
3375
3376 err_free:
3377         virtnet_free_queues(vi);
3378 err:
3379         return ret;
3380 }
3381
3382 #ifdef CONFIG_SYSFS
3383 static ssize_t mergeable_rx_buffer_size_show(struct netdev_rx_queue *queue,
3384                 char *buf)
3385 {
3386         struct virtnet_info *vi = netdev_priv(queue->dev);
3387         unsigned int queue_index = get_netdev_rx_queue_index(queue);
3388         unsigned int headroom = virtnet_get_headroom(vi);
3389         unsigned int tailroom = headroom ? sizeof(struct skb_shared_info) : 0;
3390         struct ewma_pkt_len *avg;
3391
3392         BUG_ON(queue_index >= vi->max_queue_pairs);
3393         avg = &vi->rq[queue_index].mrg_avg_pkt_len;
3394         return sprintf(buf, "%u\n",
3395                        get_mergeable_buf_len(&vi->rq[queue_index], avg,
3396                                        SKB_DATA_ALIGN(headroom + tailroom)));
3397 }
3398
3399 static struct rx_queue_attribute mergeable_rx_buffer_size_attribute =
3400         __ATTR_RO(mergeable_rx_buffer_size);
3401
3402 static struct attribute *virtio_net_mrg_rx_attrs[] = {
3403         &mergeable_rx_buffer_size_attribute.attr,
3404         NULL
3405 };
3406
3407 static const struct attribute_group virtio_net_mrg_rx_group = {
3408         .name = "virtio_net",
3409         .attrs = virtio_net_mrg_rx_attrs
3410 };
3411 #endif
3412
3413 static bool virtnet_fail_on_feature(struct virtio_device *vdev,
3414                                     unsigned int fbit,
3415                                     const char *fname, const char *dname)
3416 {
3417         if (!virtio_has_feature(vdev, fbit))
3418                 return false;
3419
3420         dev_err(&vdev->dev, "device advertises feature %s but not %s",
3421                 fname, dname);
3422
3423         return true;
3424 }
3425
3426 #define VIRTNET_FAIL_ON(vdev, fbit, dbit)                       \
3427         virtnet_fail_on_feature(vdev, fbit, #fbit, dbit)
3428
3429 static bool virtnet_validate_features(struct virtio_device *vdev)
3430 {
3431         if (!virtio_has_feature(vdev, VIRTIO_NET_F_CTRL_VQ) &&
3432             (VIRTNET_FAIL_ON(vdev, VIRTIO_NET_F_CTRL_RX,
3433                              "VIRTIO_NET_F_CTRL_VQ") ||
3434              VIRTNET_FAIL_ON(vdev, VIRTIO_NET_F_CTRL_VLAN,
3435                              "VIRTIO_NET_F_CTRL_VQ") ||
3436              VIRTNET_FAIL_ON(vdev, VIRTIO_NET_F_GUEST_ANNOUNCE,
3437                              "VIRTIO_NET_F_CTRL_VQ") ||
3438              VIRTNET_FAIL_ON(vdev, VIRTIO_NET_F_MQ, "VIRTIO_NET_F_CTRL_VQ") ||
3439              VIRTNET_FAIL_ON(vdev, VIRTIO_NET_F_CTRL_MAC_ADDR,
3440                              "VIRTIO_NET_F_CTRL_VQ") ||
3441              VIRTNET_FAIL_ON(vdev, VIRTIO_NET_F_RSS,
3442                              "VIRTIO_NET_F_CTRL_VQ") ||
3443              VIRTNET_FAIL_ON(vdev, VIRTIO_NET_F_HASH_REPORT,
3444                              "VIRTIO_NET_F_CTRL_VQ"))) {
3445                 return false;
3446         }
3447
3448         return true;
3449 }
3450
3451 #define MIN_MTU ETH_MIN_MTU
3452 #define MAX_MTU ETH_MAX_MTU
3453
3454 static int virtnet_validate(struct virtio_device *vdev)
3455 {
3456         if (!vdev->config->get) {
3457                 dev_err(&vdev->dev, "%s failure: config access disabled\n",
3458                         __func__);
3459                 return -EINVAL;
3460         }
3461
3462         if (!virtnet_validate_features(vdev))
3463                 return -EINVAL;
3464
3465         if (virtio_has_feature(vdev, VIRTIO_NET_F_MTU)) {
3466                 int mtu = virtio_cread16(vdev,
3467                                          offsetof(struct virtio_net_config,
3468                                                   mtu));
3469                 if (mtu < MIN_MTU)
3470                         __virtio_clear_bit(vdev, VIRTIO_NET_F_MTU);
3471         }
3472
3473         return 0;
3474 }
3475
3476 static int virtnet_probe(struct virtio_device *vdev)
3477 {
3478         int i, err = -ENOMEM;
3479         struct net_device *dev;
3480         struct virtnet_info *vi;
3481         u16 max_queue_pairs;
3482         int mtu;
3483
3484         /* Find if host supports multiqueue/rss virtio_net device */
3485         max_queue_pairs = 1;
3486         if (virtio_has_feature(vdev, VIRTIO_NET_F_MQ) || virtio_has_feature(vdev, VIRTIO_NET_F_RSS))
3487                 max_queue_pairs =
3488                      virtio_cread16(vdev, offsetof(struct virtio_net_config, max_virtqueue_pairs));
3489
3490         /* We need at least 2 queue's */
3491         if (max_queue_pairs < VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MIN ||
3492             max_queue_pairs > VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MAX ||
3493             !virtio_has_feature(vdev, VIRTIO_NET_F_CTRL_VQ))
3494                 max_queue_pairs = 1;
3495
3496         /* Allocate ourselves a network device with room for our info */
3497         dev = alloc_etherdev_mq(sizeof(struct virtnet_info), max_queue_pairs);
3498         if (!dev)
3499                 return -ENOMEM;
3500
3501         /* Set up network device as normal. */
3502         dev->priv_flags |= IFF_UNICAST_FLT | IFF_LIVE_ADDR_CHANGE |
3503                            IFF_TX_SKB_NO_LINEAR;
3504         dev->netdev_ops = &virtnet_netdev;
3505         dev->features = NETIF_F_HIGHDMA;
3506
3507         dev->ethtool_ops = &virtnet_ethtool_ops;
3508         SET_NETDEV_DEV(dev, &vdev->dev);
3509
3510         /* Do we support "hardware" checksums? */
3511         if (virtio_has_feature(vdev, VIRTIO_NET_F_CSUM)) {
3512                 /* This opens up the world of extra features. */
3513                 dev->hw_features |= NETIF_F_HW_CSUM | NETIF_F_SG;
3514                 if (csum)
3515                         dev->features |= NETIF_F_HW_CSUM | NETIF_F_SG;
3516
3517                 if (virtio_has_feature(vdev, VIRTIO_NET_F_GSO)) {
3518                         dev->hw_features |= NETIF_F_TSO
3519                                 | NETIF_F_TSO_ECN | NETIF_F_TSO6;
3520                 }
3521                 /* Individual feature bits: what can host handle? */
3522                 if (virtio_has_feature(vdev, VIRTIO_NET_F_HOST_TSO4))
3523                         dev->hw_features |= NETIF_F_TSO;
3524                 if (virtio_has_feature(vdev, VIRTIO_NET_F_HOST_TSO6))
3525                         dev->hw_features |= NETIF_F_TSO6;
3526                 if (virtio_has_feature(vdev, VIRTIO_NET_F_HOST_ECN))
3527                         dev->hw_features |= NETIF_F_TSO_ECN;
3528
3529                 dev->features |= NETIF_F_GSO_ROBUST;
3530
3531                 if (gso)
3532                         dev->features |= dev->hw_features & NETIF_F_ALL_TSO;
3533                 /* (!csum && gso) case will be fixed by register_netdev() */
3534         }
3535         if (virtio_has_feature(vdev, VIRTIO_NET_F_GUEST_CSUM))
3536                 dev->features |= NETIF_F_RXCSUM;
3537         if (virtio_has_feature(vdev, VIRTIO_NET_F_GUEST_TSO4) ||
3538             virtio_has_feature(vdev, VIRTIO_NET_F_GUEST_TSO6))
3539                 dev->features |= NETIF_F_GRO_HW;
3540         if (virtio_has_feature(vdev, VIRTIO_NET_F_CTRL_GUEST_OFFLOADS))
3541                 dev->hw_features |= NETIF_F_GRO_HW;
3542
3543         dev->vlan_features = dev->features;
3544
3545         /* MTU range: 68 - 65535 */
3546         dev->min_mtu = MIN_MTU;
3547         dev->max_mtu = MAX_MTU;
3548
3549         /* Configuration may specify what MAC to use.  Otherwise random. */
3550         if (virtio_has_feature(vdev, VIRTIO_NET_F_MAC)) {
3551                 u8 addr[ETH_ALEN];
3552
3553                 virtio_cread_bytes(vdev,
3554                                    offsetof(struct virtio_net_config, mac),
3555                                    addr, ETH_ALEN);
3556                 eth_hw_addr_set(dev, addr);
3557         } else {
3558                 eth_hw_addr_random(dev);
3559         }
3560
3561         /* Set up our device-specific information */
3562         vi = netdev_priv(dev);
3563         vi->dev = dev;
3564         vi->vdev = vdev;
3565         vdev->priv = vi;
3566
3567         INIT_WORK(&vi->config_work, virtnet_config_changed_work);
3568         spin_lock_init(&vi->refill_lock);
3569
3570         /* If we can receive ANY GSO packets, we must allocate large ones. */
3571         if (virtio_has_feature(vdev, VIRTIO_NET_F_GUEST_TSO4) ||
3572             virtio_has_feature(vdev, VIRTIO_NET_F_GUEST_TSO6) ||
3573             virtio_has_feature(vdev, VIRTIO_NET_F_GUEST_ECN) ||
3574             virtio_has_feature(vdev, VIRTIO_NET_F_GUEST_UFO))
3575                 vi->big_packets = true;
3576
3577         if (virtio_has_feature(vdev, VIRTIO_NET_F_MRG_RXBUF))
3578                 vi->mergeable_rx_bufs = true;
3579
3580         if (virtio_has_feature(vdev, VIRTIO_NET_F_HASH_REPORT))
3581                 vi->has_rss_hash_report = true;
3582
3583         if (virtio_has_feature(vdev, VIRTIO_NET_F_RSS))
3584                 vi->has_rss = true;
3585
3586         if (vi->has_rss || vi->has_rss_hash_report) {
3587                 vi->rss_indir_table_size =
3588                         virtio_cread16(vdev, offsetof(struct virtio_net_config,
3589                                 rss_max_indirection_table_length));
3590                 vi->rss_key_size =
3591                         virtio_cread8(vdev, offsetof(struct virtio_net_config, rss_max_key_size));
3592
3593                 vi->rss_hash_types_supported =
3594                     virtio_cread32(vdev, offsetof(struct virtio_net_config, supported_hash_types));
3595                 vi->rss_hash_types_supported &=
3596                                 ~(VIRTIO_NET_RSS_HASH_TYPE_IP_EX |
3597                                   VIRTIO_NET_RSS_HASH_TYPE_TCP_EX |
3598                                   VIRTIO_NET_RSS_HASH_TYPE_UDP_EX);
3599
3600                 dev->hw_features |= NETIF_F_RXHASH;
3601         }
3602
3603         if (vi->has_rss_hash_report)
3604                 vi->hdr_len = sizeof(struct virtio_net_hdr_v1_hash);
3605         else if (virtio_has_feature(vdev, VIRTIO_NET_F_MRG_RXBUF) ||
3606                  virtio_has_feature(vdev, VIRTIO_F_VERSION_1))
3607                 vi->hdr_len = sizeof(struct virtio_net_hdr_mrg_rxbuf);
3608         else
3609                 vi->hdr_len = sizeof(struct virtio_net_hdr);
3610
3611         if (virtio_has_feature(vdev, VIRTIO_F_ANY_LAYOUT) ||
3612             virtio_has_feature(vdev, VIRTIO_F_VERSION_1))
3613                 vi->any_header_sg = true;
3614
3615         if (virtio_has_feature(vdev, VIRTIO_NET_F_CTRL_VQ))
3616                 vi->has_cvq = true;
3617
3618         if (virtio_has_feature(vdev, VIRTIO_NET_F_MTU)) {
3619                 mtu = virtio_cread16(vdev,
3620                                      offsetof(struct virtio_net_config,
3621                                               mtu));
3622                 if (mtu < dev->min_mtu) {
3623                         /* Should never trigger: MTU was previously validated
3624                          * in virtnet_validate.
3625                          */
3626                         dev_err(&vdev->dev,
3627                                 "device MTU appears to have changed it is now %d < %d",
3628                                 mtu, dev->min_mtu);
3629                         err = -EINVAL;
3630                         goto free;
3631                 }
3632
3633                 dev->mtu = mtu;
3634                 dev->max_mtu = mtu;
3635
3636                 /* TODO: size buffers correctly in this case. */
3637                 if (dev->mtu > ETH_DATA_LEN)
3638                         vi->big_packets = true;
3639         }
3640
3641         if (vi->any_header_sg)
3642                 dev->needed_headroom = vi->hdr_len;
3643
3644         /* Enable multiqueue by default */
3645         if (num_online_cpus() >= max_queue_pairs)
3646                 vi->curr_queue_pairs = max_queue_pairs;
3647         else
3648                 vi->curr_queue_pairs = num_online_cpus();
3649         vi->max_queue_pairs = max_queue_pairs;
3650
3651         /* Allocate/initialize the rx/tx queues, and invoke find_vqs */
3652         err = init_vqs(vi);
3653         if (err)
3654                 goto free;
3655
3656 #ifdef CONFIG_SYSFS
3657         if (vi->mergeable_rx_bufs)
3658                 dev->sysfs_rx_queue_group = &virtio_net_mrg_rx_group;
3659 #endif
3660         netif_set_real_num_tx_queues(dev, vi->curr_queue_pairs);
3661         netif_set_real_num_rx_queues(dev, vi->curr_queue_pairs);
3662
3663         virtnet_init_settings(dev);
3664
3665         if (virtio_has_feature(vdev, VIRTIO_NET_F_STANDBY)) {
3666                 vi->failover = net_failover_create(vi->dev);
3667                 if (IS_ERR(vi->failover)) {
3668                         err = PTR_ERR(vi->failover);
3669                         goto free_vqs;
3670                 }
3671         }
3672
3673         if (vi->has_rss || vi->has_rss_hash_report)
3674                 virtnet_init_default_rss(vi);
3675
3676         /* serialize netdev register + virtio_device_ready() with ndo_open() */
3677         rtnl_lock();
3678
3679         err = register_netdevice(dev);
3680         if (err) {
3681                 pr_debug("virtio_net: registering device failed\n");
3682                 rtnl_unlock();
3683                 goto free_failover;
3684         }
3685
3686         virtio_device_ready(vdev);
3687
3688         rtnl_unlock();
3689
3690         err = virtnet_cpu_notif_add(vi);
3691         if (err) {
3692                 pr_debug("virtio_net: registering cpu notifier failed\n");
3693                 goto free_unregister_netdev;
3694         }
3695
3696         virtnet_set_queues(vi, vi->curr_queue_pairs);
3697
3698         /* Assume link up if device can't report link status,
3699            otherwise get link status from config. */
3700         netif_carrier_off(dev);
3701         if (virtio_has_feature(vi->vdev, VIRTIO_NET_F_STATUS)) {
3702                 schedule_work(&vi->config_work);
3703         } else {
3704                 vi->status = VIRTIO_NET_S_LINK_UP;
3705                 virtnet_update_settings(vi);
3706                 netif_carrier_on(dev);
3707         }
3708
3709         for (i = 0; i < ARRAY_SIZE(guest_offloads); i++)
3710                 if (virtio_has_feature(vi->vdev, guest_offloads[i]))
3711                         set_bit(guest_offloads[i], &vi->guest_offloads);
3712         vi->guest_offloads_capable = vi->guest_offloads;
3713
3714         pr_debug("virtnet: registered device %s with %d RX and TX vq's\n",
3715                  dev->name, max_queue_pairs);
3716
3717         return 0;
3718
3719 free_unregister_netdev:
3720         virtio_reset_device(vdev);
3721
3722         unregister_netdev(dev);
3723 free_failover:
3724         net_failover_destroy(vi->failover);
3725 free_vqs:
3726         cancel_delayed_work_sync(&vi->refill);
3727         free_receive_page_frags(vi);
3728         virtnet_del_vqs(vi);
3729 free:
3730         free_netdev(dev);
3731         return err;
3732 }
3733
3734 static void remove_vq_common(struct virtnet_info *vi)
3735 {
3736         virtio_reset_device(vi->vdev);
3737
3738         /* Free unused buffers in both send and recv, if any. */
3739         free_unused_bufs(vi);
3740
3741         free_receive_bufs(vi);
3742
3743         free_receive_page_frags(vi);
3744
3745         virtnet_del_vqs(vi);
3746 }
3747
3748 static void virtnet_remove(struct virtio_device *vdev)
3749 {
3750         struct virtnet_info *vi = vdev->priv;
3751
3752         virtnet_cpu_notif_remove(vi);
3753
3754         /* Make sure no work handler is accessing the device. */
3755         flush_work(&vi->config_work);
3756
3757         unregister_netdev(vi->dev);
3758
3759         net_failover_destroy(vi->failover);
3760
3761         remove_vq_common(vi);
3762
3763         free_netdev(vi->dev);
3764 }
3765
3766 static __maybe_unused int virtnet_freeze(struct virtio_device *vdev)
3767 {
3768         struct virtnet_info *vi = vdev->priv;
3769
3770         virtnet_cpu_notif_remove(vi);
3771         virtnet_freeze_down(vdev);
3772         remove_vq_common(vi);
3773
3774         return 0;
3775 }
3776
3777 static __maybe_unused int virtnet_restore(struct virtio_device *vdev)
3778 {
3779         struct virtnet_info *vi = vdev->priv;
3780         int err;
3781
3782         err = virtnet_restore_up(vdev);
3783         if (err)
3784                 return err;
3785         virtnet_set_queues(vi, vi->curr_queue_pairs);
3786
3787         err = virtnet_cpu_notif_add(vi);
3788         if (err) {
3789                 virtnet_freeze_down(vdev);
3790                 remove_vq_common(vi);
3791                 return err;
3792         }
3793
3794         return 0;
3795 }
3796
3797 static struct virtio_device_id id_table[] = {
3798         { VIRTIO_ID_NET, VIRTIO_DEV_ANY_ID },
3799         { 0 },
3800 };
3801
3802 #define VIRTNET_FEATURES \
3803         VIRTIO_NET_F_CSUM, VIRTIO_NET_F_GUEST_CSUM, \
3804         VIRTIO_NET_F_MAC, \
3805         VIRTIO_NET_F_HOST_TSO4, VIRTIO_NET_F_HOST_UFO, VIRTIO_NET_F_HOST_TSO6, \
3806         VIRTIO_NET_F_HOST_ECN, VIRTIO_NET_F_GUEST_TSO4, VIRTIO_NET_F_GUEST_TSO6, \
3807         VIRTIO_NET_F_GUEST_ECN, VIRTIO_NET_F_GUEST_UFO, \
3808         VIRTIO_NET_F_MRG_RXBUF, VIRTIO_NET_F_STATUS, VIRTIO_NET_F_CTRL_VQ, \
3809         VIRTIO_NET_F_CTRL_RX, VIRTIO_NET_F_CTRL_VLAN, \
3810         VIRTIO_NET_F_GUEST_ANNOUNCE, VIRTIO_NET_F_MQ, \
3811         VIRTIO_NET_F_CTRL_MAC_ADDR, \
3812         VIRTIO_NET_F_MTU, VIRTIO_NET_F_CTRL_GUEST_OFFLOADS, \
3813         VIRTIO_NET_F_SPEED_DUPLEX, VIRTIO_NET_F_STANDBY, \
3814         VIRTIO_NET_F_RSS, VIRTIO_NET_F_HASH_REPORT
3815
3816 static unsigned int features[] = {
3817         VIRTNET_FEATURES,
3818 };
3819
3820 static unsigned int features_legacy[] = {
3821         VIRTNET_FEATURES,
3822         VIRTIO_NET_F_GSO,
3823         VIRTIO_F_ANY_LAYOUT,
3824 };
3825
3826 static struct virtio_driver virtio_net_driver = {
3827         .feature_table = features,
3828         .feature_table_size = ARRAY_SIZE(features),
3829         .feature_table_legacy = features_legacy,
3830         .feature_table_size_legacy = ARRAY_SIZE(features_legacy),
3831         .driver.name =  KBUILD_MODNAME,
3832         .driver.owner = THIS_MODULE,
3833         .id_table =     id_table,
3834         .validate =     virtnet_validate,
3835         .probe =        virtnet_probe,
3836         .remove =       virtnet_remove,
3837         .config_changed = virtnet_config_changed,
3838 #ifdef CONFIG_PM_SLEEP
3839         .freeze =       virtnet_freeze,
3840         .restore =      virtnet_restore,
3841 #endif
3842 };
3843
3844 static __init int virtio_net_driver_init(void)
3845 {
3846         int ret;
3847
3848         ret = cpuhp_setup_state_multi(CPUHP_AP_ONLINE_DYN, "virtio/net:online",
3849                                       virtnet_cpu_online,
3850                                       virtnet_cpu_down_prep);
3851         if (ret < 0)
3852                 goto out;
3853         virtionet_online = ret;
3854         ret = cpuhp_setup_state_multi(CPUHP_VIRT_NET_DEAD, "virtio/net:dead",
3855                                       NULL, virtnet_cpu_dead);
3856         if (ret)
3857                 goto err_dead;
3858         ret = register_virtio_driver(&virtio_net_driver);
3859         if (ret)
3860                 goto err_virtio;
3861         return 0;
3862 err_virtio:
3863         cpuhp_remove_multi_state(CPUHP_VIRT_NET_DEAD);
3864 err_dead:
3865         cpuhp_remove_multi_state(virtionet_online);
3866 out:
3867         return ret;
3868 }
3869 module_init(virtio_net_driver_init);
3870
3871 static __exit void virtio_net_driver_exit(void)
3872 {
3873         unregister_virtio_driver(&virtio_net_driver);
3874         cpuhp_remove_multi_state(CPUHP_VIRT_NET_DEAD);
3875         cpuhp_remove_multi_state(virtionet_online);
3876 }
3877 module_exit(virtio_net_driver_exit);
3878
3879 MODULE_DEVICE_TABLE(virtio, id_table);
3880 MODULE_DESCRIPTION("Virtio network driver");
3881 MODULE_LICENSE("GPL");