201a22681945fa845d8b22d802d454ca7000d405
[linux-2.6-microblaze.git] / drivers / net / wireguard / noise.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5
6 #include "noise.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "messages.h"
10 #include "queueing.h"
11 #include "peerlookup.h"
12
13 #include <linux/rcupdate.h>
14 #include <linux/slab.h>
15 #include <linux/bitmap.h>
16 #include <linux/scatterlist.h>
17 #include <linux/highmem.h>
18 #include <crypto/algapi.h>
19
20 /* This implements Noise_IKpsk2:
21  *
22  * <- s
23  * ******
24  * -> e, es, s, ss, {t}
25  * <- e, ee, se, psk, {}
26  */
27
28 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
33
34 void __init wg_noise_init(void)
35 {
36         struct blake2s_state blake;
37
38         blake2s(handshake_init_chaining_key, handshake_name, NULL,
39                 NOISE_HASH_LEN, sizeof(handshake_name), 0);
40         blake2s_init(&blake, NOISE_HASH_LEN);
41         blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42         blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43         blake2s_final(&blake, handshake_init_hash);
44 }
45
46 /* Must hold peer->handshake.static_identity->lock */
47 void wg_noise_precompute_static_static(struct wg_peer *peer)
48 {
49         down_write(&peer->handshake.lock);
50         if (!peer->handshake.static_identity->has_identity ||
51             !curve25519(peer->handshake.precomputed_static_static,
52                         peer->handshake.static_identity->static_private,
53                         peer->handshake.remote_static))
54                 memset(peer->handshake.precomputed_static_static, 0,
55                        NOISE_PUBLIC_KEY_LEN);
56         up_write(&peer->handshake.lock);
57 }
58
59 void wg_noise_handshake_init(struct noise_handshake *handshake,
60                              struct noise_static_identity *static_identity,
61                              const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
62                              const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
63                              struct wg_peer *peer)
64 {
65         memset(handshake, 0, sizeof(*handshake));
66         init_rwsem(&handshake->lock);
67         handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
68         handshake->entry.peer = peer;
69         memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
70         if (peer_preshared_key)
71                 memcpy(handshake->preshared_key, peer_preshared_key,
72                        NOISE_SYMMETRIC_KEY_LEN);
73         handshake->static_identity = static_identity;
74         handshake->state = HANDSHAKE_ZEROED;
75         wg_noise_precompute_static_static(peer);
76 }
77
78 static void handshake_zero(struct noise_handshake *handshake)
79 {
80         memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
81         memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
82         memset(&handshake->hash, 0, NOISE_HASH_LEN);
83         memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
84         handshake->remote_index = 0;
85         handshake->state = HANDSHAKE_ZEROED;
86 }
87
88 void wg_noise_handshake_clear(struct noise_handshake *handshake)
89 {
90         wg_index_hashtable_remove(
91                         handshake->entry.peer->device->index_hashtable,
92                         &handshake->entry);
93         down_write(&handshake->lock);
94         handshake_zero(handshake);
95         up_write(&handshake->lock);
96         wg_index_hashtable_remove(
97                         handshake->entry.peer->device->index_hashtable,
98                         &handshake->entry);
99 }
100
101 static struct noise_keypair *keypair_create(struct wg_peer *peer)
102 {
103         struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
104
105         if (unlikely(!keypair))
106                 return NULL;
107         spin_lock_init(&keypair->receiving_counter.lock);
108         keypair->internal_id = atomic64_inc_return(&keypair_counter);
109         keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
110         keypair->entry.peer = peer;
111         kref_init(&keypair->refcount);
112         return keypair;
113 }
114
115 static void keypair_free_rcu(struct rcu_head *rcu)
116 {
117         kzfree(container_of(rcu, struct noise_keypair, rcu));
118 }
119
120 static void keypair_free_kref(struct kref *kref)
121 {
122         struct noise_keypair *keypair =
123                 container_of(kref, struct noise_keypair, refcount);
124
125         net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
126                             keypair->entry.peer->device->dev->name,
127                             keypair->internal_id,
128                             keypair->entry.peer->internal_id);
129         wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
130                                   &keypair->entry);
131         call_rcu(&keypair->rcu, keypair_free_rcu);
132 }
133
134 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
135 {
136         if (unlikely(!keypair))
137                 return;
138         if (unlikely(unreference_now))
139                 wg_index_hashtable_remove(
140                         keypair->entry.peer->device->index_hashtable,
141                         &keypair->entry);
142         kref_put(&keypair->refcount, keypair_free_kref);
143 }
144
145 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
146 {
147         RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
148                 "Taking noise keypair reference without holding the RCU BH read lock");
149         if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
150                 return NULL;
151         return keypair;
152 }
153
154 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
155 {
156         struct noise_keypair *old;
157
158         spin_lock_bh(&keypairs->keypair_update_lock);
159
160         /* We zero the next_keypair before zeroing the others, so that
161          * wg_noise_received_with_keypair returns early before subsequent ones
162          * are zeroed.
163          */
164         old = rcu_dereference_protected(keypairs->next_keypair,
165                 lockdep_is_held(&keypairs->keypair_update_lock));
166         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
167         wg_noise_keypair_put(old, true);
168
169         old = rcu_dereference_protected(keypairs->previous_keypair,
170                 lockdep_is_held(&keypairs->keypair_update_lock));
171         RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
172         wg_noise_keypair_put(old, true);
173
174         old = rcu_dereference_protected(keypairs->current_keypair,
175                 lockdep_is_held(&keypairs->keypair_update_lock));
176         RCU_INIT_POINTER(keypairs->current_keypair, NULL);
177         wg_noise_keypair_put(old, true);
178
179         spin_unlock_bh(&keypairs->keypair_update_lock);
180 }
181
182 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
183 {
184         struct noise_keypair *keypair;
185
186         wg_noise_handshake_clear(&peer->handshake);
187         wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
188
189         spin_lock_bh(&peer->keypairs.keypair_update_lock);
190         keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
191                         lockdep_is_held(&peer->keypairs.keypair_update_lock));
192         if (keypair)
193                 keypair->sending.is_valid = false;
194         keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
195                         lockdep_is_held(&peer->keypairs.keypair_update_lock));
196         if (keypair)
197                 keypair->sending.is_valid = false;
198         spin_unlock_bh(&peer->keypairs.keypair_update_lock);
199 }
200
201 static void add_new_keypair(struct noise_keypairs *keypairs,
202                             struct noise_keypair *new_keypair)
203 {
204         struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
205
206         spin_lock_bh(&keypairs->keypair_update_lock);
207         previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
208                 lockdep_is_held(&keypairs->keypair_update_lock));
209         next_keypair = rcu_dereference_protected(keypairs->next_keypair,
210                 lockdep_is_held(&keypairs->keypair_update_lock));
211         current_keypair = rcu_dereference_protected(keypairs->current_keypair,
212                 lockdep_is_held(&keypairs->keypair_update_lock));
213         if (new_keypair->i_am_the_initiator) {
214                 /* If we're the initiator, it means we've sent a handshake, and
215                  * received a confirmation response, which means this new
216                  * keypair can now be used.
217                  */
218                 if (next_keypair) {
219                         /* If there already was a next keypair pending, we
220                          * demote it to be the previous keypair, and free the
221                          * existing current. Note that this means KCI can result
222                          * in this transition. It would perhaps be more sound to
223                          * always just get rid of the unused next keypair
224                          * instead of putting it in the previous slot, but this
225                          * might be a bit less robust. Something to think about
226                          * for the future.
227                          */
228                         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
229                         rcu_assign_pointer(keypairs->previous_keypair,
230                                            next_keypair);
231                         wg_noise_keypair_put(current_keypair, true);
232                 } else /* If there wasn't an existing next keypair, we replace
233                         * the previous with the current one.
234                         */
235                         rcu_assign_pointer(keypairs->previous_keypair,
236                                            current_keypair);
237                 /* At this point we can get rid of the old previous keypair, and
238                  * set up the new keypair.
239                  */
240                 wg_noise_keypair_put(previous_keypair, true);
241                 rcu_assign_pointer(keypairs->current_keypair, new_keypair);
242         } else {
243                 /* If we're the responder, it means we can't use the new keypair
244                  * until we receive confirmation via the first data packet, so
245                  * we get rid of the existing previous one, the possibly
246                  * existing next one, and slide in the new next one.
247                  */
248                 rcu_assign_pointer(keypairs->next_keypair, new_keypair);
249                 wg_noise_keypair_put(next_keypair, true);
250                 RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
251                 wg_noise_keypair_put(previous_keypair, true);
252         }
253         spin_unlock_bh(&keypairs->keypair_update_lock);
254 }
255
256 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
257                                     struct noise_keypair *received_keypair)
258 {
259         struct noise_keypair *old_keypair;
260         bool key_is_new;
261
262         /* We first check without taking the spinlock. */
263         key_is_new = received_keypair ==
264                      rcu_access_pointer(keypairs->next_keypair);
265         if (likely(!key_is_new))
266                 return false;
267
268         spin_lock_bh(&keypairs->keypair_update_lock);
269         /* After locking, we double check that things didn't change from
270          * beneath us.
271          */
272         if (unlikely(received_keypair !=
273                     rcu_dereference_protected(keypairs->next_keypair,
274                             lockdep_is_held(&keypairs->keypair_update_lock)))) {
275                 spin_unlock_bh(&keypairs->keypair_update_lock);
276                 return false;
277         }
278
279         /* When we've finally received the confirmation, we slide the next
280          * into the current, the current into the previous, and get rid of
281          * the old previous.
282          */
283         old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
284                 lockdep_is_held(&keypairs->keypair_update_lock));
285         rcu_assign_pointer(keypairs->previous_keypair,
286                 rcu_dereference_protected(keypairs->current_keypair,
287                         lockdep_is_held(&keypairs->keypair_update_lock)));
288         wg_noise_keypair_put(old_keypair, true);
289         rcu_assign_pointer(keypairs->current_keypair, received_keypair);
290         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
291
292         spin_unlock_bh(&keypairs->keypair_update_lock);
293         return true;
294 }
295
296 /* Must hold static_identity->lock */
297 void wg_noise_set_static_identity_private_key(
298         struct noise_static_identity *static_identity,
299         const u8 private_key[NOISE_PUBLIC_KEY_LEN])
300 {
301         memcpy(static_identity->static_private, private_key,
302                NOISE_PUBLIC_KEY_LEN);
303         curve25519_clamp_secret(static_identity->static_private);
304         static_identity->has_identity = curve25519_generate_public(
305                 static_identity->static_public, private_key);
306 }
307
308 /* This is Hugo Krawczyk's HKDF:
309  *  - https://eprint.iacr.org/2010/264.pdf
310  *  - https://tools.ietf.org/html/rfc5869
311  */
312 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
313                 size_t first_len, size_t second_len, size_t third_len,
314                 size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
315 {
316         u8 output[BLAKE2S_HASH_SIZE + 1];
317         u8 secret[BLAKE2S_HASH_SIZE];
318
319         WARN_ON(IS_ENABLED(DEBUG) &&
320                 (first_len > BLAKE2S_HASH_SIZE ||
321                  second_len > BLAKE2S_HASH_SIZE ||
322                  third_len > BLAKE2S_HASH_SIZE ||
323                  ((second_len || second_dst || third_len || third_dst) &&
324                   (!first_len || !first_dst)) ||
325                  ((third_len || third_dst) && (!second_len || !second_dst))));
326
327         /* Extract entropy from data into secret */
328         blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
329
330         if (!first_dst || !first_len)
331                 goto out;
332
333         /* Expand first key: key = secret, data = 0x1 */
334         output[0] = 1;
335         blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
336         memcpy(first_dst, output, first_len);
337
338         if (!second_dst || !second_len)
339                 goto out;
340
341         /* Expand second key: key = secret, data = first-key || 0x2 */
342         output[BLAKE2S_HASH_SIZE] = 2;
343         blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
344                         BLAKE2S_HASH_SIZE);
345         memcpy(second_dst, output, second_len);
346
347         if (!third_dst || !third_len)
348                 goto out;
349
350         /* Expand third key: key = secret, data = second-key || 0x3 */
351         output[BLAKE2S_HASH_SIZE] = 3;
352         blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
353                         BLAKE2S_HASH_SIZE);
354         memcpy(third_dst, output, third_len);
355
356 out:
357         /* Clear sensitive data from stack */
358         memzero_explicit(secret, BLAKE2S_HASH_SIZE);
359         memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
360 }
361
362 static void derive_keys(struct noise_symmetric_key *first_dst,
363                         struct noise_symmetric_key *second_dst,
364                         const u8 chaining_key[NOISE_HASH_LEN])
365 {
366         u64 birthdate = ktime_get_coarse_boottime_ns();
367         kdf(first_dst->key, second_dst->key, NULL, NULL,
368             NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
369             chaining_key);
370         first_dst->birthdate = second_dst->birthdate = birthdate;
371         first_dst->is_valid = second_dst->is_valid = true;
372 }
373
374 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
375                                 u8 key[NOISE_SYMMETRIC_KEY_LEN],
376                                 const u8 private[NOISE_PUBLIC_KEY_LEN],
377                                 const u8 public[NOISE_PUBLIC_KEY_LEN])
378 {
379         u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
380
381         if (unlikely(!curve25519(dh_calculation, private, public)))
382                 return false;
383         kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
384             NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
385         memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
386         return true;
387 }
388
389 static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
390                                             u8 key[NOISE_SYMMETRIC_KEY_LEN],
391                                             const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
392 {
393         static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
394         if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
395                 return false;
396         kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
397             NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
398             chaining_key);
399         return true;
400 }
401
402 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
403 {
404         struct blake2s_state blake;
405
406         blake2s_init(&blake, NOISE_HASH_LEN);
407         blake2s_update(&blake, hash, NOISE_HASH_LEN);
408         blake2s_update(&blake, src, src_len);
409         blake2s_final(&blake, hash);
410 }
411
412 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
413                     u8 key[NOISE_SYMMETRIC_KEY_LEN],
414                     const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
415 {
416         u8 temp_hash[NOISE_HASH_LEN];
417
418         kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
419             NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
420         mix_hash(hash, temp_hash, NOISE_HASH_LEN);
421         memzero_explicit(temp_hash, NOISE_HASH_LEN);
422 }
423
424 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
425                            u8 hash[NOISE_HASH_LEN],
426                            const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
427 {
428         memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
429         memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
430         mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
431 }
432
433 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
434                             size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
435                             u8 hash[NOISE_HASH_LEN])
436 {
437         chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
438                                  NOISE_HASH_LEN,
439                                  0 /* Always zero for Noise_IK */, key);
440         mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
441 }
442
443 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
444                             size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
445                             u8 hash[NOISE_HASH_LEN])
446 {
447         if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
448                                       hash, NOISE_HASH_LEN,
449                                       0 /* Always zero for Noise_IK */, key))
450                 return false;
451         mix_hash(hash, src_ciphertext, src_len);
452         return true;
453 }
454
455 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
456                               const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
457                               u8 chaining_key[NOISE_HASH_LEN],
458                               u8 hash[NOISE_HASH_LEN])
459 {
460         if (ephemeral_dst != ephemeral_src)
461                 memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
462         mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
463         kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
464             NOISE_PUBLIC_KEY_LEN, chaining_key);
465 }
466
467 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
468 {
469         struct timespec64 now;
470
471         ktime_get_real_ts64(&now);
472
473         /* In order to prevent some sort of infoleak from precise timers, we
474          * round down the nanoseconds part to the closest rounded-down power of
475          * two to the maximum initiations per second allowed anyway by the
476          * implementation.
477          */
478         now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
479                 rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
480
481         /* https://cr.yp.to/libtai/tai64.html */
482         *(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
483         *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
484 }
485
486 bool
487 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
488                                      struct noise_handshake *handshake)
489 {
490         u8 timestamp[NOISE_TIMESTAMP_LEN];
491         u8 key[NOISE_SYMMETRIC_KEY_LEN];
492         bool ret = false;
493
494         /* We need to wait for crng _before_ taking any locks, since
495          * curve25519_generate_secret uses get_random_bytes_wait.
496          */
497         wait_for_random_bytes();
498
499         down_read(&handshake->static_identity->lock);
500         down_write(&handshake->lock);
501
502         if (unlikely(!handshake->static_identity->has_identity))
503                 goto out;
504
505         dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
506
507         handshake_init(handshake->chaining_key, handshake->hash,
508                        handshake->remote_static);
509
510         /* e */
511         curve25519_generate_secret(handshake->ephemeral_private);
512         if (!curve25519_generate_public(dst->unencrypted_ephemeral,
513                                         handshake->ephemeral_private))
514                 goto out;
515         message_ephemeral(dst->unencrypted_ephemeral,
516                           dst->unencrypted_ephemeral, handshake->chaining_key,
517                           handshake->hash);
518
519         /* es */
520         if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
521                     handshake->remote_static))
522                 goto out;
523
524         /* s */
525         message_encrypt(dst->encrypted_static,
526                         handshake->static_identity->static_public,
527                         NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
528
529         /* ss */
530         if (!mix_precomputed_dh(handshake->chaining_key, key,
531                                 handshake->precomputed_static_static))
532                 goto out;
533
534         /* {t} */
535         tai64n_now(timestamp);
536         message_encrypt(dst->encrypted_timestamp, timestamp,
537                         NOISE_TIMESTAMP_LEN, key, handshake->hash);
538
539         dst->sender_index = wg_index_hashtable_insert(
540                 handshake->entry.peer->device->index_hashtable,
541                 &handshake->entry);
542
543         handshake->state = HANDSHAKE_CREATED_INITIATION;
544         ret = true;
545
546 out:
547         up_write(&handshake->lock);
548         up_read(&handshake->static_identity->lock);
549         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
550         return ret;
551 }
552
553 struct wg_peer *
554 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
555                                       struct wg_device *wg)
556 {
557         struct wg_peer *peer = NULL, *ret_peer = NULL;
558         struct noise_handshake *handshake;
559         bool replay_attack, flood_attack;
560         u8 key[NOISE_SYMMETRIC_KEY_LEN];
561         u8 chaining_key[NOISE_HASH_LEN];
562         u8 hash[NOISE_HASH_LEN];
563         u8 s[NOISE_PUBLIC_KEY_LEN];
564         u8 e[NOISE_PUBLIC_KEY_LEN];
565         u8 t[NOISE_TIMESTAMP_LEN];
566         u64 initiation_consumption;
567
568         down_read(&wg->static_identity.lock);
569         if (unlikely(!wg->static_identity.has_identity))
570                 goto out;
571
572         handshake_init(chaining_key, hash, wg->static_identity.static_public);
573
574         /* e */
575         message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
576
577         /* es */
578         if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
579                 goto out;
580
581         /* s */
582         if (!message_decrypt(s, src->encrypted_static,
583                              sizeof(src->encrypted_static), key, hash))
584                 goto out;
585
586         /* Lookup which peer we're actually talking to */
587         peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
588         if (!peer)
589                 goto out;
590         handshake = &peer->handshake;
591
592         /* ss */
593         if (!mix_precomputed_dh(chaining_key, key,
594                                 handshake->precomputed_static_static))
595             goto out;
596
597         /* {t} */
598         if (!message_decrypt(t, src->encrypted_timestamp,
599                              sizeof(src->encrypted_timestamp), key, hash))
600                 goto out;
601
602         down_read(&handshake->lock);
603         replay_attack = memcmp(t, handshake->latest_timestamp,
604                                NOISE_TIMESTAMP_LEN) <= 0;
605         flood_attack = (s64)handshake->last_initiation_consumption +
606                                NSEC_PER_SEC / INITIATIONS_PER_SECOND >
607                        (s64)ktime_get_coarse_boottime_ns();
608         up_read(&handshake->lock);
609         if (replay_attack || flood_attack)
610                 goto out;
611
612         /* Success! Copy everything to peer */
613         down_write(&handshake->lock);
614         memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
615         if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
616                 memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
617         memcpy(handshake->hash, hash, NOISE_HASH_LEN);
618         memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
619         handshake->remote_index = src->sender_index;
620         initiation_consumption = ktime_get_coarse_boottime_ns();
621         if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
622                 handshake->last_initiation_consumption = initiation_consumption;
623         handshake->state = HANDSHAKE_CONSUMED_INITIATION;
624         up_write(&handshake->lock);
625         ret_peer = peer;
626
627 out:
628         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
629         memzero_explicit(hash, NOISE_HASH_LEN);
630         memzero_explicit(chaining_key, NOISE_HASH_LEN);
631         up_read(&wg->static_identity.lock);
632         if (!ret_peer)
633                 wg_peer_put(peer);
634         return ret_peer;
635 }
636
637 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
638                                         struct noise_handshake *handshake)
639 {
640         u8 key[NOISE_SYMMETRIC_KEY_LEN];
641         bool ret = false;
642
643         /* We need to wait for crng _before_ taking any locks, since
644          * curve25519_generate_secret uses get_random_bytes_wait.
645          */
646         wait_for_random_bytes();
647
648         down_read(&handshake->static_identity->lock);
649         down_write(&handshake->lock);
650
651         if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
652                 goto out;
653
654         dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
655         dst->receiver_index = handshake->remote_index;
656
657         /* e */
658         curve25519_generate_secret(handshake->ephemeral_private);
659         if (!curve25519_generate_public(dst->unencrypted_ephemeral,
660                                         handshake->ephemeral_private))
661                 goto out;
662         message_ephemeral(dst->unencrypted_ephemeral,
663                           dst->unencrypted_ephemeral, handshake->chaining_key,
664                           handshake->hash);
665
666         /* ee */
667         if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
668                     handshake->remote_ephemeral))
669                 goto out;
670
671         /* se */
672         if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
673                     handshake->remote_static))
674                 goto out;
675
676         /* psk */
677         mix_psk(handshake->chaining_key, handshake->hash, key,
678                 handshake->preshared_key);
679
680         /* {} */
681         message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
682
683         dst->sender_index = wg_index_hashtable_insert(
684                 handshake->entry.peer->device->index_hashtable,
685                 &handshake->entry);
686
687         handshake->state = HANDSHAKE_CREATED_RESPONSE;
688         ret = true;
689
690 out:
691         up_write(&handshake->lock);
692         up_read(&handshake->static_identity->lock);
693         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
694         return ret;
695 }
696
697 struct wg_peer *
698 wg_noise_handshake_consume_response(struct message_handshake_response *src,
699                                     struct wg_device *wg)
700 {
701         enum noise_handshake_state state = HANDSHAKE_ZEROED;
702         struct wg_peer *peer = NULL, *ret_peer = NULL;
703         struct noise_handshake *handshake;
704         u8 key[NOISE_SYMMETRIC_KEY_LEN];
705         u8 hash[NOISE_HASH_LEN];
706         u8 chaining_key[NOISE_HASH_LEN];
707         u8 e[NOISE_PUBLIC_KEY_LEN];
708         u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
709         u8 static_private[NOISE_PUBLIC_KEY_LEN];
710         u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
711
712         down_read(&wg->static_identity.lock);
713
714         if (unlikely(!wg->static_identity.has_identity))
715                 goto out;
716
717         handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
718                 wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
719                 src->receiver_index, &peer);
720         if (unlikely(!handshake))
721                 goto out;
722
723         down_read(&handshake->lock);
724         state = handshake->state;
725         memcpy(hash, handshake->hash, NOISE_HASH_LEN);
726         memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
727         memcpy(ephemeral_private, handshake->ephemeral_private,
728                NOISE_PUBLIC_KEY_LEN);
729         memcpy(preshared_key, handshake->preshared_key,
730                NOISE_SYMMETRIC_KEY_LEN);
731         up_read(&handshake->lock);
732
733         if (state != HANDSHAKE_CREATED_INITIATION)
734                 goto fail;
735
736         /* e */
737         message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
738
739         /* ee */
740         if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
741                 goto fail;
742
743         /* se */
744         if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
745                 goto fail;
746
747         /* psk */
748         mix_psk(chaining_key, hash, key, preshared_key);
749
750         /* {} */
751         if (!message_decrypt(NULL, src->encrypted_nothing,
752                              sizeof(src->encrypted_nothing), key, hash))
753                 goto fail;
754
755         /* Success! Copy everything to peer */
756         down_write(&handshake->lock);
757         /* It's important to check that the state is still the same, while we
758          * have an exclusive lock.
759          */
760         if (handshake->state != state) {
761                 up_write(&handshake->lock);
762                 goto fail;
763         }
764         memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
765         memcpy(handshake->hash, hash, NOISE_HASH_LEN);
766         memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
767         handshake->remote_index = src->sender_index;
768         handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
769         up_write(&handshake->lock);
770         ret_peer = peer;
771         goto out;
772
773 fail:
774         wg_peer_put(peer);
775 out:
776         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
777         memzero_explicit(hash, NOISE_HASH_LEN);
778         memzero_explicit(chaining_key, NOISE_HASH_LEN);
779         memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
780         memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
781         memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
782         up_read(&wg->static_identity.lock);
783         return ret_peer;
784 }
785
786 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
787                                       struct noise_keypairs *keypairs)
788 {
789         struct noise_keypair *new_keypair;
790         bool ret = false;
791
792         down_write(&handshake->lock);
793         if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
794             handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
795                 goto out;
796
797         new_keypair = keypair_create(handshake->entry.peer);
798         if (!new_keypair)
799                 goto out;
800         new_keypair->i_am_the_initiator = handshake->state ==
801                                           HANDSHAKE_CONSUMED_RESPONSE;
802         new_keypair->remote_index = handshake->remote_index;
803
804         if (new_keypair->i_am_the_initiator)
805                 derive_keys(&new_keypair->sending, &new_keypair->receiving,
806                             handshake->chaining_key);
807         else
808                 derive_keys(&new_keypair->receiving, &new_keypair->sending,
809                             handshake->chaining_key);
810
811         handshake_zero(handshake);
812         rcu_read_lock_bh();
813         if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
814                                            handshake)->is_dead))) {
815                 add_new_keypair(keypairs, new_keypair);
816                 net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
817                                     handshake->entry.peer->device->dev->name,
818                                     new_keypair->internal_id,
819                                     handshake->entry.peer->internal_id);
820                 ret = wg_index_hashtable_replace(
821                         handshake->entry.peer->device->index_hashtable,
822                         &handshake->entry, &new_keypair->entry);
823         } else {
824                 kzfree(new_keypair);
825         }
826         rcu_read_unlock_bh();
827
828 out:
829         up_write(&handshake->lock);
830         return ret;
831 }