Linux 6.9-rc1
[linux-2.6-microblaze.git] / net / ipv6 / inet6_hashtables.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * INET         An implementation of the TCP/IP protocol suite for the LINUX
4  *              operating system.  INET is implemented using the BSD Socket
5  *              interface as the means of communication with the user level.
6  *
7  *              Generic INET6 transport hashtables
8  *
9  * Authors:     Lotsa people, from code originally in tcp, generalised here
10  *              by Arnaldo Carvalho de Melo <acme@mandriva.com>
11  */
12
13 #include <linux/module.h>
14 #include <linux/random.h>
15
16 #include <net/addrconf.h>
17 #include <net/hotdata.h>
18 #include <net/inet_connection_sock.h>
19 #include <net/inet_hashtables.h>
20 #include <net/inet6_hashtables.h>
21 #include <net/secure_seq.h>
22 #include <net/ip.h>
23 #include <net/sock_reuseport.h>
24
25 u32 inet6_ehashfn(const struct net *net,
26                   const struct in6_addr *laddr, const u16 lport,
27                   const struct in6_addr *faddr, const __be16 fport)
28 {
29         u32 lhash, fhash;
30
31         net_get_random_once(&inet6_ehash_secret, sizeof(inet6_ehash_secret));
32         net_get_random_once(&tcp_ipv6_hash_secret, sizeof(tcp_ipv6_hash_secret));
33
34         lhash = (__force u32)laddr->s6_addr32[3];
35         fhash = __ipv6_addr_jhash(faddr, tcp_ipv6_hash_secret);
36
37         return __inet6_ehashfn(lhash, lport, fhash, fport,
38                                inet6_ehash_secret + net_hash_mix(net));
39 }
40 EXPORT_SYMBOL_GPL(inet6_ehashfn);
41
42 /*
43  * Sockets in TCP_CLOSE state are _always_ taken out of the hash, so
44  * we need not check it for TCP lookups anymore, thanks Alexey. -DaveM
45  *
46  * The sockhash lock must be held as a reader here.
47  */
48 struct sock *__inet6_lookup_established(struct net *net,
49                                         struct inet_hashinfo *hashinfo,
50                                            const struct in6_addr *saddr,
51                                            const __be16 sport,
52                                            const struct in6_addr *daddr,
53                                            const u16 hnum,
54                                            const int dif, const int sdif)
55 {
56         struct sock *sk;
57         const struct hlist_nulls_node *node;
58         const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
59         /* Optimize here for direct hit, only listening connections can
60          * have wildcards anyways.
61          */
62         unsigned int hash = inet6_ehashfn(net, daddr, hnum, saddr, sport);
63         unsigned int slot = hash & hashinfo->ehash_mask;
64         struct inet_ehash_bucket *head = &hashinfo->ehash[slot];
65
66
67 begin:
68         sk_nulls_for_each_rcu(sk, node, &head->chain) {
69                 if (sk->sk_hash != hash)
70                         continue;
71                 if (!inet6_match(net, sk, saddr, daddr, ports, dif, sdif))
72                         continue;
73                 if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
74                         goto out;
75
76                 if (unlikely(!inet6_match(net, sk, saddr, daddr, ports, dif, sdif))) {
77                         sock_gen_put(sk);
78                         goto begin;
79                 }
80                 goto found;
81         }
82         if (get_nulls_value(node) != slot)
83                 goto begin;
84 out:
85         sk = NULL;
86 found:
87         return sk;
88 }
89 EXPORT_SYMBOL(__inet6_lookup_established);
90
91 static inline int compute_score(struct sock *sk, struct net *net,
92                                 const unsigned short hnum,
93                                 const struct in6_addr *daddr,
94                                 const int dif, const int sdif)
95 {
96         int score = -1;
97
98         if (net_eq(sock_net(sk), net) && inet_sk(sk)->inet_num == hnum &&
99             sk->sk_family == PF_INET6) {
100                 if (!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, daddr))
101                         return -1;
102
103                 if (!inet_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif))
104                         return -1;
105
106                 score =  sk->sk_bound_dev_if ? 2 : 1;
107                 if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
108                         score++;
109         }
110         return score;
111 }
112
113 /**
114  * inet6_lookup_reuseport() - execute reuseport logic on AF_INET6 socket if necessary.
115  * @net: network namespace.
116  * @sk: AF_INET6 socket, must be in TCP_LISTEN state for TCP or TCP_CLOSE for UDP.
117  * @skb: context for a potential SK_REUSEPORT program.
118  * @doff: header offset.
119  * @saddr: source address.
120  * @sport: source port.
121  * @daddr: destination address.
122  * @hnum: destination port in host byte order.
123  * @ehashfn: hash function used to generate the fallback hash.
124  *
125  * Return: NULL if sk doesn't have SO_REUSEPORT set, otherwise a pointer to
126  *         the selected sock or an error.
127  */
128 struct sock *inet6_lookup_reuseport(struct net *net, struct sock *sk,
129                                     struct sk_buff *skb, int doff,
130                                     const struct in6_addr *saddr,
131                                     __be16 sport,
132                                     const struct in6_addr *daddr,
133                                     unsigned short hnum,
134                                     inet6_ehashfn_t *ehashfn)
135 {
136         struct sock *reuse_sk = NULL;
137         u32 phash;
138
139         if (sk->sk_reuseport) {
140                 phash = INDIRECT_CALL_INET(ehashfn, udp6_ehashfn, inet6_ehashfn,
141                                            net, daddr, hnum, saddr, sport);
142                 reuse_sk = reuseport_select_sock(sk, phash, skb, doff);
143         }
144         return reuse_sk;
145 }
146 EXPORT_SYMBOL_GPL(inet6_lookup_reuseport);
147
148 /* called with rcu_read_lock() */
149 static struct sock *inet6_lhash2_lookup(struct net *net,
150                 struct inet_listen_hashbucket *ilb2,
151                 struct sk_buff *skb, int doff,
152                 const struct in6_addr *saddr,
153                 const __be16 sport, const struct in6_addr *daddr,
154                 const unsigned short hnum, const int dif, const int sdif)
155 {
156         struct sock *sk, *result = NULL;
157         struct hlist_nulls_node *node;
158         int score, hiscore = 0;
159
160         sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) {
161                 score = compute_score(sk, net, hnum, daddr, dif, sdif);
162                 if (score > hiscore) {
163                         result = inet6_lookup_reuseport(net, sk, skb, doff,
164                                                         saddr, sport, daddr, hnum, inet6_ehashfn);
165                         if (result)
166                                 return result;
167
168                         result = sk;
169                         hiscore = score;
170                 }
171         }
172
173         return result;
174 }
175
176 struct sock *inet6_lookup_run_sk_lookup(struct net *net,
177                                         int protocol,
178                                         struct sk_buff *skb, int doff,
179                                         const struct in6_addr *saddr,
180                                         const __be16 sport,
181                                         const struct in6_addr *daddr,
182                                         const u16 hnum, const int dif,
183                                         inet6_ehashfn_t *ehashfn)
184 {
185         struct sock *sk, *reuse_sk;
186         bool no_reuseport;
187
188         no_reuseport = bpf_sk_lookup_run_v6(net, protocol, saddr, sport,
189                                             daddr, hnum, dif, &sk);
190         if (no_reuseport || IS_ERR_OR_NULL(sk))
191                 return sk;
192
193         reuse_sk = inet6_lookup_reuseport(net, sk, skb, doff,
194                                           saddr, sport, daddr, hnum, ehashfn);
195         if (reuse_sk)
196                 sk = reuse_sk;
197         return sk;
198 }
199 EXPORT_SYMBOL_GPL(inet6_lookup_run_sk_lookup);
200
201 struct sock *inet6_lookup_listener(struct net *net,
202                 struct inet_hashinfo *hashinfo,
203                 struct sk_buff *skb, int doff,
204                 const struct in6_addr *saddr,
205                 const __be16 sport, const struct in6_addr *daddr,
206                 const unsigned short hnum, const int dif, const int sdif)
207 {
208         struct inet_listen_hashbucket *ilb2;
209         struct sock *result = NULL;
210         unsigned int hash2;
211
212         /* Lookup redirect from BPF */
213         if (static_branch_unlikely(&bpf_sk_lookup_enabled) &&
214             hashinfo == net->ipv4.tcp_death_row.hashinfo) {
215                 result = inet6_lookup_run_sk_lookup(net, IPPROTO_TCP, skb, doff,
216                                                     saddr, sport, daddr, hnum, dif,
217                                                     inet6_ehashfn);
218                 if (result)
219                         goto done;
220         }
221
222         hash2 = ipv6_portaddr_hash(net, daddr, hnum);
223         ilb2 = inet_lhash2_bucket(hashinfo, hash2);
224
225         result = inet6_lhash2_lookup(net, ilb2, skb, doff,
226                                      saddr, sport, daddr, hnum,
227                                      dif, sdif);
228         if (result)
229                 goto done;
230
231         /* Lookup lhash2 with in6addr_any */
232         hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);
233         ilb2 = inet_lhash2_bucket(hashinfo, hash2);
234
235         result = inet6_lhash2_lookup(net, ilb2, skb, doff,
236                                      saddr, sport, &in6addr_any, hnum,
237                                      dif, sdif);
238 done:
239         if (IS_ERR(result))
240                 return NULL;
241         return result;
242 }
243 EXPORT_SYMBOL_GPL(inet6_lookup_listener);
244
245 struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
246                           struct sk_buff *skb, int doff,
247                           const struct in6_addr *saddr, const __be16 sport,
248                           const struct in6_addr *daddr, const __be16 dport,
249                           const int dif)
250 {
251         struct sock *sk;
252         bool refcounted;
253
254         sk = __inet6_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
255                             ntohs(dport), dif, 0, &refcounted);
256         if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
257                 sk = NULL;
258         return sk;
259 }
260 EXPORT_SYMBOL_GPL(inet6_lookup);
261
262 static int __inet6_check_established(struct inet_timewait_death_row *death_row,
263                                      struct sock *sk, const __u16 lport,
264                                      struct inet_timewait_sock **twp)
265 {
266         struct inet_hashinfo *hinfo = death_row->hashinfo;
267         struct inet_sock *inet = inet_sk(sk);
268         const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr;
269         const struct in6_addr *saddr = &sk->sk_v6_daddr;
270         const int dif = sk->sk_bound_dev_if;
271         struct net *net = sock_net(sk);
272         const int sdif = l3mdev_master_ifindex_by_index(net, dif);
273         const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
274         const unsigned int hash = inet6_ehashfn(net, daddr, lport, saddr,
275                                                 inet->inet_dport);
276         struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
277         spinlock_t *lock = inet_ehash_lockp(hinfo, hash);
278         struct sock *sk2;
279         const struct hlist_nulls_node *node;
280         struct inet_timewait_sock *tw = NULL;
281
282         spin_lock(lock);
283
284         sk_nulls_for_each(sk2, node, &head->chain) {
285                 if (sk2->sk_hash != hash)
286                         continue;
287
288                 if (likely(inet6_match(net, sk2, saddr, daddr, ports,
289                                        dif, sdif))) {
290                         if (sk2->sk_state == TCP_TIME_WAIT) {
291                                 tw = inet_twsk(sk2);
292                                 if (twsk_unique(sk, sk2, twp))
293                                         break;
294                         }
295                         goto not_unique;
296                 }
297         }
298
299         /* Must record num and sport now. Otherwise we will see
300          * in hash table socket with a funny identity.
301          */
302         inet->inet_num = lport;
303         inet->inet_sport = htons(lport);
304         sk->sk_hash = hash;
305         WARN_ON(!sk_unhashed(sk));
306         __sk_nulls_add_node_rcu(sk, &head->chain);
307         if (tw) {
308                 sk_nulls_del_node_init_rcu((struct sock *)tw);
309                 __NET_INC_STATS(net, LINUX_MIB_TIMEWAITRECYCLED);
310         }
311         spin_unlock(lock);
312         sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
313
314         if (twp) {
315                 *twp = tw;
316         } else if (tw) {
317                 /* Silly. Should hash-dance instead... */
318                 inet_twsk_deschedule_put(tw);
319         }
320         return 0;
321
322 not_unique:
323         spin_unlock(lock);
324         return -EADDRNOTAVAIL;
325 }
326
327 static u64 inet6_sk_port_offset(const struct sock *sk)
328 {
329         const struct inet_sock *inet = inet_sk(sk);
330
331         return secure_ipv6_port_ephemeral(sk->sk_v6_rcv_saddr.s6_addr32,
332                                           sk->sk_v6_daddr.s6_addr32,
333                                           inet->inet_dport);
334 }
335
336 int inet6_hash_connect(struct inet_timewait_death_row *death_row,
337                        struct sock *sk)
338 {
339         u64 port_offset = 0;
340
341         if (!inet_sk(sk)->inet_num)
342                 port_offset = inet6_sk_port_offset(sk);
343         return __inet_hash_connect(death_row, sk, port_offset,
344                                    __inet6_check_established);
345 }
346 EXPORT_SYMBOL_GPL(inet6_hash_connect);
347
348 int inet6_hash(struct sock *sk)
349 {
350         int err = 0;
351
352         if (sk->sk_state != TCP_CLOSE)
353                 err = __inet_hash(sk, NULL);
354
355         return err;
356 }
357 EXPORT_SYMBOL_GPL(inet6_hash);