Merge tag 'hyperv-next-signed-20220807' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-microblaze.git] / net / mac80211 / key.c
index 0fcf8ae..6befb57 100644 (file)
@@ -6,7 +6,7 @@
  * Copyright 2007-2008 Johannes Berg <johannes@sipsolutions.net>
  * Copyright 2013-2014  Intel Mobile Communications GmbH
  * Copyright 2015-2017 Intel Deutschland GmbH
- * Copyright 2018-2020  Intel Corporation
+ * Copyright 2018-2020, 2022  Intel Corporation
  */
 
 #include <linux/if_ether.h>
@@ -351,8 +351,11 @@ static void __ieee80211_set_default_key(struct ieee80211_sub_if_data *sdata,
 
        assert_key_lock(sdata->local);
 
-       if (idx >= 0 && idx < NUM_DEFAULT_KEYS)
+       if (idx >= 0 && idx < NUM_DEFAULT_KEYS) {
                key = key_mtx_dereference(sdata->local, sdata->keys[idx]);
+               if (!key)
+                       key = key_mtx_dereference(sdata->local, sdata->deflink.gtk[idx]);
+       }
 
        if (uni) {
                rcu_assign_pointer(sdata->default_unicast_key, key);
@@ -362,7 +365,7 @@ static void __ieee80211_set_default_key(struct ieee80211_sub_if_data *sdata,
        }
 
        if (multi)
-               rcu_assign_pointer(sdata->default_multicast_key, key);
+               rcu_assign_pointer(sdata->deflink.default_multicast_key, key);
 
        ieee80211_debugfs_key_update_default(sdata);
 }
@@ -384,9 +387,10 @@ __ieee80211_set_default_mgmt_key(struct ieee80211_sub_if_data *sdata, int idx)
 
        if (idx >= NUM_DEFAULT_KEYS &&
            idx < NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS)
-               key = key_mtx_dereference(sdata->local, sdata->keys[idx]);
+               key = key_mtx_dereference(sdata->local,
+                                         sdata->deflink.gtk[idx]);
 
-       rcu_assign_pointer(sdata->default_mgmt_key, key);
+       rcu_assign_pointer(sdata->deflink.default_mgmt_key, key);
 
        ieee80211_debugfs_key_update_default(sdata);
 }
@@ -409,9 +413,10 @@ __ieee80211_set_default_beacon_key(struct ieee80211_sub_if_data *sdata, int idx)
        if (idx >= NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS &&
            idx < NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS +
            NUM_DEFAULT_BEACON_KEYS)
-               key = key_mtx_dereference(sdata->local, sdata->keys[idx]);
+               key = key_mtx_dereference(sdata->local,
+                                         sdata->deflink.gtk[idx]);
 
-       rcu_assign_pointer(sdata->default_beacon_key, key);
+       rcu_assign_pointer(sdata->deflink.default_beacon_key, key);
 
        ieee80211_debugfs_key_update_default(sdata);
 }
@@ -433,13 +438,25 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
        int idx;
        int ret = 0;
        bool defunikey, defmultikey, defmgmtkey, defbeaconkey;
+       bool is_wep;
 
        /* caller must provide at least one old/new */
        if (WARN_ON(!new && !old))
                return 0;
 
-       if (new)
+       if (new) {
+               idx = new->conf.keyidx;
                list_add_tail_rcu(&new->list, &sdata->key_list);
+               is_wep = new->conf.cipher == WLAN_CIPHER_SUITE_WEP40 ||
+                        new->conf.cipher == WLAN_CIPHER_SUITE_WEP104;
+       } else {
+               idx = old->conf.keyidx;
+               is_wep = old->conf.cipher == WLAN_CIPHER_SUITE_WEP40 ||
+                        old->conf.cipher == WLAN_CIPHER_SUITE_WEP104;
+       }
+
+       if ((is_wep || pairwise) && idx >= NUM_DEFAULT_KEYS)
+               return -EINVAL;
 
        WARN_ON(new && old && new->conf.keyidx != old->conf.keyidx);
 
@@ -451,8 +468,6 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
        }
 
        if (old) {
-               idx = old->conf.keyidx;
-
                if (old->flags & KEY_FLAG_UPLOADED_TO_HARDWARE) {
                        ieee80211_key_disable_hw_accel(old);
 
@@ -460,8 +475,6 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
                                ret = ieee80211_key_enable_hw_accel(new);
                }
        } else {
-               /* new must be provided in case old is not */
-               idx = new->conf.keyidx;
                if (!new->local->wowlan)
                        ret = ieee80211_key_enable_hw_accel(new);
        }
@@ -490,13 +503,13 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
                                                sdata->default_unicast_key);
                defmultikey = old &&
                        old == key_mtx_dereference(sdata->local,
-                                               sdata->default_multicast_key);
+                                               sdata->deflink.default_multicast_key);
                defmgmtkey = old &&
                        old == key_mtx_dereference(sdata->local,
-                                               sdata->default_mgmt_key);
+                                               sdata->deflink.default_mgmt_key);
                defbeaconkey = old &&
                        old == key_mtx_dereference(sdata->local,
-                                                  sdata->default_beacon_key);
+                                                  sdata->deflink.default_beacon_key);
 
                if (defunikey && !new)
                        __ieee80211_set_default_key(sdata, -1, true, false);
@@ -507,7 +520,11 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
                if (defbeaconkey && !new)
                        __ieee80211_set_default_beacon_key(sdata, -1);
 
-               rcu_assign_pointer(sdata->keys[idx], new);
+               if (is_wep || pairwise)
+                       rcu_assign_pointer(sdata->keys[idx], new);
+               else
+                       rcu_assign_pointer(sdata->deflink.gtk[idx], new);
+
                if (defunikey && new)
                        __ieee80211_set_default_key(sdata, new->conf.keyidx,
                                                    true, false);
@@ -531,8 +548,7 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
 struct ieee80211_key *
 ieee80211_key_alloc(u32 cipher, int idx, size_t key_len,
                    const u8 *key_data,
-                   size_t seq_len, const u8 *seq,
-                   const struct ieee80211_cipher_scheme *cs)
+                   size_t seq_len, const u8 *seq)
 {
        struct ieee80211_key *key;
        int i, j, err;
@@ -675,21 +691,6 @@ ieee80211_key_alloc(u32 cipher, int idx, size_t key_len,
                        return ERR_PTR(err);
                }
                break;
-       default:
-               if (cs) {
-                       if (seq_len && seq_len != cs->pn_len) {
-                               kfree(key);
-                               return ERR_PTR(-EINVAL);
-                       }
-
-                       key->conf.iv_len = cs->hdr_len;
-                       key->conf.icv_len = cs->mic_len;
-                       for (i = 0; i < IEEE80211_NUM_TIDS + 1; i++)
-                               for (j = 0; j < seq_len; j++)
-                                       key->u.gen.rx_pn[i][j] =
-                                                       seq[seq_len - j - 1];
-                       key->flags |= KEY_FLAG_CIPHER_SCHEME;
-               }
        }
        memcpy(key->conf.key, key_data, key_len);
        INIT_LIST_HEAD(&key->list);
@@ -800,7 +801,7 @@ int ieee80211_key_link(struct ieee80211_key *key,
                       struct sta_info *sta)
 {
        static atomic_t key_color = ATOMIC_INIT(0);
-       struct ieee80211_key *old_key;
+       struct ieee80211_key *old_key = NULL;
        int idx = key->conf.keyidx;
        bool pairwise = key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE;
        /*
@@ -829,7 +830,12 @@ int ieee80211_key_link(struct ieee80211_key *key,
                old_key = key_mtx_dereference(sdata->local,
                                              sta->deflink.gtk[idx]);
        } else {
-               old_key = key_mtx_dereference(sdata->local, sdata->keys[idx]);
+               if (idx < NUM_DEFAULT_KEYS)
+                       old_key = key_mtx_dereference(sdata->local,
+                                                     sdata->keys[idx]);
+               if (!old_key)
+                       old_key = key_mtx_dereference(sdata->local,
+                                                     sdata->deflink.gtk[idx]);
        }
 
        /* Non-pairwise keys must also not switch the cipher on rekey */
@@ -1294,7 +1300,7 @@ ieee80211_gtk_rekey_add(struct ieee80211_vif *vif,
 
        key = ieee80211_key_alloc(keyconf->cipher, keyconf->keyidx,
                                  keyconf->keylen, keyconf->key,
-                                 0, NULL, NULL);
+                                 0, NULL);
        if (IS_ERR(key))
                return ERR_CAST(key);