Merge tag 'gpio-updates-for-v5.18' of git://git.kernel.org/pub/scm/linux/kernel/git...
[linux-2.6-microblaze.git] / net / mctp / route.c
index e52cef7..d5e7db8 100644 (file)
@@ -64,8 +64,7 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
                if (msk->bind_type != type)
                        continue;
 
-               if (msk->bind_addr != MCTP_ADDR_ANY &&
-                   msk->bind_addr != mh->dest)
+               if (!mctp_address_matches(msk->bind_addr, mh->dest))
                        continue;
 
                return msk;
@@ -77,7 +76,7 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
 static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
                           mctp_eid_t peer, u8 tag)
 {
-       if (key->local_addr != local)
+       if (!mctp_address_matches(key->local_addr, local))
                return false;
 
        if (key->peer_addr != peer)
@@ -204,29 +203,38 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
        return rc;
 }
 
-/* We're done with the key; unset valid and remove from lists. There may still
- * be outstanding refs on the key though...
+/* Helper for mctp_route_input().
+ * We're done with the key; unlock and unref the key.
+ * For the usual case of automatic expiry we remove the key from lists.
+ * In the case that manual allocation is set on a key we release the lock
+ * and local ref, reset reassembly, but don't remove from lists.
  */
-static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net,
-                                  unsigned long flags)
-       __releases(&key->lock)
+static void __mctp_key_done_in(struct mctp_sk_key *key, struct net *net,
+                              unsigned long flags, unsigned long reason)
+__releases(&key->lock)
 {
        struct sk_buff *skb;
 
+       trace_mctp_key_release(key, reason);
        skb = key->reasm_head;
        key->reasm_head = NULL;
-       key->reasm_dead = true;
-       key->valid = false;
-       mctp_dev_release_key(key->dev, key);
+
+       if (!key->manual_alloc) {
+               key->reasm_dead = true;
+               key->valid = false;
+               mctp_dev_release_key(key->dev, key);
+       }
        spin_unlock_irqrestore(&key->lock, flags);
 
-       spin_lock_irqsave(&net->mctp.keys_lock, flags);
-       hlist_del(&key->hlist);
-       hlist_del(&key->sklist);
-       spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+       if (!key->manual_alloc) {
+               spin_lock_irqsave(&net->mctp.keys_lock, flags);
+               hlist_del(&key->hlist);
+               hlist_del(&key->sklist);
+               spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
 
-       /* one unref for the lists */
-       mctp_key_unref(key);
+               /* unref for the lists */
+               mctp_key_unref(key);
+       }
 
        /* and one for the local reference */
        mctp_key_unref(key);
@@ -380,9 +388,8 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
                                /* we've hit a pending reassembly; not much we
                                 * can do but drop it
                                 */
-                               trace_mctp_key_release(key,
-                                                      MCTP_TRACE_KEY_REPLIED);
-                               __mctp_key_unlock_drop(key, net, f);
+                               __mctp_key_done_in(key, net, f,
+                                                  MCTP_TRACE_KEY_REPLIED);
                                key = NULL;
                        }
                        rc = 0;
@@ -425,9 +432,8 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
                } else {
                        if (key->reasm_head || key->reasm_dead) {
                                /* duplicate start? drop everything */
-                               trace_mctp_key_release(key,
-                                                      MCTP_TRACE_KEY_INVALIDATED);
-                               __mctp_key_unlock_drop(key, net, f);
+                               __mctp_key_done_in(key, net, f,
+                                                  MCTP_TRACE_KEY_INVALIDATED);
                                rc = -EEXIST;
                                key = NULL;
                        } else {
@@ -452,8 +458,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
                if (!rc && flags & MCTP_HDR_FLAG_EOM) {
                        sock_queue_rcv_skb(key->sk, key->reasm_head);
                        key->reasm_head = NULL;
-                       trace_mctp_key_release(key, MCTP_TRACE_KEY_REPLIED);
-                       __mctp_key_unlock_drop(key, net, f);
+                       __mctp_key_done_in(key, net, f, MCTP_TRACE_KEY_REPLIED);
                        key = NULL;
                }
 
@@ -581,9 +586,9 @@ static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
 /* Allocate a locally-owned tag value for (saddr, daddr), and reserve
  * it for the socket msk
  */
-static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
-                                               mctp_eid_t saddr,
-                                               mctp_eid_t daddr, u8 *tagp)
+struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
+                                        mctp_eid_t daddr, mctp_eid_t saddr,
+                                        bool manual, u8 *tagp)
 {
        struct net *net = sock_net(&msk->sk);
        struct netns_mctp *mns = &net->mctp;
@@ -617,9 +622,8 @@ static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
                if (tmp->tag & MCTP_HDR_FLAG_TO)
                        continue;
 
-               if (!((tmp->peer_addr == daddr ||
-                      tmp->peer_addr == MCTP_ADDR_ANY) &&
-                      tmp->local_addr == saddr))
+               if (!(mctp_address_matches(tmp->peer_addr, daddr) &&
+                     mctp_address_matches(tmp->local_addr, saddr)))
                        continue;
 
                spin_lock(&tmp->lock);
@@ -639,6 +643,7 @@ static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
                mctp_reserve_tag(net, key, msk);
                trace_mctp_key_acquire(key);
 
+               key->manual_alloc = manual;
                *tagp = key->tag;
        }
 
@@ -652,6 +657,50 @@ static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
        return key;
 }
 
+static struct mctp_sk_key *mctp_lookup_prealloc_tag(struct mctp_sock *msk,
+                                                   mctp_eid_t daddr,
+                                                   u8 req_tag, u8 *tagp)
+{
+       struct net *net = sock_net(&msk->sk);
+       struct netns_mctp *mns = &net->mctp;
+       struct mctp_sk_key *key, *tmp;
+       unsigned long flags;
+
+       req_tag &= ~(MCTP_TAG_PREALLOC | MCTP_TAG_OWNER);
+       key = NULL;
+
+       spin_lock_irqsave(&mns->keys_lock, flags);
+
+       hlist_for_each_entry(tmp, &mns->keys, hlist) {
+               if (tmp->tag != req_tag)
+                       continue;
+
+               if (!mctp_address_matches(tmp->peer_addr, daddr))
+                       continue;
+
+               if (!tmp->manual_alloc)
+                       continue;
+
+               spin_lock(&tmp->lock);
+               if (tmp->valid) {
+                       key = tmp;
+                       refcount_inc(&key->refs);
+                       spin_unlock(&tmp->lock);
+                       break;
+               }
+               spin_unlock(&tmp->lock);
+       }
+       spin_unlock_irqrestore(&mns->keys_lock, flags);
+
+       if (!key)
+               return ERR_PTR(-ENOENT);
+
+       if (tagp)
+               *tagp = key->tag;
+
+       return key;
+}
+
 /* routing lookups */
 static bool mctp_rt_match_eid(struct mctp_route *rt,
                              unsigned int net, mctp_eid_t eid)
@@ -786,9 +835,8 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
 {
        struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
        struct mctp_skb_cb *cb = mctp_cb(skb);
-       struct mctp_route tmp_rt;
+       struct mctp_route tmp_rt = {0};
        struct mctp_sk_key *key;
-       struct net_device *dev;
        struct mctp_hdr *hdr;
        unsigned long flags;
        unsigned int mtu;
@@ -801,12 +849,12 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
 
        if (rt) {
                ext_rt = false;
-               dev = NULL;
-
                if (WARN_ON(!rt->dev))
                        goto out_release;
 
        } else if (cb->ifindex) {
+               struct net_device *dev;
+
                ext_rt = true;
                rt = &tmp_rt;
 
@@ -816,7 +864,6 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
                        rcu_read_unlock();
                        return rc;
                }
-
                rt->dev = __mctp_dev_get(dev);
                rcu_read_unlock();
 
@@ -846,8 +893,14 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
        if (rc)
                goto out_release;
 
-       if (req_tag & MCTP_HDR_FLAG_TO) {
-               key = mctp_alloc_local_tag(msk, saddr, daddr, &tag);
+       if (req_tag & MCTP_TAG_OWNER) {
+               if (req_tag & MCTP_TAG_PREALLOC)
+                       key = mctp_lookup_prealloc_tag(msk, daddr,
+                                                      req_tag, &tag);
+               else
+                       key = mctp_alloc_local_tag(msk, daddr, saddr,
+                                                  false, &tag);
+
                if (IS_ERR(key)) {
                        rc = PTR_ERR(key);
                        goto out_release;
@@ -858,7 +911,7 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
                tag |= MCTP_HDR_FLAG_TO;
        } else {
                key = NULL;
-               tag = req_tag;
+               tag = req_tag & MCTP_TAG_MASK;
        }
 
        skb->protocol = htons(ETH_P_MCTP);
@@ -891,10 +944,9 @@ out_release:
        if (!ext_rt)
                mctp_route_release(rt);
 
-       dev_put(dev);
+       mctp_dev_put(tmp_rt.dev);
 
        return rc;
-
 }
 
 /* route management */
@@ -906,7 +958,7 @@ static int mctp_route_add(struct mctp_dev *mdev, mctp_eid_t daddr_start,
        struct net *net = dev_net(mdev->dev);
        struct mctp_route *rt, *ert;
 
-       if (!mctp_address_ok(daddr_start))
+       if (!mctp_address_unicast(daddr_start))
                return -EINVAL;
 
        if (daddr_extent > 0xff || daddr_start + daddr_extent >= 255)
@@ -1036,6 +1088,17 @@ static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev,
        if (mh->ver < MCTP_VER_MIN || mh->ver > MCTP_VER_MAX)
                goto err_drop;
 
+       /* source must be valid unicast or null; drop reserved ranges and
+        * broadcast
+        */
+       if (!(mctp_address_unicast(mh->src) || mctp_address_null(mh->src)))
+               goto err_drop;
+
+       /* dest address: as above, but allow broadcast */
+       if (!(mctp_address_unicast(mh->dest) || mctp_address_null(mh->dest) ||
+             mctp_address_broadcast(mh->dest)))
+               goto err_drop;
+
        /* MCTP drivers must populate halen/haddr */
        if (dev->type == ARPHRD_MCTP) {
                cb = mctp_cb(skb);
@@ -1057,11 +1120,13 @@ static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev,
 
        rt->output(rt, skb);
        mctp_route_release(rt);
+       mctp_dev_put(mdev);
 
        return NET_RX_SUCCESS;
 
 err_drop:
        kfree_skb(skb);
+       mctp_dev_put(mdev);
        return NET_RX_DROP;
 }