diff options
Diffstat (limited to 'net/sunrpc/auth.c')
| -rw-r--r-- | net/sunrpc/auth.c | 239 | 
1 files changed, 211 insertions, 28 deletions
diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c index afe67849269..f7736671742 100644 --- a/net/sunrpc/auth.c +++ b/net/sunrpc/auth.c @@ -13,6 +13,7 @@  #include <linux/errno.h>  #include <linux/hash.h>  #include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/gss_api.h>  #include <linux/spinlock.h>  #ifdef RPC_DEBUG @@ -81,7 +82,7 @@ MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");  static u32  pseudoflavor_to_flavor(u32 flavor) { -	if (flavor >= RPC_AUTH_MAXFLAVOR) +	if (flavor > RPC_AUTH_MAXFLAVOR)  		return RPC_AUTH_GSS;  	return flavor;  } @@ -122,12 +123,138 @@ rpcauth_unregister(const struct rpc_authops *ops)  }  EXPORT_SYMBOL_GPL(rpcauth_unregister); +/** + * rpcauth_get_pseudoflavor - check if security flavor is supported + * @flavor: a security flavor + * @info: a GSS mech OID, quality of protection, and service value + * + * Verifies that an appropriate kernel module is available or already loaded. + * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is + * not supported locally. + */ +rpc_authflavor_t +rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info) +{ +	const struct rpc_authops *ops; +	rpc_authflavor_t pseudoflavor; + +	ops = auth_flavors[flavor]; +	if (ops == NULL) +		request_module("rpc-auth-%u", flavor); +	spin_lock(&rpc_authflavor_lock); +	ops = auth_flavors[flavor]; +	if (ops == NULL || !try_module_get(ops->owner)) { +		spin_unlock(&rpc_authflavor_lock); +		return RPC_AUTH_MAXFLAVOR; +	} +	spin_unlock(&rpc_authflavor_lock); + +	pseudoflavor = flavor; +	if (ops->info2flavor != NULL) +		pseudoflavor = ops->info2flavor(info); + +	module_put(ops->owner); +	return pseudoflavor; +} +EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor); + +/** + * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor + * @pseudoflavor: GSS pseudoflavor to match + * @info: rpcsec_gss_info structure to fill in + * + * Returns zero and fills in "info" if pseudoflavor matches a + * supported mechanism. + */ +int +rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info) +{ +	rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor); +	const struct rpc_authops *ops; +	int result; + +	if (flavor >= RPC_AUTH_MAXFLAVOR) +		return -EINVAL; + +	ops = auth_flavors[flavor]; +	if (ops == NULL) +		request_module("rpc-auth-%u", flavor); +	spin_lock(&rpc_authflavor_lock); +	ops = auth_flavors[flavor]; +	if (ops == NULL || !try_module_get(ops->owner)) { +		spin_unlock(&rpc_authflavor_lock); +		return -ENOENT; +	} +	spin_unlock(&rpc_authflavor_lock); + +	result = -ENOENT; +	if (ops->flavor2info != NULL) +		result = ops->flavor2info(pseudoflavor, info); + +	module_put(ops->owner); +	return result; +} +EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo); + +/** + * rpcauth_list_flavors - discover registered flavors and pseudoflavors + * @array: array to fill in + * @size: size of "array" + * + * Returns the number of array items filled in, or a negative errno. + * + * The returned array is not sorted by any policy.  Callers should not + * rely on the order of the items in the returned array. + */ +int +rpcauth_list_flavors(rpc_authflavor_t *array, int size) +{ +	rpc_authflavor_t flavor; +	int result = 0; + +	spin_lock(&rpc_authflavor_lock); +	for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) { +		const struct rpc_authops *ops = auth_flavors[flavor]; +		rpc_authflavor_t pseudos[4]; +		int i, len; + +		if (result >= size) { +			result = -ENOMEM; +			break; +		} + +		if (ops == NULL) +			continue; +		if (ops->list_pseudoflavors == NULL) { +			array[result++] = ops->au_flavor; +			continue; +		} +		len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos)); +		if (len < 0) { +			result = len; +			break; +		} +		for (i = 0; i < len; i++) { +			if (result >= size) { +				result = -ENOMEM; +				break; +			} +			array[result++] = pseudos[i]; +		} +	} +	spin_unlock(&rpc_authflavor_lock); + +	dprintk("RPC:       %s returns %d\n", __func__, result); +	return result; +} +EXPORT_SYMBOL_GPL(rpcauth_list_flavors); +  struct rpc_auth * -rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt) +rpcauth_create(struct rpc_auth_create_args *args, struct rpc_clnt *clnt)  {  	struct rpc_auth		*auth;  	const struct rpc_authops *ops; -	u32			flavor = pseudoflavor_to_flavor(pseudoflavor); +	u32			flavor = pseudoflavor_to_flavor(args->pseudoflavor);  	auth = ERR_PTR(-EINVAL);  	if (flavor >= RPC_AUTH_MAXFLAVOR) @@ -142,7 +269,7 @@ rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt)  		goto out;  	}  	spin_unlock(&rpc_authflavor_lock); -	auth = ops->create(clnt, pseudoflavor); +	auth = ops->create(args, clnt);  	module_put(ops->owner);  	if (IS_ERR(auth))  		return auth; @@ -169,7 +296,7 @@ static void  rpcauth_unhash_cred_locked(struct rpc_cred *cred)  {  	hlist_del_rcu(&cred->cr_hash); -	smp_mb__before_clear_bit(); +	smp_mb__before_atomic();  	clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);  } @@ -216,6 +343,27 @@ out_nocache:  EXPORT_SYMBOL_GPL(rpcauth_init_credcache);  /* + * Setup a credential key lifetime timeout notification + */ +int +rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred) +{ +	if (!cred->cr_auth->au_ops->key_timeout) +		return 0; +	return cred->cr_auth->au_ops->key_timeout(auth, cred); +} +EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify); + +bool +rpcauth_cred_key_to_expire(struct rpc_cred *cred) +{ +	if (!cred->cr_ops->crkey_to_expire) +		return false; +	return cred->cr_ops->crkey_to_expire(cred); +} +EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire); + +/*   * Destroy a list of credentials   */  static inline @@ -286,12 +434,13 @@ EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);  /*   * Remove stale credentials. Avoid sleeping inside the loop.   */ -static int +static long  rpcauth_prune_expired(struct list_head *free, int nr_to_scan)  {  	spinlock_t *cache_lock;  	struct rpc_cred *cred, *next;  	unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM; +	long freed = 0;  	list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) { @@ -303,10 +452,11 @@ rpcauth_prune_expired(struct list_head *free, int nr_to_scan)  		 */  		if (time_in_range(cred->cr_expire, expired, jiffies) &&  		    test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) -			return 0; +			break;  		list_del_init(&cred->cr_lru);  		number_cred_unused--; +		freed++;  		if (atomic_read(&cred->cr_count) != 0)  			continue; @@ -319,27 +469,39 @@ rpcauth_prune_expired(struct list_head *free, int nr_to_scan)  		}  		spin_unlock(cache_lock);  	} -	return (number_cred_unused / 100) * sysctl_vfs_cache_pressure; +	return freed;  }  /*   * Run memory cache shrinker.   */ -static int -rpcauth_cache_shrinker(struct shrinker *shrink, int nr_to_scan, gfp_t gfp_mask) +static unsigned long +rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc) +  {  	LIST_HEAD(free); -	int res; +	unsigned long freed; -	if ((gfp_mask & GFP_KERNEL) != GFP_KERNEL) -		return (nr_to_scan == 0) ? 0 : -1; +	if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL) +		return SHRINK_STOP; + +	/* nothing left, don't come back */  	if (list_empty(&cred_unused)) -		return 0; +		return SHRINK_STOP; +  	spin_lock(&rpc_credcache_lock); -	res = rpcauth_prune_expired(&free, nr_to_scan); +	freed = rpcauth_prune_expired(&free, sc->nr_to_scan);  	spin_unlock(&rpc_credcache_lock);  	rpcauth_destroy_credlist(&free); -	return res; + +	return freed; +} + +static unsigned long +rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc) + +{ +	return (number_cred_unused / 100) * sysctl_vfs_cache_pressure;  }  /* @@ -351,15 +513,14 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,  {  	LIST_HEAD(free);  	struct rpc_cred_cache *cache = auth->au_credcache; -	struct hlist_node *pos;  	struct rpc_cred	*cred = NULL,  			*entry, *new;  	unsigned int nr; -	nr = hash_long(acred->uid, cache->hashbits); +	nr = hash_long(from_kuid(&init_user_ns, acred->uid), cache->hashbits);  	rcu_read_lock(); -	hlist_for_each_entry_rcu(entry, pos, &cache->hashtable[nr], cr_hash) { +	hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {  		if (!entry->cr_ops->crmatch(acred, entry, flags))  			continue;  		spin_lock(&cache->lock); @@ -383,7 +544,7 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,  	}  	spin_lock(&cache->lock); -	hlist_for_each_entry(entry, pos, &cache->hashtable[nr], cr_hash) { +	hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {  		if (!entry->cr_ops->crmatch(acred, entry, flags))  			continue;  		cred = get_rpccred(entry); @@ -431,6 +592,7 @@ rpcauth_lookupcred(struct rpc_auth *auth, int flags)  	put_group_info(acred.group_info);  	return ret;  } +EXPORT_SYMBOL_GPL(rpcauth_lookupcred);  void  rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, @@ -463,8 +625,8 @@ rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)  {  	struct rpc_auth *auth = task->tk_client->cl_auth;  	struct auth_cred acred = { -		.uid = 0, -		.gid = 0, +		.uid = GLOBAL_ROOT_UID, +		.gid = GLOBAL_ROOT_GID,  	};  	dprintk("RPC: %5u looking up %s cred\n", @@ -563,8 +725,17 @@ rpcauth_checkverf(struct rpc_task *task, __be32 *p)  	return cred->cr_ops->crvalidate(task, p);  } +static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp, +				   __be32 *data, void *obj) +{ +	struct xdr_stream xdr; + +	xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data); +	encode(rqstp, &xdr, obj); +} +  int -rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp, +rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp,  		__be32 *data, void *obj)  {  	struct rpc_cred *cred = task->tk_rqstp->rq_cred; @@ -574,11 +745,22 @@ rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp,  	if (cred->cr_ops->crwrap_req)  		return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);  	/* By default, we encode the arguments normally. */ -	return encode(rqstp, data, obj); +	rpcauth_wrap_req_encode(encode, rqstp, data, obj); +	return 0; +} + +static int +rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, +			  __be32 *data, void *obj) +{ +	struct xdr_stream xdr; + +	xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data); +	return decode(rqstp, &xdr, obj);  }  int -rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp, +rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp,  		__be32 *data, void *obj)  {  	struct rpc_cred *cred = task->tk_rqstp->rq_cred; @@ -589,7 +771,7 @@ rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp,  		return cred->cr_ops->crunwrap_resp(task, decode, rqstp,  						   data, obj);  	/* By default, we decode the arguments normally. */ -	return decode(rqstp, data, obj); +	return rpcauth_unwrap_req_decode(decode, rqstp, data, obj);  }  int @@ -604,7 +786,7 @@ rpcauth_refreshcred(struct rpc_task *task)  		if (err < 0)  			goto out;  		cred = task->tk_rqstp->rq_cred; -	}; +	}  	dprintk("RPC: %5u refreshing %s cred %p\n",  		task->tk_pid, cred->cr_auth->au_ops->au_name, cred); @@ -636,7 +818,8 @@ rpcauth_uptodatecred(struct rpc_task *task)  }  static struct shrinker rpc_cred_shrinker = { -	.shrink = rpcauth_cache_shrinker, +	.count_objects = rpcauth_cache_shrink_count, +	.scan_objects = rpcauth_cache_shrink_scan,  	.seeks = DEFAULT_SEEKS,  };  | 
