Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/dtor/input
[linux-2.6-microblaze.git] / drivers / net / wireguard / receive.c
index da3b782..9143814 100644 (file)
@@ -226,40 +226,39 @@ void wg_packet_handshake_receive_worker(struct work_struct *work)
 static void keep_key_fresh(struct wg_peer *peer)
 {
        struct noise_keypair *keypair;
-       bool send = false;
+       bool send;
 
        if (peer->sent_lastminute_handshake)
                return;
 
        rcu_read_lock_bh();
        keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
-       if (likely(keypair && READ_ONCE(keypair->sending.is_valid)) &&
-           keypair->i_am_the_initiator &&
-           unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
-                       REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT)))
-               send = true;
+       send = keypair && READ_ONCE(keypair->sending.is_valid) &&
+              keypair->i_am_the_initiator &&
+              wg_birthdate_has_expired(keypair->sending.birthdate,
+                       REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT);
        rcu_read_unlock_bh();
 
-       if (send) {
+       if (unlikely(send)) {
                peer->sent_lastminute_handshake = true;
                wg_packet_send_queued_handshake_initiation(peer, false);
        }
 }
 
-static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
+static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
 {
        struct scatterlist sg[MAX_SKB_FRAGS + 8];
        struct sk_buff *trailer;
        unsigned int offset;
        int num_frags;
 
-       if (unlikely(!key))
+       if (unlikely(!keypair))
                return false;
 
-       if (unlikely(!READ_ONCE(key->is_valid) ||
-                 wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) ||
-                 key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
-               WRITE_ONCE(key->is_valid, false);
+       if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
+                 wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
+                 keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
+               WRITE_ONCE(keypair->receiving.is_valid, false);
                return false;
        }
 
@@ -284,7 +283,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
 
        if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
                                                 PACKET_CB(skb)->nonce,
-                                                key->key))
+                                                keypair->receiving.key))
                return false;
 
        /* Another ugly situation of pushing and pulling the header so as to
@@ -299,41 +298,41 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
 }
 
 /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
-static bool counter_validate(union noise_counter *counter, u64 their_counter)
+static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
 {
        unsigned long index, index_current, top, i;
        bool ret = false;
 
-       spin_lock_bh(&counter->receive.lock);
+       spin_lock_bh(&counter->lock);
 
-       if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 ||
+       if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
                     their_counter >= REJECT_AFTER_MESSAGES))
                goto out;
 
        ++their_counter;
 
        if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
-                    counter->receive.counter))
+                    counter->counter))
                goto out;
 
        index = their_counter >> ilog2(BITS_PER_LONG);
 
-       if (likely(their_counter > counter->receive.counter)) {
-               index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
+       if (likely(their_counter > counter->counter)) {
+               index_current = counter->counter >> ilog2(BITS_PER_LONG);
                top = min_t(unsigned long, index - index_current,
                            COUNTER_BITS_TOTAL / BITS_PER_LONG);
                for (i = 1; i <= top; ++i)
-                       counter->receive.backtrack[(i + index_current) &
+                       counter->backtrack[(i + index_current) &
                                ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
-               counter->receive.counter = their_counter;
+               counter->counter = their_counter;
        }
 
        index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
        ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
-                               &counter->receive.backtrack[index]);
+                               &counter->backtrack[index]);
 
 out:
-       spin_unlock_bh(&counter->receive.lock);
+       spin_unlock_bh(&counter->lock);
        return ret;
 }
 
@@ -393,13 +392,11 @@ static void wg_packet_consume_data_done(struct wg_peer *peer,
                len = ntohs(ip_hdr(skb)->tot_len);
                if (unlikely(len < sizeof(struct iphdr)))
                        goto dishonest_packet_size;
-               if (INET_ECN_is_ce(PACKET_CB(skb)->ds))
-                       IP_ECN_set_ce(ip_hdr(skb));
+               INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ip_hdr(skb)->tos);
        } else if (skb->protocol == htons(ETH_P_IPV6)) {
                len = ntohs(ipv6_hdr(skb)->payload_len) +
                      sizeof(struct ipv6hdr);
-               if (INET_ECN_is_ce(PACKET_CB(skb)->ds))
-                       IP6_ECN_set_ce(skb, ipv6_hdr(skb));
+               INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ipv6_get_dsfield(ipv6_hdr(skb)));
        } else {
                goto dishonest_packet_type;
        }
@@ -475,19 +472,19 @@ int wg_packet_rx_poll(struct napi_struct *napi, int budget)
                if (unlikely(state != PACKET_STATE_CRYPTED))
                        goto next;
 
-               if (unlikely(!counter_validate(&keypair->receiving.counter,
+               if (unlikely(!counter_validate(&keypair->receiving_counter,
                                               PACKET_CB(skb)->nonce))) {
                        net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
                                            peer->device->dev->name,
                                            PACKET_CB(skb)->nonce,
-                                           keypair->receiving.counter.receive.counter);
+                                           keypair->receiving_counter.counter);
                        goto next;
                }
 
                if (unlikely(wg_socket_endpoint_from_skb(&endpoint, skb)))
                        goto next;
 
-               wg_reset_packet(skb);
+               wg_reset_packet(skb, false);
                wg_packet_consume_data_done(peer, skb, &endpoint);
                free = false;
 
@@ -514,10 +511,12 @@ void wg_packet_decrypt_worker(struct work_struct *work)
        struct sk_buff *skb;
 
        while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
-               enum packet_state state = likely(decrypt_packet(skb,
-                               &PACKET_CB(skb)->keypair->receiving)) ?
+               enum packet_state state =
+                       likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
                                PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
                wg_queue_enqueue_per_peer_napi(skb, state);
+               if (need_resched())
+                       cond_resched();
        }
 }