Merge tag 'ext4_for_linus_stable' of git://git.kernel.org/pub/scm/linux/kernel/git...
[linux-2.6-microblaze.git] / drivers / net / wireguard / noise.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5
6 #include "noise.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "messages.h"
10 #include "queueing.h"
11 #include "peerlookup.h"
12
13 #include <linux/rcupdate.h>
14 #include <linux/slab.h>
15 #include <linux/bitmap.h>
16 #include <linux/scatterlist.h>
17 #include <linux/highmem.h>
18 #include <crypto/algapi.h>
19
20 /* This implements Noise_IKpsk2:
21  *
22  * <- s
23  * ******
24  * -> e, es, s, ss, {t}
25  * <- e, ee, se, psk, {}
26  */
27
28 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
33
34 void __init wg_noise_init(void)
35 {
36         struct blake2s_state blake;
37
38         blake2s(handshake_init_chaining_key, handshake_name, NULL,
39                 NOISE_HASH_LEN, sizeof(handshake_name), 0);
40         blake2s_init(&blake, NOISE_HASH_LEN);
41         blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42         blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43         blake2s_final(&blake, handshake_init_hash);
44 }
45
46 /* Must hold peer->handshake.static_identity->lock */
47 void wg_noise_precompute_static_static(struct wg_peer *peer)
48 {
49         down_write(&peer->handshake.lock);
50         if (!peer->handshake.static_identity->has_identity ||
51             !curve25519(peer->handshake.precomputed_static_static,
52                         peer->handshake.static_identity->static_private,
53                         peer->handshake.remote_static))
54                 memset(peer->handshake.precomputed_static_static, 0,
55                        NOISE_PUBLIC_KEY_LEN);
56         up_write(&peer->handshake.lock);
57 }
58
59 void wg_noise_handshake_init(struct noise_handshake *handshake,
60                              struct noise_static_identity *static_identity,
61                              const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
62                              const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
63                              struct wg_peer *peer)
64 {
65         memset(handshake, 0, sizeof(*handshake));
66         init_rwsem(&handshake->lock);
67         handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
68         handshake->entry.peer = peer;
69         memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
70         if (peer_preshared_key)
71                 memcpy(handshake->preshared_key, peer_preshared_key,
72                        NOISE_SYMMETRIC_KEY_LEN);
73         handshake->static_identity = static_identity;
74         handshake->state = HANDSHAKE_ZEROED;
75         wg_noise_precompute_static_static(peer);
76 }
77
78 static void handshake_zero(struct noise_handshake *handshake)
79 {
80         memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
81         memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
82         memset(&handshake->hash, 0, NOISE_HASH_LEN);
83         memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
84         handshake->remote_index = 0;
85         handshake->state = HANDSHAKE_ZEROED;
86 }
87
88 void wg_noise_handshake_clear(struct noise_handshake *handshake)
89 {
90         down_write(&handshake->lock);
91         wg_index_hashtable_remove(
92                         handshake->entry.peer->device->index_hashtable,
93                         &handshake->entry);
94         handshake_zero(handshake);
95         up_write(&handshake->lock);
96 }
97
98 static struct noise_keypair *keypair_create(struct wg_peer *peer)
99 {
100         struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
101
102         if (unlikely(!keypair))
103                 return NULL;
104         spin_lock_init(&keypair->receiving_counter.lock);
105         keypair->internal_id = atomic64_inc_return(&keypair_counter);
106         keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
107         keypair->entry.peer = peer;
108         kref_init(&keypair->refcount);
109         return keypair;
110 }
111
112 static void keypair_free_rcu(struct rcu_head *rcu)
113 {
114         kfree_sensitive(container_of(rcu, struct noise_keypair, rcu));
115 }
116
117 static void keypair_free_kref(struct kref *kref)
118 {
119         struct noise_keypair *keypair =
120                 container_of(kref, struct noise_keypair, refcount);
121
122         net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
123                             keypair->entry.peer->device->dev->name,
124                             keypair->internal_id,
125                             keypair->entry.peer->internal_id);
126         wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
127                                   &keypair->entry);
128         call_rcu(&keypair->rcu, keypair_free_rcu);
129 }
130
131 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
132 {
133         if (unlikely(!keypair))
134                 return;
135         if (unlikely(unreference_now))
136                 wg_index_hashtable_remove(
137                         keypair->entry.peer->device->index_hashtable,
138                         &keypair->entry);
139         kref_put(&keypair->refcount, keypair_free_kref);
140 }
141
142 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
143 {
144         RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
145                 "Taking noise keypair reference without holding the RCU BH read lock");
146         if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
147                 return NULL;
148         return keypair;
149 }
150
151 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
152 {
153         struct noise_keypair *old;
154
155         spin_lock_bh(&keypairs->keypair_update_lock);
156
157         /* We zero the next_keypair before zeroing the others, so that
158          * wg_noise_received_with_keypair returns early before subsequent ones
159          * are zeroed.
160          */
161         old = rcu_dereference_protected(keypairs->next_keypair,
162                 lockdep_is_held(&keypairs->keypair_update_lock));
163         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
164         wg_noise_keypair_put(old, true);
165
166         old = rcu_dereference_protected(keypairs->previous_keypair,
167                 lockdep_is_held(&keypairs->keypair_update_lock));
168         RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
169         wg_noise_keypair_put(old, true);
170
171         old = rcu_dereference_protected(keypairs->current_keypair,
172                 lockdep_is_held(&keypairs->keypair_update_lock));
173         RCU_INIT_POINTER(keypairs->current_keypair, NULL);
174         wg_noise_keypair_put(old, true);
175
176         spin_unlock_bh(&keypairs->keypair_update_lock);
177 }
178
179 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
180 {
181         struct noise_keypair *keypair;
182
183         wg_noise_handshake_clear(&peer->handshake);
184         wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
185
186         spin_lock_bh(&peer->keypairs.keypair_update_lock);
187         keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
188                         lockdep_is_held(&peer->keypairs.keypair_update_lock));
189         if (keypair)
190                 keypair->sending.is_valid = false;
191         keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
192                         lockdep_is_held(&peer->keypairs.keypair_update_lock));
193         if (keypair)
194                 keypair->sending.is_valid = false;
195         spin_unlock_bh(&peer->keypairs.keypair_update_lock);
196 }
197
198 static void add_new_keypair(struct noise_keypairs *keypairs,
199                             struct noise_keypair *new_keypair)
200 {
201         struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
202
203         spin_lock_bh(&keypairs->keypair_update_lock);
204         previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
205                 lockdep_is_held(&keypairs->keypair_update_lock));
206         next_keypair = rcu_dereference_protected(keypairs->next_keypair,
207                 lockdep_is_held(&keypairs->keypair_update_lock));
208         current_keypair = rcu_dereference_protected(keypairs->current_keypair,
209                 lockdep_is_held(&keypairs->keypair_update_lock));
210         if (new_keypair->i_am_the_initiator) {
211                 /* If we're the initiator, it means we've sent a handshake, and
212                  * received a confirmation response, which means this new
213                  * keypair can now be used.
214                  */
215                 if (next_keypair) {
216                         /* If there already was a next keypair pending, we
217                          * demote it to be the previous keypair, and free the
218                          * existing current. Note that this means KCI can result
219                          * in this transition. It would perhaps be more sound to
220                          * always just get rid of the unused next keypair
221                          * instead of putting it in the previous slot, but this
222                          * might be a bit less robust. Something to think about
223                          * for the future.
224                          */
225                         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
226                         rcu_assign_pointer(keypairs->previous_keypair,
227                                            next_keypair);
228                         wg_noise_keypair_put(current_keypair, true);
229                 } else /* If there wasn't an existing next keypair, we replace
230                         * the previous with the current one.
231                         */
232                         rcu_assign_pointer(keypairs->previous_keypair,
233                                            current_keypair);
234                 /* At this point we can get rid of the old previous keypair, and
235                  * set up the new keypair.
236                  */
237                 wg_noise_keypair_put(previous_keypair, true);
238                 rcu_assign_pointer(keypairs->current_keypair, new_keypair);
239         } else {
240                 /* If we're the responder, it means we can't use the new keypair
241                  * until we receive confirmation via the first data packet, so
242                  * we get rid of the existing previous one, the possibly
243                  * existing next one, and slide in the new next one.
244                  */
245                 rcu_assign_pointer(keypairs->next_keypair, new_keypair);
246                 wg_noise_keypair_put(next_keypair, true);
247                 RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
248                 wg_noise_keypair_put(previous_keypair, true);
249         }
250         spin_unlock_bh(&keypairs->keypair_update_lock);
251 }
252
253 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
254                                     struct noise_keypair *received_keypair)
255 {
256         struct noise_keypair *old_keypair;
257         bool key_is_new;
258
259         /* We first check without taking the spinlock. */
260         key_is_new = received_keypair ==
261                      rcu_access_pointer(keypairs->next_keypair);
262         if (likely(!key_is_new))
263                 return false;
264
265         spin_lock_bh(&keypairs->keypair_update_lock);
266         /* After locking, we double check that things didn't change from
267          * beneath us.
268          */
269         if (unlikely(received_keypair !=
270                     rcu_dereference_protected(keypairs->next_keypair,
271                             lockdep_is_held(&keypairs->keypair_update_lock)))) {
272                 spin_unlock_bh(&keypairs->keypair_update_lock);
273                 return false;
274         }
275
276         /* When we've finally received the confirmation, we slide the next
277          * into the current, the current into the previous, and get rid of
278          * the old previous.
279          */
280         old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
281                 lockdep_is_held(&keypairs->keypair_update_lock));
282         rcu_assign_pointer(keypairs->previous_keypair,
283                 rcu_dereference_protected(keypairs->current_keypair,
284                         lockdep_is_held(&keypairs->keypair_update_lock)));
285         wg_noise_keypair_put(old_keypair, true);
286         rcu_assign_pointer(keypairs->current_keypair, received_keypair);
287         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
288
289         spin_unlock_bh(&keypairs->keypair_update_lock);
290         return true;
291 }
292
293 /* Must hold static_identity->lock */
294 void wg_noise_set_static_identity_private_key(
295         struct noise_static_identity *static_identity,
296         const u8 private_key[NOISE_PUBLIC_KEY_LEN])
297 {
298         memcpy(static_identity->static_private, private_key,
299                NOISE_PUBLIC_KEY_LEN);
300         curve25519_clamp_secret(static_identity->static_private);
301         static_identity->has_identity = curve25519_generate_public(
302                 static_identity->static_public, private_key);
303 }
304
305 /* This is Hugo Krawczyk's HKDF:
306  *  - https://eprint.iacr.org/2010/264.pdf
307  *  - https://tools.ietf.org/html/rfc5869
308  */
309 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
310                 size_t first_len, size_t second_len, size_t third_len,
311                 size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
312 {
313         u8 output[BLAKE2S_HASH_SIZE + 1];
314         u8 secret[BLAKE2S_HASH_SIZE];
315
316         WARN_ON(IS_ENABLED(DEBUG) &&
317                 (first_len > BLAKE2S_HASH_SIZE ||
318                  second_len > BLAKE2S_HASH_SIZE ||
319                  third_len > BLAKE2S_HASH_SIZE ||
320                  ((second_len || second_dst || third_len || third_dst) &&
321                   (!first_len || !first_dst)) ||
322                  ((third_len || third_dst) && (!second_len || !second_dst))));
323
324         /* Extract entropy from data into secret */
325         blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
326
327         if (!first_dst || !first_len)
328                 goto out;
329
330         /* Expand first key: key = secret, data = 0x1 */
331         output[0] = 1;
332         blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
333         memcpy(first_dst, output, first_len);
334
335         if (!second_dst || !second_len)
336                 goto out;
337
338         /* Expand second key: key = secret, data = first-key || 0x2 */
339         output[BLAKE2S_HASH_SIZE] = 2;
340         blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
341                         BLAKE2S_HASH_SIZE);
342         memcpy(second_dst, output, second_len);
343
344         if (!third_dst || !third_len)
345                 goto out;
346
347         /* Expand third key: key = secret, data = second-key || 0x3 */
348         output[BLAKE2S_HASH_SIZE] = 3;
349         blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
350                         BLAKE2S_HASH_SIZE);
351         memcpy(third_dst, output, third_len);
352
353 out:
354         /* Clear sensitive data from stack */
355         memzero_explicit(secret, BLAKE2S_HASH_SIZE);
356         memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
357 }
358
359 static void derive_keys(struct noise_symmetric_key *first_dst,
360                         struct noise_symmetric_key *second_dst,
361                         const u8 chaining_key[NOISE_HASH_LEN])
362 {
363         u64 birthdate = ktime_get_coarse_boottime_ns();
364         kdf(first_dst->key, second_dst->key, NULL, NULL,
365             NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
366             chaining_key);
367         first_dst->birthdate = second_dst->birthdate = birthdate;
368         first_dst->is_valid = second_dst->is_valid = true;
369 }
370
371 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
372                                 u8 key[NOISE_SYMMETRIC_KEY_LEN],
373                                 const u8 private[NOISE_PUBLIC_KEY_LEN],
374                                 const u8 public[NOISE_PUBLIC_KEY_LEN])
375 {
376         u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
377
378         if (unlikely(!curve25519(dh_calculation, private, public)))
379                 return false;
380         kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
381             NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
382         memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
383         return true;
384 }
385
386 static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
387                                             u8 key[NOISE_SYMMETRIC_KEY_LEN],
388                                             const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
389 {
390         static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
391         if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
392                 return false;
393         kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
394             NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
395             chaining_key);
396         return true;
397 }
398
399 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
400 {
401         struct blake2s_state blake;
402
403         blake2s_init(&blake, NOISE_HASH_LEN);
404         blake2s_update(&blake, hash, NOISE_HASH_LEN);
405         blake2s_update(&blake, src, src_len);
406         blake2s_final(&blake, hash);
407 }
408
409 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
410                     u8 key[NOISE_SYMMETRIC_KEY_LEN],
411                     const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
412 {
413         u8 temp_hash[NOISE_HASH_LEN];
414
415         kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
416             NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
417         mix_hash(hash, temp_hash, NOISE_HASH_LEN);
418         memzero_explicit(temp_hash, NOISE_HASH_LEN);
419 }
420
421 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
422                            u8 hash[NOISE_HASH_LEN],
423                            const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
424 {
425         memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
426         memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
427         mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
428 }
429
430 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
431                             size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
432                             u8 hash[NOISE_HASH_LEN])
433 {
434         chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
435                                  NOISE_HASH_LEN,
436                                  0 /* Always zero for Noise_IK */, key);
437         mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
438 }
439
440 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
441                             size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
442                             u8 hash[NOISE_HASH_LEN])
443 {
444         if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
445                                       hash, NOISE_HASH_LEN,
446                                       0 /* Always zero for Noise_IK */, key))
447                 return false;
448         mix_hash(hash, src_ciphertext, src_len);
449         return true;
450 }
451
452 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
453                               const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
454                               u8 chaining_key[NOISE_HASH_LEN],
455                               u8 hash[NOISE_HASH_LEN])
456 {
457         if (ephemeral_dst != ephemeral_src)
458                 memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
459         mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
460         kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
461             NOISE_PUBLIC_KEY_LEN, chaining_key);
462 }
463
464 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
465 {
466         struct timespec64 now;
467
468         ktime_get_real_ts64(&now);
469
470         /* In order to prevent some sort of infoleak from precise timers, we
471          * round down the nanoseconds part to the closest rounded-down power of
472          * two to the maximum initiations per second allowed anyway by the
473          * implementation.
474          */
475         now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
476                 rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
477
478         /* https://cr.yp.to/libtai/tai64.html */
479         *(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
480         *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
481 }
482
483 bool
484 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
485                                      struct noise_handshake *handshake)
486 {
487         u8 timestamp[NOISE_TIMESTAMP_LEN];
488         u8 key[NOISE_SYMMETRIC_KEY_LEN];
489         bool ret = false;
490
491         /* We need to wait for crng _before_ taking any locks, since
492          * curve25519_generate_secret uses get_random_bytes_wait.
493          */
494         wait_for_random_bytes();
495
496         down_read(&handshake->static_identity->lock);
497         down_write(&handshake->lock);
498
499         if (unlikely(!handshake->static_identity->has_identity))
500                 goto out;
501
502         dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
503
504         handshake_init(handshake->chaining_key, handshake->hash,
505                        handshake->remote_static);
506
507         /* e */
508         curve25519_generate_secret(handshake->ephemeral_private);
509         if (!curve25519_generate_public(dst->unencrypted_ephemeral,
510                                         handshake->ephemeral_private))
511                 goto out;
512         message_ephemeral(dst->unencrypted_ephemeral,
513                           dst->unencrypted_ephemeral, handshake->chaining_key,
514                           handshake->hash);
515
516         /* es */
517         if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
518                     handshake->remote_static))
519                 goto out;
520
521         /* s */
522         message_encrypt(dst->encrypted_static,
523                         handshake->static_identity->static_public,
524                         NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
525
526         /* ss */
527         if (!mix_precomputed_dh(handshake->chaining_key, key,
528                                 handshake->precomputed_static_static))
529                 goto out;
530
531         /* {t} */
532         tai64n_now(timestamp);
533         message_encrypt(dst->encrypted_timestamp, timestamp,
534                         NOISE_TIMESTAMP_LEN, key, handshake->hash);
535
536         dst->sender_index = wg_index_hashtable_insert(
537                 handshake->entry.peer->device->index_hashtable,
538                 &handshake->entry);
539
540         handshake->state = HANDSHAKE_CREATED_INITIATION;
541         ret = true;
542
543 out:
544         up_write(&handshake->lock);
545         up_read(&handshake->static_identity->lock);
546         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
547         return ret;
548 }
549
550 struct wg_peer *
551 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
552                                       struct wg_device *wg)
553 {
554         struct wg_peer *peer = NULL, *ret_peer = NULL;
555         struct noise_handshake *handshake;
556         bool replay_attack, flood_attack;
557         u8 key[NOISE_SYMMETRIC_KEY_LEN];
558         u8 chaining_key[NOISE_HASH_LEN];
559         u8 hash[NOISE_HASH_LEN];
560         u8 s[NOISE_PUBLIC_KEY_LEN];
561         u8 e[NOISE_PUBLIC_KEY_LEN];
562         u8 t[NOISE_TIMESTAMP_LEN];
563         u64 initiation_consumption;
564
565         down_read(&wg->static_identity.lock);
566         if (unlikely(!wg->static_identity.has_identity))
567                 goto out;
568
569         handshake_init(chaining_key, hash, wg->static_identity.static_public);
570
571         /* e */
572         message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
573
574         /* es */
575         if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
576                 goto out;
577
578         /* s */
579         if (!message_decrypt(s, src->encrypted_static,
580                              sizeof(src->encrypted_static), key, hash))
581                 goto out;
582
583         /* Lookup which peer we're actually talking to */
584         peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
585         if (!peer)
586                 goto out;
587         handshake = &peer->handshake;
588
589         /* ss */
590         if (!mix_precomputed_dh(chaining_key, key,
591                                 handshake->precomputed_static_static))
592             goto out;
593
594         /* {t} */
595         if (!message_decrypt(t, src->encrypted_timestamp,
596                              sizeof(src->encrypted_timestamp), key, hash))
597                 goto out;
598
599         down_read(&handshake->lock);
600         replay_attack = memcmp(t, handshake->latest_timestamp,
601                                NOISE_TIMESTAMP_LEN) <= 0;
602         flood_attack = (s64)handshake->last_initiation_consumption +
603                                NSEC_PER_SEC / INITIATIONS_PER_SECOND >
604                        (s64)ktime_get_coarse_boottime_ns();
605         up_read(&handshake->lock);
606         if (replay_attack || flood_attack)
607                 goto out;
608
609         /* Success! Copy everything to peer */
610         down_write(&handshake->lock);
611         memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
612         if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
613                 memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
614         memcpy(handshake->hash, hash, NOISE_HASH_LEN);
615         memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
616         handshake->remote_index = src->sender_index;
617         initiation_consumption = ktime_get_coarse_boottime_ns();
618         if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
619                 handshake->last_initiation_consumption = initiation_consumption;
620         handshake->state = HANDSHAKE_CONSUMED_INITIATION;
621         up_write(&handshake->lock);
622         ret_peer = peer;
623
624 out:
625         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
626         memzero_explicit(hash, NOISE_HASH_LEN);
627         memzero_explicit(chaining_key, NOISE_HASH_LEN);
628         up_read(&wg->static_identity.lock);
629         if (!ret_peer)
630                 wg_peer_put(peer);
631         return ret_peer;
632 }
633
634 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
635                                         struct noise_handshake *handshake)
636 {
637         u8 key[NOISE_SYMMETRIC_KEY_LEN];
638         bool ret = false;
639
640         /* We need to wait for crng _before_ taking any locks, since
641          * curve25519_generate_secret uses get_random_bytes_wait.
642          */
643         wait_for_random_bytes();
644
645         down_read(&handshake->static_identity->lock);
646         down_write(&handshake->lock);
647
648         if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
649                 goto out;
650
651         dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
652         dst->receiver_index = handshake->remote_index;
653
654         /* e */
655         curve25519_generate_secret(handshake->ephemeral_private);
656         if (!curve25519_generate_public(dst->unencrypted_ephemeral,
657                                         handshake->ephemeral_private))
658                 goto out;
659         message_ephemeral(dst->unencrypted_ephemeral,
660                           dst->unencrypted_ephemeral, handshake->chaining_key,
661                           handshake->hash);
662
663         /* ee */
664         if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
665                     handshake->remote_ephemeral))
666                 goto out;
667
668         /* se */
669         if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
670                     handshake->remote_static))
671                 goto out;
672
673         /* psk */
674         mix_psk(handshake->chaining_key, handshake->hash, key,
675                 handshake->preshared_key);
676
677         /* {} */
678         message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
679
680         dst->sender_index = wg_index_hashtable_insert(
681                 handshake->entry.peer->device->index_hashtable,
682                 &handshake->entry);
683
684         handshake->state = HANDSHAKE_CREATED_RESPONSE;
685         ret = true;
686
687 out:
688         up_write(&handshake->lock);
689         up_read(&handshake->static_identity->lock);
690         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
691         return ret;
692 }
693
694 struct wg_peer *
695 wg_noise_handshake_consume_response(struct message_handshake_response *src,
696                                     struct wg_device *wg)
697 {
698         enum noise_handshake_state state = HANDSHAKE_ZEROED;
699         struct wg_peer *peer = NULL, *ret_peer = NULL;
700         struct noise_handshake *handshake;
701         u8 key[NOISE_SYMMETRIC_KEY_LEN];
702         u8 hash[NOISE_HASH_LEN];
703         u8 chaining_key[NOISE_HASH_LEN];
704         u8 e[NOISE_PUBLIC_KEY_LEN];
705         u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
706         u8 static_private[NOISE_PUBLIC_KEY_LEN];
707         u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
708
709         down_read(&wg->static_identity.lock);
710
711         if (unlikely(!wg->static_identity.has_identity))
712                 goto out;
713
714         handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
715                 wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
716                 src->receiver_index, &peer);
717         if (unlikely(!handshake))
718                 goto out;
719
720         down_read(&handshake->lock);
721         state = handshake->state;
722         memcpy(hash, handshake->hash, NOISE_HASH_LEN);
723         memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
724         memcpy(ephemeral_private, handshake->ephemeral_private,
725                NOISE_PUBLIC_KEY_LEN);
726         memcpy(preshared_key, handshake->preshared_key,
727                NOISE_SYMMETRIC_KEY_LEN);
728         up_read(&handshake->lock);
729
730         if (state != HANDSHAKE_CREATED_INITIATION)
731                 goto fail;
732
733         /* e */
734         message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
735
736         /* ee */
737         if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
738                 goto fail;
739
740         /* se */
741         if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
742                 goto fail;
743
744         /* psk */
745         mix_psk(chaining_key, hash, key, preshared_key);
746
747         /* {} */
748         if (!message_decrypt(NULL, src->encrypted_nothing,
749                              sizeof(src->encrypted_nothing), key, hash))
750                 goto fail;
751
752         /* Success! Copy everything to peer */
753         down_write(&handshake->lock);
754         /* It's important to check that the state is still the same, while we
755          * have an exclusive lock.
756          */
757         if (handshake->state != state) {
758                 up_write(&handshake->lock);
759                 goto fail;
760         }
761         memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
762         memcpy(handshake->hash, hash, NOISE_HASH_LEN);
763         memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
764         handshake->remote_index = src->sender_index;
765         handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
766         up_write(&handshake->lock);
767         ret_peer = peer;
768         goto out;
769
770 fail:
771         wg_peer_put(peer);
772 out:
773         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
774         memzero_explicit(hash, NOISE_HASH_LEN);
775         memzero_explicit(chaining_key, NOISE_HASH_LEN);
776         memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
777         memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
778         memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
779         up_read(&wg->static_identity.lock);
780         return ret_peer;
781 }
782
783 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
784                                       struct noise_keypairs *keypairs)
785 {
786         struct noise_keypair *new_keypair;
787         bool ret = false;
788
789         down_write(&handshake->lock);
790         if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
791             handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
792                 goto out;
793
794         new_keypair = keypair_create(handshake->entry.peer);
795         if (!new_keypair)
796                 goto out;
797         new_keypair->i_am_the_initiator = handshake->state ==
798                                           HANDSHAKE_CONSUMED_RESPONSE;
799         new_keypair->remote_index = handshake->remote_index;
800
801         if (new_keypair->i_am_the_initiator)
802                 derive_keys(&new_keypair->sending, &new_keypair->receiving,
803                             handshake->chaining_key);
804         else
805                 derive_keys(&new_keypair->receiving, &new_keypair->sending,
806                             handshake->chaining_key);
807
808         handshake_zero(handshake);
809         rcu_read_lock_bh();
810         if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
811                                            handshake)->is_dead))) {
812                 add_new_keypair(keypairs, new_keypair);
813                 net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
814                                     handshake->entry.peer->device->dev->name,
815                                     new_keypair->internal_id,
816                                     handshake->entry.peer->internal_id);
817                 ret = wg_index_hashtable_replace(
818                         handshake->entry.peer->device->index_hashtable,
819                         &handshake->entry, &new_keypair->entry);
820         } else {
821                 kfree_sensitive(new_keypair);
822         }
823         rcu_read_unlock_bh();
824
825 out:
826         up_write(&handshake->lock);
827         return ret;
828 }