Merge tag 'for-linus-20190524' of git://git.kernel.dk/linux-block
[linux-2.6-microblaze.git] / fs / afs / addr_list.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Server address list management
3  *
4  * Copyright (C) 2017 Red Hat, Inc. All Rights Reserved.
5  * Written by David Howells (dhowells@redhat.com)
6  */
7
8 #include <linux/slab.h>
9 #include <linux/ctype.h>
10 #include <linux/dns_resolver.h>
11 #include <linux/inet.h>
12 #include <keys/rxrpc-type.h>
13 #include "internal.h"
14 #include "afs_fs.h"
15
16 /*
17  * Release an address list.
18  */
19 void afs_put_addrlist(struct afs_addr_list *alist)
20 {
21         if (alist && refcount_dec_and_test(&alist->usage))
22                 call_rcu(&alist->rcu, (rcu_callback_t)kfree);
23 }
24
25 /*
26  * Allocate an address list.
27  */
28 struct afs_addr_list *afs_alloc_addrlist(unsigned int nr,
29                                          unsigned short service,
30                                          unsigned short port)
31 {
32         struct afs_addr_list *alist;
33         unsigned int i;
34
35         _enter("%u,%u,%u", nr, service, port);
36
37         if (nr > AFS_MAX_ADDRESSES)
38                 nr = AFS_MAX_ADDRESSES;
39
40         alist = kzalloc(struct_size(alist, addrs, nr), GFP_KERNEL);
41         if (!alist)
42                 return NULL;
43
44         refcount_set(&alist->usage, 1);
45         alist->max_addrs = nr;
46
47         for (i = 0; i < nr; i++) {
48                 struct sockaddr_rxrpc *srx = &alist->addrs[i];
49                 srx->srx_family                 = AF_RXRPC;
50                 srx->srx_service                = service;
51                 srx->transport_type             = SOCK_DGRAM;
52                 srx->transport_len              = sizeof(srx->transport.sin6);
53                 srx->transport.sin6.sin6_family = AF_INET6;
54                 srx->transport.sin6.sin6_port   = htons(port);
55         }
56
57         return alist;
58 }
59
60 /*
61  * Parse a text string consisting of delimited addresses.
62  */
63 struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *net,
64                                                const char *text, size_t len,
65                                                char delim,
66                                                unsigned short service,
67                                                unsigned short port)
68 {
69         struct afs_vlserver_list *vllist;
70         struct afs_addr_list *alist;
71         const char *p, *end = text + len;
72         const char *problem;
73         unsigned int nr = 0;
74         int ret = -ENOMEM;
75
76         _enter("%*.*s,%c", (int)len, (int)len, text, delim);
77
78         if (!len) {
79                 _leave(" = -EDESTADDRREQ [empty]");
80                 return ERR_PTR(-EDESTADDRREQ);
81         }
82
83         if (delim == ':' && (memchr(text, ',', len) || !memchr(text, '.', len)))
84                 delim = ',';
85
86         /* Count the addresses */
87         p = text;
88         do {
89                 if (!*p) {
90                         problem = "nul";
91                         goto inval;
92                 }
93                 if (*p == delim)
94                         continue;
95                 nr++;
96                 if (*p == '[') {
97                         p++;
98                         if (p == end) {
99                                 problem = "brace1";
100                                 goto inval;
101                         }
102                         p = memchr(p, ']', end - p);
103                         if (!p) {
104                                 problem = "brace2";
105                                 goto inval;
106                         }
107                         p++;
108                         if (p >= end)
109                                 break;
110                 }
111
112                 p = memchr(p, delim, end - p);
113                 if (!p)
114                         break;
115                 p++;
116         } while (p < end);
117
118         _debug("%u/%u addresses", nr, AFS_MAX_ADDRESSES);
119
120         vllist = afs_alloc_vlserver_list(1);
121         if (!vllist)
122                 return ERR_PTR(-ENOMEM);
123
124         vllist->nr_servers = 1;
125         vllist->servers[0].server = afs_alloc_vlserver("<dummy>", 7, AFS_VL_PORT);
126         if (!vllist->servers[0].server)
127                 goto error_vl;
128
129         alist = afs_alloc_addrlist(nr, service, AFS_VL_PORT);
130         if (!alist)
131                 goto error;
132
133         /* Extract the addresses */
134         p = text;
135         do {
136                 const char *q, *stop;
137                 unsigned int xport = port;
138                 __be32 x[4];
139                 int family;
140
141                 if (*p == delim) {
142                         p++;
143                         continue;
144                 }
145
146                 if (*p == '[') {
147                         p++;
148                         q = memchr(p, ']', end - p);
149                 } else {
150                         for (q = p; q < end; q++)
151                                 if (*q == '+' || *q == delim)
152                                         break;
153                 }
154
155                 if (in4_pton(p, q - p, (u8 *)&x[0], -1, &stop)) {
156                         family = AF_INET;
157                 } else if (in6_pton(p, q - p, (u8 *)x, -1, &stop)) {
158                         family = AF_INET6;
159                 } else {
160                         problem = "family";
161                         goto bad_address;
162                 }
163
164                 p = q;
165                 if (stop != p) {
166                         problem = "nostop";
167                         goto bad_address;
168                 }
169
170                 if (q < end && *q == ']')
171                         p++;
172
173                 if (p < end) {
174                         if (*p == '+') {
175                                 /* Port number specification "+1234" */
176                                 xport = 0;
177                                 p++;
178                                 if (p >= end || !isdigit(*p)) {
179                                         problem = "port";
180                                         goto bad_address;
181                                 }
182                                 do {
183                                         xport *= 10;
184                                         xport += *p - '0';
185                                         if (xport > 65535) {
186                                                 problem = "pval";
187                                                 goto bad_address;
188                                         }
189                                         p++;
190                                 } while (p < end && isdigit(*p));
191                         } else if (*p == delim) {
192                                 p++;
193                         } else {
194                                 problem = "weird";
195                                 goto bad_address;
196                         }
197                 }
198
199                 if (family == AF_INET)
200                         afs_merge_fs_addr4(alist, x[0], xport);
201                 else
202                         afs_merge_fs_addr6(alist, x, xport);
203
204         } while (p < end);
205
206         rcu_assign_pointer(vllist->servers[0].server->addresses, alist);
207         _leave(" = [nr %u]", alist->nr_addrs);
208         return vllist;
209
210 inval:
211         _leave(" = -EINVAL [%s %zu %*.*s]",
212                problem, p - text, (int)len, (int)len, text);
213         return ERR_PTR(-EINVAL);
214 bad_address:
215         _leave(" = -EINVAL [%s %zu %*.*s]",
216                problem, p - text, (int)len, (int)len, text);
217         ret = -EINVAL;
218 error:
219         afs_put_addrlist(alist);
220 error_vl:
221         afs_put_vlserverlist(net, vllist);
222         return ERR_PTR(ret);
223 }
224
225 /*
226  * Compare old and new address lists to see if there's been any change.
227  * - How to do this in better than O(Nlog(N)) time?
228  *   - We don't really want to sort the address list, but would rather take the
229  *     list as we got it so as not to undo record rotation by the DNS server.
230  */
231 #if 0
232 static int afs_cmp_addr_list(const struct afs_addr_list *a1,
233                              const struct afs_addr_list *a2)
234 {
235 }
236 #endif
237
238 /*
239  * Perform a DNS query for VL servers and build a up an address list.
240  */
241 struct afs_vlserver_list *afs_dns_query(struct afs_cell *cell, time64_t *_expiry)
242 {
243         struct afs_vlserver_list *vllist;
244         char *result = NULL;
245         int ret;
246
247         _enter("%s", cell->name);
248
249         ret = dns_query("afsdb", cell->name, cell->name_len, "srv=1",
250                         &result, _expiry, true);
251         if (ret < 0) {
252                 _leave(" = %d [dns]", ret);
253                 return ERR_PTR(ret);
254         }
255
256         if (*_expiry == 0)
257                 *_expiry = ktime_get_real_seconds() + 60;
258
259         if (ret > 1 && result[0] == 0)
260                 vllist = afs_extract_vlserver_list(cell, result, ret);
261         else
262                 vllist = afs_parse_text_addrs(cell->net, result, ret, ',',
263                                               VL_SERVICE, AFS_VL_PORT);
264         kfree(result);
265         if (IS_ERR(vllist) && vllist != ERR_PTR(-ENOMEM))
266                 pr_err("Failed to parse DNS data %ld\n", PTR_ERR(vllist));
267
268         return vllist;
269 }
270
271 /*
272  * Merge an IPv4 entry into a fileserver address list.
273  */
274 void afs_merge_fs_addr4(struct afs_addr_list *alist, __be32 xdr, u16 port)
275 {
276         struct sockaddr_rxrpc *srx;
277         u32 addr = ntohl(xdr);
278         int i;
279
280         if (alist->nr_addrs >= alist->max_addrs)
281                 return;
282
283         for (i = 0; i < alist->nr_ipv4; i++) {
284                 struct sockaddr_in *a = &alist->addrs[i].transport.sin;
285                 u32 a_addr = ntohl(a->sin_addr.s_addr);
286                 u16 a_port = ntohs(a->sin_port);
287
288                 if (addr == a_addr && port == a_port)
289                         return;
290                 if (addr == a_addr && port < a_port)
291                         break;
292                 if (addr < a_addr)
293                         break;
294         }
295
296         if (i < alist->nr_addrs)
297                 memmove(alist->addrs + i + 1,
298                         alist->addrs + i,
299                         sizeof(alist->addrs[0]) * (alist->nr_addrs - i));
300
301         srx = &alist->addrs[i];
302         srx->srx_family = AF_RXRPC;
303         srx->transport_type = SOCK_DGRAM;
304         srx->transport_len = sizeof(srx->transport.sin);
305         srx->transport.sin.sin_family = AF_INET;
306         srx->transport.sin.sin_port = htons(port);
307         srx->transport.sin.sin_addr.s_addr = xdr;
308         alist->nr_ipv4++;
309         alist->nr_addrs++;
310 }
311
312 /*
313  * Merge an IPv6 entry into a fileserver address list.
314  */
315 void afs_merge_fs_addr6(struct afs_addr_list *alist, __be32 *xdr, u16 port)
316 {
317         struct sockaddr_rxrpc *srx;
318         int i, diff;
319
320         if (alist->nr_addrs >= alist->max_addrs)
321                 return;
322
323         for (i = alist->nr_ipv4; i < alist->nr_addrs; i++) {
324                 struct sockaddr_in6 *a = &alist->addrs[i].transport.sin6;
325                 u16 a_port = ntohs(a->sin6_port);
326
327                 diff = memcmp(xdr, &a->sin6_addr, 16);
328                 if (diff == 0 && port == a_port)
329                         return;
330                 if (diff == 0 && port < a_port)
331                         break;
332                 if (diff < 0)
333                         break;
334         }
335
336         if (i < alist->nr_addrs)
337                 memmove(alist->addrs + i + 1,
338                         alist->addrs + i,
339                         sizeof(alist->addrs[0]) * (alist->nr_addrs - i));
340
341         srx = &alist->addrs[i];
342         srx->srx_family = AF_RXRPC;
343         srx->transport_type = SOCK_DGRAM;
344         srx->transport_len = sizeof(srx->transport.sin6);
345         srx->transport.sin6.sin6_family = AF_INET6;
346         srx->transport.sin6.sin6_port = htons(port);
347         memcpy(&srx->transport.sin6.sin6_addr, xdr, 16);
348         alist->nr_addrs++;
349 }
350
351 /*
352  * Get an address to try.
353  */
354 bool afs_iterate_addresses(struct afs_addr_cursor *ac)
355 {
356         unsigned long set, failed;
357         int index;
358
359         if (!ac->alist)
360                 return false;
361
362         set = ac->alist->responded;
363         failed = ac->alist->failed;
364         _enter("%lx-%lx-%lx,%d", set, failed, ac->tried, ac->index);
365
366         ac->nr_iterations++;
367
368         set &= ~(failed | ac->tried);
369
370         if (!set)
371                 return false;
372
373         index = READ_ONCE(ac->alist->preferred);
374         if (test_bit(index, &set))
375                 goto selected;
376
377         index = __ffs(set);
378
379 selected:
380         ac->index = index;
381         set_bit(index, &ac->tried);
382         ac->responded = false;
383         return true;
384 }
385
386 /*
387  * Release an address list cursor.
388  */
389 int afs_end_cursor(struct afs_addr_cursor *ac)
390 {
391         struct afs_addr_list *alist;
392
393         alist = ac->alist;
394         if (alist) {
395                 if (ac->responded &&
396                     ac->index != alist->preferred &&
397                     test_bit(ac->alist->preferred, &ac->tried))
398                         WRITE_ONCE(alist->preferred, ac->index);
399                 afs_put_addrlist(alist);
400                 ac->alist = NULL;
401         }
402
403         return ac->error;
404 }