Merge tag 'armsoc-defconfig' of git://git.kernel.org/pub/scm/linux/kernel/git/arm...
[linux-2.6-microblaze.git] / net / sunrpc / auth.c
1 /*
2  * linux/net/sunrpc/auth.c
3  *
4  * Generic RPC client authentication API.
5  *
6  * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
7  */
8
9 #include <linux/types.h>
10 #include <linux/sched.h>
11 #include <linux/cred.h>
12 #include <linux/module.h>
13 #include <linux/slab.h>
14 #include <linux/errno.h>
15 #include <linux/hash.h>
16 #include <linux/sunrpc/clnt.h>
17 #include <linux/sunrpc/gss_api.h>
18 #include <linux/spinlock.h>
19
20 #if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
21 # define RPCDBG_FACILITY        RPCDBG_AUTH
22 #endif
23
24 #define RPC_CREDCACHE_DEFAULT_HASHBITS  (4)
25 struct rpc_cred_cache {
26         struct hlist_head       *hashtable;
27         unsigned int            hashbits;
28         spinlock_t              lock;
29 };
30
31 static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
32
33 static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
34         [RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops,
35         [RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops,
36         NULL,                   /* others can be loadable modules */
37 };
38
39 static LIST_HEAD(cred_unused);
40 static unsigned long number_cred_unused;
41
42 #define MAX_HASHTABLE_BITS (14)
43 static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
44 {
45         unsigned long num;
46         unsigned int nbits;
47         int ret;
48
49         if (!val)
50                 goto out_inval;
51         ret = kstrtoul(val, 0, &num);
52         if (ret)
53                 goto out_inval;
54         nbits = fls(num - 1);
55         if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
56                 goto out_inval;
57         *(unsigned int *)kp->arg = nbits;
58         return 0;
59 out_inval:
60         return -EINVAL;
61 }
62
63 static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
64 {
65         unsigned int nbits;
66
67         nbits = *(unsigned int *)kp->arg;
68         return sprintf(buffer, "%u", 1U << nbits);
69 }
70
71 #define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);
72
73 static const struct kernel_param_ops param_ops_hashtbl_sz = {
74         .set = param_set_hashtbl_sz,
75         .get = param_get_hashtbl_sz,
76 };
77
78 module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
79 MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");
80
81 static unsigned long auth_max_cred_cachesize = ULONG_MAX;
82 module_param(auth_max_cred_cachesize, ulong, 0644);
83 MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size");
84
85 static u32
86 pseudoflavor_to_flavor(u32 flavor) {
87         if (flavor > RPC_AUTH_MAXFLAVOR)
88                 return RPC_AUTH_GSS;
89         return flavor;
90 }
91
92 int
93 rpcauth_register(const struct rpc_authops *ops)
94 {
95         const struct rpc_authops *old;
96         rpc_authflavor_t flavor;
97
98         if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
99                 return -EINVAL;
100         old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], NULL, ops);
101         if (old == NULL || old == ops)
102                 return 0;
103         return -EPERM;
104 }
105 EXPORT_SYMBOL_GPL(rpcauth_register);
106
107 int
108 rpcauth_unregister(const struct rpc_authops *ops)
109 {
110         const struct rpc_authops *old;
111         rpc_authflavor_t flavor;
112
113         if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
114                 return -EINVAL;
115
116         old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], ops, NULL);
117         if (old == ops || old == NULL)
118                 return 0;
119         return -EPERM;
120 }
121 EXPORT_SYMBOL_GPL(rpcauth_unregister);
122
123 static const struct rpc_authops *
124 rpcauth_get_authops(rpc_authflavor_t flavor)
125 {
126         const struct rpc_authops *ops;
127
128         if (flavor >= RPC_AUTH_MAXFLAVOR)
129                 return NULL;
130
131         rcu_read_lock();
132         ops = rcu_dereference(auth_flavors[flavor]);
133         if (ops == NULL) {
134                 rcu_read_unlock();
135                 request_module("rpc-auth-%u", flavor);
136                 rcu_read_lock();
137                 ops = rcu_dereference(auth_flavors[flavor]);
138                 if (ops == NULL)
139                         goto out;
140         }
141         if (!try_module_get(ops->owner))
142                 ops = NULL;
143 out:
144         rcu_read_unlock();
145         return ops;
146 }
147
148 static void
149 rpcauth_put_authops(const struct rpc_authops *ops)
150 {
151         module_put(ops->owner);
152 }
153
154 /**
155  * rpcauth_get_pseudoflavor - check if security flavor is supported
156  * @flavor: a security flavor
157  * @info: a GSS mech OID, quality of protection, and service value
158  *
159  * Verifies that an appropriate kernel module is available or already loaded.
160  * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
161  * not supported locally.
162  */
163 rpc_authflavor_t
164 rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
165 {
166         const struct rpc_authops *ops = rpcauth_get_authops(flavor);
167         rpc_authflavor_t pseudoflavor;
168
169         if (!ops)
170                 return RPC_AUTH_MAXFLAVOR;
171         pseudoflavor = flavor;
172         if (ops->info2flavor != NULL)
173                 pseudoflavor = ops->info2flavor(info);
174
175         rpcauth_put_authops(ops);
176         return pseudoflavor;
177 }
178 EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);
179
180 /**
181  * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
182  * @pseudoflavor: GSS pseudoflavor to match
183  * @info: rpcsec_gss_info structure to fill in
184  *
185  * Returns zero and fills in "info" if pseudoflavor matches a
186  * supported mechanism.
187  */
188 int
189 rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
190 {
191         rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
192         const struct rpc_authops *ops;
193         int result;
194
195         ops = rpcauth_get_authops(flavor);
196         if (ops == NULL)
197                 return -ENOENT;
198
199         result = -ENOENT;
200         if (ops->flavor2info != NULL)
201                 result = ops->flavor2info(pseudoflavor, info);
202
203         rpcauth_put_authops(ops);
204         return result;
205 }
206 EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);
207
208 /**
209  * rpcauth_list_flavors - discover registered flavors and pseudoflavors
210  * @array: array to fill in
211  * @size: size of "array"
212  *
213  * Returns the number of array items filled in, or a negative errno.
214  *
215  * The returned array is not sorted by any policy.  Callers should not
216  * rely on the order of the items in the returned array.
217  */
218 int
219 rpcauth_list_flavors(rpc_authflavor_t *array, int size)
220 {
221         const struct rpc_authops *ops;
222         rpc_authflavor_t flavor, pseudos[4];
223         int i, len, result = 0;
224
225         rcu_read_lock();
226         for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) {
227                 ops = rcu_dereference(auth_flavors[flavor]);
228                 if (result >= size) {
229                         result = -ENOMEM;
230                         break;
231                 }
232
233                 if (ops == NULL)
234                         continue;
235                 if (ops->list_pseudoflavors == NULL) {
236                         array[result++] = ops->au_flavor;
237                         continue;
238                 }
239                 len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos));
240                 if (len < 0) {
241                         result = len;
242                         break;
243                 }
244                 for (i = 0; i < len; i++) {
245                         if (result >= size) {
246                                 result = -ENOMEM;
247                                 break;
248                         }
249                         array[result++] = pseudos[i];
250                 }
251         }
252         rcu_read_unlock();
253
254         dprintk("RPC:       %s returns %d\n", __func__, result);
255         return result;
256 }
257 EXPORT_SYMBOL_GPL(rpcauth_list_flavors);
258
259 struct rpc_auth *
260 rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
261 {
262         struct rpc_auth *auth = ERR_PTR(-EINVAL);
263         const struct rpc_authops *ops;
264         u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor);
265
266         ops = rpcauth_get_authops(flavor);
267         if (ops == NULL)
268                 goto out;
269
270         auth = ops->create(args, clnt);
271
272         rpcauth_put_authops(ops);
273         if (IS_ERR(auth))
274                 return auth;
275         if (clnt->cl_auth)
276                 rpcauth_release(clnt->cl_auth);
277         clnt->cl_auth = auth;
278
279 out:
280         return auth;
281 }
282 EXPORT_SYMBOL_GPL(rpcauth_create);
283
284 void
285 rpcauth_release(struct rpc_auth *auth)
286 {
287         if (!refcount_dec_and_test(&auth->au_count))
288                 return;
289         auth->au_ops->destroy(auth);
290 }
291
292 static DEFINE_SPINLOCK(rpc_credcache_lock);
293
294 /*
295  * On success, the caller is responsible for freeing the reference
296  * held by the hashtable
297  */
298 static bool
299 rpcauth_unhash_cred_locked(struct rpc_cred *cred)
300 {
301         if (!test_and_clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
302                 return false;
303         hlist_del_rcu(&cred->cr_hash);
304         return true;
305 }
306
307 static bool
308 rpcauth_unhash_cred(struct rpc_cred *cred)
309 {
310         spinlock_t *cache_lock;
311         bool ret;
312
313         if (!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
314                 return false;
315         cache_lock = &cred->cr_auth->au_credcache->lock;
316         spin_lock(cache_lock);
317         ret = rpcauth_unhash_cred_locked(cred);
318         spin_unlock(cache_lock);
319         return ret;
320 }
321
322 /*
323  * Initialize RPC credential cache
324  */
325 int
326 rpcauth_init_credcache(struct rpc_auth *auth)
327 {
328         struct rpc_cred_cache *new;
329         unsigned int hashsize;
330
331         new = kmalloc(sizeof(*new), GFP_KERNEL);
332         if (!new)
333                 goto out_nocache;
334         new->hashbits = auth_hashbits;
335         hashsize = 1U << new->hashbits;
336         new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
337         if (!new->hashtable)
338                 goto out_nohashtbl;
339         spin_lock_init(&new->lock);
340         auth->au_credcache = new;
341         return 0;
342 out_nohashtbl:
343         kfree(new);
344 out_nocache:
345         return -ENOMEM;
346 }
347 EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
348
349 /*
350  * Setup a credential key lifetime timeout notification
351  */
352 int
353 rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred)
354 {
355         if (!cred->cr_auth->au_ops->key_timeout)
356                 return 0;
357         return cred->cr_auth->au_ops->key_timeout(auth, cred);
358 }
359 EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify);
360
361 bool
362 rpcauth_cred_key_to_expire(struct rpc_auth *auth, struct rpc_cred *cred)
363 {
364         if (auth->au_flags & RPCAUTH_AUTH_NO_CRKEY_TIMEOUT)
365                 return false;
366         if (!cred->cr_ops->crkey_to_expire)
367                 return false;
368         return cred->cr_ops->crkey_to_expire(cred);
369 }
370 EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire);
371
372 char *
373 rpcauth_stringify_acceptor(struct rpc_cred *cred)
374 {
375         if (!cred->cr_ops->crstringify_acceptor)
376                 return NULL;
377         return cred->cr_ops->crstringify_acceptor(cred);
378 }
379 EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor);
380
381 /*
382  * Destroy a list of credentials
383  */
384 static inline
385 void rpcauth_destroy_credlist(struct list_head *head)
386 {
387         struct rpc_cred *cred;
388
389         while (!list_empty(head)) {
390                 cred = list_entry(head->next, struct rpc_cred, cr_lru);
391                 list_del_init(&cred->cr_lru);
392                 put_rpccred(cred);
393         }
394 }
395
396 static void
397 rpcauth_lru_add_locked(struct rpc_cred *cred)
398 {
399         if (!list_empty(&cred->cr_lru))
400                 return;
401         number_cred_unused++;
402         list_add_tail(&cred->cr_lru, &cred_unused);
403 }
404
405 static void
406 rpcauth_lru_add(struct rpc_cred *cred)
407 {
408         if (!list_empty(&cred->cr_lru))
409                 return;
410         spin_lock(&rpc_credcache_lock);
411         rpcauth_lru_add_locked(cred);
412         spin_unlock(&rpc_credcache_lock);
413 }
414
415 static void
416 rpcauth_lru_remove_locked(struct rpc_cred *cred)
417 {
418         if (list_empty(&cred->cr_lru))
419                 return;
420         number_cred_unused--;
421         list_del_init(&cred->cr_lru);
422 }
423
424 static void
425 rpcauth_lru_remove(struct rpc_cred *cred)
426 {
427         if (list_empty(&cred->cr_lru))
428                 return;
429         spin_lock(&rpc_credcache_lock);
430         rpcauth_lru_remove_locked(cred);
431         spin_unlock(&rpc_credcache_lock);
432 }
433
434 /*
435  * Clear the RPC credential cache, and delete those credentials
436  * that are not referenced.
437  */
438 void
439 rpcauth_clear_credcache(struct rpc_cred_cache *cache)
440 {
441         LIST_HEAD(free);
442         struct hlist_head *head;
443         struct rpc_cred *cred;
444         unsigned int hashsize = 1U << cache->hashbits;
445         int             i;
446
447         spin_lock(&rpc_credcache_lock);
448         spin_lock(&cache->lock);
449         for (i = 0; i < hashsize; i++) {
450                 head = &cache->hashtable[i];
451                 while (!hlist_empty(head)) {
452                         cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
453                         rpcauth_unhash_cred_locked(cred);
454                         /* Note: We now hold a reference to cred */
455                         rpcauth_lru_remove_locked(cred);
456                         list_add_tail(&cred->cr_lru, &free);
457                 }
458         }
459         spin_unlock(&cache->lock);
460         spin_unlock(&rpc_credcache_lock);
461         rpcauth_destroy_credlist(&free);
462 }
463
464 /*
465  * Destroy the RPC credential cache
466  */
467 void
468 rpcauth_destroy_credcache(struct rpc_auth *auth)
469 {
470         struct rpc_cred_cache *cache = auth->au_credcache;
471
472         if (cache) {
473                 auth->au_credcache = NULL;
474                 rpcauth_clear_credcache(cache);
475                 kfree(cache->hashtable);
476                 kfree(cache);
477         }
478 }
479 EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
480
481
482 #define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)
483
484 /*
485  * Remove stale credentials. Avoid sleeping inside the loop.
486  */
487 static long
488 rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
489 {
490         struct rpc_cred *cred, *next;
491         unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
492         long freed = 0;
493
494         list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {
495
496                 if (nr_to_scan-- == 0)
497                         break;
498                 if (refcount_read(&cred->cr_count) > 1) {
499                         rpcauth_lru_remove_locked(cred);
500                         continue;
501                 }
502                 /*
503                  * Enforce a 60 second garbage collection moratorium
504                  * Note that the cred_unused list must be time-ordered.
505                  */
506                 if (!time_in_range(cred->cr_expire, expired, jiffies))
507                         continue;
508                 if (!rpcauth_unhash_cred(cred))
509                         continue;
510
511                 rpcauth_lru_remove_locked(cred);
512                 freed++;
513                 list_add_tail(&cred->cr_lru, free);
514         }
515         return freed ? freed : SHRINK_STOP;
516 }
517
518 static unsigned long
519 rpcauth_cache_do_shrink(int nr_to_scan)
520 {
521         LIST_HEAD(free);
522         unsigned long freed;
523
524         spin_lock(&rpc_credcache_lock);
525         freed = rpcauth_prune_expired(&free, nr_to_scan);
526         spin_unlock(&rpc_credcache_lock);
527         rpcauth_destroy_credlist(&free);
528
529         return freed;
530 }
531
532 /*
533  * Run memory cache shrinker.
534  */
535 static unsigned long
536 rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
537
538 {
539         if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL)
540                 return SHRINK_STOP;
541
542         /* nothing left, don't come back */
543         if (list_empty(&cred_unused))
544                 return SHRINK_STOP;
545
546         return rpcauth_cache_do_shrink(sc->nr_to_scan);
547 }
548
549 static unsigned long
550 rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc)
551
552 {
553         return number_cred_unused * sysctl_vfs_cache_pressure / 100;
554 }
555
556 static void
557 rpcauth_cache_enforce_limit(void)
558 {
559         unsigned long diff;
560         unsigned int nr_to_scan;
561
562         if (number_cred_unused <= auth_max_cred_cachesize)
563                 return;
564         diff = number_cred_unused - auth_max_cred_cachesize;
565         nr_to_scan = 100;
566         if (diff < nr_to_scan)
567                 nr_to_scan = diff;
568         rpcauth_cache_do_shrink(nr_to_scan);
569 }
570
571 /*
572  * Look up a process' credentials in the authentication cache
573  */
574 struct rpc_cred *
575 rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
576                 int flags, gfp_t gfp)
577 {
578         LIST_HEAD(free);
579         struct rpc_cred_cache *cache = auth->au_credcache;
580         struct rpc_cred *cred = NULL,
581                         *entry, *new;
582         unsigned int nr;
583
584         nr = auth->au_ops->hash_cred(acred, cache->hashbits);
585
586         rcu_read_lock();
587         hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
588                 if (!entry->cr_ops->crmatch(acred, entry, flags))
589                         continue;
590                 if (flags & RPCAUTH_LOOKUP_RCU) {
591                         if (test_bit(RPCAUTH_CRED_NEW, &entry->cr_flags) ||
592                             refcount_read(&entry->cr_count) == 0)
593                                 continue;
594                         cred = entry;
595                         break;
596                 }
597                 cred = get_rpccred(entry);
598                 if (cred)
599                         break;
600         }
601         rcu_read_unlock();
602
603         if (cred != NULL)
604                 goto found;
605
606         if (flags & RPCAUTH_LOOKUP_RCU)
607                 return ERR_PTR(-ECHILD);
608
609         new = auth->au_ops->crcreate(auth, acred, flags, gfp);
610         if (IS_ERR(new)) {
611                 cred = new;
612                 goto out;
613         }
614
615         spin_lock(&cache->lock);
616         hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
617                 if (!entry->cr_ops->crmatch(acred, entry, flags))
618                         continue;
619                 cred = get_rpccred(entry);
620                 if (cred)
621                         break;
622         }
623         if (cred == NULL) {
624                 cred = new;
625                 set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
626                 refcount_inc(&cred->cr_count);
627                 hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
628         } else
629                 list_add_tail(&new->cr_lru, &free);
630         spin_unlock(&cache->lock);
631         rpcauth_cache_enforce_limit();
632 found:
633         if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
634             cred->cr_ops->cr_init != NULL &&
635             !(flags & RPCAUTH_LOOKUP_NEW)) {
636                 int res = cred->cr_ops->cr_init(auth, cred);
637                 if (res < 0) {
638                         put_rpccred(cred);
639                         cred = ERR_PTR(res);
640                 }
641         }
642         rpcauth_destroy_credlist(&free);
643 out:
644         return cred;
645 }
646 EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
647
648 struct rpc_cred *
649 rpcauth_lookupcred(struct rpc_auth *auth, int flags)
650 {
651         struct auth_cred acred;
652         struct rpc_cred *ret;
653         const struct cred *cred = current_cred();
654
655         dprintk("RPC:       looking up %s cred\n",
656                 auth->au_ops->au_name);
657
658         memset(&acred, 0, sizeof(acred));
659         acred.uid = cred->fsuid;
660         acred.gid = cred->fsgid;
661         acred.group_info = cred->group_info;
662         ret = auth->au_ops->lookup_cred(auth, &acred, flags);
663         return ret;
664 }
665 EXPORT_SYMBOL_GPL(rpcauth_lookupcred);
666
667 void
668 rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
669                   struct rpc_auth *auth, const struct rpc_credops *ops)
670 {
671         INIT_HLIST_NODE(&cred->cr_hash);
672         INIT_LIST_HEAD(&cred->cr_lru);
673         refcount_set(&cred->cr_count, 1);
674         cred->cr_auth = auth;
675         cred->cr_ops = ops;
676         cred->cr_expire = jiffies;
677         cred->cr_uid = acred->uid;
678 }
679 EXPORT_SYMBOL_GPL(rpcauth_init_cred);
680
681 struct rpc_cred *
682 rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags)
683 {
684         dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid,
685                         cred->cr_auth->au_ops->au_name, cred);
686         return get_rpccred(cred);
687 }
688 EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred);
689
690 static struct rpc_cred *
691 rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
692 {
693         struct rpc_auth *auth = task->tk_client->cl_auth;
694         struct auth_cred acred = {
695                 .uid = GLOBAL_ROOT_UID,
696                 .gid = GLOBAL_ROOT_GID,
697         };
698
699         dprintk("RPC: %5u looking up %s cred\n",
700                 task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
701         return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
702 }
703
704 static struct rpc_cred *
705 rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
706 {
707         struct rpc_auth *auth = task->tk_client->cl_auth;
708
709         dprintk("RPC: %5u looking up %s cred\n",
710                 task->tk_pid, auth->au_ops->au_name);
711         return rpcauth_lookupcred(auth, lookupflags);
712 }
713
714 static int
715 rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags)
716 {
717         struct rpc_rqst *req = task->tk_rqstp;
718         struct rpc_cred *new;
719         int lookupflags = 0;
720
721         if (flags & RPC_TASK_ASYNC)
722                 lookupflags |= RPCAUTH_LOOKUP_NEW;
723         if (cred != NULL)
724                 new = cred->cr_ops->crbind(task, cred, lookupflags);
725         else if (flags & RPC_TASK_ROOTCREDS)
726                 new = rpcauth_bind_root_cred(task, lookupflags);
727         else
728                 new = rpcauth_bind_new_cred(task, lookupflags);
729         if (IS_ERR(new))
730                 return PTR_ERR(new);
731         put_rpccred(req->rq_cred);
732         req->rq_cred = new;
733         return 0;
734 }
735
736 void
737 put_rpccred(struct rpc_cred *cred)
738 {
739         if (cred == NULL)
740                 return;
741         rcu_read_lock();
742         if (refcount_dec_and_test(&cred->cr_count))
743                 goto destroy;
744         if (refcount_read(&cred->cr_count) != 1 ||
745             !test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
746                 goto out;
747         if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
748                 cred->cr_expire = jiffies;
749                 rpcauth_lru_add(cred);
750                 /* Race breaker */
751                 if (unlikely(!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)))
752                         rpcauth_lru_remove(cred);
753         } else if (rpcauth_unhash_cred(cred)) {
754                 rpcauth_lru_remove(cred);
755                 if (refcount_dec_and_test(&cred->cr_count))
756                         goto destroy;
757         }
758 out:
759         rcu_read_unlock();
760         return;
761 destroy:
762         rcu_read_unlock();
763         cred->cr_ops->crdestroy(cred);
764 }
765 EXPORT_SYMBOL_GPL(put_rpccred);
766
767 __be32 *
768 rpcauth_marshcred(struct rpc_task *task, __be32 *p)
769 {
770         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
771
772         dprintk("RPC: %5u marshaling %s cred %p\n",
773                 task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
774
775         return cred->cr_ops->crmarshal(task, p);
776 }
777
778 __be32 *
779 rpcauth_checkverf(struct rpc_task *task, __be32 *p)
780 {
781         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
782
783         dprintk("RPC: %5u validating %s cred %p\n",
784                 task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
785
786         return cred->cr_ops->crvalidate(task, p);
787 }
788
789 static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
790                                    __be32 *data, void *obj)
791 {
792         struct xdr_stream xdr;
793
794         xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data);
795         encode(rqstp, &xdr, obj);
796 }
797
798 int
799 rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp,
800                 __be32 *data, void *obj)
801 {
802         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
803
804         dprintk("RPC: %5u using %s cred %p to wrap rpc data\n",
805                         task->tk_pid, cred->cr_ops->cr_name, cred);
806         if (cred->cr_ops->crwrap_req)
807                 return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
808         /* By default, we encode the arguments normally. */
809         rpcauth_wrap_req_encode(encode, rqstp, data, obj);
810         return 0;
811 }
812
813 static int
814 rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
815                           __be32 *data, void *obj)
816 {
817         struct xdr_stream xdr;
818
819         xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data);
820         return decode(rqstp, &xdr, obj);
821 }
822
823 int
824 rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp,
825                 __be32 *data, void *obj)
826 {
827         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
828
829         dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n",
830                         task->tk_pid, cred->cr_ops->cr_name, cred);
831         if (cred->cr_ops->crunwrap_resp)
832                 return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
833                                                    data, obj);
834         /* By default, we decode the arguments normally. */
835         return rpcauth_unwrap_req_decode(decode, rqstp, data, obj);
836 }
837
838 bool
839 rpcauth_xmit_need_reencode(struct rpc_task *task)
840 {
841         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
842
843         if (!cred || !cred->cr_ops->crneed_reencode)
844                 return false;
845         return cred->cr_ops->crneed_reencode(task);
846 }
847
848 int
849 rpcauth_refreshcred(struct rpc_task *task)
850 {
851         struct rpc_cred *cred;
852         int err;
853
854         cred = task->tk_rqstp->rq_cred;
855         if (cred == NULL) {
856                 err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
857                 if (err < 0)
858                         goto out;
859                 cred = task->tk_rqstp->rq_cred;
860         }
861         dprintk("RPC: %5u refreshing %s cred %p\n",
862                 task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
863
864         err = cred->cr_ops->crrefresh(task);
865 out:
866         if (err < 0)
867                 task->tk_status = err;
868         return err;
869 }
870
871 void
872 rpcauth_invalcred(struct rpc_task *task)
873 {
874         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
875
876         dprintk("RPC: %5u invalidating %s cred %p\n",
877                 task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
878         if (cred)
879                 clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
880 }
881
882 int
883 rpcauth_uptodatecred(struct rpc_task *task)
884 {
885         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
886
887         return cred == NULL ||
888                 test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
889 }
890
891 static struct shrinker rpc_cred_shrinker = {
892         .count_objects = rpcauth_cache_shrink_count,
893         .scan_objects = rpcauth_cache_shrink_scan,
894         .seeks = DEFAULT_SEEKS,
895 };
896
897 int __init rpcauth_init_module(void)
898 {
899         int err;
900
901         err = rpc_init_authunix();
902         if (err < 0)
903                 goto out1;
904         err = rpc_init_generic_auth();
905         if (err < 0)
906                 goto out2;
907         err = register_shrinker(&rpc_cred_shrinker);
908         if (err < 0)
909                 goto out3;
910         return 0;
911 out3:
912         rpc_destroy_generic_auth();
913 out2:
914         rpc_destroy_authunix();
915 out1:
916         return err;
917 }
918
919 void rpcauth_remove_module(void)
920 {
921         rpc_destroy_authunix();
922         rpc_destroy_generic_auth();
923         unregister_shrinker(&rpc_cred_shrinker);
924 }