diff options
Diffstat (limited to 'net/sunrpc')
59 files changed, 13081 insertions, 6272 deletions
diff --git a/net/sunrpc/Kconfig b/net/sunrpc/Kconfig index 443c161eb8b..0754d0f466d 100644 --- a/net/sunrpc/Kconfig +++ b/net/sunrpc/Kconfig @@ -3,28 +3,24 @@ config SUNRPC config SUNRPC_GSS tristate + select OID_REGISTRY -config SUNRPC_XPRT_RDMA - tristate - depends on SUNRPC && INFINIBAND && INFINIBAND_ADDR_TRANS && EXPERIMENTAL - default SUNRPC && INFINIBAND - help - This option allows the NFS client and server to support - an RDMA-enabled transport. +config SUNRPC_BACKCHANNEL + bool + depends on SUNRPC - To compile RPC client RDMA transport support as a module, - choose M here: the module will be called xprtrdma. - - If unsure, say N. +config SUNRPC_SWAP + bool + depends on SUNRPC config RPCSEC_GSS_KRB5 - tristate "Secure RPC: Kerberos V mechanism (EXPERIMENTAL)" - depends on SUNRPC && EXPERIMENTAL + tristate "Secure RPC: Kerberos V mechanism" + depends on SUNRPC && CRYPTO + depends on CRYPTO_MD5 && CRYPTO_DES && CRYPTO_CBC && CRYPTO_CTS + depends on CRYPTO_ECB && CRYPTO_HMAC && CRYPTO_SHA1 && CRYPTO_AES + depends on CRYPTO_ARC4 + default y select SUNRPC_GSS - select CRYPTO - select CRYPTO_MD5 - select CRYPTO_DES - select CRYPTO_CBC help Choose Y here to enable Secure RPC using the Kerberos version 5 GSS-API mechanism (RFC 1964). @@ -34,23 +30,43 @@ config RPCSEC_GSS_KRB5 available from http://linux-nfs.org/. In addition, user-space Kerberos support should be installed. + If unsure, say Y. + +config SUNRPC_DEBUG + bool "RPC: Enable dprintk debugging" + depends on SUNRPC && SYSCTL + help + This option enables a sysctl-based debugging interface + that is be used by the 'rpcdebug' utility to turn on or off + logging of different aspects of the kernel RPC activity. + + Disabling this option will make your kernel slightly smaller, + but makes troubleshooting NFS issues significantly harder. + + If unsure, say Y. + +config SUNRPC_XPRT_RDMA_CLIENT + tristate "RPC over RDMA Client Support" + depends on SUNRPC && INFINIBAND && INFINIBAND_ADDR_TRANS + default SUNRPC && INFINIBAND + help + This option allows the NFS client to support an RDMA-enabled + transport. + + To compile RPC client RDMA transport support as a module, + choose M here: the module will be called xprtrdma. + If unsure, say N. -config RPCSEC_GSS_SPKM3 - tristate "Secure RPC: SPKM3 mechanism (EXPERIMENTAL)" - depends on SUNRPC && EXPERIMENTAL - select SUNRPC_GSS - select CRYPTO - select CRYPTO_MD5 - select CRYPTO_DES - select CRYPTO_CAST5 - select CRYPTO_CBC +config SUNRPC_XPRT_RDMA_SERVER + tristate "RPC over RDMA Server Support" + depends on SUNRPC && INFINIBAND && INFINIBAND_ADDR_TRANS + default SUNRPC && INFINIBAND help - Choose Y here to enable Secure RPC using the SPKM3 public key - GSS-API mechanism (RFC 2025). + This option allows the NFS server to support an RDMA-enabled + transport. - Secure RPC calls with SPKM3 require an auxiliary userspace - daemon which may be found in the Linux nfs-utils package - available from http://linux-nfs.org/. + To compile RPC server RDMA transport support as a module, + choose M here: the module will be called svcrdma. If unsure, say N. diff --git a/net/sunrpc/Makefile b/net/sunrpc/Makefile index 9d2fca5ad14..e5a7a1cac8f 100644 --- a/net/sunrpc/Makefile +++ b/net/sunrpc/Makefile @@ -5,7 +5,8 @@ obj-$(CONFIG_SUNRPC) += sunrpc.o obj-$(CONFIG_SUNRPC_GSS) += auth_gss/ -obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma/ + +obj-y += xprtrdma/ sunrpc-y := clnt.o xprt.o socklib.o xprtsock.o sched.o \ auth.o auth_null.o auth_unix.o auth_generic.o \ @@ -13,6 +14,6 @@ sunrpc-y := clnt.o xprt.o socklib.o xprtsock.o sched.o \ addr.o rpcb_clnt.o timer.o xdr.o \ sunrpc_syms.o cache.o rpc_pipe.o \ svc_xprt.o -sunrpc-$(CONFIG_NFS_V4_1) += backchannel_rqst.o bc_svc.o +sunrpc-$(CONFIG_SUNRPC_BACKCHANNEL) += backchannel_rqst.o bc_svc.o sunrpc-$(CONFIG_PROC_FS) += stats.o sunrpc-$(CONFIG_SYSCTL) += sysctl.o diff --git a/net/sunrpc/addr.c b/net/sunrpc/addr.c index f845d9d72f7..a622ad64acd 100644 --- a/net/sunrpc/addr.c +++ b/net/sunrpc/addr.c @@ -17,9 +17,12 @@ */ #include <net/ipv6.h> -#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/msg_prot.h> +#include <linux/slab.h> +#include <linux/export.h> -#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) +#if IS_ENABLED(CONFIG_IPV6) static size_t rpc_ntop6_noscopeid(const struct sockaddr *sap, char *buf, const int buflen) @@ -89,7 +92,7 @@ static size_t rpc_ntop6(const struct sockaddr *sap, return len; } -#else /* !(defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)) */ +#else /* !IS_ENABLED(CONFIG_IPV6) */ static size_t rpc_ntop6_noscopeid(const struct sockaddr *sap, char *buf, const int buflen) @@ -103,7 +106,7 @@ static size_t rpc_ntop6(const struct sockaddr *sap, return 0; } -#endif /* !(defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)) */ +#endif /* !IS_ENABLED(CONFIG_IPV6) */ static int rpc_ntop4(const struct sockaddr *sap, char *buf, const size_t buflen) @@ -150,12 +153,13 @@ static size_t rpc_pton4(const char *buf, const size_t buflen, return 0; sin->sin_family = AF_INET; - return sizeof(struct sockaddr_in);; + return sizeof(struct sockaddr_in); } -#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) -static int rpc_parse_scope_id(const char *buf, const size_t buflen, - const char *delim, struct sockaddr_in6 *sin6) +#if IS_ENABLED(CONFIG_IPV6) +static int rpc_parse_scope_id(struct net *net, const char *buf, + const size_t buflen, const char *delim, + struct sockaddr_in6 *sin6) { char *p; size_t len; @@ -175,7 +179,7 @@ static int rpc_parse_scope_id(const char *buf, const size_t buflen, unsigned long scope_id = 0; struct net_device *dev; - dev = dev_get_by_name(&init_net, p); + dev = dev_get_by_name(net, p); if (dev != NULL) { scope_id = dev->ifindex; dev_put(dev); @@ -195,7 +199,7 @@ static int rpc_parse_scope_id(const char *buf, const size_t buflen, return 0; } -static size_t rpc_pton6(const char *buf, const size_t buflen, +static size_t rpc_pton6(struct net *net, const char *buf, const size_t buflen, struct sockaddr *sap, const size_t salen) { struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap; @@ -211,14 +215,14 @@ static size_t rpc_pton6(const char *buf, const size_t buflen, if (in6_pton(buf, buflen, addr, IPV6_SCOPE_DELIMITER, &delim) == 0) return 0; - if (!rpc_parse_scope_id(buf, buflen, delim, sin6)) + if (!rpc_parse_scope_id(net, buf, buflen, delim, sin6)) return 0; sin6->sin6_family = AF_INET6; return sizeof(struct sockaddr_in6); } #else -static size_t rpc_pton6(const char *buf, const size_t buflen, +static size_t rpc_pton6(struct net *net, const char *buf, const size_t buflen, struct sockaddr *sap, const size_t salen) { return 0; @@ -227,6 +231,7 @@ static size_t rpc_pton6(const char *buf, const size_t buflen, /** * rpc_pton - Construct a sockaddr in @sap + * @net: applicable network namespace * @buf: C string containing presentation format IP address * @buflen: length of presentation address in bytes * @sap: buffer into which to plant socket address @@ -239,14 +244,14 @@ static size_t rpc_pton6(const char *buf, const size_t buflen, * socket address, if successful. Returns zero if an error * occurred. */ -size_t rpc_pton(const char *buf, const size_t buflen, +size_t rpc_pton(struct net *net, const char *buf, const size_t buflen, struct sockaddr *sap, const size_t salen) { unsigned int i; for (i = 0; i < buflen; i++) if (buf[i] == ':') - return rpc_pton6(buf, buflen, sap, salen); + return rpc_pton6(net, buf, buflen, sap, salen); return rpc_pton4(buf, buflen, sap, salen); } EXPORT_SYMBOL_GPL(rpc_pton); @@ -254,12 +259,13 @@ EXPORT_SYMBOL_GPL(rpc_pton); /** * rpc_sockaddr2uaddr - Construct a universal address string from @sap. * @sap: socket address + * @gfp_flags: allocation mode * * Returns a %NUL-terminated string in dynamically allocated memory; * otherwise NULL is returned if an error occurred. Caller must * free the returned string. */ -char *rpc_sockaddr2uaddr(const struct sockaddr *sap) +char *rpc_sockaddr2uaddr(const struct sockaddr *sap, gfp_t gfp_flags) { char portbuf[RPCBIND_MAXUADDRPLEN]; char addrbuf[RPCBIND_MAXUADDRLEN]; @@ -287,12 +293,12 @@ char *rpc_sockaddr2uaddr(const struct sockaddr *sap) if (strlcat(addrbuf, portbuf, sizeof(addrbuf)) > sizeof(addrbuf)) return NULL; - return kstrdup(addrbuf, GFP_KERNEL); + return kstrdup(addrbuf, gfp_flags); } -EXPORT_SYMBOL_GPL(rpc_sockaddr2uaddr); /** * rpc_uaddr2sockaddr - convert a universal address to a socket address. + * @net: applicable network namespace * @uaddr: C string containing universal address to convert * @uaddr_len: length of universal address string * @sap: buffer into which to plant socket address @@ -304,8 +310,9 @@ EXPORT_SYMBOL_GPL(rpc_sockaddr2uaddr); * Returns the size of the socket address if successful; otherwise * zero is returned. */ -size_t rpc_uaddr2sockaddr(const char *uaddr, const size_t uaddr_len, - struct sockaddr *sap, const size_t salen) +size_t rpc_uaddr2sockaddr(struct net *net, const char *uaddr, + const size_t uaddr_len, struct sockaddr *sap, + const size_t salen) { char *c, buf[RPCBIND_MAXUADDRLEN + sizeof('\0')]; unsigned long portlo, porthi; @@ -337,7 +344,7 @@ size_t rpc_uaddr2sockaddr(const char *uaddr, const size_t uaddr_len, port = (unsigned short)((porthi << 8) | portlo); *c = '\0'; - if (rpc_pton(buf, strlen(buf), sap, salen) == 0) + if (rpc_pton(net, buf, strlen(buf), sap, salen) == 0) return 0; switch (sap->sa_family) { diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c index f394fc190a4..f7736671742 100644 --- a/net/sunrpc/auth.c +++ b/net/sunrpc/auth.c @@ -13,12 +13,22 @@ #include <linux/errno.h> #include <linux/hash.h> #include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/gss_api.h> #include <linux/spinlock.h> #ifdef RPC_DEBUG # define RPCDBG_FACILITY RPCDBG_AUTH #endif +#define RPC_CREDCACHE_DEFAULT_HASHBITS (4) +struct rpc_cred_cache { + struct hlist_head *hashtable; + unsigned int hashbits; + spinlock_t lock; +}; + +static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS; + static DEFINE_SPINLOCK(rpc_authflavor_lock); static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = { &authnull_ops, /* AUTH_NULL */ @@ -29,9 +39,50 @@ static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = { static LIST_HEAD(cred_unused); static unsigned long number_cred_unused; +#define MAX_HASHTABLE_BITS (14) +static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp) +{ + unsigned long num; + unsigned int nbits; + int ret; + + if (!val) + goto out_inval; + ret = strict_strtoul(val, 0, &num); + if (ret == -EINVAL) + goto out_inval; + nbits = fls(num); + if (num > (1U << nbits)) + nbits++; + if (nbits > MAX_HASHTABLE_BITS || nbits < 2) + goto out_inval; + *(unsigned int *)kp->arg = nbits; + return 0; +out_inval: + return -EINVAL; +} + +static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp) +{ + unsigned int nbits; + + nbits = *(unsigned int *)kp->arg; + return sprintf(buffer, "%u", 1U << nbits); +} + +#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int); + +static struct kernel_param_ops param_ops_hashtbl_sz = { + .set = param_set_hashtbl_sz, + .get = param_get_hashtbl_sz, +}; + +module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644); +MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size"); + static u32 pseudoflavor_to_flavor(u32 flavor) { - if (flavor >= RPC_AUTH_MAXFLAVOR) + if (flavor > RPC_AUTH_MAXFLAVOR) return RPC_AUTH_GSS; return flavor; } @@ -72,12 +123,138 @@ rpcauth_unregister(const struct rpc_authops *ops) } EXPORT_SYMBOL_GPL(rpcauth_unregister); +/** + * rpcauth_get_pseudoflavor - check if security flavor is supported + * @flavor: a security flavor + * @info: a GSS mech OID, quality of protection, and service value + * + * Verifies that an appropriate kernel module is available or already loaded. + * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is + * not supported locally. + */ +rpc_authflavor_t +rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info) +{ + const struct rpc_authops *ops; + rpc_authflavor_t pseudoflavor; + + ops = auth_flavors[flavor]; + if (ops == NULL) + request_module("rpc-auth-%u", flavor); + spin_lock(&rpc_authflavor_lock); + ops = auth_flavors[flavor]; + if (ops == NULL || !try_module_get(ops->owner)) { + spin_unlock(&rpc_authflavor_lock); + return RPC_AUTH_MAXFLAVOR; + } + spin_unlock(&rpc_authflavor_lock); + + pseudoflavor = flavor; + if (ops->info2flavor != NULL) + pseudoflavor = ops->info2flavor(info); + + module_put(ops->owner); + return pseudoflavor; +} +EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor); + +/** + * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor + * @pseudoflavor: GSS pseudoflavor to match + * @info: rpcsec_gss_info structure to fill in + * + * Returns zero and fills in "info" if pseudoflavor matches a + * supported mechanism. + */ +int +rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info) +{ + rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor); + const struct rpc_authops *ops; + int result; + + if (flavor >= RPC_AUTH_MAXFLAVOR) + return -EINVAL; + + ops = auth_flavors[flavor]; + if (ops == NULL) + request_module("rpc-auth-%u", flavor); + spin_lock(&rpc_authflavor_lock); + ops = auth_flavors[flavor]; + if (ops == NULL || !try_module_get(ops->owner)) { + spin_unlock(&rpc_authflavor_lock); + return -ENOENT; + } + spin_unlock(&rpc_authflavor_lock); + + result = -ENOENT; + if (ops->flavor2info != NULL) + result = ops->flavor2info(pseudoflavor, info); + + module_put(ops->owner); + return result; +} +EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo); + +/** + * rpcauth_list_flavors - discover registered flavors and pseudoflavors + * @array: array to fill in + * @size: size of "array" + * + * Returns the number of array items filled in, or a negative errno. + * + * The returned array is not sorted by any policy. Callers should not + * rely on the order of the items in the returned array. + */ +int +rpcauth_list_flavors(rpc_authflavor_t *array, int size) +{ + rpc_authflavor_t flavor; + int result = 0; + + spin_lock(&rpc_authflavor_lock); + for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) { + const struct rpc_authops *ops = auth_flavors[flavor]; + rpc_authflavor_t pseudos[4]; + int i, len; + + if (result >= size) { + result = -ENOMEM; + break; + } + + if (ops == NULL) + continue; + if (ops->list_pseudoflavors == NULL) { + array[result++] = ops->au_flavor; + continue; + } + len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos)); + if (len < 0) { + result = len; + break; + } + for (i = 0; i < len; i++) { + if (result >= size) { + result = -ENOMEM; + break; + } + array[result++] = pseudos[i]; + } + } + spin_unlock(&rpc_authflavor_lock); + + dprintk("RPC: %s returns %d\n", __func__, result); + return result; +} +EXPORT_SYMBOL_GPL(rpcauth_list_flavors); + struct rpc_auth * -rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt) +rpcauth_create(struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { struct rpc_auth *auth; const struct rpc_authops *ops; - u32 flavor = pseudoflavor_to_flavor(pseudoflavor); + u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor); auth = ERR_PTR(-EINVAL); if (flavor >= RPC_AUTH_MAXFLAVOR) @@ -92,7 +269,7 @@ rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt) goto out; } spin_unlock(&rpc_authflavor_lock); - auth = ops->create(clnt, pseudoflavor); + auth = ops->create(args, clnt); module_put(ops->owner); if (IS_ERR(auth)) return auth; @@ -119,7 +296,7 @@ static void rpcauth_unhash_cred_locked(struct rpc_cred *cred) { hlist_del_rcu(&cred->cr_hash); - smp_mb__before_clear_bit(); + smp_mb__before_atomic(); clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); } @@ -145,20 +322,48 @@ int rpcauth_init_credcache(struct rpc_auth *auth) { struct rpc_cred_cache *new; - int i; + unsigned int hashsize; new = kmalloc(sizeof(*new), GFP_KERNEL); if (!new) - return -ENOMEM; - for (i = 0; i < RPC_CREDCACHE_NR; i++) - INIT_HLIST_HEAD(&new->hashtable[i]); + goto out_nocache; + new->hashbits = auth_hashbits; + hashsize = 1U << new->hashbits; + new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL); + if (!new->hashtable) + goto out_nohashtbl; spin_lock_init(&new->lock); auth->au_credcache = new; return 0; +out_nohashtbl: + kfree(new); +out_nocache: + return -ENOMEM; } EXPORT_SYMBOL_GPL(rpcauth_init_credcache); /* + * Setup a credential key lifetime timeout notification + */ +int +rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred) +{ + if (!cred->cr_auth->au_ops->key_timeout) + return 0; + return cred->cr_auth->au_ops->key_timeout(auth, cred); +} +EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify); + +bool +rpcauth_cred_key_to_expire(struct rpc_cred *cred) +{ + if (!cred->cr_ops->crkey_to_expire) + return false; + return cred->cr_ops->crkey_to_expire(cred); +} +EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire); + +/* * Destroy a list of credentials */ static inline @@ -183,11 +388,12 @@ rpcauth_clear_credcache(struct rpc_cred_cache *cache) LIST_HEAD(free); struct hlist_head *head; struct rpc_cred *cred; + unsigned int hashsize = 1U << cache->hashbits; int i; spin_lock(&rpc_credcache_lock); spin_lock(&cache->lock); - for (i = 0; i < RPC_CREDCACHE_NR; i++) { + for (i = 0; i < hashsize; i++) { head = &cache->hashtable[i]; while (!hlist_empty(head)) { cred = hlist_entry(head->first, struct rpc_cred, cr_hash); @@ -216,6 +422,7 @@ rpcauth_destroy_credcache(struct rpc_auth *auth) if (cache) { auth->au_credcache = NULL; rpcauth_clear_credcache(cache); + kfree(cache->hashtable); kfree(cache); } } @@ -227,22 +434,29 @@ EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache); /* * Remove stale credentials. Avoid sleeping inside the loop. */ -static int +static long rpcauth_prune_expired(struct list_head *free, int nr_to_scan) { spinlock_t *cache_lock; struct rpc_cred *cred, *next; unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM; + long freed = 0; list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) { - /* Enforce a 60 second garbage collection moratorium */ - if (time_in_range_open(cred->cr_expire, expired, jiffies) && + if (nr_to_scan-- == 0) + break; + /* + * Enforce a 60 second garbage collection moratorium + * Note that the cred_unused list must be time-ordered. + */ + if (time_in_range(cred->cr_expire, expired, jiffies) && test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) - continue; + break; list_del_init(&cred->cr_lru); number_cred_unused--; + freed++; if (atomic_read(&cred->cr_count) != 0) continue; @@ -252,32 +466,42 @@ rpcauth_prune_expired(struct list_head *free, int nr_to_scan) get_rpccred(cred); list_add_tail(&cred->cr_lru, free); rpcauth_unhash_cred_locked(cred); - nr_to_scan--; } spin_unlock(cache_lock); - if (nr_to_scan == 0) - break; } - return nr_to_scan; + return freed; } /* * Run memory cache shrinker. */ -static int -rpcauth_cache_shrinker(int nr_to_scan, gfp_t gfp_mask) +static unsigned long +rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc) + { LIST_HEAD(free); - int res; + unsigned long freed; + + if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL) + return SHRINK_STOP; + /* nothing left, don't come back */ if (list_empty(&cred_unused)) - return 0; + return SHRINK_STOP; + spin_lock(&rpc_credcache_lock); - nr_to_scan = rpcauth_prune_expired(&free, nr_to_scan); - res = (number_cred_unused / 100) * sysctl_vfs_cache_pressure; + freed = rpcauth_prune_expired(&free, sc->nr_to_scan); spin_unlock(&rpc_credcache_lock); rpcauth_destroy_credlist(&free); - return res; + + return freed; +} + +static unsigned long +rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc) + +{ + return (number_cred_unused / 100) * sysctl_vfs_cache_pressure; } /* @@ -289,15 +513,14 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, { LIST_HEAD(free); struct rpc_cred_cache *cache = auth->au_credcache; - struct hlist_node *pos; struct rpc_cred *cred = NULL, *entry, *new; unsigned int nr; - nr = hash_long(acred->uid, RPC_CREDCACHE_HASHBITS); + nr = hash_long(from_kuid(&init_user_ns, acred->uid), cache->hashbits); rcu_read_lock(); - hlist_for_each_entry_rcu(entry, pos, &cache->hashtable[nr], cr_hash) { + hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) { if (!entry->cr_ops->crmatch(acred, entry, flags)) continue; spin_lock(&cache->lock); @@ -321,7 +544,7 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, } spin_lock(&cache->lock); - hlist_for_each_entry(entry, pos, &cache->hashtable[nr], cr_hash) { + hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) { if (!entry->cr_ops->crmatch(acred, entry, flags)) continue; cred = get_rpccred(entry); @@ -369,6 +592,7 @@ rpcauth_lookupcred(struct rpc_auth *auth, int flags) put_group_info(acred.group_info); return ret; } +EXPORT_SYMBOL_GPL(rpcauth_lookupcred); void rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, @@ -387,62 +611,60 @@ rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, } EXPORT_SYMBOL_GPL(rpcauth_init_cred); -void +struct rpc_cred * rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags) { - task->tk_msg.rpc_cred = get_rpccred(cred); dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + return get_rpccred(cred); } EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred); -static void +static struct rpc_cred * rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags) { struct rpc_auth *auth = task->tk_client->cl_auth; struct auth_cred acred = { - .uid = 0, - .gid = 0, + .uid = GLOBAL_ROOT_UID, + .gid = GLOBAL_ROOT_GID, }; - struct rpc_cred *ret; dprintk("RPC: %5u looking up %s cred\n", task->tk_pid, task->tk_client->cl_auth->au_ops->au_name); - ret = auth->au_ops->lookup_cred(auth, &acred, lookupflags); - if (!IS_ERR(ret)) - task->tk_msg.rpc_cred = ret; - else - task->tk_status = PTR_ERR(ret); + return auth->au_ops->lookup_cred(auth, &acred, lookupflags); } -static void +static struct rpc_cred * rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags) { struct rpc_auth *auth = task->tk_client->cl_auth; - struct rpc_cred *ret; dprintk("RPC: %5u looking up %s cred\n", task->tk_pid, auth->au_ops->au_name); - ret = rpcauth_lookupcred(auth, lookupflags); - if (!IS_ERR(ret)) - task->tk_msg.rpc_cred = ret; - else - task->tk_status = PTR_ERR(ret); + return rpcauth_lookupcred(auth, lookupflags); } -void +static int rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags) { + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_cred *new; int lookupflags = 0; if (flags & RPC_TASK_ASYNC) lookupflags |= RPCAUTH_LOOKUP_NEW; if (cred != NULL) - cred->cr_ops->crbind(task, cred, lookupflags); + new = cred->cr_ops->crbind(task, cred, lookupflags); else if (flags & RPC_TASK_ROOTCREDS) - rpcauth_bind_root_cred(task, lookupflags); + new = rpcauth_bind_root_cred(task, lookupflags); else - rpcauth_bind_new_cred(task, lookupflags); + new = rpcauth_bind_new_cred(task, lookupflags); + if (IS_ERR(new)) + return PTR_ERR(new); + if (req->rq_cred != NULL) + put_rpccred(req->rq_cred); + req->rq_cred = new; + return 0; } void @@ -481,22 +703,10 @@ out_nodestroy: } EXPORT_SYMBOL_GPL(put_rpccred); -void -rpcauth_unbindcred(struct rpc_task *task) -{ - struct rpc_cred *cred = task->tk_msg.rpc_cred; - - dprintk("RPC: %5u releasing %s cred %p\n", - task->tk_pid, cred->cr_auth->au_ops->au_name, cred); - - put_rpccred(cred); - task->tk_msg.rpc_cred = NULL; -} - __be32 * rpcauth_marshcred(struct rpc_task *task, __be32 *p) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; dprintk("RPC: %5u marshaling %s cred %p\n", task->tk_pid, cred->cr_auth->au_ops->au_name, cred); @@ -507,7 +717,7 @@ rpcauth_marshcred(struct rpc_task *task, __be32 *p) __be32 * rpcauth_checkverf(struct rpc_task *task, __be32 *p) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; dprintk("RPC: %5u validating %s cred %p\n", task->tk_pid, cred->cr_auth->au_ops->au_name, cred); @@ -515,25 +725,45 @@ rpcauth_checkverf(struct rpc_task *task, __be32 *p) return cred->cr_ops->crvalidate(task, p); } +static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp, + __be32 *data, void *obj) +{ + struct xdr_stream xdr; + + xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data); + encode(rqstp, &xdr, obj); +} + int -rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp, +rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp, __be32 *data, void *obj) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; dprintk("RPC: %5u using %s cred %p to wrap rpc data\n", task->tk_pid, cred->cr_ops->cr_name, cred); if (cred->cr_ops->crwrap_req) return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj); /* By default, we encode the arguments normally. */ - return encode(rqstp, data, obj); + rpcauth_wrap_req_encode(encode, rqstp, data, obj); + return 0; +} + +static int +rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, + __be32 *data, void *obj) +{ + struct xdr_stream xdr; + + xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data); + return decode(rqstp, &xdr, obj); } int -rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp, +rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp, __be32 *data, void *obj) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n", task->tk_pid, cred->cr_ops->cr_name, cred); @@ -541,19 +771,27 @@ rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp, return cred->cr_ops->crunwrap_resp(task, decode, rqstp, data, obj); /* By default, we decode the arguments normally. */ - return decode(rqstp, data, obj); + return rpcauth_unwrap_req_decode(decode, rqstp, data, obj); } int rpcauth_refreshcred(struct rpc_task *task) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred; int err; + cred = task->tk_rqstp->rq_cred; + if (cred == NULL) { + err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags); + if (err < 0) + goto out; + cred = task->tk_rqstp->rq_cred; + } dprintk("RPC: %5u refreshing %s cred %p\n", task->tk_pid, cred->cr_auth->au_ops->au_name, cred); err = cred->cr_ops->crrefresh(task); +out: if (err < 0) task->tk_status = err; return err; @@ -562,7 +800,7 @@ rpcauth_refreshcred(struct rpc_task *task) void rpcauth_invalcred(struct rpc_task *task) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; dprintk("RPC: %5u invalidating %s cred %p\n", task->tk_pid, cred->cr_auth->au_ops->au_name, cred); @@ -573,25 +811,39 @@ rpcauth_invalcred(struct rpc_task *task) int rpcauth_uptodatecred(struct rpc_task *task) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; return cred == NULL || test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0; } static struct shrinker rpc_cred_shrinker = { - .shrink = rpcauth_cache_shrinker, + .count_objects = rpcauth_cache_shrink_count, + .scan_objects = rpcauth_cache_shrink_scan, .seeks = DEFAULT_SEEKS, }; -void __init rpcauth_init_module(void) +int __init rpcauth_init_module(void) { - rpc_init_authunix(); - rpc_init_generic_auth(); + int err; + + err = rpc_init_authunix(); + if (err < 0) + goto out1; + err = rpc_init_generic_auth(); + if (err < 0) + goto out2; register_shrinker(&rpc_cred_shrinker); + return 0; +out2: + rpc_destroy_authunix(); +out1: + return err; } -void __exit rpcauth_remove_module(void) +void rpcauth_remove_module(void) { + rpc_destroy_authunix(); + rpc_destroy_generic_auth(); unregister_shrinker(&rpc_cred_shrinker); } diff --git a/net/sunrpc/auth_generic.c b/net/sunrpc/auth_generic.c index bf88bf8e936..ed04869b2d4 100644 --- a/net/sunrpc/auth_generic.c +++ b/net/sunrpc/auth_generic.c @@ -5,6 +5,7 @@ */ #include <linux/err.h> +#include <linux/slab.h> #include <linux/types.h> #include <linux/module.h> #include <linux/sched.h> @@ -17,8 +18,8 @@ # define RPCDBG_FACILITY RPCDBG_AUTH #endif -#define RPC_MACHINE_CRED_USERID ((uid_t)0) -#define RPC_MACHINE_CRED_GROUPID ((gid_t)0) +#define RPC_MACHINE_CRED_USERID GLOBAL_ROOT_UID +#define RPC_MACHINE_CRED_GROUPID GLOBAL_ROOT_GID struct generic_cred { struct rpc_cred gc_base; @@ -26,7 +27,6 @@ struct generic_cred { }; static struct rpc_auth generic_auth; -static struct rpc_cred_cache generic_cred_cache; static const struct rpc_credops generic_credops; /* @@ -41,31 +41,28 @@ EXPORT_SYMBOL_GPL(rpc_lookup_cred); /* * Public call interface for looking up machine creds. */ -struct rpc_cred *rpc_lookup_machine_cred(void) +struct rpc_cred *rpc_lookup_machine_cred(const char *service_name) { struct auth_cred acred = { .uid = RPC_MACHINE_CRED_USERID, .gid = RPC_MACHINE_CRED_GROUPID, + .principal = service_name, .machine_cred = 1, }; - dprintk("RPC: looking up machine cred\n"); + dprintk("RPC: looking up machine cred for service %s\n", + service_name); return generic_auth.au_ops->lookup_cred(&generic_auth, &acred, 0); } EXPORT_SYMBOL_GPL(rpc_lookup_machine_cred); -static void -generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags) +static struct rpc_cred *generic_bind_cred(struct rpc_task *task, + struct rpc_cred *cred, int lookupflags) { struct rpc_auth *auth = task->tk_client->cl_auth; struct auth_cred *acred = &container_of(cred, struct generic_cred, gc_base)->acred; - struct rpc_cred *ret; - ret = auth->au_ops->lookup_cred(auth, acred, lookupflags); - if (!IS_ERR(ret)) - task->tk_msg.rpc_cred = ret; - else - task->tk_status = PTR_ERR(ret); + return auth->au_ops->lookup_cred(auth, acred, lookupflags); } /* @@ -92,13 +89,17 @@ generic_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) gcred->acred.uid = acred->uid; gcred->acred.gid = acred->gid; gcred->acred.group_info = acred->group_info; + gcred->acred.ac_flags = 0; if (gcred->acred.group_info != NULL) get_group_info(gcred->acred.group_info); gcred->acred.machine_cred = acred->machine_cred; + gcred->acred.principal = acred->principal; dprintk("RPC: allocated %s cred %p for uid %d gid %d\n", gcred->acred.machine_cred ? "machine" : "generic", - gcred, acred->uid, acred->gid); + gcred, + from_kuid(&init_user_ns, acred->uid), + from_kgid(&init_user_ns, acred->gid)); return &gcred->gc_base; } @@ -126,6 +127,17 @@ generic_destroy_cred(struct rpc_cred *cred) call_rcu(&cred->cr_rcu, generic_free_cred_callback); } +static int +machine_cred_match(struct auth_cred *acred, struct generic_cred *gcred, int flags) +{ + if (!gcred->acred.machine_cred || + gcred->acred.principal != acred->principal || + !uid_eq(gcred->acred.uid, acred->uid) || + !gid_eq(gcred->acred.gid, acred->gid)) + return 0; + return 1; +} + /* * Match credentials against current process creds. */ @@ -135,9 +147,12 @@ generic_match(struct auth_cred *acred, struct rpc_cred *cred, int flags) struct generic_cred *gcred = container_of(cred, struct generic_cred, gc_base); int i; - if (gcred->acred.uid != acred->uid || - gcred->acred.gid != acred->gid || - gcred->acred.machine_cred != acred->machine_cred) + if (acred->machine_cred) + return machine_cred_match(acred, gcred, flags); + + if (!uid_eq(gcred->acred.uid, acred->uid) || + !gid_eq(gcred->acred.gid, acred->gid) || + gcred->acred.machine_cred != 0) goto out_nomatch; /* Optimisation in the case where pointers are identical... */ @@ -148,8 +163,8 @@ generic_match(struct auth_cred *acred, struct rpc_cred *cred, int flags) if (gcred->acred.group_info->ngroups != acred->group_info->ngroups) goto out_nomatch; for (i = 0; i < gcred->acred.group_info->ngroups; i++) { - if (GROUP_AT(gcred->acred.group_info, i) != - GROUP_AT(acred->group_info, i)) + if (!gid_eq(GROUP_AT(gcred->acred.group_info, i), + GROUP_AT(acred->group_info, i))) goto out_nomatch; } out_match: @@ -158,36 +173,112 @@ out_nomatch: return 0; } -void __init rpc_init_generic_auth(void) +int __init rpc_init_generic_auth(void) { - spin_lock_init(&generic_cred_cache.lock); + return rpcauth_init_credcache(&generic_auth); } -void __exit rpc_destroy_generic_auth(void) +void rpc_destroy_generic_auth(void) { - rpcauth_clear_credcache(&generic_cred_cache); + rpcauth_destroy_credcache(&generic_auth); } -static struct rpc_cred_cache generic_cred_cache = { - {{ NULL, },}, -}; +/* + * Test the the current time (now) against the underlying credential key expiry + * minus a timeout and setup notification. + * + * The normal case: + * If 'now' is before the key expiry minus RPC_KEY_EXPIRE_TIMEO, set + * the RPC_CRED_NOTIFY_TIMEOUT flag to setup the underlying credential + * rpc_credops crmatch routine to notify this generic cred when it's key + * expiration is within RPC_KEY_EXPIRE_TIMEO, and return 0. + * + * The error case: + * If the underlying cred lookup fails, return -EACCES. + * + * The 'almost' error case: + * If 'now' is within key expiry minus RPC_KEY_EXPIRE_TIMEO, but not within + * key expiry minus RPC_KEY_EXPIRE_FAIL, set the RPC_CRED_EXPIRE_SOON bit + * on the acred ac_flags and return 0. + */ +static int +generic_key_timeout(struct rpc_auth *auth, struct rpc_cred *cred) +{ + struct auth_cred *acred = &container_of(cred, struct generic_cred, + gc_base)->acred; + struct rpc_cred *tcred; + int ret = 0; + + + /* Fast track for non crkey_timeout (no key) underlying credentials */ + if (test_bit(RPC_CRED_NO_CRKEY_TIMEOUT, &acred->ac_flags)) + return 0; + + /* Fast track for the normal case */ + if (test_bit(RPC_CRED_NOTIFY_TIMEOUT, &acred->ac_flags)) + return 0; + + /* lookup_cred either returns a valid referenced rpc_cred, or PTR_ERR */ + tcred = auth->au_ops->lookup_cred(auth, acred, 0); + if (IS_ERR(tcred)) + return -EACCES; + + if (!tcred->cr_ops->crkey_timeout) { + set_bit(RPC_CRED_NO_CRKEY_TIMEOUT, &acred->ac_flags); + ret = 0; + goto out_put; + } + + /* Test for the almost error case */ + ret = tcred->cr_ops->crkey_timeout(tcred); + if (ret != 0) { + set_bit(RPC_CRED_KEY_EXPIRE_SOON, &acred->ac_flags); + ret = 0; + } else { + /* In case underlying cred key has been reset */ + if (test_and_clear_bit(RPC_CRED_KEY_EXPIRE_SOON, + &acred->ac_flags)) + dprintk("RPC: UID %d Credential key reset\n", + from_kuid(&init_user_ns, tcred->cr_uid)); + /* set up fasttrack for the normal case */ + set_bit(RPC_CRED_NOTIFY_TIMEOUT, &acred->ac_flags); + } + +out_put: + put_rpccred(tcred); + return ret; +} static const struct rpc_authops generic_auth_ops = { .owner = THIS_MODULE, .au_name = "Generic", .lookup_cred = generic_lookup_cred, .crcreate = generic_create_cred, + .key_timeout = generic_key_timeout, }; static struct rpc_auth generic_auth = { .au_ops = &generic_auth_ops, .au_count = ATOMIC_INIT(0), - .au_credcache = &generic_cred_cache, }; +static bool generic_key_to_expire(struct rpc_cred *cred) +{ + struct auth_cred *acred = &container_of(cred, struct generic_cred, + gc_base)->acred; + bool ret; + + get_rpccred(cred); + ret = test_bit(RPC_CRED_KEY_EXPIRE_SOON, &acred->ac_flags); + put_rpccred(cred); + + return ret; +} + static const struct rpc_credops generic_credops = { .cr_name = "Generic cred", .crdestroy = generic_destroy_cred, .crbind = generic_bind_cred, .crmatch = generic_match, + .crkey_to_expire = generic_key_to_expire, }; diff --git a/net/sunrpc/auth_gss/Makefile b/net/sunrpc/auth_gss/Makefile index 4de8bcf26fa..14e9e53e63d 100644 --- a/net/sunrpc/auth_gss/Makefile +++ b/net/sunrpc/auth_gss/Makefile @@ -4,15 +4,11 @@ obj-$(CONFIG_SUNRPC_GSS) += auth_rpcgss.o -auth_rpcgss-objs := auth_gss.o gss_generic_token.o \ - gss_mech_switch.o svcauth_gss.o +auth_rpcgss-y := auth_gss.o gss_generic_token.o \ + gss_mech_switch.o svcauth_gss.o \ + gss_rpc_upcall.o gss_rpc_xdr.o obj-$(CONFIG_RPCSEC_GSS_KRB5) += rpcsec_gss_krb5.o -rpcsec_gss_krb5-objs := gss_krb5_mech.o gss_krb5_seal.o gss_krb5_unseal.o \ - gss_krb5_seqnum.o gss_krb5_wrap.o gss_krb5_crypto.o - -obj-$(CONFIG_RPCSEC_GSS_SPKM3) += rpcsec_gss_spkm3.o - -rpcsec_gss_spkm3-objs := gss_spkm3_mech.o gss_spkm3_seal.o gss_spkm3_unseal.o \ - gss_spkm3_token.o +rpcsec_gss_krb5-y := gss_krb5_mech.o gss_krb5_seal.o gss_krb5_unseal.o \ + gss_krb5_seqnum.o gss_krb5_wrap.o gss_krb5_crypto.o gss_krb5_keys.o diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c index 0cfccc2a029..b6e440baccc 100644 --- a/net/sunrpc/auth_gss/auth_gss.c +++ b/net/sunrpc/auth_gss/auth_gss.c @@ -51,42 +51,64 @@ #include <linux/sunrpc/rpc_pipe_fs.h> #include <linux/sunrpc/gss_api.h> #include <asm/uaccess.h> +#include <linux/hashtable.h> + +#include "../netns.h" static const struct rpc_authops authgss_ops; static const struct rpc_credops gss_credops; static const struct rpc_credops gss_nullops; +#define GSS_RETRY_EXPIRED 5 +static unsigned int gss_expired_cred_retry_delay = GSS_RETRY_EXPIRED; + +#define GSS_KEY_EXPIRE_TIMEO 240 +static unsigned int gss_key_expire_timeo = GSS_KEY_EXPIRE_TIMEO; + #ifdef RPC_DEBUG # define RPCDBG_FACILITY RPCDBG_AUTH #endif -#define GSS_CRED_SLACK 1024 +#define GSS_CRED_SLACK (RPC_MAX_AUTH_SIZE * 2) /* length of a krb5 verifier (48), plus data added before arguments when * using integrity (two 4-byte integers): */ #define GSS_VERF_SLACK 100 +static DEFINE_HASHTABLE(gss_auth_hash_table, 4); +static DEFINE_SPINLOCK(gss_auth_hash_lock); + +struct gss_pipe { + struct rpc_pipe_dir_object pdo; + struct rpc_pipe *pipe; + struct rpc_clnt *clnt; + const char *name; + struct kref kref; +}; + struct gss_auth { struct kref kref; + struct hlist_node hash; struct rpc_auth rpc_auth; struct gss_api_mech *mech; enum rpc_gss_svc service; struct rpc_clnt *client; + struct net *net; /* * There are two upcall pipes; dentry[1], named "gssd", is used * for the new text-based upcall; dentry[0] is named after the * mechanism (for example, "krb5") and exists for * backwards-compatibility with older gssd's. */ - struct dentry *dentry[2]; + struct gss_pipe *gss_pipe[2]; + const char *target_name; }; /* pipe_version >= 0 if and only if someone has a pipe open. */ -static int pipe_version = -1; -static atomic_t pipe_users = ATOMIC_INIT(0); static DEFINE_SPINLOCK(pipe_version_lock); static struct rpc_wait_queue pipe_version_rpc_waitqueue; static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue); +static void gss_put_auth(struct gss_auth *gss_auth); static void gss_free_ctx(struct gss_cl_ctx *); static const struct rpc_pipe_ops gss_upcall_ops_v0; @@ -109,7 +131,7 @@ gss_put_ctx(struct gss_cl_ctx *ctx) /* gss_cred_set_ctx: * called by gss_upcall_callback and gss_create_upcall in order * to set the gss context. The actual exchange of an old context - * and a new one is protected by the inode->i_lock. + * and a new one is protected by the pipe->lock. */ static void gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx) @@ -121,7 +143,7 @@ gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx) gss_get_ctx(ctx); rcu_assign_pointer(gss_cred->gc_ctx, ctx); set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); - smp_mb__before_clear_bit(); + smp_mb__before_atomic(); clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags); } @@ -189,17 +211,23 @@ gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct const void *q; unsigned int seclen; unsigned int timeout; + unsigned long now = jiffies; u32 window_size; int ret; - /* First unsigned int gives the lifetime (in seconds) of the cred */ + /* First unsigned int gives the remaining lifetime in seconds of the + * credential - e.g. the remaining TGT lifetime for Kerberos or + * the -t value passed to GSSD. + */ p = simple_get_bytes(p, end, &timeout, sizeof(timeout)); if (IS_ERR(p)) goto err; if (timeout == 0) timeout = GSSD_MIN_TIMEOUT; - ctx->gc_expiry = jiffies + (unsigned long)timeout * HZ * 3 / 4; - /* Sequence number window. Determines the maximum number of simultaneous requests */ + ctx->gc_expiry = now + ((unsigned long)timeout * HZ); + /* Sequence number window. Determines the maximum number of + * simultaneous requests + */ p = simple_get_bytes(p, end, &window_size, sizeof(window_size)); if (IS_ERR(p)) goto err; @@ -229,14 +257,16 @@ gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct p = ERR_PTR(-EFAULT); goto err; } - ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx); + ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, NULL, GFP_NOFS); if (ret < 0) { p = ERR_PTR(ret); goto err; } + dprintk("RPC: %s Success. gc_expiry %lu now %lu timeout %u\n", + __func__, ctx->gc_expiry, now, timeout); return q; err: - dprintk("RPC: gss_fill_context returning %ld\n", -PTR_ERR(p)); + dprintk("RPC: %s returns error %ld\n", __func__, -PTR_ERR(p)); return p; } @@ -244,35 +274,38 @@ err: struct gss_upcall_msg { atomic_t count; - uid_t uid; + kuid_t uid; struct rpc_pipe_msg msg; struct list_head list; struct gss_auth *auth; - struct rpc_inode *inode; + struct rpc_pipe *pipe; struct rpc_wait_queue rpc_waitqueue; wait_queue_head_t waitqueue; struct gss_cl_ctx *ctx; char databuf[UPCALL_BUF_LEN]; }; -static int get_pipe_version(void) +static int get_pipe_version(struct net *net) { + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); int ret; spin_lock(&pipe_version_lock); - if (pipe_version >= 0) { - atomic_inc(&pipe_users); - ret = pipe_version; + if (sn->pipe_version >= 0) { + atomic_inc(&sn->pipe_users); + ret = sn->pipe_version; } else ret = -EAGAIN; spin_unlock(&pipe_version_lock); return ret; } -static void put_pipe_version(void) +static void put_pipe_version(struct net *net) { - if (atomic_dec_and_lock(&pipe_users, &pipe_version_lock)) { - pipe_version = -1; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + if (atomic_dec_and_lock(&sn->pipe_users, &pipe_version_lock)) { + sn->pipe_version = -1; spin_unlock(&pipe_version_lock); } } @@ -280,28 +313,30 @@ static void put_pipe_version(void) static void gss_release_msg(struct gss_upcall_msg *gss_msg) { + struct net *net = gss_msg->auth->net; if (!atomic_dec_and_test(&gss_msg->count)) return; - put_pipe_version(); + put_pipe_version(net); BUG_ON(!list_empty(&gss_msg->list)); if (gss_msg->ctx != NULL) gss_put_ctx(gss_msg->ctx); rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue); + gss_put_auth(gss_msg->auth); kfree(gss_msg); } static struct gss_upcall_msg * -__gss_find_upcall(struct rpc_inode *rpci, uid_t uid) +__gss_find_upcall(struct rpc_pipe *pipe, kuid_t uid) { struct gss_upcall_msg *pos; - list_for_each_entry(pos, &rpci->in_downcall, list) { - if (pos->uid != uid) + list_for_each_entry(pos, &pipe->in_downcall, list) { + if (!uid_eq(pos->uid, uid)) continue; atomic_inc(&pos->count); - dprintk("RPC: gss_find_upcall found msg %p\n", pos); + dprintk("RPC: %s found msg %p\n", __func__, pos); return pos; } - dprintk("RPC: gss_find_upcall found nothing\n"); + dprintk("RPC: %s found nothing\n", __func__); return NULL; } @@ -312,18 +347,17 @@ __gss_find_upcall(struct rpc_inode *rpci, uid_t uid) static inline struct gss_upcall_msg * gss_add_msg(struct gss_upcall_msg *gss_msg) { - struct rpc_inode *rpci = gss_msg->inode; - struct inode *inode = &rpci->vfs_inode; + struct rpc_pipe *pipe = gss_msg->pipe; struct gss_upcall_msg *old; - spin_lock(&inode->i_lock); - old = __gss_find_upcall(rpci, gss_msg->uid); + spin_lock(&pipe->lock); + old = __gss_find_upcall(pipe, gss_msg->uid); if (old == NULL) { atomic_inc(&gss_msg->count); - list_add(&gss_msg->list, &rpci->in_downcall); + list_add(&gss_msg->list, &pipe->in_downcall); } else gss_msg = old; - spin_unlock(&inode->i_lock); + spin_unlock(&pipe->lock); return gss_msg; } @@ -339,122 +373,159 @@ __gss_unhash_msg(struct gss_upcall_msg *gss_msg) static void gss_unhash_msg(struct gss_upcall_msg *gss_msg) { - struct inode *inode = &gss_msg->inode->vfs_inode; + struct rpc_pipe *pipe = gss_msg->pipe; if (list_empty(&gss_msg->list)) return; - spin_lock(&inode->i_lock); + spin_lock(&pipe->lock); if (!list_empty(&gss_msg->list)) __gss_unhash_msg(gss_msg); - spin_unlock(&inode->i_lock); + spin_unlock(&pipe->lock); +} + +static void +gss_handle_downcall_result(struct gss_cred *gss_cred, struct gss_upcall_msg *gss_msg) +{ + switch (gss_msg->msg.errno) { + case 0: + if (gss_msg->ctx == NULL) + break; + clear_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags); + gss_cred_set_ctx(&gss_cred->gc_base, gss_msg->ctx); + break; + case -EKEYEXPIRED: + set_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags); + } + gss_cred->gc_upcall_timestamp = jiffies; + gss_cred->gc_upcall = NULL; + rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); } static void gss_upcall_callback(struct rpc_task *task) { - struct gss_cred *gss_cred = container_of(task->tk_msg.rpc_cred, + struct gss_cred *gss_cred = container_of(task->tk_rqstp->rq_cred, struct gss_cred, gc_base); struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall; - struct inode *inode = &gss_msg->inode->vfs_inode; + struct rpc_pipe *pipe = gss_msg->pipe; - spin_lock(&inode->i_lock); - if (gss_msg->ctx) - gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_msg->ctx); - else - task->tk_status = gss_msg->msg.errno; - gss_cred->gc_upcall = NULL; - rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); - spin_unlock(&inode->i_lock); + spin_lock(&pipe->lock); + gss_handle_downcall_result(gss_cred, gss_msg); + spin_unlock(&pipe->lock); + task->tk_status = gss_msg->msg.errno; gss_release_msg(gss_msg); } static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg) { - gss_msg->msg.data = &gss_msg->uid; - gss_msg->msg.len = sizeof(gss_msg->uid); + uid_t uid = from_kuid(&init_user_ns, gss_msg->uid); + memcpy(gss_msg->databuf, &uid, sizeof(uid)); + gss_msg->msg.data = gss_msg->databuf; + gss_msg->msg.len = sizeof(uid); + + BUILD_BUG_ON(sizeof(uid) > sizeof(gss_msg->databuf)); } -static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg, - struct rpc_clnt *clnt, int machine_cred) +static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg, + const char *service_name, + const char *target_name) { + struct gss_api_mech *mech = gss_msg->auth->mech; char *p = gss_msg->databuf; - int len = 0; - - gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d ", - gss_msg->auth->mech->gm_name, - gss_msg->uid); - p += gss_msg->msg.len; - if (clnt->cl_principal) { - len = sprintf(p, "target=%s ", clnt->cl_principal); + size_t buflen = sizeof(gss_msg->databuf); + int len; + + len = scnprintf(p, buflen, "mech=%s uid=%d ", mech->gm_name, + from_kuid(&init_user_ns, gss_msg->uid)); + buflen -= len; + p += len; + gss_msg->msg.len = len; + if (target_name) { + len = scnprintf(p, buflen, "target=%s ", target_name); + buflen -= len; p += len; gss_msg->msg.len += len; } - if (machine_cred) { - len = sprintf(p, "service=* "); + if (service_name != NULL) { + len = scnprintf(p, buflen, "service=%s ", service_name); + buflen -= len; p += len; gss_msg->msg.len += len; - } else if (!strcmp(clnt->cl_program->name, "nfs4_cb")) { - len = sprintf(p, "service=nfs "); + } + if (mech->gm_upcall_enctypes) { + len = scnprintf(p, buflen, "enctypes=%s ", + mech->gm_upcall_enctypes); + buflen -= len; p += len; gss_msg->msg.len += len; } - len = sprintf(p, "\n"); + len = scnprintf(p, buflen, "\n"); + if (len == 0) + goto out_overflow; gss_msg->msg.len += len; gss_msg->msg.data = gss_msg->databuf; - BUG_ON(gss_msg->msg.len > UPCALL_BUF_LEN); -} - -static void gss_encode_msg(struct gss_upcall_msg *gss_msg, - struct rpc_clnt *clnt, int machine_cred) -{ - if (pipe_version == 0) - gss_encode_v0_msg(gss_msg); - else /* pipe_version == 1 */ - gss_encode_v1_msg(gss_msg, clnt, machine_cred); + return 0; +out_overflow: + WARN_ON_ONCE(1); + return -ENOMEM; } -static inline struct gss_upcall_msg * -gss_alloc_msg(struct gss_auth *gss_auth, uid_t uid, struct rpc_clnt *clnt, - int machine_cred) +static struct gss_upcall_msg * +gss_alloc_msg(struct gss_auth *gss_auth, + kuid_t uid, const char *service_name) { struct gss_upcall_msg *gss_msg; int vers; + int err = -ENOMEM; gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS); if (gss_msg == NULL) - return ERR_PTR(-ENOMEM); - vers = get_pipe_version(); - if (vers < 0) { - kfree(gss_msg); - return ERR_PTR(vers); - } - gss_msg->inode = RPC_I(gss_auth->dentry[vers]->d_inode); + goto err; + vers = get_pipe_version(gss_auth->net); + err = vers; + if (err < 0) + goto err_free_msg; + gss_msg->pipe = gss_auth->gss_pipe[vers]->pipe; INIT_LIST_HEAD(&gss_msg->list); rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq"); init_waitqueue_head(&gss_msg->waitqueue); atomic_set(&gss_msg->count, 1); gss_msg->uid = uid; gss_msg->auth = gss_auth; - gss_encode_msg(gss_msg, clnt, machine_cred); + switch (vers) { + case 0: + gss_encode_v0_msg(gss_msg); + break; + default: + err = gss_encode_v1_msg(gss_msg, service_name, gss_auth->target_name); + if (err) + goto err_put_pipe_version; + }; + kref_get(&gss_auth->kref); return gss_msg; +err_put_pipe_version: + put_pipe_version(gss_auth->net); +err_free_msg: + kfree(gss_msg); +err: + return ERR_PTR(err); } static struct gss_upcall_msg * -gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred) +gss_setup_upcall(struct gss_auth *gss_auth, struct rpc_cred *cred) { struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_upcall_msg *gss_new, *gss_msg; - uid_t uid = cred->cr_uid; + kuid_t uid = cred->cr_uid; - gss_new = gss_alloc_msg(gss_auth, uid, clnt, gss_cred->gc_machine_cred); + gss_new = gss_alloc_msg(gss_auth, uid, gss_cred->gc_principal); if (IS_ERR(gss_new)) return gss_new; gss_msg = gss_add_msg(gss_new); if (gss_msg == gss_new) { - struct inode *inode = &gss_new->inode->vfs_inode; - int res = rpc_queue_upcall(inode, &gss_new->msg); + int res = rpc_queue_upcall(gss_new->pipe, &gss_new->msg); if (res) { gss_unhash_msg(gss_new); gss_msg = ERR_PTR(res); @@ -466,101 +537,104 @@ gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cr static void warn_gssd(void) { - static unsigned long ratelimit; - unsigned long now = jiffies; - - if (time_after(now, ratelimit)) { - printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n" - "Please check user daemon is running.\n"); - ratelimit = now + 15*HZ; - } + dprintk("AUTH_GSS upcall failed. Please check user daemon is running.\n"); } static inline int gss_refresh_upcall(struct rpc_task *task) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth); struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_upcall_msg *gss_msg; - struct inode *inode; + struct rpc_pipe *pipe; int err = 0; - dprintk("RPC: %5u gss_refresh_upcall for uid %u\n", task->tk_pid, - cred->cr_uid); - gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred); + dprintk("RPC: %5u %s for uid %u\n", + task->tk_pid, __func__, from_kuid(&init_user_ns, cred->cr_uid)); + gss_msg = gss_setup_upcall(gss_auth, cred); if (PTR_ERR(gss_msg) == -EAGAIN) { /* XXX: warning on the first, under the assumption we * shouldn't normally hit this case on a refresh. */ warn_gssd(); task->tk_timeout = 15*HZ; rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL); - return 0; + return -EAGAIN; } if (IS_ERR(gss_msg)) { err = PTR_ERR(gss_msg); goto out; } - inode = &gss_msg->inode->vfs_inode; - spin_lock(&inode->i_lock); + pipe = gss_msg->pipe; + spin_lock(&pipe->lock); if (gss_cred->gc_upcall != NULL) rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL); - else if (gss_msg->ctx != NULL) { - gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_msg->ctx); - gss_cred->gc_upcall = NULL; - rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); - } else if (gss_msg->msg.errno >= 0) { + else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) { task->tk_timeout = 0; gss_cred->gc_upcall = gss_msg; /* gss_upcall_callback will release the reference to gss_upcall_msg */ atomic_inc(&gss_msg->count); rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback); - } else + } else { + gss_handle_downcall_result(gss_cred, gss_msg); err = gss_msg->msg.errno; - spin_unlock(&inode->i_lock); + } + spin_unlock(&pipe->lock); gss_release_msg(gss_msg); out: - dprintk("RPC: %5u gss_refresh_upcall for uid %u result %d\n", - task->tk_pid, cred->cr_uid, err); + dprintk("RPC: %5u %s for uid %u result %d\n", + task->tk_pid, __func__, + from_kuid(&init_user_ns, cred->cr_uid), err); return err; } static inline int gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred) { - struct inode *inode; + struct net *net = gss_auth->net; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_pipe *pipe; struct rpc_cred *cred = &gss_cred->gc_base; struct gss_upcall_msg *gss_msg; DEFINE_WAIT(wait); - int err = 0; + int err; - dprintk("RPC: gss_upcall for uid %u\n", cred->cr_uid); + dprintk("RPC: %s for uid %u\n", + __func__, from_kuid(&init_user_ns, cred->cr_uid)); retry: - gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred); + err = 0; + /* if gssd is down, just skip upcalling altogether */ + if (!gssd_running(net)) { + warn_gssd(); + return -EACCES; + } + gss_msg = gss_setup_upcall(gss_auth, cred); if (PTR_ERR(gss_msg) == -EAGAIN) { err = wait_event_interruptible_timeout(pipe_version_waitqueue, - pipe_version >= 0, 15*HZ); - if (err) - goto out; - if (pipe_version < 0) + sn->pipe_version >= 0, 15 * HZ); + if (sn->pipe_version < 0) { warn_gssd(); + err = -EACCES; + } + if (err < 0) + goto out; goto retry; } if (IS_ERR(gss_msg)) { err = PTR_ERR(gss_msg); goto out; } - inode = &gss_msg->inode->vfs_inode; + pipe = gss_msg->pipe; for (;;) { - prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_INTERRUPTIBLE); - spin_lock(&inode->i_lock); + prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_KILLABLE); + spin_lock(&pipe->lock); if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) { break; } - spin_unlock(&inode->i_lock); - if (signalled()) { + spin_unlock(&pipe->lock); + if (fatal_signal_pending(current)) { err = -ERESTARTSYS; goto out_intr; } @@ -570,36 +644,16 @@ retry: gss_cred_set_ctx(cred, gss_msg->ctx); else err = gss_msg->msg.errno; - spin_unlock(&inode->i_lock); + spin_unlock(&pipe->lock); out_intr: finish_wait(&gss_msg->waitqueue, &wait); gss_release_msg(gss_msg); out: - dprintk("RPC: gss_create_upcall for uid %u result %d\n", - cred->cr_uid, err); + dprintk("RPC: %s for uid %u result %d\n", + __func__, from_kuid(&init_user_ns, cred->cr_uid), err); return err; } -static ssize_t -gss_pipe_upcall(struct file *filp, struct rpc_pipe_msg *msg, - char __user *dst, size_t buflen) -{ - char *data = (char *)msg->data + msg->copied; - size_t mlen = min(msg->len, buflen); - unsigned long left; - - left = copy_to_user(dst, data, mlen); - if (left == mlen) { - msg->errno = -EFAULT; - return -EFAULT; - } - - mlen -= left; - msg->copied += mlen; - msg->errno = 0; - return mlen; -} - #define MSG_BUF_MAXSIZE 1024 static ssize_t @@ -608,9 +662,10 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) const void *p, *end; void *buf; struct gss_upcall_msg *gss_msg; - struct inode *inode = filp->f_path.dentry->d_inode; + struct rpc_pipe *pipe = RPC_I(file_inode(filp))->pipe; struct gss_cl_ctx *ctx; - uid_t uid; + uid_t id; + kuid_t uid; ssize_t err = -EFBIG; if (mlen > MSG_BUF_MAXSIZE) @@ -625,12 +680,18 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) goto err; end = (const void *)((char *)buf + mlen); - p = simple_get_bytes(buf, end, &uid, sizeof(uid)); + p = simple_get_bytes(buf, end, &id, sizeof(id)); if (IS_ERR(p)) { err = PTR_ERR(p); goto err; } + uid = make_kuid(&init_user_ns, id); + if (!uid_valid(uid)) { + err = -EINVAL; + goto err; + } + err = -ENOMEM; ctx = gss_alloc_context(); if (ctx == NULL) @@ -638,14 +699,14 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) err = -ENOENT; /* Find a matching upcall */ - spin_lock(&inode->i_lock); - gss_msg = __gss_find_upcall(RPC_I(inode), uid); + spin_lock(&pipe->lock); + gss_msg = __gss_find_upcall(pipe, uid); if (gss_msg == NULL) { - spin_unlock(&inode->i_lock); + spin_unlock(&pipe->lock); goto err_put_ctx; } list_del_init(&gss_msg->list); - spin_unlock(&inode->i_lock); + spin_unlock(&pipe->lock); p = gss_fill_context(p, end, ctx, gss_msg->auth->mech); if (IS_ERR(p)) { @@ -673,35 +734,37 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) err = mlen; err_release_msg: - spin_lock(&inode->i_lock); + spin_lock(&pipe->lock); __gss_unhash_msg(gss_msg); - spin_unlock(&inode->i_lock); + spin_unlock(&pipe->lock); gss_release_msg(gss_msg); err_put_ctx: gss_put_ctx(ctx); err: kfree(buf); out: - dprintk("RPC: gss_pipe_downcall returning %Zd\n", err); + dprintk("RPC: %s returning %Zd\n", __func__, err); return err; } static int gss_pipe_open(struct inode *inode, int new_version) { + struct net *net = inode->i_sb->s_fs_info; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); int ret = 0; spin_lock(&pipe_version_lock); - if (pipe_version < 0) { + if (sn->pipe_version < 0) { /* First open of any gss pipe determines the version: */ - pipe_version = new_version; + sn->pipe_version = new_version; rpc_wake_up(&pipe_version_rpc_waitqueue); wake_up(&pipe_version_waitqueue); - } else if (pipe_version != new_version) { + } else if (sn->pipe_version != new_version) { /* Trying to open a pipe of a different version */ ret = -EBUSY; goto out; } - atomic_inc(&pipe_users); + atomic_inc(&sn->pipe_users); out: spin_unlock(&pipe_version_lock); return ret; @@ -721,24 +784,26 @@ static int gss_pipe_open_v1(struct inode *inode) static void gss_pipe_release(struct inode *inode) { - struct rpc_inode *rpci = RPC_I(inode); + struct net *net = inode->i_sb->s_fs_info; + struct rpc_pipe *pipe = RPC_I(inode)->pipe; struct gss_upcall_msg *gss_msg; - spin_lock(&inode->i_lock); - while (!list_empty(&rpci->in_downcall)) { +restart: + spin_lock(&pipe->lock); + list_for_each_entry(gss_msg, &pipe->in_downcall, list) { - gss_msg = list_entry(rpci->in_downcall.next, - struct gss_upcall_msg, list); + if (!list_empty(&gss_msg->msg.list)) + continue; gss_msg->msg.errno = -EPIPE; atomic_inc(&gss_msg->count); __gss_unhash_msg(gss_msg); - spin_unlock(&inode->i_lock); + spin_unlock(&pipe->lock); gss_release_msg(gss_msg); - spin_lock(&inode->i_lock); + goto restart; } - spin_unlock(&inode->i_lock); + spin_unlock(&pipe->lock); - put_pipe_version(); + put_pipe_version(net); } static void @@ -747,8 +812,8 @@ gss_pipe_destroy_msg(struct rpc_pipe_msg *msg) struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg); if (msg->errno < 0) { - dprintk("RPC: gss_pipe_destroy_msg releasing msg %p\n", - gss_msg); + dprintk("RPC: %s releasing msg %p\n", + __func__, gss_msg); atomic_inc(&gss_msg->count); gss_unhash_msg(gss_msg); if (msg->errno == -ETIMEDOUT) @@ -757,14 +822,153 @@ gss_pipe_destroy_msg(struct rpc_pipe_msg *msg) } } +static void gss_pipe_dentry_destroy(struct dentry *dir, + struct rpc_pipe_dir_object *pdo) +{ + struct gss_pipe *gss_pipe = pdo->pdo_data; + struct rpc_pipe *pipe = gss_pipe->pipe; + + if (pipe->dentry != NULL) { + rpc_unlink(pipe->dentry); + pipe->dentry = NULL; + } +} + +static int gss_pipe_dentry_create(struct dentry *dir, + struct rpc_pipe_dir_object *pdo) +{ + struct gss_pipe *p = pdo->pdo_data; + struct dentry *dentry; + + dentry = rpc_mkpipe_dentry(dir, p->name, p->clnt, p->pipe); + if (IS_ERR(dentry)) + return PTR_ERR(dentry); + p->pipe->dentry = dentry; + return 0; +} + +static const struct rpc_pipe_dir_object_ops gss_pipe_dir_object_ops = { + .create = gss_pipe_dentry_create, + .destroy = gss_pipe_dentry_destroy, +}; + +static struct gss_pipe *gss_pipe_alloc(struct rpc_clnt *clnt, + const char *name, + const struct rpc_pipe_ops *upcall_ops) +{ + struct gss_pipe *p; + int err = -ENOMEM; + + p = kmalloc(sizeof(*p), GFP_KERNEL); + if (p == NULL) + goto err; + p->pipe = rpc_mkpipe_data(upcall_ops, RPC_PIPE_WAIT_FOR_OPEN); + if (IS_ERR(p->pipe)) { + err = PTR_ERR(p->pipe); + goto err_free_gss_pipe; + } + p->name = name; + p->clnt = clnt; + kref_init(&p->kref); + rpc_init_pipe_dir_object(&p->pdo, + &gss_pipe_dir_object_ops, + p); + return p; +err_free_gss_pipe: + kfree(p); +err: + return ERR_PTR(err); +} + +struct gss_alloc_pdo { + struct rpc_clnt *clnt; + const char *name; + const struct rpc_pipe_ops *upcall_ops; +}; + +static int gss_pipe_match_pdo(struct rpc_pipe_dir_object *pdo, void *data) +{ + struct gss_pipe *gss_pipe; + struct gss_alloc_pdo *args = data; + + if (pdo->pdo_ops != &gss_pipe_dir_object_ops) + return 0; + gss_pipe = container_of(pdo, struct gss_pipe, pdo); + if (strcmp(gss_pipe->name, args->name) != 0) + return 0; + if (!kref_get_unless_zero(&gss_pipe->kref)) + return 0; + return 1; +} + +static struct rpc_pipe_dir_object *gss_pipe_alloc_pdo(void *data) +{ + struct gss_pipe *gss_pipe; + struct gss_alloc_pdo *args = data; + + gss_pipe = gss_pipe_alloc(args->clnt, args->name, args->upcall_ops); + if (!IS_ERR(gss_pipe)) + return &gss_pipe->pdo; + return NULL; +} + +static struct gss_pipe *gss_pipe_get(struct rpc_clnt *clnt, + const char *name, + const struct rpc_pipe_ops *upcall_ops) +{ + struct net *net = rpc_net_ns(clnt); + struct rpc_pipe_dir_object *pdo; + struct gss_alloc_pdo args = { + .clnt = clnt, + .name = name, + .upcall_ops = upcall_ops, + }; + + pdo = rpc_find_or_alloc_pipe_dir_object(net, + &clnt->cl_pipedir_objects, + gss_pipe_match_pdo, + gss_pipe_alloc_pdo, + &args); + if (pdo != NULL) + return container_of(pdo, struct gss_pipe, pdo); + return ERR_PTR(-ENOMEM); +} + +static void __gss_pipe_free(struct gss_pipe *p) +{ + struct rpc_clnt *clnt = p->clnt; + struct net *net = rpc_net_ns(clnt); + + rpc_remove_pipe_dir_object(net, + &clnt->cl_pipedir_objects, + &p->pdo); + rpc_destroy_pipe_data(p->pipe); + kfree(p); +} + +static void __gss_pipe_release(struct kref *kref) +{ + struct gss_pipe *p = container_of(kref, struct gss_pipe, kref); + + __gss_pipe_free(p); +} + +static void gss_pipe_free(struct gss_pipe *p) +{ + if (p != NULL) + kref_put(&p->kref, __gss_pipe_release); +} + /* * NOTE: we have the opportunity to use different * parameters based on the input flavor (which must be a pseudoflavor) */ -static struct rpc_auth * -gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +static struct gss_auth * +gss_create_new(struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { + rpc_authflavor_t flavor = args->pseudoflavor; struct gss_auth *gss_auth; + struct gss_pipe *gss_pipe; struct rpc_auth * auth; int err = -ENOMEM; /* XXX? */ @@ -774,17 +978,26 @@ gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) return ERR_PTR(err); if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL))) goto out_dec; + INIT_HLIST_NODE(&gss_auth->hash); + gss_auth->target_name = NULL; + if (args->target_name) { + gss_auth->target_name = kstrdup(args->target_name, GFP_KERNEL); + if (gss_auth->target_name == NULL) + goto err_free; + } gss_auth->client = clnt; + gss_auth->net = get_net(rpc_net_ns(clnt)); err = -EINVAL; gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor); if (!gss_auth->mech) { - printk(KERN_WARNING "%s: Pseudoflavor %d not found!\n", - __func__, flavor); - goto err_free; + dprintk("RPC: Pseudoflavor %d not found!\n", flavor); + goto err_put_net; } gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor); if (gss_auth->service == 0) goto err_put_mech; + if (!gssd_running(gss_auth->net)) + goto err_put_mech; auth = &gss_auth->rpc_auth; auth->au_cslack = GSS_CRED_SLACK >> 2; auth->au_rslack = GSS_VERF_SLACK >> 2; @@ -793,41 +1006,41 @@ gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) atomic_set(&auth->au_count, 1); kref_init(&gss_auth->kref); + err = rpcauth_init_credcache(auth); + if (err) + goto err_put_mech; /* * Note: if we created the old pipe first, then someone who * examined the directory at the right moment might conclude * that we supported only the old pipe. So we instead create * the new pipe first. */ - gss_auth->dentry[1] = rpc_mkpipe(clnt->cl_path.dentry, - "gssd", - clnt, &gss_upcall_ops_v1, - RPC_PIPE_WAIT_FOR_OPEN); - if (IS_ERR(gss_auth->dentry[1])) { - err = PTR_ERR(gss_auth->dentry[1]); - goto err_put_mech; + gss_pipe = gss_pipe_get(clnt, "gssd", &gss_upcall_ops_v1); + if (IS_ERR(gss_pipe)) { + err = PTR_ERR(gss_pipe); + goto err_destroy_credcache; } + gss_auth->gss_pipe[1] = gss_pipe; - gss_auth->dentry[0] = rpc_mkpipe(clnt->cl_path.dentry, - gss_auth->mech->gm_name, - clnt, &gss_upcall_ops_v0, - RPC_PIPE_WAIT_FOR_OPEN); - if (IS_ERR(gss_auth->dentry[0])) { - err = PTR_ERR(gss_auth->dentry[0]); - goto err_unlink_pipe_1; + gss_pipe = gss_pipe_get(clnt, gss_auth->mech->gm_name, + &gss_upcall_ops_v0); + if (IS_ERR(gss_pipe)) { + err = PTR_ERR(gss_pipe); + goto err_destroy_pipe_1; } - err = rpcauth_init_credcache(auth); - if (err) - goto err_unlink_pipe_0; + gss_auth->gss_pipe[0] = gss_pipe; - return auth; -err_unlink_pipe_0: - rpc_unlink(gss_auth->dentry[0]); -err_unlink_pipe_1: - rpc_unlink(gss_auth->dentry[1]); + return gss_auth; +err_destroy_pipe_1: + gss_pipe_free(gss_auth->gss_pipe[1]); +err_destroy_credcache: + rpcauth_destroy_credcache(auth); err_put_mech: gss_mech_put(gss_auth->mech); +err_put_net: + put_net(gss_auth->net); err_free: + kfree(gss_auth->target_name); kfree(gss_auth); out_dec: module_put(THIS_MODULE); @@ -837,9 +1050,11 @@ out_dec: static void gss_free(struct gss_auth *gss_auth) { - rpc_unlink(gss_auth->dentry[1]); - rpc_unlink(gss_auth->dentry[0]); + gss_pipe_free(gss_auth->gss_pipe[0]); + gss_pipe_free(gss_auth->gss_pipe[1]); gss_mech_put(gss_auth->mech); + put_net(gss_auth->net); + kfree(gss_auth->target_name); kfree(gss_auth); module_put(THIS_MODULE); @@ -854,17 +1069,118 @@ gss_free_callback(struct kref *kref) } static void +gss_put_auth(struct gss_auth *gss_auth) +{ + kref_put(&gss_auth->kref, gss_free_callback); +} + +static void gss_destroy(struct rpc_auth *auth) { - struct gss_auth *gss_auth; + struct gss_auth *gss_auth = container_of(auth, + struct gss_auth, rpc_auth); dprintk("RPC: destroying GSS authenticator %p flavor %d\n", auth, auth->au_flavor); + if (hash_hashed(&gss_auth->hash)) { + spin_lock(&gss_auth_hash_lock); + hash_del(&gss_auth->hash); + spin_unlock(&gss_auth_hash_lock); + } + + gss_pipe_free(gss_auth->gss_pipe[0]); + gss_auth->gss_pipe[0] = NULL; + gss_pipe_free(gss_auth->gss_pipe[1]); + gss_auth->gss_pipe[1] = NULL; rpcauth_destroy_credcache(auth); - gss_auth = container_of(auth, struct gss_auth, rpc_auth); - kref_put(&gss_auth->kref, gss_free_callback); + gss_put_auth(gss_auth); +} + +/* + * Auths may be shared between rpc clients that were cloned from a + * common client with the same xprt, if they also share the flavor and + * target_name. + * + * The auth is looked up from the oldest parent sharing the same + * cl_xprt, and the auth itself references only that common parent + * (which is guaranteed to last as long as any of its descendants). + */ +static struct gss_auth * +gss_auth_find_or_add_hashed(struct rpc_auth_create_args *args, + struct rpc_clnt *clnt, + struct gss_auth *new) +{ + struct gss_auth *gss_auth; + unsigned long hashval = (unsigned long)clnt; + + spin_lock(&gss_auth_hash_lock); + hash_for_each_possible(gss_auth_hash_table, + gss_auth, + hash, + hashval) { + if (gss_auth->client != clnt) + continue; + if (gss_auth->rpc_auth.au_flavor != args->pseudoflavor) + continue; + if (gss_auth->target_name != args->target_name) { + if (gss_auth->target_name == NULL) + continue; + if (args->target_name == NULL) + continue; + if (strcmp(gss_auth->target_name, args->target_name)) + continue; + } + if (!atomic_inc_not_zero(&gss_auth->rpc_auth.au_count)) + continue; + goto out; + } + if (new) + hash_add(gss_auth_hash_table, &new->hash, hashval); + gss_auth = new; +out: + spin_unlock(&gss_auth_hash_lock); + return gss_auth; +} + +static struct gss_auth * +gss_create_hashed(struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + struct gss_auth *gss_auth; + struct gss_auth *new; + + gss_auth = gss_auth_find_or_add_hashed(args, clnt, NULL); + if (gss_auth != NULL) + goto out; + new = gss_create_new(args, clnt); + if (IS_ERR(new)) + return new; + gss_auth = gss_auth_find_or_add_hashed(args, clnt, new); + if (gss_auth != new) + gss_destroy(&new->rpc_auth); +out: + return gss_auth; +} + +static struct rpc_auth * +gss_create(struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + struct gss_auth *gss_auth; + struct rpc_xprt *xprt = rcu_access_pointer(clnt->cl_xprt); + + while (clnt != clnt->cl_parent) { + struct rpc_clnt *parent = clnt->cl_parent; + /* Find the original parent for this transport */ + if (rcu_access_pointer(parent->cl_xprt) != xprt) + break; + clnt = parent; + } + + gss_auth = gss_create_hashed(args, clnt); + if (IS_ERR(gss_auth)) + return ERR_CAST(gss_auth); + return &gss_auth->rpc_auth; } /* @@ -905,8 +1221,9 @@ gss_destroying_context(struct rpc_cred *cred) static void gss_do_free_ctx(struct gss_cl_ctx *ctx) { - dprintk("RPC: gss_free_ctx\n"); + dprintk("RPC: %s\n", __func__); + gss_delete_sec_context(&ctx->gc_gss_ctx); kfree(ctx->gc_wire_ctx.data); kfree(ctx); } @@ -921,19 +1238,13 @@ gss_free_ctx_callback(struct rcu_head *head) static void gss_free_ctx(struct gss_cl_ctx *ctx) { - struct gss_ctx *gc_gss_ctx; - - gc_gss_ctx = rcu_dereference(ctx->gc_gss_ctx); - rcu_assign_pointer(ctx->gc_gss_ctx, NULL); call_rcu(&ctx->gc_rcu, gss_free_ctx_callback); - if (gc_gss_ctx) - gss_delete_sec_context(&gc_gss_ctx); } static void gss_free_cred(struct gss_cred *gss_cred) { - dprintk("RPC: gss_free_cred %p\n", gss_cred); + dprintk("RPC: %s cred=%p\n", __func__, gss_cred); kfree(gss_cred); } @@ -951,11 +1262,11 @@ gss_destroy_nullcred(struct rpc_cred *cred) struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth); struct gss_cl_ctx *ctx = gss_cred->gc_ctx; - rcu_assign_pointer(gss_cred->gc_ctx, NULL); + RCU_INIT_POINTER(gss_cred->gc_ctx, NULL); call_rcu(&cred->cr_rcu, gss_free_cred_callback); if (ctx) gss_put_ctx(ctx); - kref_put(&gss_auth->kref, gss_free_callback); + gss_put_auth(gss_auth); } static void @@ -983,8 +1294,9 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) struct gss_cred *cred = NULL; int err = -ENOMEM; - dprintk("RPC: gss_create_cred for uid %d, flavor %d\n", - acred->uid, auth->au_flavor); + dprintk("RPC: %s for uid %d, flavor %d\n", + __func__, from_kuid(&init_user_ns, acred->uid), + auth->au_flavor); if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS))) goto out_err; @@ -996,12 +1308,14 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) */ cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW; cred->gc_service = gss_auth->service; - cred->gc_machine_cred = acred->machine_cred; + cred->gc_principal = NULL; + if (acred->machine_cred) + cred->gc_principal = acred->principal; kref_get(&gss_auth->kref); return &cred->gc_base; out_err: - dprintk("RPC: gss_create_cred failed with error %d\n", err); + dprintk("RPC: %s failed with error %d\n", __func__, err); return ERR_PTR(err); } @@ -1018,10 +1332,32 @@ gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred) return err; } +/* + * Returns -EACCES if GSS context is NULL or will expire within the + * timeout (miliseconds) + */ +static int +gss_key_timeout(struct rpc_cred *rc) +{ + struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base); + unsigned long now = jiffies; + unsigned long expire; + + if (gss_cred->gc_ctx == NULL) + return -EACCES; + + expire = gss_cred->gc_ctx->gc_expiry - (gss_key_expire_timeo * HZ); + + if (time_after(now, expire)) + return -EACCES; + return 0; +} + static int gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags) { struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base); + int ret; if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags)) goto out; @@ -1031,9 +1367,29 @@ gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags) if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags)) return 0; out: - if (acred->machine_cred != gss_cred->gc_machine_cred) + if (acred->principal != NULL) { + if (gss_cred->gc_principal == NULL) + return 0; + ret = strcmp(acred->principal, gss_cred->gc_principal) == 0; + goto check_expire; + } + if (gss_cred->gc_principal != NULL) return 0; - return (rc->cr_uid == acred->uid); + ret = uid_eq(rc->cr_uid, acred->uid); + +check_expire: + if (ret == 0) + return ret; + + /* Notify acred users of GSS context expiration timeout */ + if (test_bit(RPC_CRED_NOTIFY_TIMEOUT, &acred->ac_flags) && + (gss_key_timeout(rc) != 0)) { + /* test will now be done from generic cred */ + test_and_clear_bit(RPC_CRED_NOTIFY_TIMEOUT, &acred->ac_flags); + /* tell NFS layer that key will expire soon */ + set_bit(RPC_CRED_KEY_EXPIRE_SOON, &acred->ac_flags); + } + return ret; } /* @@ -1043,18 +1399,18 @@ out: static __be32 * gss_marshal(struct rpc_task *task, __be32 *p) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_cred *cred = req->rq_cred; struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); __be32 *cred_len; - struct rpc_rqst *req = task->tk_rqstp; u32 maj_stat = 0; struct xdr_netobj mic; struct kvec iov; struct xdr_buf verf_buf; - dprintk("RPC: %5u gss_marshal\n", task->tk_pid); + dprintk("RPC: %5u %s\n", task->tk_pid, __func__); *p++ = htonl(RPC_AUTH_GSS); cred_len = p++; @@ -1072,7 +1428,7 @@ gss_marshal(struct rpc_task *task, __be32 *p) /* We compute the checksum for the verifier over the xdr-encoded bytes * starting with the xid and ending at the end of the credential: */ - iov.iov_base = xprt_skip_transport_header(task->tk_xprt, + iov.iov_base = xprt_skip_transport_header(req->rq_xprt, req->rq_snd_buf.head[0].iov_base); iov.iov_len = (u8 *)p - (u8 *)iov.iov_base; xdr_buf_from_iov(&iov, &verf_buf); @@ -1098,40 +1454,61 @@ out_put_ctx: static int gss_renew_cred(struct rpc_task *task) { - struct rpc_cred *oldcred = task->tk_msg.rpc_cred; + struct rpc_cred *oldcred = task->tk_rqstp->rq_cred; struct gss_cred *gss_cred = container_of(oldcred, struct gss_cred, gc_base); struct rpc_auth *auth = oldcred->cr_auth; struct auth_cred acred = { .uid = oldcred->cr_uid, - .machine_cred = gss_cred->gc_machine_cred, + .principal = gss_cred->gc_principal, + .machine_cred = (gss_cred->gc_principal != NULL ? 1 : 0), }; struct rpc_cred *new; new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW); if (IS_ERR(new)) return PTR_ERR(new); - task->tk_msg.rpc_cred = new; + task->tk_rqstp->rq_cred = new; put_rpccred(oldcred); return 0; } +static int gss_cred_is_negative_entry(struct rpc_cred *cred) +{ + if (test_bit(RPCAUTH_CRED_NEGATIVE, &cred->cr_flags)) { + unsigned long now = jiffies; + unsigned long begin, expire; + struct gss_cred *gss_cred; + + gss_cred = container_of(cred, struct gss_cred, gc_base); + begin = gss_cred->gc_upcall_timestamp; + expire = begin + gss_expired_cred_retry_delay * HZ; + + if (time_in_range_open(now, begin, expire)) + return 1; + } + return 0; +} + /* * Refresh credentials. XXX - finish */ static int gss_refresh(struct rpc_task *task) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; int ret = 0; + if (gss_cred_is_negative_entry(cred)) + return -EKEYEXPIRED; + if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) && !test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) { ret = gss_renew_cred(task); if (ret < 0) goto out; - cred = task->tk_msg.rpc_cred; + cred = task->tk_rqstp->rq_cred; } if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags)) @@ -1144,13 +1521,13 @@ out: static int gss_refresh_null(struct rpc_task *task) { - return -EACCES; + return 0; } static __be32 * gss_validate(struct rpc_task *task, __be32 *p) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); __be32 seq; struct kvec iov; @@ -1158,8 +1535,9 @@ gss_validate(struct rpc_task *task, __be32 *p) struct xdr_netobj mic; u32 flav,len; u32 maj_stat; + __be32 *ret = ERR_PTR(-EIO); - dprintk("RPC: %5u gss_validate\n", task->tk_pid); + dprintk("RPC: %5u %s\n", task->tk_pid, __func__); flav = ntohl(*p++); if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE) @@ -1173,30 +1551,42 @@ gss_validate(struct rpc_task *task, __be32 *p) mic.data = (u8 *)p; mic.len = len; + ret = ERR_PTR(-EACCES); maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic); if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); if (maj_stat) { - dprintk("RPC: %5u gss_validate: gss_verify_mic returned " - "error 0x%08x\n", task->tk_pid, maj_stat); + dprintk("RPC: %5u %s: gss_verify_mic returned error 0x%08x\n", + task->tk_pid, __func__, maj_stat); goto out_bad; } /* We leave it to unwrap to calculate au_rslack. For now we just * calculate the length of the verifier: */ cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2; gss_put_ctx(ctx); - dprintk("RPC: %5u gss_validate: gss_verify_mic succeeded.\n", - task->tk_pid); + dprintk("RPC: %5u %s: gss_verify_mic succeeded.\n", + task->tk_pid, __func__); return p + XDR_QUADLEN(len); out_bad: gss_put_ctx(ctx); - dprintk("RPC: %5u gss_validate failed.\n", task->tk_pid); - return NULL; + dprintk("RPC: %5u %s failed ret %ld.\n", task->tk_pid, __func__, + PTR_ERR(ret)); + return ret; +} + +static void gss_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp, + __be32 *p, void *obj) +{ + struct xdr_stream xdr; + + xdr_init_encode(&xdr, &rqstp->rq_snd_buf, p); + encode(rqstp, &xdr, obj); } static inline int gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, - kxdrproc_t encode, struct rpc_rqst *rqstp, __be32 *p, void *obj) + kxdreproc_t encode, struct rpc_rqst *rqstp, + __be32 *p, void *obj) { struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; struct xdr_buf integ_buf; @@ -1212,9 +1602,7 @@ gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; *p++ = htonl(rqstp->rq_seqno); - status = encode(rqstp, p, obj); - if (status) - return status; + gss_wrap_req_encode(encode, rqstp, p, obj); if (xdr_buf_subsegment(snd_buf, &integ_buf, offset, snd_buf->len - offset)) @@ -1280,16 +1668,16 @@ alloc_enc_pages(struct rpc_rqst *rqstp) rqstp->rq_release_snd_buf = priv_release_snd_buf; return 0; out_free: - for (i--; i >= 0; i--) { - __free_page(rqstp->rq_enc_pages[i]); - } + rqstp->rq_enc_pages_num = i; + priv_release_snd_buf(rqstp); out: return -EAGAIN; } static inline int gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, - kxdrproc_t encode, struct rpc_rqst *rqstp, __be32 *p, void *obj) + kxdreproc_t encode, struct rpc_rqst *rqstp, + __be32 *p, void *obj) { struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; u32 offset; @@ -1306,9 +1694,7 @@ gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; *p++ = htonl(rqstp->rq_seqno); - status = encode(rqstp, p, obj); - if (status) - return status; + gss_wrap_req_encode(encode, rqstp, p, obj); status = alloc_enc_pages(rqstp); if (status) @@ -1317,15 +1703,21 @@ gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, inpages = snd_buf->pages + first; snd_buf->pages = rqstp->rq_enc_pages; snd_buf->page_base -= first << PAGE_CACHE_SHIFT; - /* Give the tail its own page, in case we need extra space in the - * head when wrapping: */ + /* + * Give the tail its own page, in case we need extra space in the + * head when wrapping: + * + * call_allocate() allocates twice the slack space required + * by the authentication flavor to rq_callsize. + * For GSS, slack is GSS_CRED_SLACK. + */ if (snd_buf->page_len || snd_buf->tail[0].iov_len) { tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]); memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len); snd_buf->tail[0].iov_base = tmp; } maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages); - /* RPC_SLACK_SPACE should prevent this ever happening: */ + /* slack space should prevent this ever happening: */ BUG_ON(snd_buf->len > snd_buf->buflen); status = -EIO; /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was @@ -1352,38 +1744,38 @@ gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, static int gss_wrap_req(struct rpc_task *task, - kxdrproc_t encode, void *rqstp, __be32 *p, void *obj) + kxdreproc_t encode, void *rqstp, __be32 *p, void *obj) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); int status = -EIO; - dprintk("RPC: %5u gss_wrap_req\n", task->tk_pid); + dprintk("RPC: %5u %s\n", task->tk_pid, __func__); if (ctx->gc_proc != RPC_GSS_PROC_DATA) { /* The spec seems a little ambiguous here, but I think that not * wrapping context destruction requests makes the most sense. */ - status = encode(rqstp, p, obj); + gss_wrap_req_encode(encode, rqstp, p, obj); + status = 0; goto out; } switch (gss_cred->gc_service) { - case RPC_GSS_SVC_NONE: - status = encode(rqstp, p, obj); - break; - case RPC_GSS_SVC_INTEGRITY: - status = gss_wrap_req_integ(cred, ctx, encode, - rqstp, p, obj); - break; - case RPC_GSS_SVC_PRIVACY: - status = gss_wrap_req_priv(cred, ctx, encode, - rqstp, p, obj); - break; + case RPC_GSS_SVC_NONE: + gss_wrap_req_encode(encode, rqstp, p, obj); + status = 0; + break; + case RPC_GSS_SVC_INTEGRITY: + status = gss_wrap_req_integ(cred, ctx, encode, rqstp, p, obj); + break; + case RPC_GSS_SVC_PRIVACY: + status = gss_wrap_req_priv(cred, ctx, encode, rqstp, p, obj); + break; } out: gss_put_ctx(ctx); - dprintk("RPC: %5u gss_wrap_req returning %d\n", task->tk_pid, status); + dprintk("RPC: %5u %s returning %d\n", task->tk_pid, __func__, status); return status; } @@ -1452,12 +1844,21 @@ gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, return 0; } +static int +gss_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, + __be32 *p, void *obj) +{ + struct xdr_stream xdr; + + xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, p); + return decode(rqstp, &xdr, obj); +} static int gss_unwrap_resp(struct rpc_task *task, - kxdrproc_t decode, void *rqstp, __be32 *p, void *obj) + kxdrdproc_t decode, void *rqstp, __be32 *p, void *obj) { - struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); @@ -1469,28 +1870,28 @@ gss_unwrap_resp(struct rpc_task *task, if (ctx->gc_proc != RPC_GSS_PROC_DATA) goto out_decode; switch (gss_cred->gc_service) { - case RPC_GSS_SVC_NONE: - break; - case RPC_GSS_SVC_INTEGRITY: - status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p); - if (status) - goto out; - break; - case RPC_GSS_SVC_PRIVACY: - status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p); - if (status) - goto out; - break; + case RPC_GSS_SVC_NONE: + break; + case RPC_GSS_SVC_INTEGRITY: + status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p); + if (status) + goto out; + break; + case RPC_GSS_SVC_PRIVACY: + status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p); + if (status) + goto out; + break; } /* take into account extra slack for integrity and privacy cases: */ cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp) + (savedlen - head->iov_len); out_decode: - status = decode(rqstp, p, obj); + status = gss_unwrap_req_decode(decode, rqstp, p, obj); out: gss_put_ctx(ctx); - dprintk("RPC: %5u gss_unwrap_resp returning %d\n", task->tk_pid, - status); + dprintk("RPC: %5u %s returning %d\n", + task->tk_pid, __func__, status); return status; } @@ -1501,7 +1902,10 @@ static const struct rpc_authops authgss_ops = { .create = gss_create, .destroy = gss_destroy, .lookup_cred = gss_lookup_cred, - .crcreate = gss_create_cred + .crcreate = gss_create_cred, + .list_pseudoflavors = gss_mech_list_pseudoflavors, + .info2flavor = gss_mech_info2flavor, + .flavor2info = gss_mech_flavor2info, }; static const struct rpc_credops gss_credops = { @@ -1515,6 +1919,7 @@ static const struct rpc_credops gss_credops = { .crvalidate = gss_validate, .crwrap_req = gss_wrap_req, .crunwrap_resp = gss_unwrap_resp, + .crkey_timeout = gss_key_timeout, }; static const struct rpc_credops gss_nullops = { @@ -1530,7 +1935,7 @@ static const struct rpc_credops gss_nullops = { }; static const struct rpc_pipe_ops gss_upcall_ops_v0 = { - .upcall = gss_pipe_upcall, + .upcall = rpc_pipe_generic_upcall, .downcall = gss_pipe_downcall, .destroy_msg = gss_pipe_destroy_msg, .open_pipe = gss_pipe_open_v0, @@ -1538,13 +1943,28 @@ static const struct rpc_pipe_ops gss_upcall_ops_v0 = { }; static const struct rpc_pipe_ops gss_upcall_ops_v1 = { - .upcall = gss_pipe_upcall, + .upcall = rpc_pipe_generic_upcall, .downcall = gss_pipe_downcall, .destroy_msg = gss_pipe_destroy_msg, .open_pipe = gss_pipe_open_v1, .release_pipe = gss_pipe_release, }; +static __net_init int rpcsec_gss_init_net(struct net *net) +{ + return gss_svc_init_net(net); +} + +static __net_exit void rpcsec_gss_exit_net(struct net *net) +{ + gss_svc_shutdown_net(net); +} + +static struct pernet_operations rpcsec_gss_net_ops = { + .init = rpcsec_gss_init_net, + .exit = rpcsec_gss_exit_net, +}; + /* * Initialize RPCSEC_GSS module */ @@ -1558,8 +1978,13 @@ static int __init init_rpcsec_gss(void) err = gss_svc_init(); if (err) goto out_unregister; + err = register_pernet_subsys(&rpcsec_gss_net_ops); + if (err) + goto out_svc_exit; rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version"); return 0; +out_svc_exit: + gss_svc_shutdown(); out_unregister: rpcauth_unregister(&authgss_ops); out: @@ -1568,11 +1993,26 @@ out: static void __exit exit_rpcsec_gss(void) { + unregister_pernet_subsys(&rpcsec_gss_net_ops); gss_svc_shutdown(); rpcauth_unregister(&authgss_ops); rcu_barrier(); /* Wait for completion of call_rcu()'s */ } +MODULE_ALIAS("rpc-auth-6"); MODULE_LICENSE("GPL"); +module_param_named(expired_cred_retry_delay, + gss_expired_cred_retry_delay, + uint, 0644); +MODULE_PARM_DESC(expired_cred_retry_delay, "Timeout (in seconds) until " + "the RPC engine retries an expired credential"); + +module_param_named(key_expire_timeo, + gss_key_expire_timeo, + uint, 0644); +MODULE_PARM_DESC(key_expire_timeo, "Time (in seconds) at the end of a " + "credential keys lifetime where the NFS layer cleans up " + "prior to key expiration"); + module_init(init_rpcsec_gss) module_exit(exit_rpcsec_gss) diff --git a/net/sunrpc/auth_gss/gss_generic_token.c b/net/sunrpc/auth_gss/gss_generic_token.c index c0ba39c4f5f..c586e92bcf7 100644 --- a/net/sunrpc/auth_gss/gss_generic_token.c +++ b/net/sunrpc/auth_gss/gss_generic_token.c @@ -33,7 +33,6 @@ #include <linux/types.h> #include <linux/module.h> -#include <linux/slab.h> #include <linux/string.h> #include <linux/sunrpc/sched.h> #include <linux/sunrpc/gss_asn1.h> @@ -77,19 +76,19 @@ static int der_length_size( int length) { if (length < (1<<7)) - return(1); + return 1; else if (length < (1<<8)) - return(2); + return 2; #if (SIZEOF_INT == 2) else - return(3); + return 3; #else else if (length < (1<<16)) - return(3); + return 3; else if (length < (1<<24)) - return(4); + return 4; else - return(5); + return 5; #endif } @@ -122,14 +121,14 @@ der_read_length(unsigned char **buf, int *bufsize) int ret; if (*bufsize < 1) - return(-1); + return -1; sf = *(*buf)++; (*bufsize)--; if (sf & 0x80) { if ((sf &= 0x7f) > ((*bufsize)-1)) - return(-1); + return -1; if (sf > SIZEOF_INT) - return (-1); + return -1; ret = 0; for (; sf; sf--) { ret = (ret<<8) + (*(*buf)++); @@ -139,7 +138,7 @@ der_read_length(unsigned char **buf, int *bufsize) ret = sf; } - return(ret); + return ret; } /* returns the length of a token, given the mech oid and the body size */ @@ -149,7 +148,7 @@ g_token_size(struct xdr_netobj *mech, unsigned int body_size) { /* set body_size to sequence contents size */ body_size += 2 + (int) mech->len; /* NEED overflow check */ - return(1 + der_length_size(body_size) + body_size); + return 1 + der_length_size(body_size) + body_size; } EXPORT_SYMBOL_GPL(g_token_size); @@ -187,27 +186,27 @@ g_verify_token_header(struct xdr_netobj *mech, int *body_size, int ret = 0; if ((toksize-=1) < 0) - return(G_BAD_TOK_HEADER); + return G_BAD_TOK_HEADER; if (*buf++ != 0x60) - return(G_BAD_TOK_HEADER); + return G_BAD_TOK_HEADER; if ((seqsize = der_read_length(&buf, &toksize)) < 0) - return(G_BAD_TOK_HEADER); + return G_BAD_TOK_HEADER; if (seqsize != toksize) - return(G_BAD_TOK_HEADER); + return G_BAD_TOK_HEADER; if ((toksize-=1) < 0) - return(G_BAD_TOK_HEADER); + return G_BAD_TOK_HEADER; if (*buf++ != 0x06) - return(G_BAD_TOK_HEADER); + return G_BAD_TOK_HEADER; if ((toksize-=1) < 0) - return(G_BAD_TOK_HEADER); + return G_BAD_TOK_HEADER; toid.len = *buf++; if ((toksize-=toid.len) < 0) - return(G_BAD_TOK_HEADER); + return G_BAD_TOK_HEADER; toid.data = buf; buf+=toid.len; @@ -218,17 +217,17 @@ g_verify_token_header(struct xdr_netobj *mech, int *body_size, to return G_BAD_TOK_HEADER if the token header is in fact bad */ if ((toksize-=2) < 0) - return(G_BAD_TOK_HEADER); + return G_BAD_TOK_HEADER; if (ret) - return(ret); + return ret; if (!ret) { *buf_in = buf; *body_size = toksize; } - return(ret); + return ret; } EXPORT_SYMBOL_GPL(g_verify_token_header); diff --git a/net/sunrpc/auth_gss/gss_krb5_crypto.c b/net/sunrpc/auth_gss/gss_krb5_crypto.c index c93fca20455..0f43e894bc0 100644 --- a/net/sunrpc/auth_gss/gss_krb5_crypto.c +++ b/net/sunrpc/auth_gss/gss_krb5_crypto.c @@ -1,7 +1,7 @@ /* * linux/net/sunrpc/gss_krb5_crypto.c * - * Copyright (c) 2000 The Regents of the University of Michigan. + * Copyright (c) 2000-2008 The Regents of the University of Michigan. * All rights reserved. * * Andy Adamson <andros@umich.edu> @@ -37,11 +37,11 @@ #include <linux/err.h> #include <linux/types.h> #include <linux/mm.h> -#include <linux/slab.h> #include <linux/scatterlist.h> #include <linux/crypto.h> #include <linux/highmem.h> #include <linux/pagemap.h> +#include <linux/random.h> #include <linux/sunrpc/gss_krb5.h> #include <linux/sunrpc/xdr.h> @@ -59,13 +59,13 @@ krb5_encrypt( { u32 ret = -EINVAL; struct scatterlist sg[1]; - u8 local_iv[16] = {0}; + u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0}; struct blkcipher_desc desc = { .tfm = tfm, .info = local_iv }; if (length % crypto_blkcipher_blocksize(tfm) != 0) goto out; - if (crypto_blkcipher_ivsize(tfm) > 16) { + if (crypto_blkcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { dprintk("RPC: gss_k5encrypt: tfm iv size too large %d\n", crypto_blkcipher_ivsize(tfm)); goto out; @@ -93,13 +93,13 @@ krb5_decrypt( { u32 ret = -EINVAL; struct scatterlist sg[1]; - u8 local_iv[16] = {0}; + u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0}; struct blkcipher_desc desc = { .tfm = tfm, .info = local_iv }; if (length % crypto_blkcipher_blocksize(tfm) != 0) goto out; - if (crypto_blkcipher_ivsize(tfm) > 16) { + if (crypto_blkcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { dprintk("RPC: gss_k5decrypt: tfm iv size too large %d\n", crypto_blkcipher_ivsize(tfm)); goto out; @@ -124,21 +124,155 @@ checksummer(struct scatterlist *sg, void *data) return crypto_hash_update(desc, sg, sg->length); } -/* checksum the plaintext data and hdrlen bytes of the token header */ -s32 -make_checksum(char *cksumname, char *header, int hdrlen, struct xdr_buf *body, - int body_offset, struct xdr_netobj *cksum) +static int +arcfour_hmac_md5_usage_to_salt(unsigned int usage, u8 salt[4]) +{ + unsigned int ms_usage; + + switch (usage) { + case KG_USAGE_SIGN: + ms_usage = 15; + break; + case KG_USAGE_SEAL: + ms_usage = 13; + break; + default: + return -EINVAL; + } + salt[0] = (ms_usage >> 0) & 0xff; + salt[1] = (ms_usage >> 8) & 0xff; + salt[2] = (ms_usage >> 16) & 0xff; + salt[3] = (ms_usage >> 24) & 0xff; + + return 0; +} + +static u32 +make_checksum_hmac_md5(struct krb5_ctx *kctx, char *header, int hdrlen, + struct xdr_buf *body, int body_offset, u8 *cksumkey, + unsigned int usage, struct xdr_netobj *cksumout) +{ + struct hash_desc desc; + struct scatterlist sg[1]; + int err; + u8 checksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + u8 rc4salt[4]; + struct crypto_hash *md5; + struct crypto_hash *hmac_md5; + + if (cksumkey == NULL) + return GSS_S_FAILURE; + + if (cksumout->len < kctx->gk5e->cksumlength) { + dprintk("%s: checksum buffer length, %u, too small for %s\n", + __func__, cksumout->len, kctx->gk5e->name); + return GSS_S_FAILURE; + } + + if (arcfour_hmac_md5_usage_to_salt(usage, rc4salt)) { + dprintk("%s: invalid usage value %u\n", __func__, usage); + return GSS_S_FAILURE; + } + + md5 = crypto_alloc_hash("md5", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(md5)) + return GSS_S_FAILURE; + + hmac_md5 = crypto_alloc_hash(kctx->gk5e->cksum_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(hmac_md5)) { + crypto_free_hash(md5); + return GSS_S_FAILURE; + } + + desc.tfm = md5; + desc.flags = CRYPTO_TFM_REQ_MAY_SLEEP; + + err = crypto_hash_init(&desc); + if (err) + goto out; + sg_init_one(sg, rc4salt, 4); + err = crypto_hash_update(&desc, sg, 4); + if (err) + goto out; + + sg_init_one(sg, header, hdrlen); + err = crypto_hash_update(&desc, sg, hdrlen); + if (err) + goto out; + err = xdr_process_buf(body, body_offset, body->len - body_offset, + checksummer, &desc); + if (err) + goto out; + err = crypto_hash_final(&desc, checksumdata); + if (err) + goto out; + + desc.tfm = hmac_md5; + desc.flags = CRYPTO_TFM_REQ_MAY_SLEEP; + + err = crypto_hash_init(&desc); + if (err) + goto out; + err = crypto_hash_setkey(hmac_md5, cksumkey, kctx->gk5e->keylength); + if (err) + goto out; + + sg_init_one(sg, checksumdata, crypto_hash_digestsize(md5)); + err = crypto_hash_digest(&desc, sg, crypto_hash_digestsize(md5), + checksumdata); + if (err) + goto out; + + memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); + cksumout->len = kctx->gk5e->cksumlength; +out: + crypto_free_hash(md5); + crypto_free_hash(hmac_md5); + return err ? GSS_S_FAILURE : 0; +} + +/* + * checksum the plaintext data and hdrlen bytes of the token header + * The checksum is performed over the first 8 bytes of the + * gss token header and then over the data body + */ +u32 +make_checksum(struct krb5_ctx *kctx, char *header, int hdrlen, + struct xdr_buf *body, int body_offset, u8 *cksumkey, + unsigned int usage, struct xdr_netobj *cksumout) { - struct hash_desc desc; /* XXX add to ctx? */ + struct hash_desc desc; struct scatterlist sg[1]; int err; + u8 checksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + unsigned int checksumlen; + + if (kctx->gk5e->ctype == CKSUMTYPE_HMAC_MD5_ARCFOUR) + return make_checksum_hmac_md5(kctx, header, hdrlen, + body, body_offset, + cksumkey, usage, cksumout); - desc.tfm = crypto_alloc_hash(cksumname, 0, CRYPTO_ALG_ASYNC); + if (cksumout->len < kctx->gk5e->cksumlength) { + dprintk("%s: checksum buffer length, %u, too small for %s\n", + __func__, cksumout->len, kctx->gk5e->name); + return GSS_S_FAILURE; + } + + desc.tfm = crypto_alloc_hash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); if (IS_ERR(desc.tfm)) return GSS_S_FAILURE; - cksum->len = crypto_hash_digestsize(desc.tfm); desc.flags = CRYPTO_TFM_REQ_MAY_SLEEP; + checksumlen = crypto_hash_digestsize(desc.tfm); + + if (cksumkey != NULL) { + err = crypto_hash_setkey(desc.tfm, cksumkey, + kctx->gk5e->keylength); + if (err) + goto out; + } + err = crypto_hash_init(&desc); if (err) goto out; @@ -150,15 +284,109 @@ make_checksum(char *cksumname, char *header, int hdrlen, struct xdr_buf *body, checksummer, &desc); if (err) goto out; - err = crypto_hash_final(&desc, cksum->data); + err = crypto_hash_final(&desc, checksumdata); + if (err) + goto out; + switch (kctx->gk5e->ctype) { + case CKSUMTYPE_RSA_MD5: + err = kctx->gk5e->encrypt(kctx->seq, NULL, checksumdata, + checksumdata, checksumlen); + if (err) + goto out; + memcpy(cksumout->data, + checksumdata + checksumlen - kctx->gk5e->cksumlength, + kctx->gk5e->cksumlength); + break; + case CKSUMTYPE_HMAC_SHA1_DES3: + memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); + break; + default: + BUG(); + break; + } + cksumout->len = kctx->gk5e->cksumlength; +out: + crypto_free_hash(desc.tfm); + return err ? GSS_S_FAILURE : 0; +} + +/* + * checksum the plaintext data and hdrlen bytes of the token header + * Per rfc4121, sec. 4.2.4, the checksum is performed over the data + * body then over the first 16 octets of the MIC token + * Inclusion of the header data in the calculation of the + * checksum is optional. + */ +u32 +make_checksum_v2(struct krb5_ctx *kctx, char *header, int hdrlen, + struct xdr_buf *body, int body_offset, u8 *cksumkey, + unsigned int usage, struct xdr_netobj *cksumout) +{ + struct hash_desc desc; + struct scatterlist sg[1]; + int err; + u8 checksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + unsigned int checksumlen; + + if (kctx->gk5e->keyed_cksum == 0) { + dprintk("%s: expected keyed hash for %s\n", + __func__, kctx->gk5e->name); + return GSS_S_FAILURE; + } + if (cksumkey == NULL) { + dprintk("%s: no key supplied for %s\n", + __func__, kctx->gk5e->name); + return GSS_S_FAILURE; + } + + desc.tfm = crypto_alloc_hash(kctx->gk5e->cksum_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(desc.tfm)) + return GSS_S_FAILURE; + checksumlen = crypto_hash_digestsize(desc.tfm); + desc.flags = CRYPTO_TFM_REQ_MAY_SLEEP; + + err = crypto_hash_setkey(desc.tfm, cksumkey, kctx->gk5e->keylength); + if (err) + goto out; + + err = crypto_hash_init(&desc); + if (err) + goto out; + err = xdr_process_buf(body, body_offset, body->len - body_offset, + checksummer, &desc); + if (err) + goto out; + if (header != NULL) { + sg_init_one(sg, header, hdrlen); + err = crypto_hash_update(&desc, sg, hdrlen); + if (err) + goto out; + } + err = crypto_hash_final(&desc, checksumdata); + if (err) + goto out; + + cksumout->len = kctx->gk5e->cksumlength; + + switch (kctx->gk5e->ctype) { + case CKSUMTYPE_HMAC_SHA1_96_AES128: + case CKSUMTYPE_HMAC_SHA1_96_AES256: + /* note that this truncates the hash */ + memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); + break; + default: + BUG(); + break; + } out: crypto_free_hash(desc.tfm); return err ? GSS_S_FAILURE : 0; } struct encryptor_desc { - u8 iv[8]; /* XXX hard-coded blocksize */ + u8 iv[GSS_KRB5_MAX_BLOCKSIZE]; struct blkcipher_desc desc; int pos; struct xdr_buf *outbuf; @@ -199,7 +427,7 @@ encryptor(struct scatterlist *sg, void *data) desc->fraglen += sg->length; desc->pos += sg->length; - fraglen = thislen & 7; /* XXX hardcoded blocksize */ + fraglen = thislen & (crypto_blkcipher_blocksize(desc->desc.tfm) - 1); thislen -= fraglen; if (thislen == 0) @@ -257,7 +485,7 @@ gss_encrypt_xdr_buf(struct crypto_blkcipher *tfm, struct xdr_buf *buf, } struct decryptor_desc { - u8 iv[8]; /* XXX hard-coded blocksize */ + u8 iv[GSS_KRB5_MAX_BLOCKSIZE]; struct blkcipher_desc desc; struct scatterlist frags[4]; int fragno; @@ -279,7 +507,7 @@ decryptor(struct scatterlist *sg, void *data) desc->fragno++; desc->fraglen += sg->length; - fraglen = thislen & 7; /* XXX hardcoded blocksize */ + fraglen = thislen & (crypto_blkcipher_blocksize(desc->desc.tfm) - 1); thislen -= fraglen; if (thislen == 0) @@ -326,3 +554,440 @@ gss_decrypt_xdr_buf(struct crypto_blkcipher *tfm, struct xdr_buf *buf, return xdr_process_buf(buf, offset, buf->len - offset, decryptor, &desc); } + +/* + * This function makes the assumption that it was ultimately called + * from gss_wrap(). + * + * The client auth_gss code moves any existing tail data into a + * separate page before calling gss_wrap. + * The server svcauth_gss code ensures that both the head and the + * tail have slack space of RPC_MAX_AUTH_SIZE before calling gss_wrap. + * + * Even with that guarantee, this function may be called more than + * once in the processing of gss_wrap(). The best we can do is + * verify at compile-time (see GSS_KRB5_SLACK_CHECK) that the + * largest expected shift will fit within RPC_MAX_AUTH_SIZE. + * At run-time we can verify that a single invocation of this + * function doesn't attempt to use more the RPC_MAX_AUTH_SIZE. + */ + +int +xdr_extend_head(struct xdr_buf *buf, unsigned int base, unsigned int shiftlen) +{ + u8 *p; + + if (shiftlen == 0) + return 0; + + BUILD_BUG_ON(GSS_KRB5_MAX_SLACK_NEEDED > RPC_MAX_AUTH_SIZE); + BUG_ON(shiftlen > RPC_MAX_AUTH_SIZE); + + p = buf->head[0].iov_base + base; + + memmove(p + shiftlen, p, buf->head[0].iov_len - base); + + buf->head[0].iov_len += shiftlen; + buf->len += shiftlen; + + return 0; +} + +static u32 +gss_krb5_cts_crypt(struct crypto_blkcipher *cipher, struct xdr_buf *buf, + u32 offset, u8 *iv, struct page **pages, int encrypt) +{ + u32 ret; + struct scatterlist sg[1]; + struct blkcipher_desc desc = { .tfm = cipher, .info = iv }; + u8 data[GSS_KRB5_MAX_BLOCKSIZE * 2]; + struct page **save_pages; + u32 len = buf->len - offset; + + if (len > ARRAY_SIZE(data)) { + WARN_ON(0); + return -ENOMEM; + } + + /* + * For encryption, we want to read from the cleartext + * page cache pages, and write the encrypted data to + * the supplied xdr_buf pages. + */ + save_pages = buf->pages; + if (encrypt) + buf->pages = pages; + + ret = read_bytes_from_xdr_buf(buf, offset, data, len); + buf->pages = save_pages; + if (ret) + goto out; + + sg_init_one(sg, data, len); + + if (encrypt) + ret = crypto_blkcipher_encrypt_iv(&desc, sg, sg, len); + else + ret = crypto_blkcipher_decrypt_iv(&desc, sg, sg, len); + + if (ret) + goto out; + + ret = write_bytes_to_xdr_buf(buf, offset, data, len); + +out: + return ret; +} + +u32 +gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, + struct xdr_buf *buf, int ec, struct page **pages) +{ + u32 err; + struct xdr_netobj hmac; + u8 *cksumkey; + u8 *ecptr; + struct crypto_blkcipher *cipher, *aux_cipher; + int blocksize; + struct page **save_pages; + int nblocks, nbytes; + struct encryptor_desc desc; + u32 cbcbytes; + unsigned int usage; + + if (kctx->initiate) { + cipher = kctx->initiator_enc; + aux_cipher = kctx->initiator_enc_aux; + cksumkey = kctx->initiator_integ; + usage = KG_USAGE_INITIATOR_SEAL; + } else { + cipher = kctx->acceptor_enc; + aux_cipher = kctx->acceptor_enc_aux; + cksumkey = kctx->acceptor_integ; + usage = KG_USAGE_ACCEPTOR_SEAL; + } + blocksize = crypto_blkcipher_blocksize(cipher); + + /* hide the gss token header and insert the confounder */ + offset += GSS_KRB5_TOK_HDR_LEN; + if (xdr_extend_head(buf, offset, kctx->gk5e->conflen)) + return GSS_S_FAILURE; + gss_krb5_make_confounder(buf->head[0].iov_base + offset, kctx->gk5e->conflen); + offset -= GSS_KRB5_TOK_HDR_LEN; + + if (buf->tail[0].iov_base != NULL) { + ecptr = buf->tail[0].iov_base + buf->tail[0].iov_len; + } else { + buf->tail[0].iov_base = buf->head[0].iov_base + + buf->head[0].iov_len; + buf->tail[0].iov_len = 0; + ecptr = buf->tail[0].iov_base; + } + + memset(ecptr, 'X', ec); + buf->tail[0].iov_len += ec; + buf->len += ec; + + /* copy plaintext gss token header after filler (if any) */ + memcpy(ecptr + ec, buf->head[0].iov_base + offset, + GSS_KRB5_TOK_HDR_LEN); + buf->tail[0].iov_len += GSS_KRB5_TOK_HDR_LEN; + buf->len += GSS_KRB5_TOK_HDR_LEN; + + /* Do the HMAC */ + hmac.len = GSS_KRB5_MAX_CKSUM_LEN; + hmac.data = buf->tail[0].iov_base + buf->tail[0].iov_len; + + /* + * When we are called, pages points to the real page cache + * data -- which we can't go and encrypt! buf->pages points + * to scratch pages which we are going to send off to the + * client/server. Swap in the plaintext pages to calculate + * the hmac. + */ + save_pages = buf->pages; + buf->pages = pages; + + err = make_checksum_v2(kctx, NULL, 0, buf, + offset + GSS_KRB5_TOK_HDR_LEN, + cksumkey, usage, &hmac); + buf->pages = save_pages; + if (err) + return GSS_S_FAILURE; + + nbytes = buf->len - offset - GSS_KRB5_TOK_HDR_LEN; + nblocks = (nbytes + blocksize - 1) / blocksize; + cbcbytes = 0; + if (nblocks > 2) + cbcbytes = (nblocks - 2) * blocksize; + + memset(desc.iv, 0, sizeof(desc.iv)); + + if (cbcbytes) { + desc.pos = offset + GSS_KRB5_TOK_HDR_LEN; + desc.fragno = 0; + desc.fraglen = 0; + desc.pages = pages; + desc.outbuf = buf; + desc.desc.info = desc.iv; + desc.desc.flags = 0; + desc.desc.tfm = aux_cipher; + + sg_init_table(desc.infrags, 4); + sg_init_table(desc.outfrags, 4); + + err = xdr_process_buf(buf, offset + GSS_KRB5_TOK_HDR_LEN, + cbcbytes, encryptor, &desc); + if (err) + goto out_err; + } + + /* Make sure IV carries forward from any CBC results. */ + err = gss_krb5_cts_crypt(cipher, buf, + offset + GSS_KRB5_TOK_HDR_LEN + cbcbytes, + desc.iv, pages, 1); + if (err) { + err = GSS_S_FAILURE; + goto out_err; + } + + /* Now update buf to account for HMAC */ + buf->tail[0].iov_len += kctx->gk5e->cksumlength; + buf->len += kctx->gk5e->cksumlength; + +out_err: + if (err) + err = GSS_S_FAILURE; + return err; +} + +u32 +gss_krb5_aes_decrypt(struct krb5_ctx *kctx, u32 offset, struct xdr_buf *buf, + u32 *headskip, u32 *tailskip) +{ + struct xdr_buf subbuf; + u32 ret = 0; + u8 *cksum_key; + struct crypto_blkcipher *cipher, *aux_cipher; + struct xdr_netobj our_hmac_obj; + u8 our_hmac[GSS_KRB5_MAX_CKSUM_LEN]; + u8 pkt_hmac[GSS_KRB5_MAX_CKSUM_LEN]; + int nblocks, blocksize, cbcbytes; + struct decryptor_desc desc; + unsigned int usage; + + if (kctx->initiate) { + cipher = kctx->acceptor_enc; + aux_cipher = kctx->acceptor_enc_aux; + cksum_key = kctx->acceptor_integ; + usage = KG_USAGE_ACCEPTOR_SEAL; + } else { + cipher = kctx->initiator_enc; + aux_cipher = kctx->initiator_enc_aux; + cksum_key = kctx->initiator_integ; + usage = KG_USAGE_INITIATOR_SEAL; + } + blocksize = crypto_blkcipher_blocksize(cipher); + + + /* create a segment skipping the header and leaving out the checksum */ + xdr_buf_subsegment(buf, &subbuf, offset + GSS_KRB5_TOK_HDR_LEN, + (buf->len - offset - GSS_KRB5_TOK_HDR_LEN - + kctx->gk5e->cksumlength)); + + nblocks = (subbuf.len + blocksize - 1) / blocksize; + + cbcbytes = 0; + if (nblocks > 2) + cbcbytes = (nblocks - 2) * blocksize; + + memset(desc.iv, 0, sizeof(desc.iv)); + + if (cbcbytes) { + desc.fragno = 0; + desc.fraglen = 0; + desc.desc.info = desc.iv; + desc.desc.flags = 0; + desc.desc.tfm = aux_cipher; + + sg_init_table(desc.frags, 4); + + ret = xdr_process_buf(&subbuf, 0, cbcbytes, decryptor, &desc); + if (ret) + goto out_err; + } + + /* Make sure IV carries forward from any CBC results. */ + ret = gss_krb5_cts_crypt(cipher, &subbuf, cbcbytes, desc.iv, NULL, 0); + if (ret) + goto out_err; + + + /* Calculate our hmac over the plaintext data */ + our_hmac_obj.len = sizeof(our_hmac); + our_hmac_obj.data = our_hmac; + + ret = make_checksum_v2(kctx, NULL, 0, &subbuf, 0, + cksum_key, usage, &our_hmac_obj); + if (ret) + goto out_err; + + /* Get the packet's hmac value */ + ret = read_bytes_from_xdr_buf(buf, buf->len - kctx->gk5e->cksumlength, + pkt_hmac, kctx->gk5e->cksumlength); + if (ret) + goto out_err; + + if (memcmp(pkt_hmac, our_hmac, kctx->gk5e->cksumlength) != 0) { + ret = GSS_S_BAD_SIG; + goto out_err; + } + *headskip = kctx->gk5e->conflen; + *tailskip = kctx->gk5e->cksumlength; +out_err: + if (ret && ret != GSS_S_BAD_SIG) + ret = GSS_S_FAILURE; + return ret; +} + +/* + * Compute Kseq given the initial session key and the checksum. + * Set the key of the given cipher. + */ +int +krb5_rc4_setup_seq_key(struct krb5_ctx *kctx, struct crypto_blkcipher *cipher, + unsigned char *cksum) +{ + struct crypto_hash *hmac; + struct hash_desc desc; + struct scatterlist sg[1]; + u8 Kseq[GSS_KRB5_MAX_KEYLEN]; + u32 zeroconstant = 0; + int err; + + dprintk("%s: entered\n", __func__); + + hmac = crypto_alloc_hash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(hmac)) { + dprintk("%s: error %ld, allocating hash '%s'\n", + __func__, PTR_ERR(hmac), kctx->gk5e->cksum_name); + return PTR_ERR(hmac); + } + + desc.tfm = hmac; + desc.flags = 0; + + err = crypto_hash_init(&desc); + if (err) + goto out_err; + + /* Compute intermediate Kseq from session key */ + err = crypto_hash_setkey(hmac, kctx->Ksess, kctx->gk5e->keylength); + if (err) + goto out_err; + + sg_init_table(sg, 1); + sg_set_buf(sg, &zeroconstant, 4); + + err = crypto_hash_digest(&desc, sg, 4, Kseq); + if (err) + goto out_err; + + /* Compute final Kseq from the checksum and intermediate Kseq */ + err = crypto_hash_setkey(hmac, Kseq, kctx->gk5e->keylength); + if (err) + goto out_err; + + sg_set_buf(sg, cksum, 8); + + err = crypto_hash_digest(&desc, sg, 8, Kseq); + if (err) + goto out_err; + + err = crypto_blkcipher_setkey(cipher, Kseq, kctx->gk5e->keylength); + if (err) + goto out_err; + + err = 0; + +out_err: + crypto_free_hash(hmac); + dprintk("%s: returning %d\n", __func__, err); + return err; +} + +/* + * Compute Kcrypt given the initial session key and the plaintext seqnum. + * Set the key of cipher kctx->enc. + */ +int +krb5_rc4_setup_enc_key(struct krb5_ctx *kctx, struct crypto_blkcipher *cipher, + s32 seqnum) +{ + struct crypto_hash *hmac; + struct hash_desc desc; + struct scatterlist sg[1]; + u8 Kcrypt[GSS_KRB5_MAX_KEYLEN]; + u8 zeroconstant[4] = {0}; + u8 seqnumarray[4]; + int err, i; + + dprintk("%s: entered, seqnum %u\n", __func__, seqnum); + + hmac = crypto_alloc_hash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(hmac)) { + dprintk("%s: error %ld, allocating hash '%s'\n", + __func__, PTR_ERR(hmac), kctx->gk5e->cksum_name); + return PTR_ERR(hmac); + } + + desc.tfm = hmac; + desc.flags = 0; + + err = crypto_hash_init(&desc); + if (err) + goto out_err; + + /* Compute intermediate Kcrypt from session key */ + for (i = 0; i < kctx->gk5e->keylength; i++) + Kcrypt[i] = kctx->Ksess[i] ^ 0xf0; + + err = crypto_hash_setkey(hmac, Kcrypt, kctx->gk5e->keylength); + if (err) + goto out_err; + + sg_init_table(sg, 1); + sg_set_buf(sg, zeroconstant, 4); + + err = crypto_hash_digest(&desc, sg, 4, Kcrypt); + if (err) + goto out_err; + + /* Compute final Kcrypt from the seqnum and intermediate Kcrypt */ + err = crypto_hash_setkey(hmac, Kcrypt, kctx->gk5e->keylength); + if (err) + goto out_err; + + seqnumarray[0] = (unsigned char) ((seqnum >> 24) & 0xff); + seqnumarray[1] = (unsigned char) ((seqnum >> 16) & 0xff); + seqnumarray[2] = (unsigned char) ((seqnum >> 8) & 0xff); + seqnumarray[3] = (unsigned char) ((seqnum >> 0) & 0xff); + + sg_set_buf(sg, seqnumarray, 4); + + err = crypto_hash_digest(&desc, sg, 4, Kcrypt); + if (err) + goto out_err; + + err = crypto_blkcipher_setkey(cipher, Kcrypt, kctx->gk5e->keylength); + if (err) + goto out_err; + + err = 0; + +out_err: + crypto_free_hash(hmac); + dprintk("%s: returning %d\n", __func__, err); + return err; +} + diff --git a/net/sunrpc/auth_gss/gss_krb5_keys.c b/net/sunrpc/auth_gss/gss_krb5_keys.c new file mode 100644 index 00000000000..24589bd2a4b --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_keys.c @@ -0,0 +1,327 @@ +/* + * COPYRIGHT (c) 2008 + * The Regents of the University of Michigan + * ALL RIGHTS RESERVED + * + * Permission is granted to use, copy, create derivative works + * and redistribute this software and such derivative works + * for any purpose, so long as the name of The University of + * Michigan is not used in any advertising or publicity + * pertaining to the use of distribution of this software + * without specific, written prior authorization. If the + * above copyright notice or any other identification of the + * University of Michigan is included in any copy of any + * portion of this software, then the disclaimer below must + * also be included. + * + * THIS SOFTWARE IS PROVIDED AS IS, WITHOUT REPRESENTATION + * FROM THE UNIVERSITY OF MICHIGAN AS TO ITS FITNESS FOR ANY + * PURPOSE, AND WITHOUT WARRANTY BY THE UNIVERSITY OF + * MICHIGAN OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE + * REGENTS OF THE UNIVERSITY OF MICHIGAN SHALL NOT BE LIABLE + * FOR ANY DAMAGES, INCLUDING SPECIAL, INDIRECT, INCIDENTAL, OR + * CONSEQUENTIAL DAMAGES, WITH RESPECT TO ANY CLAIM ARISING + * OUT OF OR IN CONNECTION WITH THE USE OF THE SOFTWARE, EVEN + * IF IT HAS BEEN OR IS HEREAFTER ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGES. + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include <linux/err.h> +#include <linux/types.h> +#include <linux/crypto.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/sunrpc/xdr.h> +#include <linux/lcm.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +/* + * This is the n-fold function as described in rfc3961, sec 5.1 + * Taken from MIT Kerberos and modified. + */ + +static void krb5_nfold(u32 inbits, const u8 *in, + u32 outbits, u8 *out) +{ + unsigned long ulcm; + int byte, i, msbit; + + /* the code below is more readable if I make these bytes + instead of bits */ + + inbits >>= 3; + outbits >>= 3; + + /* first compute lcm(n,k) */ + ulcm = lcm(inbits, outbits); + + /* now do the real work */ + + memset(out, 0, outbits); + byte = 0; + + /* this will end up cycling through k lcm(k,n)/k times, which + is correct */ + for (i = ulcm-1; i >= 0; i--) { + /* compute the msbit in k which gets added into this byte */ + msbit = ( + /* first, start with the msbit in the first, + * unrotated byte */ + ((inbits << 3) - 1) + /* then, for each byte, shift to the right + * for each repetition */ + + (((inbits << 3) + 13) * (i/inbits)) + /* last, pick out the correct byte within + * that shifted repetition */ + + ((inbits - (i % inbits)) << 3) + ) % (inbits << 3); + + /* pull out the byte value itself */ + byte += (((in[((inbits - 1) - (msbit >> 3)) % inbits] << 8)| + (in[((inbits) - (msbit >> 3)) % inbits])) + >> ((msbit & 7) + 1)) & 0xff; + + /* do the addition */ + byte += out[i % outbits]; + out[i % outbits] = byte & 0xff; + + /* keep around the carry bit, if any */ + byte >>= 8; + + } + + /* if there's a carry bit left over, add it back in */ + if (byte) { + for (i = outbits - 1; i >= 0; i--) { + /* do the addition */ + byte += out[i]; + out[i] = byte & 0xff; + + /* keep around the carry bit, if any */ + byte >>= 8; + } + } +} + +/* + * This is the DK (derive_key) function as described in rfc3961, sec 5.1 + * Taken from MIT Kerberos and modified. + */ + +u32 krb5_derive_key(const struct gss_krb5_enctype *gk5e, + const struct xdr_netobj *inkey, + struct xdr_netobj *outkey, + const struct xdr_netobj *in_constant, + gfp_t gfp_mask) +{ + size_t blocksize, keybytes, keylength, n; + unsigned char *inblockdata, *outblockdata, *rawkey; + struct xdr_netobj inblock, outblock; + struct crypto_blkcipher *cipher; + u32 ret = EINVAL; + + blocksize = gk5e->blocksize; + keybytes = gk5e->keybytes; + keylength = gk5e->keylength; + + if ((inkey->len != keylength) || (outkey->len != keylength)) + goto err_return; + + cipher = crypto_alloc_blkcipher(gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(cipher)) + goto err_return; + if (crypto_blkcipher_setkey(cipher, inkey->data, inkey->len)) + goto err_return; + + /* allocate and set up buffers */ + + ret = ENOMEM; + inblockdata = kmalloc(blocksize, gfp_mask); + if (inblockdata == NULL) + goto err_free_cipher; + + outblockdata = kmalloc(blocksize, gfp_mask); + if (outblockdata == NULL) + goto err_free_in; + + rawkey = kmalloc(keybytes, gfp_mask); + if (rawkey == NULL) + goto err_free_out; + + inblock.data = (char *) inblockdata; + inblock.len = blocksize; + + outblock.data = (char *) outblockdata; + outblock.len = blocksize; + + /* initialize the input block */ + + if (in_constant->len == inblock.len) { + memcpy(inblock.data, in_constant->data, inblock.len); + } else { + krb5_nfold(in_constant->len * 8, in_constant->data, + inblock.len * 8, inblock.data); + } + + /* loop encrypting the blocks until enough key bytes are generated */ + + n = 0; + while (n < keybytes) { + (*(gk5e->encrypt))(cipher, NULL, inblock.data, + outblock.data, inblock.len); + + if ((keybytes - n) <= outblock.len) { + memcpy(rawkey + n, outblock.data, (keybytes - n)); + break; + } + + memcpy(rawkey + n, outblock.data, outblock.len); + memcpy(inblock.data, outblock.data, outblock.len); + n += outblock.len; + } + + /* postprocess the key */ + + inblock.data = (char *) rawkey; + inblock.len = keybytes; + + BUG_ON(gk5e->mk_key == NULL); + ret = (*(gk5e->mk_key))(gk5e, &inblock, outkey); + if (ret) { + dprintk("%s: got %d from mk_key function for '%s'\n", + __func__, ret, gk5e->encrypt_name); + goto err_free_raw; + } + + /* clean memory, free resources and exit */ + + ret = 0; + +err_free_raw: + memset(rawkey, 0, keybytes); + kfree(rawkey); +err_free_out: + memset(outblockdata, 0, blocksize); + kfree(outblockdata); +err_free_in: + memset(inblockdata, 0, blocksize); + kfree(inblockdata); +err_free_cipher: + crypto_free_blkcipher(cipher); +err_return: + return ret; +} + +#define smask(step) ((1<<step)-1) +#define pstep(x, step) (((x)&smask(step))^(((x)>>step)&smask(step))) +#define parity_char(x) pstep(pstep(pstep((x), 4), 2), 1) + +static void mit_des_fixup_key_parity(u8 key[8]) +{ + int i; + for (i = 0; i < 8; i++) { + key[i] &= 0xfe; + key[i] |= 1^parity_char(key[i]); + } +} + +/* + * This is the des3 key derivation postprocess function + */ +u32 gss_krb5_des3_make_key(const struct gss_krb5_enctype *gk5e, + struct xdr_netobj *randombits, + struct xdr_netobj *key) +{ + int i; + u32 ret = EINVAL; + + if (key->len != 24) { + dprintk("%s: key->len is %d\n", __func__, key->len); + goto err_out; + } + if (randombits->len != 21) { + dprintk("%s: randombits->len is %d\n", + __func__, randombits->len); + goto err_out; + } + + /* take the seven bytes, move them around into the top 7 bits of the + 8 key bytes, then compute the parity bits. Do this three times. */ + + for (i = 0; i < 3; i++) { + memcpy(key->data + i*8, randombits->data + i*7, 7); + key->data[i*8+7] = (((key->data[i*8]&1)<<1) | + ((key->data[i*8+1]&1)<<2) | + ((key->data[i*8+2]&1)<<3) | + ((key->data[i*8+3]&1)<<4) | + ((key->data[i*8+4]&1)<<5) | + ((key->data[i*8+5]&1)<<6) | + ((key->data[i*8+6]&1)<<7)); + + mit_des_fixup_key_parity(key->data + i*8); + } + ret = 0; +err_out: + return ret; +} + +/* + * This is the aes key derivation postprocess function + */ +u32 gss_krb5_aes_make_key(const struct gss_krb5_enctype *gk5e, + struct xdr_netobj *randombits, + struct xdr_netobj *key) +{ + u32 ret = EINVAL; + + if (key->len != 16 && key->len != 32) { + dprintk("%s: key->len is %d\n", __func__, key->len); + goto err_out; + } + if (randombits->len != 16 && randombits->len != 32) { + dprintk("%s: randombits->len is %d\n", + __func__, randombits->len); + goto err_out; + } + if (randombits->len != key->len) { + dprintk("%s: randombits->len is %d, key->len is %d\n", + __func__, randombits->len, key->len); + goto err_out; + } + memcpy(key->data, randombits->data, key->len); + ret = 0; +err_out: + return ret; +} + diff --git a/net/sunrpc/auth_gss/gss_krb5_mech.c b/net/sunrpc/auth_gss/gss_krb5_mech.c index 2deb0ed72ff..0d3c158ef8f 100644 --- a/net/sunrpc/auth_gss/gss_krb5_mech.c +++ b/net/sunrpc/auth_gss/gss_krb5_mech.c @@ -1,7 +1,7 @@ /* * linux/net/sunrpc/gss_krb5_mech.c * - * Copyright (c) 2001 The Regents of the University of Michigan. + * Copyright (c) 2001-2008 The Regents of the University of Michigan. * All rights reserved. * * Andy Adamson <andros@umich.edu> @@ -43,11 +43,149 @@ #include <linux/sunrpc/gss_krb5.h> #include <linux/sunrpc/xdr.h> #include <linux/crypto.h> +#include <linux/sunrpc/gss_krb5_enctypes.h> #ifdef RPC_DEBUG # define RPCDBG_FACILITY RPCDBG_AUTH #endif +static struct gss_api_mech gss_kerberos_mech; /* forward declaration */ + +static const struct gss_krb5_enctype supported_gss_krb5_enctypes[] = { + /* + * DES (All DES enctypes are mapped to the same gss functionality) + */ + { + .etype = ENCTYPE_DES_CBC_RAW, + .ctype = CKSUMTYPE_RSA_MD5, + .name = "des-cbc-crc", + .encrypt_name = "cbc(des)", + .cksum_name = "md5", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = NULL, + .signalg = SGN_ALG_DES_MAC_MD5, + .sealalg = SEAL_ALG_DES, + .keybytes = 7, + .keylength = 8, + .blocksize = 8, + .conflen = 8, + .cksumlength = 8, + .keyed_cksum = 0, + }, + /* + * RC4-HMAC + */ + { + .etype = ENCTYPE_ARCFOUR_HMAC, + .ctype = CKSUMTYPE_HMAC_MD5_ARCFOUR, + .name = "rc4-hmac", + .encrypt_name = "ecb(arc4)", + .cksum_name = "hmac(md5)", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = NULL, + .signalg = SGN_ALG_HMAC_MD5, + .sealalg = SEAL_ALG_MICROSOFT_RC4, + .keybytes = 16, + .keylength = 16, + .blocksize = 1, + .conflen = 8, + .cksumlength = 8, + .keyed_cksum = 1, + }, + /* + * 3DES + */ + { + .etype = ENCTYPE_DES3_CBC_RAW, + .ctype = CKSUMTYPE_HMAC_SHA1_DES3, + .name = "des3-hmac-sha1", + .encrypt_name = "cbc(des3_ede)", + .cksum_name = "hmac(sha1)", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = gss_krb5_des3_make_key, + .signalg = SGN_ALG_HMAC_SHA1_DES3_KD, + .sealalg = SEAL_ALG_DES3KD, + .keybytes = 21, + .keylength = 24, + .blocksize = 8, + .conflen = 8, + .cksumlength = 20, + .keyed_cksum = 1, + }, + /* + * AES128 + */ + { + .etype = ENCTYPE_AES128_CTS_HMAC_SHA1_96, + .ctype = CKSUMTYPE_HMAC_SHA1_96_AES128, + .name = "aes128-cts", + .encrypt_name = "cts(cbc(aes))", + .cksum_name = "hmac(sha1)", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = gss_krb5_aes_make_key, + .encrypt_v2 = gss_krb5_aes_encrypt, + .decrypt_v2 = gss_krb5_aes_decrypt, + .signalg = -1, + .sealalg = -1, + .keybytes = 16, + .keylength = 16, + .blocksize = 16, + .conflen = 16, + .cksumlength = 12, + .keyed_cksum = 1, + }, + /* + * AES256 + */ + { + .etype = ENCTYPE_AES256_CTS_HMAC_SHA1_96, + .ctype = CKSUMTYPE_HMAC_SHA1_96_AES256, + .name = "aes256-cts", + .encrypt_name = "cts(cbc(aes))", + .cksum_name = "hmac(sha1)", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = gss_krb5_aes_make_key, + .encrypt_v2 = gss_krb5_aes_encrypt, + .decrypt_v2 = gss_krb5_aes_decrypt, + .signalg = -1, + .sealalg = -1, + .keybytes = 32, + .keylength = 32, + .blocksize = 16, + .conflen = 16, + .cksumlength = 12, + .keyed_cksum = 1, + }, +}; + +static const int num_supported_enctypes = + ARRAY_SIZE(supported_gss_krb5_enctypes); + +static int +supported_gss_krb5_enctype(int etype) +{ + int i; + for (i = 0; i < num_supported_enctypes; i++) + if (supported_gss_krb5_enctypes[i].etype == etype) + return 1; + return 0; +} + +static const struct gss_krb5_enctype * +get_gss_krb5_enctype(int etype) +{ + int i; + for (i = 0; i < num_supported_enctypes; i++) + if (supported_gss_krb5_enctypes[i].etype == etype) + return &supported_gss_krb5_enctypes[i]; + return NULL; +} + static const void * simple_get_bytes(const void *p, const void *end, void *res, int len) { @@ -78,35 +216,46 @@ simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res) } static inline const void * -get_key(const void *p, const void *end, struct crypto_blkcipher **res) +get_key(const void *p, const void *end, + struct krb5_ctx *ctx, struct crypto_blkcipher **res) { struct xdr_netobj key; int alg; - char *alg_name; p = simple_get_bytes(p, end, &alg, sizeof(alg)); if (IS_ERR(p)) goto out_err; + + switch (alg) { + case ENCTYPE_DES_CBC_CRC: + case ENCTYPE_DES_CBC_MD4: + case ENCTYPE_DES_CBC_MD5: + /* Map all these key types to ENCTYPE_DES_CBC_RAW */ + alg = ENCTYPE_DES_CBC_RAW; + break; + } + + if (!supported_gss_krb5_enctype(alg)) { + printk(KERN_WARNING "gss_kerberos_mech: unsupported " + "encryption key algorithm %d\n", alg); + p = ERR_PTR(-EINVAL); + goto out_err; + } p = simple_get_netobj(p, end, &key); if (IS_ERR(p)) goto out_err; - switch (alg) { - case ENCTYPE_DES_CBC_RAW: - alg_name = "cbc(des)"; - break; - default: - printk("gss_kerberos_mech: unsupported algorithm %d\n", alg); - goto out_err_free_key; - } - *res = crypto_alloc_blkcipher(alg_name, 0, CRYPTO_ALG_ASYNC); + *res = crypto_alloc_blkcipher(ctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); if (IS_ERR(*res)) { - printk("gss_kerberos_mech: unable to initialize crypto algorithm %s\n", alg_name); + printk(KERN_WARNING "gss_kerberos_mech: unable to initialize " + "crypto algorithm %s\n", ctx->gk5e->encrypt_name); *res = NULL; goto out_err_free_key; } if (crypto_blkcipher_setkey(*res, key.data, key.len)) { - printk("gss_kerberos_mech: error setting key for crypto algorithm %s\n", alg_name); + printk(KERN_WARNING "gss_kerberos_mech: error setting key for " + "crypto algorithm %s\n", ctx->gk5e->encrypt_name); goto out_err_free_tfm; } @@ -123,56 +272,59 @@ out_err: } static int -gss_import_sec_context_kerberos(const void *p, - size_t len, - struct gss_ctx *ctx_id) +gss_import_v1_context(const void *p, const void *end, struct krb5_ctx *ctx) { - const void *end = (const void *)((const char *)p + len); - struct krb5_ctx *ctx; int tmp; - if (!(ctx = kzalloc(sizeof(*ctx), GFP_NOFS))) { - p = ERR_PTR(-ENOMEM); + p = simple_get_bytes(p, end, &ctx->initiate, sizeof(ctx->initiate)); + if (IS_ERR(p)) + goto out_err; + + /* Old format supports only DES! Any other enctype uses new format */ + ctx->enctype = ENCTYPE_DES_CBC_RAW; + + ctx->gk5e = get_gss_krb5_enctype(ctx->enctype); + if (ctx->gk5e == NULL) { + p = ERR_PTR(-EINVAL); goto out_err; } - p = simple_get_bytes(p, end, &ctx->initiate, sizeof(ctx->initiate)); - if (IS_ERR(p)) - goto out_err_free_ctx; /* The downcall format was designed before we completely understood * the uses of the context fields; so it includes some stuff we * just give some minimal sanity-checking, and some we ignore * completely (like the next twenty bytes): */ - if (unlikely(p + 20 > end || p + 20 < p)) - goto out_err_free_ctx; + if (unlikely(p + 20 > end || p + 20 < p)) { + p = ERR_PTR(-EFAULT); + goto out_err; + } p += 20; p = simple_get_bytes(p, end, &tmp, sizeof(tmp)); if (IS_ERR(p)) - goto out_err_free_ctx; + goto out_err; if (tmp != SGN_ALG_DES_MAC_MD5) { p = ERR_PTR(-ENOSYS); - goto out_err_free_ctx; + goto out_err; } p = simple_get_bytes(p, end, &tmp, sizeof(tmp)); if (IS_ERR(p)) - goto out_err_free_ctx; + goto out_err; if (tmp != SEAL_ALG_DES) { p = ERR_PTR(-ENOSYS); - goto out_err_free_ctx; + goto out_err; } p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); if (IS_ERR(p)) - goto out_err_free_ctx; + goto out_err; p = simple_get_bytes(p, end, &ctx->seq_send, sizeof(ctx->seq_send)); if (IS_ERR(p)) - goto out_err_free_ctx; + goto out_err; p = simple_get_netobj(p, end, &ctx->mech_used); if (IS_ERR(p)) - goto out_err_free_ctx; - p = get_key(p, end, &ctx->enc); + goto out_err; + p = get_key(p, end, ctx, &ctx->enc); if (IS_ERR(p)) goto out_err_free_mech; - p = get_key(p, end, &ctx->seq); + p = get_key(p, end, ctx, &ctx->seq); if (IS_ERR(p)) goto out_err_free_key1; if (p != end) { @@ -180,9 +332,6 @@ gss_import_sec_context_kerberos(const void *p, goto out_err_free_key2; } - ctx_id->internal_ctx_id = ctx; - - dprintk("RPC: Successfully imported new context.\n"); return 0; out_err_free_key2: @@ -191,18 +340,382 @@ out_err_free_key1: crypto_free_blkcipher(ctx->enc); out_err_free_mech: kfree(ctx->mech_used.data); -out_err_free_ctx: - kfree(ctx); out_err: return PTR_ERR(p); } +static struct crypto_blkcipher * +context_v2_alloc_cipher(struct krb5_ctx *ctx, const char *cname, u8 *key) +{ + struct crypto_blkcipher *cp; + + cp = crypto_alloc_blkcipher(cname, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(cp)) { + dprintk("gss_kerberos_mech: unable to initialize " + "crypto algorithm %s\n", cname); + return NULL; + } + if (crypto_blkcipher_setkey(cp, key, ctx->gk5e->keylength)) { + dprintk("gss_kerberos_mech: error setting key for " + "crypto algorithm %s\n", cname); + crypto_free_blkcipher(cp); + return NULL; + } + return cp; +} + +static inline void +set_cdata(u8 cdata[GSS_KRB5_K5CLENGTH], u32 usage, u8 seed) +{ + cdata[0] = (usage>>24)&0xff; + cdata[1] = (usage>>16)&0xff; + cdata[2] = (usage>>8)&0xff; + cdata[3] = usage&0xff; + cdata[4] = seed; +} + +static int +context_derive_keys_des3(struct krb5_ctx *ctx, gfp_t gfp_mask) +{ + struct xdr_netobj c, keyin, keyout; + u8 cdata[GSS_KRB5_K5CLENGTH]; + u32 err; + + c.len = GSS_KRB5_K5CLENGTH; + c.data = cdata; + + keyin.data = ctx->Ksess; + keyin.len = ctx->gk5e->keylength; + keyout.len = ctx->gk5e->keylength; + + /* seq uses the raw key */ + ctx->seq = context_v2_alloc_cipher(ctx, ctx->gk5e->encrypt_name, + ctx->Ksess); + if (ctx->seq == NULL) + goto out_err; + + ctx->enc = context_v2_alloc_cipher(ctx, ctx->gk5e->encrypt_name, + ctx->Ksess); + if (ctx->enc == NULL) + goto out_free_seq; + + /* derive cksum */ + set_cdata(cdata, KG_USAGE_SIGN, KEY_USAGE_SEED_CHECKSUM); + keyout.data = ctx->cksum; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving cksum key\n", + __func__, err); + goto out_free_enc; + } + + return 0; + +out_free_enc: + crypto_free_blkcipher(ctx->enc); +out_free_seq: + crypto_free_blkcipher(ctx->seq); +out_err: + return -EINVAL; +} + +/* + * Note that RC4 depends on deriving keys using the sequence + * number or the checksum of a token. Therefore, the final keys + * cannot be calculated until the token is being constructed! + */ +static int +context_derive_keys_rc4(struct krb5_ctx *ctx) +{ + struct crypto_hash *hmac; + char sigkeyconstant[] = "signaturekey"; + int slen = strlen(sigkeyconstant) + 1; /* include null terminator */ + struct hash_desc desc; + struct scatterlist sg[1]; + int err; + + dprintk("RPC: %s: entered\n", __func__); + /* + * derive cksum (aka Ksign) key + */ + hmac = crypto_alloc_hash(ctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(hmac)) { + dprintk("%s: error %ld allocating hash '%s'\n", + __func__, PTR_ERR(hmac), ctx->gk5e->cksum_name); + err = PTR_ERR(hmac); + goto out_err; + } + + err = crypto_hash_setkey(hmac, ctx->Ksess, ctx->gk5e->keylength); + if (err) + goto out_err_free_hmac; + + sg_init_table(sg, 1); + sg_set_buf(sg, sigkeyconstant, slen); + + desc.tfm = hmac; + desc.flags = 0; + + err = crypto_hash_init(&desc); + if (err) + goto out_err_free_hmac; + + err = crypto_hash_digest(&desc, sg, slen, ctx->cksum); + if (err) + goto out_err_free_hmac; + /* + * allocate hash, and blkciphers for data and seqnum encryption + */ + ctx->enc = crypto_alloc_blkcipher(ctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(ctx->enc)) { + err = PTR_ERR(ctx->enc); + goto out_err_free_hmac; + } + + ctx->seq = crypto_alloc_blkcipher(ctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(ctx->seq)) { + crypto_free_blkcipher(ctx->enc); + err = PTR_ERR(ctx->seq); + goto out_err_free_hmac; + } + + dprintk("RPC: %s: returning success\n", __func__); + + err = 0; + +out_err_free_hmac: + crypto_free_hash(hmac); +out_err: + dprintk("RPC: %s: returning %d\n", __func__, err); + return err; +} + +static int +context_derive_keys_new(struct krb5_ctx *ctx, gfp_t gfp_mask) +{ + struct xdr_netobj c, keyin, keyout; + u8 cdata[GSS_KRB5_K5CLENGTH]; + u32 err; + + c.len = GSS_KRB5_K5CLENGTH; + c.data = cdata; + + keyin.data = ctx->Ksess; + keyin.len = ctx->gk5e->keylength; + keyout.len = ctx->gk5e->keylength; + + /* initiator seal encryption */ + set_cdata(cdata, KG_USAGE_INITIATOR_SEAL, KEY_USAGE_SEED_ENCRYPTION); + keyout.data = ctx->initiator_seal; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving initiator_seal key\n", + __func__, err); + goto out_err; + } + ctx->initiator_enc = context_v2_alloc_cipher(ctx, + ctx->gk5e->encrypt_name, + ctx->initiator_seal); + if (ctx->initiator_enc == NULL) + goto out_err; + + /* acceptor seal encryption */ + set_cdata(cdata, KG_USAGE_ACCEPTOR_SEAL, KEY_USAGE_SEED_ENCRYPTION); + keyout.data = ctx->acceptor_seal; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving acceptor_seal key\n", + __func__, err); + goto out_free_initiator_enc; + } + ctx->acceptor_enc = context_v2_alloc_cipher(ctx, + ctx->gk5e->encrypt_name, + ctx->acceptor_seal); + if (ctx->acceptor_enc == NULL) + goto out_free_initiator_enc; + + /* initiator sign checksum */ + set_cdata(cdata, KG_USAGE_INITIATOR_SIGN, KEY_USAGE_SEED_CHECKSUM); + keyout.data = ctx->initiator_sign; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving initiator_sign key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + /* acceptor sign checksum */ + set_cdata(cdata, KG_USAGE_ACCEPTOR_SIGN, KEY_USAGE_SEED_CHECKSUM); + keyout.data = ctx->acceptor_sign; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving acceptor_sign key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + /* initiator seal integrity */ + set_cdata(cdata, KG_USAGE_INITIATOR_SEAL, KEY_USAGE_SEED_INTEGRITY); + keyout.data = ctx->initiator_integ; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving initiator_integ key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + /* acceptor seal integrity */ + set_cdata(cdata, KG_USAGE_ACCEPTOR_SEAL, KEY_USAGE_SEED_INTEGRITY); + keyout.data = ctx->acceptor_integ; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving acceptor_integ key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + switch (ctx->enctype) { + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + ctx->initiator_enc_aux = + context_v2_alloc_cipher(ctx, "cbc(aes)", + ctx->initiator_seal); + if (ctx->initiator_enc_aux == NULL) + goto out_free_acceptor_enc; + ctx->acceptor_enc_aux = + context_v2_alloc_cipher(ctx, "cbc(aes)", + ctx->acceptor_seal); + if (ctx->acceptor_enc_aux == NULL) { + crypto_free_blkcipher(ctx->initiator_enc_aux); + goto out_free_acceptor_enc; + } + } + + return 0; + +out_free_acceptor_enc: + crypto_free_blkcipher(ctx->acceptor_enc); +out_free_initiator_enc: + crypto_free_blkcipher(ctx->initiator_enc); +out_err: + return -EINVAL; +} + +static int +gss_import_v2_context(const void *p, const void *end, struct krb5_ctx *ctx, + gfp_t gfp_mask) +{ + int keylen; + + p = simple_get_bytes(p, end, &ctx->flags, sizeof(ctx->flags)); + if (IS_ERR(p)) + goto out_err; + ctx->initiate = ctx->flags & KRB5_CTX_FLAG_INITIATOR; + + p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); + if (IS_ERR(p)) + goto out_err; + p = simple_get_bytes(p, end, &ctx->seq_send64, sizeof(ctx->seq_send64)); + if (IS_ERR(p)) + goto out_err; + /* set seq_send for use by "older" enctypes */ + ctx->seq_send = ctx->seq_send64; + if (ctx->seq_send64 != ctx->seq_send) { + dprintk("%s: seq_send64 %lx, seq_send %x overflow?\n", __func__, + (unsigned long)ctx->seq_send64, ctx->seq_send); + p = ERR_PTR(-EINVAL); + goto out_err; + } + p = simple_get_bytes(p, end, &ctx->enctype, sizeof(ctx->enctype)); + if (IS_ERR(p)) + goto out_err; + /* Map ENCTYPE_DES3_CBC_SHA1 to ENCTYPE_DES3_CBC_RAW */ + if (ctx->enctype == ENCTYPE_DES3_CBC_SHA1) + ctx->enctype = ENCTYPE_DES3_CBC_RAW; + ctx->gk5e = get_gss_krb5_enctype(ctx->enctype); + if (ctx->gk5e == NULL) { + dprintk("gss_kerberos_mech: unsupported krb5 enctype %u\n", + ctx->enctype); + p = ERR_PTR(-EINVAL); + goto out_err; + } + keylen = ctx->gk5e->keylength; + + p = simple_get_bytes(p, end, ctx->Ksess, keylen); + if (IS_ERR(p)) + goto out_err; + + if (p != end) { + p = ERR_PTR(-EINVAL); + goto out_err; + } + + ctx->mech_used.data = kmemdup(gss_kerberos_mech.gm_oid.data, + gss_kerberos_mech.gm_oid.len, gfp_mask); + if (unlikely(ctx->mech_used.data == NULL)) { + p = ERR_PTR(-ENOMEM); + goto out_err; + } + ctx->mech_used.len = gss_kerberos_mech.gm_oid.len; + + switch (ctx->enctype) { + case ENCTYPE_DES3_CBC_RAW: + return context_derive_keys_des3(ctx, gfp_mask); + case ENCTYPE_ARCFOUR_HMAC: + return context_derive_keys_rc4(ctx); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return context_derive_keys_new(ctx, gfp_mask); + default: + return -EINVAL; + } + +out_err: + return PTR_ERR(p); +} + +static int +gss_import_sec_context_kerberos(const void *p, size_t len, + struct gss_ctx *ctx_id, + time_t *endtime, + gfp_t gfp_mask) +{ + const void *end = (const void *)((const char *)p + len); + struct krb5_ctx *ctx; + int ret; + + ctx = kzalloc(sizeof(*ctx), gfp_mask); + if (ctx == NULL) + return -ENOMEM; + + if (len == 85) + ret = gss_import_v1_context(p, end, ctx); + else + ret = gss_import_v2_context(p, end, ctx, gfp_mask); + + if (ret == 0) { + ctx_id->internal_ctx_id = ctx; + if (endtime) + *endtime = ctx->endtime; + } else + kfree(ctx); + + dprintk("RPC: %s: returning %d\n", __func__, ret); + return ret; +} + static void gss_delete_sec_context_kerberos(void *internal_ctx) { struct krb5_ctx *kctx = internal_ctx; crypto_free_blkcipher(kctx->seq); crypto_free_blkcipher(kctx->enc); + crypto_free_blkcipher(kctx->acceptor_enc); + crypto_free_blkcipher(kctx->initiator_enc); + crypto_free_blkcipher(kctx->acceptor_enc_aux); + crypto_free_blkcipher(kctx->initiator_enc_aux); kfree(kctx->mech_used.data); kfree(kctx); } @@ -219,28 +732,40 @@ static const struct gss_api_ops gss_kerberos_ops = { static struct pf_desc gss_kerberos_pfs[] = { [0] = { .pseudoflavor = RPC_AUTH_GSS_KRB5, + .qop = GSS_C_QOP_DEFAULT, .service = RPC_GSS_SVC_NONE, .name = "krb5", }, [1] = { .pseudoflavor = RPC_AUTH_GSS_KRB5I, + .qop = GSS_C_QOP_DEFAULT, .service = RPC_GSS_SVC_INTEGRITY, .name = "krb5i", }, [2] = { .pseudoflavor = RPC_AUTH_GSS_KRB5P, + .qop = GSS_C_QOP_DEFAULT, .service = RPC_GSS_SVC_PRIVACY, .name = "krb5p", }, }; +MODULE_ALIAS("rpc-auth-gss-krb5"); +MODULE_ALIAS("rpc-auth-gss-krb5i"); +MODULE_ALIAS("rpc-auth-gss-krb5p"); +MODULE_ALIAS("rpc-auth-gss-390003"); +MODULE_ALIAS("rpc-auth-gss-390004"); +MODULE_ALIAS("rpc-auth-gss-390005"); +MODULE_ALIAS("rpc-auth-gss-1.2.840.113554.1.2.2"); + static struct gss_api_mech gss_kerberos_mech = { .gm_name = "krb5", .gm_owner = THIS_MODULE, - .gm_oid = {9, (void *)"\x2a\x86\x48\x86\xf7\x12\x01\x02\x02"}, + .gm_oid = { 9, "\x2a\x86\x48\x86\xf7\x12\x01\x02\x02" }, .gm_ops = &gss_kerberos_ops, .gm_pf_num = ARRAY_SIZE(gss_kerberos_pfs), .gm_pfs = gss_kerberos_pfs, + .gm_upcall_enctypes = KRB5_SUPPORTED_ENCTYPES, }; static int __init init_kerberos_module(void) diff --git a/net/sunrpc/auth_gss/gss_krb5_seal.c b/net/sunrpc/auth_gss/gss_krb5_seal.c index b8f42ef7178..62ae3273186 100644 --- a/net/sunrpc/auth_gss/gss_krb5_seal.c +++ b/net/sunrpc/auth_gss/gss_krb5_seal.c @@ -3,7 +3,7 @@ * * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/k5seal.c * - * Copyright (c) 2000 The Regents of the University of Michigan. + * Copyright (c) 2000-2008 The Regents of the University of Michigan. * All rights reserved. * * Andy Adamson <andros@umich.edu> @@ -59,7 +59,6 @@ */ #include <linux/types.h> -#include <linux/slab.h> #include <linux/jiffies.h> #include <linux/sunrpc/gss_krb5.h> #include <linux/random.h> @@ -71,53 +70,154 @@ DEFINE_SPINLOCK(krb5_seq_lock); -u32 -gss_get_mic_kerberos(struct gss_ctx *gss_ctx, struct xdr_buf *text, +static char * +setup_token(struct krb5_ctx *ctx, struct xdr_netobj *token) +{ + __be16 *ptr, *krb5_hdr; + int body_size = GSS_KRB5_TOK_HDR_LEN + ctx->gk5e->cksumlength; + + token->len = g_token_size(&ctx->mech_used, body_size); + + ptr = (__be16 *)token->data; + g_make_token_header(&ctx->mech_used, body_size, (unsigned char **)&ptr); + + /* ptr now at start of header described in rfc 1964, section 1.2.1: */ + krb5_hdr = ptr; + *ptr++ = KG_TOK_MIC_MSG; + *ptr++ = cpu_to_le16(ctx->gk5e->signalg); + *ptr++ = SEAL_ALG_NONE; + *ptr++ = 0xffff; + + return (char *)krb5_hdr; +} + +static void * +setup_token_v2(struct krb5_ctx *ctx, struct xdr_netobj *token) +{ + __be16 *ptr, *krb5_hdr; + u8 *p, flags = 0x00; + + if ((ctx->flags & KRB5_CTX_FLAG_INITIATOR) == 0) + flags |= 0x01; + if (ctx->flags & KRB5_CTX_FLAG_ACCEPTOR_SUBKEY) + flags |= 0x04; + + /* Per rfc 4121, sec 4.2.6.1, there is no header, + * just start the token */ + krb5_hdr = ptr = (__be16 *)token->data; + + *ptr++ = KG2_TOK_MIC; + p = (u8 *)ptr; + *p++ = flags; + *p++ = 0xff; + ptr = (__be16 *)p; + *ptr++ = 0xffff; + *ptr++ = 0xffff; + + token->len = GSS_KRB5_TOK_HDR_LEN + ctx->gk5e->cksumlength; + return krb5_hdr; +} + +static u32 +gss_get_mic_v1(struct krb5_ctx *ctx, struct xdr_buf *text, struct xdr_netobj *token) { - struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; - char cksumdata[16]; - struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; - unsigned char *ptr, *msg_start; + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; + void *ptr; s32 now; u32 seq_send; + u8 *cksumkey; - dprintk("RPC: gss_krb5_seal\n"); + dprintk("RPC: %s\n", __func__); BUG_ON(ctx == NULL); now = get_seconds(); - token->len = g_token_size(&ctx->mech_used, GSS_KRB5_TOK_HDR_LEN + 8); + ptr = setup_token(ctx, token); - ptr = token->data; - g_make_token_header(&ctx->mech_used, GSS_KRB5_TOK_HDR_LEN + 8, &ptr); + if (ctx->gk5e->keyed_cksum) + cksumkey = ctx->cksum; + else + cksumkey = NULL; - /* ptr now at header described in rfc 1964, section 1.2.1: */ - ptr[0] = (unsigned char) ((KG_TOK_MIC_MSG >> 8) & 0xff); - ptr[1] = (unsigned char) (KG_TOK_MIC_MSG & 0xff); + if (make_checksum(ctx, ptr, 8, text, 0, cksumkey, + KG_USAGE_SIGN, &md5cksum)) + return GSS_S_FAILURE; - msg_start = ptr + GSS_KRB5_TOK_HDR_LEN + 8; + memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data, md5cksum.len); - *(__be16 *)(ptr + 2) = htons(SGN_ALG_DES_MAC_MD5); - memset(ptr + 4, 0xff, 4); + spin_lock(&krb5_seq_lock); + seq_send = ctx->seq_send++; + spin_unlock(&krb5_seq_lock); - if (make_checksum("md5", ptr, 8, text, 0, &md5cksum)) + if (krb5_make_seq_num(ctx, ctx->seq, ctx->initiate ? 0 : 0xff, + seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8)) return GSS_S_FAILURE; - if (krb5_encrypt(ctx->seq, NULL, md5cksum.data, - md5cksum.data, md5cksum.len)) - return GSS_S_FAILURE; + return (ctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} + +static u32 +gss_get_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *text, + struct xdr_netobj *token) +{ + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj cksumobj = { .len = sizeof(cksumdata), + .data = cksumdata}; + void *krb5_hdr; + s32 now; + u64 seq_send; + u8 *cksumkey; + unsigned int cksum_usage; + + dprintk("RPC: %s\n", __func__); - memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data + md5cksum.len - 8, 8); + krb5_hdr = setup_token_v2(ctx, token); + /* Set up the sequence number. Now 64-bits in clear + * text and w/o direction indicator */ spin_lock(&krb5_seq_lock); - seq_send = ctx->seq_send++; + seq_send = ctx->seq_send64++; spin_unlock(&krb5_seq_lock); - - if (krb5_make_seq_num(ctx->seq, ctx->initiate ? 0 : 0xff, - seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, - ptr + 8)) + *((u64 *)(krb5_hdr + 8)) = cpu_to_be64(seq_send); + + if (ctx->initiate) { + cksumkey = ctx->initiator_sign; + cksum_usage = KG_USAGE_INITIATOR_SIGN; + } else { + cksumkey = ctx->acceptor_sign; + cksum_usage = KG_USAGE_ACCEPTOR_SIGN; + } + + if (make_checksum_v2(ctx, krb5_hdr, GSS_KRB5_TOK_HDR_LEN, + text, 0, cksumkey, cksum_usage, &cksumobj)) return GSS_S_FAILURE; + memcpy(krb5_hdr + GSS_KRB5_TOK_HDR_LEN, cksumobj.data, cksumobj.len); + + now = get_seconds(); + return (ctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; } + +u32 +gss_get_mic_kerberos(struct gss_ctx *gss_ctx, struct xdr_buf *text, + struct xdr_netobj *token) +{ + struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; + + switch (ctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + case ENCTYPE_ARCFOUR_HMAC: + return gss_get_mic_v1(ctx, text, token); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_get_mic_v2(ctx, text, token); + } +} + diff --git a/net/sunrpc/auth_gss/gss_krb5_seqnum.c b/net/sunrpc/auth_gss/gss_krb5_seqnum.c index 17562b4c35f..62ac90c62cb 100644 --- a/net/sunrpc/auth_gss/gss_krb5_seqnum.c +++ b/net/sunrpc/auth_gss/gss_krb5_seqnum.c @@ -32,7 +32,6 @@ */ #include <linux/types.h> -#include <linux/slab.h> #include <linux/sunrpc/gss_krb5.h> #include <linux/crypto.h> @@ -40,14 +39,51 @@ # define RPCDBG_FACILITY RPCDBG_AUTH #endif +static s32 +krb5_make_rc4_seq_num(struct krb5_ctx *kctx, int direction, s32 seqnum, + unsigned char *cksum, unsigned char *buf) +{ + struct crypto_blkcipher *cipher; + unsigned char plain[8]; + s32 code; + + dprintk("RPC: %s:\n", __func__); + cipher = crypto_alloc_blkcipher(kctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(cipher)) + return PTR_ERR(cipher); + + plain[0] = (unsigned char) ((seqnum >> 24) & 0xff); + plain[1] = (unsigned char) ((seqnum >> 16) & 0xff); + plain[2] = (unsigned char) ((seqnum >> 8) & 0xff); + plain[3] = (unsigned char) ((seqnum >> 0) & 0xff); + plain[4] = direction; + plain[5] = direction; + plain[6] = direction; + plain[7] = direction; + + code = krb5_rc4_setup_seq_key(kctx, cipher, cksum); + if (code) + goto out; + + code = krb5_encrypt(cipher, cksum, plain, buf, 8); +out: + crypto_free_blkcipher(cipher); + return code; +} s32 -krb5_make_seq_num(struct crypto_blkcipher *key, +krb5_make_seq_num(struct krb5_ctx *kctx, + struct crypto_blkcipher *key, int direction, u32 seqnum, unsigned char *cksum, unsigned char *buf) { unsigned char plain[8]; + if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) + return krb5_make_rc4_seq_num(kctx, direction, seqnum, + cksum, buf); + plain[0] = (unsigned char) (seqnum & 0xff); plain[1] = (unsigned char) ((seqnum >> 8) & 0xff); plain[2] = (unsigned char) ((seqnum >> 16) & 0xff); @@ -61,17 +97,59 @@ krb5_make_seq_num(struct crypto_blkcipher *key, return krb5_encrypt(key, cksum, plain, buf, 8); } +static s32 +krb5_get_rc4_seq_num(struct krb5_ctx *kctx, unsigned char *cksum, + unsigned char *buf, int *direction, s32 *seqnum) +{ + struct crypto_blkcipher *cipher; + unsigned char plain[8]; + s32 code; + + dprintk("RPC: %s:\n", __func__); + cipher = crypto_alloc_blkcipher(kctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(cipher)) + return PTR_ERR(cipher); + + code = krb5_rc4_setup_seq_key(kctx, cipher, cksum); + if (code) + goto out; + + code = krb5_decrypt(cipher, cksum, buf, plain, 8); + if (code) + goto out; + + if ((plain[4] != plain[5]) || (plain[4] != plain[6]) + || (plain[4] != plain[7])) { + code = (s32)KG_BAD_SEQ; + goto out; + } + + *direction = plain[4]; + + *seqnum = ((plain[0] << 24) | (plain[1] << 16) | + (plain[2] << 8) | (plain[3])); +out: + crypto_free_blkcipher(cipher); + return code; +} + s32 -krb5_get_seq_num(struct crypto_blkcipher *key, +krb5_get_seq_num(struct krb5_ctx *kctx, unsigned char *cksum, unsigned char *buf, int *direction, u32 *seqnum) { s32 code; unsigned char plain[8]; + struct crypto_blkcipher *key = kctx->seq; dprintk("RPC: krb5_get_seq_num:\n"); + if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) + return krb5_get_rc4_seq_num(kctx, cksum, buf, + direction, seqnum); + if ((code = krb5_decrypt(key, cksum, buf, plain, 8))) return code; @@ -84,5 +162,5 @@ krb5_get_seq_num(struct crypto_blkcipher *key, *seqnum = ((plain[0]) | (plain[1] << 8) | (plain[2] << 16) | (plain[3] << 24)); - return (0); + return 0; } diff --git a/net/sunrpc/auth_gss/gss_krb5_unseal.c b/net/sunrpc/auth_gss/gss_krb5_unseal.c index 066ec73c84d..6c981ddc19f 100644 --- a/net/sunrpc/auth_gss/gss_krb5_unseal.c +++ b/net/sunrpc/auth_gss/gss_krb5_unseal.c @@ -3,7 +3,7 @@ * * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/k5unseal.c * - * Copyright (c) 2000 The Regents of the University of Michigan. + * Copyright (c) 2000-2008 The Regents of the University of Michigan. * All rights reserved. * * Andy Adamson <andros@umich.edu> @@ -58,7 +58,6 @@ */ #include <linux/types.h> -#include <linux/slab.h> #include <linux/jiffies.h> #include <linux/sunrpc/gss_krb5.h> #include <linux/crypto.h> @@ -71,20 +70,21 @@ /* read_token is a mic token, and message_buffer is the data that the mic was * supposedly taken over. */ -u32 -gss_verify_mic_kerberos(struct gss_ctx *gss_ctx, +static u32 +gss_verify_mic_v1(struct krb5_ctx *ctx, struct xdr_buf *message_buffer, struct xdr_netobj *read_token) { - struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; int signalg; int sealalg; - char cksumdata[16]; - struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; s32 now; int direction; u32 seqnum; unsigned char *ptr = (unsigned char *)read_token->data; int bodysize; + u8 *cksumkey; dprintk("RPC: krb5_read_token\n"); @@ -99,7 +99,7 @@ gss_verify_mic_kerberos(struct gss_ctx *gss_ctx, /* XXX sanity-check bodysize?? */ signalg = ptr[2] + (ptr[3] << 8); - if (signalg != SGN_ALG_DES_MAC_MD5) + if (signalg != ctx->gk5e->signalg) return GSS_S_DEFECTIVE_TOKEN; sealalg = ptr[4] + (ptr[5] << 8); @@ -109,13 +109,17 @@ gss_verify_mic_kerberos(struct gss_ctx *gss_ctx, if ((ptr[6] != 0xff) || (ptr[7] != 0xff)) return GSS_S_DEFECTIVE_TOKEN; - if (make_checksum("md5", ptr, 8, message_buffer, 0, &md5cksum)) - return GSS_S_FAILURE; + if (ctx->gk5e->keyed_cksum) + cksumkey = ctx->cksum; + else + cksumkey = NULL; - if (krb5_encrypt(ctx->seq, NULL, md5cksum.data, md5cksum.data, 16)) + if (make_checksum(ctx, ptr, 8, message_buffer, 0, + cksumkey, KG_USAGE_SIGN, &md5cksum)) return GSS_S_FAILURE; - if (memcmp(md5cksum.data + 8, ptr + GSS_KRB5_TOK_HDR_LEN, 8)) + if (memcmp(md5cksum.data, ptr + GSS_KRB5_TOK_HDR_LEN, + ctx->gk5e->cksumlength)) return GSS_S_BAD_SIG; /* it got through unscathed. Make sure the context is unexpired */ @@ -127,7 +131,8 @@ gss_verify_mic_kerberos(struct gss_ctx *gss_ctx, /* do sequencing checks */ - if (krb5_get_seq_num(ctx->seq, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8, &direction, &seqnum)) + if (krb5_get_seq_num(ctx, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8, + &direction, &seqnum)) return GSS_S_FAILURE; if ((ctx->initiate && direction != 0xff) || @@ -136,3 +141,86 @@ gss_verify_mic_kerberos(struct gss_ctx *gss_ctx, return GSS_S_COMPLETE; } + +static u32 +gss_verify_mic_v2(struct krb5_ctx *ctx, + struct xdr_buf *message_buffer, struct xdr_netobj *read_token) +{ + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj cksumobj = {.len = sizeof(cksumdata), + .data = cksumdata}; + s32 now; + u8 *ptr = read_token->data; + u8 *cksumkey; + u8 flags; + int i; + unsigned int cksum_usage; + + dprintk("RPC: %s\n", __func__); + + if (be16_to_cpu(*((__be16 *)ptr)) != KG2_TOK_MIC) + return GSS_S_DEFECTIVE_TOKEN; + + flags = ptr[2]; + if ((!ctx->initiate && (flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR)) || + (ctx->initiate && !(flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR))) + return GSS_S_BAD_SIG; + + if (flags & KG2_TOKEN_FLAG_SEALED) { + dprintk("%s: token has unexpected sealed flag\n", __func__); + return GSS_S_FAILURE; + } + + for (i = 3; i < 8; i++) + if (ptr[i] != 0xff) + return GSS_S_DEFECTIVE_TOKEN; + + if (ctx->initiate) { + cksumkey = ctx->acceptor_sign; + cksum_usage = KG_USAGE_ACCEPTOR_SIGN; + } else { + cksumkey = ctx->initiator_sign; + cksum_usage = KG_USAGE_INITIATOR_SIGN; + } + + if (make_checksum_v2(ctx, ptr, GSS_KRB5_TOK_HDR_LEN, message_buffer, 0, + cksumkey, cksum_usage, &cksumobj)) + return GSS_S_FAILURE; + + if (memcmp(cksumobj.data, ptr + GSS_KRB5_TOK_HDR_LEN, + ctx->gk5e->cksumlength)) + return GSS_S_BAD_SIG; + + /* it got through unscathed. Make sure the context is unexpired */ + now = get_seconds(); + if (now > ctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* + * NOTE: the sequence number at ptr + 8 is skipped, rpcsec_gss + * doesn't want it checked; see page 6 of rfc 2203. + */ + + return GSS_S_COMPLETE; +} + +u32 +gss_verify_mic_kerberos(struct gss_ctx *gss_ctx, + struct xdr_buf *message_buffer, + struct xdr_netobj *read_token) +{ + struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; + + switch (ctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + case ENCTYPE_ARCFOUR_HMAC: + return gss_verify_mic_v1(ctx, message_buffer, read_token); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_verify_mic_v2(ctx, message_buffer, read_token); + } +} + diff --git a/net/sunrpc/auth_gss/gss_krb5_wrap.c b/net/sunrpc/auth_gss/gss_krb5_wrap.c index ae8e69b59c4..42560e55d97 100644 --- a/net/sunrpc/auth_gss/gss_krb5_wrap.c +++ b/net/sunrpc/auth_gss/gss_krb5_wrap.c @@ -1,5 +1,34 @@ +/* + * COPYRIGHT (c) 2008 + * The Regents of the University of Michigan + * ALL RIGHTS RESERVED + * + * Permission is granted to use, copy, create derivative works + * and redistribute this software and such derivative works + * for any purpose, so long as the name of The University of + * Michigan is not used in any advertising or publicity + * pertaining to the use of distribution of this software + * without specific, written prior authorization. If the + * above copyright notice or any other identification of the + * University of Michigan is included in any copy of any + * portion of this software, then the disclaimer below must + * also be included. + * + * THIS SOFTWARE IS PROVIDED AS IS, WITHOUT REPRESENTATION + * FROM THE UNIVERSITY OF MICHIGAN AS TO ITS FITNESS FOR ANY + * PURPOSE, AND WITHOUT WARRANTY BY THE UNIVERSITY OF + * MICHIGAN OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE + * REGENTS OF THE UNIVERSITY OF MICHIGAN SHALL NOT BE LIABLE + * FOR ANY DAMAGES, INCLUDING SPECIAL, INDIRECT, INCIDENTAL, OR + * CONSEQUENTIAL DAMAGES, WITH RESPECT TO ANY CLAIM ARISING + * OUT OF OR IN CONNECTION WITH THE USE OF THE SOFTWARE, EVEN + * IF IT HAS BEEN OR IS HEREAFTER ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGES. + */ + #include <linux/types.h> -#include <linux/slab.h> #include <linux/jiffies.h> #include <linux/sunrpc/gss_krb5.h> #include <linux/random.h> @@ -13,10 +42,7 @@ static inline int gss_krb5_padding(int blocksize, int length) { - /* Most of the code is block-size independent but currently we - * use only 8: */ - BUG_ON(blocksize != 8); - return 8 - (length & 7); + return blocksize - (length % blocksize); } static inline void @@ -56,9 +82,9 @@ gss_krb5_remove_padding(struct xdr_buf *buf, int blocksize) >>PAGE_CACHE_SHIFT; unsigned int offset = (buf->page_base + len - 1) & (PAGE_CACHE_SIZE - 1); - ptr = kmap_atomic(buf->pages[last], KM_USER0); + ptr = kmap_atomic(buf->pages[last]); pad = *(ptr + offset); - kunmap_atomic(ptr, KM_USER0); + kunmap_atomic(ptr); goto out; } else len -= buf->page_len; @@ -87,8 +113,8 @@ out: return 0; } -static void -make_confounder(char *p, u32 conflen) +void +gss_krb5_make_confounder(char *p, u32 conflen) { static u64 i = 0; u64 *q = (u64 *)p; @@ -104,8 +130,8 @@ make_confounder(char *p, u32 conflen) /* initialize to random value */ if (i == 0) { - i = random32(); - i = (i << 32) | random32(); + i = prandom_u32(); + i = (i << 32) | prandom_u32(); } switch (conflen) { @@ -128,69 +154,73 @@ make_confounder(char *p, u32 conflen) /* XXX factor out common code with seal/unseal. */ -u32 -gss_wrap_kerberos(struct gss_ctx *ctx, int offset, +static u32 +gss_wrap_kerberos_v1(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf, struct page **pages) { - struct krb5_ctx *kctx = ctx->internal_ctx_id; - char cksumdata[16]; - struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; int blocksize = 0, plainlen; unsigned char *ptr, *msg_start; s32 now; int headlen; struct page **tmp_pages; u32 seq_send; + u8 *cksumkey; + u32 conflen = kctx->gk5e->conflen; - dprintk("RPC: gss_wrap_kerberos\n"); + dprintk("RPC: %s\n", __func__); now = get_seconds(); blocksize = crypto_blkcipher_blocksize(kctx->enc); gss_krb5_add_padding(buf, offset, blocksize); BUG_ON((buf->len - offset) % blocksize); - plainlen = blocksize + buf->len - offset; + plainlen = conflen + buf->len - offset; - headlen = g_token_size(&kctx->mech_used, 24 + plainlen) - - (buf->len - offset); + headlen = g_token_size(&kctx->mech_used, + GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength + plainlen) - + (buf->len - offset); ptr = buf->head[0].iov_base + offset; /* shift data to make room for header. */ + xdr_extend_head(buf, offset, headlen); + /* XXX Would be cleverer to encrypt while copying. */ - /* XXX bounds checking, slack, etc. */ - memmove(ptr + headlen, ptr, buf->head[0].iov_len - offset); - buf->head[0].iov_len += headlen; - buf->len += headlen; BUG_ON((buf->len - offset - headlen) % blocksize); g_make_token_header(&kctx->mech_used, - GSS_KRB5_TOK_HDR_LEN + 8 + plainlen, &ptr); + GSS_KRB5_TOK_HDR_LEN + + kctx->gk5e->cksumlength + plainlen, &ptr); /* ptr now at header described in rfc 1964, section 1.2.1: */ ptr[0] = (unsigned char) ((KG_TOK_WRAP_MSG >> 8) & 0xff); ptr[1] = (unsigned char) (KG_TOK_WRAP_MSG & 0xff); - msg_start = ptr + 24; + msg_start = ptr + GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength; - *(__be16 *)(ptr + 2) = htons(SGN_ALG_DES_MAC_MD5); + *(__be16 *)(ptr + 2) = cpu_to_le16(kctx->gk5e->signalg); memset(ptr + 4, 0xff, 4); - *(__be16 *)(ptr + 4) = htons(SEAL_ALG_DES); + *(__be16 *)(ptr + 4) = cpu_to_le16(kctx->gk5e->sealalg); + + gss_krb5_make_confounder(msg_start, conflen); - make_confounder(msg_start, blocksize); + if (kctx->gk5e->keyed_cksum) + cksumkey = kctx->cksum; + else + cksumkey = NULL; /* XXXJBF: UGH!: */ tmp_pages = buf->pages; buf->pages = pages; - if (make_checksum("md5", ptr, 8, buf, - offset + headlen - blocksize, &md5cksum)) + if (make_checksum(kctx, ptr, 8, buf, offset + headlen - conflen, + cksumkey, KG_USAGE_SEAL, &md5cksum)) return GSS_S_FAILURE; buf->pages = tmp_pages; - if (krb5_encrypt(kctx->seq, NULL, md5cksum.data, - md5cksum.data, md5cksum.len)) - return GSS_S_FAILURE; - memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data + md5cksum.len - 8, 8); + memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data, md5cksum.len); spin_lock(&krb5_seq_lock); seq_send = kctx->seq_send++; @@ -198,25 +228,42 @@ gss_wrap_kerberos(struct gss_ctx *ctx, int offset, /* XXX would probably be more efficient to compute checksum * and encrypt at the same time: */ - if ((krb5_make_seq_num(kctx->seq, kctx->initiate ? 0 : 0xff, + if ((krb5_make_seq_num(kctx, kctx->seq, kctx->initiate ? 0 : 0xff, seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8))) return GSS_S_FAILURE; - if (gss_encrypt_xdr_buf(kctx->enc, buf, offset + headlen - blocksize, - pages)) - return GSS_S_FAILURE; + if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) { + struct crypto_blkcipher *cipher; + int err; + cipher = crypto_alloc_blkcipher(kctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(cipher)) + return GSS_S_FAILURE; + + krb5_rc4_setup_enc_key(kctx, cipher, seq_send); + + err = gss_encrypt_xdr_buf(cipher, buf, + offset + headlen - conflen, pages); + crypto_free_blkcipher(cipher); + if (err) + return GSS_S_FAILURE; + } else { + if (gss_encrypt_xdr_buf(kctx->enc, buf, + offset + headlen - conflen, pages)) + return GSS_S_FAILURE; + } return (kctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; } -u32 -gss_unwrap_kerberos(struct gss_ctx *ctx, int offset, struct xdr_buf *buf) +static u32 +gss_unwrap_kerberos_v1(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) { - struct krb5_ctx *kctx = ctx->internal_ctx_id; int signalg; int sealalg; - char cksumdata[16]; - struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; s32 now; int direction; s32 seqnum; @@ -225,6 +272,9 @@ gss_unwrap_kerberos(struct gss_ctx *ctx, int offset, struct xdr_buf *buf) void *data_start, *orig_start; int data_len; int blocksize; + u32 conflen = kctx->gk5e->conflen; + int crypt_offset; + u8 *cksumkey; dprintk("RPC: gss_unwrap_kerberos\n"); @@ -242,29 +292,65 @@ gss_unwrap_kerberos(struct gss_ctx *ctx, int offset, struct xdr_buf *buf) /* get the sign and seal algorithms */ signalg = ptr[2] + (ptr[3] << 8); - if (signalg != SGN_ALG_DES_MAC_MD5) + if (signalg != kctx->gk5e->signalg) return GSS_S_DEFECTIVE_TOKEN; sealalg = ptr[4] + (ptr[5] << 8); - if (sealalg != SEAL_ALG_DES) + if (sealalg != kctx->gk5e->sealalg) return GSS_S_DEFECTIVE_TOKEN; if ((ptr[6] != 0xff) || (ptr[7] != 0xff)) return GSS_S_DEFECTIVE_TOKEN; - if (gss_decrypt_xdr_buf(kctx->enc, buf, - ptr + GSS_KRB5_TOK_HDR_LEN + 8 - (unsigned char *)buf->head[0].iov_base)) - return GSS_S_DEFECTIVE_TOKEN; + /* + * Data starts after token header and checksum. ptr points + * to the beginning of the token header + */ + crypt_offset = ptr + (GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength) - + (unsigned char *)buf->head[0].iov_base; + + /* + * Need plaintext seqnum to derive encryption key for arcfour-hmac + */ + if (krb5_get_seq_num(kctx, ptr + GSS_KRB5_TOK_HDR_LEN, + ptr + 8, &direction, &seqnum)) + return GSS_S_BAD_SIG; - if (make_checksum("md5", ptr, 8, buf, - ptr + GSS_KRB5_TOK_HDR_LEN + 8 - (unsigned char *)buf->head[0].iov_base, &md5cksum)) - return GSS_S_FAILURE; + if ((kctx->initiate && direction != 0xff) || + (!kctx->initiate && direction != 0)) + return GSS_S_BAD_SIG; - if (krb5_encrypt(kctx->seq, NULL, md5cksum.data, - md5cksum.data, md5cksum.len)) + if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) { + struct crypto_blkcipher *cipher; + int err; + + cipher = crypto_alloc_blkcipher(kctx->gk5e->encrypt_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(cipher)) + return GSS_S_FAILURE; + + krb5_rc4_setup_enc_key(kctx, cipher, seqnum); + + err = gss_decrypt_xdr_buf(cipher, buf, crypt_offset); + crypto_free_blkcipher(cipher); + if (err) + return GSS_S_DEFECTIVE_TOKEN; + } else { + if (gss_decrypt_xdr_buf(kctx->enc, buf, crypt_offset)) + return GSS_S_DEFECTIVE_TOKEN; + } + + if (kctx->gk5e->keyed_cksum) + cksumkey = kctx->cksum; + else + cksumkey = NULL; + + if (make_checksum(kctx, ptr, 8, buf, crypt_offset, + cksumkey, KG_USAGE_SEAL, &md5cksum)) return GSS_S_FAILURE; - if (memcmp(md5cksum.data + 8, ptr + GSS_KRB5_TOK_HDR_LEN, 8)) + if (memcmp(md5cksum.data, ptr + GSS_KRB5_TOK_HDR_LEN, + kctx->gk5e->cksumlength)) return GSS_S_BAD_SIG; /* it got through unscathed. Make sure the context is unexpired */ @@ -276,19 +362,12 @@ gss_unwrap_kerberos(struct gss_ctx *ctx, int offset, struct xdr_buf *buf) /* do sequencing checks */ - if (krb5_get_seq_num(kctx->seq, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8, - &direction, &seqnum)) - return GSS_S_BAD_SIG; - - if ((kctx->initiate && direction != 0xff) || - (!kctx->initiate && direction != 0)) - return GSS_S_BAD_SIG; - /* Copy the data back to the right position. XXX: Would probably be * better to copy and encrypt at the same time. */ blocksize = crypto_blkcipher_blocksize(kctx->enc); - data_start = ptr + GSS_KRB5_TOK_HDR_LEN + 8 + blocksize; + data_start = ptr + (GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength) + + conflen; orig_start = buf->head[0].iov_base + offset; data_len = (buf->head[0].iov_base + buf->head[0].iov_len) - data_start; memmove(orig_start, data_start, data_len); @@ -300,3 +379,242 @@ gss_unwrap_kerberos(struct gss_ctx *ctx, int offset, struct xdr_buf *buf) return GSS_S_COMPLETE; } + +/* + * We can shift data by up to LOCAL_BUF_LEN bytes in a pass. If we need + * to do more than that, we shift repeatedly. Kevin Coffman reports + * seeing 28 bytes as the value used by Microsoft clients and servers + * with AES, so this constant is chosen to allow handling 28 in one pass + * without using too much stack space. + * + * If that proves to a problem perhaps we could use a more clever + * algorithm. + */ +#define LOCAL_BUF_LEN 32u + +static void rotate_buf_a_little(struct xdr_buf *buf, unsigned int shift) +{ + char head[LOCAL_BUF_LEN]; + char tmp[LOCAL_BUF_LEN]; + unsigned int this_len, i; + + BUG_ON(shift > LOCAL_BUF_LEN); + + read_bytes_from_xdr_buf(buf, 0, head, shift); + for (i = 0; i + shift < buf->len; i += LOCAL_BUF_LEN) { + this_len = min(LOCAL_BUF_LEN, buf->len - (i + shift)); + read_bytes_from_xdr_buf(buf, i+shift, tmp, this_len); + write_bytes_to_xdr_buf(buf, i, tmp, this_len); + } + write_bytes_to_xdr_buf(buf, buf->len - shift, head, shift); +} + +static void _rotate_left(struct xdr_buf *buf, unsigned int shift) +{ + int shifted = 0; + int this_shift; + + shift %= buf->len; + while (shifted < shift) { + this_shift = min(shift - shifted, LOCAL_BUF_LEN); + rotate_buf_a_little(buf, this_shift); + shifted += this_shift; + } +} + +static void rotate_left(u32 base, struct xdr_buf *buf, unsigned int shift) +{ + struct xdr_buf subbuf; + + xdr_buf_subsegment(buf, &subbuf, base, buf->len - base); + _rotate_left(&subbuf, shift); +} + +static u32 +gss_wrap_kerberos_v2(struct krb5_ctx *kctx, u32 offset, + struct xdr_buf *buf, struct page **pages) +{ + int blocksize; + u8 *ptr, *plainhdr; + s32 now; + u8 flags = 0x00; + __be16 *be16ptr, ec = 0; + __be64 *be64ptr; + u32 err; + + dprintk("RPC: %s\n", __func__); + + if (kctx->gk5e->encrypt_v2 == NULL) + return GSS_S_FAILURE; + + /* make room for gss token header */ + if (xdr_extend_head(buf, offset, GSS_KRB5_TOK_HDR_LEN)) + return GSS_S_FAILURE; + + /* construct gss token header */ + ptr = plainhdr = buf->head[0].iov_base + offset; + *ptr++ = (unsigned char) ((KG2_TOK_WRAP>>8) & 0xff); + *ptr++ = (unsigned char) (KG2_TOK_WRAP & 0xff); + + if ((kctx->flags & KRB5_CTX_FLAG_INITIATOR) == 0) + flags |= KG2_TOKEN_FLAG_SENTBYACCEPTOR; + if ((kctx->flags & KRB5_CTX_FLAG_ACCEPTOR_SUBKEY) != 0) + flags |= KG2_TOKEN_FLAG_ACCEPTORSUBKEY; + /* We always do confidentiality in wrap tokens */ + flags |= KG2_TOKEN_FLAG_SEALED; + + *ptr++ = flags; + *ptr++ = 0xff; + be16ptr = (__be16 *)ptr; + + blocksize = crypto_blkcipher_blocksize(kctx->acceptor_enc); + *be16ptr++ = cpu_to_be16(ec); + /* "inner" token header always uses 0 for RRC */ + *be16ptr++ = cpu_to_be16(0); + + be64ptr = (__be64 *)be16ptr; + spin_lock(&krb5_seq_lock); + *be64ptr = cpu_to_be64(kctx->seq_send64++); + spin_unlock(&krb5_seq_lock); + + err = (*kctx->gk5e->encrypt_v2)(kctx, offset, buf, ec, pages); + if (err) + return err; + + now = get_seconds(); + return (kctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} + +static u32 +gss_unwrap_kerberos_v2(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) +{ + s32 now; + u8 *ptr; + u8 flags = 0x00; + u16 ec, rrc; + int err; + u32 headskip, tailskip; + u8 decrypted_hdr[GSS_KRB5_TOK_HDR_LEN]; + unsigned int movelen; + + + dprintk("RPC: %s\n", __func__); + + if (kctx->gk5e->decrypt_v2 == NULL) + return GSS_S_FAILURE; + + ptr = buf->head[0].iov_base + offset; + + if (be16_to_cpu(*((__be16 *)ptr)) != KG2_TOK_WRAP) + return GSS_S_DEFECTIVE_TOKEN; + + flags = ptr[2]; + if ((!kctx->initiate && (flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR)) || + (kctx->initiate && !(flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR))) + return GSS_S_BAD_SIG; + + if ((flags & KG2_TOKEN_FLAG_SEALED) == 0) { + dprintk("%s: token missing expected sealed flag\n", __func__); + return GSS_S_DEFECTIVE_TOKEN; + } + + if (ptr[3] != 0xff) + return GSS_S_DEFECTIVE_TOKEN; + + ec = be16_to_cpup((__be16 *)(ptr + 4)); + rrc = be16_to_cpup((__be16 *)(ptr + 6)); + + /* + * NOTE: the sequence number at ptr + 8 is skipped, rpcsec_gss + * doesn't want it checked; see page 6 of rfc 2203. + */ + + if (rrc != 0) + rotate_left(offset + 16, buf, rrc); + + err = (*kctx->gk5e->decrypt_v2)(kctx, offset, buf, + &headskip, &tailskip); + if (err) + return GSS_S_FAILURE; + + /* + * Retrieve the decrypted gss token header and verify + * it against the original + */ + err = read_bytes_from_xdr_buf(buf, + buf->len - GSS_KRB5_TOK_HDR_LEN - tailskip, + decrypted_hdr, GSS_KRB5_TOK_HDR_LEN); + if (err) { + dprintk("%s: error %u getting decrypted_hdr\n", __func__, err); + return GSS_S_FAILURE; + } + if (memcmp(ptr, decrypted_hdr, 6) + || memcmp(ptr + 8, decrypted_hdr + 8, 8)) { + dprintk("%s: token hdr, plaintext hdr mismatch!\n", __func__); + return GSS_S_FAILURE; + } + + /* do sequencing checks */ + + /* it got through unscathed. Make sure the context is unexpired */ + now = get_seconds(); + if (now > kctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* + * Move the head data back to the right position in xdr_buf. + * We ignore any "ec" data since it might be in the head or + * the tail, and we really don't need to deal with it. + * Note that buf->head[0].iov_len may indicate the available + * head buffer space rather than that actually occupied. + */ + movelen = min_t(unsigned int, buf->head[0].iov_len, buf->len); + movelen -= offset + GSS_KRB5_TOK_HDR_LEN + headskip; + BUG_ON(offset + GSS_KRB5_TOK_HDR_LEN + headskip + movelen > + buf->head[0].iov_len); + memmove(ptr, ptr + GSS_KRB5_TOK_HDR_LEN + headskip, movelen); + buf->head[0].iov_len -= GSS_KRB5_TOK_HDR_LEN + headskip; + buf->len -= GSS_KRB5_TOK_HDR_LEN + headskip; + + /* Trim off the trailing "extra count" and checksum blob */ + xdr_buf_trim(buf, ec + GSS_KRB5_TOK_HDR_LEN + tailskip); + return GSS_S_COMPLETE; +} + +u32 +gss_wrap_kerberos(struct gss_ctx *gctx, int offset, + struct xdr_buf *buf, struct page **pages) +{ + struct krb5_ctx *kctx = gctx->internal_ctx_id; + + switch (kctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + case ENCTYPE_ARCFOUR_HMAC: + return gss_wrap_kerberos_v1(kctx, offset, buf, pages); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_wrap_kerberos_v2(kctx, offset, buf, pages); + } +} + +u32 +gss_unwrap_kerberos(struct gss_ctx *gctx, int offset, struct xdr_buf *buf) +{ + struct krb5_ctx *kctx = gctx->internal_ctx_id; + + switch (kctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + case ENCTYPE_ARCFOUR_HMAC: + return gss_unwrap_kerberos_v1(kctx, offset, buf); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_unwrap_kerberos_v2(kctx, offset, buf); + } +} + diff --git a/net/sunrpc/auth_gss/gss_mech_switch.c b/net/sunrpc/auth_gss/gss_mech_switch.c index 76e4c6f4ac3..92d5ab99fbf 100644 --- a/net/sunrpc/auth_gss/gss_mech_switch.c +++ b/net/sunrpc/auth_gss/gss_mech_switch.c @@ -36,6 +36,7 @@ #include <linux/types.h> #include <linux/slab.h> #include <linux/module.h> +#include <linux/oid_registry.h> #include <linux/sunrpc/msg_prot.h> #include <linux/sunrpc/gss_asn1.h> #include <linux/sunrpc/auth_gss.h> @@ -102,8 +103,13 @@ out: return status; } -int -gss_mech_register(struct gss_api_mech *gm) +/** + * gss_mech_register - register a GSS mechanism + * @gm: GSS mechanism handle + * + * Returns zero if successful, or a negative errno. + */ +int gss_mech_register(struct gss_api_mech *gm) { int status; @@ -116,11 +122,14 @@ gss_mech_register(struct gss_api_mech *gm) dprintk("RPC: registered gss mechanism %s\n", gm->gm_name); return 0; } - EXPORT_SYMBOL_GPL(gss_mech_register); -void -gss_mech_unregister(struct gss_api_mech *gm) +/** + * gss_mech_unregister - release a GSS mechanism + * @gm: GSS mechanism handle + * + */ +void gss_mech_unregister(struct gss_api_mech *gm) { spin_lock(®istered_mechs_lock); list_del(&gm->gm_list); @@ -128,20 +137,17 @@ gss_mech_unregister(struct gss_api_mech *gm) dprintk("RPC: unregistered gss mechanism %s\n", gm->gm_name); gss_mech_free(gm); } - EXPORT_SYMBOL_GPL(gss_mech_unregister); -struct gss_api_mech * -gss_mech_get(struct gss_api_mech *gm) +struct gss_api_mech *gss_mech_get(struct gss_api_mech *gm) { __module_get(gm->gm_owner); return gm; } +EXPORT_SYMBOL(gss_mech_get); -EXPORT_SYMBOL_GPL(gss_mech_get); - -struct gss_api_mech * -gss_mech_get_by_name(const char *name) +static struct gss_api_mech * +_gss_mech_get_by_name(const char *name) { struct gss_api_mech *pos, *gm = NULL; @@ -158,7 +164,41 @@ gss_mech_get_by_name(const char *name) } -EXPORT_SYMBOL_GPL(gss_mech_get_by_name); +struct gss_api_mech * gss_mech_get_by_name(const char *name) +{ + struct gss_api_mech *gm = NULL; + + gm = _gss_mech_get_by_name(name); + if (!gm) { + request_module("rpc-auth-gss-%s", name); + gm = _gss_mech_get_by_name(name); + } + return gm; +} + +struct gss_api_mech *gss_mech_get_by_OID(struct rpcsec_gss_oid *obj) +{ + struct gss_api_mech *pos, *gm = NULL; + char buf[32]; + + if (sprint_oid(obj->data, obj->len, buf, sizeof(buf)) < 0) + return NULL; + dprintk("RPC: %s(%s)\n", __func__, buf); + request_module("rpc-auth-gss-%s", buf); + + spin_lock(®istered_mechs_lock); + list_for_each_entry(pos, ®istered_mechs, gm_list) { + if (obj->len == pos->gm_oid.len) { + if (0 == memcmp(obj->data, pos->gm_oid.data, obj->len)) { + if (try_module_get(pos->gm_owner)) + gm = pos; + break; + } + } + } + spin_unlock(®istered_mechs_lock); + return gm; +} static inline int mech_supports_pseudoflavor(struct gss_api_mech *gm, u32 pseudoflavor) @@ -172,17 +212,14 @@ mech_supports_pseudoflavor(struct gss_api_mech *gm, u32 pseudoflavor) return 0; } -struct gss_api_mech * -gss_mech_get_by_pseudoflavor(u32 pseudoflavor) +static struct gss_api_mech *_gss_mech_get_by_pseudoflavor(u32 pseudoflavor) { - struct gss_api_mech *pos, *gm = NULL; + struct gss_api_mech *gm = NULL, *pos; spin_lock(®istered_mechs_lock); list_for_each_entry(pos, ®istered_mechs, gm_list) { - if (!mech_supports_pseudoflavor(pos, pseudoflavor)) { - module_put(pos->gm_owner); + if (!mech_supports_pseudoflavor(pos, pseudoflavor)) continue; - } if (try_module_get(pos->gm_owner)) gm = pos; break; @@ -191,21 +228,125 @@ gss_mech_get_by_pseudoflavor(u32 pseudoflavor) return gm; } -EXPORT_SYMBOL_GPL(gss_mech_get_by_pseudoflavor); +struct gss_api_mech * +gss_mech_get_by_pseudoflavor(u32 pseudoflavor) +{ + struct gss_api_mech *gm; -u32 -gss_svc_to_pseudoflavor(struct gss_api_mech *gm, u32 service) + gm = _gss_mech_get_by_pseudoflavor(pseudoflavor); + + if (!gm) { + request_module("rpc-auth-gss-%u", pseudoflavor); + gm = _gss_mech_get_by_pseudoflavor(pseudoflavor); + } + return gm; +} + +/** + * gss_mech_list_pseudoflavors - Discover registered GSS pseudoflavors + * @array: array to fill in + * @size: size of "array" + * + * Returns the number of array items filled in, or a negative errno. + * + * The returned array is not sorted by any policy. Callers should not + * rely on the order of the items in the returned array. + */ +int gss_mech_list_pseudoflavors(rpc_authflavor_t *array_ptr, int size) +{ + struct gss_api_mech *pos = NULL; + int j, i = 0; + + spin_lock(®istered_mechs_lock); + list_for_each_entry(pos, ®istered_mechs, gm_list) { + for (j = 0; j < pos->gm_pf_num; j++) { + if (i >= size) { + spin_unlock(®istered_mechs_lock); + return -ENOMEM; + } + array_ptr[i++] = pos->gm_pfs[j].pseudoflavor; + } + } + spin_unlock(®istered_mechs_lock); + return i; +} + +/** + * gss_svc_to_pseudoflavor - map a GSS service number to a pseudoflavor + * @gm: GSS mechanism handle + * @qop: GSS quality-of-protection value + * @service: GSS service value + * + * Returns a matching security flavor, or RPC_AUTH_MAXFLAVOR if none is found. + */ +rpc_authflavor_t gss_svc_to_pseudoflavor(struct gss_api_mech *gm, u32 qop, + u32 service) { int i; for (i = 0; i < gm->gm_pf_num; i++) { - if (gm->gm_pfs[i].service == service) { + if (gm->gm_pfs[i].qop == qop && + gm->gm_pfs[i].service == service) { return gm->gm_pfs[i].pseudoflavor; } } - return RPC_AUTH_MAXFLAVOR; /* illegal value */ + return RPC_AUTH_MAXFLAVOR; +} + +/** + * gss_mech_info2flavor - look up a pseudoflavor given a GSS tuple + * @info: a GSS mech OID, quality of protection, and service value + * + * Returns a matching pseudoflavor, or RPC_AUTH_MAXFLAVOR if the tuple is + * not supported. + */ +rpc_authflavor_t gss_mech_info2flavor(struct rpcsec_gss_info *info) +{ + rpc_authflavor_t pseudoflavor; + struct gss_api_mech *gm; + + gm = gss_mech_get_by_OID(&info->oid); + if (gm == NULL) + return RPC_AUTH_MAXFLAVOR; + + pseudoflavor = gss_svc_to_pseudoflavor(gm, info->qop, info->service); + + gss_mech_put(gm); + return pseudoflavor; +} + +/** + * gss_mech_flavor2info - look up a GSS tuple for a given pseudoflavor + * @pseudoflavor: GSS pseudoflavor to match + * @info: rpcsec_gss_info structure to fill in + * + * Returns zero and fills in "info" if pseudoflavor matches a + * supported mechanism. Otherwise a negative errno is returned. + */ +int gss_mech_flavor2info(rpc_authflavor_t pseudoflavor, + struct rpcsec_gss_info *info) +{ + struct gss_api_mech *gm; + int i; + + gm = gss_mech_get_by_pseudoflavor(pseudoflavor); + if (gm == NULL) + return -ENOENT; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) { + memcpy(info->oid.data, gm->gm_oid.data, gm->gm_oid.len); + info->oid.len = gm->gm_oid.len; + info->qop = gm->gm_pfs[i].qop; + info->service = gm->gm_pfs[i].service; + gss_mech_put(gm); + return 0; + } + } + + gss_mech_put(gm); + return -ENOENT; } -EXPORT_SYMBOL_GPL(gss_svc_to_pseudoflavor); u32 gss_pseudoflavor_to_service(struct gss_api_mech *gm, u32 pseudoflavor) @@ -218,8 +359,7 @@ gss_pseudoflavor_to_service(struct gss_api_mech *gm, u32 pseudoflavor) } return 0; } - -EXPORT_SYMBOL_GPL(gss_pseudoflavor_to_service); +EXPORT_SYMBOL(gss_pseudoflavor_to_service); char * gss_service_to_auth_domain_name(struct gss_api_mech *gm, u32 service) @@ -233,30 +373,29 @@ gss_service_to_auth_domain_name(struct gss_api_mech *gm, u32 service) return NULL; } -EXPORT_SYMBOL_GPL(gss_service_to_auth_domain_name); - void gss_mech_put(struct gss_api_mech * gm) { if (gm) module_put(gm->gm_owner); } - -EXPORT_SYMBOL_GPL(gss_mech_put); +EXPORT_SYMBOL(gss_mech_put); /* The mech could probably be determined from the token instead, but it's just * as easy for now to pass it in. */ int gss_import_sec_context(const void *input_token, size_t bufsize, struct gss_api_mech *mech, - struct gss_ctx **ctx_id) + struct gss_ctx **ctx_id, + time_t *endtime, + gfp_t gfp_mask) { - if (!(*ctx_id = kzalloc(sizeof(**ctx_id), GFP_KERNEL))) + if (!(*ctx_id = kzalloc(sizeof(**ctx_id), gfp_mask))) return -ENOMEM; (*ctx_id)->mech_type = gss_mech_get(mech); - return mech->gm_ops - ->gss_import_sec_context(input_token, bufsize, *ctx_id); + return mech->gm_ops->gss_import_sec_context(input_token, bufsize, + *ctx_id, endtime, gfp_mask); } /* gss_get_mic: compute a mic over message and return mic_token. */ @@ -285,6 +424,20 @@ gss_verify_mic(struct gss_ctx *context_handle, mic_token); } +/* + * This function is called from both the client and server code. + * Each makes guarantees about how much "slack" space is available + * for the underlying function in "buf"'s head and tail while + * performing the wrap. + * + * The client and server code allocate RPC_MAX_AUTH_SIZE extra + * space in both the head and tail which is available for use by + * the wrap function. + * + * Underlying functions should verify they do not use more than + * RPC_MAX_AUTH_SIZE of extra space in either the head or tail + * when performing the wrap. + */ u32 gss_wrap(struct gss_ctx *ctx_id, int offset, @@ -316,7 +469,7 @@ gss_delete_sec_context(struct gss_ctx **context_handle) *context_handle); if (!*context_handle) - return(GSS_S_NO_CONTEXT); + return GSS_S_NO_CONTEXT; if ((*context_handle)->internal_ctx_id) (*context_handle)->mech_type->gm_ops ->gss_delete_sec_context((*context_handle) diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.c b/net/sunrpc/auth_gss/gss_rpc_upcall.c new file mode 100644 index 00000000000..abbb7dcd168 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_upcall.c @@ -0,0 +1,382 @@ +/* + * linux/net/sunrpc/gss_rpc_upcall.c + * + * Copyright (C) 2012 Simo Sorce <simo@redhat.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. + */ + +#include <linux/types.h> +#include <linux/un.h> + +#include <linux/sunrpc/svcauth.h> +#include "gss_rpc_upcall.h" + +#define GSSPROXY_SOCK_PATHNAME "/var/run/gssproxy.sock" + +#define GSSPROXY_PROGRAM (400112u) +#define GSSPROXY_VERS_1 (1u) + +/* + * Encoding/Decoding functions + */ + +enum { + GSSX_NULL = 0, /* Unused */ + GSSX_INDICATE_MECHS = 1, + GSSX_GET_CALL_CONTEXT = 2, + GSSX_IMPORT_AND_CANON_NAME = 3, + GSSX_EXPORT_CRED = 4, + GSSX_IMPORT_CRED = 5, + GSSX_ACQUIRE_CRED = 6, + GSSX_STORE_CRED = 7, + GSSX_INIT_SEC_CONTEXT = 8, + GSSX_ACCEPT_SEC_CONTEXT = 9, + GSSX_RELEASE_HANDLE = 10, + GSSX_GET_MIC = 11, + GSSX_VERIFY = 12, + GSSX_WRAP = 13, + GSSX_UNWRAP = 14, + GSSX_WRAP_SIZE_LIMIT = 15, +}; + +#define PROC(proc, name) \ +[GSSX_##proc] = { \ + .p_proc = GSSX_##proc, \ + .p_encode = (kxdreproc_t)gssx_enc_##name, \ + .p_decode = (kxdrdproc_t)gssx_dec_##name, \ + .p_arglen = GSSX_ARG_##name##_sz, \ + .p_replen = GSSX_RES_##name##_sz, \ + .p_statidx = GSSX_##proc, \ + .p_name = #proc, \ +} + +static struct rpc_procinfo gssp_procedures[] = { + PROC(INDICATE_MECHS, indicate_mechs), + PROC(GET_CALL_CONTEXT, get_call_context), + PROC(IMPORT_AND_CANON_NAME, import_and_canon_name), + PROC(EXPORT_CRED, export_cred), + PROC(IMPORT_CRED, import_cred), + PROC(ACQUIRE_CRED, acquire_cred), + PROC(STORE_CRED, store_cred), + PROC(INIT_SEC_CONTEXT, init_sec_context), + PROC(ACCEPT_SEC_CONTEXT, accept_sec_context), + PROC(RELEASE_HANDLE, release_handle), + PROC(GET_MIC, get_mic), + PROC(VERIFY, verify), + PROC(WRAP, wrap), + PROC(UNWRAP, unwrap), + PROC(WRAP_SIZE_LIMIT, wrap_size_limit), +}; + + + +/* + * Common transport functions + */ + +static const struct rpc_program gssp_program; + +static int gssp_rpc_create(struct net *net, struct rpc_clnt **_clnt) +{ + static const struct sockaddr_un gssp_localaddr = { + .sun_family = AF_LOCAL, + .sun_path = GSSPROXY_SOCK_PATHNAME, + }; + struct rpc_create_args args = { + .net = net, + .protocol = XPRT_TRANSPORT_LOCAL, + .address = (struct sockaddr *)&gssp_localaddr, + .addrsize = sizeof(gssp_localaddr), + .servername = "localhost", + .program = &gssp_program, + .version = GSSPROXY_VERS_1, + .authflavor = RPC_AUTH_NULL, + /* + * Note we want connection to be done in the caller's + * filesystem namespace. We therefore turn off the idle + * timeout, which would result in reconnections being + * done without the correct namespace: + */ + .flags = RPC_CLNT_CREATE_NOPING | + RPC_CLNT_CREATE_NO_IDLE_TIMEOUT + }; + struct rpc_clnt *clnt; + int result = 0; + + clnt = rpc_create(&args); + if (IS_ERR(clnt)) { + dprintk("RPC: failed to create AF_LOCAL gssproxy " + "client (errno %ld).\n", PTR_ERR(clnt)); + result = PTR_ERR(clnt); + *_clnt = NULL; + goto out; + } + + dprintk("RPC: created new gssp local client (gssp_local_clnt: " + "%p)\n", clnt); + *_clnt = clnt; + +out: + return result; +} + +void init_gssp_clnt(struct sunrpc_net *sn) +{ + mutex_init(&sn->gssp_lock); + sn->gssp_clnt = NULL; +} + +int set_gssp_clnt(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt; + int ret; + + mutex_lock(&sn->gssp_lock); + ret = gssp_rpc_create(net, &clnt); + if (!ret) { + if (sn->gssp_clnt) + rpc_shutdown_client(sn->gssp_clnt); + sn->gssp_clnt = clnt; + } + mutex_unlock(&sn->gssp_lock); + return ret; +} + +void clear_gssp_clnt(struct sunrpc_net *sn) +{ + mutex_lock(&sn->gssp_lock); + if (sn->gssp_clnt) { + rpc_shutdown_client(sn->gssp_clnt); + sn->gssp_clnt = NULL; + } + mutex_unlock(&sn->gssp_lock); +} + +static struct rpc_clnt *get_gssp_clnt(struct sunrpc_net *sn) +{ + struct rpc_clnt *clnt; + + mutex_lock(&sn->gssp_lock); + clnt = sn->gssp_clnt; + if (clnt) + atomic_inc(&clnt->cl_count); + mutex_unlock(&sn->gssp_lock); + return clnt; +} + +static int gssp_call(struct net *net, struct rpc_message *msg) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt; + int status; + + clnt = get_gssp_clnt(sn); + if (!clnt) + return -EIO; + status = rpc_call_sync(clnt, msg, 0); + if (status < 0) { + dprintk("gssp: rpc_call returned error %d\n", -status); + switch (status) { + case -EPROTONOSUPPORT: + status = -EINVAL; + break; + case -ECONNREFUSED: + case -ETIMEDOUT: + case -ENOTCONN: + status = -EAGAIN; + break; + case -ERESTARTSYS: + if (signalled ()) + status = -EINTR; + break; + default: + break; + } + } + rpc_release_client(clnt); + return status; +} + +static void gssp_free_receive_pages(struct gssx_arg_accept_sec_context *arg) +{ + int i; + + for (i = 0; i < arg->npages && arg->pages[i]; i++) + __free_page(arg->pages[i]); +} + +static int gssp_alloc_receive_pages(struct gssx_arg_accept_sec_context *arg) +{ + arg->npages = DIV_ROUND_UP(NGROUPS_MAX * 4, PAGE_SIZE); + arg->pages = kzalloc(arg->npages * sizeof(struct page *), GFP_KERNEL); + /* + * XXX: actual pages are allocated by xdr layer in + * xdr_partial_copy_from_skb. + */ + if (!arg->pages) + return -ENOMEM; + return 0; +} + +/* + * Public functions + */ + +/* numbers somewhat arbitrary but large enough for current needs */ +#define GSSX_MAX_OUT_HANDLE 128 +#define GSSX_MAX_SRC_PRINC 256 +#define GSSX_KMEMBUF (GSSX_max_output_handle_sz + \ + GSSX_max_oid_sz + \ + GSSX_max_princ_sz + \ + sizeof(struct svc_cred)) + +int gssp_accept_sec_context_upcall(struct net *net, + struct gssp_upcall_data *data) +{ + struct gssx_ctx ctxh = { + .state = data->in_handle + }; + struct gssx_arg_accept_sec_context arg = { + .input_token = data->in_token, + }; + struct gssx_ctx rctxh = { + /* + * pass in the max length we expect for each of these + * buffers but let the xdr code kmalloc them: + */ + .exported_context_token.len = GSSX_max_output_handle_sz, + .mech.len = GSS_OID_MAX_LEN, + .src_name.display_name.len = GSSX_max_princ_sz + }; + struct gssx_res_accept_sec_context res = { + .context_handle = &rctxh, + .output_token = &data->out_token + }; + struct rpc_message msg = { + .rpc_proc = &gssp_procedures[GSSX_ACCEPT_SEC_CONTEXT], + .rpc_argp = &arg, + .rpc_resp = &res, + .rpc_cred = NULL, /* FIXME ? */ + }; + struct xdr_netobj client_name = { 0 , NULL }; + int ret; + + if (data->in_handle.len != 0) + arg.context_handle = &ctxh; + res.output_token->len = GSSX_max_output_token_sz; + + ret = gssp_alloc_receive_pages(&arg); + if (ret) + return ret; + + /* use nfs/ for targ_name ? */ + + ret = gssp_call(net, &msg); + + gssp_free_receive_pages(&arg); + + /* we need to fetch all data even in case of error so + * that we can free special strctures is they have been allocated */ + data->major_status = res.status.major_status; + data->minor_status = res.status.minor_status; + if (res.context_handle) { + data->out_handle = rctxh.exported_context_token; + data->mech_oid.len = rctxh.mech.len; + if (rctxh.mech.data) + memcpy(data->mech_oid.data, rctxh.mech.data, + data->mech_oid.len); + client_name = rctxh.src_name.display_name; + } + + if (res.options.count == 1) { + gssx_buffer *value = &res.options.data[0].value; + /* Currently we only decode CREDS_VALUE, if we add + * anything else we'll have to loop and match on the + * option name */ + if (value->len == 1) { + /* steal group info from struct svc_cred */ + data->creds = *(struct svc_cred *)value->data; + data->found_creds = 1; + } + /* whether we use it or not, free data */ + kfree(value->data); + } + + if (res.options.count != 0) { + kfree(res.options.data); + } + + /* convert to GSS_NT_HOSTBASED_SERVICE form and set into creds */ + if (data->found_creds && client_name.data != NULL) { + char *c; + + data->creds.cr_principal = kstrndup(client_name.data, + client_name.len, GFP_KERNEL); + if (data->creds.cr_principal) { + /* terminate and remove realm part */ + c = strchr(data->creds.cr_principal, '@'); + if (c) { + *c = '\0'; + + /* change service-hostname delimiter */ + c = strchr(data->creds.cr_principal, '/'); + if (c) *c = '@'; + } + if (!c) { + /* not a service principal */ + kfree(data->creds.cr_principal); + data->creds.cr_principal = NULL; + } + } + } + kfree(client_name.data); + + return ret; +} + +void gssp_free_upcall_data(struct gssp_upcall_data *data) +{ + kfree(data->in_handle.data); + kfree(data->out_handle.data); + kfree(data->out_token.data); + free_svc_cred(&data->creds); +} + +/* + * Initialization stuff + */ + +static const struct rpc_version gssp_version1 = { + .number = GSSPROXY_VERS_1, + .nrprocs = ARRAY_SIZE(gssp_procedures), + .procs = gssp_procedures, +}; + +static const struct rpc_version *gssp_version[] = { + NULL, + &gssp_version1, +}; + +static struct rpc_stat gssp_stats; + +static const struct rpc_program gssp_program = { + .name = "gssproxy", + .number = GSSPROXY_PROGRAM, + .nrvers = ARRAY_SIZE(gssp_version), + .version = gssp_version, + .stats = &gssp_stats, +}; diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.h b/net/sunrpc/auth_gss/gss_rpc_upcall.h new file mode 100644 index 00000000000..1e542aded90 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_upcall.h @@ -0,0 +1,48 @@ +/* + * linux/net/sunrpc/gss_rpc_upcall.h + * + * Copyright (C) 2012 Simo Sorce <simo@redhat.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. + */ + +#ifndef _GSS_RPC_UPCALL_H +#define _GSS_RPC_UPCALL_H + +#include <linux/sunrpc/gss_api.h> +#include <linux/sunrpc/auth_gss.h> +#include "gss_rpc_xdr.h" +#include "../netns.h" + +struct gssp_upcall_data { + struct xdr_netobj in_handle; + struct gssp_in_token in_token; + struct xdr_netobj out_handle; + struct xdr_netobj out_token; + struct rpcsec_gss_oid mech_oid; + struct svc_cred creds; + int found_creds; + int major_status; + int minor_status; +}; + +int gssp_accept_sec_context_upcall(struct net *net, + struct gssp_upcall_data *data); +void gssp_free_upcall_data(struct gssp_upcall_data *data); + +void init_gssp_clnt(struct sunrpc_net *); +int set_gssp_clnt(struct net *); +void clear_gssp_clnt(struct sunrpc_net *); +#endif /* _GSS_RPC_UPCALL_H */ diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.c b/net/sunrpc/auth_gss/gss_rpc_xdr.c new file mode 100644 index 00000000000..1ec19f6f0c2 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_xdr.c @@ -0,0 +1,839 @@ +/* + * GSS Proxy upcall module + * + * Copyright (C) 2012 Simo Sorce <simo@redhat.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. + */ + +#include <linux/sunrpc/svcauth.h> +#include "gss_rpc_xdr.h" + +static int gssx_enc_bool(struct xdr_stream *xdr, int v) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + *p = v ? xdr_one : xdr_zero; + return 0; +} + +static int gssx_dec_bool(struct xdr_stream *xdr, u32 *v) +{ + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + *v = be32_to_cpu(*p); + return 0; +} + +static int gssx_enc_buffer(struct xdr_stream *xdr, + gssx_buffer *buf) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, sizeof(u32) + buf->len); + if (!p) + return -ENOSPC; + xdr_encode_opaque(p, buf->data, buf->len); + return 0; +} + +static int gssx_enc_in_token(struct xdr_stream *xdr, + struct gssp_in_token *in) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = cpu_to_be32(in->page_len); + + /* all we need to do is to write pages */ + xdr_write_pages(xdr, in->pages, in->page_base, in->page_len); + + return 0; +} + + +static int gssx_dec_buffer(struct xdr_stream *xdr, + gssx_buffer *buf) +{ + u32 length; + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + + length = be32_to_cpup(p); + p = xdr_inline_decode(xdr, length); + if (unlikely(p == NULL)) + return -ENOSPC; + + if (buf->len == 0) { + /* we intentionally are not interested in this buffer */ + return 0; + } + if (length > buf->len) + return -ENOSPC; + + if (!buf->data) { + buf->data = kmemdup(p, length, GFP_KERNEL); + if (!buf->data) + return -ENOMEM; + } else { + memcpy(buf->data, p, length); + } + buf->len = length; + return 0; +} + +static int gssx_enc_option(struct xdr_stream *xdr, + struct gssx_option *opt) +{ + int err; + + err = gssx_enc_buffer(xdr, &opt->option); + if (err) + return err; + err = gssx_enc_buffer(xdr, &opt->value); + return err; +} + +static int gssx_dec_option(struct xdr_stream *xdr, + struct gssx_option *opt) +{ + int err; + + err = gssx_dec_buffer(xdr, &opt->option); + if (err) + return err; + err = gssx_dec_buffer(xdr, &opt->value); + return err; +} + +static int dummy_enc_opt_array(struct xdr_stream *xdr, + struct gssx_option_array *oa) +{ + __be32 *p; + + if (oa->count != 0) + return -EINVAL; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = 0; + + return 0; +} + +static int dummy_dec_opt_array(struct xdr_stream *xdr, + struct gssx_option_array *oa) +{ + struct gssx_option dummy; + u32 count, i; + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + count = be32_to_cpup(p++); + memset(&dummy, 0, sizeof(dummy)); + for (i = 0; i < count; i++) { + gssx_dec_option(xdr, &dummy); + } + + oa->count = 0; + oa->data = NULL; + return 0; +} + +static int get_host_u32(struct xdr_stream *xdr, u32 *res) +{ + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (!p) + return -EINVAL; + /* Contents of linux creds are all host-endian: */ + memcpy(res, p, sizeof(u32)); + return 0; +} + +static int gssx_dec_linux_creds(struct xdr_stream *xdr, + struct svc_cred *creds) +{ + u32 length; + __be32 *p; + u32 tmp; + u32 N; + int i, err; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + + length = be32_to_cpup(p); + + if (length > (3 + NGROUPS_MAX) * sizeof(u32)) + return -ENOSPC; + + /* uid */ + err = get_host_u32(xdr, &tmp); + if (err) + return err; + creds->cr_uid = make_kuid(&init_user_ns, tmp); + + /* gid */ + err = get_host_u32(xdr, &tmp); + if (err) + return err; + creds->cr_gid = make_kgid(&init_user_ns, tmp); + + /* number of additional gid's */ + err = get_host_u32(xdr, &tmp); + if (err) + return err; + N = tmp; + if ((3 + N) * sizeof(u32) != length) + return -EINVAL; + creds->cr_group_info = groups_alloc(N); + if (creds->cr_group_info == NULL) + return -ENOMEM; + + /* gid's */ + for (i = 0; i < N; i++) { + kgid_t kgid; + err = get_host_u32(xdr, &tmp); + if (err) + goto out_free_groups; + err = -EINVAL; + kgid = make_kgid(&init_user_ns, tmp); + if (!gid_valid(kgid)) + goto out_free_groups; + GROUP_AT(creds->cr_group_info, i) = kgid; + } + + return 0; +out_free_groups: + groups_free(creds->cr_group_info); + return err; +} + +static int gssx_dec_option_array(struct xdr_stream *xdr, + struct gssx_option_array *oa) +{ + struct svc_cred *creds; + u32 count, i; + __be32 *p; + int err; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + count = be32_to_cpup(p++); + if (!count) + return 0; + + /* we recognize only 1 currently: CREDS_VALUE */ + oa->count = 1; + + oa->data = kmalloc(sizeof(struct gssx_option), GFP_KERNEL); + if (!oa->data) + return -ENOMEM; + + creds = kmalloc(sizeof(struct svc_cred), GFP_KERNEL); + if (!creds) { + kfree(oa->data); + return -ENOMEM; + } + + oa->data[0].option.data = CREDS_VALUE; + oa->data[0].option.len = sizeof(CREDS_VALUE); + oa->data[0].value.data = (void *)creds; + oa->data[0].value.len = 0; + + for (i = 0; i < count; i++) { + gssx_buffer dummy = { 0, NULL }; + u32 length; + + /* option buffer */ + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + + length = be32_to_cpup(p); + p = xdr_inline_decode(xdr, length); + if (unlikely(p == NULL)) + return -ENOSPC; + + if (length == sizeof(CREDS_VALUE) && + memcmp(p, CREDS_VALUE, sizeof(CREDS_VALUE)) == 0) { + /* We have creds here. parse them */ + err = gssx_dec_linux_creds(xdr, creds); + if (err) + return err; + oa->data[0].value.len = 1; /* presence */ + } else { + /* consume uninteresting buffer */ + err = gssx_dec_buffer(xdr, &dummy); + if (err) + return err; + } + } + return 0; +} + +static int gssx_dec_status(struct xdr_stream *xdr, + struct gssx_status *status) +{ + __be32 *p; + int err; + + /* status->major_status */ + p = xdr_inline_decode(xdr, 8); + if (unlikely(p == NULL)) + return -ENOSPC; + p = xdr_decode_hyper(p, &status->major_status); + + /* status->mech */ + err = gssx_dec_buffer(xdr, &status->mech); + if (err) + return err; + + /* status->minor_status */ + p = xdr_inline_decode(xdr, 8); + if (unlikely(p == NULL)) + return -ENOSPC; + p = xdr_decode_hyper(p, &status->minor_status); + + /* status->major_status_string */ + err = gssx_dec_buffer(xdr, &status->major_status_string); + if (err) + return err; + + /* status->minor_status_string */ + err = gssx_dec_buffer(xdr, &status->minor_status_string); + if (err) + return err; + + /* status->server_ctx */ + err = gssx_dec_buffer(xdr, &status->server_ctx); + if (err) + return err; + + /* we assume we have no options for now, so simply consume them */ + /* status->options */ + err = dummy_dec_opt_array(xdr, &status->options); + + return err; +} + +static int gssx_enc_call_ctx(struct xdr_stream *xdr, + struct gssx_call_ctx *ctx) +{ + struct gssx_option opt; + __be32 *p; + int err; + + /* ctx->locale */ + err = gssx_enc_buffer(xdr, &ctx->locale); + if (err) + return err; + + /* ctx->server_ctx */ + err = gssx_enc_buffer(xdr, &ctx->server_ctx); + if (err) + return err; + + /* we always want to ask for lucid contexts */ + /* ctx->options */ + p = xdr_reserve_space(xdr, 4); + *p = cpu_to_be32(2); + + /* we want a lucid_v1 context */ + opt.option.data = LUCID_OPTION; + opt.option.len = sizeof(LUCID_OPTION); + opt.value.data = LUCID_VALUE; + opt.value.len = sizeof(LUCID_VALUE); + err = gssx_enc_option(xdr, &opt); + + /* ..and user creds */ + opt.option.data = CREDS_OPTION; + opt.option.len = sizeof(CREDS_OPTION); + opt.value.data = CREDS_VALUE; + opt.value.len = sizeof(CREDS_VALUE); + err = gssx_enc_option(xdr, &opt); + + return err; +} + +static int gssx_dec_name_attr(struct xdr_stream *xdr, + struct gssx_name_attr *attr) +{ + int err; + + /* attr->attr */ + err = gssx_dec_buffer(xdr, &attr->attr); + if (err) + return err; + + /* attr->value */ + err = gssx_dec_buffer(xdr, &attr->value); + if (err) + return err; + + /* attr->extensions */ + err = dummy_dec_opt_array(xdr, &attr->extensions); + + return err; +} + +static int dummy_enc_nameattr_array(struct xdr_stream *xdr, + struct gssx_name_attr_array *naa) +{ + __be32 *p; + + if (naa->count != 0) + return -EINVAL; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = 0; + + return 0; +} + +static int dummy_dec_nameattr_array(struct xdr_stream *xdr, + struct gssx_name_attr_array *naa) +{ + struct gssx_name_attr dummy = { .attr = {.len = 0} }; + u32 count, i; + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + count = be32_to_cpup(p++); + for (i = 0; i < count; i++) { + gssx_dec_name_attr(xdr, &dummy); + } + + naa->count = 0; + naa->data = NULL; + return 0; +} + +static struct xdr_netobj zero_netobj = {}; + +static struct gssx_name_attr_array zero_name_attr_array = {}; + +static struct gssx_option_array zero_option_array = {}; + +static int gssx_enc_name(struct xdr_stream *xdr, + struct gssx_name *name) +{ + int err; + + /* name->display_name */ + err = gssx_enc_buffer(xdr, &name->display_name); + if (err) + return err; + + /* name->name_type */ + err = gssx_enc_buffer(xdr, &zero_netobj); + if (err) + return err; + + /* name->exported_name */ + err = gssx_enc_buffer(xdr, &zero_netobj); + if (err) + return err; + + /* name->exported_composite_name */ + err = gssx_enc_buffer(xdr, &zero_netobj); + if (err) + return err; + + /* leave name_attributes empty for now, will add once we have any + * to pass up at all */ + /* name->name_attributes */ + err = dummy_enc_nameattr_array(xdr, &zero_name_attr_array); + if (err) + return err; + + /* leave options empty for now, will add once we have any options + * to pass up at all */ + /* name->extensions */ + err = dummy_enc_opt_array(xdr, &zero_option_array); + + return err; +} + + +static int gssx_dec_name(struct xdr_stream *xdr, + struct gssx_name *name) +{ + struct xdr_netobj dummy_netobj = { .len = 0 }; + struct gssx_name_attr_array dummy_name_attr_array = { .count = 0 }; + struct gssx_option_array dummy_option_array = { .count = 0 }; + int err; + + /* name->display_name */ + err = gssx_dec_buffer(xdr, &name->display_name); + if (err) + return err; + + /* name->name_type */ + err = gssx_dec_buffer(xdr, &dummy_netobj); + if (err) + return err; + + /* name->exported_name */ + err = gssx_dec_buffer(xdr, &dummy_netobj); + if (err) + return err; + + /* name->exported_composite_name */ + err = gssx_dec_buffer(xdr, &dummy_netobj); + if (err) + return err; + + /* we assume we have no attributes for now, so simply consume them */ + /* name->name_attributes */ + err = dummy_dec_nameattr_array(xdr, &dummy_name_attr_array); + if (err) + return err; + + /* we assume we have no options for now, so simply consume them */ + /* name->extensions */ + err = dummy_dec_opt_array(xdr, &dummy_option_array); + + return err; +} + +static int dummy_enc_credel_array(struct xdr_stream *xdr, + struct gssx_cred_element_array *cea) +{ + __be32 *p; + + if (cea->count != 0) + return -EINVAL; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = 0; + + return 0; +} + +static int gssx_enc_cred(struct xdr_stream *xdr, + struct gssx_cred *cred) +{ + int err; + + /* cred->desired_name */ + err = gssx_enc_name(xdr, &cred->desired_name); + if (err) + return err; + + /* cred->elements */ + err = dummy_enc_credel_array(xdr, &cred->elements); + if (err) + return err; + + /* cred->cred_handle_reference */ + err = gssx_enc_buffer(xdr, &cred->cred_handle_reference); + if (err) + return err; + + /* cred->needs_release */ + err = gssx_enc_bool(xdr, cred->needs_release); + + return err; +} + +static int gssx_enc_ctx(struct xdr_stream *xdr, + struct gssx_ctx *ctx) +{ + __be32 *p; + int err; + + /* ctx->exported_context_token */ + err = gssx_enc_buffer(xdr, &ctx->exported_context_token); + if (err) + return err; + + /* ctx->state */ + err = gssx_enc_buffer(xdr, &ctx->state); + if (err) + return err; + + /* ctx->need_release */ + err = gssx_enc_bool(xdr, ctx->need_release); + if (err) + return err; + + /* ctx->mech */ + err = gssx_enc_buffer(xdr, &ctx->mech); + if (err) + return err; + + /* ctx->src_name */ + err = gssx_enc_name(xdr, &ctx->src_name); + if (err) + return err; + + /* ctx->targ_name */ + err = gssx_enc_name(xdr, &ctx->targ_name); + if (err) + return err; + + /* ctx->lifetime */ + p = xdr_reserve_space(xdr, 8+8); + if (!p) + return -ENOSPC; + p = xdr_encode_hyper(p, ctx->lifetime); + + /* ctx->ctx_flags */ + p = xdr_encode_hyper(p, ctx->ctx_flags); + + /* ctx->locally_initiated */ + err = gssx_enc_bool(xdr, ctx->locally_initiated); + if (err) + return err; + + /* ctx->open */ + err = gssx_enc_bool(xdr, ctx->open); + if (err) + return err; + + /* leave options empty for now, will add once we have any options + * to pass up at all */ + /* ctx->options */ + err = dummy_enc_opt_array(xdr, &ctx->options); + + return err; +} + +static int gssx_dec_ctx(struct xdr_stream *xdr, + struct gssx_ctx *ctx) +{ + __be32 *p; + int err; + + /* ctx->exported_context_token */ + err = gssx_dec_buffer(xdr, &ctx->exported_context_token); + if (err) + return err; + + /* ctx->state */ + err = gssx_dec_buffer(xdr, &ctx->state); + if (err) + return err; + + /* ctx->need_release */ + err = gssx_dec_bool(xdr, &ctx->need_release); + if (err) + return err; + + /* ctx->mech */ + err = gssx_dec_buffer(xdr, &ctx->mech); + if (err) + return err; + + /* ctx->src_name */ + err = gssx_dec_name(xdr, &ctx->src_name); + if (err) + return err; + + /* ctx->targ_name */ + err = gssx_dec_name(xdr, &ctx->targ_name); + if (err) + return err; + + /* ctx->lifetime */ + p = xdr_inline_decode(xdr, 8+8); + if (unlikely(p == NULL)) + return -ENOSPC; + p = xdr_decode_hyper(p, &ctx->lifetime); + + /* ctx->ctx_flags */ + p = xdr_decode_hyper(p, &ctx->ctx_flags); + + /* ctx->locally_initiated */ + err = gssx_dec_bool(xdr, &ctx->locally_initiated); + if (err) + return err; + + /* ctx->open */ + err = gssx_dec_bool(xdr, &ctx->open); + if (err) + return err; + + /* we assume we have no options for now, so simply consume them */ + /* ctx->options */ + err = dummy_dec_opt_array(xdr, &ctx->options); + + return err; +} + +static int gssx_enc_cb(struct xdr_stream *xdr, struct gssx_cb *cb) +{ + __be32 *p; + int err; + + /* cb->initiator_addrtype */ + p = xdr_reserve_space(xdr, 8); + if (!p) + return -ENOSPC; + p = xdr_encode_hyper(p, cb->initiator_addrtype); + + /* cb->initiator_address */ + err = gssx_enc_buffer(xdr, &cb->initiator_address); + if (err) + return err; + + /* cb->acceptor_addrtype */ + p = xdr_reserve_space(xdr, 8); + if (!p) + return -ENOSPC; + p = xdr_encode_hyper(p, cb->acceptor_addrtype); + + /* cb->acceptor_address */ + err = gssx_enc_buffer(xdr, &cb->acceptor_address); + if (err) + return err; + + /* cb->application_data */ + err = gssx_enc_buffer(xdr, &cb->application_data); + + return err; +} + +void gssx_enc_accept_sec_context(struct rpc_rqst *req, + struct xdr_stream *xdr, + struct gssx_arg_accept_sec_context *arg) +{ + int err; + + err = gssx_enc_call_ctx(xdr, &arg->call_ctx); + if (err) + goto done; + + /* arg->context_handle */ + if (arg->context_handle) + err = gssx_enc_ctx(xdr, arg->context_handle); + else + err = gssx_enc_bool(xdr, 0); + if (err) + goto done; + + /* arg->cred_handle */ + if (arg->cred_handle) + err = gssx_enc_cred(xdr, arg->cred_handle); + else + err = gssx_enc_bool(xdr, 0); + if (err) + goto done; + + /* arg->input_token */ + err = gssx_enc_in_token(xdr, &arg->input_token); + if (err) + goto done; + + /* arg->input_cb */ + if (arg->input_cb) + err = gssx_enc_cb(xdr, arg->input_cb); + else + err = gssx_enc_bool(xdr, 0); + if (err) + goto done; + + err = gssx_enc_bool(xdr, arg->ret_deleg_cred); + if (err) + goto done; + + /* leave options empty for now, will add once we have any options + * to pass up at all */ + /* arg->options */ + err = dummy_enc_opt_array(xdr, &arg->options); + + xdr_inline_pages(&req->rq_rcv_buf, + PAGE_SIZE/2 /* pretty arbitrary */, + arg->pages, 0 /* page base */, arg->npages * PAGE_SIZE); +done: + if (err) + dprintk("RPC: gssx_enc_accept_sec_context: %d\n", err); +} + +int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, + struct xdr_stream *xdr, + struct gssx_res_accept_sec_context *res) +{ + u32 value_follows; + int err; + + /* res->status */ + err = gssx_dec_status(xdr, &res->status); + if (err) + return err; + + /* res->context_handle */ + err = gssx_dec_bool(xdr, &value_follows); + if (err) + return err; + if (value_follows) { + err = gssx_dec_ctx(xdr, res->context_handle); + if (err) + return err; + } else { + res->context_handle = NULL; + } + + /* res->output_token */ + err = gssx_dec_bool(xdr, &value_follows); + if (err) + return err; + if (value_follows) { + err = gssx_dec_buffer(xdr, res->output_token); + if (err) + return err; + } else { + res->output_token = NULL; + } + + /* res->delegated_cred_handle */ + err = gssx_dec_bool(xdr, &value_follows); + if (err) + return err; + if (value_follows) { + /* we do not support upcall servers sending this data. */ + return -EINVAL; + } + + /* res->options */ + err = gssx_dec_option_array(xdr, &res->options); + + return err; +} diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.h b/net/sunrpc/auth_gss/gss_rpc_xdr.h new file mode 100644 index 00000000000..685a688f3d8 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_xdr.h @@ -0,0 +1,267 @@ +/* + * GSS Proxy upcall module + * + * Copyright (C) 2012 Simo Sorce <simo@redhat.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. + */ + +#ifndef _LINUX_GSS_RPC_XDR_H +#define _LINUX_GSS_RPC_XDR_H + +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/xprtsock.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +#define LUCID_OPTION "exported_context_type" +#define LUCID_VALUE "linux_lucid_v1" +#define CREDS_OPTION "exported_creds_type" +#define CREDS_VALUE "linux_creds_v1" + +typedef struct xdr_netobj gssx_buffer; +typedef struct xdr_netobj utf8string; +typedef struct xdr_netobj gssx_OID; + +enum gssx_cred_usage { + GSSX_C_INITIATE = 1, + GSSX_C_ACCEPT = 2, + GSSX_C_BOTH = 3, +}; + +struct gssx_option { + gssx_buffer option; + gssx_buffer value; +}; + +struct gssx_option_array { + u32 count; + struct gssx_option *data; +}; + +struct gssx_status { + u64 major_status; + gssx_OID mech; + u64 minor_status; + utf8string major_status_string; + utf8string minor_status_string; + gssx_buffer server_ctx; + struct gssx_option_array options; +}; + +struct gssx_call_ctx { + utf8string locale; + gssx_buffer server_ctx; + struct gssx_option_array options; +}; + +struct gssx_name_attr { + gssx_buffer attr; + gssx_buffer value; + struct gssx_option_array extensions; +}; + +struct gssx_name_attr_array { + u32 count; + struct gssx_name_attr *data; +}; + +struct gssx_name { + gssx_buffer display_name; +}; +typedef struct gssx_name gssx_name; + +struct gssx_cred_element { + gssx_name MN; + gssx_OID mech; + u32 cred_usage; + u64 initiator_time_rec; + u64 acceptor_time_rec; + struct gssx_option_array options; +}; + +struct gssx_cred_element_array { + u32 count; + struct gssx_cred_element *data; +}; + +struct gssx_cred { + gssx_name desired_name; + struct gssx_cred_element_array elements; + gssx_buffer cred_handle_reference; + u32 needs_release; +}; + +struct gssx_ctx { + gssx_buffer exported_context_token; + gssx_buffer state; + u32 need_release; + gssx_OID mech; + gssx_name src_name; + gssx_name targ_name; + u64 lifetime; + u64 ctx_flags; + u32 locally_initiated; + u32 open; + struct gssx_option_array options; +}; + +struct gssx_cb { + u64 initiator_addrtype; + gssx_buffer initiator_address; + u64 acceptor_addrtype; + gssx_buffer acceptor_address; + gssx_buffer application_data; +}; + + +/* This structure is not defined in the protocol. + * It is used in the kernel to carry around a big buffer + * as a set of pages */ +struct gssp_in_token { + struct page **pages; /* Array of contiguous pages */ + unsigned int page_base; /* Start of page data */ + unsigned int page_len; /* Length of page data */ +}; + +struct gssx_arg_accept_sec_context { + struct gssx_call_ctx call_ctx; + struct gssx_ctx *context_handle; + struct gssx_cred *cred_handle; + struct gssp_in_token input_token; + struct gssx_cb *input_cb; + u32 ret_deleg_cred; + struct gssx_option_array options; + struct page **pages; + unsigned int npages; +}; + +struct gssx_res_accept_sec_context { + struct gssx_status status; + struct gssx_ctx *context_handle; + gssx_buffer *output_token; + /* struct gssx_cred *delegated_cred_handle; not used in kernel */ + struct gssx_option_array options; +}; + + + +#define gssx_enc_indicate_mechs NULL +#define gssx_dec_indicate_mechs NULL +#define gssx_enc_get_call_context NULL +#define gssx_dec_get_call_context NULL +#define gssx_enc_import_and_canon_name NULL +#define gssx_dec_import_and_canon_name NULL +#define gssx_enc_export_cred NULL +#define gssx_dec_export_cred NULL +#define gssx_enc_import_cred NULL +#define gssx_dec_import_cred NULL +#define gssx_enc_acquire_cred NULL +#define gssx_dec_acquire_cred NULL +#define gssx_enc_store_cred NULL +#define gssx_dec_store_cred NULL +#define gssx_enc_init_sec_context NULL +#define gssx_dec_init_sec_context NULL +void gssx_enc_accept_sec_context(struct rpc_rqst *req, + struct xdr_stream *xdr, + struct gssx_arg_accept_sec_context *args); +int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, + struct xdr_stream *xdr, + struct gssx_res_accept_sec_context *res); +#define gssx_enc_release_handle NULL +#define gssx_dec_release_handle NULL +#define gssx_enc_get_mic NULL +#define gssx_dec_get_mic NULL +#define gssx_enc_verify NULL +#define gssx_dec_verify NULL +#define gssx_enc_wrap NULL +#define gssx_dec_wrap NULL +#define gssx_enc_unwrap NULL +#define gssx_dec_unwrap NULL +#define gssx_enc_wrap_size_limit NULL +#define gssx_dec_wrap_size_limit NULL + +/* non implemented calls are set to 0 size */ +#define GSSX_ARG_indicate_mechs_sz 0 +#define GSSX_RES_indicate_mechs_sz 0 +#define GSSX_ARG_get_call_context_sz 0 +#define GSSX_RES_get_call_context_sz 0 +#define GSSX_ARG_import_and_canon_name_sz 0 +#define GSSX_RES_import_and_canon_name_sz 0 +#define GSSX_ARG_export_cred_sz 0 +#define GSSX_RES_export_cred_sz 0 +#define GSSX_ARG_import_cred_sz 0 +#define GSSX_RES_import_cred_sz 0 +#define GSSX_ARG_acquire_cred_sz 0 +#define GSSX_RES_acquire_cred_sz 0 +#define GSSX_ARG_store_cred_sz 0 +#define GSSX_RES_store_cred_sz 0 +#define GSSX_ARG_init_sec_context_sz 0 +#define GSSX_RES_init_sec_context_sz 0 + +#define GSSX_default_in_call_ctx_sz (4 + 4 + 4 + \ + 8 + sizeof(LUCID_OPTION) + sizeof(LUCID_VALUE) + \ + 8 + sizeof(CREDS_OPTION) + sizeof(CREDS_VALUE)) +#define GSSX_default_in_ctx_hndl_sz (4 + 4+8 + 4 + 4 + 6*4 + 6*4 + 8 + 8 + \ + 4 + 4 + 4) +#define GSSX_default_in_cred_sz 4 /* we send in no cred_handle */ +#define GSSX_default_in_token_sz 4 /* does *not* include token data */ +#define GSSX_default_in_cb_sz 4 /* we do not use channel bindings */ +#define GSSX_ARG_accept_sec_context_sz (GSSX_default_in_call_ctx_sz + \ + GSSX_default_in_ctx_hndl_sz + \ + GSSX_default_in_cred_sz + \ + GSSX_default_in_token_sz + \ + GSSX_default_in_cb_sz + \ + 4 /* no deleg creds boolean */ + \ + 4) /* empty options */ + +/* somewhat arbitrary numbers but large enough (we ignore some of the data + * sent down, but it is part of the protocol so we need enough space to take + * it in) */ +#define GSSX_default_status_sz 8 + 24 + 8 + 256 + 256 + 16 + 4 +#define GSSX_max_output_handle_sz 128 +#define GSSX_max_oid_sz 16 +#define GSSX_max_princ_sz 256 +#define GSSX_default_ctx_sz (GSSX_max_output_handle_sz + \ + 16 + 4 + GSSX_max_oid_sz + \ + 2 * GSSX_max_princ_sz + \ + 8 + 8 + 4 + 4 + 4) +#define GSSX_max_output_token_sz 1024 +/* grouplist not included; we allocate separate pages for that: */ +#define GSSX_max_creds_sz (4 + 4 + 4 /* + NGROUPS_MAX*4 */) +#define GSSX_RES_accept_sec_context_sz (GSSX_default_status_sz + \ + GSSX_default_ctx_sz + \ + GSSX_max_output_token_sz + \ + 4 + GSSX_max_creds_sz) + +#define GSSX_ARG_release_handle_sz 0 +#define GSSX_RES_release_handle_sz 0 +#define GSSX_ARG_get_mic_sz 0 +#define GSSX_RES_get_mic_sz 0 +#define GSSX_ARG_verify_sz 0 +#define GSSX_RES_verify_sz 0 +#define GSSX_ARG_wrap_sz 0 +#define GSSX_RES_wrap_sz 0 +#define GSSX_ARG_unwrap_sz 0 +#define GSSX_RES_unwrap_sz 0 +#define GSSX_ARG_wrap_size_limit_sz 0 +#define GSSX_RES_wrap_size_limit_sz 0 + + + +#endif /* _LINUX_GSS_RPC_XDR_H */ diff --git a/net/sunrpc/auth_gss/gss_spkm3_mech.c b/net/sunrpc/auth_gss/gss_spkm3_mech.c deleted file mode 100644 index 035e1dd6af1..00000000000 --- a/net/sunrpc/auth_gss/gss_spkm3_mech.c +++ /dev/null @@ -1,243 +0,0 @@ -/* - * linux/net/sunrpc/gss_spkm3_mech.c - * - * Copyright (c) 2003 The Regents of the University of Michigan. - * All rights reserved. - * - * Andy Adamson <andros@umich.edu> - * J. Bruce Fields <bfields@umich.edu> - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * 3. Neither the name of the University nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED - * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR - * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - */ - -#include <linux/err.h> -#include <linux/module.h> -#include <linux/init.h> -#include <linux/types.h> -#include <linux/slab.h> -#include <linux/sunrpc/auth.h> -#include <linux/in.h> -#include <linux/sunrpc/svcauth_gss.h> -#include <linux/sunrpc/gss_spkm3.h> -#include <linux/sunrpc/xdr.h> -#include <linux/crypto.h> - -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_AUTH -#endif - -static const void * -simple_get_bytes(const void *p, const void *end, void *res, int len) -{ - const void *q = (const void *)((const char *)p + len); - if (unlikely(q > end || q < p)) - return ERR_PTR(-EFAULT); - memcpy(res, p, len); - return q; -} - -static const void * -simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res) -{ - const void *q; - unsigned int len; - p = simple_get_bytes(p, end, &len, sizeof(len)); - if (IS_ERR(p)) - return p; - res->len = len; - if (len == 0) { - res->data = NULL; - return p; - } - q = (const void *)((const char *)p + len); - if (unlikely(q > end || q < p)) - return ERR_PTR(-EFAULT); - res->data = kmemdup(p, len, GFP_NOFS); - if (unlikely(res->data == NULL)) - return ERR_PTR(-ENOMEM); - return q; -} - -static int -gss_import_sec_context_spkm3(const void *p, size_t len, - struct gss_ctx *ctx_id) -{ - const void *end = (const void *)((const char *)p + len); - struct spkm3_ctx *ctx; - int version; - - if (!(ctx = kzalloc(sizeof(*ctx), GFP_NOFS))) - goto out_err; - - p = simple_get_bytes(p, end, &version, sizeof(version)); - if (IS_ERR(p)) - goto out_err_free_ctx; - if (version != 1) { - dprintk("RPC: unknown spkm3 token format: " - "obsolete nfs-utils?\n"); - goto out_err_free_ctx; - } - - p = simple_get_netobj(p, end, &ctx->ctx_id); - if (IS_ERR(p)) - goto out_err_free_ctx; - - p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); - if (IS_ERR(p)) - goto out_err_free_ctx_id; - - p = simple_get_netobj(p, end, &ctx->mech_used); - if (IS_ERR(p)) - goto out_err_free_ctx_id; - - p = simple_get_bytes(p, end, &ctx->ret_flags, sizeof(ctx->ret_flags)); - if (IS_ERR(p)) - goto out_err_free_mech; - - p = simple_get_netobj(p, end, &ctx->conf_alg); - if (IS_ERR(p)) - goto out_err_free_mech; - - p = simple_get_netobj(p, end, &ctx->derived_conf_key); - if (IS_ERR(p)) - goto out_err_free_conf_alg; - - p = simple_get_netobj(p, end, &ctx->intg_alg); - if (IS_ERR(p)) - goto out_err_free_conf_key; - - p = simple_get_netobj(p, end, &ctx->derived_integ_key); - if (IS_ERR(p)) - goto out_err_free_intg_alg; - - if (p != end) - goto out_err_free_intg_key; - - ctx_id->internal_ctx_id = ctx; - - dprintk("RPC: Successfully imported new spkm context.\n"); - return 0; - -out_err_free_intg_key: - kfree(ctx->derived_integ_key.data); -out_err_free_intg_alg: - kfree(ctx->intg_alg.data); -out_err_free_conf_key: - kfree(ctx->derived_conf_key.data); -out_err_free_conf_alg: - kfree(ctx->conf_alg.data); -out_err_free_mech: - kfree(ctx->mech_used.data); -out_err_free_ctx_id: - kfree(ctx->ctx_id.data); -out_err_free_ctx: - kfree(ctx); -out_err: - return PTR_ERR(p); -} - -static void -gss_delete_sec_context_spkm3(void *internal_ctx) -{ - struct spkm3_ctx *sctx = internal_ctx; - - kfree(sctx->derived_integ_key.data); - kfree(sctx->intg_alg.data); - kfree(sctx->derived_conf_key.data); - kfree(sctx->conf_alg.data); - kfree(sctx->mech_used.data); - kfree(sctx->ctx_id.data); - kfree(sctx); -} - -static u32 -gss_verify_mic_spkm3(struct gss_ctx *ctx, - struct xdr_buf *signbuf, - struct xdr_netobj *checksum) -{ - u32 maj_stat = 0; - struct spkm3_ctx *sctx = ctx->internal_ctx_id; - - maj_stat = spkm3_read_token(sctx, checksum, signbuf, SPKM_MIC_TOK); - - dprintk("RPC: gss_verify_mic_spkm3 returning %d\n", maj_stat); - return maj_stat; -} - -static u32 -gss_get_mic_spkm3(struct gss_ctx *ctx, - struct xdr_buf *message_buffer, - struct xdr_netobj *message_token) -{ - u32 err = 0; - struct spkm3_ctx *sctx = ctx->internal_ctx_id; - - err = spkm3_make_token(sctx, message_buffer, - message_token, SPKM_MIC_TOK); - dprintk("RPC: gss_get_mic_spkm3 returning %d\n", err); - return err; -} - -static const struct gss_api_ops gss_spkm3_ops = { - .gss_import_sec_context = gss_import_sec_context_spkm3, - .gss_get_mic = gss_get_mic_spkm3, - .gss_verify_mic = gss_verify_mic_spkm3, - .gss_delete_sec_context = gss_delete_sec_context_spkm3, -}; - -static struct pf_desc gss_spkm3_pfs[] = { - {RPC_AUTH_GSS_SPKM, RPC_GSS_SVC_NONE, "spkm3"}, - {RPC_AUTH_GSS_SPKMI, RPC_GSS_SVC_INTEGRITY, "spkm3i"}, -}; - -static struct gss_api_mech gss_spkm3_mech = { - .gm_name = "spkm3", - .gm_owner = THIS_MODULE, - .gm_oid = {7, "\053\006\001\005\005\001\003"}, - .gm_ops = &gss_spkm3_ops, - .gm_pf_num = ARRAY_SIZE(gss_spkm3_pfs), - .gm_pfs = gss_spkm3_pfs, -}; - -static int __init init_spkm3_module(void) -{ - int status; - - status = gss_mech_register(&gss_spkm3_mech); - if (status) - printk("Failed to register spkm3 gss mechanism!\n"); - return status; -} - -static void __exit cleanup_spkm3_module(void) -{ - gss_mech_unregister(&gss_spkm3_mech); -} - -MODULE_LICENSE("GPL"); -module_init(init_spkm3_module); -module_exit(cleanup_spkm3_module); diff --git a/net/sunrpc/auth_gss/gss_spkm3_seal.c b/net/sunrpc/auth_gss/gss_spkm3_seal.c deleted file mode 100644 index c832712f8d5..00000000000 --- a/net/sunrpc/auth_gss/gss_spkm3_seal.c +++ /dev/null @@ -1,187 +0,0 @@ -/* - * linux/net/sunrpc/gss_spkm3_seal.c - * - * Copyright (c) 2003 The Regents of the University of Michigan. - * All rights reserved. - * - * Andy Adamson <andros@umich.edu> - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * 3. Neither the name of the University nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED - * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR - * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - */ - -#include <linux/types.h> -#include <linux/slab.h> -#include <linux/jiffies.h> -#include <linux/sunrpc/gss_spkm3.h> -#include <linux/random.h> -#include <linux/crypto.h> -#include <linux/pagemap.h> -#include <linux/scatterlist.h> -#include <linux/sunrpc/xdr.h> - -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_AUTH -#endif - -const struct xdr_netobj hmac_md5_oid = { 8, "\x2B\x06\x01\x05\x05\x08\x01\x01"}; -const struct xdr_netobj cast5_cbc_oid = {9, "\x2A\x86\x48\x86\xF6\x7D\x07\x42\x0A"}; - -/* - * spkm3_make_token() - * - * Only SPKM_MIC_TOK with md5 intg-alg is supported - */ - -u32 -spkm3_make_token(struct spkm3_ctx *ctx, - struct xdr_buf * text, struct xdr_netobj * token, - int toktype) -{ - s32 checksum_type; - char tokhdrbuf[25]; - char cksumdata[16]; - struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; - struct xdr_netobj mic_hdr = {.len = 0, .data = tokhdrbuf}; - int tokenlen = 0; - unsigned char *ptr; - s32 now; - int ctxelen = 0, ctxzbit = 0; - int md5elen = 0, md5zbit = 0; - - now = jiffies; - - if (ctx->ctx_id.len != 16) { - dprintk("RPC: spkm3_make_token BAD ctx_id.len %d\n", - ctx->ctx_id.len); - goto out_err; - } - - if (!g_OID_equal(&ctx->intg_alg, &hmac_md5_oid)) { - dprintk("RPC: gss_spkm3_seal: unsupported I-ALG " - "algorithm. only support hmac-md5 I-ALG.\n"); - goto out_err; - } else - checksum_type = CKSUMTYPE_HMAC_MD5; - - if (!g_OID_equal(&ctx->conf_alg, &cast5_cbc_oid)) { - dprintk("RPC: gss_spkm3_seal: unsupported C-ALG " - "algorithm\n"); - goto out_err; - } - - if (toktype == SPKM_MIC_TOK) { - /* Calculate checksum over the mic-header */ - asn1_bitstring_len(&ctx->ctx_id, &ctxelen, &ctxzbit); - spkm3_mic_header(&mic_hdr.data, &mic_hdr.len, ctx->ctx_id.data, - ctxelen, ctxzbit); - if (make_spkm3_checksum(checksum_type, &ctx->derived_integ_key, - (char *)mic_hdr.data, mic_hdr.len, - text, 0, &md5cksum)) - goto out_err; - - asn1_bitstring_len(&md5cksum, &md5elen, &md5zbit); - tokenlen = 10 + ctxelen + 1 + md5elen + 1; - - /* Create token header using generic routines */ - token->len = g_token_size(&ctx->mech_used, tokenlen + 2); - - ptr = token->data; - g_make_token_header(&ctx->mech_used, tokenlen + 2, &ptr); - - spkm3_make_mic_token(&ptr, tokenlen, &mic_hdr, &md5cksum, md5elen, md5zbit); - } else if (toktype == SPKM_WRAP_TOK) { /* Not Supported */ - dprintk("RPC: gss_spkm3_seal: SPKM_WRAP_TOK " - "not supported\n"); - goto out_err; - } - - /* XXX need to implement sequence numbers, and ctx->expired */ - - return GSS_S_COMPLETE; -out_err: - token->data = NULL; - token->len = 0; - return GSS_S_FAILURE; -} - -static int -spkm3_checksummer(struct scatterlist *sg, void *data) -{ - struct hash_desc *desc = data; - - return crypto_hash_update(desc, sg, sg->length); -} - -/* checksum the plaintext data and hdrlen bytes of the token header */ -s32 -make_spkm3_checksum(s32 cksumtype, struct xdr_netobj *key, char *header, - unsigned int hdrlen, struct xdr_buf *body, - unsigned int body_offset, struct xdr_netobj *cksum) -{ - char *cksumname; - struct hash_desc desc; /* XXX add to ctx? */ - struct scatterlist sg[1]; - int err; - - switch (cksumtype) { - case CKSUMTYPE_HMAC_MD5: - cksumname = "hmac(md5)"; - break; - default: - dprintk("RPC: spkm3_make_checksum:" - " unsupported checksum %d", cksumtype); - return GSS_S_FAILURE; - } - - if (key->data == NULL || key->len <= 0) return GSS_S_FAILURE; - - desc.tfm = crypto_alloc_hash(cksumname, 0, CRYPTO_ALG_ASYNC); - if (IS_ERR(desc.tfm)) - return GSS_S_FAILURE; - cksum->len = crypto_hash_digestsize(desc.tfm); - desc.flags = CRYPTO_TFM_REQ_MAY_SLEEP; - - err = crypto_hash_setkey(desc.tfm, key->data, key->len); - if (err) - goto out; - - err = crypto_hash_init(&desc); - if (err) - goto out; - - sg_init_one(sg, header, hdrlen); - crypto_hash_update(&desc, sg, sg->length); - - xdr_process_buf(body, body_offset, body->len - body_offset, - spkm3_checksummer, &desc); - crypto_hash_final(&desc, cksum->data); - -out: - crypto_free_hash(desc.tfm); - - return err ? GSS_S_FAILURE : 0; -} diff --git a/net/sunrpc/auth_gss/gss_spkm3_token.c b/net/sunrpc/auth_gss/gss_spkm3_token.c deleted file mode 100644 index 3308157436d..00000000000 --- a/net/sunrpc/auth_gss/gss_spkm3_token.c +++ /dev/null @@ -1,267 +0,0 @@ -/* - * linux/net/sunrpc/gss_spkm3_token.c - * - * Copyright (c) 2003 The Regents of the University of Michigan. - * All rights reserved. - * - * Andy Adamson <andros@umich.edu> - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * 3. Neither the name of the University nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED - * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR - * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - */ - -#include <linux/types.h> -#include <linux/slab.h> -#include <linux/jiffies.h> -#include <linux/sunrpc/gss_spkm3.h> -#include <linux/random.h> -#include <linux/crypto.h> - -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_AUTH -#endif - -/* - * asn1_bitstring_len() - * - * calculate the asn1 bitstring length of the xdr_netobject - */ -void -asn1_bitstring_len(struct xdr_netobj *in, int *enclen, int *zerobits) -{ - int i, zbit = 0,elen = in->len; - char *ptr; - - ptr = &in->data[in->len -1]; - - /* count trailing 0's */ - for(i = in->len; i > 0; i--) { - if (*ptr == 0) { - ptr--; - elen--; - } else - break; - } - - /* count number of 0 bits in final octet */ - ptr = &in->data[elen - 1]; - for(i = 0; i < 8; i++) { - short mask = 0x01; - - if (!((mask << i) & *ptr)) - zbit++; - else - break; - } - *enclen = elen; - *zerobits = zbit; -} - -/* - * decode_asn1_bitstring() - * - * decode a bitstring into a buffer of the expected length. - * enclen = bit string length - * explen = expected length (define in rfc) - */ -int -decode_asn1_bitstring(struct xdr_netobj *out, char *in, int enclen, int explen) -{ - if (!(out->data = kzalloc(explen,GFP_NOFS))) - return 0; - out->len = explen; - memcpy(out->data, in, enclen); - return 1; -} - -/* - * SPKMInnerContextToken choice SPKM_MIC asn1 token layout - * - * contextid is always 16 bytes plain data. max asn1 bitstring len = 17. - * - * tokenlen = pos[0] to end of token (max pos[45] with MD5 cksum) - * - * pos value - * ---------- - * [0] a4 SPKM-MIC tag - * [1] ?? innertoken length (max 44) - * - * - * tok_hdr piece of checksum data starts here - * - * the maximum mic-header len = 9 + 17 = 26 - * mic-header - * ---------- - * [2] 30 SEQUENCE tag - * [3] ?? mic-header length: (max 23) = TokenID + ContextID - * - * TokenID - all fields constant and can be hardcoded - * ------- - * [4] 02 Type 2 - * [5] 02 Length 2 - * [6][7] 01 01 TokenID (SPKM_MIC_TOK) - * - * ContextID - encoded length not constant, calculated - * --------- - * [8] 03 Type 3 - * [9] ?? encoded length - * [10] ?? ctxzbit - * [11] contextid - * - * mic_header piece of checksum data ends here. - * - * int-cksum - encoded length not constant, calculated - * --------- - * [??] 03 Type 3 - * [??] ?? encoded length - * [??] ?? md5zbit - * [??] int-cksum (NID_md5 = 16) - * - * maximum SPKM-MIC innercontext token length = - * 10 + encoded contextid_size(17 max) + 2 + encoded - * cksum_size (17 maxfor NID_md5) = 46 - */ - -/* - * spkm3_mic_header() - * - * Prepare the SPKM_MIC_TOK mic-header for check-sum calculation - * elen: 16 byte context id asn1 bitstring encoded length - */ -void -spkm3_mic_header(unsigned char **hdrbuf, unsigned int *hdrlen, unsigned char *ctxdata, int elen, int zbit) -{ - char *hptr = *hdrbuf; - char *top = *hdrbuf; - - *(u8 *)hptr++ = 0x30; - *(u8 *)hptr++ = elen + 7; /* on the wire header length */ - - /* tokenid */ - *(u8 *)hptr++ = 0x02; - *(u8 *)hptr++ = 0x02; - *(u8 *)hptr++ = 0x01; - *(u8 *)hptr++ = 0x01; - - /* coniextid */ - *(u8 *)hptr++ = 0x03; - *(u8 *)hptr++ = elen + 1; /* add 1 to include zbit */ - *(u8 *)hptr++ = zbit; - memcpy(hptr, ctxdata, elen); - hptr += elen; - *hdrlen = hptr - top; -} - -/* - * spkm3_mic_innercontext_token() - * - * *tokp points to the beginning of the SPKM_MIC token described - * in rfc 2025, section 3.2.1: - * - * toklen is the inner token length - */ -void -spkm3_make_mic_token(unsigned char **tokp, int toklen, struct xdr_netobj *mic_hdr, struct xdr_netobj *md5cksum, int md5elen, int md5zbit) -{ - unsigned char *ict = *tokp; - - *(u8 *)ict++ = 0xa4; - *(u8 *)ict++ = toklen; - memcpy(ict, mic_hdr->data, mic_hdr->len); - ict += mic_hdr->len; - - *(u8 *)ict++ = 0x03; - *(u8 *)ict++ = md5elen + 1; /* add 1 to include zbit */ - *(u8 *)ict++ = md5zbit; - memcpy(ict, md5cksum->data, md5elen); -} - -u32 -spkm3_verify_mic_token(unsigned char **tokp, int *mic_hdrlen, unsigned char **cksum) -{ - struct xdr_netobj spkm3_ctx_id = {.len =0, .data = NULL}; - unsigned char *ptr = *tokp; - int ctxelen; - u32 ret = GSS_S_DEFECTIVE_TOKEN; - - /* spkm3 innercontext token preamble */ - if ((ptr[0] != 0xa4) || (ptr[2] != 0x30)) { - dprintk("RPC: BAD SPKM ictoken preamble\n"); - goto out; - } - - *mic_hdrlen = ptr[3]; - - /* token type */ - if ((ptr[4] != 0x02) || (ptr[5] != 0x02)) { - dprintk("RPC: BAD asn1 SPKM3 token type\n"); - goto out; - } - - /* only support SPKM_MIC_TOK */ - if((ptr[6] != 0x01) || (ptr[7] != 0x01)) { - dprintk("RPC: ERROR unsupported SPKM3 token \n"); - goto out; - } - - /* contextid */ - if (ptr[8] != 0x03) { - dprintk("RPC: BAD SPKM3 asn1 context-id type\n"); - goto out; - } - - ctxelen = ptr[9]; - if (ctxelen > 17) { /* length includes asn1 zbit octet */ - dprintk("RPC: BAD SPKM3 contextid len %d\n", ctxelen); - goto out; - } - - /* ignore ptr[10] */ - - if(!decode_asn1_bitstring(&spkm3_ctx_id, &ptr[11], ctxelen - 1, 16)) - goto out; - - /* - * in the current implementation: the optional int-alg is not present - * so the default int-alg (md5) is used the optional snd-seq field is - * also not present - */ - - if (*mic_hdrlen != 6 + ctxelen) { - dprintk("RPC: BAD SPKM_ MIC_TOK header len %d: we only " - "support default int-alg (should be absent) " - "and do not support snd-seq\n", *mic_hdrlen); - goto out; - } - /* checksum */ - *cksum = (&ptr[10] + ctxelen); /* ctxelen includes ptr[10] */ - - ret = GSS_S_COMPLETE; -out: - kfree(spkm3_ctx_id.data); - return ret; -} - diff --git a/net/sunrpc/auth_gss/gss_spkm3_unseal.c b/net/sunrpc/auth_gss/gss_spkm3_unseal.c deleted file mode 100644 index cc21ee860bb..00000000000 --- a/net/sunrpc/auth_gss/gss_spkm3_unseal.c +++ /dev/null @@ -1,127 +0,0 @@ -/* - * linux/net/sunrpc/gss_spkm3_unseal.c - * - * Copyright (c) 2003 The Regents of the University of Michigan. - * All rights reserved. - * - * Andy Adamson <andros@umich.edu> - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * 3. Neither the name of the University nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED - * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR - * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - */ - -#include <linux/types.h> -#include <linux/slab.h> -#include <linux/jiffies.h> -#include <linux/sunrpc/gss_spkm3.h> -#include <linux/crypto.h> - -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_AUTH -#endif - -/* - * spkm3_read_token() - * - * only SPKM_MIC_TOK with md5 intg-alg is supported - */ -u32 -spkm3_read_token(struct spkm3_ctx *ctx, - struct xdr_netobj *read_token, /* checksum */ - struct xdr_buf *message_buffer, /* signbuf */ - int toktype) -{ - s32 checksum_type; - s32 code; - struct xdr_netobj wire_cksum = {.len =0, .data = NULL}; - char cksumdata[16]; - struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; - unsigned char *ptr = (unsigned char *)read_token->data; - unsigned char *cksum; - int bodysize, md5elen; - int mic_hdrlen; - u32 ret = GSS_S_DEFECTIVE_TOKEN; - - if (g_verify_token_header((struct xdr_netobj *) &ctx->mech_used, - &bodysize, &ptr, read_token->len)) - goto out; - - /* decode the token */ - - if (toktype != SPKM_MIC_TOK) { - dprintk("RPC: BAD SPKM3 token type: %d\n", toktype); - goto out; - } - - if ((ret = spkm3_verify_mic_token(&ptr, &mic_hdrlen, &cksum))) - goto out; - - if (*cksum++ != 0x03) { - dprintk("RPC: spkm3_read_token BAD checksum type\n"); - goto out; - } - md5elen = *cksum++; - cksum++; /* move past the zbit */ - - if (!decode_asn1_bitstring(&wire_cksum, cksum, md5elen - 1, 16)) - goto out; - - /* HARD CODED FOR MD5 */ - - /* compute the checksum of the message. - * ptr + 2 = start of header piece of checksum - * mic_hdrlen + 2 = length of header piece of checksum - */ - ret = GSS_S_DEFECTIVE_TOKEN; - if (!g_OID_equal(&ctx->intg_alg, &hmac_md5_oid)) { - dprintk("RPC: gss_spkm3_seal: unsupported I-ALG " - "algorithm\n"); - goto out; - } - - checksum_type = CKSUMTYPE_HMAC_MD5; - - code = make_spkm3_checksum(checksum_type, - &ctx->derived_integ_key, ptr + 2, mic_hdrlen + 2, - message_buffer, 0, &md5cksum); - - if (code) - goto out; - - ret = GSS_S_BAD_SIG; - code = memcmp(md5cksum.data, wire_cksum.data, wire_cksum.len); - if (code) { - dprintk("RPC: bad MIC checksum\n"); - goto out; - } - - - /* XXX: need to add expiration and sequencing */ - ret = GSS_S_COMPLETE; -out: - kfree(wire_cksum.data); - return ret; -} diff --git a/net/sunrpc/auth_gss/svcauth_gss.c b/net/sunrpc/auth_gss/svcauth_gss.c index e34bc531fcb..4ce5eccec1f 100644 --- a/net/sunrpc/auth_gss/svcauth_gss.c +++ b/net/sunrpc/auth_gss/svcauth_gss.c @@ -37,15 +37,19 @@ * */ +#include <linux/slab.h> #include <linux/types.h> #include <linux/module.h> #include <linux/pagemap.h> +#include <linux/user_namespace.h> #include <linux/sunrpc/auth_gss.h> #include <linux/sunrpc/gss_err.h> #include <linux/sunrpc/svcauth.h> #include <linux/sunrpc/svcauth_gss.h> #include <linux/sunrpc/cache.h> +#include "gss_rpc_upcall.h" + #ifdef RPC_DEBUG # define RPCDBG_FACILITY RPCDBG_AUTH @@ -66,7 +70,6 @@ static int netobj_equal(struct xdr_netobj *a, struct xdr_netobj *b) #define RSI_HASHBITS 6 #define RSI_HASHMAX (1<<RSI_HASHBITS) -#define RSI_HASHMASK (RSI_HASHMAX-1) struct rsi { struct cache_head h; @@ -75,10 +78,8 @@ struct rsi { int major_status, minor_status; }; -static struct cache_head *rsi_table[RSI_HASHMAX]; -static struct cache_detail rsi_cache; -static struct rsi *rsi_update(struct rsi *new, struct rsi *old); -static struct rsi *rsi_lookup(struct rsi *item); +static struct rsi *rsi_update(struct cache_detail *cd, struct rsi *new, struct rsi *old); +static struct rsi *rsi_lookup(struct cache_detail *cd, struct rsi *item); static void rsi_free(struct rsi *rsii) { @@ -181,12 +182,6 @@ static void rsi_request(struct cache_detail *cd, (*bpp)[-1] = '\n'; } -static int rsi_upcall(struct cache_detail *cd, struct cache_head *h) -{ - return sunrpc_cache_pipe_upcall(cd, h, rsi_request); -} - - static int rsi_parse(struct cache_detail *cd, char *mesg, int mlen) { @@ -216,7 +211,7 @@ static int rsi_parse(struct cache_detail *cd, if (dup_to_netobj(&rsii.in_token, buf, len)) goto out; - rsip = rsi_lookup(&rsii); + rsip = rsi_lookup(cd, &rsii); if (!rsip) goto out; @@ -258,24 +253,23 @@ static int rsi_parse(struct cache_detail *cd, if (dup_to_netobj(&rsii.out_token, buf, len)) goto out; rsii.h.expiry_time = expiry; - rsip = rsi_update(&rsii, rsip); + rsip = rsi_update(cd, &rsii, rsip); status = 0; out: rsi_free(&rsii); if (rsip) - cache_put(&rsip->h, &rsi_cache); + cache_put(&rsip->h, cd); else status = -ENOMEM; return status; } -static struct cache_detail rsi_cache = { +static struct cache_detail rsi_cache_template = { .owner = THIS_MODULE, .hash_size = RSI_HASHMAX, - .hash_table = rsi_table, .name = "auth.rpcsec.init", .cache_put = rsi_put, - .cache_upcall = rsi_upcall, + .cache_request = rsi_request, .cache_parse = rsi_parse, .match = rsi_match, .init = rsi_init, @@ -283,24 +277,24 @@ static struct cache_detail rsi_cache = { .alloc = rsi_alloc, }; -static struct rsi *rsi_lookup(struct rsi *item) +static struct rsi *rsi_lookup(struct cache_detail *cd, struct rsi *item) { struct cache_head *ch; int hash = rsi_hash(item); - ch = sunrpc_cache_lookup(&rsi_cache, &item->h, hash); + ch = sunrpc_cache_lookup(cd, &item->h, hash); if (ch) return container_of(ch, struct rsi, h); else return NULL; } -static struct rsi *rsi_update(struct rsi *new, struct rsi *old) +static struct rsi *rsi_update(struct cache_detail *cd, struct rsi *new, struct rsi *old) { struct cache_head *ch; int hash = rsi_hash(new); - ch = sunrpc_cache_update(&rsi_cache, &new->h, + ch = sunrpc_cache_update(cd, &new->h, &old->h, hash); if (ch) return container_of(ch, struct rsi, h); @@ -318,7 +312,6 @@ static struct rsi *rsi_update(struct rsi *new, struct rsi *old) #define RSC_HASHBITS 10 #define RSC_HASHMAX (1<<RSC_HASHBITS) -#define RSC_HASHMASK (RSC_HASHMAX-1) #define GSS_SEQ_WIN 128 @@ -337,22 +330,17 @@ struct rsc { struct svc_cred cred; struct gss_svc_seq_data seqdata; struct gss_ctx *mechctx; - char *client_name; }; -static struct cache_head *rsc_table[RSC_HASHMAX]; -static struct cache_detail rsc_cache; -static struct rsc *rsc_update(struct rsc *new, struct rsc *old); -static struct rsc *rsc_lookup(struct rsc *item); +static struct rsc *rsc_update(struct cache_detail *cd, struct rsc *new, struct rsc *old); +static struct rsc *rsc_lookup(struct cache_detail *cd, struct rsc *item); static void rsc_free(struct rsc *rsci) { kfree(rsci->handle.data); if (rsci->mechctx) gss_delete_sec_context(&rsci->mechctx); - if (rsci->cred.cr_group_info) - put_group_info(rsci->cred.cr_group_info); - kfree(rsci->client_name); + free_svc_cred(&rsci->cred); } static void rsc_put(struct kref *ref) @@ -389,8 +377,7 @@ rsc_init(struct cache_head *cnew, struct cache_head *ctmp) new->handle.data = tmp->handle.data; tmp->handle.data = NULL; new->mechctx = NULL; - new->cred.cr_group_info = NULL; - new->client_name = NULL; + init_svc_cred(&new->cred); } static void @@ -404,9 +391,7 @@ update_rsc(struct cache_head *cnew, struct cache_head *ctmp) memset(&new->seqdata, 0, sizeof(new->seqdata)); spin_lock_init(&new->seqdata.sd_lock); new->cred = tmp->cred; - tmp->cred.cr_group_info = NULL; - new->client_name = tmp->client_name; - tmp->client_name = NULL; + init_svc_cred(&tmp->cred); } static struct cache_head * @@ -424,6 +409,7 @@ static int rsc_parse(struct cache_detail *cd, { /* contexthandle expiry [ uid gid N <n gids> mechname ...mechdata... ] */ char *buf = mesg; + int id; int len, rv; struct rsc rsci, *rscp = NULL; time_t expiry; @@ -445,12 +431,12 @@ static int rsc_parse(struct cache_detail *cd, if (expiry == 0) goto out; - rscp = rsc_lookup(&rsci); + rscp = rsc_lookup(cd, &rsci); if (!rscp) goto out; /* uid, or NEGATIVE */ - rv = get_int(&mesg, &rsci.cred.cr_uid); + rv = get_int(&mesg, &id); if (rv == -EINVAL) goto out; if (rv == -ENOENT) @@ -458,9 +444,21 @@ static int rsc_parse(struct cache_detail *cd, else { int N, i; + /* + * NOTE: we skip uid_valid()/gid_valid() checks here: + * instead, * -1 id's are later mapped to the + * (export-specific) anonymous id by nfsd_setuser. + * + * (But supplementary gid's get no such special + * treatment so are checked for validity here.) + */ + /* uid */ + rsci.cred.cr_uid = make_kuid(&init_user_ns, id); + /* gid */ - if (get_int(&mesg, &rsci.cred.cr_gid)) + if (get_int(&mesg, &id)) goto out; + rsci.cred.cr_gid = make_kgid(&init_user_ns, id); /* number of additional gid's */ if (get_int(&mesg, &N)) @@ -473,17 +471,20 @@ static int rsc_parse(struct cache_detail *cd, /* gid's */ status = -EINVAL; for (i=0; i<N; i++) { - gid_t gid; - if (get_int(&mesg, &gid)) + kgid_t kgid; + if (get_int(&mesg, &id)) goto out; - GROUP_AT(rsci.cred.cr_group_info, i) = gid; + kgid = make_kgid(&init_user_ns, id); + if (!gid_valid(kgid)) + goto out; + GROUP_AT(rsci.cred.cr_group_info, i) = kgid; } /* mech name */ len = qword_get(&mesg, buf, mlen); if (len < 0) goto out; - gm = gss_mech_get_by_name(buf); + gm = rsci.cred.cr_gss_mech = gss_mech_get_by_name(buf); status = -EOPNOTSUPP; if (!gm) goto out; @@ -493,36 +494,37 @@ static int rsc_parse(struct cache_detail *cd, len = qword_get(&mesg, buf, mlen); if (len < 0) goto out; - status = gss_import_sec_context(buf, len, gm, &rsci.mechctx); + status = gss_import_sec_context(buf, len, gm, &rsci.mechctx, + NULL, GFP_KERNEL); if (status) goto out; /* get client name */ len = qword_get(&mesg, buf, mlen); if (len > 0) { - rsci.client_name = kstrdup(buf, GFP_KERNEL); - if (!rsci.client_name) + rsci.cred.cr_principal = kstrdup(buf, GFP_KERNEL); + if (!rsci.cred.cr_principal) { + status = -ENOMEM; goto out; + } } } rsci.h.expiry_time = expiry; - rscp = rsc_update(&rsci, rscp); + rscp = rsc_update(cd, &rsci, rscp); status = 0; out: - gss_mech_put(gm); rsc_free(&rsci); if (rscp) - cache_put(&rscp->h, &rsc_cache); + cache_put(&rscp->h, cd); else status = -ENOMEM; return status; } -static struct cache_detail rsc_cache = { +static struct cache_detail rsc_cache_template = { .owner = THIS_MODULE, .hash_size = RSC_HASHMAX, - .hash_table = rsc_table, .name = "auth.rpcsec.context", .cache_put = rsc_put, .cache_parse = rsc_parse, @@ -532,24 +534,24 @@ static struct cache_detail rsc_cache = { .alloc = rsc_alloc, }; -static struct rsc *rsc_lookup(struct rsc *item) +static struct rsc *rsc_lookup(struct cache_detail *cd, struct rsc *item) { struct cache_head *ch; int hash = rsc_hash(item); - ch = sunrpc_cache_lookup(&rsc_cache, &item->h, hash); + ch = sunrpc_cache_lookup(cd, &item->h, hash); if (ch) return container_of(ch, struct rsc, h); else return NULL; } -static struct rsc *rsc_update(struct rsc *new, struct rsc *old) +static struct rsc *rsc_update(struct cache_detail *cd, struct rsc *new, struct rsc *old) { struct cache_head *ch; int hash = rsc_hash(new); - ch = sunrpc_cache_update(&rsc_cache, &new->h, + ch = sunrpc_cache_update(cd, &new->h, &old->h, hash); if (ch) return container_of(ch, struct rsc, h); @@ -559,7 +561,7 @@ static struct rsc *rsc_update(struct rsc *new, struct rsc *old) static struct rsc * -gss_svc_searchbyctx(struct xdr_netobj *handle) +gss_svc_searchbyctx(struct cache_detail *cd, struct xdr_netobj *handle) { struct rsc rsci; struct rsc *found; @@ -567,11 +569,11 @@ gss_svc_searchbyctx(struct xdr_netobj *handle) memset(&rsci, 0, sizeof(rsci)); if (dup_to_netobj(&rsci.handle, handle->data, handle->len)) return NULL; - found = rsc_lookup(&rsci); + found = rsc_lookup(cd, &rsci); rsc_free(&rsci); if (!found) return NULL; - if (cache_check(&rsc_cache, &found->h, NULL)) + if (cache_check(cd, &found->h, NULL)) return NULL; return found; } @@ -820,13 +822,17 @@ read_u32_from_xdr_buf(struct xdr_buf *buf, int base, u32 *obj) * The server uses base of head iovec as read pointer, while the * client uses separate pointer. */ static int -unwrap_integ_data(struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) +unwrap_integ_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) { int stat = -EINVAL; u32 integ_len, maj_stat; struct xdr_netobj mic; struct xdr_buf integ_buf; + /* Did we already verify the signature on the original pass through? */ + if (rqstp->rq_deferred) + return 0; + integ_len = svc_getnl(&buf->head[0]); if (integ_len & 3) return stat; @@ -849,6 +855,8 @@ unwrap_integ_data(struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) goto out; if (svc_getnl(&buf->head[0]) != seq) goto out; + /* trim off the mic at the end before returning */ + xdr_buf_trim(buf, mic.len + 4); stat = 0; out: kfree(mic.data); @@ -932,16 +940,6 @@ struct gss_svc_data { struct rsc *rsci; }; -char *svc_gss_principal(struct svc_rqst *rqstp) -{ - struct gss_svc_data *gd = (struct gss_svc_data *)rqstp->rq_auth_data; - - if (gd && gd->rsci) - return gd->rsci->client_name; - return NULL; -} -EXPORT_SYMBOL_GPL(svc_gss_principal); - static int svcauth_gss_set_client(struct svc_rqst *rqstp) { @@ -963,45 +961,35 @@ svcauth_gss_set_client(struct svc_rqst *rqstp) if (rqstp->rq_gssclient == NULL) return SVC_DENIED; stat = svcauth_unix_set_client(rqstp); - if (stat == SVC_DROP) + if (stat == SVC_DROP || stat == SVC_CLOSE) return stat; return SVC_OK; } static inline int -gss_write_init_verf(struct svc_rqst *rqstp, struct rsi *rsip) +gss_write_init_verf(struct cache_detail *cd, struct svc_rqst *rqstp, + struct xdr_netobj *out_handle, int *major_status) { struct rsc *rsci; int rc; - if (rsip->major_status != GSS_S_COMPLETE) + if (*major_status != GSS_S_COMPLETE) return gss_write_null_verf(rqstp); - rsci = gss_svc_searchbyctx(&rsip->out_handle); + rsci = gss_svc_searchbyctx(cd, out_handle); if (rsci == NULL) { - rsip->major_status = GSS_S_NO_CONTEXT; + *major_status = GSS_S_NO_CONTEXT; return gss_write_null_verf(rqstp); } rc = gss_write_verf(rqstp, rsci->mechctx, GSS_SEQ_WIN); - cache_put(&rsci->h, &rsc_cache); + cache_put(&rsci->h, cd); return rc; } -/* - * Having read the cred already and found we're in the context - * initiation case, read the verifier and initiate (or check the results - * of) upcalls to userspace for help with context initiation. If - * the upcall results are available, write the verifier and result. - * Otherwise, drop the request pending an answer to the upcall. - */ -static int svcauth_gss_handle_init(struct svc_rqst *rqstp, - struct rpc_gss_wire_cred *gc, __be32 *authp) +static inline int +gss_read_common_verf(struct rpc_gss_wire_cred *gc, + struct kvec *argv, __be32 *authp, + struct xdr_netobj *in_handle) { - struct kvec *argv = &rqstp->rq_arg.head[0]; - struct kvec *resv = &rqstp->rq_res.head[0]; - struct xdr_netobj tmpobj; - struct rsi *rsip, rsikey; - int ret; - /* Read the verifier; should be NULL: */ *authp = rpc_autherr_badverf; if (argv->iov_len < 2 * 4) @@ -1010,60 +998,392 @@ static int svcauth_gss_handle_init(struct svc_rqst *rqstp, return SVC_DENIED; if (svc_getnl(argv) != 0) return SVC_DENIED; - /* Martial context handle and token for upcall: */ *authp = rpc_autherr_badcred; if (gc->gc_proc == RPC_GSS_PROC_INIT && gc->gc_ctx.len != 0) return SVC_DENIED; - memset(&rsikey, 0, sizeof(rsikey)); - if (dup_netobj(&rsikey.in_handle, &gc->gc_ctx)) - return SVC_DROP; + if (dup_netobj(in_handle, &gc->gc_ctx)) + return SVC_CLOSE; *authp = rpc_autherr_badverf; + + return 0; +} + +static inline int +gss_read_verf(struct rpc_gss_wire_cred *gc, + struct kvec *argv, __be32 *authp, + struct xdr_netobj *in_handle, + struct xdr_netobj *in_token) +{ + struct xdr_netobj tmpobj; + int res; + + res = gss_read_common_verf(gc, argv, authp, in_handle); + if (res) + return res; + if (svc_safe_getnetobj(argv, &tmpobj)) { - kfree(rsikey.in_handle.data); + kfree(in_handle->data); return SVC_DENIED; } - if (dup_netobj(&rsikey.in_token, &tmpobj)) { - kfree(rsikey.in_handle.data); - return SVC_DROP; + if (dup_netobj(in_token, &tmpobj)) { + kfree(in_handle->data); + return SVC_CLOSE; } + return 0; +} + +/* Ok this is really heavily depending on a set of semantics in + * how rqstp is set up by svc_recv and pages laid down by the + * server when reading a request. We are basically guaranteed that + * the token lays all down linearly across a set of pages, starting + * at iov_base in rq_arg.head[0] which happens to be the first of a + * set of pages stored in rq_pages[]. + * rq_arg.head[0].iov_base will provide us the page_base to pass + * to the upcall. + */ +static inline int +gss_read_proxy_verf(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc, __be32 *authp, + struct xdr_netobj *in_handle, + struct gssp_in_token *in_token) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + u32 inlen; + int res; + + res = gss_read_common_verf(gc, argv, authp, in_handle); + if (res) + return res; + + inlen = svc_getnl(argv); + if (inlen > (argv->iov_len + rqstp->rq_arg.page_len)) + return SVC_DENIED; + + in_token->pages = rqstp->rq_pages; + in_token->page_base = (ulong)argv->iov_base & ~PAGE_MASK; + in_token->page_len = inlen; + + return 0; +} + +static inline int +gss_write_resv(struct kvec *resv, size_t size_limit, + struct xdr_netobj *out_handle, struct xdr_netobj *out_token, + int major_status, int minor_status) +{ + if (resv->iov_len + 4 > size_limit) + return -1; + svc_putnl(resv, RPC_SUCCESS); + if (svc_safe_putnetobj(resv, out_handle)) + return -1; + if (resv->iov_len + 3 * 4 > size_limit) + return -1; + svc_putnl(resv, major_status); + svc_putnl(resv, minor_status); + svc_putnl(resv, GSS_SEQ_WIN); + if (svc_safe_putnetobj(resv, out_token)) + return -1; + return 0; +} + +/* + * Having read the cred already and found we're in the context + * initiation case, read the verifier and initiate (or check the results + * of) upcalls to userspace for help with context initiation. If + * the upcall results are available, write the verifier and result. + * Otherwise, drop the request pending an answer to the upcall. + */ +static int svcauth_gss_legacy_init(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc, __be32 *authp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct rsi *rsip, rsikey; + int ret; + struct sunrpc_net *sn = net_generic(rqstp->rq_xprt->xpt_net, sunrpc_net_id); + + memset(&rsikey, 0, sizeof(rsikey)); + ret = gss_read_verf(gc, argv, authp, + &rsikey.in_handle, &rsikey.in_token); + if (ret) + return ret; + /* Perform upcall, or find upcall result: */ - rsip = rsi_lookup(&rsikey); + rsip = rsi_lookup(sn->rsi_cache, &rsikey); rsi_free(&rsikey); if (!rsip) - return SVC_DROP; - switch (cache_check(&rsi_cache, &rsip->h, &rqstp->rq_chandle)) { - case -EAGAIN: - case -ETIMEDOUT: - case -ENOENT: + return SVC_CLOSE; + if (cache_check(sn->rsi_cache, &rsip->h, &rqstp->rq_chandle) < 0) /* No upcall result: */ - return SVC_DROP; - case 0: - ret = SVC_DROP; - /* Got an answer to the upcall; use it: */ - if (gss_write_init_verf(rqstp, rsip)) - goto out; - if (resv->iov_len + 4 > PAGE_SIZE) - goto out; - svc_putnl(resv, RPC_SUCCESS); - if (svc_safe_putnetobj(resv, &rsip->out_handle)) + return SVC_CLOSE; + + ret = SVC_CLOSE; + /* Got an answer to the upcall; use it: */ + if (gss_write_init_verf(sn->rsc_cache, rqstp, + &rsip->out_handle, &rsip->major_status)) + goto out; + if (gss_write_resv(resv, PAGE_SIZE, + &rsip->out_handle, &rsip->out_token, + rsip->major_status, rsip->minor_status)) + goto out; + + ret = SVC_COMPLETE; +out: + cache_put(&rsip->h, sn->rsi_cache); + return ret; +} + +static int gss_proxy_save_rsc(struct cache_detail *cd, + struct gssp_upcall_data *ud, + uint64_t *handle) +{ + struct rsc rsci, *rscp = NULL; + static atomic64_t ctxhctr; + long long ctxh; + struct gss_api_mech *gm = NULL; + time_t expiry; + int status = -EINVAL; + + memset(&rsci, 0, sizeof(rsci)); + /* context handle */ + status = -ENOMEM; + /* the handle needs to be just a unique id, + * use a static counter */ + ctxh = atomic64_inc_return(&ctxhctr); + + /* make a copy for the caller */ + *handle = ctxh; + + /* make a copy for the rsc cache */ + if (dup_to_netobj(&rsci.handle, (char *)handle, sizeof(uint64_t))) + goto out; + rscp = rsc_lookup(cd, &rsci); + if (!rscp) + goto out; + + /* creds */ + if (!ud->found_creds) { + /* userspace seem buggy, we should always get at least a + * mapping to nobody */ + dprintk("RPC: No creds found!\n"); + goto out; + } else { + + /* steal creds */ + rsci.cred = ud->creds; + memset(&ud->creds, 0, sizeof(struct svc_cred)); + + status = -EOPNOTSUPP; + /* get mech handle from OID */ + gm = gss_mech_get_by_OID(&ud->mech_oid); + if (!gm) goto out; - if (resv->iov_len + 3 * 4 > PAGE_SIZE) + rsci.cred.cr_gss_mech = gm; + + status = -EINVAL; + /* mech-specific data: */ + status = gss_import_sec_context(ud->out_handle.data, + ud->out_handle.len, + gm, &rsci.mechctx, + &expiry, GFP_KERNEL); + if (status) goto out; - svc_putnl(resv, rsip->major_status); - svc_putnl(resv, rsip->minor_status); - svc_putnl(resv, GSS_SEQ_WIN); - if (svc_safe_putnetobj(resv, &rsip->out_token)) + } + + rsci.h.expiry_time = expiry; + rscp = rsc_update(cd, &rsci, rscp); + status = 0; +out: + rsc_free(&rsci); + if (rscp) + cache_put(&rscp->h, cd); + else + status = -ENOMEM; + return status; +} + +static int svcauth_gss_proxy_init(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc, __be32 *authp) +{ + struct kvec *resv = &rqstp->rq_res.head[0]; + struct xdr_netobj cli_handle; + struct gssp_upcall_data ud; + uint64_t handle; + int status; + int ret; + struct net *net = rqstp->rq_xprt->xpt_net; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + memset(&ud, 0, sizeof(ud)); + ret = gss_read_proxy_verf(rqstp, gc, authp, + &ud.in_handle, &ud.in_token); + if (ret) + return ret; + + ret = SVC_CLOSE; + + /* Perform synchronous upcall to gss-proxy */ + status = gssp_accept_sec_context_upcall(net, &ud); + if (status) + goto out; + + dprintk("RPC: svcauth_gss: gss major status = %d\n", + ud.major_status); + + switch (ud.major_status) { + case GSS_S_CONTINUE_NEEDED: + cli_handle = ud.out_handle; + break; + case GSS_S_COMPLETE: + status = gss_proxy_save_rsc(sn->rsc_cache, &ud, &handle); + if (status) goto out; + cli_handle.data = (u8 *)&handle; + cli_handle.len = sizeof(handle); + break; + default: + ret = SVC_CLOSE; + goto out; } + + /* Got an answer to the upcall; use it: */ + if (gss_write_init_verf(sn->rsc_cache, rqstp, + &cli_handle, &ud.major_status)) + goto out; + if (gss_write_resv(resv, PAGE_SIZE, + &cli_handle, &ud.out_token, + ud.major_status, ud.minor_status)) + goto out; + ret = SVC_COMPLETE; out: - cache_put(&rsip->h, &rsi_cache); + gssp_free_upcall_data(&ud); return ret; } /* + * Try to set the sn->use_gss_proxy variable to a new value. We only allow + * it to be changed if it's currently undefined (-1). If it's any other value + * then return -EBUSY unless the type wouldn't have changed anyway. + */ +static int set_gss_proxy(struct net *net, int type) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + int ret; + + WARN_ON_ONCE(type != 0 && type != 1); + ret = cmpxchg(&sn->use_gss_proxy, -1, type); + if (ret != -1 && ret != type) + return -EBUSY; + return 0; +} + +static bool use_gss_proxy(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + /* If use_gss_proxy is still undefined, then try to disable it */ + if (sn->use_gss_proxy == -1) + set_gss_proxy(net, 0); + return sn->use_gss_proxy; +} + +#ifdef CONFIG_PROC_FS + +static ssize_t write_gssp(struct file *file, const char __user *buf, + size_t count, loff_t *ppos) +{ + struct net *net = PDE_DATA(file_inode(file)); + char tbuf[20]; + unsigned long i; + int res; + + if (*ppos || count > sizeof(tbuf)-1) + return -EINVAL; + if (copy_from_user(tbuf, buf, count)) + return -EFAULT; + + tbuf[count] = 0; + res = kstrtoul(tbuf, 0, &i); + if (res) + return res; + if (i != 1) + return -EINVAL; + res = set_gssp_clnt(net); + if (res) + return res; + res = set_gss_proxy(net, 1); + if (res) + return res; + return count; +} + +static ssize_t read_gssp(struct file *file, char __user *buf, + size_t count, loff_t *ppos) +{ + struct net *net = PDE_DATA(file_inode(file)); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + unsigned long p = *ppos; + char tbuf[10]; + size_t len; + + snprintf(tbuf, sizeof(tbuf), "%d\n", sn->use_gss_proxy); + len = strlen(tbuf); + if (p >= len) + return 0; + len -= p; + if (len > count) + len = count; + if (copy_to_user(buf, (void *)(tbuf+p), len)) + return -EFAULT; + *ppos += len; + return len; +} + +static const struct file_operations use_gss_proxy_ops = { + .open = nonseekable_open, + .write = write_gssp, + .read = read_gssp, +}; + +static int create_use_gss_proxy_proc_entry(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct proc_dir_entry **p = &sn->use_gssp_proc; + + sn->use_gss_proxy = -1; + *p = proc_create_data("use-gss-proxy", S_IFREG|S_IRUSR|S_IWUSR, + sn->proc_net_rpc, + &use_gss_proxy_ops, net); + if (!*p) + return -ENOMEM; + init_gssp_clnt(sn); + return 0; +} + +static void destroy_use_gss_proxy_proc_entry(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + if (sn->use_gssp_proc) { + remove_proc_entry("use-gss-proxy", sn->proc_net_rpc); + clear_gssp_clnt(sn); + } +} +#else /* CONFIG_PROC_FS */ + +static int create_use_gss_proxy_proc_entry(struct net *net) +{ + return 0; +} + +static void destroy_use_gss_proxy_proc_entry(struct net *net) {} + +#endif /* CONFIG_PROC_FS */ + +/* * Accept an rpcsec packet. * If context establishment, punt to user space * If data exchange, verify/decrypt @@ -1083,6 +1403,7 @@ svcauth_gss_accept(struct svc_rqst *rqstp, __be32 *authp) __be32 *rpcstart; __be32 *reject_stat = resv->iov_base + resv->iov_len; int ret; + struct sunrpc_net *sn = net_generic(rqstp->rq_xprt->xpt_net, sunrpc_net_id); dprintk("RPC: svcauth_gss: argv->iov_len = %zd\n", argv->iov_len); @@ -1105,7 +1426,7 @@ svcauth_gss_accept(struct svc_rqst *rqstp, __be32 *authp) /* credential is: * version(==1), proc(0,1,2,3), seq, service (1,2,3), handle - * at least 5 u32s, and is preceeded by length, so that makes 6. + * at least 5 u32s, and is preceded by length, so that makes 6. */ if (argv->iov_len < 5 * 4) @@ -1128,12 +1449,15 @@ svcauth_gss_accept(struct svc_rqst *rqstp, __be32 *authp) switch (gc->gc_proc) { case RPC_GSS_PROC_INIT: case RPC_GSS_PROC_CONTINUE_INIT: - return svcauth_gss_handle_init(rqstp, gc, authp); + if (use_gss_proxy(SVC_NET(rqstp))) + return svcauth_gss_proxy_init(rqstp, gc, authp); + else + return svcauth_gss_legacy_init(rqstp, gc, authp); case RPC_GSS_PROC_DATA: case RPC_GSS_PROC_DESTROY: /* Look up the context, and check the verifier: */ *authp = rpcsec_gsserr_credproblem; - rsci = gss_svc_searchbyctx(&gc->gc_ctx); + rsci = gss_svc_searchbyctx(sn->rsc_cache, &gc->gc_ctx); if (!rsci) goto auth_err; switch (gss_verify_header(rqstp, rsci, rpcstart, gc, authp)) { @@ -1176,9 +1500,10 @@ svcauth_gss_accept(struct svc_rqst *rqstp, __be32 *authp) /* placeholders for length and seq. number: */ svc_putnl(resv, 0); svc_putnl(resv, 0); - if (unwrap_integ_data(&rqstp->rq_arg, + if (unwrap_integ_data(rqstp, &rqstp->rq_arg, gc->gc_seq, rsci->mechctx)) goto garbage_args; + rqstp->rq_auth_slack = RPC_MAX_AUTH_SIZE; break; case RPC_GSS_SVC_PRIVACY: /* placeholders for length and seq. number: */ @@ -1187,14 +1512,17 @@ svcauth_gss_accept(struct svc_rqst *rqstp, __be32 *authp) if (unwrap_priv_data(rqstp, &rqstp->rq_arg, gc->gc_seq, rsci->mechctx)) goto garbage_args; + rqstp->rq_auth_slack = RPC_MAX_AUTH_SIZE * 2; break; default: goto auth_err; } svcdata->rsci = rsci; cache_get(&rsci->h); - rqstp->rq_flavor = gss_svc_to_pseudoflavor( - rsci->mechctx->mech_type, gc->gc_svc); + rqstp->rq_cred.cr_flavor = gss_svc_to_pseudoflavor( + rsci->mechctx->mech_type, + GSS_C_QOP_DEFAULT, + gc->gc_svc); ret = SVC_OK; goto out; } @@ -1213,7 +1541,7 @@ drop: ret = SVC_DROP; out: if (rsci) - cache_put(&rsci->h, &rsc_cache); + cache_put(&rsci->h, sn->rsc_cache); return ret; } @@ -1265,8 +1593,7 @@ svcauth_gss_wrap_resp_integ(struct svc_rqst *rqstp) BUG_ON(integ_len % 4); *p++ = htonl(integ_len); *p++ = htonl(gc->gc_seq); - if (xdr_buf_subsegment(resbuf, &integ_buf, integ_offset, - integ_len)) + if (xdr_buf_subsegment(resbuf, &integ_buf, integ_offset, integ_len)) BUG(); if (resbuf->tail[0].iov_base == NULL) { if (resbuf->head[0].iov_len + RPC_MAX_AUTH_SIZE > PAGE_SIZE) @@ -1274,10 +1601,8 @@ svcauth_gss_wrap_resp_integ(struct svc_rqst *rqstp) resbuf->tail[0].iov_base = resbuf->head[0].iov_base + resbuf->head[0].iov_len; resbuf->tail[0].iov_len = 0; - resv = &resbuf->tail[0]; - } else { - resv = &resbuf->tail[0]; } + resv = &resbuf->tail[0]; mic.data = (u8 *)resv->iov_base + resv->iov_len + 4; if (gss_get_mic(gsd->rsci->mechctx, &integ_buf, &mic)) goto out_err; @@ -1314,6 +1639,14 @@ svcauth_gss_wrap_resp_priv(struct svc_rqst *rqstp) inpages = resbuf->pages; /* XXX: Would be better to write some xdr helper functions for * nfs{2,3,4}xdr.c that place the data right, instead of copying: */ + + /* + * If there is currently tail data, make sure there is + * room for the head, tail, and 2 * RPC_MAX_AUTH_SIZE in + * the page, and move the current tail data such that + * there is RPC_MAX_AUTH_SIZE slack space available in + * both the head and tail. + */ if (resbuf->tail[0].iov_base) { BUG_ON(resbuf->tail[0].iov_base >= resbuf->head[0].iov_base + PAGE_SIZE); @@ -1326,6 +1659,13 @@ svcauth_gss_wrap_resp_priv(struct svc_rqst *rqstp) resbuf->tail[0].iov_len); resbuf->tail[0].iov_base += RPC_MAX_AUTH_SIZE; } + /* + * If there is no current tail data, make sure there is + * room for the head data, and 2 * RPC_MAX_AUTH_SIZE in the + * allotted page, and set up tail information such that there + * is RPC_MAX_AUTH_SIZE slack space available in both the + * head and tail. + */ if (resbuf->tail[0].iov_base == NULL) { if (resbuf->head[0].iov_len + 2*RPC_MAX_AUTH_SIZE > PAGE_SIZE) return -ENOMEM; @@ -1351,6 +1691,7 @@ svcauth_gss_release(struct svc_rqst *rqstp) struct rpc_gss_wire_cred *gc = &gsd->clcred; struct xdr_buf *resbuf = &rqstp->rq_res; int stat = -EINVAL; + struct sunrpc_net *sn = net_generic(rqstp->rq_xprt->xpt_net, sunrpc_net_id); if (gc->gc_proc != RPC_GSS_PROC_DATA) goto out; @@ -1393,7 +1734,7 @@ out_err: put_group_info(rqstp->rq_cred.cr_group_info); rqstp->rq_cred.cr_group_info = NULL; if (gsd->rsci) - cache_put(&gsd->rsci->h, &rsc_cache); + cache_put(&gsd->rsci->h, sn->rsc_cache); gsd->rsci = NULL; return stat; @@ -1418,30 +1759,102 @@ static struct auth_ops svcauthops_gss = { .set_client = svcauth_gss_set_client, }; +static int rsi_cache_create_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&rsi_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->rsi_cache = cd; + return 0; +} + +static void rsi_cache_destroy_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->rsi_cache; + + sn->rsi_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} + +static int rsc_cache_create_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&rsc_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->rsc_cache = cd; + return 0; +} + +static void rsc_cache_destroy_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->rsc_cache; + + sn->rsc_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} + int -gss_svc_init(void) +gss_svc_init_net(struct net *net) { - int rv = svc_auth_register(RPC_AUTH_GSS, &svcauthops_gss); + int rv; + + rv = rsc_cache_create_net(net); if (rv) return rv; - rv = cache_register(&rsc_cache); + rv = rsi_cache_create_net(net); if (rv) goto out1; - rv = cache_register(&rsi_cache); + rv = create_use_gss_proxy_proc_entry(net); if (rv) goto out2; return 0; out2: - cache_unregister(&rsc_cache); + destroy_use_gss_proxy_proc_entry(net); out1: - svc_auth_unregister(RPC_AUTH_GSS); + rsc_cache_destroy_net(net); return rv; } void +gss_svc_shutdown_net(struct net *net) +{ + destroy_use_gss_proxy_proc_entry(net); + rsi_cache_destroy_net(net); + rsc_cache_destroy_net(net); +} + +int +gss_svc_init(void) +{ + return svc_auth_register(RPC_AUTH_GSS, &svcauthops_gss); +} + +void gss_svc_shutdown(void) { - cache_unregister(&rsc_cache); - cache_unregister(&rsi_cache); svc_auth_unregister(RPC_AUTH_GSS); } diff --git a/net/sunrpc/auth_null.c b/net/sunrpc/auth_null.c index 1db618f56ec..f0ebe07978a 100644 --- a/net/sunrpc/auth_null.c +++ b/net/sunrpc/auth_null.c @@ -18,7 +18,7 @@ static struct rpc_auth null_auth; static struct rpc_cred null_cred; static struct rpc_auth * -nul_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +nul_create(struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { atomic_inc(&null_auth.au_count); return &null_auth; @@ -75,7 +75,7 @@ nul_marshal(struct rpc_task *task, __be32 *p) static int nul_refresh(struct rpc_task *task) { - set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_msg.rpc_cred->cr_flags); + set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); return 0; } @@ -88,13 +88,13 @@ nul_validate(struct rpc_task *task, __be32 *p) flavor = ntohl(*p++); if (flavor != RPC_AUTH_NULL) { printk("RPC: bad verf flavor: %u\n", flavor); - return NULL; + return ERR_PTR(-EIO); } size = ntohl(*p++); if (size != 0) { printk("RPC: bad verf size: %u\n", size); - return NULL; + return ERR_PTR(-EIO); } return p; diff --git a/net/sunrpc/auth_unix.c b/net/sunrpc/auth_unix.c index 46b2647c5bd..d5d69236629 100644 --- a/net/sunrpc/auth_unix.c +++ b/net/sunrpc/auth_unix.c @@ -6,18 +6,20 @@ * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> */ +#include <linux/slab.h> #include <linux/types.h> #include <linux/sched.h> #include <linux/module.h> #include <linux/sunrpc/clnt.h> #include <linux/sunrpc/auth.h> +#include <linux/user_namespace.h> #define NFS_NGROUPS 16 struct unx_cred { struct rpc_cred uc_base; - gid_t uc_gid; - gid_t uc_gids[NFS_NGROUPS]; + kgid_t uc_gid; + kgid_t uc_gids[NFS_NGROUPS]; }; #define uc_uid uc_base.cr_uid @@ -28,11 +30,10 @@ struct unx_cred { #endif static struct rpc_auth unix_auth; -static struct rpc_cred_cache unix_cred_cache; static const struct rpc_credops unix_credops; static struct rpc_auth * -unx_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +unx_create(struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { dprintk("RPC: creating UNIX authenticator for client %p\n", clnt); @@ -64,7 +65,8 @@ unx_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) unsigned int i; dprintk("RPC: allocating UNIX cred for uid %d gid %d\n", - acred->uid, acred->gid); + from_kuid(&init_user_ns, acred->uid), + from_kgid(&init_user_ns, acred->gid)); if (!(cred = kmalloc(sizeof(*cred), GFP_NOFS))) return ERR_PTR(-ENOMEM); @@ -81,7 +83,7 @@ unx_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) for (i = 0; i < groups; i++) cred->uc_gids[i] = GROUP_AT(acred->group_info, i); if (i < NFS_NGROUPS) - cred->uc_gids[i] = NOGROUP; + cred->uc_gids[i] = INVALID_GID; return &cred->uc_base; } @@ -119,7 +121,7 @@ unx_match(struct auth_cred *acred, struct rpc_cred *rcred, int flags) unsigned int i; - if (cred->uc_uid != acred->uid || cred->uc_gid != acred->gid) + if (!uid_eq(cred->uc_uid, acred->uid) || !gid_eq(cred->uc_gid, acred->gid)) return 0; if (acred->group_info != NULL) @@ -127,8 +129,10 @@ unx_match(struct auth_cred *acred, struct rpc_cred *rcred, int flags) if (groups > NFS_NGROUPS) groups = NFS_NGROUPS; for (i = 0; i < groups ; i++) - if (cred->uc_gids[i] != GROUP_AT(acred->group_info, i)) + if (!gid_eq(cred->uc_gids[i], GROUP_AT(acred->group_info, i))) return 0; + if (groups < NFS_NGROUPS && gid_valid(cred->uc_gids[groups])) + return 0; return 1; } @@ -140,7 +144,7 @@ static __be32 * unx_marshal(struct rpc_task *task, __be32 *p) { struct rpc_clnt *clnt = task->tk_client; - struct unx_cred *cred = container_of(task->tk_msg.rpc_cred, struct unx_cred, uc_base); + struct unx_cred *cred = container_of(task->tk_rqstp->rq_cred, struct unx_cred, uc_base); __be32 *base, *hold; int i; @@ -153,11 +157,11 @@ unx_marshal(struct rpc_task *task, __be32 *p) */ p = xdr_encode_array(p, clnt->cl_nodename, clnt->cl_nodelen); - *p++ = htonl((u32) cred->uc_uid); - *p++ = htonl((u32) cred->uc_gid); + *p++ = htonl((u32) from_kuid(&init_user_ns, cred->uc_uid)); + *p++ = htonl((u32) from_kgid(&init_user_ns, cred->uc_gid)); hold = p++; - for (i = 0; i < 16 && cred->uc_gids[i] != (gid_t) NOGROUP; i++) - *p++ = htonl((u32) cred->uc_gids[i]); + for (i = 0; i < 16 && gid_valid(cred->uc_gids[i]); i++) + *p++ = htonl((u32) from_kgid(&init_user_ns, cred->uc_gids[i])); *hold = htonl(p - hold - 1); /* gid array length */ *base = htonl((p - base - 1) << 2); /* cred length */ @@ -173,7 +177,7 @@ unx_marshal(struct rpc_task *task, __be32 *p) static int unx_refresh(struct rpc_task *task) { - set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_msg.rpc_cred->cr_flags); + set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); return 0; } @@ -188,23 +192,28 @@ unx_validate(struct rpc_task *task, __be32 *p) flavor != RPC_AUTH_UNIX && flavor != RPC_AUTH_SHORT) { printk("RPC: bad verf flavor: %u\n", flavor); - return NULL; + return ERR_PTR(-EIO); } size = ntohl(*p++); if (size > RPC_MAX_AUTH_SIZE) { printk("RPC: giant verf size: %u\n", size); - return NULL; + return ERR_PTR(-EIO); } - task->tk_msg.rpc_cred->cr_auth->au_rslack = (size >> 2) + 2; + task->tk_rqstp->rq_cred->cr_auth->au_rslack = (size >> 2) + 2; p += (size >> 2); return p; } -void __init rpc_init_authunix(void) +int __init rpc_init_authunix(void) +{ + return rpcauth_init_credcache(&unix_auth); +} + +void rpc_destroy_authunix(void) { - spin_lock_init(&unix_cred_cache.lock); + rpcauth_destroy_credcache(&unix_auth); } const struct rpc_authops authunix_ops = { @@ -218,17 +227,12 @@ const struct rpc_authops authunix_ops = { }; static -struct rpc_cred_cache unix_cred_cache = { -}; - -static struct rpc_auth unix_auth = { .au_cslack = UNX_WRITESLACK, .au_rslack = 2, /* assume AUTH_NULL verf */ .au_ops = &authunix_ops, .au_flavor = RPC_AUTH_UNIX, .au_count = ATOMIC_INIT(0), - .au_credcache = &unix_cred_cache, }; static diff --git a/net/sunrpc/backchannel_rqst.c b/net/sunrpc/backchannel_rqst.c index 553621fb2c4..9761a0da964 100644 --- a/net/sunrpc/backchannel_rqst.c +++ b/net/sunrpc/backchannel_rqst.c @@ -22,14 +22,15 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ******************************************************************************/ #include <linux/tcp.h> +#include <linux/slab.h> #include <linux/sunrpc/xprt.h> +#include <linux/export.h> +#include <linux/sunrpc/bc_xprt.h> #ifdef RPC_DEBUG #define RPCDBG_FACILITY RPCDBG_TRANS #endif -#if defined(CONFIG_NFS_V4_1) - /* * Helper routines that track the number of preallocation elements * on the transport. @@ -58,12 +59,11 @@ static void xprt_free_allocation(struct rpc_rqst *req) struct xdr_buf *xbufp; dprintk("RPC: free allocations for req= %p\n", req); - BUG_ON(test_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state)); + WARN_ON_ONCE(test_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state)); xbufp = &req->rq_private_buf; free_page((unsigned long)xbufp->head[0].iov_base); xbufp = &req->rq_snd_buf; free_page((unsigned long)xbufp->head[0].iov_base); - list_del(&req->rq_bc_pa_list); kfree(req); } @@ -167,21 +167,24 @@ out_free: /* * Memory allocation failed, free the temporary list */ - list_for_each_entry_safe(req, tmp, &tmp_list, rq_bc_pa_list) + list_for_each_entry_safe(req, tmp, &tmp_list, rq_bc_pa_list) { + list_del(&req->rq_bc_pa_list); xprt_free_allocation(req); + } dprintk("RPC: setup backchannel transport failed\n"); - return -1; + return -ENOMEM; } -EXPORT_SYMBOL(xprt_setup_backchannel); +EXPORT_SYMBOL_GPL(xprt_setup_backchannel); -/* - * Destroys the backchannel preallocated structures. +/** + * xprt_destroy_backchannel - Destroys the backchannel preallocated structures. + * @xprt: the transport holding the preallocated strucures + * @max_reqs the maximum number of preallocated structures to destroy + * * Since these structures may have been allocated by multiple calls * to xprt_setup_backchannel, we only destroy up to the maximum number * of reqs specified by the caller. - * @xprt: the transport holding the preallocated strucures - * @max_reqs the maximum number of preallocated structures to destroy */ void xprt_destroy_backchannel(struct rpc_xprt *xprt, unsigned int max_reqs) { @@ -189,55 +192,43 @@ void xprt_destroy_backchannel(struct rpc_xprt *xprt, unsigned int max_reqs) dprintk("RPC: destroy backchannel transport\n"); - BUG_ON(max_reqs == 0); + if (max_reqs == 0) + goto out; + spin_lock_bh(&xprt->bc_pa_lock); xprt_dec_alloc_count(xprt, max_reqs); list_for_each_entry_safe(req, tmp, &xprt->bc_pa_list, rq_bc_pa_list) { dprintk("RPC: req=%p\n", req); + list_del(&req->rq_bc_pa_list); xprt_free_allocation(req); if (--max_reqs == 0) break; } spin_unlock_bh(&xprt->bc_pa_lock); +out: dprintk("RPC: backchannel list empty= %s\n", list_empty(&xprt->bc_pa_list) ? "true" : "false"); } -EXPORT_SYMBOL(xprt_destroy_backchannel); +EXPORT_SYMBOL_GPL(xprt_destroy_backchannel); -/* - * One or more rpc_rqst structure have been preallocated during the - * backchannel setup. Buffer space for the send and private XDR buffers - * has been preallocated as well. Use xprt_alloc_bc_request to allocate - * to this request. Use xprt_free_bc_request to return it. - * - * We know that we're called in soft interrupt context, grab the spin_lock - * since there is no need to grab the bottom half spin_lock. - * - * Return an available rpc_rqst, otherwise NULL if non are available. - */ -struct rpc_rqst *xprt_alloc_bc_request(struct rpc_xprt *xprt) +static struct rpc_rqst *xprt_alloc_bc_request(struct rpc_xprt *xprt, __be32 xid) { - struct rpc_rqst *req; + struct rpc_rqst *req = NULL; dprintk("RPC: allocate a backchannel request\n"); - spin_lock(&xprt->bc_pa_lock); - if (!list_empty(&xprt->bc_pa_list)) { - req = list_first_entry(&xprt->bc_pa_list, struct rpc_rqst, - rq_bc_pa_list); - list_del(&req->rq_bc_pa_list); - } else { - req = NULL; - } - spin_unlock(&xprt->bc_pa_lock); + if (list_empty(&xprt->bc_pa_list)) + goto not_found; - if (req != NULL) { - set_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state); - req->rq_reply_bytes_recvd = 0; - req->rq_bytes_sent = 0; - memcpy(&req->rq_private_buf, &req->rq_rcv_buf, + req = list_first_entry(&xprt->bc_pa_list, struct rpc_rqst, + rq_bc_pa_list); + req->rq_reply_bytes_recvd = 0; + req->rq_bytes_sent = 0; + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(req->rq_private_buf)); - } + req->rq_xid = xid; + req->rq_connect_cookie = xprt->connect_cookie; +not_found: dprintk("RPC: backchannel req=%p\n", req); return req; } @@ -252,10 +243,11 @@ void xprt_free_bc_request(struct rpc_rqst *req) dprintk("RPC: free backchannel req=%p\n", req); - smp_mb__before_clear_bit(); - BUG_ON(!test_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state)); + req->rq_connect_cookie = xprt->connect_cookie - 1; + smp_mb__before_atomic(); + WARN_ON_ONCE(!test_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state)); clear_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state); - smp_mb__after_clear_bit(); + smp_mb__after_atomic(); if (!xprt_need_to_requeue(xprt)) { /* @@ -274,8 +266,57 @@ void xprt_free_bc_request(struct rpc_rqst *req) * may be reused by a new callback request. */ spin_lock_bh(&xprt->bc_pa_lock); - list_add(&req->rq_bc_pa_list, &xprt->bc_pa_list); + list_add_tail(&req->rq_bc_pa_list, &xprt->bc_pa_list); spin_unlock_bh(&xprt->bc_pa_lock); } -#endif /* CONFIG_NFS_V4_1 */ +/* + * One or more rpc_rqst structure have been preallocated during the + * backchannel setup. Buffer space for the send and private XDR buffers + * has been preallocated as well. Use xprt_alloc_bc_request to allocate + * to this request. Use xprt_free_bc_request to return it. + * + * We know that we're called in soft interrupt context, grab the spin_lock + * since there is no need to grab the bottom half spin_lock. + * + * Return an available rpc_rqst, otherwise NULL if non are available. + */ +struct rpc_rqst *xprt_lookup_bc_request(struct rpc_xprt *xprt, __be32 xid) +{ + struct rpc_rqst *req; + + spin_lock(&xprt->bc_pa_lock); + list_for_each_entry(req, &xprt->bc_pa_list, rq_bc_pa_list) { + if (req->rq_connect_cookie != xprt->connect_cookie) + continue; + if (req->rq_xid == xid) + goto found; + } + req = xprt_alloc_bc_request(xprt, xid); +found: + spin_unlock(&xprt->bc_pa_lock); + return req; +} + +/* + * Add callback request to callback list. The callback + * service sleeps on the sv_cb_waitq waiting for new + * requests. Wake it up after adding enqueing the + * request. + */ +void xprt_complete_bc_request(struct rpc_rqst *req, uint32_t copied) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct svc_serv *bc_serv = xprt->bc_serv; + + req->rq_private_buf.len = copied; + set_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state); + + dprintk("RPC: add callback request to list\n"); + spin_lock(&bc_serv->sv_cb_lock); + list_del(&req->rq_bc_pa_list); + list_add(&req->rq_bc_list, &bc_serv->sv_cb_list); + wake_up(&bc_serv->sv_cb_waitq); + spin_unlock(&bc_serv->sv_cb_lock); +} + diff --git a/net/sunrpc/bc_svc.c b/net/sunrpc/bc_svc.c index 13f214f5312..15c7a8a1c24 100644 --- a/net/sunrpc/bc_svc.c +++ b/net/sunrpc/bc_svc.c @@ -27,8 +27,6 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * reply over an existing open connection previously established by the client. */ -#if defined(CONFIG_NFS_V4_1) - #include <linux/module.h> #include <linux/sunrpc/xprt.h> @@ -37,21 +35,6 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define RPCDBG_FACILITY RPCDBG_SVCDSP -void bc_release_request(struct rpc_task *task) -{ - struct rpc_rqst *req = task->tk_rqstp; - - dprintk("RPC: bc_release_request: task= %p\n", task); - - /* - * Release this request only if it's a backchannel - * preallocated request - */ - if (!bc_prealloc(req)) - return; - xprt_free_bc_request(req); -} - /* Empty callback ops */ static const struct rpc_call_ops nfs41_callback_ops = { }; @@ -70,12 +53,11 @@ int bc_send(struct rpc_rqst *req) if (IS_ERR(task)) ret = PTR_ERR(task); else { - BUG_ON(atomic_read(&task->tk_count) != 1); + WARN_ON_ONCE(atomic_read(&task->tk_count) != 1); ret = task->tk_status; rpc_put_task(task); } + dprintk("RPC: bc_send ret= %d\n", ret); return ret; - dprintk("RPC: bc_send ret= %d \n", ret); } -#endif /* CONFIG_NFS_V4_1 */ diff --git a/net/sunrpc/cache.c b/net/sunrpc/cache.c index 39bddba53ba..06636214113 100644 --- a/net/sunrpc/cache.c +++ b/net/sunrpc/cache.c @@ -33,15 +33,16 @@ #include <linux/sunrpc/cache.h> #include <linux/sunrpc/stats.h> #include <linux/sunrpc/rpc_pipe_fs.h> +#include "netns.h" #define RPCDBG_FACILITY RPCDBG_CACHE -static int cache_defer_req(struct cache_req *req, struct cache_head *item); +static bool cache_defer_req(struct cache_req *req, struct cache_head *item); static void cache_revisit_request(struct cache_head *item); static void cache_init(struct cache_head *h) { - time_t now = get_seconds(); + time_t now = seconds_since_boot(); h->next = NULL; h->flags = 0; kref_init(&h->ref); @@ -53,7 +54,7 @@ struct cache_head *sunrpc_cache_lookup(struct cache_detail *detail, struct cache_head *key, int hash) { struct cache_head **head, **hp; - struct cache_head *new = NULL; + struct cache_head *new = NULL, *freeme = NULL; head = &detail->hash_table[hash]; @@ -62,6 +63,9 @@ struct cache_head *sunrpc_cache_lookup(struct cache_detail *detail, for (hp=head; *hp != NULL ; hp = &(*hp)->next) { struct cache_head *tmp = *hp; if (detail->match(tmp, key)) { + if (cache_is_expired(detail, tmp)) + /* This entry is expired, we will discard it. */ + break; cache_get(tmp); read_unlock(&detail->hash_lock); return tmp; @@ -86,6 +90,13 @@ struct cache_head *sunrpc_cache_lookup(struct cache_detail *detail, for (hp=head; *hp != NULL ; hp = &(*hp)->next) { struct cache_head *tmp = *hp; if (detail->match(tmp, key)) { + if (cache_is_expired(detail, tmp)) { + *hp = tmp->next; + tmp->next = NULL; + detail->entries --; + freeme = tmp; + break; + } cache_get(tmp); write_unlock(&detail->hash_lock); cache_put(new, detail); @@ -98,6 +109,8 @@ struct cache_head *sunrpc_cache_lookup(struct cache_detail *detail, cache_get(new); write_unlock(&detail->hash_lock); + if (freeme) + cache_put(freeme, detail); return new; } EXPORT_SYMBOL_GPL(sunrpc_cache_lookup); @@ -108,7 +121,8 @@ static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch); static void cache_fresh_locked(struct cache_head *head, time_t expiry) { head->expiry_time = expiry; - head->last_refresh = get_seconds(); + head->last_refresh = seconds_since_boot(); + smp_wmb(); /* paired with smp_rmb() in cache_is_valid() */ set_bit(CACHE_VALID, &head->flags); } @@ -176,25 +190,46 @@ EXPORT_SYMBOL_GPL(sunrpc_cache_update); static int cache_make_upcall(struct cache_detail *cd, struct cache_head *h) { - if (!cd->cache_upcall) - return -EINVAL; - return cd->cache_upcall(cd, h); + if (cd->cache_upcall) + return cd->cache_upcall(cd, h); + return sunrpc_cache_pipe_upcall(cd, h); } -static inline int cache_is_valid(struct cache_detail *detail, struct cache_head *h) +static inline int cache_is_valid(struct cache_head *h) { - if (!test_bit(CACHE_VALID, &h->flags) || - h->expiry_time < get_seconds()) - return -EAGAIN; - else if (detail->flush_time > h->last_refresh) + if (!test_bit(CACHE_VALID, &h->flags)) return -EAGAIN; else { /* entry is valid */ if (test_bit(CACHE_NEGATIVE, &h->flags)) return -ENOENT; - else + else { + /* + * In combination with write barrier in + * sunrpc_cache_update, ensures that anyone + * using the cache entry after this sees the + * updated contents: + */ + smp_rmb(); return 0; + } + } +} + +static int try_to_negate_entry(struct cache_detail *detail, struct cache_head *h) +{ + int rv; + + write_lock(&detail->hash_lock); + rv = cache_is_valid(h); + if (rv == -EAGAIN) { + set_bit(CACHE_NEGATIVE, &h->flags); + cache_fresh_locked(h, seconds_since_boot()+CACHE_NEW_EXPIRY); + rv = -ENOENT; } + write_unlock(&detail->hash_lock); + cache_fresh_unlocked(h, detail); + return rv; } /* @@ -218,43 +253,38 @@ int cache_check(struct cache_detail *detail, long refresh_age, age; /* First decide return status as best we can */ - rv = cache_is_valid(detail, h); + rv = cache_is_valid(h); /* now see if we want to start an upcall */ refresh_age = (h->expiry_time - h->last_refresh); - age = get_seconds() - h->last_refresh; + age = seconds_since_boot() - h->last_refresh; if (rqstp == NULL) { if (rv == -EAGAIN) rv = -ENOENT; - } else if (rv == -EAGAIN || age > refresh_age/2) { + } else if (rv == -EAGAIN || + (h->expiry_time != 0 && age > refresh_age/2)) { dprintk("RPC: Want update, refage=%ld, age=%ld\n", refresh_age, age); if (!test_and_set_bit(CACHE_PENDING, &h->flags)) { switch (cache_make_upcall(detail, h)) { case -EINVAL: - clear_bit(CACHE_PENDING, &h->flags); - cache_revisit_request(h); - if (rv == -EAGAIN) { - set_bit(CACHE_NEGATIVE, &h->flags); - cache_fresh_locked(h, get_seconds()+CACHE_NEW_EXPIRY); - cache_fresh_unlocked(h, detail); - rv = -ENOENT; - } + rv = try_to_negate_entry(detail, h); break; - case -EAGAIN: - clear_bit(CACHE_PENDING, &h->flags); - cache_revisit_request(h); + cache_fresh_unlocked(h, detail); break; } } } if (rv == -EAGAIN) { - if (cache_defer_req(rqstp, h) < 0) { - /* Request is not deferred */ - rv = cache_is_valid(detail, h); + if (!cache_defer_req(rqstp, h)) { + /* + * Request was not deferred; handle it as best + * we can ourselves: + */ + rv = cache_is_valid(h); if (rv == -EAGAIN) rv = -ETIMEDOUT; } @@ -271,7 +301,7 @@ EXPORT_SYMBOL_GPL(cache_check); * a current pointer into that list and into the table * for that entry. * - * Each time clean_cache is called it finds the next non-empty entry + * Each time cache_clean is called it finds the next non-empty entry * in the current table and walks the list in that entry * looking for entries that can be removed. * @@ -303,9 +333,9 @@ static struct cache_detail *current_detail; static int current_index; static void do_cache_clean(struct work_struct *work); -static DECLARE_DELAYED_WORK(cache_cleaner, do_cache_clean); +static struct delayed_work cache_cleaner; -static void sunrpc_init_cache_detail(struct cache_detail *cd) +void sunrpc_init_cache_detail(struct cache_detail *cd) { rwlock_init(&cd->hash_lock); INIT_LIST_HEAD(&cd->queue); @@ -321,8 +351,9 @@ static void sunrpc_init_cache_detail(struct cache_detail *cd) /* start the cleaning process */ schedule_delayed_work(&cache_cleaner, 0); } +EXPORT_SYMBOL_GPL(sunrpc_init_cache_detail); -static void sunrpc_destroy_cache_detail(struct cache_detail *cd) +void sunrpc_destroy_cache_detail(struct cache_detail *cd) { cache_purge(cd); spin_lock(&cache_list_lock); @@ -343,8 +374,9 @@ static void sunrpc_destroy_cache_detail(struct cache_detail *cd) } return; out: - printk(KERN_ERR "nfsd: failed to unregister %s cache\n", cd->name); + printk(KERN_ERR "RPC: failed to unregister %s cache\n", cd->name); } +EXPORT_SYMBOL_GPL(sunrpc_destroy_cache_detail); /* clean cache tries to find something to clean * and cleans it. @@ -372,11 +404,11 @@ static int cache_clean(void) return -1; } current_detail = list_entry(next, struct cache_detail, others); - if (current_detail->nextcheck > get_seconds()) + if (current_detail->nextcheck > seconds_since_boot()) current_index = current_detail->hash_size; else { current_index = 0; - current_detail->nextcheck = get_seconds()+30*60; + current_detail->nextcheck = seconds_since_boot()+30*60; } } @@ -397,32 +429,27 @@ static int cache_clean(void) /* Ok, now to clean this strand */ cp = & current_detail->hash_table[current_index]; - ch = *cp; - for (; ch; cp= & ch->next, ch= *cp) { + for (ch = *cp ; ch ; cp = & ch->next, ch = *cp) { if (current_detail->nextcheck > ch->expiry_time) current_detail->nextcheck = ch->expiry_time+1; - if (ch->expiry_time >= get_seconds() && - ch->last_refresh >= current_detail->flush_time) + if (!cache_is_expired(current_detail, ch)) continue; - if (test_and_clear_bit(CACHE_PENDING, &ch->flags)) - cache_dequeue(current_detail, ch); - if (atomic_read(&ch->ref.refcount) == 1) - break; - } - if (ch) { *cp = ch->next; ch->next = NULL; current_detail->entries--; rv = 1; + break; } + write_unlock(¤t_detail->hash_lock); d = current_detail; if (!ch) current_index ++; spin_unlock(&cache_list_lock); if (ch) { - cache_revisit_request(ch); + set_bit(CACHE_CLEANED, &ch->flags); + cache_fresh_unlocked(ch, d); cache_put(ch, d); } } else @@ -465,7 +492,7 @@ EXPORT_SYMBOL_GPL(cache_flush); void cache_purge(struct cache_detail *detail) { detail->flush_time = LONG_MAX; - detail->nextcheck = get_seconds(); + detail->nextcheck = seconds_since_boot(); cache_flush(); detail->flush_time = 1; } @@ -494,81 +521,157 @@ EXPORT_SYMBOL_GPL(cache_purge); static DEFINE_SPINLOCK(cache_defer_lock); static LIST_HEAD(cache_defer_list); -static struct list_head cache_defer_hash[DFR_HASHSIZE]; +static struct hlist_head cache_defer_hash[DFR_HASHSIZE]; static int cache_defer_cnt; -static int cache_defer_req(struct cache_req *req, struct cache_head *item) +static void __unhash_deferred_req(struct cache_deferred_req *dreq) +{ + hlist_del_init(&dreq->hash); + if (!list_empty(&dreq->recent)) { + list_del_init(&dreq->recent); + cache_defer_cnt--; + } +} + +static void __hash_deferred_req(struct cache_deferred_req *dreq, struct cache_head *item) { - struct cache_deferred_req *dreq, *discard; int hash = DFR_HASH(item); - if (cache_defer_cnt >= DFR_MAX) { - /* too much in the cache, randomly drop this one, - * or continue and drop the oldest below - */ - if (net_random()&1) - return -ENOMEM; - } - dreq = req->defer(req); - if (dreq == NULL) - return -ENOMEM; + INIT_LIST_HEAD(&dreq->recent); + hlist_add_head(&dreq->hash, &cache_defer_hash[hash]); +} + +static void setup_deferral(struct cache_deferred_req *dreq, + struct cache_head *item, + int count_me) +{ dreq->item = item; spin_lock(&cache_defer_lock); - list_add(&dreq->recent, &cache_defer_list); + __hash_deferred_req(dreq, item); - if (cache_defer_hash[hash].next == NULL) - INIT_LIST_HEAD(&cache_defer_hash[hash]); - list_add(&dreq->hash, &cache_defer_hash[hash]); - - /* it is in, now maybe clean up */ - discard = NULL; - if (++cache_defer_cnt > DFR_MAX) { - discard = list_entry(cache_defer_list.prev, - struct cache_deferred_req, recent); - list_del_init(&discard->recent); - list_del_init(&discard->hash); - cache_defer_cnt--; + if (count_me) { + cache_defer_cnt++; + list_add(&dreq->recent, &cache_defer_list); } + spin_unlock(&cache_defer_lock); +} + +struct thread_deferred_req { + struct cache_deferred_req handle; + struct completion completion; +}; + +static void cache_restart_thread(struct cache_deferred_req *dreq, int too_many) +{ + struct thread_deferred_req *dr = + container_of(dreq, struct thread_deferred_req, handle); + complete(&dr->completion); +} + +static void cache_wait_req(struct cache_req *req, struct cache_head *item) +{ + struct thread_deferred_req sleeper; + struct cache_deferred_req *dreq = &sleeper.handle; + + sleeper.completion = COMPLETION_INITIALIZER_ONSTACK(sleeper.completion); + dreq->revisit = cache_restart_thread; + + setup_deferral(dreq, item, 0); + + if (!test_bit(CACHE_PENDING, &item->flags) || + wait_for_completion_interruptible_timeout( + &sleeper.completion, req->thread_wait) <= 0) { + /* The completion wasn't completed, so we need + * to clean up + */ + spin_lock(&cache_defer_lock); + if (!hlist_unhashed(&sleeper.handle.hash)) { + __unhash_deferred_req(&sleeper.handle); + spin_unlock(&cache_defer_lock); + } else { + /* cache_revisit_request already removed + * this from the hash table, but hasn't + * called ->revisit yet. It will very soon + * and we need to wait for it. + */ + spin_unlock(&cache_defer_lock); + wait_for_completion(&sleeper.completion); + } + } +} + +static void cache_limit_defers(void) +{ + /* Make sure we haven't exceed the limit of allowed deferred + * requests. + */ + struct cache_deferred_req *discard = NULL; + + if (cache_defer_cnt <= DFR_MAX) + return; + + spin_lock(&cache_defer_lock); + + /* Consider removing either the first or the last */ + if (cache_defer_cnt > DFR_MAX) { + if (prandom_u32() & 1) + discard = list_entry(cache_defer_list.next, + struct cache_deferred_req, recent); + else + discard = list_entry(cache_defer_list.prev, + struct cache_deferred_req, recent); + __unhash_deferred_req(discard); + } + spin_unlock(&cache_defer_lock); if (discard) - /* there was one too many */ discard->revisit(discard, 1); +} - if (!test_bit(CACHE_PENDING, &item->flags)) { - /* must have just been validated... */ - cache_revisit_request(item); - return -EAGAIN; +/* Return true if and only if a deferred request is queued. */ +static bool cache_defer_req(struct cache_req *req, struct cache_head *item) +{ + struct cache_deferred_req *dreq; + + if (req->thread_wait) { + cache_wait_req(req, item); + if (!test_bit(CACHE_PENDING, &item->flags)) + return false; } - return 0; + dreq = req->defer(req); + if (dreq == NULL) + return false; + setup_deferral(dreq, item, 1); + if (!test_bit(CACHE_PENDING, &item->flags)) + /* Bit could have been cleared before we managed to + * set up the deferral, so need to revisit just in case + */ + cache_revisit_request(item); + + cache_limit_defers(); + return true; } static void cache_revisit_request(struct cache_head *item) { struct cache_deferred_req *dreq; struct list_head pending; - - struct list_head *lp; + struct hlist_node *tmp; int hash = DFR_HASH(item); INIT_LIST_HEAD(&pending); spin_lock(&cache_defer_lock); - lp = cache_defer_hash[hash].next; - if (lp) { - while (lp != &cache_defer_hash[hash]) { - dreq = list_entry(lp, struct cache_deferred_req, hash); - lp = lp->next; - if (dreq->item == item) { - list_del_init(&dreq->hash); - list_move(&dreq->recent, &pending); - cache_defer_cnt--; - } + hlist_for_each_entry_safe(dreq, tmp, &cache_defer_hash[hash], hash) + if (dreq->item == item) { + __unhash_deferred_req(dreq); + list_add(&dreq->recent, &pending); } - } + spin_unlock(&cache_defer_lock); while (!list_empty(&pending)) { @@ -589,9 +692,8 @@ void cache_clean_deferred(void *owner) list_for_each_entry_safe(dreq, tmp, &cache_defer_list, recent) { if (dreq->owner == owner) { - list_del_init(&dreq->hash); - list_move(&dreq->recent, &pending); - cache_defer_cnt--; + __unhash_deferred_req(dreq); + list_add(&dreq->recent, &pending); } } spin_unlock(&cache_defer_lock); @@ -638,12 +740,24 @@ struct cache_reader { int offset; /* if non-0, we have a refcnt on next request */ }; +static int cache_request(struct cache_detail *detail, + struct cache_request *crq) +{ + char *bp = crq->buf; + int len = PAGE_SIZE; + + detail->cache_request(detail, crq->item, &bp, &len); + if (len < 0) + return -EAGAIN; + return PAGE_SIZE - len; +} + static ssize_t cache_read(struct file *filp, char __user *buf, size_t count, loff_t *ppos, struct cache_detail *cd) { struct cache_reader *rp = filp->private_data; struct cache_request *rq; - struct inode *inode = filp->f_path.dentry->d_inode; + struct inode *inode = file_inode(filp); int err; if (count == 0) @@ -663,15 +777,22 @@ static ssize_t cache_read(struct file *filp, char __user *buf, size_t count, if (rp->q.list.next == &cd->queue) { spin_unlock(&queue_lock); mutex_unlock(&inode->i_mutex); - BUG_ON(rp->offset); + WARN_ON_ONCE(rp->offset); return 0; } rq = container_of(rp->q.list.next, struct cache_request, q.list); - BUG_ON(rq->q.reader); + WARN_ON_ONCE(rq->q.reader); if (rp->offset == 0) rq->readers++; spin_unlock(&queue_lock); + if (rq->len == 0) { + err = cache_request(cd, rq); + if (err < 0) + goto out; + rq->len = err; + } + if (rp->offset == 0 && !test_bit(CACHE_PENDING, &rq->item->flags)) { err = -EAGAIN; spin_lock(&queue_lock); @@ -718,6 +839,8 @@ static ssize_t cache_do_downcall(char *kaddr, const char __user *buf, { ssize_t ret; + if (count == 0) + return -EINVAL; if (copy_from_user(kaddr, buf, count)) return -EFAULT; kaddr[count] = '\0'; @@ -772,7 +895,7 @@ static ssize_t cache_write(struct file *filp, const char __user *buf, struct cache_detail *cd) { struct address_space *mapping = filp->f_mapping; - struct inode *inode = filp->f_path.dentry->d_inode; + struct inode *inode = file_inode(filp); ssize_t ret = -EINVAL; if (!cd->cache_parse) @@ -853,8 +976,10 @@ static int cache_open(struct inode *inode, struct file *filp, nonseekable_open(inode, filp); if (filp->f_mode & FMODE_READ) { rp = kmalloc(sizeof(*rp), GFP_KERNEL); - if (!rp) + if (!rp) { + module_put(cd->owner); return -ENOMEM; + } rp->offset = 0; rp->q.reader = 1; atomic_inc(&cd->readers); @@ -890,7 +1015,7 @@ static int cache_release(struct inode *inode, struct file *filp, filp->private_data = NULL; kfree(rp); - cd->last_close = get_seconds(); + cd->last_close = seconds_since_boot(); atomic_dec(&cd->readers); } module_put(cd->owner); @@ -901,23 +1026,32 @@ static int cache_release(struct inode *inode, struct file *filp, static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch) { - struct cache_queue *cq; + struct cache_queue *cq, *tmp; + struct cache_request *cr; + struct list_head dequeued; + + INIT_LIST_HEAD(&dequeued); spin_lock(&queue_lock); - list_for_each_entry(cq, &detail->queue, list) + list_for_each_entry_safe(cq, tmp, &detail->queue, list) if (!cq->reader) { - struct cache_request *cr = container_of(cq, struct cache_request, q); + cr = container_of(cq, struct cache_request, q); if (cr->item != ch) continue; + if (test_bit(CACHE_PENDING, &ch->flags)) + /* Lost a race and it is pending again */ + break; if (cr->readers != 0) continue; - list_del(&cr->q.list); - spin_unlock(&queue_lock); - cache_put(cr->item, detail); - kfree(cr->buf); - kfree(cr); - return; + list_move(&cr->q.list, &dequeued); } spin_unlock(&queue_lock); + while (!list_empty(&dequeued)) { + cr = list_entry(dequeued.next, struct cache_request, q.list); + list_del(&cr->q.list); + cache_put(cr->item, detail); + kfree(cr->buf); + kfree(cr); + } } /* @@ -977,9 +1111,7 @@ void qword_addhex(char **bpp, int *lp, char *buf, int blen) *bp++ = 'x'; len -= 2; while (blen && len >= 2) { - unsigned char c = *buf++; - *bp++ = '0' + ((c&0xf0)>>4) + (c>=0xa0)*('a'-'9'-1); - *bp++ = '0' + (c&0x0f) + ((c&0x0f)>=0x0a)*('a'-'9'-1); + bp = hex_byte_pack(bp, *buf++); len -= 2; blen--; } @@ -1003,29 +1135,46 @@ static void warn_no_listener(struct cache_detail *detail) } } +static bool cache_listeners_exist(struct cache_detail *detail) +{ + if (atomic_read(&detail->readers)) + return true; + if (detail->last_close == 0) + /* This cache was never opened */ + return false; + if (detail->last_close < seconds_since_boot() - 30) + /* + * We allow for the possibility that someone might + * restart a userspace daemon without restarting the + * server; but after 30 seconds, we give up. + */ + return false; + return true; +} + /* * register an upcall request to user-space and queue it up for read() by the * upcall daemon. * * Each request is at most one page long. */ -int sunrpc_cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h, - void (*cache_request)(struct cache_detail *, - struct cache_head *, - char **, - int *)) +int sunrpc_cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h) { char *buf; struct cache_request *crq; - char *bp; - int len; + int ret = 0; + + if (!detail->cache_request) + return -EINVAL; - if (atomic_read(&detail->readers) == 0 && - detail->last_close < get_seconds() - 30) { - warn_no_listener(detail); - return -EINVAL; + if (!cache_listeners_exist(detail)) { + warn_no_listener(detail); + return -EINVAL; } + if (test_bit(CACHE_CLEANED, &h->flags)) + /* Too late to make an upcall */ + return -EAGAIN; buf = kmalloc(PAGE_SIZE, GFP_KERNEL); if (!buf) @@ -1037,25 +1186,24 @@ int sunrpc_cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h, return -EAGAIN; } - bp = buf; len = PAGE_SIZE; - - cache_request(detail, h, &bp, &len); - - if (len < 0) { - kfree(buf); - kfree(crq); - return -EAGAIN; - } crq->q.reader = 0; crq->item = cache_get(h); crq->buf = buf; - crq->len = PAGE_SIZE - len; + crq->len = 0; crq->readers = 0; spin_lock(&queue_lock); - list_add_tail(&crq->q.list, &detail->queue); + if (test_bit(CACHE_PENDING, &h->flags)) + list_add_tail(&crq->q.list, &detail->queue); + else + /* Lost a race, no longer PENDING, so don't enqueue */ + ret = -EAGAIN; spin_unlock(&queue_lock); wake_up(&queue_wait); - return 0; + if (ret == -EAGAIN) { + kfree(buf); + kfree(crq); + } + return ret; } EXPORT_SYMBOL_GPL(sunrpc_cache_pipe_upcall); @@ -1071,7 +1219,6 @@ EXPORT_SYMBOL_GPL(sunrpc_cache_pipe_upcall); * key and content are both parsed by cache */ -#define isodigit(c) (isdigit(c) && c <= '7') int qword_get(char **bpp, char *dest, int bufsize) { /* return bytes copied, or -1 on error */ @@ -1083,13 +1230,19 @@ int qword_get(char **bpp, char *dest, int bufsize) if (bp[0] == '\\' && bp[1] == 'x') { /* HEX STRING */ bp += 2; - while (isxdigit(bp[0]) && isxdigit(bp[1]) && len < bufsize) { - int byte = isdigit(*bp) ? *bp-'0' : toupper(*bp)-'A'+10; - bp++; - byte <<= 4; - byte |= isdigit(*bp) ? *bp-'0' : toupper(*bp)-'A'+10; - *dest++ = byte; - bp++; + while (len < bufsize) { + int h, l; + + h = hex_to_bin(bp[0]); + if (h < 0) + break; + + l = hex_to_bin(bp[1]); + if (l < 0) + break; + + *dest++ = (h << 4) | l; + bp += 2; len++; } } else { @@ -1137,7 +1290,7 @@ static void *c_start(struct seq_file *m, loff_t *pos) __acquires(cd->hash_lock) { loff_t n = *pos; - unsigned hash, entry; + unsigned int hash, entry; struct cache_head *ch; struct cache_detail *cd = ((struct handle*)m->private)->cd; @@ -1207,13 +1360,17 @@ static int c_show(struct seq_file *m, void *p) ifdebug(CACHE) seq_printf(m, "# expiry=%ld refcnt=%d flags=%lx\n", - cp->expiry_time, atomic_read(&cp->ref.refcount), cp->flags); + convert_to_wallclock(cp->expiry_time), + atomic_read(&cp->ref.refcount), cp->flags); cache_get(cp); if (cache_check(cd, cp, NULL)) /* cache_check does a cache_put on failure */ seq_printf(m, "# "); - else + else { + if (cache_is_expired(cd, cp)) + seq_printf(m, "# "); cache_put(cp, cd); + } return cd->cache_show(m, cd, cp); } @@ -1233,8 +1390,10 @@ static int content_open(struct inode *inode, struct file *file, if (!cd || !try_module_get(cd->owner)) return -EACCES; han = __seq_open_private(file, &cache_content_op, sizeof(*han)); - if (han == NULL) + if (han == NULL) { + module_put(cd->owner); return -ENOMEM; + } han->cd = cd; return 0; @@ -1267,11 +1426,11 @@ static ssize_t read_flush(struct file *file, char __user *buf, size_t count, loff_t *ppos, struct cache_detail *cd) { - char tbuf[20]; + char tbuf[22]; unsigned long p = *ppos; size_t len; - sprintf(tbuf, "%lu\n", cd->flush_time); + snprintf(tbuf, sizeof(tbuf), "%lu\n", convert_to_wallclock(cd->flush_time)); len = strlen(tbuf); if (p >= len) return 0; @@ -1289,19 +1448,20 @@ static ssize_t write_flush(struct file *file, const char __user *buf, struct cache_detail *cd) { char tbuf[20]; - char *ep; - long flushtime; + char *bp, *ep; + if (*ppos || count > sizeof(tbuf)-1) return -EINVAL; if (copy_from_user(tbuf, buf, count)) return -EFAULT; tbuf[count] = 0; - flushtime = simple_strtoul(tbuf, &ep, 0); + simple_strtoul(tbuf, &ep, 0); if (*ep && *ep != '\n') return -EINVAL; - cd->flush_time = flushtime; - cd->nextcheck = get_seconds(); + bp = tbuf; + cd->flush_time = get_expiry(&bp); + cd->nextcheck = seconds_since_boot(); cache_flush(); *ppos += count; @@ -1311,7 +1471,7 @@ static ssize_t write_flush(struct file *file, const char __user *buf, static ssize_t cache_read_procfs(struct file *filp, char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = PDE(filp->f_path.dentry->d_inode)->data; + struct cache_detail *cd = PDE_DATA(file_inode(filp)); return cache_read(filp, buf, count, ppos, cd); } @@ -1319,36 +1479,37 @@ static ssize_t cache_read_procfs(struct file *filp, char __user *buf, static ssize_t cache_write_procfs(struct file *filp, const char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = PDE(filp->f_path.dentry->d_inode)->data; + struct cache_detail *cd = PDE_DATA(file_inode(filp)); return cache_write(filp, buf, count, ppos, cd); } static unsigned int cache_poll_procfs(struct file *filp, poll_table *wait) { - struct cache_detail *cd = PDE(filp->f_path.dentry->d_inode)->data; + struct cache_detail *cd = PDE_DATA(file_inode(filp)); return cache_poll(filp, wait, cd); } -static int cache_ioctl_procfs(struct inode *inode, struct file *filp, - unsigned int cmd, unsigned long arg) +static long cache_ioctl_procfs(struct file *filp, + unsigned int cmd, unsigned long arg) { - struct cache_detail *cd = PDE(inode)->data; + struct inode *inode = file_inode(filp); + struct cache_detail *cd = PDE_DATA(inode); return cache_ioctl(inode, filp, cmd, arg, cd); } static int cache_open_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE(inode)->data; + struct cache_detail *cd = PDE_DATA(inode); return cache_open(inode, filp, cd); } static int cache_release_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE(inode)->data; + struct cache_detail *cd = PDE_DATA(inode); return cache_release(inode, filp, cd); } @@ -1359,21 +1520,21 @@ static const struct file_operations cache_file_operations_procfs = { .read = cache_read_procfs, .write = cache_write_procfs, .poll = cache_poll_procfs, - .ioctl = cache_ioctl_procfs, /* for FIONREAD */ + .unlocked_ioctl = cache_ioctl_procfs, /* for FIONREAD */ .open = cache_open_procfs, .release = cache_release_procfs, }; static int content_open_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE(inode)->data; + struct cache_detail *cd = PDE_DATA(inode); return content_open(inode, filp, cd); } static int content_release_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE(inode)->data; + struct cache_detail *cd = PDE_DATA(inode); return content_release(inode, filp, cd); } @@ -1387,14 +1548,14 @@ static const struct file_operations content_file_operations_procfs = { static int open_flush_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE(inode)->data; + struct cache_detail *cd = PDE_DATA(inode); return open_flush(inode, filp, cd); } static int release_flush_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE(inode)->data; + struct cache_detail *cd = PDE_DATA(inode); return release_flush(inode, filp, cd); } @@ -1402,7 +1563,7 @@ static int release_flush_procfs(struct inode *inode, struct file *filp) static ssize_t read_flush_procfs(struct file *filp, char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = PDE(filp->f_path.dentry->d_inode)->data; + struct cache_detail *cd = PDE_DATA(file_inode(filp)); return read_flush(filp, buf, count, ppos, cd); } @@ -1411,7 +1572,7 @@ static ssize_t write_flush_procfs(struct file *filp, const char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = PDE(filp->f_path.dentry->d_inode)->data; + struct cache_detail *cd = PDE_DATA(file_inode(filp)); return write_flush(filp, buf, count, ppos, cd); } @@ -1421,10 +1582,13 @@ static const struct file_operations cache_flush_operations_procfs = { .read = read_flush_procfs, .write = write_flush_procfs, .release = release_flush_procfs, + .llseek = no_llseek, }; -static void remove_cache_proc_entries(struct cache_detail *cd) +static void remove_cache_proc_entries(struct cache_detail *cd, struct net *net) { + struct sunrpc_net *sn; + if (cd->u.procfs.proc_ent == NULL) return; if (cd->u.procfs.flush_ent) @@ -1434,15 +1598,18 @@ static void remove_cache_proc_entries(struct cache_detail *cd) if (cd->u.procfs.content_ent) remove_proc_entry("content", cd->u.procfs.proc_ent); cd->u.procfs.proc_ent = NULL; - remove_proc_entry(cd->name, proc_net_rpc); + sn = net_generic(net, sunrpc_net_id); + remove_proc_entry(cd->name, sn->proc_net_rpc); } #ifdef CONFIG_PROC_FS -static int create_cache_proc_entries(struct cache_detail *cd) +static int create_cache_proc_entries(struct cache_detail *cd, struct net *net) { struct proc_dir_entry *p; + struct sunrpc_net *sn; - cd->u.procfs.proc_ent = proc_mkdir(cd->name, proc_net_rpc); + sn = net_generic(net, sunrpc_net_id); + cd->u.procfs.proc_ent = proc_mkdir(cd->name, sn->proc_net_rpc); if (cd->u.procfs.proc_ent == NULL) goto out_nomem; cd->u.procfs.channel_ent = NULL; @@ -1455,7 +1622,7 @@ static int create_cache_proc_entries(struct cache_detail *cd) if (p == NULL) goto out_nomem; - if (cd->cache_upcall || cd->cache_parse) { + if (cd->cache_request || cd->cache_parse) { p = proc_create_data("channel", S_IFREG|S_IRUSR|S_IWUSR, cd->u.procfs.proc_ent, &cache_file_operations_procfs, cd); @@ -1464,7 +1631,7 @@ static int create_cache_proc_entries(struct cache_detail *cd) goto out_nomem; } if (cd->cache_show) { - p = proc_create_data("content", S_IFREG|S_IRUSR|S_IWUSR, + p = proc_create_data("content", S_IFREG|S_IRUSR, cd->u.procfs.proc_ent, &content_file_operations_procfs, cd); cd->u.procfs.content_ent = p; @@ -1473,39 +1640,70 @@ static int create_cache_proc_entries(struct cache_detail *cd) } return 0; out_nomem: - remove_cache_proc_entries(cd); + remove_cache_proc_entries(cd, net); return -ENOMEM; } #else /* CONFIG_PROC_FS */ -static int create_cache_proc_entries(struct cache_detail *cd) +static int create_cache_proc_entries(struct cache_detail *cd, struct net *net) { return 0; } #endif -int cache_register(struct cache_detail *cd) +void __init cache_initialize(void) +{ + INIT_DEFERRABLE_WORK(&cache_cleaner, do_cache_clean); +} + +int cache_register_net(struct cache_detail *cd, struct net *net) { int ret; sunrpc_init_cache_detail(cd); - ret = create_cache_proc_entries(cd); + ret = create_cache_proc_entries(cd, net); if (ret) sunrpc_destroy_cache_detail(cd); return ret; } -EXPORT_SYMBOL_GPL(cache_register); +EXPORT_SYMBOL_GPL(cache_register_net); -void cache_unregister(struct cache_detail *cd) +void cache_unregister_net(struct cache_detail *cd, struct net *net) { - remove_cache_proc_entries(cd); + remove_cache_proc_entries(cd, net); sunrpc_destroy_cache_detail(cd); } -EXPORT_SYMBOL_GPL(cache_unregister); +EXPORT_SYMBOL_GPL(cache_unregister_net); + +struct cache_detail *cache_create_net(struct cache_detail *tmpl, struct net *net) +{ + struct cache_detail *cd; + + cd = kmemdup(tmpl, sizeof(struct cache_detail), GFP_KERNEL); + if (cd == NULL) + return ERR_PTR(-ENOMEM); + + cd->hash_table = kzalloc(cd->hash_size * sizeof(struct cache_head *), + GFP_KERNEL); + if (cd->hash_table == NULL) { + kfree(cd); + return ERR_PTR(-ENOMEM); + } + cd->net = net; + return cd; +} +EXPORT_SYMBOL_GPL(cache_create_net); + +void cache_destroy_net(struct cache_detail *cd, struct net *net) +{ + kfree(cd->hash_table); + kfree(cd); +} +EXPORT_SYMBOL_GPL(cache_destroy_net); static ssize_t cache_read_pipefs(struct file *filp, char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = RPC_I(filp->f_path.dentry->d_inode)->private; + struct cache_detail *cd = RPC_I(file_inode(filp))->private; return cache_read(filp, buf, count, ppos, cd); } @@ -1513,21 +1711,22 @@ static ssize_t cache_read_pipefs(struct file *filp, char __user *buf, static ssize_t cache_write_pipefs(struct file *filp, const char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = RPC_I(filp->f_path.dentry->d_inode)->private; + struct cache_detail *cd = RPC_I(file_inode(filp))->private; return cache_write(filp, buf, count, ppos, cd); } static unsigned int cache_poll_pipefs(struct file *filp, poll_table *wait) { - struct cache_detail *cd = RPC_I(filp->f_path.dentry->d_inode)->private; + struct cache_detail *cd = RPC_I(file_inode(filp))->private; return cache_poll(filp, wait, cd); } -static int cache_ioctl_pipefs(struct inode *inode, struct file *filp, +static long cache_ioctl_pipefs(struct file *filp, unsigned int cmd, unsigned long arg) { + struct inode *inode = file_inode(filp); struct cache_detail *cd = RPC_I(inode)->private; return cache_ioctl(inode, filp, cmd, arg, cd); @@ -1553,7 +1752,7 @@ const struct file_operations cache_file_operations_pipefs = { .read = cache_read_pipefs, .write = cache_write_pipefs, .poll = cache_poll_pipefs, - .ioctl = cache_ioctl_pipefs, /* for FIONREAD */ + .unlocked_ioctl = cache_ioctl_pipefs, /* for FIONREAD */ .open = cache_open_pipefs, .release = cache_release_pipefs, }; @@ -1596,7 +1795,7 @@ static int release_flush_pipefs(struct inode *inode, struct file *filp) static ssize_t read_flush_pipefs(struct file *filp, char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = RPC_I(filp->f_path.dentry->d_inode)->private; + struct cache_detail *cd = RPC_I(file_inode(filp))->private; return read_flush(filp, buf, count, ppos, cd); } @@ -1605,7 +1804,7 @@ static ssize_t write_flush_pipefs(struct file *filp, const char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = RPC_I(filp->f_path.dentry->d_inode)->private; + struct cache_detail *cd = RPC_I(file_inode(filp))->private; return write_flush(filp, buf, count, ppos, cd); } @@ -1615,28 +1814,18 @@ const struct file_operations cache_flush_operations_pipefs = { .read = read_flush_pipefs, .write = write_flush_pipefs, .release = release_flush_pipefs, + .llseek = no_llseek, }; int sunrpc_cache_register_pipefs(struct dentry *parent, - const char *name, mode_t umode, + const char *name, umode_t umode, struct cache_detail *cd) { - struct qstr q; - struct dentry *dir; - int ret = 0; - - sunrpc_init_cache_detail(cd); - q.name = name; - q.len = strlen(name); - q.hash = full_name_hash(q.name, q.len); - dir = rpc_create_cache_dir(parent, &q, umode, cd); - if (!IS_ERR(dir)) - cd->u.pipefs.dir = dir; - else { - sunrpc_destroy_cache_detail(cd); - ret = PTR_ERR(dir); - } - return ret; + struct dentry *dir = rpc_create_cache_dir(parent, name, umode, cd); + if (IS_ERR(dir)) + return PTR_ERR(dir); + cd->u.pipefs.dir = dir; + return 0; } EXPORT_SYMBOL_GPL(sunrpc_cache_register_pipefs); @@ -1644,7 +1833,6 @@ void sunrpc_cache_unregister_pipefs(struct cache_detail *cd) { rpc_remove_cache_dir(cd->u.pipefs.dir); cd->u.pipefs.dir = NULL; - sunrpc_destroy_cache_detail(cd); } EXPORT_SYMBOL_GPL(sunrpc_cache_unregister_pipefs); diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c index 154034b675b..2e6ab10734f 100644 --- a/net/sunrpc/clnt.c +++ b/net/sunrpc/clnt.c @@ -13,15 +13,10 @@ * and need to be refreshed, or when a packet was damaged in transit. * This may be have to be moved to the VFS layer. * - * NB: BSD uses a more intelligent approach to guessing when a request - * or reply has been lost by keeping the RTO estimate for each procedure. - * We currently make do with a constant timeout value. - * * Copyright (C) 1992,1993 Rick Sladkey <jrs@world.std.com> * Copyright (C) 1995,1996 Olaf Kirch <okir@monad.swb.de> */ -#include <asm/system.h> #include <linux/module.h> #include <linux/types.h> @@ -30,16 +25,22 @@ #include <linux/namei.h> #include <linux/mount.h> #include <linux/slab.h> +#include <linux/rcupdate.h> #include <linux/utsname.h> #include <linux/workqueue.h> +#include <linux/in.h> #include <linux/in6.h> +#include <linux/un.h> #include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/addr.h> #include <linux/sunrpc/rpc_pipe_fs.h> #include <linux/sunrpc/metrics.h> #include <linux/sunrpc/bc_xprt.h> +#include <trace/events/sunrpc.h> #include "sunrpc.h" +#include "netns.h" #ifdef RPC_DEBUG # define RPCDBG_FACILITY RPCDBG_CALL @@ -52,8 +53,6 @@ /* * All RPC clients are linked into this list */ -static LIST_HEAD(all_clients); -static DEFINE_SPINLOCK(rpc_client_lock); static DECLARE_WAIT_QUEUE_HEAD(destroy_wait); @@ -66,9 +65,9 @@ static void call_decode(struct rpc_task *task); static void call_bind(struct rpc_task *task); static void call_bind_status(struct rpc_task *task); static void call_transmit(struct rpc_task *task); -#if defined(CONFIG_NFS_V4_1) +#if defined(CONFIG_SUNRPC_BACKCHANNEL) static void call_bc_transmit(struct rpc_task *task); -#endif /* CONFIG_NFS_V4_1 */ +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ static void call_status(struct rpc_task *task); static void call_transmit_status(struct rpc_task *task); static void call_refresh(struct rpc_task *task); @@ -83,93 +82,295 @@ static int rpc_ping(struct rpc_clnt *clnt); static void rpc_register_client(struct rpc_clnt *clnt) { - spin_lock(&rpc_client_lock); - list_add(&clnt->cl_clients, &all_clients); - spin_unlock(&rpc_client_lock); + struct net *net = rpc_net_ns(clnt); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + spin_lock(&sn->rpc_client_lock); + list_add(&clnt->cl_clients, &sn->all_clients); + spin_unlock(&sn->rpc_client_lock); } static void rpc_unregister_client(struct rpc_clnt *clnt) { - spin_lock(&rpc_client_lock); + struct net *net = rpc_net_ns(clnt); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + spin_lock(&sn->rpc_client_lock); list_del(&clnt->cl_clients); - spin_unlock(&rpc_client_lock); + spin_unlock(&sn->rpc_client_lock); } -static int -rpc_setup_pipedir(struct rpc_clnt *clnt, char *dir_name) +static void __rpc_clnt_remove_pipedir(struct rpc_clnt *clnt) { - static uint32_t clntid; - struct nameidata nd; - struct path path; - char name[15]; - struct qstr q = { - .name = name, - }; - int error; + rpc_remove_client_dir(clnt); +} - clnt->cl_path.mnt = ERR_PTR(-ENOENT); - clnt->cl_path.dentry = ERR_PTR(-ENOENT); - if (dir_name == NULL) - return 0; +static void rpc_clnt_remove_pipedir(struct rpc_clnt *clnt) +{ + struct net *net = rpc_net_ns(clnt); + struct super_block *pipefs_sb; - path.mnt = rpc_get_mount(); - if (IS_ERR(path.mnt)) - return PTR_ERR(path.mnt); - error = vfs_path_lookup(path.mnt->mnt_root, path.mnt, dir_name, 0, &nd); - if (error) - goto err; + pipefs_sb = rpc_get_sb_net(net); + if (pipefs_sb) { + __rpc_clnt_remove_pipedir(clnt); + rpc_put_sb_net(net); + } +} +static struct dentry *rpc_setup_pipedir_sb(struct super_block *sb, + struct rpc_clnt *clnt) +{ + static uint32_t clntid; + const char *dir_name = clnt->cl_program->pipe_dir_name; + char name[15]; + struct dentry *dir, *dentry; + + dir = rpc_d_lookup_sb(sb, dir_name); + if (dir == NULL) { + pr_info("RPC: pipefs directory doesn't exist: %s\n", dir_name); + return dir; + } for (;;) { - q.len = snprintf(name, sizeof(name), "clnt%x", (unsigned int)clntid++); + snprintf(name, sizeof(name), "clnt%x", (unsigned int)clntid++); name[sizeof(name) - 1] = '\0'; - q.hash = full_name_hash(q.name, q.len); - path.dentry = rpc_create_client_dir(nd.path.dentry, &q, clnt); - if (!IS_ERR(path.dentry)) + dentry = rpc_create_client_dir(dir, name, clnt); + if (!IS_ERR(dentry)) break; - error = PTR_ERR(path.dentry); - if (error != -EEXIST) { - printk(KERN_INFO "RPC: Couldn't create pipefs entry" - " %s/%s, error %d\n", - dir_name, name, error); - goto err_path_put; - } + if (dentry == ERR_PTR(-EEXIST)) + continue; + printk(KERN_INFO "RPC: Couldn't create pipefs entry" + " %s/%s, error %ld\n", + dir_name, name, PTR_ERR(dentry)); + break; + } + dput(dir); + return dentry; +} + +static int +rpc_setup_pipedir(struct super_block *pipefs_sb, struct rpc_clnt *clnt) +{ + struct dentry *dentry; + + if (clnt->cl_program->pipe_dir_name != NULL) { + dentry = rpc_setup_pipedir_sb(pipefs_sb, clnt); + if (IS_ERR(dentry)) + return PTR_ERR(dentry); + } + return 0; +} + +static int rpc_clnt_skip_event(struct rpc_clnt *clnt, unsigned long event) +{ + if (clnt->cl_program->pipe_dir_name == NULL) + return 1; + + switch (event) { + case RPC_PIPEFS_MOUNT: + if (clnt->cl_pipedir_objects.pdh_dentry != NULL) + return 1; + if (atomic_read(&clnt->cl_count) == 0) + return 1; + break; + case RPC_PIPEFS_UMOUNT: + if (clnt->cl_pipedir_objects.pdh_dentry == NULL) + return 1; + break; } - path_put(&nd.path); - clnt->cl_path = path; return 0; -err_path_put: - path_put(&nd.path); -err: - rpc_put_mount(); +} + +static int __rpc_clnt_handle_event(struct rpc_clnt *clnt, unsigned long event, + struct super_block *sb) +{ + struct dentry *dentry; + int err = 0; + + switch (event) { + case RPC_PIPEFS_MOUNT: + dentry = rpc_setup_pipedir_sb(sb, clnt); + if (!dentry) + return -ENOENT; + if (IS_ERR(dentry)) + return PTR_ERR(dentry); + break; + case RPC_PIPEFS_UMOUNT: + __rpc_clnt_remove_pipedir(clnt); + break; + default: + printk(KERN_ERR "%s: unknown event: %ld\n", __func__, event); + return -ENOTSUPP; + } + return err; +} + +static int __rpc_pipefs_event(struct rpc_clnt *clnt, unsigned long event, + struct super_block *sb) +{ + int error = 0; + + for (;; clnt = clnt->cl_parent) { + if (!rpc_clnt_skip_event(clnt, event)) + error = __rpc_clnt_handle_event(clnt, event, sb); + if (error || clnt == clnt->cl_parent) + break; + } + return error; +} + +static struct rpc_clnt *rpc_get_client_for_event(struct net *net, int event) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt; + + spin_lock(&sn->rpc_client_lock); + list_for_each_entry(clnt, &sn->all_clients, cl_clients) { + if (rpc_clnt_skip_event(clnt, event)) + continue; + spin_unlock(&sn->rpc_client_lock); + return clnt; + } + spin_unlock(&sn->rpc_client_lock); + return NULL; +} + +static int rpc_pipefs_event(struct notifier_block *nb, unsigned long event, + void *ptr) +{ + struct super_block *sb = ptr; + struct rpc_clnt *clnt; + int error = 0; + + while ((clnt = rpc_get_client_for_event(sb->s_fs_info, event))) { + error = __rpc_pipefs_event(clnt, event, sb); + if (error) + break; + } return error; } -static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, struct rpc_xprt *xprt) +static struct notifier_block rpc_clients_block = { + .notifier_call = rpc_pipefs_event, + .priority = SUNRPC_PIPEFS_RPC_PRIO, +}; + +int rpc_clients_notifier_register(void) +{ + return rpc_pipefs_notifier_register(&rpc_clients_block); +} + +void rpc_clients_notifier_unregister(void) { - struct rpc_program *program = args->program; - struct rpc_version *version; - struct rpc_clnt *clnt = NULL; - struct rpc_auth *auth; + return rpc_pipefs_notifier_unregister(&rpc_clients_block); +} + +static struct rpc_xprt *rpc_clnt_set_transport(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + const struct rpc_timeout *timeout) +{ + struct rpc_xprt *old; + + spin_lock(&clnt->cl_lock); + old = rcu_dereference_protected(clnt->cl_xprt, + lockdep_is_held(&clnt->cl_lock)); + + if (!xprt_bound(xprt)) + clnt->cl_autobind = 1; + + clnt->cl_timeout = timeout; + rcu_assign_pointer(clnt->cl_xprt, xprt); + spin_unlock(&clnt->cl_lock); + + return old; +} + +static void rpc_clnt_set_nodename(struct rpc_clnt *clnt, const char *nodename) +{ + clnt->cl_nodelen = strlen(nodename); + if (clnt->cl_nodelen > UNX_MAXNODENAME) + clnt->cl_nodelen = UNX_MAXNODENAME; + memcpy(clnt->cl_nodename, nodename, clnt->cl_nodelen); +} + +static int rpc_client_register(struct rpc_clnt *clnt, + rpc_authflavor_t pseudoflavor, + const char *client_name) +{ + struct rpc_auth_create_args auth_args = { + .pseudoflavor = pseudoflavor, + .target_name = client_name, + }; + struct rpc_auth *auth; + struct net *net = rpc_net_ns(clnt); + struct super_block *pipefs_sb; int err; - size_t len; - /* sanity check the name before trying to print it */ - err = -EINVAL; - len = strlen(args->servername); - if (len > RPC_MAXNETNAMELEN) - goto out_no_rpciod; - len++; + pipefs_sb = rpc_get_sb_net(net); + if (pipefs_sb) { + err = rpc_setup_pipedir(pipefs_sb, clnt); + if (err) + goto out; + } + + rpc_register_client(clnt); + if (pipefs_sb) + rpc_put_sb_net(net); + + auth = rpcauth_create(&auth_args, clnt); + if (IS_ERR(auth)) { + dprintk("RPC: Couldn't create auth handle (flavor %u)\n", + pseudoflavor); + err = PTR_ERR(auth); + goto err_auth; + } + return 0; +err_auth: + pipefs_sb = rpc_get_sb_net(net); + rpc_unregister_client(clnt); + __rpc_clnt_remove_pipedir(clnt); +out: + if (pipefs_sb) + rpc_put_sb_net(net); + return err; +} + +static DEFINE_IDA(rpc_clids); +static int rpc_alloc_clid(struct rpc_clnt *clnt) +{ + int clid; + + clid = ida_simple_get(&rpc_clids, 0, 0, GFP_KERNEL); + if (clid < 0) + return clid; + clnt->cl_clid = clid; + return 0; +} + +static void rpc_free_clid(struct rpc_clnt *clnt) +{ + ida_simple_remove(&rpc_clids, clnt->cl_clid); +} + +static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, + struct rpc_xprt *xprt, + struct rpc_clnt *parent) +{ + const struct rpc_program *program = args->program; + const struct rpc_version *version; + struct rpc_clnt *clnt = NULL; + const struct rpc_timeout *timeout; + int err; + + /* sanity check the name before trying to print it */ dprintk("RPC: creating %s client for %s (xprt %p)\n", program->name, args->servername, xprt); err = rpciod_up(); if (err) goto out_no_rpciod; - err = -EINVAL; - if (!xprt) - goto out_no_xprt; + err = -EINVAL; if (args->version >= program->nrvers) goto out_err; version = program->version[args->version]; @@ -180,26 +381,19 @@ static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, stru clnt = kzalloc(sizeof(*clnt), GFP_KERNEL); if (!clnt) goto out_err; - clnt->cl_parent = clnt; + clnt->cl_parent = parent ? : clnt; - clnt->cl_server = clnt->cl_inline_name; - if (len > sizeof(clnt->cl_inline_name)) { - char *buf = kmalloc(len, GFP_KERNEL); - if (buf != NULL) - clnt->cl_server = buf; - else - len = sizeof(clnt->cl_inline_name); - } - strlcpy(clnt->cl_server, args->servername, len); + err = rpc_alloc_clid(clnt); + if (err) + goto out_no_clid; - clnt->cl_xprt = xprt; clnt->cl_procinfo = version->procs; clnt->cl_maxproc = version->nrprocs; - clnt->cl_protname = program->name; clnt->cl_prog = args->prognumber ? : program->number; clnt->cl_vers = version->number; clnt->cl_stats = program->stats; clnt->cl_metrics = rpc_alloc_iostats(clnt); + rpc_init_pipe_dir_head(&clnt->cl_pipedir_objects); err = -ENOMEM; if (clnt->cl_metrics == NULL) goto out_no_stats; @@ -207,69 +401,76 @@ static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, stru INIT_LIST_HEAD(&clnt->cl_tasks); spin_lock_init(&clnt->cl_lock); - if (!xprt_bound(clnt->cl_xprt)) - clnt->cl_autobind = 1; - - clnt->cl_timeout = xprt->timeout; + timeout = xprt->timeout; if (args->timeout != NULL) { memcpy(&clnt->cl_timeout_default, args->timeout, sizeof(clnt->cl_timeout_default)); - clnt->cl_timeout = &clnt->cl_timeout_default; + timeout = &clnt->cl_timeout_default; } + rpc_clnt_set_transport(clnt, xprt, timeout); + clnt->cl_rtt = &clnt->cl_rtt_default; rpc_init_rtt(&clnt->cl_rtt_default, clnt->cl_timeout->to_initval); - clnt->cl_principal = NULL; - if (args->client_name) { - clnt->cl_principal = kstrdup(args->client_name, GFP_KERNEL); - if (!clnt->cl_principal) - goto out_no_principal; - } - kref_init(&clnt->cl_kref); - - err = rpc_setup_pipedir(clnt, program->pipe_dir_name); - if (err < 0) - goto out_no_path; - - auth = rpcauth_create(args->authflavor, clnt); - if (IS_ERR(auth)) { - printk(KERN_INFO "RPC: Couldn't create auth handle (flavor %u)\n", - args->authflavor); - err = PTR_ERR(auth); - goto out_no_auth; - } + atomic_set(&clnt->cl_count, 1); /* save the nodename */ - clnt->cl_nodelen = strlen(init_utsname()->nodename); - if (clnt->cl_nodelen > UNX_MAXNODENAME) - clnt->cl_nodelen = UNX_MAXNODENAME; - memcpy(clnt->cl_nodename, init_utsname()->nodename, clnt->cl_nodelen); - rpc_register_client(clnt); + rpc_clnt_set_nodename(clnt, utsname()->nodename); + + err = rpc_client_register(clnt, args->authflavor, args->client_name); + if (err) + goto out_no_path; + if (parent) + atomic_inc(&parent->cl_count); return clnt; -out_no_auth: - if (!IS_ERR(clnt->cl_path.dentry)) { - rpc_remove_client_dir(clnt->cl_path.dentry); - rpc_put_mount(); - } out_no_path: - kfree(clnt->cl_principal); -out_no_principal: rpc_free_iostats(clnt->cl_metrics); out_no_stats: - if (clnt->cl_server != clnt->cl_inline_name) - kfree(clnt->cl_server); + rpc_free_clid(clnt); +out_no_clid: kfree(clnt); out_err: - xprt_put(xprt); -out_no_xprt: rpciod_down(); out_no_rpciod: + xprt_put(xprt); return ERR_PTR(err); } -/* +struct rpc_clnt *rpc_create_xprt(struct rpc_create_args *args, + struct rpc_xprt *xprt) +{ + struct rpc_clnt *clnt = NULL; + + clnt = rpc_new_client(args, xprt, NULL); + if (IS_ERR(clnt)) + return clnt; + + if (!(args->flags & RPC_CLNT_CREATE_NOPING)) { + int err = rpc_ping(clnt); + if (err != 0) { + rpc_shutdown_client(clnt); + return ERR_PTR(err); + } + } + + clnt->cl_softrtry = 1; + if (args->flags & RPC_CLNT_CREATE_HARDRTRY) + clnt->cl_softrtry = 0; + + if (args->flags & RPC_CLNT_CREATE_AUTOBIND) + clnt->cl_autobind = 1; + if (args->flags & RPC_CLNT_CREATE_DISCRTRY) + clnt->cl_discrtry = 1; + if (!(args->flags & RPC_CLNT_CREATE_QUIET)) + clnt->cl_chatty = 1; + + return clnt; +} +EXPORT_SYMBOL_GPL(rpc_create_xprt); + +/** * rpc_create - create an RPC client and transport with one call * @args: rpc_clnt create argument structure * @@ -282,43 +483,53 @@ out_no_rpciod: struct rpc_clnt *rpc_create(struct rpc_create_args *args) { struct rpc_xprt *xprt; - struct rpc_clnt *clnt; struct xprt_create xprtargs = { + .net = args->net, .ident = args->protocol, .srcaddr = args->saddress, .dstaddr = args->address, .addrlen = args->addrsize, + .servername = args->servername, .bc_xprt = args->bc_xprt, }; char servername[48]; + if (args->flags & RPC_CLNT_CREATE_INFINITE_SLOTS) + xprtargs.flags |= XPRT_CREATE_INFINITE_SLOTS; + if (args->flags & RPC_CLNT_CREATE_NO_IDLE_TIMEOUT) + xprtargs.flags |= XPRT_CREATE_NO_IDLE_TIMEOUT; /* * If the caller chooses not to specify a hostname, whip * up a string representation of the passed-in address. */ - if (args->servername == NULL) { + if (xprtargs.servername == NULL) { + struct sockaddr_un *sun = + (struct sockaddr_un *)args->address; + struct sockaddr_in *sin = + (struct sockaddr_in *)args->address; + struct sockaddr_in6 *sin6 = + (struct sockaddr_in6 *)args->address; + servername[0] = '\0'; switch (args->address->sa_family) { - case AF_INET: { - struct sockaddr_in *sin = - (struct sockaddr_in *)args->address; + case AF_LOCAL: + snprintf(servername, sizeof(servername), "%s", + sun->sun_path); + break; + case AF_INET: snprintf(servername, sizeof(servername), "%pI4", &sin->sin_addr.s_addr); break; - } - case AF_INET6: { - struct sockaddr_in6 *sin = - (struct sockaddr_in6 *)args->address; + case AF_INET6: snprintf(servername, sizeof(servername), "%pI6", - &sin->sin6_addr); + &sin6->sin6_addr); break; - } default: /* caller wants default server name, but * address family isn't recognized. */ return ERR_PTR(-EINVAL); } - args->servername = servername; + xprtargs.servername = servername; } xprt = xprt_create_transport(&xprtargs); @@ -335,30 +546,7 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args) if (args->flags & RPC_CLNT_CREATE_NONPRIVPORT) xprt->resvport = 0; - clnt = rpc_new_client(args, xprt); - if (IS_ERR(clnt)) - return clnt; - - if (!(args->flags & RPC_CLNT_CREATE_NOPING)) { - int err = rpc_ping(clnt); - if (err != 0) { - rpc_shutdown_client(clnt); - return ERR_PTR(err); - } - } - - clnt->cl_softrtry = 1; - if (args->flags & RPC_CLNT_CREATE_HARDRTRY) - clnt->cl_softrtry = 0; - - if (args->flags & RPC_CLNT_CREATE_AUTOBIND) - clnt->cl_autobind = 1; - if (args->flags & RPC_CLNT_CREATE_DISCRTRY) - clnt->cl_discrtry = 1; - if (!(args->flags & RPC_CLNT_CREATE_QUIET)) - clnt->cl_chatty = 1; - - return clnt; + return rpc_create_xprt(args, xprt); } EXPORT_SYMBOL_GPL(rpc_create); @@ -367,60 +555,195 @@ EXPORT_SYMBOL_GPL(rpc_create); * same transport while varying parameters such as the authentication * flavour. */ -struct rpc_clnt * -rpc_clone_client(struct rpc_clnt *clnt) +static struct rpc_clnt *__rpc_clone_client(struct rpc_create_args *args, + struct rpc_clnt *clnt) { + struct rpc_xprt *xprt; struct rpc_clnt *new; - int err = -ENOMEM; + int err; + + err = -ENOMEM; + rcu_read_lock(); + xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); + rcu_read_unlock(); + if (xprt == NULL) + goto out_err; + args->servername = xprt->servername; + + new = rpc_new_client(args, xprt, clnt); + if (IS_ERR(new)) { + err = PTR_ERR(new); + goto out_err; + } - new = kmemdup(clnt, sizeof(*new), GFP_KERNEL); - if (!new) - goto out_no_clnt; - new->cl_parent = clnt; /* Turn off autobind on clones */ new->cl_autobind = 0; - INIT_LIST_HEAD(&new->cl_tasks); - spin_lock_init(&new->cl_lock); - rpc_init_rtt(&new->cl_rtt_default, clnt->cl_timeout->to_initval); - new->cl_metrics = rpc_alloc_iostats(clnt); - if (new->cl_metrics == NULL) - goto out_no_stats; - if (clnt->cl_principal) { - new->cl_principal = kstrdup(clnt->cl_principal, GFP_KERNEL); - if (new->cl_principal == NULL) - goto out_no_principal; - } - kref_init(&new->cl_kref); - err = rpc_setup_pipedir(new, clnt->cl_program->pipe_dir_name); - if (err != 0) - goto out_no_path; - if (new->cl_auth) - atomic_inc(&new->cl_auth->au_count); - xprt_get(clnt->cl_xprt); - kref_get(&clnt->cl_kref); - rpc_register_client(new); - rpciod_up(); + new->cl_softrtry = clnt->cl_softrtry; + new->cl_discrtry = clnt->cl_discrtry; + new->cl_chatty = clnt->cl_chatty; return new; -out_no_path: - kfree(new->cl_principal); -out_no_principal: - rpc_free_iostats(new->cl_metrics); -out_no_stats: - kfree(new); -out_no_clnt: + +out_err: dprintk("RPC: %s: returned error %d\n", __func__, err); return ERR_PTR(err); } + +/** + * rpc_clone_client - Clone an RPC client structure + * + * @clnt: RPC client whose parameters are copied + * + * Returns a fresh RPC client or an ERR_PTR. + */ +struct rpc_clnt *rpc_clone_client(struct rpc_clnt *clnt) +{ + struct rpc_create_args args = { + .program = clnt->cl_program, + .prognumber = clnt->cl_prog, + .version = clnt->cl_vers, + .authflavor = clnt->cl_auth->au_flavor, + }; + return __rpc_clone_client(&args, clnt); +} EXPORT_SYMBOL_GPL(rpc_clone_client); +/** + * rpc_clone_client_set_auth - Clone an RPC client structure and set its auth + * + * @clnt: RPC client whose parameters are copied + * @flavor: security flavor for new client + * + * Returns a fresh RPC client or an ERR_PTR. + */ +struct rpc_clnt * +rpc_clone_client_set_auth(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +{ + struct rpc_create_args args = { + .program = clnt->cl_program, + .prognumber = clnt->cl_prog, + .version = clnt->cl_vers, + .authflavor = flavor, + }; + return __rpc_clone_client(&args, clnt); +} +EXPORT_SYMBOL_GPL(rpc_clone_client_set_auth); + +/** + * rpc_switch_client_transport: switch the RPC transport on the fly + * @clnt: pointer to a struct rpc_clnt + * @args: pointer to the new transport arguments + * @timeout: pointer to the new timeout parameters + * + * This function allows the caller to switch the RPC transport for the + * rpc_clnt structure 'clnt' to allow it to connect to a mirrored NFS + * server, for instance. It assumes that the caller has ensured that + * there are no active RPC tasks by using some form of locking. + * + * Returns zero if "clnt" is now using the new xprt. Otherwise a + * negative errno is returned, and "clnt" continues to use the old + * xprt. + */ +int rpc_switch_client_transport(struct rpc_clnt *clnt, + struct xprt_create *args, + const struct rpc_timeout *timeout) +{ + const struct rpc_timeout *old_timeo; + rpc_authflavor_t pseudoflavor; + struct rpc_xprt *xprt, *old; + struct rpc_clnt *parent; + int err; + + xprt = xprt_create_transport(args); + if (IS_ERR(xprt)) { + dprintk("RPC: failed to create new xprt for clnt %p\n", + clnt); + return PTR_ERR(xprt); + } + + pseudoflavor = clnt->cl_auth->au_flavor; + + old_timeo = clnt->cl_timeout; + old = rpc_clnt_set_transport(clnt, xprt, timeout); + + rpc_unregister_client(clnt); + __rpc_clnt_remove_pipedir(clnt); + + /* + * A new transport was created. "clnt" therefore + * becomes the root of a new cl_parent tree. clnt's + * children, if it has any, still point to the old xprt. + */ + parent = clnt->cl_parent; + clnt->cl_parent = clnt; + + /* + * The old rpc_auth cache cannot be re-used. GSS + * contexts in particular are between a single + * client and server. + */ + err = rpc_client_register(clnt, pseudoflavor, NULL); + if (err) + goto out_revert; + + synchronize_rcu(); + if (parent != clnt) + rpc_release_client(parent); + xprt_put(old); + dprintk("RPC: replaced xprt for clnt %p\n", clnt); + return 0; + +out_revert: + rpc_clnt_set_transport(clnt, old, old_timeo); + clnt->cl_parent = parent; + rpc_client_register(clnt, pseudoflavor, NULL); + xprt_put(xprt); + dprintk("RPC: failed to switch xprt for clnt %p\n", clnt); + return err; +} +EXPORT_SYMBOL_GPL(rpc_switch_client_transport); + +/* + * Kill all tasks for the given client. + * XXX: kill their descendants as well? + */ +void rpc_killall_tasks(struct rpc_clnt *clnt) +{ + struct rpc_task *rovr; + + + if (list_empty(&clnt->cl_tasks)) + return; + dprintk("RPC: killing all tasks for client %p\n", clnt); + /* + * Spin lock all_tasks to prevent changes... + */ + spin_lock(&clnt->cl_lock); + list_for_each_entry(rovr, &clnt->cl_tasks, tk_task) { + if (!RPC_IS_ACTIVATED(rovr)) + continue; + if (!(rovr->tk_flags & RPC_TASK_KILLED)) { + rovr->tk_flags |= RPC_TASK_KILLED; + rpc_exit(rovr, -EIO); + if (RPC_IS_QUEUED(rovr)) + rpc_wake_up_queued_task(rovr->tk_waitqueue, + rovr); + } + } + spin_unlock(&clnt->cl_lock); +} +EXPORT_SYMBOL_GPL(rpc_killall_tasks); + /* * Properly shut down an RPC client, terminating all outstanding * requests. */ void rpc_shutdown_client(struct rpc_clnt *clnt) { - dprintk("RPC: shutting down %s client for %s\n", - clnt->cl_protname, clnt->cl_server); + might_sleep(); + + dprintk_rcu("RPC: shutting down %s client for %s\n", + clnt->cl_program->name, + rcu_dereference(clnt->cl_xprt)->servername); while (!list_empty(&clnt->cl_tasks)) { rpc_killall_tasks(clnt); @@ -435,55 +758,47 @@ EXPORT_SYMBOL_GPL(rpc_shutdown_client); /* * Free an RPC client */ -static void -rpc_free_client(struct kref *kref) +static struct rpc_clnt * +rpc_free_client(struct rpc_clnt *clnt) { - struct rpc_clnt *clnt = container_of(kref, struct rpc_clnt, cl_kref); - - dprintk("RPC: destroying %s client for %s\n", - clnt->cl_protname, clnt->cl_server); - if (!IS_ERR(clnt->cl_path.dentry)) { - rpc_remove_client_dir(clnt->cl_path.dentry); - rpc_put_mount(); - } - if (clnt->cl_parent != clnt) { - rpc_release_client(clnt->cl_parent); - goto out_free; - } - if (clnt->cl_server != clnt->cl_inline_name) - kfree(clnt->cl_server); -out_free: + struct rpc_clnt *parent = NULL; + + dprintk_rcu("RPC: destroying %s client for %s\n", + clnt->cl_program->name, + rcu_dereference(clnt->cl_xprt)->servername); + if (clnt->cl_parent != clnt) + parent = clnt->cl_parent; + rpc_clnt_remove_pipedir(clnt); rpc_unregister_client(clnt); rpc_free_iostats(clnt->cl_metrics); - kfree(clnt->cl_principal); clnt->cl_metrics = NULL; - xprt_put(clnt->cl_xprt); + xprt_put(rcu_dereference_raw(clnt->cl_xprt)); rpciod_down(); + rpc_free_clid(clnt); kfree(clnt); + return parent; } /* * Free an RPC client */ -static void -rpc_free_auth(struct kref *kref) +static struct rpc_clnt * +rpc_free_auth(struct rpc_clnt *clnt) { - struct rpc_clnt *clnt = container_of(kref, struct rpc_clnt, cl_kref); - - if (clnt->cl_auth == NULL) { - rpc_free_client(kref); - return; - } + if (clnt->cl_auth == NULL) + return rpc_free_client(clnt); /* * Note: RPCSEC_GSS may need to send NULL RPC calls in order to * release remaining GSS contexts. This mechanism ensures * that it can do so safely. */ - kref_init(kref); + atomic_inc(&clnt->cl_count); rpcauth_release(clnt->cl_auth); clnt->cl_auth = NULL; - kref_put(kref, rpc_free_client); + if (atomic_dec_and_test(&clnt->cl_count)) + return rpc_free_client(clnt); + return NULL; } /* @@ -494,10 +809,15 @@ rpc_release_client(struct rpc_clnt *clnt) { dprintk("RPC: rpc_release_client(%p)\n", clnt); - if (list_empty(&clnt->cl_tasks)) - wake_up(&destroy_wait); - kref_put(&clnt->cl_kref, rpc_free_auth); + do { + if (list_empty(&clnt->cl_tasks)) + wake_up(&destroy_wait); + if (!atomic_dec_and_test(&clnt->cl_count)) + break; + clnt = rpc_free_auth(clnt); + } while (clnt != NULL); } +EXPORT_SYMBOL_GPL(rpc_release_client); /** * rpc_bind_new_program - bind a new RPC program to an existing client @@ -510,24 +830,21 @@ rpc_release_client(struct rpc_clnt *clnt) * The Sun NFSv2/v3 ACL protocol can do this. */ struct rpc_clnt *rpc_bind_new_program(struct rpc_clnt *old, - struct rpc_program *program, + const struct rpc_program *program, u32 vers) { + struct rpc_create_args args = { + .program = program, + .prognumber = program->number, + .version = vers, + .authflavor = old->cl_auth->au_flavor, + }; struct rpc_clnt *clnt; - struct rpc_version *version; int err; - BUG_ON(vers >= program->nrvers || !program->version[vers]); - version = program->version[vers]; - clnt = rpc_clone_client(old); + clnt = __rpc_clone_client(&args, old); if (IS_ERR(clnt)) goto out; - clnt->cl_procinfo = version->procs; - clnt->cl_maxproc = version->nrprocs; - clnt->cl_protname = program->name; - clnt->cl_prog = program->number; - clnt->cl_vers = version->number; - clnt->cl_stats = program->stats; err = rpc_ping(clnt); if (err != 0) { rpc_shutdown_client(clnt); @@ -538,6 +855,68 @@ out: } EXPORT_SYMBOL_GPL(rpc_bind_new_program); +void rpc_task_release_client(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + + if (clnt != NULL) { + /* Remove from client task list */ + spin_lock(&clnt->cl_lock); + list_del(&task->tk_task); + spin_unlock(&clnt->cl_lock); + task->tk_client = NULL; + + rpc_release_client(clnt); + } +} + +static +void rpc_task_set_client(struct rpc_task *task, struct rpc_clnt *clnt) +{ + if (clnt != NULL) { + rpc_task_release_client(task); + task->tk_client = clnt; + atomic_inc(&clnt->cl_count); + if (clnt->cl_softrtry) + task->tk_flags |= RPC_TASK_SOFT; + if (clnt->cl_noretranstimeo) + task->tk_flags |= RPC_TASK_NO_RETRANS_TIMEOUT; + if (sk_memalloc_socks()) { + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + if (xprt->swapper) + task->tk_flags |= RPC_TASK_SWAPPER; + rcu_read_unlock(); + } + /* Add to the client's list of all tasks */ + spin_lock(&clnt->cl_lock); + list_add_tail(&task->tk_task, &clnt->cl_tasks); + spin_unlock(&clnt->cl_lock); + } +} + +void rpc_task_reset_client(struct rpc_task *task, struct rpc_clnt *clnt) +{ + rpc_task_release_client(task); + rpc_task_set_client(task, clnt); +} +EXPORT_SYMBOL_GPL(rpc_task_reset_client); + + +static void +rpc_task_set_rpc_message(struct rpc_task *task, const struct rpc_message *msg) +{ + if (msg != NULL) { + task->tk_msg.rpc_proc = msg->rpc_proc; + task->tk_msg.rpc_argp = msg->rpc_argp; + task->tk_msg.rpc_resp = msg->rpc_resp; + if (msg->rpc_cred != NULL) + task->tk_msg.rpc_cred = get_rpccred(msg->rpc_cred); + } +} + /* * Default callback for async RPC calls */ @@ -556,26 +935,22 @@ static const struct rpc_call_ops rpc_default_ops = { */ struct rpc_task *rpc_run_task(const struct rpc_task_setup *task_setup_data) { - struct rpc_task *task, *ret; + struct rpc_task *task; task = rpc_new_task(task_setup_data); - if (task == NULL) { - rpc_release_calldata(task_setup_data->callback_ops, - task_setup_data->callback_data); - ret = ERR_PTR(-ENOMEM); + if (IS_ERR(task)) goto out; - } - if (task->tk_status != 0) { - ret = ERR_PTR(task->tk_status); - rpc_put_task(task); - goto out; - } + rpc_task_set_client(task, task_setup_data->rpc_client); + rpc_task_set_rpc_message(task, task_setup_data->rpc_message); + + if (task->tk_action == NULL) + rpc_call_start(task); + atomic_inc(&task->tk_count); rpc_execute(task); - ret = task; out: - return ret; + return task; } EXPORT_SYMBOL_GPL(rpc_run_task); @@ -596,7 +971,12 @@ int rpc_call_sync(struct rpc_clnt *clnt, const struct rpc_message *msg, int flag }; int status; - BUG_ON(flags & RPC_TASK_ASYNC); + WARN_ON_ONCE(flags & RPC_TASK_ASYNC); + if (flags & RPC_TASK_ASYNC) { + rpc_release_calldata(task_setup_data.callback_ops, + task_setup_data.callback_data); + return -EINVAL; + } task = rpc_run_task(&task_setup_data); if (IS_ERR(task)) @@ -636,7 +1016,7 @@ rpc_call_async(struct rpc_clnt *clnt, const struct rpc_message *msg, int flags, } EXPORT_SYMBOL_GPL(rpc_call_async); -#if defined(CONFIG_NFS_V4_1) +#if defined(CONFIG_SUNRPC_BACKCHANNEL) /** * rpc_run_bc_task - Allocate a new RPC task for backchannel use, then run * rpc_execute against it @@ -657,7 +1037,7 @@ struct rpc_task *rpc_run_bc_task(struct rpc_rqst *req, * Create an rpc_task to send the data */ task = rpc_new_task(&task_setup_data); - if (!task) { + if (IS_ERR(task)) { xprt_free_bc_request(req); goto out; } @@ -672,14 +1052,14 @@ struct rpc_task *rpc_run_bc_task(struct rpc_rqst *req, task->tk_action = call_bc_transmit; atomic_inc(&task->tk_count); - BUG_ON(atomic_read(&task->tk_count) != 2); + WARN_ON_ONCE(atomic_read(&task->tk_count) != 2); rpc_execute(task); out: dprintk("RPC: rpc_run_bc_task: task= %p\n", task); return task; } -#endif /* CONFIG_NFS_V4_1 */ +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ void rpc_call_start(struct rpc_task *task) @@ -699,13 +1079,18 @@ EXPORT_SYMBOL_GPL(rpc_call_start); size_t rpc_peeraddr(struct rpc_clnt *clnt, struct sockaddr *buf, size_t bufsize) { size_t bytes; - struct rpc_xprt *xprt = clnt->cl_xprt; + struct rpc_xprt *xprt; - bytes = sizeof(xprt->addr); + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + + bytes = xprt->addrlen; if (bytes > bufsize) bytes = bufsize; - memcpy(buf, &clnt->cl_xprt->addr, bytes); - return xprt->addrlen; + memcpy(buf, &xprt->addr, bytes); + rcu_read_unlock(); + + return bytes; } EXPORT_SYMBOL_GPL(rpc_peeraddr); @@ -714,11 +1099,16 @@ EXPORT_SYMBOL_GPL(rpc_peeraddr); * @clnt: RPC client structure * @format: address format * + * NB: the lifetime of the memory referenced by the returned pointer is + * the same as the rpc_xprt itself. As long as the caller uses this + * pointer, it must hold the RCU read lock. */ const char *rpc_peeraddr2str(struct rpc_clnt *clnt, enum rpc_display_format_t format) { - struct rpc_xprt *xprt = clnt->cl_xprt; + struct rpc_xprt *xprt; + + xprt = rcu_dereference(clnt->cl_xprt); if (xprt->address_strings[format] != NULL) return xprt->address_strings[format]; @@ -727,17 +1117,203 @@ const char *rpc_peeraddr2str(struct rpc_clnt *clnt, } EXPORT_SYMBOL_GPL(rpc_peeraddr2str); +static const struct sockaddr_in rpc_inaddr_loopback = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), +}; + +static const struct sockaddr_in6 rpc_in6addr_loopback = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, +}; + +/* + * Try a getsockname() on a connected datagram socket. Using a + * connected datagram socket prevents leaving a socket in TIME_WAIT. + * This conserves the ephemeral port number space. + * + * Returns zero and fills in "buf" if successful; otherwise, a + * negative errno is returned. + */ +static int rpc_sockname(struct net *net, struct sockaddr *sap, size_t salen, + struct sockaddr *buf, int buflen) +{ + struct socket *sock; + int err; + + err = __sock_create(net, sap->sa_family, + SOCK_DGRAM, IPPROTO_UDP, &sock, 1); + if (err < 0) { + dprintk("RPC: can't create UDP socket (%d)\n", err); + goto out; + } + + switch (sap->sa_family) { + case AF_INET: + err = kernel_bind(sock, + (struct sockaddr *)&rpc_inaddr_loopback, + sizeof(rpc_inaddr_loopback)); + break; + case AF_INET6: + err = kernel_bind(sock, + (struct sockaddr *)&rpc_in6addr_loopback, + sizeof(rpc_in6addr_loopback)); + break; + default: + err = -EAFNOSUPPORT; + goto out; + } + if (err < 0) { + dprintk("RPC: can't bind UDP socket (%d)\n", err); + goto out_release; + } + + err = kernel_connect(sock, sap, salen, 0); + if (err < 0) { + dprintk("RPC: can't connect UDP socket (%d)\n", err); + goto out_release; + } + + err = kernel_getsockname(sock, buf, &buflen); + if (err < 0) { + dprintk("RPC: getsockname failed (%d)\n", err); + goto out_release; + } + + err = 0; + if (buf->sa_family == AF_INET6) { + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)buf; + sin6->sin6_scope_id = 0; + } + dprintk("RPC: %s succeeded\n", __func__); + +out_release: + sock_release(sock); +out: + return err; +} + +/* + * Scraping a connected socket failed, so we don't have a useable + * local address. Fallback: generate an address that will prevent + * the server from calling us back. + * + * Returns zero and fills in "buf" if successful; otherwise, a + * negative errno is returned. + */ +static int rpc_anyaddr(int family, struct sockaddr *buf, size_t buflen) +{ + switch (family) { + case AF_INET: + if (buflen < sizeof(rpc_inaddr_loopback)) + return -EINVAL; + memcpy(buf, &rpc_inaddr_loopback, + sizeof(rpc_inaddr_loopback)); + break; + case AF_INET6: + if (buflen < sizeof(rpc_in6addr_loopback)) + return -EINVAL; + memcpy(buf, &rpc_in6addr_loopback, + sizeof(rpc_in6addr_loopback)); + default: + dprintk("RPC: %s: address family not supported\n", + __func__); + return -EAFNOSUPPORT; + } + dprintk("RPC: %s: succeeded\n", __func__); + return 0; +} + +/** + * rpc_localaddr - discover local endpoint address for an RPC client + * @clnt: RPC client structure + * @buf: target buffer + * @buflen: size of target buffer, in bytes + * + * Returns zero and fills in "buf" and "buflen" if successful; + * otherwise, a negative errno is returned. + * + * This works even if the underlying transport is not currently connected, + * or if the upper layer never previously provided a source address. + * + * The result of this function call is transient: multiple calls in + * succession may give different results, depending on how local + * networking configuration changes over time. + */ +int rpc_localaddr(struct rpc_clnt *clnt, struct sockaddr *buf, size_t buflen) +{ + struct sockaddr_storage address; + struct sockaddr *sap = (struct sockaddr *)&address; + struct rpc_xprt *xprt; + struct net *net; + size_t salen; + int err; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + salen = xprt->addrlen; + memcpy(sap, &xprt->addr, salen); + net = get_net(xprt->xprt_net); + rcu_read_unlock(); + + rpc_set_port(sap, 0); + err = rpc_sockname(net, sap, salen, buf, buflen); + put_net(net); + if (err != 0) + /* Couldn't discover local address, return ANYADDR */ + return rpc_anyaddr(sap->sa_family, buf, buflen); + return 0; +} +EXPORT_SYMBOL_GPL(rpc_localaddr); + void rpc_setbufsize(struct rpc_clnt *clnt, unsigned int sndsize, unsigned int rcvsize) { - struct rpc_xprt *xprt = clnt->cl_xprt; + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); if (xprt->ops->set_buffer_size) xprt->ops->set_buffer_size(xprt, sndsize, rcvsize); + rcu_read_unlock(); } EXPORT_SYMBOL_GPL(rpc_setbufsize); -/* - * Return size of largest payload RPC client can support, in bytes +/** + * rpc_protocol - Get transport protocol number for an RPC client + * @clnt: RPC client to query + * + */ +int rpc_protocol(struct rpc_clnt *clnt) +{ + int protocol; + + rcu_read_lock(); + protocol = rcu_dereference(clnt->cl_xprt)->prot; + rcu_read_unlock(); + return protocol; +} +EXPORT_SYMBOL_GPL(rpc_protocol); + +/** + * rpc_net_ns - Get the network namespace for this RPC client + * @clnt: RPC client to query + * + */ +struct net *rpc_net_ns(struct rpc_clnt *clnt) +{ + struct net *ret; + + rcu_read_lock(); + ret = rcu_dereference(clnt->cl_xprt)->xprt_net; + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_net_ns); + +/** + * rpc_max_payload - Get maximum payload size for a transport, in bytes + * @clnt: RPC client to query * * For stream transports, this is one RPC record fragment (see RFC * 1831), as we don't support multi-record requests yet. For datagram @@ -746,19 +1322,42 @@ EXPORT_SYMBOL_GPL(rpc_setbufsize); */ size_t rpc_max_payload(struct rpc_clnt *clnt) { - return clnt->cl_xprt->max_payload; + size_t ret; + + rcu_read_lock(); + ret = rcu_dereference(clnt->cl_xprt)->max_payload; + rcu_read_unlock(); + return ret; } EXPORT_SYMBOL_GPL(rpc_max_payload); /** + * rpc_get_timeout - Get timeout for transport in units of HZ + * @clnt: RPC client to query + */ +unsigned long rpc_get_timeout(struct rpc_clnt *clnt) +{ + unsigned long ret; + + rcu_read_lock(); + ret = rcu_dereference(clnt->cl_xprt)->timeout->to_initval; + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_get_timeout); + +/** * rpc_force_rebind - force transport to check that remote port is unchanged * @clnt: client to rebind * */ void rpc_force_rebind(struct rpc_clnt *clnt) { - if (clnt->cl_autobind) - xprt_clear_bound(clnt->cl_xprt); + if (clnt->cl_autobind) { + rcu_read_lock(); + xprt_clear_bound(rcu_dereference(clnt->cl_xprt)); + rcu_read_unlock(); + } } EXPORT_SYMBOL_GPL(rpc_force_rebind); @@ -766,12 +1365,16 @@ EXPORT_SYMBOL_GPL(rpc_force_rebind); * Restart an (async) RPC call from the call_prepare state. * Usually called from within the exit handler. */ -void +int rpc_restart_call_prepare(struct rpc_task *task) { if (RPC_ASSASSINATED(task)) - return; - task->tk_action = rpc_prepare_task; + return 0; + task->tk_action = call_start; + task->tk_status = 0; + if (task->tk_ops->rpc_call_prepare != NULL) + task->tk_action = rpc_prepare_task; + return 1; } EXPORT_SYMBOL_GPL(rpc_restart_call_prepare); @@ -779,13 +1382,14 @@ EXPORT_SYMBOL_GPL(rpc_restart_call_prepare); * Restart an (async) RPC call. Usually called from within the * exit handler. */ -void +int rpc_restart_call(struct rpc_task *task) { if (RPC_ASSASSINATED(task)) - return; - + return 0; task->tk_action = call_start; + task->tk_status = 0; + return 1; } EXPORT_SYMBOL_GPL(rpc_restart_call); @@ -816,7 +1420,7 @@ call_start(struct rpc_task *task) struct rpc_clnt *clnt = task->tk_client; dprintk("RPC: %5u call_start %s%d proc %s (%s)\n", task->tk_pid, - clnt->cl_protname, clnt->cl_vers, + clnt->cl_program->name, clnt->cl_vers, rpc_proc_name(task), (RPC_IS_ASYNC(task) ? "async" : "sync")); @@ -834,16 +1438,13 @@ call_reserve(struct rpc_task *task) { dprint_status(task); - if (!rpcauth_uptodatecred(task)) { - task->tk_action = call_refresh; - return; - } - task->tk_status = 0; task->tk_action = call_reserveresult; xprt_reserve(task); } +static void call_retry_reserve(struct rpc_task *task); + /* * 1b. Grok the result of xprt_reserve() */ @@ -861,7 +1462,7 @@ call_reserveresult(struct rpc_task *task) task->tk_status = 0; if (status >= 0) { if (task->tk_rqstp) { - task->tk_action = call_allocate; + task->tk_action = call_refresh; return; } @@ -882,8 +1483,10 @@ call_reserveresult(struct rpc_task *task) } switch (status) { + case -ENOMEM: + rpc_delay(task, HZ >> 2); case -EAGAIN: /* woken up; retry */ - task->tk_action = call_reserve; + task->tk_action = call_retry_reserve; return; case -EIO: /* probably a shutdown */ break; @@ -896,15 +1499,80 @@ call_reserveresult(struct rpc_task *task) } /* - * 2. Allocate the buffer. For details, see sched.c:rpc_malloc. + * 1c. Retry reserving an RPC call slot + */ +static void +call_retry_reserve(struct rpc_task *task) +{ + dprint_status(task); + + task->tk_status = 0; + task->tk_action = call_reserveresult; + xprt_retry_reserve(task); +} + +/* + * 2. Bind and/or refresh the credentials + */ +static void +call_refresh(struct rpc_task *task) +{ + dprint_status(task); + + task->tk_action = call_refreshresult; + task->tk_status = 0; + task->tk_client->cl_stats->rpcauthrefresh++; + rpcauth_refreshcred(task); +} + +/* + * 2a. Process the results of a credential refresh + */ +static void +call_refreshresult(struct rpc_task *task) +{ + int status = task->tk_status; + + dprint_status(task); + + task->tk_status = 0; + task->tk_action = call_refresh; + switch (status) { + case 0: + if (rpcauth_uptodatecred(task)) { + task->tk_action = call_allocate; + return; + } + /* Use rate-limiting and a max number of retries if refresh + * had status 0 but failed to update the cred. + */ + case -ETIMEDOUT: + rpc_delay(task, 3*HZ); + case -EAGAIN: + status = -EACCES; + case -EKEYEXPIRED: + if (!task->tk_cred_retry) + break; + task->tk_cred_retry--; + dprintk("RPC: %5u %s: retry refresh creds\n", + task->tk_pid, __func__); + return; + } + dprintk("RPC: %5u %s: refresh creds failed with error %d\n", + task->tk_pid, __func__, status); + rpc_exit(task, status); +} + +/* + * 2b. Allocate the buffer. For details, see sched.c:rpc_malloc. * (Note: buffer memory is freed in xprt_release). */ static void call_allocate(struct rpc_task *task) { - unsigned int slack = task->tk_msg.rpc_cred->cr_auth->au_cslack; + unsigned int slack = task->tk_rqstp->rq_cred->cr_auth->au_cslack; struct rpc_rqst *req = task->tk_rqstp; - struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_xprt *xprt = req->rq_xprt; struct rpc_procinfo *proc = task->tk_msg.rpc_proc; dprint_status(task); @@ -938,7 +1606,7 @@ call_allocate(struct rpc_task *task) dprintk("RPC: %5u rpc_buffer allocation failed\n", task->tk_pid); - if (RPC_IS_ASYNC(task) || !signalled()) { + if (RPC_IS_ASYNC(task) || !fatal_signal_pending(current)) { task->tk_action = call_allocate; rpc_delay(task, HZ>>4); return; @@ -979,7 +1647,7 @@ static void rpc_xdr_encode(struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; - kxdrproc_t encode; + kxdreproc_t encode; __be32 *p; dprint_status(task); @@ -1012,7 +1680,7 @@ rpc_xdr_encode(struct rpc_task *task) static void call_bind(struct rpc_task *task) { - struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; dprint_status(task); @@ -1039,6 +1707,7 @@ call_bind_status(struct rpc_task *task) return; } + trace_rpc_bind_status(task); switch (task->tk_status) { case -ENOMEM: dprintk("RPC: %5u rpcbind out of memory\n", task->tk_pid); @@ -1052,6 +1721,9 @@ call_bind_status(struct rpc_task *task) status = -EOPNOTSUPP; break; } + if (task->tk_rebind_retry == 0) + break; + task->tk_rebind_retry--; rpc_delay(task, 3*HZ); goto retry_timeout; case -ETIMEDOUT: @@ -1066,11 +1738,10 @@ call_bind_status(struct rpc_task *task) case -EPROTONOSUPPORT: dprintk("RPC: %5u remote rpcbind version unavailable, retrying\n", task->tk_pid); - task->tk_status = 0; - task->tk_action = call_bind; - return; + goto retry_timeout; case -ECONNREFUSED: /* connection problems */ case -ECONNRESET: + case -ECONNABORTED: case -ENOTCONN: case -EHOSTDOWN: case -EHOSTUNREACH: @@ -1093,6 +1764,7 @@ call_bind_status(struct rpc_task *task) return; retry_timeout: + task->tk_status = 0; task->tk_action = call_timeout; } @@ -1102,7 +1774,7 @@ retry_timeout: static void call_connect(struct rpc_task *task) { - struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; dprintk("RPC: %5u call_connect xprt %p %s connected\n", task->tk_pid, xprt, @@ -1113,6 +1785,10 @@ call_connect(struct rpc_task *task) task->tk_action = call_connect_status; if (task->tk_status < 0) return; + if (task->tk_flags & RPC_TASK_NOCONNECT) { + rpc_exit(task, -ENOTCONN); + return; + } xprt_connect(task); } } @@ -1128,21 +1804,29 @@ call_connect_status(struct rpc_task *task) dprint_status(task); + trace_rpc_connect_status(task, status); task->tk_status = 0; - if (status >= 0 || status == -EAGAIN) { - clnt->cl_stats->netreconn++; - task->tk_action = call_transmit; - return; - } - switch (status) { - /* if soft mounted, test if we've timed out */ + case -ECONNREFUSED: + case -ECONNRESET: + case -ECONNABORTED: + case -ENETUNREACH: + case -EHOSTUNREACH: + if (RPC_IS_SOFTCONN(task)) + break; + /* retry with existing socket, after a delay */ + rpc_delay(task, 3*HZ); + case -EAGAIN: + /* Check for timeouts before looping back to call_bind */ case -ETIMEDOUT: task->tk_action = call_timeout; - break; - default: - rpc_exit(task, -EIO); + return; + case 0: + clnt->cl_stats->netreconn++; + task->tk_action = call_transmit; + return; } + rpc_exit(task, status); } /* @@ -1151,18 +1835,18 @@ call_connect_status(struct rpc_task *task) static void call_transmit(struct rpc_task *task) { + int is_retrans = RPC_WAS_SENT(task); + dprint_status(task); task->tk_action = call_status; if (task->tk_status < 0) return; - task->tk_status = xprt_prepare_transmit(task); - if (task->tk_status != 0) + if (!xprt_prepare_transmit(task)) return; task->tk_action = call_transmit_status; /* Encode here so that rpcsec_gss can use correct sequence number. */ if (rpc_task_need_encode(task)) { - BUG_ON(task->tk_rqstp->rq_bytes_sent != 0); rpc_xdr_encode(task); /* Did the encode result in an error condition? */ if (task->tk_status != 0) { @@ -1177,6 +1861,8 @@ call_transmit(struct rpc_task *task) xprt_transmit(task); if (task->tk_status < 0) return; + if (is_retrans) + task->tk_client->cl_stats->rpcretrans++; /* * On success, ensure that we call xprt_end_transmit() before sleeping * in order to allow access to the socket to other RPC requests. @@ -1185,7 +1871,7 @@ call_transmit(struct rpc_task *task) if (rpc_reply_expected(task)) return; task->tk_action = rpc_exit_task; - rpc_wake_up_queued_task(&task->tk_xprt->pending, task); + rpc_wake_up_queued_task(&task->tk_rqstp->rq_xprt->pending, task); } /* @@ -1230,13 +1916,14 @@ call_transmit_status(struct rpc_task *task) break; } case -ECONNRESET: + case -ECONNABORTED: case -ENOTCONN: case -EPIPE: rpc_task_force_reencode(task); } } -#if defined(CONFIG_NFS_V4_1) +#if defined(CONFIG_SUNRPC_BACKCHANNEL) /* * 5b. Send the backchannel RPC reply. On error, drop the reply. In * addition, disconnect on connectivity errors. @@ -1246,9 +1933,7 @@ call_bc_transmit(struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; - BUG_ON(task->tk_status != 0); - task->tk_status = xprt_prepare_transmit(task); - if (task->tk_status == -EAGAIN) { + if (!xprt_prepare_transmit(task)) { /* * Could not reserve the transport. Try again after the * transport is released. @@ -1285,7 +1970,7 @@ call_bc_transmit(struct rpc_task *task) */ printk(KERN_NOTICE "RPC: Could not send backchannel reply " "error: %d\n", task->tk_status); - xprt_conditional_disconnect(task->tk_xprt, + xprt_conditional_disconnect(req->rq_xprt, req->rq_connect_cookie); break; default: @@ -1293,14 +1978,14 @@ call_bc_transmit(struct rpc_task *task) * We were unable to reply and will have to drop the * request. The server should reconnect and retransmit. */ - BUG_ON(task->tk_status == -EAGAIN); + WARN_ON_ONCE(task->tk_status == -EAGAIN); printk(KERN_NOTICE "RPC: Could not send backchannel reply " "error: %d\n", task->tk_status); break; } rpc_wake_up_queued_task(&req->rq_xprt->pending, task); } -#endif /* CONFIG_NFS_V4_1 */ +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ /* * 6. Sort out the RPC call status @@ -1323,11 +2008,16 @@ call_status(struct rpc_task *task) return; } + trace_rpc_call_status(task); task->tk_status = 0; switch(status) { case -EHOSTDOWN: case -EHOSTUNREACH: case -ENETUNREACH: + if (RPC_IS_SOFTCONN(task)) { + rpc_exit(task, status); + break; + } /* * Delay any retries for 3 seconds, then handle as if it * were a timeout. @@ -1335,12 +2025,14 @@ call_status(struct rpc_task *task) rpc_delay(task, 3*HZ); case -ETIMEDOUT: task->tk_action = call_timeout; - if (task->tk_client->cl_discrtry) - xprt_conditional_disconnect(task->tk_xprt, + if (!(task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) + && task->tk_client->cl_discrtry) + xprt_conditional_disconnect(req->rq_xprt, req->rq_connect_cookie); break; - case -ECONNRESET: case -ECONNREFUSED: + case -ECONNRESET: + case -ECONNABORTED: rpc_force_rebind(clnt); rpc_delay(task, 3*HZ); case -EPIPE: @@ -1357,7 +2049,7 @@ call_status(struct rpc_task *task) default: if (clnt->cl_chatty) printk("%s: RPC call returned error %d\n", - clnt->cl_protname, -status); + clnt->cl_program->name, -status); rpc_exit(task, status); } } @@ -1385,18 +2077,29 @@ call_timeout(struct rpc_task *task) return; } if (RPC_IS_SOFT(task)) { - if (clnt->cl_chatty) + if (clnt->cl_chatty) { + rcu_read_lock(); printk(KERN_NOTICE "%s: server %s not responding, timed out\n", - clnt->cl_protname, clnt->cl_server); - rpc_exit(task, -EIO); + clnt->cl_program->name, + rcu_dereference(clnt->cl_xprt)->servername); + rcu_read_unlock(); + } + if (task->tk_flags & RPC_TASK_TIMEOUT) + rpc_exit(task, -ETIMEDOUT); + else + rpc_exit(task, -EIO); return; } if (!(task->tk_flags & RPC_CALL_MAJORSEEN)) { task->tk_flags |= RPC_CALL_MAJORSEEN; - if (clnt->cl_chatty) + if (clnt->cl_chatty) { + rcu_read_lock(); printk(KERN_NOTICE "%s: server %s not responding, still trying\n", - clnt->cl_protname, clnt->cl_server); + clnt->cl_program->name, + rcu_dereference(clnt->cl_xprt)->servername); + rcu_read_unlock(); + } } rpc_force_rebind(clnt); /* @@ -1406,7 +2109,6 @@ call_timeout(struct rpc_task *task) rpcauth_invalcred(task); retry: - clnt->cl_stats->rpcretrans++; task->tk_action = call_bind; task->tk_status = 0; } @@ -1419,16 +2121,19 @@ call_decode(struct rpc_task *task) { struct rpc_clnt *clnt = task->tk_client; struct rpc_rqst *req = task->tk_rqstp; - kxdrproc_t decode = task->tk_msg.rpc_proc->p_decode; + kxdrdproc_t decode = task->tk_msg.rpc_proc->p_decode; __be32 *p; - dprintk("RPC: %5u call_decode (status %d)\n", - task->tk_pid, task->tk_status); + dprint_status(task); if (task->tk_flags & RPC_CALL_MAJORSEEN) { - if (clnt->cl_chatty) + if (clnt->cl_chatty) { + rcu_read_lock(); printk(KERN_NOTICE "%s: server %s OK\n", - clnt->cl_protname, clnt->cl_server); + clnt->cl_program->name, + rcu_dereference(clnt->cl_xprt)->servername); + rcu_read_unlock(); + } task->tk_flags &= ~RPC_CALL_MAJORSEEN; } @@ -1446,11 +2151,10 @@ call_decode(struct rpc_task *task) if (req->rq_rcv_buf.len < 12) { if (!RPC_IS_SOFT(task)) { task->tk_action = call_bind; - clnt->cl_stats->rpcretrans++; goto out_retry; } dprintk("RPC: %s: too small RPC reply size (%d bytes)\n", - clnt->cl_protname, task->tk_status); + clnt->cl_program->name, task->tk_status); task->tk_action = call_timeout; goto out_retry; } @@ -1477,49 +2181,11 @@ out_retry: if (task->tk_rqstp == req) { req->rq_reply_bytes_recvd = req->rq_rcv_buf.len = 0; if (task->tk_client->cl_discrtry) - xprt_conditional_disconnect(task->tk_xprt, + xprt_conditional_disconnect(req->rq_xprt, req->rq_connect_cookie); } } -/* - * 8. Refresh the credentials if rejected by the server - */ -static void -call_refresh(struct rpc_task *task) -{ - dprint_status(task); - - task->tk_action = call_refreshresult; - task->tk_status = 0; - task->tk_client->cl_stats->rpcauthrefresh++; - rpcauth_refreshcred(task); -} - -/* - * 8a. Process the results of a credential refresh - */ -static void -call_refreshresult(struct rpc_task *task) -{ - int status = task->tk_status; - - dprint_status(task); - - task->tk_status = 0; - task->tk_action = call_reserve; - if (status >= 0 && rpcauth_uptodatecred(task)) - return; - if (status == -EACCES) { - rpc_exit(task, -EACCES); - return; - } - task->tk_action = call_refresh; - if (status != -ETIMEDOUT) - rpc_delay(task, 3*HZ); - return; -} - static __be32 * rpc_encode_header(struct rpc_task *task) { @@ -1529,7 +2195,7 @@ rpc_encode_header(struct rpc_task *task) /* FIXME: check buffer size? */ - p = xprt_skip_transport_header(task->tk_xprt, p); + p = xprt_skip_transport_header(req->rq_xprt, p); *p++ = req->rq_xid; /* XID */ *p++ = htonl(RPC_CALL); /* CALL */ *p++ = htonl(RPC_VERSION); /* RPC version */ @@ -1544,6 +2210,7 @@ rpc_encode_header(struct rpc_task *task) static __be32 * rpc_verify_header(struct rpc_task *task) { + struct rpc_clnt *clnt = task->tk_client; struct kvec *iov = &task->tk_rqstp->rq_rcv_buf.head[0]; int len = task->tk_rqstp->rq_rcv_buf.len >> 2; __be32 *p = iov->iov_base; @@ -1559,7 +2226,8 @@ rpc_verify_header(struct rpc_task *task) dprintk("RPC: %5u %s: XDR representation not a multiple of" " 4 bytes: 0x%x\n", task->tk_pid, __func__, task->tk_rqstp->rq_rcv_buf.len); - goto out_eio; + error = -EIO; + goto out_err; } if ((len -= 3) < 0) goto out_overflow; @@ -1568,6 +2236,7 @@ rpc_verify_header(struct rpc_task *task) if ((n = ntohl(*p++)) != RPC_REPLY) { dprintk("RPC: %5u %s: not an RPC reply: %x\n", task->tk_pid, __func__, n); + error = -EIO; goto out_garbage; } @@ -1575,19 +2244,19 @@ rpc_verify_header(struct rpc_task *task) if (--len < 0) goto out_overflow; switch ((n = ntohl(*p++))) { - case RPC_AUTH_ERROR: - break; - case RPC_MISMATCH: - dprintk("RPC: %5u %s: RPC call version " - "mismatch!\n", - task->tk_pid, __func__); - error = -EPROTONOSUPPORT; - goto out_err; - default: - dprintk("RPC: %5u %s: RPC call rejected, " - "unknown error: %x\n", - task->tk_pid, __func__, n); - goto out_eio; + case RPC_AUTH_ERROR: + break; + case RPC_MISMATCH: + dprintk("RPC: %5u %s: RPC call version mismatch!\n", + task->tk_pid, __func__); + error = -EPROTONOSUPPORT; + goto out_err; + default: + dprintk("RPC: %5u %s: RPC call rejected, " + "unknown error: %x\n", + task->tk_pid, __func__, n); + error = -EIO; + goto out_err; } if (--len < 0) goto out_overflow; @@ -1604,7 +2273,7 @@ rpc_verify_header(struct rpc_task *task) rpcauth_invalcred(task); /* Ensure we obtain a new XID! */ xprt_release(task); - task->tk_action = call_refresh; + task->tk_action = call_reserve; goto out_retry; case RPC_AUTH_BADCRED: case RPC_AUTH_BADVERF: @@ -1617,8 +2286,11 @@ rpc_verify_header(struct rpc_task *task) task->tk_action = call_bind; goto out_retry; case RPC_AUTH_TOOWEAK: + rcu_read_lock(); printk(KERN_NOTICE "RPC: server %s requires stronger " - "authentication.\n", task->tk_client->cl_server); + "authentication.\n", + rcu_dereference(clnt->cl_xprt)->servername); + rcu_read_unlock(); break; default: dprintk("RPC: %5u %s: unknown auth error: %x\n", @@ -1629,9 +2301,11 @@ rpc_verify_header(struct rpc_task *task) task->tk_pid, __func__, n); goto out_err; } - if (!(p = rpcauth_checkverf(task, p))) { - dprintk("RPC: %5u %s: auth check failed\n", - task->tk_pid, __func__); + p = rpcauth_checkverf(task, p); + if (IS_ERR(p)) { + error = PTR_ERR(p); + dprintk("RPC: %5u %s: auth check failed with %d\n", + task->tk_pid, __func__, error); goto out_garbage; /* bad verifier, retry */ } len = p - (__be32 *)iov->iov_base - 1; @@ -1641,28 +2315,27 @@ rpc_verify_header(struct rpc_task *task) case RPC_SUCCESS: return p; case RPC_PROG_UNAVAIL: - dprintk("RPC: %5u %s: program %u is unsupported by server %s\n", - task->tk_pid, __func__, - (unsigned int)task->tk_client->cl_prog, - task->tk_client->cl_server); + dprintk_rcu("RPC: %5u %s: program %u is unsupported " + "by server %s\n", task->tk_pid, __func__, + (unsigned int)clnt->cl_prog, + rcu_dereference(clnt->cl_xprt)->servername); error = -EPFNOSUPPORT; goto out_err; case RPC_PROG_MISMATCH: - dprintk("RPC: %5u %s: program %u, version %u unsupported by " - "server %s\n", task->tk_pid, __func__, - (unsigned int)task->tk_client->cl_prog, - (unsigned int)task->tk_client->cl_vers, - task->tk_client->cl_server); + dprintk_rcu("RPC: %5u %s: program %u, version %u unsupported " + "by server %s\n", task->tk_pid, __func__, + (unsigned int)clnt->cl_prog, + (unsigned int)clnt->cl_vers, + rcu_dereference(clnt->cl_xprt)->servername); error = -EPROTONOSUPPORT; goto out_err; case RPC_PROC_UNAVAIL: - dprintk("RPC: %5u %s: proc %s unsupported by program %u, " + dprintk_rcu("RPC: %5u %s: proc %s unsupported by program %u, " "version %u on server %s\n", task->tk_pid, __func__, rpc_proc_name(task), - task->tk_client->cl_prog, - task->tk_client->cl_vers, - task->tk_client->cl_server); + clnt->cl_prog, clnt->cl_vers, + rcu_dereference(clnt->cl_xprt)->servername); error = -EOPNOTSUPP; goto out_err; case RPC_GARBAGE_ARGS: @@ -1676,7 +2349,7 @@ rpc_verify_header(struct rpc_task *task) } out_garbage: - task->tk_client->cl_stats->rpcgarbage++; + clnt->cl_stats->rpcgarbage++; if (task->tk_garb_retry) { task->tk_garb_retry--; dprintk("RPC: %5u %s: retrying\n", @@ -1685,8 +2358,6 @@ out_garbage: out_retry: return ERR_PTR(-EAGAIN); } -out_eio: - error = -EIO; out_err: rpc_exit(task, error); dprintk("RPC: %5u %s: call failed with error %d\n", task->tk_pid, @@ -1698,12 +2369,11 @@ out_overflow: goto out_garbage; } -static int rpcproc_encode_null(void *rqstp, __be32 *data, void *obj) +static void rpcproc_encode_null(void *rqstp, struct xdr_stream *xdr, void *obj) { - return 0; } -static int rpcproc_decode_null(void *rqstp, __be32 *data, void *obj) +static int rpcproc_decode_null(void *rqstp, struct xdr_stream *xdr, void *obj) { return 0; } @@ -1752,33 +2422,26 @@ static void rpc_show_task(const struct rpc_clnt *clnt, const struct rpc_task *task) { const char *rpc_waitq = "none"; - char *p, action[KSYM_SYMBOL_LEN]; if (RPC_IS_QUEUED(task)) rpc_waitq = rpc_qname(task->tk_waitqueue); - /* map tk_action pointer to a function name; then trim off - * the "+0x0 [sunrpc]" */ - sprint_symbol(action, (unsigned long)task->tk_action); - p = strchr(action, '+'); - if (p) - *p = '\0'; - - printk(KERN_INFO "%5u %04x %6d %8p %8p %8ld %8p %sv%u %s a:%s q:%s\n", + printk(KERN_INFO "%5u %04x %6d %8p %8p %8ld %8p %sv%u %s a:%ps q:%s\n", task->tk_pid, task->tk_flags, task->tk_status, clnt, task->tk_rqstp, task->tk_timeout, task->tk_ops, - clnt->cl_protname, clnt->cl_vers, rpc_proc_name(task), - action, rpc_waitq); + clnt->cl_program->name, clnt->cl_vers, rpc_proc_name(task), + task->tk_action, rpc_waitq); } -void rpc_show_tasks(void) +void rpc_show_tasks(struct net *net) { struct rpc_clnt *clnt; struct rpc_task *task; int header = 0; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); - spin_lock(&rpc_client_lock); - list_for_each_entry(clnt, &all_clients, cl_clients) { + spin_lock(&sn->rpc_client_lock); + list_for_each_entry(clnt, &sn->all_clients, cl_clients) { spin_lock(&clnt->cl_lock); list_for_each_entry(task, &clnt->cl_tasks, tk_task) { if (!header) { @@ -1789,6 +2452,6 @@ void rpc_show_tasks(void) } spin_unlock(&clnt->cl_lock); } - spin_unlock(&rpc_client_lock); + spin_unlock(&sn->rpc_client_lock); } #endif diff --git a/net/sunrpc/netns.h b/net/sunrpc/netns.h new file mode 100644 index 00000000000..df582687653 --- /dev/null +++ b/net/sunrpc/netns.h @@ -0,0 +1,42 @@ +#ifndef __SUNRPC_NETNS_H__ +#define __SUNRPC_NETNS_H__ + +#include <net/net_namespace.h> +#include <net/netns/generic.h> + +struct cache_detail; + +struct sunrpc_net { + struct proc_dir_entry *proc_net_rpc; + struct cache_detail *ip_map_cache; + struct cache_detail *unix_gid_cache; + struct cache_detail *rsc_cache; + struct cache_detail *rsi_cache; + + struct super_block *pipefs_sb; + struct rpc_pipe *gssd_dummy; + struct mutex pipefs_sb_lock; + + struct list_head all_clients; + spinlock_t rpc_client_lock; + + struct rpc_clnt *rpcb_local_clnt; + struct rpc_clnt *rpcb_local_clnt4; + spinlock_t rpcb_clnt_lock; + unsigned int rpcb_users; + unsigned int rpcb_is_af_local : 1; + + struct mutex gssp_lock; + struct rpc_clnt *gssp_clnt; + int use_gss_proxy; + int pipe_version; + atomic_t pipe_users; + struct proc_dir_entry *use_gssp_proc; +}; + +extern int sunrpc_net_id; + +int ip_map_cache_create(struct net *); +void ip_map_cache_destroy(struct net *); + +#endif diff --git a/net/sunrpc/rpc_pipe.c b/net/sunrpc/rpc_pipe.c index 8d63f8fd29b..b1855489856 100644 --- a/net/sunrpc/rpc_pipe.c +++ b/net/sunrpc/rpc_pipe.c @@ -16,9 +16,10 @@ #include <linux/namei.h> #include <linux/fsnotify.h> #include <linux/kernel.h> +#include <linux/rcupdate.h> +#include <linux/utsname.h> #include <asm/ioctls.h> -#include <linux/fs.h> #include <linux/poll.h> #include <linux/wait.h> #include <linux/seq_file.h> @@ -27,18 +28,38 @@ #include <linux/workqueue.h> #include <linux/sunrpc/rpc_pipe_fs.h> #include <linux/sunrpc/cache.h> +#include <linux/nsproxy.h> +#include <linux/notifier.h> -static struct vfsmount *rpc_mount __read_mostly; -static int rpc_mount_count; +#include "netns.h" +#include "sunrpc.h" -static struct file_system_type rpc_pipe_fs_type; +#define RPCDBG_FACILITY RPCDBG_DEBUG + +#define NET_NAME(net) ((net == &init_net) ? " (init_net)" : "") +static struct file_system_type rpc_pipe_fs_type; +static const struct rpc_pipe_ops gssd_dummy_pipe_ops; static struct kmem_cache *rpc_inode_cachep __read_mostly; #define RPC_UPCALL_TIMEOUT (30*HZ) -static void rpc_purge_list(struct rpc_inode *rpci, struct list_head *head, +static BLOCKING_NOTIFIER_HEAD(rpc_pipefs_notifier_list); + +int rpc_pipefs_notifier_register(struct notifier_block *nb) +{ + return blocking_notifier_chain_cond_register(&rpc_pipefs_notifier_list, nb); +} +EXPORT_SYMBOL_GPL(rpc_pipefs_notifier_register); + +void rpc_pipefs_notifier_unregister(struct notifier_block *nb) +{ + blocking_notifier_chain_unregister(&rpc_pipefs_notifier_list, nb); +} +EXPORT_SYMBOL_GPL(rpc_pipefs_notifier_unregister); + +static void rpc_purge_list(wait_queue_head_t *waitq, struct list_head *head, void (*destroy_msg)(struct rpc_pipe_msg *), int err) { struct rpc_pipe_msg *msg; @@ -47,39 +68,60 @@ static void rpc_purge_list(struct rpc_inode *rpci, struct list_head *head, return; do { msg = list_entry(head->next, struct rpc_pipe_msg, list); - list_del(&msg->list); + list_del_init(&msg->list); msg->errno = err; destroy_msg(msg); } while (!list_empty(head)); - wake_up(&rpci->waitq); + + if (waitq) + wake_up(waitq); } static void rpc_timeout_upcall_queue(struct work_struct *work) { LIST_HEAD(free_list); - struct rpc_inode *rpci = - container_of(work, struct rpc_inode, queue_timeout.work); - struct inode *inode = &rpci->vfs_inode; + struct rpc_pipe *pipe = + container_of(work, struct rpc_pipe, queue_timeout.work); void (*destroy_msg)(struct rpc_pipe_msg *); + struct dentry *dentry; - spin_lock(&inode->i_lock); - if (rpci->ops == NULL) { - spin_unlock(&inode->i_lock); - return; + spin_lock(&pipe->lock); + destroy_msg = pipe->ops->destroy_msg; + if (pipe->nreaders == 0) { + list_splice_init(&pipe->pipe, &free_list); + pipe->pipelen = 0; } - destroy_msg = rpci->ops->destroy_msg; - if (rpci->nreaders == 0) { - list_splice_init(&rpci->pipe, &free_list); - rpci->pipelen = 0; + dentry = dget(pipe->dentry); + spin_unlock(&pipe->lock); + rpc_purge_list(dentry ? &RPC_I(dentry->d_inode)->waitq : NULL, + &free_list, destroy_msg, -ETIMEDOUT); + dput(dentry); +} + +ssize_t rpc_pipe_generic_upcall(struct file *filp, struct rpc_pipe_msg *msg, + char __user *dst, size_t buflen) +{ + char *data = (char *)msg->data + msg->copied; + size_t mlen = min(msg->len - msg->copied, buflen); + unsigned long left; + + left = copy_to_user(dst, data, mlen); + if (left == mlen) { + msg->errno = -EFAULT; + return -EFAULT; } - spin_unlock(&inode->i_lock); - rpc_purge_list(rpci, &free_list, destroy_msg, -ETIMEDOUT); + + mlen -= left; + msg->copied += mlen; + msg->errno = 0; + return mlen; } +EXPORT_SYMBOL_GPL(rpc_pipe_generic_upcall); /** * rpc_queue_upcall - queue an upcall message to userspace - * @inode: inode of upcall pipe on which to queue given message + * @pipe: upcall pipe on which to queue given message * @msg: message to queue * * Call with an @inode created by rpc_mkpipe() to queue an upcall. @@ -88,30 +130,31 @@ rpc_timeout_upcall_queue(struct work_struct *work) * initialize the fields of @msg (other than @msg->list) appropriately. */ int -rpc_queue_upcall(struct inode *inode, struct rpc_pipe_msg *msg) +rpc_queue_upcall(struct rpc_pipe *pipe, struct rpc_pipe_msg *msg) { - struct rpc_inode *rpci = RPC_I(inode); int res = -EPIPE; + struct dentry *dentry; - spin_lock(&inode->i_lock); - if (rpci->ops == NULL) - goto out; - if (rpci->nreaders) { - list_add_tail(&msg->list, &rpci->pipe); - rpci->pipelen += msg->len; + spin_lock(&pipe->lock); + if (pipe->nreaders) { + list_add_tail(&msg->list, &pipe->pipe); + pipe->pipelen += msg->len; res = 0; - } else if (rpci->flags & RPC_PIPE_WAIT_FOR_OPEN) { - if (list_empty(&rpci->pipe)) + } else if (pipe->flags & RPC_PIPE_WAIT_FOR_OPEN) { + if (list_empty(&pipe->pipe)) queue_delayed_work(rpciod_workqueue, - &rpci->queue_timeout, + &pipe->queue_timeout, RPC_UPCALL_TIMEOUT); - list_add_tail(&msg->list, &rpci->pipe); - rpci->pipelen += msg->len; + list_add_tail(&msg->list, &pipe->pipe); + pipe->pipelen += msg->len; res = 0; } -out: - spin_unlock(&inode->i_lock); - wake_up(&rpci->waitq); + dentry = dget(pipe->dentry); + spin_unlock(&pipe->lock); + if (dentry) { + wake_up(&RPC_I(dentry->d_inode)->waitq); + dput(dentry); + } return res; } EXPORT_SYMBOL_GPL(rpc_queue_upcall); @@ -125,29 +168,26 @@ rpc_inode_setowner(struct inode *inode, void *private) static void rpc_close_pipes(struct inode *inode) { - struct rpc_inode *rpci = RPC_I(inode); - const struct rpc_pipe_ops *ops; + struct rpc_pipe *pipe = RPC_I(inode)->pipe; int need_release; + LIST_HEAD(free_list); mutex_lock(&inode->i_mutex); - ops = rpci->ops; - if (ops != NULL) { - LIST_HEAD(free_list); - spin_lock(&inode->i_lock); - need_release = rpci->nreaders != 0 || rpci->nwriters != 0; - rpci->nreaders = 0; - list_splice_init(&rpci->in_upcall, &free_list); - list_splice_init(&rpci->pipe, &free_list); - rpci->pipelen = 0; - rpci->ops = NULL; - spin_unlock(&inode->i_lock); - rpc_purge_list(rpci, &free_list, ops->destroy_msg, -EPIPE); - rpci->nwriters = 0; - if (need_release && ops->release_pipe) - ops->release_pipe(inode); - cancel_delayed_work_sync(&rpci->queue_timeout); - } + spin_lock(&pipe->lock); + need_release = pipe->nreaders != 0 || pipe->nwriters != 0; + pipe->nreaders = 0; + list_splice_init(&pipe->in_upcall, &free_list); + list_splice_init(&pipe->pipe, &free_list); + pipe->pipelen = 0; + pipe->dentry = NULL; + spin_unlock(&pipe->lock); + rpc_purge_list(&RPC_I(inode)->waitq, &free_list, pipe->ops->destroy_msg, -EPIPE); + pipe->nwriters = 0; + if (need_release && pipe->ops->release_pipe) + pipe->ops->release_pipe(inode); + cancel_delayed_work_sync(&pipe->queue_timeout); rpc_inode_setowner(inode, NULL); + RPC_I(inode)->pipe = NULL; mutex_unlock(&inode->i_mutex); } @@ -162,31 +202,39 @@ rpc_alloc_inode(struct super_block *sb) } static void -rpc_destroy_inode(struct inode *inode) +rpc_i_callback(struct rcu_head *head) { + struct inode *inode = container_of(head, struct inode, i_rcu); kmem_cache_free(rpc_inode_cachep, RPC_I(inode)); } +static void +rpc_destroy_inode(struct inode *inode) +{ + call_rcu(&inode->i_rcu, rpc_i_callback); +} + static int rpc_pipe_open(struct inode *inode, struct file *filp) { - struct rpc_inode *rpci = RPC_I(inode); + struct rpc_pipe *pipe; int first_open; int res = -ENXIO; mutex_lock(&inode->i_mutex); - if (rpci->ops == NULL) + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) goto out; - first_open = rpci->nreaders == 0 && rpci->nwriters == 0; - if (first_open && rpci->ops->open_pipe) { - res = rpci->ops->open_pipe(inode); + first_open = pipe->nreaders == 0 && pipe->nwriters == 0; + if (first_open && pipe->ops->open_pipe) { + res = pipe->ops->open_pipe(inode); if (res) goto out; } if (filp->f_mode & FMODE_READ) - rpci->nreaders++; + pipe->nreaders++; if (filp->f_mode & FMODE_WRITE) - rpci->nwriters++; + pipe->nwriters++; res = 0; out: mutex_unlock(&inode->i_mutex); @@ -196,38 +244,39 @@ out: static int rpc_pipe_release(struct inode *inode, struct file *filp) { - struct rpc_inode *rpci = RPC_I(inode); + struct rpc_pipe *pipe; struct rpc_pipe_msg *msg; int last_close; mutex_lock(&inode->i_mutex); - if (rpci->ops == NULL) + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) goto out; - msg = (struct rpc_pipe_msg *)filp->private_data; + msg = filp->private_data; if (msg != NULL) { - spin_lock(&inode->i_lock); + spin_lock(&pipe->lock); msg->errno = -EAGAIN; - list_del(&msg->list); - spin_unlock(&inode->i_lock); - rpci->ops->destroy_msg(msg); + list_del_init(&msg->list); + spin_unlock(&pipe->lock); + pipe->ops->destroy_msg(msg); } if (filp->f_mode & FMODE_WRITE) - rpci->nwriters --; + pipe->nwriters --; if (filp->f_mode & FMODE_READ) { - rpci->nreaders --; - if (rpci->nreaders == 0) { + pipe->nreaders --; + if (pipe->nreaders == 0) { LIST_HEAD(free_list); - spin_lock(&inode->i_lock); - list_splice_init(&rpci->pipe, &free_list); - rpci->pipelen = 0; - spin_unlock(&inode->i_lock); - rpc_purge_list(rpci, &free_list, - rpci->ops->destroy_msg, -EAGAIN); + spin_lock(&pipe->lock); + list_splice_init(&pipe->pipe, &free_list); + pipe->pipelen = 0; + spin_unlock(&pipe->lock); + rpc_purge_list(&RPC_I(inode)->waitq, &free_list, + pipe->ops->destroy_msg, -EAGAIN); } } - last_close = rpci->nwriters == 0 && rpci->nreaders == 0; - if (last_close && rpci->ops->release_pipe) - rpci->ops->release_pipe(inode); + last_close = pipe->nwriters == 0 && pipe->nreaders == 0; + if (last_close && pipe->ops->release_pipe) + pipe->ops->release_pipe(inode); out: mutex_unlock(&inode->i_mutex); return 0; @@ -236,40 +285,41 @@ out: static ssize_t rpc_pipe_read(struct file *filp, char __user *buf, size_t len, loff_t *offset) { - struct inode *inode = filp->f_path.dentry->d_inode; - struct rpc_inode *rpci = RPC_I(inode); + struct inode *inode = file_inode(filp); + struct rpc_pipe *pipe; struct rpc_pipe_msg *msg; int res = 0; mutex_lock(&inode->i_mutex); - if (rpci->ops == NULL) { + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) { res = -EPIPE; goto out_unlock; } msg = filp->private_data; if (msg == NULL) { - spin_lock(&inode->i_lock); - if (!list_empty(&rpci->pipe)) { - msg = list_entry(rpci->pipe.next, + spin_lock(&pipe->lock); + if (!list_empty(&pipe->pipe)) { + msg = list_entry(pipe->pipe.next, struct rpc_pipe_msg, list); - list_move(&msg->list, &rpci->in_upcall); - rpci->pipelen -= msg->len; + list_move(&msg->list, &pipe->in_upcall); + pipe->pipelen -= msg->len; filp->private_data = msg; msg->copied = 0; } - spin_unlock(&inode->i_lock); + spin_unlock(&pipe->lock); if (msg == NULL) goto out_unlock; } /* NOTE: it is up to the callback to update msg->copied */ - res = rpci->ops->upcall(filp, msg, buf, len); + res = pipe->ops->upcall(filp, msg, buf, len); if (res < 0 || msg->len == msg->copied) { filp->private_data = NULL; - spin_lock(&inode->i_lock); - list_del(&msg->list); - spin_unlock(&inode->i_lock); - rpci->ops->destroy_msg(msg); + spin_lock(&pipe->lock); + list_del_init(&msg->list); + spin_unlock(&pipe->lock); + pipe->ops->destroy_msg(msg); } out_unlock: mutex_unlock(&inode->i_mutex); @@ -279,14 +329,13 @@ out_unlock: static ssize_t rpc_pipe_write(struct file *filp, const char __user *buf, size_t len, loff_t *offset) { - struct inode *inode = filp->f_path.dentry->d_inode; - struct rpc_inode *rpci = RPC_I(inode); + struct inode *inode = file_inode(filp); int res; mutex_lock(&inode->i_mutex); res = -EPIPE; - if (rpci->ops != NULL) - res = rpci->ops->downcall(filp, buf, len); + if (RPC_I(inode)->pipe != NULL) + res = RPC_I(inode)->pipe->ops->downcall(filp, buf, len); mutex_unlock(&inode->i_mutex); return res; } @@ -294,37 +343,45 @@ rpc_pipe_write(struct file *filp, const char __user *buf, size_t len, loff_t *of static unsigned int rpc_pipe_poll(struct file *filp, struct poll_table_struct *wait) { - struct rpc_inode *rpci; - unsigned int mask = 0; + struct inode *inode = file_inode(filp); + struct rpc_inode *rpci = RPC_I(inode); + unsigned int mask = POLLOUT | POLLWRNORM; - rpci = RPC_I(filp->f_path.dentry->d_inode); poll_wait(filp, &rpci->waitq, wait); - mask = POLLOUT | POLLWRNORM; - if (rpci->ops == NULL) + mutex_lock(&inode->i_mutex); + if (rpci->pipe == NULL) mask |= POLLERR | POLLHUP; - if (filp->private_data || !list_empty(&rpci->pipe)) + else if (filp->private_data || !list_empty(&rpci->pipe->pipe)) mask |= POLLIN | POLLRDNORM; + mutex_unlock(&inode->i_mutex); return mask; } -static int -rpc_pipe_ioctl(struct inode *ino, struct file *filp, - unsigned int cmd, unsigned long arg) +static long +rpc_pipe_ioctl(struct file *filp, unsigned int cmd, unsigned long arg) { - struct rpc_inode *rpci = RPC_I(filp->f_path.dentry->d_inode); + struct inode *inode = file_inode(filp); + struct rpc_pipe *pipe; int len; switch (cmd) { case FIONREAD: - if (rpci->ops == NULL) + mutex_lock(&inode->i_mutex); + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) { + mutex_unlock(&inode->i_mutex); return -EPIPE; - len = rpci->pipelen; + } + spin_lock(&pipe->lock); + len = pipe->pipelen; if (filp->private_data) { struct rpc_pipe_msg *msg; - msg = (struct rpc_pipe_msg *)filp->private_data; + msg = filp->private_data; len += msg->len - msg->copied; } + spin_unlock(&pipe->lock); + mutex_unlock(&inode->i_mutex); return put_user(len, (int __user *)arg); default: return -EINVAL; @@ -337,7 +394,7 @@ static const struct file_operations rpc_pipe_fops = { .read = rpc_pipe_read, .write = rpc_pipe_write, .poll = rpc_pipe_poll, - .ioctl = rpc_pipe_ioctl, + .unlocked_ioctl = rpc_pipe_ioctl, .open = rpc_pipe_open, .release = rpc_pipe_release, }; @@ -347,33 +404,38 @@ rpc_show_info(struct seq_file *m, void *v) { struct rpc_clnt *clnt = m->private; - seq_printf(m, "RPC server: %s\n", clnt->cl_server); - seq_printf(m, "service: %s (%d) version %d\n", clnt->cl_protname, + rcu_read_lock(); + seq_printf(m, "RPC server: %s\n", + rcu_dereference(clnt->cl_xprt)->servername); + seq_printf(m, "service: %s (%d) version %d\n", clnt->cl_program->name, clnt->cl_prog, clnt->cl_vers); seq_printf(m, "address: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_ADDR)); seq_printf(m, "protocol: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PROTO)); seq_printf(m, "port: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PORT)); + rcu_read_unlock(); return 0; } static int rpc_info_open(struct inode *inode, struct file *file) { - struct rpc_clnt *clnt; + struct rpc_clnt *clnt = NULL; int ret = single_open(file, rpc_show_info, NULL); if (!ret) { struct seq_file *m = file->private_data; - mutex_lock(&inode->i_mutex); - clnt = RPC_I(inode)->private; - if (clnt) { - kref_get(&clnt->cl_kref); + + spin_lock(&file->f_path.dentry->d_lock); + if (!d_unhashed(file->f_path.dentry)) + clnt = RPC_I(inode)->private; + if (clnt != NULL && atomic_inc_not_zero(&clnt->cl_count)) { + spin_unlock(&file->f_path.dentry->d_lock); m->private = clnt; } else { + spin_unlock(&file->f_path.dentry->d_lock); single_release(inode, file); ret = -EINVAL; } - mutex_unlock(&inode->i_mutex); } return ret; } @@ -407,47 +469,22 @@ struct rpc_filelist { umode_t mode; }; -struct vfsmount *rpc_get_mount(void) -{ - int err; - - err = simple_pin_fs(&rpc_pipe_fs_type, &rpc_mount, &rpc_mount_count); - if (err != 0) - return ERR_PTR(err); - return rpc_mount; -} -EXPORT_SYMBOL_GPL(rpc_get_mount); - -void rpc_put_mount(void) -{ - simple_release_fs(&rpc_mount, &rpc_mount_count); -} -EXPORT_SYMBOL_GPL(rpc_put_mount); - -static int rpc_delete_dentry(struct dentry *dentry) -{ - return 1; -} - -static const struct dentry_operations rpc_dentry_operations = { - .d_delete = rpc_delete_dentry, -}; - static struct inode * rpc_get_inode(struct super_block *sb, umode_t mode) { struct inode *inode = new_inode(sb); if (!inode) return NULL; + inode->i_ino = get_next_ino(); inode->i_mode = mode; inode->i_atime = inode->i_mtime = inode->i_ctime = CURRENT_TIME; - switch(mode & S_IFMT) { - case S_IFDIR: - inode->i_fop = &simple_dir_operations; - inode->i_op = &simple_dir_inode_operations; - inc_nlink(inode); - default: - break; + switch (mode & S_IFMT) { + case S_IFDIR: + inode->i_fop = &simple_dir_operations; + inode->i_op = &simple_dir_inode_operations; + inc_nlink(inode); + default: + break; } return inode; } @@ -459,7 +496,7 @@ static int __rpc_create_common(struct inode *dir, struct dentry *dentry, { struct inode *inode; - BUG_ON(!d_unhashed(dentry)); + d_drop(dentry); inode = rpc_get_inode(dir->i_sb, mode); if (!inode) goto out_err; @@ -471,8 +508,8 @@ static int __rpc_create_common(struct inode *dir, struct dentry *dentry, d_add(dentry, inode); return 0; out_err: - printk(KERN_WARNING "%s: %s failed to allocate inode for dentry %s\n", - __FILE__, __func__, dentry->d_name.name); + printk(KERN_WARNING "%s: %s failed to allocate inode for dentry %pd\n", + __FILE__, __func__, dentry); dput(dentry); return -ENOMEM; } @@ -506,12 +543,47 @@ static int __rpc_mkdir(struct inode *dir, struct dentry *dentry, return 0; } -static int __rpc_mkpipe(struct inode *dir, struct dentry *dentry, - umode_t mode, - const struct file_operations *i_fop, - void *private, - const struct rpc_pipe_ops *ops, - int flags) +static void +init_pipe(struct rpc_pipe *pipe) +{ + pipe->nreaders = 0; + pipe->nwriters = 0; + INIT_LIST_HEAD(&pipe->in_upcall); + INIT_LIST_HEAD(&pipe->in_downcall); + INIT_LIST_HEAD(&pipe->pipe); + pipe->pipelen = 0; + INIT_DELAYED_WORK(&pipe->queue_timeout, + rpc_timeout_upcall_queue); + pipe->ops = NULL; + spin_lock_init(&pipe->lock); + pipe->dentry = NULL; +} + +void rpc_destroy_pipe_data(struct rpc_pipe *pipe) +{ + kfree(pipe); +} +EXPORT_SYMBOL_GPL(rpc_destroy_pipe_data); + +struct rpc_pipe *rpc_mkpipe_data(const struct rpc_pipe_ops *ops, int flags) +{ + struct rpc_pipe *pipe; + + pipe = kzalloc(sizeof(struct rpc_pipe), GFP_KERNEL); + if (!pipe) + return ERR_PTR(-ENOMEM); + init_pipe(pipe); + pipe->ops = ops; + pipe->flags = flags; + return pipe; +} +EXPORT_SYMBOL_GPL(rpc_mkpipe_data); + +static int __rpc_mkpipe_dentry(struct inode *dir, struct dentry *dentry, + umode_t mode, + const struct file_operations *i_fop, + void *private, + struct rpc_pipe *pipe) { struct rpc_inode *rpci; int err; @@ -520,10 +592,8 @@ static int __rpc_mkpipe(struct inode *dir, struct dentry *dentry, if (err) return err; rpci = RPC_I(dentry->d_inode); - rpci->nkern_readwriters = 1; rpci->private = private; - rpci->flags = flags; - rpci->ops = ops; + rpci->pipe = pipe; fsnotify_create(dir, dentry); return 0; } @@ -539,6 +609,22 @@ static int __rpc_rmdir(struct inode *dir, struct dentry *dentry) return ret; } +int rpc_rmdir(struct dentry *dentry) +{ + struct dentry *parent; + struct inode *dir; + int error; + + parent = dget_parent(dentry); + dir = parent->d_inode; + mutex_lock_nested(&dir->i_mutex, I_MUTEX_PARENT); + error = __rpc_rmdir(dir, dentry); + mutex_unlock(&dir->i_mutex); + dput(parent); + return error; +} +EXPORT_SYMBOL_GPL(rpc_rmdir); + static int __rpc_unlink(struct inode *dir, struct dentry *dentry) { int ret; @@ -553,40 +639,21 @@ static int __rpc_unlink(struct inode *dir, struct dentry *dentry) static int __rpc_rmpipe(struct inode *dir, struct dentry *dentry) { struct inode *inode = dentry->d_inode; - struct rpc_inode *rpci = RPC_I(inode); - rpci->nkern_readwriters--; - if (rpci->nkern_readwriters != 0) - return 0; rpc_close_pipes(inode); return __rpc_unlink(dir, dentry); } -static struct dentry *__rpc_lookup_create(struct dentry *parent, - struct qstr *name) +static struct dentry *__rpc_lookup_create_exclusive(struct dentry *parent, + const char *name) { - struct dentry *dentry; - - dentry = d_lookup(parent, name); + struct qstr q = QSTR_INIT(name, strlen(name)); + struct dentry *dentry = d_hash_and_lookup(parent, &q); if (!dentry) { - dentry = d_alloc(parent, name); - if (!dentry) { - dentry = ERR_PTR(-ENOMEM); - goto out_err; - } + dentry = d_alloc(parent, &q); + if (!dentry) + return ERR_PTR(-ENOMEM); } - if (!dentry->d_inode) - dentry->d_op = &rpc_dentry_operations; -out_err: - return dentry; -} - -static struct dentry *__rpc_lookup_create_exclusive(struct dentry *parent, - struct qstr *name) -{ - struct dentry *dentry; - - dentry = __rpc_lookup_create(parent, name); if (dentry->d_inode == NULL) return dentry; dput(dentry); @@ -608,8 +675,7 @@ static void __rpc_depopulate(struct dentry *parent, for (i = start; i < eof; i++) { name.name = files[i].name; name.len = strlen(files[i].name); - name.hash = full_name_hash(name.name, name.len); - dentry = d_lookup(parent, &name); + dentry = d_hash_and_lookup(parent, &name); if (dentry == NULL) continue; @@ -651,12 +717,7 @@ static int rpc_populate(struct dentry *parent, mutex_lock(&dir->i_mutex); for (i = start; i < eof; i++) { - struct qstr q; - - q.name = files[i].name; - q.len = strlen(files[i].name); - q.hash = full_name_hash(q.name, q.len); - dentry = __rpc_lookup_create_exclusive(parent, &q); + dentry = __rpc_lookup_create_exclusive(parent, files[i].name); err = PTR_ERR(dentry); if (IS_ERR(dentry)) goto out_bad; @@ -683,13 +744,13 @@ static int rpc_populate(struct dentry *parent, out_bad: __rpc_depopulate(parent, files, start, eof); mutex_unlock(&dir->i_mutex); - printk(KERN_WARNING "%s: %s failed to populate directory %s\n", - __FILE__, __func__, parent->d_name.name); + printk(KERN_WARNING "%s: %s failed to populate directory %pd\n", + __FILE__, __func__, parent); return err; } static struct dentry *rpc_mkdir_populate(struct dentry *parent, - struct qstr *name, umode_t mode, void *private, + const char *name, umode_t mode, void *private, int (*populate)(struct dentry *, void *), void *args_populate) { struct dentry *dentry; @@ -741,9 +802,7 @@ static int rpc_rmdir_depopulate(struct dentry *dentry, * @parent: dentry of directory to create new "pipe" in * @name: name of pipe * @private: private data to associate with the pipe, for the caller's use - * @ops: operations defining the behavior of the pipe: upcall, downcall, - * release_pipe, open_pipe, and destroy_msg. - * @flags: rpc_inode flags + * @pipe: &rpc_pipe containing input parameters * * Data is made available for userspace to read by calls to * rpc_queue_upcall(). The actual reads will result in calls to @@ -754,46 +813,27 @@ static int rpc_rmdir_depopulate(struct dentry *dentry, * responses to upcalls. They will result in calls to @msg->downcall. * * The @private argument passed here will be available to all these methods - * from the file pointer, via RPC_I(file->f_dentry->d_inode)->private. + * from the file pointer, via RPC_I(file_inode(file))->private. */ -struct dentry *rpc_mkpipe(struct dentry *parent, const char *name, - void *private, const struct rpc_pipe_ops *ops, - int flags) +struct dentry *rpc_mkpipe_dentry(struct dentry *parent, const char *name, + void *private, struct rpc_pipe *pipe) { struct dentry *dentry; struct inode *dir = parent->d_inode; umode_t umode = S_IFIFO | S_IRUSR | S_IWUSR; - struct qstr q; int err; - if (ops->upcall == NULL) + if (pipe->ops->upcall == NULL) umode &= ~S_IRUGO; - if (ops->downcall == NULL) + if (pipe->ops->downcall == NULL) umode &= ~S_IWUGO; - q.name = name; - q.len = strlen(name); - q.hash = full_name_hash(q.name, q.len), - mutex_lock_nested(&dir->i_mutex, I_MUTEX_PARENT); - dentry = __rpc_lookup_create(parent, &q); + dentry = __rpc_lookup_create_exclusive(parent, name); if (IS_ERR(dentry)) goto out; - if (dentry->d_inode) { - struct rpc_inode *rpci = RPC_I(dentry->d_inode); - if (rpci->private != private || - rpci->ops != ops || - rpci->flags != flags) { - dput (dentry); - err = -EBUSY; - goto out_err; - } - rpci->nkern_readwriters++; - goto out; - } - - err = __rpc_mkpipe(dir, dentry, umode, &rpc_pipe_fops, - private, ops, flags); + err = __rpc_mkpipe_dentry(dir, dentry, umode, &rpc_pipe_fops, + private, pipe); if (err) goto out_err; out: @@ -801,12 +841,12 @@ out: return dentry; out_err: dentry = ERR_PTR(err); - printk(KERN_WARNING "%s: %s() failed to create pipe %s/%s (errno = %d)\n", - __FILE__, __func__, parent->d_name.name, name, + printk(KERN_WARNING "%s: %s() failed to create pipe %pd/%s (errno = %d)\n", + __FILE__, __func__, parent, name, err); goto out; } -EXPORT_SYMBOL_GPL(rpc_mkpipe); +EXPORT_SYMBOL_GPL(rpc_mkpipe_dentry); /** * rpc_unlink - remove a pipe @@ -833,6 +873,159 @@ rpc_unlink(struct dentry *dentry) } EXPORT_SYMBOL_GPL(rpc_unlink); +/** + * rpc_init_pipe_dir_head - initialise a struct rpc_pipe_dir_head + * @pdh: pointer to struct rpc_pipe_dir_head + */ +void rpc_init_pipe_dir_head(struct rpc_pipe_dir_head *pdh) +{ + INIT_LIST_HEAD(&pdh->pdh_entries); + pdh->pdh_dentry = NULL; +} +EXPORT_SYMBOL_GPL(rpc_init_pipe_dir_head); + +/** + * rpc_init_pipe_dir_object - initialise a struct rpc_pipe_dir_object + * @pdo: pointer to struct rpc_pipe_dir_object + * @pdo_ops: pointer to const struct rpc_pipe_dir_object_ops + * @pdo_data: pointer to caller-defined data + */ +void rpc_init_pipe_dir_object(struct rpc_pipe_dir_object *pdo, + const struct rpc_pipe_dir_object_ops *pdo_ops, + void *pdo_data) +{ + INIT_LIST_HEAD(&pdo->pdo_head); + pdo->pdo_ops = pdo_ops; + pdo->pdo_data = pdo_data; +} +EXPORT_SYMBOL_GPL(rpc_init_pipe_dir_object); + +static int +rpc_add_pipe_dir_object_locked(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + int ret = 0; + + if (pdh->pdh_dentry) + ret = pdo->pdo_ops->create(pdh->pdh_dentry, pdo); + if (ret == 0) + list_add_tail(&pdo->pdo_head, &pdh->pdh_entries); + return ret; +} + +static void +rpc_remove_pipe_dir_object_locked(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + if (pdh->pdh_dentry) + pdo->pdo_ops->destroy(pdh->pdh_dentry, pdo); + list_del_init(&pdo->pdo_head); +} + +/** + * rpc_add_pipe_dir_object - associate a rpc_pipe_dir_object to a directory + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @pdo: pointer to struct rpc_pipe_dir_object + * + */ +int +rpc_add_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + int ret = 0; + + if (list_empty(&pdo->pdo_head)) { + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + ret = rpc_add_pipe_dir_object_locked(net, pdh, pdo); + mutex_unlock(&sn->pipefs_sb_lock); + } + return ret; +} +EXPORT_SYMBOL_GPL(rpc_add_pipe_dir_object); + +/** + * rpc_remove_pipe_dir_object - remove a rpc_pipe_dir_object from a directory + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @pdo: pointer to struct rpc_pipe_dir_object + * + */ +void +rpc_remove_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + if (!list_empty(&pdo->pdo_head)) { + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + rpc_remove_pipe_dir_object_locked(net, pdh, pdo); + mutex_unlock(&sn->pipefs_sb_lock); + } +} +EXPORT_SYMBOL_GPL(rpc_remove_pipe_dir_object); + +/** + * rpc_find_or_alloc_pipe_dir_object + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @match: match struct rpc_pipe_dir_object to data + * @alloc: allocate a new struct rpc_pipe_dir_object + * @data: user defined data for match() and alloc() + * + */ +struct rpc_pipe_dir_object * +rpc_find_or_alloc_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + int (*match)(struct rpc_pipe_dir_object *, void *), + struct rpc_pipe_dir_object *(*alloc)(void *), + void *data) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_pipe_dir_object *pdo; + + mutex_lock(&sn->pipefs_sb_lock); + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) { + if (!match(pdo, data)) + continue; + goto out; + } + pdo = alloc(data); + if (!pdo) + goto out; + rpc_add_pipe_dir_object_locked(net, pdh, pdo); +out: + mutex_unlock(&sn->pipefs_sb_lock); + return pdo; +} +EXPORT_SYMBOL_GPL(rpc_find_or_alloc_pipe_dir_object); + +static void +rpc_create_pipe_dir_objects(struct rpc_pipe_dir_head *pdh) +{ + struct rpc_pipe_dir_object *pdo; + struct dentry *dir = pdh->pdh_dentry; + + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) + pdo->pdo_ops->create(dir, pdo); +} + +static void +rpc_destroy_pipe_dir_objects(struct rpc_pipe_dir_head *pdh) +{ + struct rpc_pipe_dir_object *pdo; + struct dentry *dir = pdh->pdh_dentry; + + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) + pdo->pdo_ops->destroy(dir, pdo); +} + enum { RPCAUTH_info, RPCAUTH_EOF @@ -860,8 +1053,8 @@ static void rpc_clntdir_depopulate(struct dentry *dentry) /** * rpc_create_client_dir - Create a new rpc_client directory in rpc_pipefs - * @dentry: dentry from the rpc_pipefs root to the new directory - * @name: &struct qstr for the name + * @dentry: the parent of new directory + * @name: the name of new directory * @rpc_client: rpc client to associate with this directory * * This creates a directory at the given @path associated with @@ -870,19 +1063,32 @@ static void rpc_clntdir_depopulate(struct dentry *dentry) * later be created using rpc_mkpipe(). */ struct dentry *rpc_create_client_dir(struct dentry *dentry, - struct qstr *name, + const char *name, struct rpc_clnt *rpc_client) { - return rpc_mkdir_populate(dentry, name, S_IRUGO | S_IXUGO, NULL, + struct dentry *ret; + + ret = rpc_mkdir_populate(dentry, name, S_IRUGO | S_IXUGO, NULL, rpc_clntdir_populate, rpc_client); + if (!IS_ERR(ret)) { + rpc_client->cl_pipedir_objects.pdh_dentry = ret; + rpc_create_pipe_dir_objects(&rpc_client->cl_pipedir_objects); + } + return ret; } /** * rpc_remove_client_dir - Remove a directory created with rpc_create_client_dir() - * @dentry: directory to remove + * @rpc_client: rpc_client for the pipe */ -int rpc_remove_client_dir(struct dentry *dentry) +int rpc_remove_client_dir(struct rpc_clnt *rpc_client) { + struct dentry *dentry = rpc_client->cl_pipedir_objects.pdh_dentry; + + if (dentry == NULL) + return 0; + rpc_destroy_pipe_dir_objects(&rpc_client->cl_pipedir_objects); + rpc_client->cl_pipedir_objects.pdh_dentry = NULL; return rpc_rmdir_depopulate(dentry, rpc_clntdir_depopulate); } @@ -916,8 +1122,8 @@ static void rpc_cachedir_depopulate(struct dentry *dentry) rpc_depopulate(dentry, cache_pipefs_files, 0, 3); } -struct dentry *rpc_create_cache_dir(struct dentry *parent, struct qstr *name, - mode_t umode, struct cache_detail *cd) +struct dentry *rpc_create_cache_dir(struct dentry *parent, const char *name, + umode_t umode, struct cache_detail *cd) { return rpc_mkdir_populate(parent, name, umode, NULL, rpc_cachedir_populate, cd); @@ -950,6 +1156,8 @@ enum { RPCAUTH_statd, RPCAUTH_nfsd4_cb, RPCAUTH_cache, + RPCAUTH_nfsd, + RPCAUTH_gssd, RPCAUTH_RootEOF }; @@ -982,46 +1190,297 @@ static const struct rpc_filelist files[] = { .name = "cache", .mode = S_IFDIR | S_IRUGO | S_IXUGO, }, + [RPCAUTH_nfsd] = { + .name = "nfsd", + .mode = S_IFDIR | S_IRUGO | S_IXUGO, + }, + [RPCAUTH_gssd] = { + .name = "gssd", + .mode = S_IFDIR | S_IRUGO | S_IXUGO, + }, +}; + +/* + * This call can be used only in RPC pipefs mount notification hooks. + */ +struct dentry *rpc_d_lookup_sb(const struct super_block *sb, + const unsigned char *dir_name) +{ + struct qstr dir = QSTR_INIT(dir_name, strlen(dir_name)); + return d_hash_and_lookup(sb->s_root, &dir); +} +EXPORT_SYMBOL_GPL(rpc_d_lookup_sb); + +int rpc_pipefs_init_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + sn->gssd_dummy = rpc_mkpipe_data(&gssd_dummy_pipe_ops, 0); + if (IS_ERR(sn->gssd_dummy)) + return PTR_ERR(sn->gssd_dummy); + + mutex_init(&sn->pipefs_sb_lock); + sn->pipe_version = -1; + return 0; +} + +void rpc_pipefs_exit_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + rpc_destroy_pipe_data(sn->gssd_dummy); +} + +/* + * This call will be used for per network namespace operations calls. + * Note: Function will be returned with pipefs_sb_lock taken if superblock was + * found. This lock have to be released by rpc_put_sb_net() when all operations + * will be completed. + */ +struct super_block *rpc_get_sb_net(const struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + if (sn->pipefs_sb) + return sn->pipefs_sb; + mutex_unlock(&sn->pipefs_sb_lock); + return NULL; +} +EXPORT_SYMBOL_GPL(rpc_get_sb_net); + +void rpc_put_sb_net(const struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + WARN_ON(sn->pipefs_sb == NULL); + mutex_unlock(&sn->pipefs_sb_lock); +} +EXPORT_SYMBOL_GPL(rpc_put_sb_net); + +static const struct rpc_filelist gssd_dummy_clnt_dir[] = { + [0] = { + .name = "clntXX", + .mode = S_IFDIR | S_IRUGO | S_IXUGO, + }, }; +static ssize_t +dummy_downcall(struct file *filp, const char __user *src, size_t len) +{ + return -EINVAL; +} + +static const struct rpc_pipe_ops gssd_dummy_pipe_ops = { + .upcall = rpc_pipe_generic_upcall, + .downcall = dummy_downcall, +}; + +/* + * Here we present a bogus "info" file to keep rpc.gssd happy. We don't expect + * that it will ever use this info to handle an upcall, but rpc.gssd expects + * that this file will be there and have a certain format. + */ +static int +rpc_show_dummy_info(struct seq_file *m, void *v) +{ + seq_printf(m, "RPC server: %s\n", utsname()->nodename); + seq_printf(m, "service: foo (1) version 0\n"); + seq_printf(m, "address: 127.0.0.1\n"); + seq_printf(m, "protocol: tcp\n"); + seq_printf(m, "port: 0\n"); + return 0; +} + +static int +rpc_dummy_info_open(struct inode *inode, struct file *file) +{ + return single_open(file, rpc_show_dummy_info, NULL); +} + +static const struct file_operations rpc_dummy_info_operations = { + .owner = THIS_MODULE, + .open = rpc_dummy_info_open, + .read = seq_read, + .llseek = seq_lseek, + .release = single_release, +}; + +static const struct rpc_filelist gssd_dummy_info_file[] = { + [0] = { + .name = "info", + .i_fop = &rpc_dummy_info_operations, + .mode = S_IFREG | S_IRUSR, + }, +}; + +/** + * rpc_gssd_dummy_populate - create a dummy gssd pipe + * @root: root of the rpc_pipefs filesystem + * @pipe_data: pipe data created when netns is initialized + * + * Create a dummy set of directories and a pipe that gssd can hold open to + * indicate that it is up and running. + */ +static struct dentry * +rpc_gssd_dummy_populate(struct dentry *root, struct rpc_pipe *pipe_data) +{ + int ret = 0; + struct dentry *gssd_dentry; + struct dentry *clnt_dentry = NULL; + struct dentry *pipe_dentry = NULL; + struct qstr q = QSTR_INIT(files[RPCAUTH_gssd].name, + strlen(files[RPCAUTH_gssd].name)); + + /* We should never get this far if "gssd" doesn't exist */ + gssd_dentry = d_hash_and_lookup(root, &q); + if (!gssd_dentry) + return ERR_PTR(-ENOENT); + + ret = rpc_populate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1, NULL); + if (ret) { + pipe_dentry = ERR_PTR(ret); + goto out; + } + + q.name = gssd_dummy_clnt_dir[0].name; + q.len = strlen(gssd_dummy_clnt_dir[0].name); + clnt_dentry = d_hash_and_lookup(gssd_dentry, &q); + if (!clnt_dentry) { + pipe_dentry = ERR_PTR(-ENOENT); + goto out; + } + + ret = rpc_populate(clnt_dentry, gssd_dummy_info_file, 0, 1, NULL); + if (ret) { + __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1); + pipe_dentry = ERR_PTR(ret); + goto out; + } + + pipe_dentry = rpc_mkpipe_dentry(clnt_dentry, "gssd", NULL, pipe_data); + if (IS_ERR(pipe_dentry)) { + __rpc_depopulate(clnt_dentry, gssd_dummy_info_file, 0, 1); + __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1); + } +out: + dput(clnt_dentry); + dput(gssd_dentry); + return pipe_dentry; +} + +static void +rpc_gssd_dummy_depopulate(struct dentry *pipe_dentry) +{ + struct dentry *clnt_dir = pipe_dentry->d_parent; + struct dentry *gssd_dir = clnt_dir->d_parent; + + __rpc_rmpipe(clnt_dir->d_inode, pipe_dentry); + __rpc_depopulate(clnt_dir, gssd_dummy_info_file, 0, 1); + __rpc_depopulate(gssd_dir, gssd_dummy_clnt_dir, 0, 1); + dput(pipe_dentry); +} + static int rpc_fill_super(struct super_block *sb, void *data, int silent) { struct inode *inode; - struct dentry *root; + struct dentry *root, *gssd_dentry; + struct net *net = data; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + int err; sb->s_blocksize = PAGE_CACHE_SIZE; sb->s_blocksize_bits = PAGE_CACHE_SHIFT; sb->s_magic = RPCAUTH_GSSMAGIC; sb->s_op = &s_ops; + sb->s_d_op = &simple_dentry_operations; sb->s_time_gran = 1; - inode = rpc_get_inode(sb, S_IFDIR | 0755); - if (!inode) + inode = rpc_get_inode(sb, S_IFDIR | S_IRUGO | S_IXUGO); + sb->s_root = root = d_make_root(inode); + if (!root) return -ENOMEM; - sb->s_root = root = d_alloc_root(inode); - if (!root) { - iput(inode); - return -ENOMEM; - } if (rpc_populate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF, NULL)) return -ENOMEM; + + gssd_dentry = rpc_gssd_dummy_populate(root, sn->gssd_dummy); + if (IS_ERR(gssd_dentry)) { + __rpc_depopulate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF); + return PTR_ERR(gssd_dentry); + } + + dprintk("RPC: sending pipefs MOUNT notification for net %p%s\n", + net, NET_NAME(net)); + mutex_lock(&sn->pipefs_sb_lock); + sn->pipefs_sb = sb; + err = blocking_notifier_call_chain(&rpc_pipefs_notifier_list, + RPC_PIPEFS_MOUNT, + sb); + if (err) + goto err_depopulate; + sb->s_fs_info = get_net(net); + mutex_unlock(&sn->pipefs_sb_lock); return 0; + +err_depopulate: + rpc_gssd_dummy_depopulate(gssd_dentry); + blocking_notifier_call_chain(&rpc_pipefs_notifier_list, + RPC_PIPEFS_UMOUNT, + sb); + sn->pipefs_sb = NULL; + __rpc_depopulate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF); + mutex_unlock(&sn->pipefs_sb_lock); + return err; } -static int -rpc_get_sb(struct file_system_type *fs_type, - int flags, const char *dev_name, void *data, struct vfsmount *mnt) +bool +gssd_running(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_pipe *pipe = sn->gssd_dummy; + + return pipe->nreaders || pipe->nwriters; +} +EXPORT_SYMBOL_GPL(gssd_running); + +static struct dentry * +rpc_mount(struct file_system_type *fs_type, + int flags, const char *dev_name, void *data) +{ + return mount_ns(fs_type, flags, current->nsproxy->net_ns, rpc_fill_super); +} + +static void rpc_kill_sb(struct super_block *sb) { - return get_sb_single(fs_type, flags, data, rpc_fill_super, mnt); + struct net *net = sb->s_fs_info; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + if (sn->pipefs_sb != sb) { + mutex_unlock(&sn->pipefs_sb_lock); + goto out; + } + sn->pipefs_sb = NULL; + dprintk("RPC: sending pipefs UMOUNT notification for net %p%s\n", + net, NET_NAME(net)); + blocking_notifier_call_chain(&rpc_pipefs_notifier_list, + RPC_PIPEFS_UMOUNT, + sb); + mutex_unlock(&sn->pipefs_sb_lock); + put_net(net); +out: + kill_litter_super(sb); } static struct file_system_type rpc_pipe_fs_type = { .owner = THIS_MODULE, .name = "rpc_pipefs", - .get_sb = rpc_get_sb, - .kill_sb = kill_litter_super, + .mount = rpc_mount, + .kill_sb = rpc_kill_sb, }; +MODULE_ALIAS_FS("rpc_pipefs"); +MODULE_ALIAS("rpc_pipefs"); static void init_once(void *foo) @@ -1030,16 +1489,8 @@ init_once(void *foo) inode_init_once(&rpci->vfs_inode); rpci->private = NULL; - rpci->nreaders = 0; - rpci->nwriters = 0; - INIT_LIST_HEAD(&rpci->in_upcall); - INIT_LIST_HEAD(&rpci->in_downcall); - INIT_LIST_HEAD(&rpci->pipe); - rpci->pipelen = 0; + rpci->pipe = NULL; init_waitqueue_head(&rpci->waitq); - INIT_DELAYED_WORK(&rpci->queue_timeout, - rpc_timeout_upcall_queue); - rpci->ops = NULL; } int register_rpc_pipefs(void) @@ -1053,17 +1504,24 @@ int register_rpc_pipefs(void) init_once); if (!rpc_inode_cachep) return -ENOMEM; + err = rpc_clients_notifier_register(); + if (err) + goto err_notifier; err = register_filesystem(&rpc_pipe_fs_type); - if (err) { - kmem_cache_destroy(rpc_inode_cachep); - return err; - } - + if (err) + goto err_register; return 0; + +err_register: + rpc_clients_notifier_unregister(); +err_notifier: + kmem_cache_destroy(rpc_inode_cachep); + return err; } void unregister_rpc_pipefs(void) { + rpc_clients_notifier_unregister(); kmem_cache_destroy(rpc_inode_cachep); unregister_filesystem(&rpc_pipe_fs_type); } diff --git a/net/sunrpc/rpcb_clnt.c b/net/sunrpc/rpcb_clnt.c index 3e3772d8eb9..1891a1022c1 100644 --- a/net/sunrpc/rpcb_clnt.c +++ b/net/sunrpc/rpcb_clnt.c @@ -16,21 +16,28 @@ #include <linux/types.h> #include <linux/socket.h> +#include <linux/un.h> #include <linux/in.h> #include <linux/in6.h> #include <linux/kernel.h> #include <linux/errno.h> #include <linux/mutex.h> +#include <linux/slab.h> #include <net/ipv6.h> #include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/addr.h> #include <linux/sunrpc/sched.h> #include <linux/sunrpc/xprtsock.h> +#include "netns.h" + #ifdef RPC_DEBUG # define RPCDBG_FACILITY RPCDBG_BIND #endif +#define RPCBIND_SOCK_PATHNAME "/var/run/rpcbind.sock" + #define RPCBIND_PROGRAM (100000u) #define RPCBIND_PORT (111u) @@ -56,10 +63,6 @@ enum { RPCBPROC_GETSTAT, }; -#define RPCB_HIGHPROC_2 RPCBPROC_CALLIT -#define RPCB_HIGHPROC_3 RPCBPROC_TADDR2UADDR -#define RPCB_HIGHPROC_4 RPCBPROC_GETSTAT - /* * r_owner * @@ -109,10 +112,7 @@ enum { static void rpcb_getport_done(struct rpc_task *, void *); static void rpcb_map_release(void *data); -static struct rpc_program rpcb_program; - -static struct rpc_clnt * rpcb_local_clnt; -static struct rpc_clnt * rpcb_local_clnt4; +static const struct rpc_program rpcb_program; struct rpcbind_args { struct rpc_xprt * r_xprt; @@ -137,8 +137,8 @@ struct rpcb_info { struct rpc_procinfo * rpc_proc; }; -static struct rpcb_info rpcb_next_version[]; -static struct rpcb_info rpcb_next_version6[]; +static const struct rpcb_info rpcb_next_version[]; +static const struct rpcb_info rpcb_next_version6[]; static const struct rpc_call_ops rpcb_getport_ops = { .rpc_call_done = rpcb_getport_done, @@ -161,21 +161,137 @@ static void rpcb_map_release(void *data) kfree(map); } -static const struct sockaddr_in rpcb_inaddr_loopback = { - .sin_family = AF_INET, - .sin_addr.s_addr = htonl(INADDR_LOOPBACK), - .sin_port = htons(RPCBIND_PORT), -}; +static int rpcb_get_local(struct net *net) +{ + int cnt; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + spin_lock(&sn->rpcb_clnt_lock); + if (sn->rpcb_users) + sn->rpcb_users++; + cnt = sn->rpcb_users; + spin_unlock(&sn->rpcb_clnt_lock); + + return cnt; +} -static DEFINE_MUTEX(rpcb_create_local_mutex); +void rpcb_put_local(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt = sn->rpcb_local_clnt; + struct rpc_clnt *clnt4 = sn->rpcb_local_clnt4; + int shutdown = 0; + + spin_lock(&sn->rpcb_clnt_lock); + if (sn->rpcb_users) { + if (--sn->rpcb_users == 0) { + sn->rpcb_local_clnt = NULL; + sn->rpcb_local_clnt4 = NULL; + } + shutdown = !sn->rpcb_users; + } + spin_unlock(&sn->rpcb_clnt_lock); + + if (shutdown) { + /* + * cleanup_rpcb_clnt - remove xprtsock's sysctls, unregister + */ + if (clnt4) + rpc_shutdown_client(clnt4); + if (clnt) + rpc_shutdown_client(clnt); + } +} + +static void rpcb_set_local(struct net *net, struct rpc_clnt *clnt, + struct rpc_clnt *clnt4, + bool is_af_local) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + /* Protected by rpcb_create_local_mutex */ + sn->rpcb_local_clnt = clnt; + sn->rpcb_local_clnt4 = clnt4; + sn->rpcb_is_af_local = is_af_local ? 1 : 0; + smp_wmb(); + sn->rpcb_users = 1; + dprintk("RPC: created new rpcb local clients (rpcb_local_clnt: " + "%p, rpcb_local_clnt4: %p) for net %p%s\n", + sn->rpcb_local_clnt, sn->rpcb_local_clnt4, + net, (net == &init_net) ? " (init_net)" : ""); +} /* * Returns zero on success, otherwise a negative errno value * is returned. */ -static int rpcb_create_local(void) +static int rpcb_create_local_unix(struct net *net) { + static const struct sockaddr_un rpcb_localaddr_rpcbind = { + .sun_family = AF_LOCAL, + .sun_path = RPCBIND_SOCK_PATHNAME, + }; struct rpc_create_args args = { + .net = net, + .protocol = XPRT_TRANSPORT_LOCAL, + .address = (struct sockaddr *)&rpcb_localaddr_rpcbind, + .addrsize = sizeof(rpcb_localaddr_rpcbind), + .servername = "localhost", + .program = &rpcb_program, + .version = RPCBVERS_2, + .authflavor = RPC_AUTH_NULL, + /* + * We turn off the idle timeout to prevent the kernel + * from automatically disconnecting the socket. + * Otherwise, we'd have to cache the mount namespace + * of the caller and somehow pass that to the socket + * reconnect code. + */ + .flags = RPC_CLNT_CREATE_NO_IDLE_TIMEOUT, + }; + struct rpc_clnt *clnt, *clnt4; + int result = 0; + + /* + * Because we requested an RPC PING at transport creation time, + * this works only if the user space portmapper is rpcbind, and + * it's listening on AF_LOCAL on the named socket. + */ + clnt = rpc_create(&args); + if (IS_ERR(clnt)) { + dprintk("RPC: failed to create AF_LOCAL rpcbind " + "client (errno %ld).\n", PTR_ERR(clnt)); + result = PTR_ERR(clnt); + goto out; + } + + clnt4 = rpc_bind_new_program(clnt, &rpcb_program, RPCBVERS_4); + if (IS_ERR(clnt4)) { + dprintk("RPC: failed to bind second program to " + "rpcbind v4 client (errno %ld).\n", + PTR_ERR(clnt4)); + clnt4 = NULL; + } + + rpcb_set_local(net, clnt, clnt4, true); + +out: + return result; +} + +/* + * Returns zero on success, otherwise a negative errno value + * is returned. + */ +static int rpcb_create_local_net(struct net *net) +{ + static const struct sockaddr_in rpcb_inaddr_loopback = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_LOOPBACK), + .sin_port = htons(RPCBIND_PORT), + }; + struct rpc_create_args args = { + .net = net, .protocol = XPRT_TRANSPORT_TCP, .address = (struct sockaddr *)&rpcb_inaddr_loopback, .addrsize = sizeof(rpcb_inaddr_loopback), @@ -188,18 +304,11 @@ static int rpcb_create_local(void) struct rpc_clnt *clnt, *clnt4; int result = 0; - if (rpcb_local_clnt) - return result; - - mutex_lock(&rpcb_create_local_mutex); - if (rpcb_local_clnt) - goto out; - clnt = rpc_create(&args); if (IS_ERR(clnt)) { dprintk("RPC: failed to create local rpcbind " "client (errno %ld).\n", PTR_ERR(clnt)); - result = -PTR_ERR(clnt); + result = PTR_ERR(clnt); goto out; } @@ -210,23 +319,48 @@ static int rpcb_create_local(void) */ clnt4 = rpc_bind_new_program(clnt, &rpcb_program, RPCBVERS_4); if (IS_ERR(clnt4)) { - dprintk("RPC: failed to create local rpcbind v4 " - "cleint (errno %ld).\n", PTR_ERR(clnt4)); + dprintk("RPC: failed to bind second program to " + "rpcbind v4 client (errno %ld).\n", + PTR_ERR(clnt4)); clnt4 = NULL; } - rpcb_local_clnt = clnt; - rpcb_local_clnt4 = clnt4; + rpcb_set_local(net, clnt, clnt4, false); + +out: + return result; +} + +/* + * Returns zero on success, otherwise a negative errno value + * is returned. + */ +int rpcb_create_local(struct net *net) +{ + static DEFINE_MUTEX(rpcb_create_local_mutex); + int result = 0; + + if (rpcb_get_local(net)) + return result; + + mutex_lock(&rpcb_create_local_mutex); + if (rpcb_get_local(net)) + goto out; + + if (rpcb_create_local_unix(net) != 0) + result = rpcb_create_local_net(net); out: mutex_unlock(&rpcb_create_local_mutex); return result; } -static struct rpc_clnt *rpcb_create(char *hostname, struct sockaddr *srvaddr, - size_t salen, int proto, u32 version) +static struct rpc_clnt *rpcb_create(struct net *net, const char *hostname, + struct sockaddr *srvaddr, size_t salen, + int proto, u32 version) { struct rpc_create_args args = { + .net = net, .protocol = proto, .address = srvaddr, .addrsize = salen, @@ -246,19 +380,22 @@ static struct rpc_clnt *rpcb_create(char *hostname, struct sockaddr *srvaddr, ((struct sockaddr_in6 *)srvaddr)->sin6_port = htons(RPCBIND_PORT); break; default: - return NULL; + return ERR_PTR(-EAFNOSUPPORT); } return rpc_create(&args); } -static int rpcb_register_call(struct rpc_clnt *clnt, struct rpc_message *msg) +static int rpcb_register_call(struct sunrpc_net *sn, struct rpc_clnt *clnt, struct rpc_message *msg, bool is_set) { - int result, error = 0; + int flags = RPC_TASK_NOCONNECT; + int error, result = 0; + if (is_set || !sn->rpcb_is_af_local) + flags = RPC_TASK_SOFTCONN; msg->rpc_resp = &result; - error = rpc_call_sync(clnt, msg, RPC_TASK_SOFTCONN); + error = rpc_call_sync(clnt, msg, flags); if (error < 0) { dprintk("RPC: failed to contact local rpcbind " "server (errno %d).\n", -error); @@ -272,6 +409,7 @@ static int rpcb_register_call(struct rpc_clnt *clnt, struct rpc_message *msg) /** * rpcb_register - set or unset a port registration with the local rpcbind svc + * @net: target network namespace * @prog: RPC program number to bind * @vers: RPC version number to bind * @prot: transport protocol to register @@ -302,7 +440,7 @@ static int rpcb_register_call(struct rpc_clnt *clnt, struct rpc_message *msg) * IN6ADDR_ANY (ie available for all AF_INET and AF_INET6 * addresses). */ -int rpcb_register(u32 prog, u32 vers, int prot, unsigned short port) +int rpcb_register(struct net *net, u32 prog, u32 vers, int prot, unsigned short port) { struct rpcbind_args map = { .r_prog = prog, @@ -313,35 +451,36 @@ int rpcb_register(u32 prog, u32 vers, int prot, unsigned short port) struct rpc_message msg = { .rpc_argp = &map, }; - int error; - - error = rpcb_create_local(); - if (error) - return error; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + bool is_set = false; dprintk("RPC: %sregistering (%u, %u, %d, %u) with local " "rpcbind\n", (port ? "" : "un"), prog, vers, prot, port); msg.rpc_proc = &rpcb_procedures2[RPCBPROC_UNSET]; - if (port) + if (port != 0) { msg.rpc_proc = &rpcb_procedures2[RPCBPROC_SET]; + is_set = true; + } - return rpcb_register_call(rpcb_local_clnt, &msg); + return rpcb_register_call(sn, sn->rpcb_local_clnt, &msg, is_set); } /* * Fill in AF_INET family-specific arguments to register */ -static int rpcb_register_inet4(const struct sockaddr *sap, +static int rpcb_register_inet4(struct sunrpc_net *sn, + const struct sockaddr *sap, struct rpc_message *msg) { const struct sockaddr_in *sin = (const struct sockaddr_in *)sap; struct rpcbind_args *map = msg->rpc_argp; unsigned short port = ntohs(sin->sin_port); + bool is_set = false; int result; - map->r_addr = rpc_sockaddr2uaddr(sap); + map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL); dprintk("RPC: %sregistering [%u, %u, %s, '%s'] with " "local rpcbind\n", (port ? "" : "un"), @@ -349,10 +488,12 @@ static int rpcb_register_inet4(const struct sockaddr *sap, map->r_addr, map->r_netid); msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; - if (port) + if (port != 0) { msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET]; + is_set = true; + } - result = rpcb_register_call(rpcb_local_clnt4, msg); + result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set); kfree(map->r_addr); return result; } @@ -360,15 +501,17 @@ static int rpcb_register_inet4(const struct sockaddr *sap, /* * Fill in AF_INET6 family-specific arguments to register */ -static int rpcb_register_inet6(const struct sockaddr *sap, +static int rpcb_register_inet6(struct sunrpc_net *sn, + const struct sockaddr *sap, struct rpc_message *msg) { const struct sockaddr_in6 *sin6 = (const struct sockaddr_in6 *)sap; struct rpcbind_args *map = msg->rpc_argp; unsigned short port = ntohs(sin6->sin6_port); + bool is_set = false; int result; - map->r_addr = rpc_sockaddr2uaddr(sap); + map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL); dprintk("RPC: %sregistering [%u, %u, %s, '%s'] with " "local rpcbind\n", (port ? "" : "un"), @@ -376,15 +519,18 @@ static int rpcb_register_inet6(const struct sockaddr *sap, map->r_addr, map->r_netid); msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; - if (port) + if (port != 0) { msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET]; + is_set = true; + } - result = rpcb_register_call(rpcb_local_clnt4, msg); + result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set); kfree(map->r_addr); return result; } -static int rpcb_unregister_all_protofamilies(struct rpc_message *msg) +static int rpcb_unregister_all_protofamilies(struct sunrpc_net *sn, + struct rpc_message *msg) { struct rpcbind_args *map = msg->rpc_argp; @@ -395,11 +541,12 @@ static int rpcb_unregister_all_protofamilies(struct rpc_message *msg) map->r_addr = ""; msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; - return rpcb_register_call(rpcb_local_clnt4, msg); + return rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, false); } /** * rpcb_v4_register - set or unset a port registration with the local rpcbind + * @net: target network namespace * @program: RPC program number of service to (un)register * @version: RPC version number of service to (un)register * @address: address family, IP address, and port to (un)register @@ -441,7 +588,7 @@ static int rpcb_unregister_all_protofamilies(struct rpc_message *msg) * service on any IPv4 address, but not on IPv6. The latter * advertises the service on all IPv4 and IPv6 addresses. */ -int rpcb_v4_register(const u32 program, const u32 version, +int rpcb_v4_register(struct net *net, const u32 program, const u32 version, const struct sockaddr *address, const char *netid) { struct rpcbind_args map = { @@ -453,78 +600,24 @@ int rpcb_v4_register(const u32 program, const u32 version, struct rpc_message msg = { .rpc_argp = &map, }; - int error; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); - error = rpcb_create_local(); - if (error) - return error; - if (rpcb_local_clnt4 == NULL) + if (sn->rpcb_local_clnt4 == NULL) return -EPROTONOSUPPORT; if (address == NULL) - return rpcb_unregister_all_protofamilies(&msg); + return rpcb_unregister_all_protofamilies(sn, &msg); switch (address->sa_family) { case AF_INET: - return rpcb_register_inet4(address, &msg); + return rpcb_register_inet4(sn, address, &msg); case AF_INET6: - return rpcb_register_inet6(address, &msg); + return rpcb_register_inet6(sn, address, &msg); } return -EAFNOSUPPORT; } -/** - * rpcb_getport_sync - obtain the port for an RPC service on a given host - * @sin: address of remote peer - * @prog: RPC program number to bind - * @vers: RPC version number to bind - * @prot: transport protocol to use to make this request - * - * Return value is the requested advertised port number, - * or a negative errno value. - * - * Called from outside the RPC client in a synchronous task context. - * Uses default timeout parameters specified by underlying transport. - * - * XXX: Needs to support IPv6 - */ -int rpcb_getport_sync(struct sockaddr_in *sin, u32 prog, u32 vers, int prot) -{ - struct rpcbind_args map = { - .r_prog = prog, - .r_vers = vers, - .r_prot = prot, - .r_port = 0, - }; - struct rpc_message msg = { - .rpc_proc = &rpcb_procedures2[RPCBPROC_GETPORT], - .rpc_argp = &map, - .rpc_resp = &map, - }; - struct rpc_clnt *rpcb_clnt; - int status; - - dprintk("RPC: %s(%pI4, %u, %u, %d)\n", - __func__, &sin->sin_addr.s_addr, prog, vers, prot); - - rpcb_clnt = rpcb_create(NULL, (struct sockaddr *)sin, - sizeof(*sin), prot, RPCBVERS_2); - if (IS_ERR(rpcb_clnt)) - return PTR_ERR(rpcb_clnt); - - status = rpc_call_sync(rpcb_clnt, &msg, 0); - rpc_shutdown_client(rpcb_clnt); - - if (status >= 0) { - if (map.r_port != 0) - return map.r_port; - status = -EACCES; - } - return status; -} -EXPORT_SYMBOL_GPL(rpcb_getport_sync); - static struct rpc_task *rpcb_call_async(struct rpc_clnt *rpcb_clnt, struct rpcbind_args *map, struct rpc_procinfo *proc) { struct rpc_message msg = { @@ -553,9 +646,10 @@ static struct rpc_task *rpcb_call_async(struct rpc_clnt *rpcb_clnt, struct rpcbi static struct rpc_clnt *rpcb_find_transport_owner(struct rpc_clnt *clnt) { struct rpc_clnt *parent = clnt->cl_parent; + struct rpc_xprt *xprt = rcu_dereference(clnt->cl_xprt); while (parent != clnt) { - if (parent->cl_xprt != clnt->cl_xprt) + if (rcu_dereference(parent->cl_xprt) != xprt) break; if (clnt->cl_autobind) break; @@ -579,19 +673,23 @@ void rpcb_getport_async(struct rpc_task *task) u32 bind_version; struct rpc_xprt *xprt; struct rpc_clnt *rpcb_clnt; - static struct rpcbind_args *map; + struct rpcbind_args *map; struct rpc_task *child; struct sockaddr_storage addr; struct sockaddr *sap = (struct sockaddr *)&addr; size_t salen; int status; - clnt = rpcb_find_transport_owner(task->tk_client); - xprt = clnt->cl_xprt; + rcu_read_lock(); + do { + clnt = rpcb_find_transport_owner(task->tk_client); + xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); + } while (xprt == NULL); + rcu_read_unlock(); dprintk("RPC: %5u %s(%s, %u, %u, %d)\n", task->tk_pid, __func__, - clnt->cl_server, clnt->cl_prog, clnt->cl_vers, xprt->prot); + xprt->servername, clnt->cl_prog, clnt->cl_vers, xprt->prot); /* Put self on the wait queue to ensure we get notified if * some other task is already attempting to bind the port */ @@ -600,6 +698,7 @@ void rpcb_getport_async(struct rpc_task *task) if (xprt_test_and_set_binding(xprt)) { dprintk("RPC: %5u %s: waiting for another binder\n", task->tk_pid, __func__); + xprt_put(xprt); return; } @@ -641,8 +740,8 @@ void rpcb_getport_async(struct rpc_task *task) dprintk("RPC: %5u %s: trying rpcbind version %u\n", task->tk_pid, __func__, bind_version); - rpcb_clnt = rpcb_create(clnt->cl_server, sap, salen, xprt->prot, - bind_version); + rpcb_clnt = rpcb_create(xprt->xprt_net, xprt->servername, sap, salen, + xprt->prot, bind_version); if (IS_ERR(rpcb_clnt)) { status = PTR_ERR(rpcb_clnt); dprintk("RPC: %5u %s: rpcb_create failed, error %ld\n", @@ -661,14 +760,14 @@ void rpcb_getport_async(struct rpc_task *task) map->r_vers = clnt->cl_vers; map->r_prot = xprt->prot; map->r_port = 0; - map->r_xprt = xprt_get(xprt); + map->r_xprt = xprt; map->r_status = -EIO; switch (bind_version) { case RPCBVERS_4: case RPCBVERS_3: - map->r_netid = rpc_peeraddr2str(clnt, RPC_DISPLAY_NETID); - map->r_addr = rpc_sockaddr2uaddr(sap); + map->r_netid = xprt->address_strings[RPC_DISPLAY_NETID]; + map->r_addr = rpc_sockaddr2uaddr(sap, GFP_ATOMIC); map->r_owner = ""; break; case RPCBVERS_2: @@ -696,6 +795,7 @@ bailout_release_client: bailout_nofree: rpcb_wake_rpcbind_waiters(xprt, status); task->tk_status = status; + xprt_put(xprt); } EXPORT_SYMBOL_GPL(rpcb_getport_async); @@ -740,144 +840,114 @@ static void rpcb_getport_done(struct rpc_task *child, void *data) * XDR functions for rpcbind */ -static int rpcb_enc_mapping(struct rpc_rqst *req, __be32 *p, - const struct rpcbind_args *rpcb) +static void rpcb_enc_mapping(struct rpc_rqst *req, struct xdr_stream *xdr, + const struct rpcbind_args *rpcb) { - struct rpc_task *task = req->rq_task; - struct xdr_stream xdr; + __be32 *p; dprintk("RPC: %5u encoding PMAP_%s call (%u, %u, %d, %u)\n", - task->tk_pid, task->tk_msg.rpc_proc->p_name, + req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name, rpcb->r_prog, rpcb->r_vers, rpcb->r_prot, rpcb->r_port); - xdr_init_encode(&xdr, &req->rq_snd_buf, p); - - p = xdr_reserve_space(&xdr, sizeof(__be32) * RPCB_mappingargs_sz); - if (unlikely(p == NULL)) - return -EIO; - - *p++ = htonl(rpcb->r_prog); - *p++ = htonl(rpcb->r_vers); - *p++ = htonl(rpcb->r_prot); - *p = htonl(rpcb->r_port); - - return 0; + p = xdr_reserve_space(xdr, RPCB_mappingargs_sz << 2); + *p++ = cpu_to_be32(rpcb->r_prog); + *p++ = cpu_to_be32(rpcb->r_vers); + *p++ = cpu_to_be32(rpcb->r_prot); + *p = cpu_to_be32(rpcb->r_port); } -static int rpcb_dec_getport(struct rpc_rqst *req, __be32 *p, +static int rpcb_dec_getport(struct rpc_rqst *req, struct xdr_stream *xdr, struct rpcbind_args *rpcb) { - struct rpc_task *task = req->rq_task; - struct xdr_stream xdr; unsigned long port; - - xdr_init_decode(&xdr, &req->rq_rcv_buf, p); + __be32 *p; rpcb->r_port = 0; - p = xdr_inline_decode(&xdr, sizeof(__be32)); + p = xdr_inline_decode(xdr, 4); if (unlikely(p == NULL)) return -EIO; - port = ntohl(*p); - dprintk("RPC: %5u PMAP_%s result: %lu\n", task->tk_pid, - task->tk_msg.rpc_proc->p_name, port); - if (unlikely(port > USHORT_MAX)) + port = be32_to_cpup(p); + dprintk("RPC: %5u PMAP_%s result: %lu\n", req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name, port); + if (unlikely(port > USHRT_MAX)) return -EIO; rpcb->r_port = port; return 0; } -static int rpcb_dec_set(struct rpc_rqst *req, __be32 *p, +static int rpcb_dec_set(struct rpc_rqst *req, struct xdr_stream *xdr, unsigned int *boolp) { - struct rpc_task *task = req->rq_task; - struct xdr_stream xdr; - - xdr_init_decode(&xdr, &req->rq_rcv_buf, p); + __be32 *p; - p = xdr_inline_decode(&xdr, sizeof(__be32)); + p = xdr_inline_decode(xdr, 4); if (unlikely(p == NULL)) return -EIO; *boolp = 0; - if (*p) + if (*p != xdr_zero) *boolp = 1; dprintk("RPC: %5u RPCB_%s call %s\n", - task->tk_pid, task->tk_msg.rpc_proc->p_name, + req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name, (*boolp ? "succeeded" : "failed")); return 0; } -static int encode_rpcb_string(struct xdr_stream *xdr, const char *string, - const u32 maxstrlen) +static void encode_rpcb_string(struct xdr_stream *xdr, const char *string, + const u32 maxstrlen) { - u32 len; __be32 *p; + u32 len; - if (unlikely(string == NULL)) - return -EIO; len = strlen(string); - if (unlikely(len > maxstrlen)) - return -EIO; - - p = xdr_reserve_space(xdr, sizeof(__be32) + len); - if (unlikely(p == NULL)) - return -EIO; + WARN_ON_ONCE(len > maxstrlen); + if (len > maxstrlen) + /* truncate and hope for the best */ + len = maxstrlen; + p = xdr_reserve_space(xdr, 4 + len); xdr_encode_opaque(p, string, len); - - return 0; } -static int rpcb_enc_getaddr(struct rpc_rqst *req, __be32 *p, - const struct rpcbind_args *rpcb) +static void rpcb_enc_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, + const struct rpcbind_args *rpcb) { - struct rpc_task *task = req->rq_task; - struct xdr_stream xdr; + __be32 *p; dprintk("RPC: %5u encoding RPCB_%s call (%u, %u, '%s', '%s')\n", - task->tk_pid, task->tk_msg.rpc_proc->p_name, + req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name, rpcb->r_prog, rpcb->r_vers, rpcb->r_netid, rpcb->r_addr); - xdr_init_encode(&xdr, &req->rq_snd_buf, p); + p = xdr_reserve_space(xdr, (RPCB_program_sz + RPCB_version_sz) << 2); + *p++ = cpu_to_be32(rpcb->r_prog); + *p = cpu_to_be32(rpcb->r_vers); - p = xdr_reserve_space(&xdr, - sizeof(__be32) * (RPCB_program_sz + RPCB_version_sz)); - if (unlikely(p == NULL)) - return -EIO; - *p++ = htonl(rpcb->r_prog); - *p = htonl(rpcb->r_vers); - - if (encode_rpcb_string(&xdr, rpcb->r_netid, RPCBIND_MAXNETIDLEN)) - return -EIO; - if (encode_rpcb_string(&xdr, rpcb->r_addr, RPCBIND_MAXUADDRLEN)) - return -EIO; - if (encode_rpcb_string(&xdr, rpcb->r_owner, RPCB_MAXOWNERLEN)) - return -EIO; - - return 0; + encode_rpcb_string(xdr, rpcb->r_netid, RPCBIND_MAXNETIDLEN); + encode_rpcb_string(xdr, rpcb->r_addr, RPCBIND_MAXUADDRLEN); + encode_rpcb_string(xdr, rpcb->r_owner, RPCB_MAXOWNERLEN); } -static int rpcb_dec_getaddr(struct rpc_rqst *req, __be32 *p, +static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, struct rpcbind_args *rpcb) { struct sockaddr_storage address; struct sockaddr *sap = (struct sockaddr *)&address; - struct rpc_task *task = req->rq_task; - struct xdr_stream xdr; + __be32 *p; u32 len; rpcb->r_port = 0; - xdr_init_decode(&xdr, &req->rq_rcv_buf, p); - - p = xdr_inline_decode(&xdr, sizeof(__be32)); + p = xdr_inline_decode(xdr, 4); if (unlikely(p == NULL)) goto out_fail; - len = ntohl(*p); + len = be32_to_cpup(p); /* * If the returned universal address is a null string, @@ -885,20 +955,21 @@ static int rpcb_dec_getaddr(struct rpc_rqst *req, __be32 *p, */ if (len == 0) { dprintk("RPC: %5u RPCB reply: program not registered\n", - task->tk_pid); + req->rq_task->tk_pid); return 0; } if (unlikely(len > RPCBIND_MAXUADDRLEN)) goto out_fail; - p = xdr_inline_decode(&xdr, len); + p = xdr_inline_decode(xdr, len); if (unlikely(p == NULL)) goto out_fail; - dprintk("RPC: %5u RPCB_%s reply: %s\n", task->tk_pid, - task->tk_msg.rpc_proc->p_name, (char *)p); + dprintk("RPC: %5u RPCB_%s reply: %s\n", req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name, (char *)p); - if (rpc_uaddr2sockaddr((char *)p, len, sap, sizeof(address)) == 0) + if (rpc_uaddr2sockaddr(req->rq_xprt->xprt_net, (char *)p, len, + sap, sizeof(address)) == 0) goto out_fail; rpcb->r_port = rpc_get_port(sap); @@ -906,7 +977,8 @@ static int rpcb_dec_getaddr(struct rpc_rqst *req, __be32 *p, out_fail: dprintk("RPC: %5u malformed RPCB_%s reply\n", - task->tk_pid, task->tk_msg.rpc_proc->p_name); + req->rq_task->tk_pid, + req->rq_task->tk_msg.rpc_proc->p_name); return -EIO; } @@ -918,8 +990,8 @@ out_fail: static struct rpc_procinfo rpcb_procedures2[] = { [RPCBPROC_SET] = { .p_proc = RPCBPROC_SET, - .p_encode = (kxdrproc_t)rpcb_enc_mapping, - .p_decode = (kxdrproc_t)rpcb_dec_set, + .p_encode = (kxdreproc_t)rpcb_enc_mapping, + .p_decode = (kxdrdproc_t)rpcb_dec_set, .p_arglen = RPCB_mappingargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_SET, @@ -928,8 +1000,8 @@ static struct rpc_procinfo rpcb_procedures2[] = { }, [RPCBPROC_UNSET] = { .p_proc = RPCBPROC_UNSET, - .p_encode = (kxdrproc_t)rpcb_enc_mapping, - .p_decode = (kxdrproc_t)rpcb_dec_set, + .p_encode = (kxdreproc_t)rpcb_enc_mapping, + .p_decode = (kxdrdproc_t)rpcb_dec_set, .p_arglen = RPCB_mappingargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_UNSET, @@ -938,8 +1010,8 @@ static struct rpc_procinfo rpcb_procedures2[] = { }, [RPCBPROC_GETPORT] = { .p_proc = RPCBPROC_GETPORT, - .p_encode = (kxdrproc_t)rpcb_enc_mapping, - .p_decode = (kxdrproc_t)rpcb_dec_getport, + .p_encode = (kxdreproc_t)rpcb_enc_mapping, + .p_decode = (kxdrdproc_t)rpcb_dec_getport, .p_arglen = RPCB_mappingargs_sz, .p_replen = RPCB_getportres_sz, .p_statidx = RPCBPROC_GETPORT, @@ -951,8 +1023,8 @@ static struct rpc_procinfo rpcb_procedures2[] = { static struct rpc_procinfo rpcb_procedures3[] = { [RPCBPROC_SET] = { .p_proc = RPCBPROC_SET, - .p_encode = (kxdrproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrproc_t)rpcb_dec_set, + .p_encode = (kxdreproc_t)rpcb_enc_getaddr, + .p_decode = (kxdrdproc_t)rpcb_dec_set, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_SET, @@ -961,8 +1033,8 @@ static struct rpc_procinfo rpcb_procedures3[] = { }, [RPCBPROC_UNSET] = { .p_proc = RPCBPROC_UNSET, - .p_encode = (kxdrproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrproc_t)rpcb_dec_set, + .p_encode = (kxdreproc_t)rpcb_enc_getaddr, + .p_decode = (kxdrdproc_t)rpcb_dec_set, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_UNSET, @@ -971,8 +1043,8 @@ static struct rpc_procinfo rpcb_procedures3[] = { }, [RPCBPROC_GETADDR] = { .p_proc = RPCBPROC_GETADDR, - .p_encode = (kxdrproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrproc_t)rpcb_dec_getaddr, + .p_encode = (kxdreproc_t)rpcb_enc_getaddr, + .p_decode = (kxdrdproc_t)rpcb_dec_getaddr, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_getaddrres_sz, .p_statidx = RPCBPROC_GETADDR, @@ -984,8 +1056,8 @@ static struct rpc_procinfo rpcb_procedures3[] = { static struct rpc_procinfo rpcb_procedures4[] = { [RPCBPROC_SET] = { .p_proc = RPCBPROC_SET, - .p_encode = (kxdrproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrproc_t)rpcb_dec_set, + .p_encode = (kxdreproc_t)rpcb_enc_getaddr, + .p_decode = (kxdrdproc_t)rpcb_dec_set, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_SET, @@ -994,8 +1066,8 @@ static struct rpc_procinfo rpcb_procedures4[] = { }, [RPCBPROC_UNSET] = { .p_proc = RPCBPROC_UNSET, - .p_encode = (kxdrproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrproc_t)rpcb_dec_set, + .p_encode = (kxdreproc_t)rpcb_enc_getaddr, + .p_decode = (kxdrdproc_t)rpcb_dec_set, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_UNSET, @@ -1004,8 +1076,8 @@ static struct rpc_procinfo rpcb_procedures4[] = { }, [RPCBPROC_GETADDR] = { .p_proc = RPCBPROC_GETADDR, - .p_encode = (kxdrproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrproc_t)rpcb_dec_getaddr, + .p_encode = (kxdreproc_t)rpcb_enc_getaddr, + .p_decode = (kxdrdproc_t)rpcb_dec_getaddr, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_getaddrres_sz, .p_statidx = RPCBPROC_GETADDR, @@ -1014,7 +1086,7 @@ static struct rpc_procinfo rpcb_procedures4[] = { }, }; -static struct rpcb_info rpcb_next_version[] = { +static const struct rpcb_info rpcb_next_version[] = { { .rpc_vers = RPCBVERS_2, .rpc_proc = &rpcb_procedures2[RPCBPROC_GETPORT], @@ -1024,7 +1096,7 @@ static struct rpcb_info rpcb_next_version[] = { }, }; -static struct rpcb_info rpcb_next_version6[] = { +static const struct rpcb_info rpcb_next_version6[] = { { .rpc_vers = RPCBVERS_4, .rpc_proc = &rpcb_procedures4[RPCBPROC_GETADDR], @@ -1038,25 +1110,25 @@ static struct rpcb_info rpcb_next_version6[] = { }, }; -static struct rpc_version rpcb_version2 = { +static const struct rpc_version rpcb_version2 = { .number = RPCBVERS_2, - .nrprocs = RPCB_HIGHPROC_2, + .nrprocs = ARRAY_SIZE(rpcb_procedures2), .procs = rpcb_procedures2 }; -static struct rpc_version rpcb_version3 = { +static const struct rpc_version rpcb_version3 = { .number = RPCBVERS_3, - .nrprocs = RPCB_HIGHPROC_3, + .nrprocs = ARRAY_SIZE(rpcb_procedures3), .procs = rpcb_procedures3 }; -static struct rpc_version rpcb_version4 = { +static const struct rpc_version rpcb_version4 = { .number = RPCBVERS_4, - .nrprocs = RPCB_HIGHPROC_4, + .nrprocs = ARRAY_SIZE(rpcb_procedures4), .procs = rpcb_procedures4 }; -static struct rpc_version *rpcb_version[] = { +static const struct rpc_version *rpcb_version[] = { NULL, NULL, &rpcb_version2, @@ -1066,22 +1138,10 @@ static struct rpc_version *rpcb_version[] = { static struct rpc_stat rpcb_stats; -static struct rpc_program rpcb_program = { +static const struct rpc_program rpcb_program = { .name = "rpcbind", .number = RPCBIND_PROGRAM, .nrvers = ARRAY_SIZE(rpcb_version), .version = rpcb_version, .stats = &rpcb_stats, }; - -/** - * cleanup_rpcb_clnt - remove xprtsock's sysctls, unregister - * - */ -void cleanup_rpcb_clnt(void) -{ - if (rpcb_local_clnt4) - rpc_shutdown_client(rpcb_local_clnt4); - if (rpcb_local_clnt) - rpc_shutdown_client(rpcb_local_clnt); -} diff --git a/net/sunrpc/sched.c b/net/sunrpc/sched.c index aae6907fd54..c0365c14b85 100644 --- a/net/sunrpc/sched.c +++ b/net/sunrpc/sched.c @@ -18,6 +18,7 @@ #include <linux/smp.h> #include <linux/spinlock.h> #include <linux/mutex.h> +#include <linux/freezer.h> #include <linux/sunrpc/clnt.h> @@ -25,9 +26,11 @@ #ifdef RPC_DEBUG #define RPCDBG_FACILITY RPCDBG_SCHED -#define RPC_TASK_MAGIC_ID 0xf00baa #endif +#define CREATE_TRACE_POINTS +#include <trace/events/sunrpc.h> + /* * RPC slabs and memory pools */ @@ -95,18 +98,55 @@ __rpc_add_timer(struct rpc_wait_queue *queue, struct rpc_task *task) list_add(&task->u.tk_wait.timer_list, &queue->timer_list.list); } +static void rpc_rotate_queue_owner(struct rpc_wait_queue *queue) +{ + struct list_head *q = &queue->tasks[queue->priority]; + struct rpc_task *task; + + if (!list_empty(q)) { + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); + if (task->tk_owner == queue->owner) + list_move_tail(&task->u.tk_wait.list, q); + } +} + +static void rpc_set_waitqueue_priority(struct rpc_wait_queue *queue, int priority) +{ + if (queue->priority != priority) { + /* Fairness: rotate the list when changing priority */ + rpc_rotate_queue_owner(queue); + queue->priority = priority; + } +} + +static void rpc_set_waitqueue_owner(struct rpc_wait_queue *queue, pid_t pid) +{ + queue->owner = pid; + queue->nr = RPC_BATCH_COUNT; +} + +static void rpc_reset_waitqueue_priority(struct rpc_wait_queue *queue) +{ + rpc_set_waitqueue_priority(queue, queue->maxpriority); + rpc_set_waitqueue_owner(queue, 0); +} + /* * Add new request to a priority queue. */ -static void __rpc_add_wait_queue_priority(struct rpc_wait_queue *queue, struct rpc_task *task) +static void __rpc_add_wait_queue_priority(struct rpc_wait_queue *queue, + struct rpc_task *task, + unsigned char queue_priority) { struct list_head *q; struct rpc_task *t; INIT_LIST_HEAD(&task->u.tk_wait.links); - q = &queue->tasks[task->tk_priority]; - if (unlikely(task->tk_priority > queue->maxpriority)) - q = &queue->tasks[queue->maxpriority]; + if (unlikely(queue_priority > queue->maxpriority)) + queue_priority = queue->maxpriority; + if (queue_priority > queue->priority) + rpc_set_waitqueue_priority(queue, queue_priority); + q = &queue->tasks[queue_priority]; list_for_each_entry(t, q, u.tk_wait.list) { if (t->tk_owner == task->tk_owner) { list_add_tail(&task->u.tk_wait.list, &t->u.tk_wait.links); @@ -124,18 +164,24 @@ static void __rpc_add_wait_queue_priority(struct rpc_wait_queue *queue, struct r * improve overall performance. * Everyone else gets appended to the queue to ensure proper FIFO behavior. */ -static void __rpc_add_wait_queue(struct rpc_wait_queue *queue, struct rpc_task *task) +static void __rpc_add_wait_queue(struct rpc_wait_queue *queue, + struct rpc_task *task, + unsigned char queue_priority) { - BUG_ON (RPC_IS_QUEUED(task)); + WARN_ON_ONCE(RPC_IS_QUEUED(task)); + if (RPC_IS_QUEUED(task)) + return; if (RPC_IS_PRIORITY(queue)) - __rpc_add_wait_queue_priority(queue, task); + __rpc_add_wait_queue_priority(queue, task, queue_priority); else if (RPC_IS_SWAPPER(task)) list_add(&task->u.tk_wait.list, &queue->tasks[0]); else list_add_tail(&task->u.tk_wait.list, &queue->tasks[0]); task->tk_waitqueue = queue; queue->qlen++; + /* barrier matches the read in rpc_wake_up_task_queue_locked() */ + smp_wmb(); rpc_set_queued(task); dprintk("RPC: %5u added to queue %p \"%s\"\n", @@ -171,24 +217,6 @@ static void __rpc_remove_wait_queue(struct rpc_wait_queue *queue, struct rpc_tas task->tk_pid, queue, rpc_qname(queue)); } -static inline void rpc_set_waitqueue_priority(struct rpc_wait_queue *queue, int priority) -{ - queue->priority = priority; - queue->count = 1 << (priority * 2); -} - -static inline void rpc_set_waitqueue_owner(struct rpc_wait_queue *queue, pid_t pid) -{ - queue->owner = pid; - queue->nr = RPC_BATCH_COUNT; -} - -static inline void rpc_reset_waitqueue_priority(struct rpc_wait_queue *queue) -{ - rpc_set_waitqueue_priority(queue, queue->maxpriority); - rpc_set_waitqueue_owner(queue, 0); -} - static void __rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const char *qname, unsigned char nr_queues) { int i; @@ -201,9 +229,7 @@ static void __rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const c queue->qlen = 0; setup_timer(&queue->timer_list.timer, __rpc_queue_timer_fn, (unsigned long)queue); INIT_LIST_HEAD(&queue->timer_list.list); -#ifdef RPC_DEBUG - queue->name = qname; -#endif + rpc_assign_waitqueue_name(queue, qname); } void rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const char *qname) @@ -228,16 +254,15 @@ static int rpc_wait_bit_killable(void *word) { if (fatal_signal_pending(current)) return -ERESTARTSYS; - schedule(); + freezable_schedule_unsafe(); return 0; } -#ifdef RPC_DEBUG +#if defined(RPC_DEBUG) || defined(RPC_TRACEPOINTS) static void rpc_task_set_debuginfo(struct rpc_task *task) { static atomic_t rpc_pid; - task->tk_magic = RPC_TASK_MAGIC_ID; task->tk_pid = atomic_inc_return(&rpc_pid); } #else @@ -248,38 +273,47 @@ static inline void rpc_task_set_debuginfo(struct rpc_task *task) static void rpc_set_active(struct rpc_task *task) { - struct rpc_clnt *clnt; - if (test_and_set_bit(RPC_TASK_ACTIVE, &task->tk_runstate) != 0) - return; + trace_rpc_task_begin(task->tk_client, task, NULL); + rpc_task_set_debuginfo(task); - /* Add to global list of all tasks */ - clnt = task->tk_client; - if (clnt != NULL) { - spin_lock(&clnt->cl_lock); - list_add_tail(&task->tk_task, &clnt->cl_tasks); - spin_unlock(&clnt->cl_lock); - } + set_bit(RPC_TASK_ACTIVE, &task->tk_runstate); } /* * Mark an RPC call as having completed by clearing the 'active' bit + * and then waking up all tasks that were sleeping. */ -static void rpc_mark_complete_task(struct rpc_task *task) +static int rpc_complete_task(struct rpc_task *task) { - smp_mb__before_clear_bit(); + void *m = &task->tk_runstate; + wait_queue_head_t *wq = bit_waitqueue(m, RPC_TASK_ACTIVE); + struct wait_bit_key k = __WAIT_BIT_KEY_INITIALIZER(m, RPC_TASK_ACTIVE); + unsigned long flags; + int ret; + + trace_rpc_task_complete(task->tk_client, task, NULL); + + spin_lock_irqsave(&wq->lock, flags); clear_bit(RPC_TASK_ACTIVE, &task->tk_runstate); - smp_mb__after_clear_bit(); - wake_up_bit(&task->tk_runstate, RPC_TASK_ACTIVE); + ret = atomic_dec_and_test(&task->tk_count); + if (waitqueue_active(wq)) + __wake_up_locked_key(wq, TASK_NORMAL, &k); + spin_unlock_irqrestore(&wq->lock, flags); + return ret; } /* * Allow callers to wait for completion of an RPC call + * + * Note the use of out_of_line_wait_on_bit() rather than wait_on_bit() + * to enforce taking of the wq->lock and hence avoid races with + * rpc_complete_task(). */ int __rpc_wait_for_completion_task(struct rpc_task *task, int (*action)(void *)) { if (action == NULL) action = rpc_wait_bit_killable; - return wait_on_bit(&task->tk_runstate, RPC_TASK_ACTIVE, + return out_of_line_wait_on_bit(&task->tk_runstate, RPC_TASK_ACTIVE, action, TASK_KILLABLE); } EXPORT_SYMBOL_GPL(__rpc_wait_for_completion_task); @@ -287,24 +321,24 @@ EXPORT_SYMBOL_GPL(__rpc_wait_for_completion_task); /* * Make an RPC task runnable. * - * Note: If the task is ASYNC, this must be called with - * the spinlock held to protect the wait queue operation. + * Note: If the task is ASYNC, and is being made runnable after sitting on an + * rpc_wait_queue, this must be called with the queue spinlock held to protect + * the wait queue operation. + * Note the ordering of rpc_test_and_set_running() and rpc_clear_queued(), + * which is needed to ensure that __rpc_execute() doesn't loop (due to the + * lockless RPC_IS_QUEUED() test) before we've had a chance to test + * the RPC_TASK_RUNNING flag. */ static void rpc_make_runnable(struct rpc_task *task) { + bool need_wakeup = !rpc_test_and_set_running(task); + rpc_clear_queued(task); - if (rpc_test_and_set_running(task)) + if (!need_wakeup) return; if (RPC_IS_ASYNC(task)) { - int status; - INIT_WORK(&task->u.tk_work, rpc_async_schedule); - status = queue_work(rpciod_workqueue, &task->u.tk_work); - if (status < 0) { - printk(KERN_WARNING "RPC: failed to add task to queue: error: %d!\n", status); - task->tk_status = status; - return; - } + queue_work(rpciod_workqueue, &task->u.tk_work); } else wake_up_bit(&task->tk_runstate, RPC_TASK_QUEUED); } @@ -315,20 +349,19 @@ static void rpc_make_runnable(struct rpc_task *task) * NB: An RPC task will only receive interrupt-driven events as long * as it's on a wait queue. */ -static void __rpc_sleep_on(struct rpc_wait_queue *q, struct rpc_task *task, - rpc_action action) +static void __rpc_sleep_on_priority(struct rpc_wait_queue *q, + struct rpc_task *task, + rpc_action action, + unsigned char queue_priority) { dprintk("RPC: %5u sleep_on(queue \"%s\" time %lu)\n", task->tk_pid, rpc_qname(q), jiffies); - if (!RPC_IS_ASYNC(task) && !RPC_IS_ACTIVATED(task)) { - printk(KERN_ERR "RPC: Inactive synchronous task put to sleep!\n"); - return; - } + trace_rpc_task_sleep(task->tk_client, task, q); - __rpc_add_wait_queue(q, task); + __rpc_add_wait_queue(q, task, queue_priority); - BUG_ON(task->tk_callback != NULL); + WARN_ON_ONCE(task->tk_callback != NULL); task->tk_callback = action; __rpc_add_timer(q, task); } @@ -336,18 +369,43 @@ static void __rpc_sleep_on(struct rpc_wait_queue *q, struct rpc_task *task, void rpc_sleep_on(struct rpc_wait_queue *q, struct rpc_task *task, rpc_action action) { - /* Mark the task as being activated if so needed */ - rpc_set_active(task); + /* We shouldn't ever put an inactive task to sleep */ + WARN_ON_ONCE(!RPC_IS_ACTIVATED(task)); + if (!RPC_IS_ACTIVATED(task)) { + task->tk_status = -EIO; + rpc_put_task_async(task); + return; + } /* * Protect the queue operations. */ spin_lock_bh(&q->lock); - __rpc_sleep_on(q, task, action); + __rpc_sleep_on_priority(q, task, action, task->tk_priority); spin_unlock_bh(&q->lock); } EXPORT_SYMBOL_GPL(rpc_sleep_on); +void rpc_sleep_on_priority(struct rpc_wait_queue *q, struct rpc_task *task, + rpc_action action, int priority) +{ + /* We shouldn't ever put an inactive task to sleep */ + WARN_ON_ONCE(!RPC_IS_ACTIVATED(task)); + if (!RPC_IS_ACTIVATED(task)) { + task->tk_status = -EIO; + rpc_put_task_async(task); + return; + } + + /* + * Protect the queue operations. + */ + spin_lock_bh(&q->lock); + __rpc_sleep_on_priority(q, task, action, priority - RPC_PRIORITY_LOW); + spin_unlock_bh(&q->lock); +} +EXPORT_SYMBOL_GPL(rpc_sleep_on_priority); + /** * __rpc_do_wake_up_task - wake up a single rpc_task * @queue: wait queue @@ -360,15 +418,14 @@ static void __rpc_do_wake_up_task(struct rpc_wait_queue *queue, struct rpc_task dprintk("RPC: %5u __rpc_wake_up_task (now %lu)\n", task->tk_pid, jiffies); -#ifdef RPC_DEBUG - BUG_ON(task->tk_magic != RPC_TASK_MAGIC_ID); -#endif /* Has the task been executed yet? If not, we cannot wake it up! */ if (!RPC_IS_ACTIVATED(task)) { printk(KERN_ERR "RPC: Inactive task (%p) being woken up!\n", task); return; } + trace_rpc_task_wakeup(task->tk_client, task, queue); + __rpc_remove_wait_queue(queue, task); rpc_make_runnable(task); @@ -381,23 +438,12 @@ static void __rpc_do_wake_up_task(struct rpc_wait_queue *queue, struct rpc_task */ static void rpc_wake_up_task_queue_locked(struct rpc_wait_queue *queue, struct rpc_task *task) { - if (RPC_IS_QUEUED(task) && task->tk_waitqueue == queue) - __rpc_do_wake_up_task(queue, task); -} - -/* - * Tests whether rpc queue is empty - */ -int rpc_queue_empty(struct rpc_wait_queue *queue) -{ - int res; - - spin_lock_bh(&queue->lock); - res = queue->qlen; - spin_unlock_bh(&queue->lock); - return (res == 0); + if (RPC_IS_QUEUED(task)) { + smp_rmb(); + if (task->tk_waitqueue == queue) + __rpc_do_wake_up_task(queue, task); + } } -EXPORT_SYMBOL_GPL(rpc_queue_empty); /* * Wake up a task on a specific queue @@ -411,17 +457,9 @@ void rpc_wake_up_queued_task(struct rpc_wait_queue *queue, struct rpc_task *task EXPORT_SYMBOL_GPL(rpc_wake_up_queued_task); /* - * Wake up the specified task - */ -static void rpc_wake_up_task(struct rpc_task *task) -{ - rpc_wake_up_queued_task(task->tk_waitqueue, task); -} - -/* * Wake up the next task on a priority queue. */ -static struct rpc_task * __rpc_wake_up_next_priority(struct rpc_wait_queue *queue) +static struct rpc_task *__rpc_find_next_queued_priority(struct rpc_wait_queue *queue) { struct list_head *q; struct rpc_task *task; @@ -440,8 +478,7 @@ static struct rpc_task * __rpc_wake_up_next_priority(struct rpc_wait_queue *queu /* * Check if we need to switch queues. */ - if (--queue->count) - goto new_owner; + goto new_owner; } /* @@ -466,30 +503,54 @@ new_queue: new_owner: rpc_set_waitqueue_owner(queue, task->tk_owner); out: - rpc_wake_up_task_queue_locked(queue, task); return task; } +static struct rpc_task *__rpc_find_next_queued(struct rpc_wait_queue *queue) +{ + if (RPC_IS_PRIORITY(queue)) + return __rpc_find_next_queued_priority(queue); + if (!list_empty(&queue->tasks[0])) + return list_first_entry(&queue->tasks[0], struct rpc_task, u.tk_wait.list); + return NULL; +} + /* - * Wake up the next task on the wait queue. + * Wake up the first task on the wait queue. */ -struct rpc_task * rpc_wake_up_next(struct rpc_wait_queue *queue) +struct rpc_task *rpc_wake_up_first(struct rpc_wait_queue *queue, + bool (*func)(struct rpc_task *, void *), void *data) { struct rpc_task *task = NULL; - dprintk("RPC: wake_up_next(%p \"%s\")\n", + dprintk("RPC: wake_up_first(%p \"%s\")\n", queue, rpc_qname(queue)); spin_lock_bh(&queue->lock); - if (RPC_IS_PRIORITY(queue)) - task = __rpc_wake_up_next_priority(queue); - else { - task_for_first(task, &queue->tasks[0]) + task = __rpc_find_next_queued(queue); + if (task != NULL) { + if (func(task, data)) rpc_wake_up_task_queue_locked(queue, task); + else + task = NULL; } spin_unlock_bh(&queue->lock); return task; } +EXPORT_SYMBOL_GPL(rpc_wake_up_first); + +static bool rpc_wake_up_next_func(struct rpc_task *task, void *data) +{ + return true; +} + +/* + * Wake up the next task on the wait queue. +*/ +struct rpc_task *rpc_wake_up_next(struct rpc_wait_queue *queue) +{ + return rpc_wake_up_first(queue, rpc_wake_up_next_func, NULL); +} EXPORT_SYMBOL_GPL(rpc_wake_up_next); /** @@ -500,14 +561,18 @@ EXPORT_SYMBOL_GPL(rpc_wake_up_next); */ void rpc_wake_up(struct rpc_wait_queue *queue) { - struct rpc_task *task, *next; struct list_head *head; spin_lock_bh(&queue->lock); head = &queue->tasks[queue->maxpriority]; for (;;) { - list_for_each_entry_safe(task, next, head, u.tk_wait.list) + while (!list_empty(head)) { + struct rpc_task *task; + task = list_first_entry(head, + struct rpc_task, + u.tk_wait.list); rpc_wake_up_task_queue_locked(queue, task); + } if (head == &queue->tasks[0]) break; head--; @@ -525,13 +590,16 @@ EXPORT_SYMBOL_GPL(rpc_wake_up); */ void rpc_wake_up_status(struct rpc_wait_queue *queue, int status) { - struct rpc_task *task, *next; struct list_head *head; spin_lock_bh(&queue->lock); head = &queue->tasks[queue->maxpriority]; for (;;) { - list_for_each_entry_safe(task, next, head, u.tk_wait.list) { + while (!list_empty(head)) { + struct rpc_task *task; + task = list_first_entry(head, + struct rpc_task, + u.tk_wait.list); task->tk_status = status; rpc_wake_up_task_queue_locked(queue, task); } @@ -569,7 +637,8 @@ static void __rpc_queue_timer_fn(unsigned long ptr) static void __rpc_atrun(struct rpc_task *task) { - task->tk_status = 0; + if (task->tk_status == -ETIMEDOUT) + task->tk_status = 0; } /* @@ -590,6 +659,27 @@ void rpc_prepare_task(struct rpc_task *task) task->tk_ops->rpc_call_prepare(task, task->tk_calldata); } +static void +rpc_init_task_statistics(struct rpc_task *task) +{ + /* Initialize retry counters */ + task->tk_garb_retry = 2; + task->tk_cred_retry = 2; + task->tk_rebind_retry = 2; + + /* starting timestamp */ + task->tk_start = ktime_get(); +} + +static void +rpc_reset_task_statistics(struct rpc_task *task) +{ + task->tk_timeouts = 0; + task->tk_flags &= ~(RPC_CALL_MAJORSEEN|RPC_TASK_KILLED|RPC_TASK_SENT); + + rpc_init_task_statistics(task); +} + /* * Helper that calls task->tk_ops->rpc_call_done if it exists */ @@ -602,10 +692,19 @@ void rpc_exit_task(struct rpc_task *task) WARN_ON(RPC_ASSASSINATED(task)); /* Always release the RPC slot and buffer memory */ xprt_release(task); + rpc_reset_task_statistics(task); } } } -EXPORT_SYMBOL_GPL(rpc_exit_task); + +void rpc_exit(struct rpc_task *task, int status) +{ + task->tk_status = status; + task->tk_action = rpc_exit_task; + if (RPC_IS_QUEUED(task)) + rpc_wake_up_queued_task(task->tk_waitqueue, task); +} +EXPORT_SYMBOL_GPL(rpc_exit); void rpc_release_calldata(const struct rpc_call_ops *ops, void *calldata) { @@ -625,35 +724,31 @@ static void __rpc_execute(struct rpc_task *task) dprintk("RPC: %5u __rpc_execute flags=0x%x\n", task->tk_pid, task->tk_flags); - BUG_ON(RPC_IS_QUEUED(task)); + WARN_ON_ONCE(RPC_IS_QUEUED(task)); + if (RPC_IS_QUEUED(task)) + return; for (;;) { + void (*do_action)(struct rpc_task *); /* - * Execute any pending callback. + * Execute any pending callback first. */ - if (task->tk_callback) { - void (*save_callback)(struct rpc_task *); - + do_action = task->tk_callback; + task->tk_callback = NULL; + if (do_action == NULL) { /* - * We set tk_callback to NULL before calling it, - * in case it sets the tk_callback field itself: + * Perform the next FSM step. + * tk_action may be NULL if the task has been killed. + * In particular, note that rpc_killall_tasks may + * do this at any time, so beware when dereferencing. */ - save_callback = task->tk_callback; - task->tk_callback = NULL; - save_callback(task); - } - - /* - * Perform the next FSM step. - * tk_action may be NULL when the task has been killed - * by someone else. - */ - if (!RPC_IS_QUEUED(task)) { - if (task->tk_action == NULL) + do_action = task->tk_action; + if (do_action == NULL) break; - task->tk_action(task); } + trace_rpc_task_run_action(task->tk_client, task, task->tk_action); + do_action(task); /* * Lockless check for whether task is sleeping or not. @@ -695,9 +790,7 @@ static void __rpc_execute(struct rpc_task *task) dprintk("RPC: %5u got signal\n", task->tk_pid); task->tk_flags |= RPC_TASK_KILLED; rpc_exit(task, -ERESTARTSYS); - rpc_wake_up_task(task); } - rpc_set_running(task); dprintk("RPC: %5u sync task resuming\n", task->tk_pid); } @@ -718,14 +811,19 @@ static void __rpc_execute(struct rpc_task *task) */ void rpc_execute(struct rpc_task *task) { + bool is_async = RPC_IS_ASYNC(task); + rpc_set_active(task); - rpc_set_running(task); - __rpc_execute(task); + rpc_make_runnable(task); + if (!is_async) + __rpc_execute(task); } static void rpc_async_schedule(struct work_struct *work) { + current->flags |= PF_FSTRANS; __rpc_execute(container_of(work, struct rpc_task, u.tk_work)); + current->flags &= ~PF_FSTRANS; } /** @@ -734,7 +832,8 @@ static void rpc_async_schedule(struct work_struct *work) * @size: requested byte size * * To prevent rpciod from hanging, this allocator never sleeps, - * returning NULL if the request cannot be serviced immediately. + * returning NULL and suppressing warning if the request cannot be serviced + * immediately. * The caller can arrange to sleep in a way that is safe for rpciod. * * Most requests are 'small' (under 2KiB) and can be serviced from a @@ -747,7 +846,10 @@ static void rpc_async_schedule(struct work_struct *work) void *rpc_malloc(struct rpc_task *task, size_t size) { struct rpc_buffer *buf; - gfp_t gfp = RPC_IS_SWAPPER(task) ? GFP_ATOMIC : GFP_NOWAIT; + gfp_t gfp = GFP_NOWAIT | __GFP_NOWARN; + + if (RPC_IS_SWAPPER(task)) + gfp |= __GFP_MEMALLOC; size += sizeof(struct rpc_buffer); if (size <= RPC_BUFFER_MAXSIZE) @@ -803,38 +905,16 @@ static void rpc_init_task(struct rpc_task *task, const struct rpc_task_setup *ta task->tk_calldata = task_setup_data->callback_data; INIT_LIST_HEAD(&task->tk_task); - /* Initialize retry counters */ - task->tk_garb_retry = 2; - task->tk_cred_retry = 2; - task->tk_priority = task_setup_data->priority - RPC_PRIORITY_LOW; task->tk_owner = current->tgid; /* Initialize workqueue for async tasks */ task->tk_workqueue = task_setup_data->workqueue; - task->tk_client = task_setup_data->rpc_client; - if (task->tk_client != NULL) { - kref_get(&task->tk_client->cl_kref); - if (task->tk_client->cl_softrtry) - task->tk_flags |= RPC_TASK_SOFT; - } - if (task->tk_ops->rpc_call_prepare != NULL) task->tk_action = rpc_prepare_task; - if (task_setup_data->rpc_message != NULL) { - task->tk_msg.rpc_proc = task_setup_data->rpc_message->rpc_proc; - task->tk_msg.rpc_argp = task_setup_data->rpc_message->rpc_argp; - task->tk_msg.rpc_resp = task_setup_data->rpc_message->rpc_resp; - /* Bind the user cred */ - rpcauth_bindcred(task, task_setup_data->rpc_message->rpc_cred, task_setup_data->flags); - if (task->tk_action == NULL) - rpc_call_start(task); - } - - /* starting timestamp */ - task->tk_start = jiffies; + rpc_init_task_statistics(task); dprintk("RPC: new task initialized, procpid %u\n", task_pid_nr(current)); @@ -843,7 +923,7 @@ static void rpc_init_task(struct rpc_task *task, const struct rpc_task_setup *ta static struct rpc_task * rpc_alloc_task(void) { - return (struct rpc_task *)mempool_alloc(rpc_task_mempool, GFP_NOFS); + return (struct rpc_task *)mempool_alloc(rpc_task_mempool, GFP_NOIO); } /* @@ -856,29 +936,49 @@ struct rpc_task *rpc_new_task(const struct rpc_task_setup *setup_data) if (task == NULL) { task = rpc_alloc_task(); - if (task == NULL) - goto out; + if (task == NULL) { + rpc_release_calldata(setup_data->callback_ops, + setup_data->callback_data); + return ERR_PTR(-ENOMEM); + } flags = RPC_TASK_DYNAMIC; } rpc_init_task(task, setup_data); - task->tk_flags |= flags; dprintk("RPC: allocated task %p\n", task); -out: return task; } +/* + * rpc_free_task - release rpc task and perform cleanups + * + * Note that we free up the rpc_task _after_ rpc_release_calldata() + * in order to work around a workqueue dependency issue. + * + * Tejun Heo states: + * "Workqueue currently considers two work items to be the same if they're + * on the same address and won't execute them concurrently - ie. it + * makes a work item which is queued again while being executed wait + * for the previous execution to complete. + * + * If a work function frees the work item, and then waits for an event + * which should be performed by another work item and *that* work item + * recycles the freed work item, it can create a false dependency loop. + * There really is no reliable way to detect this short of verifying + * every memory free." + * + */ static void rpc_free_task(struct rpc_task *task) { - const struct rpc_call_ops *tk_ops = task->tk_ops; - void *calldata = task->tk_calldata; + unsigned short tk_flags = task->tk_flags; - if (task->tk_flags & RPC_TASK_DYNAMIC) { + rpc_release_calldata(task->tk_ops, task->tk_calldata); + + if (tk_flags & RPC_TASK_DYNAMIC) { dprintk("RPC: %5u freeing task\n", task->tk_pid); mempool_free(task, rpc_task_mempool); } - rpc_release_calldata(tk_ops, calldata); } static void rpc_async_release(struct work_struct *work) @@ -886,80 +986,69 @@ static void rpc_async_release(struct work_struct *work) rpc_free_task(container_of(work, struct rpc_task, u.tk_work)); } -void rpc_put_task(struct rpc_task *task) +static void rpc_release_resources_task(struct rpc_task *task) { - if (!atomic_dec_and_test(&task->tk_count)) - return; - /* Release resources */ - if (task->tk_rqstp) - xprt_release(task); - if (task->tk_msg.rpc_cred) - rpcauth_unbindcred(task); - if (task->tk_client) { - rpc_release_client(task->tk_client); - task->tk_client = NULL; + xprt_release(task); + if (task->tk_msg.rpc_cred) { + put_rpccred(task->tk_msg.rpc_cred); + task->tk_msg.rpc_cred = NULL; } - if (task->tk_workqueue != NULL) { + rpc_task_release_client(task); +} + +static void rpc_final_put_task(struct rpc_task *task, + struct workqueue_struct *q) +{ + if (q != NULL) { INIT_WORK(&task->u.tk_work, rpc_async_release); - queue_work(task->tk_workqueue, &task->u.tk_work); + queue_work(q, &task->u.tk_work); } else rpc_free_task(task); } -EXPORT_SYMBOL_GPL(rpc_put_task); -static void rpc_release_task(struct rpc_task *task) +static void rpc_do_put_task(struct rpc_task *task, struct workqueue_struct *q) { -#ifdef RPC_DEBUG - BUG_ON(task->tk_magic != RPC_TASK_MAGIC_ID); -#endif - dprintk("RPC: %5u release task\n", task->tk_pid); - - if (!list_empty(&task->tk_task)) { - struct rpc_clnt *clnt = task->tk_client; - /* Remove from client task list */ - spin_lock(&clnt->cl_lock); - list_del(&task->tk_task); - spin_unlock(&clnt->cl_lock); + if (atomic_dec_and_test(&task->tk_count)) { + rpc_release_resources_task(task); + rpc_final_put_task(task, q); } - BUG_ON (RPC_IS_QUEUED(task)); +} -#ifdef RPC_DEBUG - task->tk_magic = 0; -#endif - /* Wake up anyone who is waiting for task completion */ - rpc_mark_complete_task(task); +void rpc_put_task(struct rpc_task *task) +{ + rpc_do_put_task(task, NULL); +} +EXPORT_SYMBOL_GPL(rpc_put_task); - rpc_put_task(task); +void rpc_put_task_async(struct rpc_task *task) +{ + rpc_do_put_task(task, task->tk_workqueue); } +EXPORT_SYMBOL_GPL(rpc_put_task_async); -/* - * Kill all tasks for the given client. - * XXX: kill their descendants as well? - */ -void rpc_killall_tasks(struct rpc_clnt *clnt) +static void rpc_release_task(struct rpc_task *task) { - struct rpc_task *rovr; + dprintk("RPC: %5u release task\n", task->tk_pid); + WARN_ON_ONCE(RPC_IS_QUEUED(task)); + + rpc_release_resources_task(task); - if (list_empty(&clnt->cl_tasks)) - return; - dprintk("RPC: killing all tasks for client %p\n", clnt); /* - * Spin lock all_tasks to prevent changes... + * Note: at this point we have been removed from rpc_clnt->cl_tasks, + * so it should be safe to use task->tk_count as a test for whether + * or not any other processes still hold references to our rpc_task. */ - spin_lock(&clnt->cl_lock); - list_for_each_entry(rovr, &clnt->cl_tasks, tk_task) { - if (! RPC_IS_ACTIVATED(rovr)) - continue; - if (!(rovr->tk_flags & RPC_TASK_KILLED)) { - rovr->tk_flags |= RPC_TASK_KILLED; - rpc_exit(rovr, -EIO); - rpc_wake_up_task(rovr); - } + if (atomic_read(&task->tk_count) != 1 + !RPC_IS_ASYNC(task)) { + /* Wake up anyone who may be waiting for task completion */ + if (!rpc_complete_task(task)) + return; + } else { + if (!atomic_dec_and_test(&task->tk_count)) + return; } - spin_unlock(&clnt->cl_lock); + rpc_final_put_task(task, task->tk_workqueue); } -EXPORT_SYMBOL_GPL(rpc_killall_tasks); int rpciod_up(void) { @@ -982,7 +1071,7 @@ static int rpciod_start(void) * Create the rpciod thread and wait for it to start. */ dprintk("RPC: creating workqueue rpciod\n"); - wq = create_workqueue("rpciod"); + wq = alloc_workqueue("rpciod", WQ_MEM_RECLAIM, 1); rpciod_workqueue = wq; return rpciod_workqueue != NULL; } diff --git a/net/sunrpc/socklib.c b/net/sunrpc/socklib.c index a661a3acb37..2df87f78e51 100644 --- a/net/sunrpc/socklib.c +++ b/net/sunrpc/socklib.c @@ -8,11 +8,13 @@ #include <linux/compiler.h> #include <linux/netdevice.h> +#include <linux/gfp.h> #include <linux/skbuff.h> #include <linux/types.h> #include <linux/pagemap.h> #include <linux/udp.h> #include <linux/sunrpc/xdr.h> +#include <linux/export.h> /** @@ -112,7 +114,7 @@ ssize_t xdr_partial_copy_from_skb(struct xdr_buf *xdr, unsigned int base, struct } len = PAGE_CACHE_SIZE; - kaddr = kmap_atomic(*ppage, KM_SKB_SUNRPC_DATA); + kaddr = kmap_atomic(*ppage); if (base) { len -= base; if (pglen < len) @@ -125,7 +127,7 @@ ssize_t xdr_partial_copy_from_skb(struct xdr_buf *xdr, unsigned int base, struct ret = copy_actor(desc, kaddr, len); } flush_dcache_page(*ppage); - kunmap_atomic(kaddr, KM_SKB_SUNRPC_DATA); + kunmap_atomic(kaddr); copied += ret; if (ret != len || !desc->count) goto out; @@ -171,7 +173,8 @@ int csum_partial_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb) return -1; if (csum_fold(desc.csum)) return -1; - if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE)) + if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE) && + !skb->csum_complete_sw) netdev_rx_csum_fault(skb->dev); return 0; no_checksum: diff --git a/net/sunrpc/stats.c b/net/sunrpc/stats.c index 1b4e6791ecf..54530490944 100644 --- a/net/sunrpc/stats.c +++ b/net/sunrpc/stats.c @@ -13,6 +13,7 @@ */ #include <linux/module.h> +#include <linux/slab.h> #include <linux/init.h> #include <linux/kernel.h> @@ -21,11 +22,11 @@ #include <linux/sunrpc/clnt.h> #include <linux/sunrpc/svcsock.h> #include <linux/sunrpc/metrics.h> -#include <net/net_namespace.h> +#include <linux/rcupdate.h> -#define RPCDBG_FACILITY RPCDBG_MISC +#include "netns.h" -struct proc_dir_entry *proc_net_rpc = NULL; +#define RPCDBG_FACILITY RPCDBG_MISC /* * Get RPC client stats @@ -63,7 +64,7 @@ static int rpc_proc_show(struct seq_file *seq, void *v) { static int rpc_proc_open(struct inode *inode, struct file *file) { - return single_open(file, rpc_proc_show, PDE(inode)->data); + return single_open(file, rpc_proc_show, PDE_DATA(inode)); } static const struct file_operations rpc_proc_fops = { @@ -115,9 +116,7 @@ EXPORT_SYMBOL_GPL(svc_seq_show); */ struct rpc_iostats *rpc_alloc_iostats(struct rpc_clnt *clnt) { - struct rpc_iostats *new; - new = kcalloc(clnt->cl_maxproc, sizeof(struct rpc_iostats), GFP_KERNEL); - return new; + return kcalloc(clnt->cl_maxproc, sizeof(struct rpc_iostats), GFP_KERNEL); } EXPORT_SYMBOL_GPL(rpc_alloc_iostats); @@ -135,44 +134,37 @@ EXPORT_SYMBOL_GPL(rpc_free_iostats); /** * rpc_count_iostats - tally up per-task stats * @task: completed rpc_task + * @stats: array of stat structures * * Relies on the caller for serialization. */ -void rpc_count_iostats(struct rpc_task *task) +void rpc_count_iostats(const struct rpc_task *task, struct rpc_iostats *stats) { struct rpc_rqst *req = task->tk_rqstp; - struct rpc_iostats *stats; struct rpc_iostats *op_metrics; - long rtt, execute, queue; + ktime_t delta; - if (!task->tk_client || !task->tk_client->cl_metrics || !req) + if (!stats || !req) return; - stats = task->tk_client->cl_metrics; op_metrics = &stats[task->tk_msg.rpc_proc->p_statidx]; op_metrics->om_ops++; op_metrics->om_ntrans += req->rq_ntrans; op_metrics->om_timeouts += task->tk_timeouts; - op_metrics->om_bytes_sent += task->tk_bytes_sent; + op_metrics->om_bytes_sent += req->rq_xmit_bytes_sent; op_metrics->om_bytes_recv += req->rq_reply_bytes_recvd; - queue = (long)req->rq_xtime - task->tk_start; - if (queue < 0) - queue = -queue; - op_metrics->om_queue += queue; + delta = ktime_sub(req->rq_xtime, task->tk_start); + op_metrics->om_queue = ktime_add(op_metrics->om_queue, delta); - rtt = task->tk_rtt; - if (rtt < 0) - rtt = -rtt; - op_metrics->om_rtt += rtt; + op_metrics->om_rtt = ktime_add(op_metrics->om_rtt, req->rq_rtt); - execute = (long)jiffies - task->tk_start; - if (execute < 0) - execute = -execute; - op_metrics->om_execute += execute; + delta = ktime_sub(ktime_get(), task->tk_start); + op_metrics->om_execute = ktime_add(op_metrics->om_execute, delta); } +EXPORT_SYMBOL_GPL(rpc_count_iostats); static void _print_name(struct seq_file *seq, unsigned int op, struct rpc_procinfo *procs) @@ -185,12 +177,10 @@ static void _print_name(struct seq_file *seq, unsigned int op, seq_printf(seq, "\t%12u: ", op); } -#define MILLISECS_PER_JIFFY (1000 / HZ) - void rpc_print_iostats(struct seq_file *seq, struct rpc_clnt *clnt) { struct rpc_iostats *stats = clnt->cl_metrics; - struct rpc_xprt *xprt = clnt->cl_xprt; + struct rpc_xprt *xprt; unsigned int op, maxproc = clnt->cl_maxproc; if (!stats) @@ -198,10 +188,13 @@ void rpc_print_iostats(struct seq_file *seq, struct rpc_clnt *clnt) seq_printf(seq, "\tRPC iostats version: %s ", RPC_IOSTATS_VERS); seq_printf(seq, "p/v: %u/%u (%s)\n", - clnt->cl_prog, clnt->cl_vers, clnt->cl_protname); + clnt->cl_prog, clnt->cl_vers, clnt->cl_program->name); + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); if (xprt) xprt->ops->print_stats(xprt, seq); + rcu_read_unlock(); seq_printf(seq, "\tper-op statistics\n"); for (op = 0; op < maxproc; op++) { @@ -213,9 +206,9 @@ void rpc_print_iostats(struct seq_file *seq, struct rpc_clnt *clnt) metrics->om_timeouts, metrics->om_bytes_sent, metrics->om_bytes_recv, - metrics->om_queue * MILLISECS_PER_JIFFY, - metrics->om_rtt * MILLISECS_PER_JIFFY, - metrics->om_execute * MILLISECS_PER_JIFFY); + ktime_to_ms(metrics->om_queue), + ktime_to_ms(metrics->om_rtt), + ktime_to_ms(metrics->om_execute)); } } EXPORT_SYMBOL_GPL(rpc_print_iostats); @@ -224,57 +217,66 @@ EXPORT_SYMBOL_GPL(rpc_print_iostats); * Register/unregister RPC proc files */ static inline struct proc_dir_entry * -do_register(const char *name, void *data, const struct file_operations *fops) +do_register(struct net *net, const char *name, void *data, + const struct file_operations *fops) { - rpc_proc_init(); - dprintk("RPC: registering /proc/net/rpc/%s\n", name); + struct sunrpc_net *sn; - return proc_create_data(name, 0, proc_net_rpc, fops, data); + dprintk("RPC: registering /proc/net/rpc/%s\n", name); + sn = net_generic(net, sunrpc_net_id); + return proc_create_data(name, 0, sn->proc_net_rpc, fops, data); } struct proc_dir_entry * -rpc_proc_register(struct rpc_stat *statp) +rpc_proc_register(struct net *net, struct rpc_stat *statp) { - return do_register(statp->program->name, statp, &rpc_proc_fops); + return do_register(net, statp->program->name, statp, &rpc_proc_fops); } EXPORT_SYMBOL_GPL(rpc_proc_register); void -rpc_proc_unregister(const char *name) +rpc_proc_unregister(struct net *net, const char *name) { - remove_proc_entry(name, proc_net_rpc); + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + remove_proc_entry(name, sn->proc_net_rpc); } EXPORT_SYMBOL_GPL(rpc_proc_unregister); struct proc_dir_entry * -svc_proc_register(struct svc_stat *statp, const struct file_operations *fops) +svc_proc_register(struct net *net, struct svc_stat *statp, const struct file_operations *fops) { - return do_register(statp->program->pg_name, statp, fops); + return do_register(net, statp->program->pg_name, statp, fops); } EXPORT_SYMBOL_GPL(svc_proc_register); void -svc_proc_unregister(const char *name) +svc_proc_unregister(struct net *net, const char *name) { - remove_proc_entry(name, proc_net_rpc); + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + remove_proc_entry(name, sn->proc_net_rpc); } EXPORT_SYMBOL_GPL(svc_proc_unregister); -void -rpc_proc_init(void) +int rpc_proc_init(struct net *net) { + struct sunrpc_net *sn; + dprintk("RPC: registering /proc/net/rpc\n"); - if (!proc_net_rpc) - proc_net_rpc = proc_mkdir("rpc", init_net.proc_net); + sn = net_generic(net, sunrpc_net_id); + sn->proc_net_rpc = proc_mkdir("rpc", net->proc_net); + if (sn->proc_net_rpc == NULL) + return -ENOMEM; + + return 0; } -void -rpc_proc_exit(void) +void rpc_proc_exit(struct net *net) { dprintk("RPC: unregistering /proc/net/rpc\n"); - if (proc_net_rpc) { - proc_net_rpc = NULL; - remove_proc_entry("rpc", init_net.proc_net); - } + remove_proc_entry("rpc", net->proc_net); } diff --git a/net/sunrpc/sunrpc.h b/net/sunrpc/sunrpc.h index 90c292e2738..f2b7cb540e6 100644 --- a/net/sunrpc/sunrpc.h +++ b/net/sunrpc/sunrpc.h @@ -43,9 +43,24 @@ static inline int rpc_reply_expected(struct rpc_task *task) (task->tk_msg.rpc_proc->p_decode != NULL); } +static inline int sock_is_loopback(struct sock *sk) +{ + struct dst_entry *dst; + int loopback = 0; + rcu_read_lock(); + dst = rcu_dereference(sk->sk_dst_cache); + if (dst && dst->dev && + (dst->dev->features & NETIF_F_LOOPBACK)) + loopback = 1; + rcu_read_unlock(); + return loopback; +} + int svc_send_common(struct socket *sock, struct xdr_buf *xdr, struct page *headpage, unsigned long headoffset, struct page *tailpage, unsigned long tailoffset); +int rpc_clients_notifier_register(void); +void rpc_clients_notifier_unregister(void); #endif /* _NET_SUNRPC_SUNRPC_H */ diff --git a/net/sunrpc/sunrpc_syms.c b/net/sunrpc/sunrpc_syms.c index f438347d817..cd30120de9e 100644 --- a/net/sunrpc/sunrpc_syms.c +++ b/net/sunrpc/sunrpc_syms.c @@ -22,32 +22,94 @@ #include <linux/sunrpc/rpc_pipe_fs.h> #include <linux/sunrpc/xprtsock.h> -extern struct cache_detail ip_map_cache, unix_gid_cache; +#include "netns.h" -extern void cleanup_rpcb_clnt(void); +int sunrpc_net_id; +EXPORT_SYMBOL_GPL(sunrpc_net_id); + +static __net_init int sunrpc_init_net(struct net *net) +{ + int err; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + err = rpc_proc_init(net); + if (err) + goto err_proc; + + err = ip_map_cache_create(net); + if (err) + goto err_ipmap; + + err = unix_gid_cache_create(net); + if (err) + goto err_unixgid; + + err = rpc_pipefs_init_net(net); + if (err) + goto err_pipefs; + + INIT_LIST_HEAD(&sn->all_clients); + spin_lock_init(&sn->rpc_client_lock); + spin_lock_init(&sn->rpcb_clnt_lock); + return 0; + +err_pipefs: + unix_gid_cache_destroy(net); +err_unixgid: + ip_map_cache_destroy(net); +err_ipmap: + rpc_proc_exit(net); +err_proc: + return err; +} + +static __net_exit void sunrpc_exit_net(struct net *net) +{ + rpc_pipefs_exit_net(net); + unix_gid_cache_destroy(net); + ip_map_cache_destroy(net); + rpc_proc_exit(net); +} + +static struct pernet_operations sunrpc_net_ops = { + .init = sunrpc_init_net, + .exit = sunrpc_exit_net, + .id = &sunrpc_net_id, + .size = sizeof(struct sunrpc_net), +}; static int __init init_sunrpc(void) { - int err = register_rpc_pipefs(); + int err = rpc_init_mempool(); if (err) goto out; - err = rpc_init_mempool(); - if (err) { - unregister_rpc_pipefs(); - goto out; - } + err = rpcauth_init_module(); + if (err) + goto out2; + + cache_initialize(); + + err = register_pernet_subsys(&sunrpc_net_ops); + if (err) + goto out3; + + err = register_rpc_pipefs(); + if (err) + goto out4; #ifdef RPC_DEBUG rpc_register_sysctl(); #endif -#ifdef CONFIG_PROC_FS - rpc_proc_init(); -#endif - cache_register(&ip_map_cache); - cache_register(&unix_gid_cache); svc_init_xprt_sock(); /* svc sock transport */ init_socket_xprt(); /* clnt sock transport */ - rpcauth_init_module(); + return 0; + +out4: + unregister_pernet_subsys(&sunrpc_net_ops); +out3: + rpcauth_remove_module(); +out2: + rpc_destroy_mempool(); out: return err; } @@ -55,20 +117,15 @@ out: static void __exit cleanup_sunrpc(void) { - cleanup_rpcb_clnt(); rpcauth_remove_module(); cleanup_socket_xprt(); svc_cleanup_xprt_sock(); unregister_rpc_pipefs(); rpc_destroy_mempool(); - cache_unregister(&ip_map_cache); - cache_unregister(&unix_gid_cache); + unregister_pernet_subsys(&sunrpc_net_ops); #ifdef RPC_DEBUG rpc_unregister_sysctl(); #endif -#ifdef CONFIG_PROC_FS - rpc_proc_exit(); -#endif rcu_barrier(); /* Wait for completion of call_rcu()'s */ } MODULE_LICENSE("GPL"); diff --git a/net/sunrpc/svc.c b/net/sunrpc/svc.c index 8420a4205b7..5de6801cd92 100644 --- a/net/sunrpc/svc.c +++ b/net/sunrpc/svc.c @@ -19,6 +19,7 @@ #include <linux/interrupt.h> #include <linux/module.h> #include <linux/kthread.h> +#include <linux/slab.h> #include <linux/sunrpc/types.h> #include <linux/sunrpc/xdr.h> @@ -29,7 +30,7 @@ #define RPCDBG_FACILITY RPCDBG_SVCDSP -static void svc_unregister(const struct svc_serv *serv); +static void svc_unregister(const struct svc_serv *serv, struct net *net); #define svc_serv_is_pooled(serv) ((serv)->sv_function) @@ -166,6 +167,7 @@ svc_pool_map_alloc_arrays(struct svc_pool_map *m, unsigned int maxpools) fail_free: kfree(m->to_pool); + m->to_pool = NULL; fail: return -ENOMEM; } @@ -284,9 +286,10 @@ svc_pool_map_put(void) mutex_lock(&svc_pool_map_mutex); if (!--m->count) { - m->mode = SVC_POOL_DEFAULT; kfree(m->to_pool); + m->to_pool = NULL; kfree(m->pool_to); + m->pool_to = NULL; m->npools = 0; } @@ -294,6 +297,18 @@ svc_pool_map_put(void) } +static int svc_pool_map_get_node(unsigned int pidx) +{ + const struct svc_pool_map *m = &svc_pool_map; + + if (m->count) { + if (m->mode == SVC_POOL_PERCPU) + return cpu_to_node(m->pool_to[pidx]); + if (m->mode == SVC_POOL_PERNODE) + return m->pool_to[pidx]; + } + return NUMA_NO_NODE; +} /* * Set the given thread's cpus_allowed mask so that it * will only run on cpus in the given pool. @@ -308,7 +323,9 @@ svc_pool_map_set_cpumask(struct task_struct *task, unsigned int pidx) * The caller checks for sv_nrpools > 1, which * implies that we've been initialized. */ - BUG_ON(m->count == 0); + WARN_ON_ONCE(m->count == 0); + if (m->count == 0) + return; switch (m->mode) { case SVC_POOL_PERCPU: @@ -353,13 +370,58 @@ svc_pool_for_cpu(struct svc_serv *serv, int cpu) return &serv->sv_pools[pidx % serv->sv_nrpools]; } +int svc_rpcb_setup(struct svc_serv *serv, struct net *net) +{ + int err; + + err = rpcb_create_local(net); + if (err) + return err; + + /* Remove any stale portmap registrations */ + svc_unregister(serv, net); + return 0; +} +EXPORT_SYMBOL_GPL(svc_rpcb_setup); + +void svc_rpcb_cleanup(struct svc_serv *serv, struct net *net) +{ + svc_unregister(serv, net); + rpcb_put_local(net); +} +EXPORT_SYMBOL_GPL(svc_rpcb_cleanup); + +static int svc_uses_rpcbind(struct svc_serv *serv) +{ + struct svc_program *progp; + unsigned int i; + + for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (i = 0; i < progp->pg_nvers; i++) { + if (progp->pg_vers[i] == NULL) + continue; + if (progp->pg_vers[i]->vs_hidden == 0) + return 1; + } + } + + return 0; +} + +int svc_bind(struct svc_serv *serv, struct net *net) +{ + if (!svc_uses_rpcbind(serv)) + return 0; + return svc_rpcb_setup(serv, net); +} +EXPORT_SYMBOL_GPL(svc_bind); /* * Create an RPC service */ static struct svc_serv * __svc_create(struct svc_program *prog, unsigned int bufsize, int npools, - void (*shutdown)(struct svc_serv *serv)) + void (*shutdown)(struct svc_serv *serv, struct net *net)) { struct svc_serv *serv; unsigned int vers; @@ -418,15 +480,15 @@ __svc_create(struct svc_program *prog, unsigned int bufsize, int npools, spin_lock_init(&pool->sp_lock); } - /* Remove any stale portmap registrations */ - svc_unregister(serv); + if (svc_uses_rpcbind(serv) && (!serv->sv_shutdown)) + serv->sv_shutdown = svc_rpcb_cleanup; return serv; } struct svc_serv * svc_create(struct svc_program *prog, unsigned int bufsize, - void (*shutdown)(struct svc_serv *serv)) + void (*shutdown)(struct svc_serv *serv, struct net *net)) { return __svc_create(prog, bufsize, /*npools*/1, shutdown); } @@ -434,7 +496,7 @@ EXPORT_SYMBOL_GPL(svc_create); struct svc_serv * svc_create_pooled(struct svc_program *prog, unsigned int bufsize, - void (*shutdown)(struct svc_serv *serv), + void (*shutdown)(struct svc_serv *serv, struct net *net), svc_thread_fn func, struct module *mod) { struct svc_serv *serv; @@ -451,6 +513,15 @@ svc_create_pooled(struct svc_program *prog, unsigned int bufsize, } EXPORT_SYMBOL_GPL(svc_create_pooled); +void svc_shutdown_net(struct svc_serv *serv, struct net *net) +{ + svc_close_net(serv, net); + + if (serv->sv_shutdown) + serv->sv_shutdown(serv, net); +} +EXPORT_SYMBOL_GPL(svc_shutdown_net); + /* * Destroy an RPC service. Should be called with appropriate locking to * protect the sv_nrthreads, sv_permsocks and sv_tempsocks. @@ -472,13 +543,10 @@ svc_destroy(struct svc_serv *serv) del_timer_sync(&serv->sv_temptimer); - svc_close_all(&serv->sv_tempsocks); - - if (serv->sv_shutdown) - serv->sv_shutdown(serv); - - svc_close_all(&serv->sv_permsocks); - + /* + * The last user is gone and thus all sockets have to be destroyed to + * the point. Check this. + */ BUG_ON(!list_empty(&serv->sv_permsocks)); BUG_ON(!list_empty(&serv->sv_tempsocks)); @@ -487,11 +555,6 @@ svc_destroy(struct svc_serv *serv) if (svc_serv_is_pooled(serv)) svc_pool_map_put(); -#if defined(CONFIG_NFS_V4_1) - svc_sock_destroy(serv->bc_xprt); -#endif /* CONFIG_NFS_V4_1 */ - - svc_unregister(serv); kfree(serv->sv_pools); kfree(serv); } @@ -502,7 +565,7 @@ EXPORT_SYMBOL_GPL(svc_destroy); * We allocate pages and place them in rq_argpages. */ static int -svc_init_buffer(struct svc_rqst *rqstp, unsigned int size) +svc_init_buffer(struct svc_rqst *rqstp, unsigned int size, int node) { unsigned int pages, arghi; @@ -514,9 +577,11 @@ svc_init_buffer(struct svc_rqst *rqstp, unsigned int size) * We assume one is at most one page */ arghi = 0; - BUG_ON(pages > RPCSVC_MAXPAGES); + WARN_ON_ONCE(pages > RPCSVC_MAXPAGES); + if (pages > RPCSVC_MAXPAGES) + pages = RPCSVC_MAXPAGES; while (pages) { - struct page *p = alloc_page(GFP_KERNEL); + struct page *p = alloc_pages_node(node, GFP_KERNEL, 0); if (!p) break; rqstp->rq_pages[arghi++] = p; @@ -539,11 +604,11 @@ svc_release_buffer(struct svc_rqst *rqstp) } struct svc_rqst * -svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool) +svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool, int node) { struct svc_rqst *rqstp; - rqstp = kzalloc(sizeof(*rqstp), GFP_KERNEL); + rqstp = kzalloc_node(sizeof(*rqstp), GFP_KERNEL, node); if (!rqstp) goto out_enomem; @@ -557,15 +622,15 @@ svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool) rqstp->rq_server = serv; rqstp->rq_pool = pool; - rqstp->rq_argp = kmalloc(serv->sv_xdrsize, GFP_KERNEL); + rqstp->rq_argp = kmalloc_node(serv->sv_xdrsize, GFP_KERNEL, node); if (!rqstp->rq_argp) goto out_thread; - rqstp->rq_resp = kmalloc(serv->sv_xdrsize, GFP_KERNEL); + rqstp->rq_resp = kmalloc_node(serv->sv_xdrsize, GFP_KERNEL, node); if (!rqstp->rq_resp) goto out_thread; - if (!svc_init_buffer(rqstp, serv->sv_max_mesg)) + if (!svc_init_buffer(rqstp, serv->sv_max_mesg, node)) goto out_thread; return rqstp; @@ -632,8 +697,8 @@ found_pool: * Create or destroy enough new threads to make the number * of threads the given number. If `pool' is non-NULL, applies * only to threads in that pool, otherwise round-robins between - * all pools. Must be called with a svc_get() reference and - * the BKL or another lock to protect access to svc_serv fields. + * all pools. Caller must ensure that mutual exclusion between this and + * server startup or shutdown. * * Destroying threads relies on the service threads filling in * rqstp->rq_task, which only the nfs ones do. Assumes the serv @@ -650,6 +715,7 @@ svc_set_num_threads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) struct svc_pool *chosen_pool; int error = 0; unsigned int state = serv->sv_nrthreads-1; + int node; if (pool == NULL) { /* The -1 assumes caller has done a svc_get() */ @@ -665,14 +731,16 @@ svc_set_num_threads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) nrservs--; chosen_pool = choose_pool(serv, pool, &state); - rqstp = svc_prepare_thread(serv, chosen_pool); + node = svc_pool_map_get_node(chosen_pool->sp_id); + rqstp = svc_prepare_thread(serv, chosen_pool, node); if (IS_ERR(rqstp)) { error = PTR_ERR(rqstp); break; } __module_get(serv->sv_module); - task = kthread_create(serv->sv_function, rqstp, serv->sv_name); + task = kthread_create_on_node(serv->sv_function, rqstp, + node, "%s", serv->sv_name); if (IS_ERR(task)) { error = PTR_ERR(task); module_put(serv->sv_module); @@ -736,7 +804,8 @@ EXPORT_SYMBOL_GPL(svc_exit_thread); * Returns zero on success; a negative errno value is returned * if any error occurs. */ -static int __svc_rpcb_register4(const u32 program, const u32 version, +static int __svc_rpcb_register4(struct net *net, const u32 program, + const u32 version, const unsigned short protocol, const unsigned short port) { @@ -759,7 +828,7 @@ static int __svc_rpcb_register4(const u32 program, const u32 version, return -ENOPROTOOPT; } - error = rpcb_v4_register(program, version, + error = rpcb_v4_register(net, program, version, (const struct sockaddr *)&sin, netid); /* @@ -767,12 +836,12 @@ static int __svc_rpcb_register4(const u32 program, const u32 version, * registration request with the legacy rpcbind v2 protocol. */ if (error == -EPROTONOSUPPORT) - error = rpcb_register(program, version, protocol, port); + error = rpcb_register(net, program, version, protocol, port); return error; } -#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) +#if IS_ENABLED(CONFIG_IPV6) /* * Register an "inet6" protocol family netid with the local * rpcbind daemon via an rpcbind v4 SET request. @@ -783,7 +852,8 @@ static int __svc_rpcb_register4(const u32 program, const u32 version, * Returns zero on success; a negative errno value is returned * if any error occurs. */ -static int __svc_rpcb_register6(const u32 program, const u32 version, +static int __svc_rpcb_register6(struct net *net, const u32 program, + const u32 version, const unsigned short protocol, const unsigned short port) { @@ -806,7 +876,7 @@ static int __svc_rpcb_register6(const u32 program, const u32 version, return -ENOPROTOOPT; } - error = rpcb_v4_register(program, version, + error = rpcb_v4_register(net, program, version, (const struct sockaddr *)&sin6, netid); /* @@ -818,7 +888,7 @@ static int __svc_rpcb_register6(const u32 program, const u32 version, return error; } -#endif /* defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) */ +#endif /* IS_ENABLED(CONFIG_IPV6) */ /* * Register a kernel RPC service via rpcbind version 4. @@ -826,7 +896,7 @@ static int __svc_rpcb_register6(const u32 program, const u32 version, * Returns zero on success; a negative errno value is returned * if any error occurs. */ -static int __svc_register(const char *progname, +static int __svc_register(struct net *net, const char *progname, const u32 program, const u32 version, const int family, const unsigned short protocol, @@ -836,43 +906,46 @@ static int __svc_register(const char *progname, switch (family) { case PF_INET: - error = __svc_rpcb_register4(program, version, + error = __svc_rpcb_register4(net, program, version, protocol, port); break; -#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) +#if IS_ENABLED(CONFIG_IPV6) case PF_INET6: - error = __svc_rpcb_register6(program, version, + error = __svc_rpcb_register6(net, program, version, protocol, port); -#endif /* defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) */ +#endif } - if (error < 0) - printk(KERN_WARNING "svc: failed to register %sv%u RPC " - "service (errno %d).\n", progname, version, -error); return error; } /** * svc_register - register an RPC service with the local portmapper * @serv: svc_serv struct for the service to register + * @net: net namespace for the service to register * @family: protocol family of service's listener socket * @proto: transport protocol number to advertise * @port: port to advertise * * Service is registered for any address in the passed-in protocol family */ -int svc_register(const struct svc_serv *serv, const int family, - const unsigned short proto, const unsigned short port) +int svc_register(const struct svc_serv *serv, struct net *net, + const int family, const unsigned short proto, + const unsigned short port) { struct svc_program *progp; + struct svc_version *vers; unsigned int i; int error = 0; - BUG_ON(proto == 0 && port == 0); + WARN_ON_ONCE(proto == 0 && port == 0); + if (proto == 0 && port == 0) + return -EINVAL; for (progp = serv->sv_program; progp; progp = progp->pg_next) { for (i = 0; i < progp->pg_nvers; i++) { - if (progp->pg_vers[i] == NULL) + vers = progp->pg_vers[i]; + if (vers == NULL) continue; dprintk("svc: svc_register(%sv%d, %s, %u, %u)%s\n", @@ -881,16 +954,26 @@ int svc_register(const struct svc_serv *serv, const int family, proto == IPPROTO_UDP? "udp" : "tcp", port, family, - progp->pg_vers[i]->vs_hidden? - " (but not telling portmap)" : ""); + vers->vs_hidden ? + " (but not telling portmap)" : ""); - if (progp->pg_vers[i]->vs_hidden) + if (vers->vs_hidden) continue; - error = __svc_register(progp->pg_name, progp->pg_prog, + error = __svc_register(net, progp->pg_name, progp->pg_prog, i, family, proto, port); - if (error < 0) + + if (vers->vs_rpcb_optnl) { + error = 0; + continue; + } + + if (error < 0) { + printk(KERN_WARNING "svc: failed to register " + "%sv%u RPC service (errno %d).\n", + progp->pg_name, i, -error); break; + } } } @@ -904,19 +987,19 @@ int svc_register(const struct svc_serv *serv, const int family, * any "inet6" entries anyway. So a PMAP_UNSET should be sufficient * in this case to clear all existing entries for [program, version]. */ -static void __svc_unregister(const u32 program, const u32 version, +static void __svc_unregister(struct net *net, const u32 program, const u32 version, const char *progname) { int error; - error = rpcb_v4_register(program, version, NULL, ""); + error = rpcb_v4_register(net, program, version, NULL, ""); /* * User space didn't support rpcbind v4, so retry this * request with the legacy rpcbind v2 protocol. */ if (error == -EPROTONOSUPPORT) - error = rpcb_register(program, version, 0, 0); + error = rpcb_register(net, program, version, 0, 0); dprintk("svc: %s(%sv%u), error %d\n", __func__, progname, version, error); @@ -930,7 +1013,7 @@ static void __svc_unregister(const u32 program, const u32 version, * The result of unregistration is reported via dprintk for those who want * verification of the result, but is otherwise not important. */ -static void svc_unregister(const struct svc_serv *serv) +static void svc_unregister(const struct svc_serv *serv, struct net *net) { struct svc_program *progp; unsigned long flags; @@ -945,7 +1028,9 @@ static void svc_unregister(const struct svc_serv *serv) if (progp->pg_vers[i]->vs_hidden) continue; - __svc_unregister(progp->pg_prog, i, progp->pg_name); + dprintk("svc: attempting to unregister %sv%u\n", + progp->pg_name, i); + __svc_unregister(net, progp->pg_prog, i, progp->pg_name); } } @@ -955,28 +1040,28 @@ static void svc_unregister(const struct svc_serv *serv) } /* - * Printk the given error with the address of the client that caused it. + * dprintk the given error with the address of the client that caused it. */ -static int -__attribute__ ((format (printf, 2, 3))) -svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) +#ifdef RPC_DEBUG +static __printf(2, 3) +void svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) { + struct va_format vaf; va_list args; - int r; char buf[RPC_MAX_ADDRBUFLEN]; - if (!net_ratelimit()) - return 0; + va_start(args, fmt); - printk(KERN_WARNING "svc: %s: ", - svc_print_addr(rqstp, buf, sizeof(buf))); + vaf.fmt = fmt; + vaf.va = &args; - va_start(args, fmt); - r = vprintk(fmt, args); - va_end(args); + dprintk("svc: %s: %pV", svc_print_addr(rqstp, buf, sizeof(buf)), &vaf); - return r; + va_end(args); } +#else +static __printf(2,3) void svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) {} +#endif /* * Common routine for processing the RPC request. @@ -1004,6 +1089,7 @@ svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) rqstp->rq_splice_ok = 1; /* Will be turned off only when NFSv4 Sessions are used */ rqstp->rq_usedeferral = 1; + rqstp->rq_dropme = false; /* Setup reply header */ rqstp->rq_xprt->xpt_ops->xpo_prep_reply_hdr(rqstp); @@ -1027,8 +1113,6 @@ svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) rqstp->rq_vers = vers = svc_getnl(argv); /* version number */ rqstp->rq_proc = proc = svc_getnl(argv); /* procedure number */ - progp = serv->sv_program; - for (progp = serv->sv_program; progp; progp = progp->pg_next) if (prog == progp->pg_prog) break; @@ -1054,6 +1138,9 @@ svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) goto err_bad; case SVC_DENIED: goto err_bad_auth; + case SVC_CLOSE: + if (test_bit(XPT_TEMP, &rqstp->rq_xprt->xpt_flags)) + svc_close_xprt(rqstp->rq_xprt); case SVC_DROP: goto dropit; case SVC_COMPLETE: @@ -1102,7 +1189,7 @@ svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) *statp = procp->pc_func(rqstp, rqstp->rq_argp, rqstp->rq_resp); /* Encode reply */ - if (*statp == rpc_drop_reply) { + if (rqstp->rq_dropme) { if (procp->pc_release) procp->pc_release(rqstp, NULL, rqstp->rq_resp); goto dropit; @@ -1143,7 +1230,6 @@ svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) dropit: svc_authorise(rqstp); /* doesn't hurt to call this twice */ dprintk("svc: svc_process dropit\n"); - svc_drop(rqstp); return 0; err_short_len: @@ -1214,13 +1300,12 @@ svc_process(struct svc_rqst *rqstp) struct kvec *resv = &rqstp->rq_res.head[0]; struct svc_serv *serv = rqstp->rq_server; u32 dir; - int error; /* * Setup response xdr_buf. * Initially it has just one page */ - rqstp->rq_resused = 1; + rqstp->rq_next_page = &rqstp->rq_respages[1]; resv->iov_base = page_address(rqstp->rq_respages[0]); resv->iov_len = 0; rqstp->rq_res.pages = rqstp->rq_respages + 1; @@ -1242,14 +1327,16 @@ svc_process(struct svc_rqst *rqstp) return 0; } - error = svc_process_common(rqstp, argv, resv); - if (error <= 0) - return error; - - return svc_send(rqstp); + /* Returns 1 for send, 0 for drop */ + if (svc_process_common(rqstp, argv, resv)) + return svc_send(rqstp); + else { + svc_drop(rqstp); + return 0; + } } -#if defined(CONFIG_NFS_V4_1) +#if defined(CONFIG_SUNRPC_BACKCHANNEL) /* * Process a backchannel RPC request that arrived over an existing * outbound connection @@ -1260,10 +1347,9 @@ bc_svc_process(struct svc_serv *serv, struct rpc_rqst *req, { struct kvec *argv = &rqstp->rq_arg.head[0]; struct kvec *resv = &rqstp->rq_res.head[0]; - int error; /* Build the svc_rqst used by the common processing routine */ - rqstp->rq_xprt = serv->bc_xprt; + rqstp->rq_xprt = serv->sv_bc_xprt; rqstp->rq_xid = req->rq_xid; rqstp->rq_prot = req->rq_xprt->prot; rqstp->rq_server = serv; @@ -1288,15 +1374,19 @@ bc_svc_process(struct svc_serv *serv, struct rpc_rqst *req, svc_getu32(argv); /* XID */ svc_getnl(argv); /* CALLDIR */ - error = svc_process_common(rqstp, argv, resv); - if (error <= 0) - return error; - - memcpy(&req->rq_snd_buf, &rqstp->rq_res, sizeof(req->rq_snd_buf)); - return bc_send(req); + /* Returns 1 for send, 0 for drop */ + if (svc_process_common(rqstp, argv, resv)) { + memcpy(&req->rq_snd_buf, &rqstp->rq_res, + sizeof(req->rq_snd_buf)); + return bc_send(req); + } else { + /* drop request */ + xprt_free_bc_request(req); + return 0; + } } -EXPORT_SYMBOL(bc_svc_process); -#endif /* CONFIG_NFS_V4_1 */ +EXPORT_SYMBOL_GPL(bc_svc_process); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ /* * Return (transport-specific) limit on the rpc payload. diff --git a/net/sunrpc/svc_xprt.c b/net/sunrpc/svc_xprt.c index 8f0f1fb3dc5..b4737fbdec1 100644 --- a/net/sunrpc/svc_xprt.c +++ b/net/sunrpc/svc_xprt.c @@ -5,14 +5,16 @@ */ #include <linux/sched.h> -#include <linux/smp_lock.h> #include <linux/errno.h> #include <linux/freezer.h> #include <linux/kthread.h> +#include <linux/slab.h> #include <net/sock.h> #include <linux/sunrpc/stats.h> #include <linux/sunrpc/svc_xprt.h> #include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/xprt.h> +#include <linux/module.h> #define RPCDBG_FACILITY RPCDBG_SVCXPRT @@ -20,6 +22,7 @@ static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt); static int svc_deferred_recv(struct svc_rqst *rqstp); static struct cache_deferred_req *svc_defer(struct cache_req *req); static void svc_age_temp_xprts(unsigned long closure); +static void svc_delete_xprt(struct svc_xprt *xprt); /* apparently the "standard" is that clients close * idle connections after 5 minutes, servers after @@ -99,16 +102,14 @@ EXPORT_SYMBOL_GPL(svc_unreg_xprt_class); */ int svc_print_xprts(char *buf, int maxlen) { - struct list_head *le; + struct svc_xprt_class *xcl; char tmpstr[80]; int len = 0; buf[0] = '\0'; spin_lock(&svc_xprt_class_lock); - list_for_each(le, &svc_xprt_class_list) { + list_for_each_entry(xcl, &svc_xprt_class_list, xcl_list) { int slen; - struct svc_xprt_class *xcl = - list_entry(le, struct svc_xprt_class, xcl_list); sprintf(tmpstr, "%s %d\n", xcl->xcl_name, xcl->xcl_max_payload); slen = strlen(tmpstr); @@ -127,9 +128,12 @@ static void svc_xprt_free(struct kref *kref) struct svc_xprt *xprt = container_of(kref, struct svc_xprt, xpt_ref); struct module *owner = xprt->xpt_class->xcl_owner; - if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags) && - xprt->xpt_auth_cache != NULL) - svcauth_unix_info_release(xprt->xpt_auth_cache); + if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) + svcauth_unix_info_release(xprt); + put_net(xprt->xpt_net); + /* See comment on corresponding get in xs_setup_bc_tcp(): */ + if (xprt->xpt_bc_xprt) + xprt_put(xprt->xpt_bc_xprt); xprt->xpt_ops->xpo_free(xprt); module_put(owner); } @@ -144,8 +148,8 @@ EXPORT_SYMBOL_GPL(svc_xprt_put); * Called by transport drivers to initialize the transport independent * portion of the transport instance. */ -void svc_xprt_init(struct svc_xprt_class *xcl, struct svc_xprt *xprt, - struct svc_serv *serv) +void svc_xprt_init(struct net *net, struct svc_xprt_class *xcl, + struct svc_xprt *xprt, struct svc_serv *serv) { memset(xprt, 0, sizeof(*xprt)); xprt->xpt_class = xcl; @@ -155,15 +159,18 @@ void svc_xprt_init(struct svc_xprt_class *xcl, struct svc_xprt *xprt, INIT_LIST_HEAD(&xprt->xpt_list); INIT_LIST_HEAD(&xprt->xpt_ready); INIT_LIST_HEAD(&xprt->xpt_deferred); + INIT_LIST_HEAD(&xprt->xpt_users); mutex_init(&xprt->xpt_mutex); spin_lock_init(&xprt->xpt_lock); set_bit(XPT_BUSY, &xprt->xpt_flags); rpc_init_wait_queue(&xprt->xpt_bc_pending, "xpt_bc_pending"); + xprt->xpt_net = get_net(net); } EXPORT_SYMBOL_GPL(svc_xprt_init); static struct svc_xprt *__svc_xpo_create(struct svc_xprt_class *xcl, struct svc_serv *serv, + struct net *net, const int family, const unsigned short port, int flags) @@ -173,13 +180,13 @@ static struct svc_xprt *__svc_xpo_create(struct svc_xprt_class *xcl, .sin_addr.s_addr = htonl(INADDR_ANY), .sin_port = htons(port), }; -#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) +#if IS_ENABLED(CONFIG_IPV6) struct sockaddr_in6 sin6 = { .sin6_family = AF_INET6, .sin6_addr = IN6ADDR_ANY_INIT, .sin6_port = htons(port), }; -#endif /* defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) */ +#endif struct sockaddr *sap; size_t len; @@ -188,22 +195,53 @@ static struct svc_xprt *__svc_xpo_create(struct svc_xprt_class *xcl, sap = (struct sockaddr *)&sin; len = sizeof(sin); break; -#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) +#if IS_ENABLED(CONFIG_IPV6) case PF_INET6: sap = (struct sockaddr *)&sin6; len = sizeof(sin6); break; -#endif /* defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) */ +#endif default: return ERR_PTR(-EAFNOSUPPORT); } - return xcl->xcl_ops->xpo_create(serv, sap, len, flags); + return xcl->xcl_ops->xpo_create(serv, net, sap, len, flags); +} + +/* + * svc_xprt_received conditionally queues the transport for processing + * by another thread. The caller must hold the XPT_BUSY bit and must + * not thereafter touch transport data. + * + * Note: XPT_DATA only gets cleared when a read-attempt finds no (or + * insufficient) data. + */ +static void svc_xprt_received(struct svc_xprt *xprt) +{ + WARN_ON_ONCE(!test_bit(XPT_BUSY, &xprt->xpt_flags)); + if (!test_bit(XPT_BUSY, &xprt->xpt_flags)) + return; + /* As soon as we clear busy, the xprt could be closed and + * 'put', so we need a reference to call svc_xprt_enqueue with: + */ + svc_xprt_get(xprt); + clear_bit(XPT_BUSY, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); +} + +void svc_add_new_perm_xprt(struct svc_serv *serv, struct svc_xprt *new) +{ + clear_bit(XPT_TEMP, &new->xpt_flags); + spin_lock_bh(&serv->sv_lock); + list_add(&new->xpt_list, &serv->sv_permsocks); + spin_unlock_bh(&serv->sv_lock); + svc_xprt_received(new); } int svc_create_xprt(struct svc_serv *serv, const char *xprt_name, - const int family, const unsigned short port, - int flags) + struct net *net, const int family, + const unsigned short port, int flags) { struct svc_xprt_class *xcl; @@ -211,6 +249,7 @@ int svc_create_xprt(struct svc_serv *serv, const char *xprt_name, spin_lock(&svc_xprt_class_lock); list_for_each_entry(xcl, &svc_xprt_class_list, xcl_list) { struct svc_xprt *newxprt; + unsigned short newport; if (strcmp(xprt_name, xcl->xcl_name)) continue; @@ -219,18 +258,14 @@ int svc_create_xprt(struct svc_serv *serv, const char *xprt_name, goto err; spin_unlock(&svc_xprt_class_lock); - newxprt = __svc_xpo_create(xcl, serv, family, port, flags); + newxprt = __svc_xpo_create(xcl, serv, net, family, port, flags); if (IS_ERR(newxprt)) { module_put(xcl->xcl_owner); return PTR_ERR(newxprt); } - - clear_bit(XPT_TEMP, &newxprt->xpt_flags); - spin_lock_bh(&serv->sv_lock); - list_add(&newxprt->xpt_list, &serv->sv_permsocks); - spin_unlock_bh(&serv->sv_lock); - clear_bit(XPT_BUSY, &newxprt->xpt_flags); - return svc_xprt_local_port(newxprt); + svc_add_new_perm_xprt(serv, newxprt); + newport = svc_xprt_local_port(newxprt); + return newport; } err: spin_unlock(&svc_xprt_class_lock); @@ -247,8 +282,6 @@ EXPORT_SYMBOL_GPL(svc_create_xprt); */ void svc_xprt_copy_addrs(struct svc_rqst *rqstp, struct svc_xprt *xprt) { - struct sockaddr *sin; - memcpy(&rqstp->rq_addr, &xprt->xpt_remote, xprt->xpt_remotelen); rqstp->rq_addrlen = xprt->xpt_remotelen; @@ -256,15 +289,8 @@ void svc_xprt_copy_addrs(struct svc_rqst *rqstp, struct svc_xprt *xprt) * Destination address in request is needed for binding the * source address in RPC replies/callbacks later. */ - sin = (struct sockaddr *)&xprt->xpt_local; - switch (sin->sa_family) { - case AF_INET: - rqstp->rq_daddr.addr = ((struct sockaddr_in *)sin)->sin_addr; - break; - case AF_INET6: - rqstp->rq_daddr.addr6 = ((struct sockaddr_in6 *)sin)->sin6_addr; - break; - } + memcpy(&rqstp->rq_daddr, &xprt->xpt_local, xprt->xpt_locallen); + rqstp->rq_daddrlen = xprt->xpt_locallen; } EXPORT_SYMBOL_GPL(svc_xprt_copy_addrs); @@ -300,6 +326,15 @@ static void svc_thread_dequeue(struct svc_pool *pool, struct svc_rqst *rqstp) list_del(&rqstp->rq_list); } +static bool svc_xprt_has_something_to_do(struct svc_xprt *xprt) +{ + if (xprt->xpt_flags & ((1<<XPT_CONN)|(1<<XPT_CLOSE))) + return true; + if (xprt->xpt_flags & ((1<<XPT_DATA)|(1<<XPT_DEFERRED))) + return xprt->xpt_ops->xpo_has_wspace(xprt); + return false; +} + /* * Queue up a transport with data pending. If there are idle nfsd * processes, wake 'em up. @@ -307,13 +342,11 @@ static void svc_thread_dequeue(struct svc_pool *pool, struct svc_rqst *rqstp) */ void svc_xprt_enqueue(struct svc_xprt *xprt) { - struct svc_serv *serv = xprt->xpt_server; struct svc_pool *pool; struct svc_rqst *rqstp; int cpu; - if (!(xprt->xpt_flags & - ((1<<XPT_CONN)|(1<<XPT_DATA)|(1<<XPT_CLOSE)|(1<<XPT_DEFERRED)))) + if (!svc_xprt_has_something_to_do(xprt)) return; cpu = get_cpu(); @@ -328,12 +361,6 @@ void svc_xprt_enqueue(struct svc_xprt *xprt) "svc_xprt_enqueue: " "threads and transports both waiting??\n"); - if (test_bit(XPT_DEAD, &xprt->xpt_flags)) { - /* Don't enqueue dead transports */ - dprintk("svc: transport %p is dead, not enqueued\n", xprt); - goto out_unlock; - } - pool->sp_stats.packets++; /* Mark transport as busy. It will remain in this state until @@ -346,28 +373,7 @@ void svc_xprt_enqueue(struct svc_xprt *xprt) dprintk("svc: transport %p busy, not enqueued\n", xprt); goto out_unlock; } - BUG_ON(xprt->xpt_pool != NULL); - xprt->xpt_pool = pool; - - /* Handle pending connection */ - if (test_bit(XPT_CONN, &xprt->xpt_flags)) - goto process; - - /* Handle close in-progress */ - if (test_bit(XPT_CLOSE, &xprt->xpt_flags)) - goto process; - - /* Check if we have space to reply to a request */ - if (!xprt->xpt_ops->xpo_has_wspace(xprt)) { - /* Don't enqueue while not enough space for reply */ - dprintk("svc: no write space, transport %p not enqueued\n", - xprt); - xprt->xpt_pool = NULL; - clear_bit(XPT_BUSY, &xprt->xpt_flags); - goto out_unlock; - } - process: if (!list_empty(&pool->sp_threads)) { rqstp = list_entry(pool->sp_threads.next, struct svc_rqst, @@ -381,16 +387,12 @@ void svc_xprt_enqueue(struct svc_xprt *xprt) rqstp, rqstp->rq_xprt); rqstp->rq_xprt = xprt; svc_xprt_get(xprt); - rqstp->rq_reserved = serv->sv_max_mesg; - atomic_add(rqstp->rq_reserved, &xprt->xpt_reserved); pool->sp_stats.threads_woken++; - BUG_ON(xprt->xpt_pool != pool); wake_up(&rqstp->rq_wait); } else { dprintk("svc: transport %p put into queue\n", xprt); list_add_tail(&xprt->xpt_ready, &pool->sp_sockets); pool->sp_stats.sockets_queued++; - BUG_ON(xprt->xpt_pool != pool); } out_unlock: @@ -418,23 +420,6 @@ static struct svc_xprt *svc_xprt_dequeue(struct svc_pool *pool) return xprt; } -/* - * svc_xprt_received conditionally queues the transport for processing - * by another thread. The caller must hold the XPT_BUSY bit and must - * not thereafter touch transport data. - * - * Note: XPT_DATA only gets cleared when a read-attempt finds no (or - * insufficient) data. - */ -void svc_xprt_received(struct svc_xprt *xprt) -{ - BUG_ON(!test_bit(XPT_BUSY, &xprt->xpt_flags)); - xprt->xpt_pool = NULL; - clear_bit(XPT_BUSY, &xprt->xpt_flags); - svc_xprt_enqueue(xprt); -} -EXPORT_SYMBOL_GPL(svc_xprt_received); - /** * svc_reserve - change the space reserved for the reply to a request. * @rqstp: The request in question @@ -514,7 +499,8 @@ void svc_wake_up(struct svc_serv *serv) rqstp->rq_xprt = NULL; */ wake_up(&rqstp->rq_wait); - } + } else + pool->sp_task_pending = 1; spin_unlock_bh(&pool->sp_lock); } } @@ -561,14 +547,11 @@ static void svc_check_conn_limits(struct svc_serv *serv) struct svc_xprt *xprt = NULL; spin_lock_bh(&serv->sv_lock); if (!list_empty(&serv->sv_tempsocks)) { - if (net_ratelimit()) { - /* Try to help the admin */ - printk(KERN_NOTICE "%s: too many open " - "connections, consider increasing %s\n", - serv->sv_name, serv->sv_maxconn ? - "the max number of connections." : - "the number of threads."); - } + /* Try to help the admin */ + net_notice_ratelimited("%s: too many open connections, consider increasing the %s\n", + serv->sv_name, serv->sv_maxconn ? + "max number of connections" : + "number of threads"); /* * Always select the oldest connection. It's not fair, * but so is life @@ -588,36 +571,19 @@ static void svc_check_conn_limits(struct svc_serv *serv) } } -/* - * Receive the next request on any transport. This code is carefully - * organised not to touch any cachelines in the shared svc_serv - * structure, only cachelines in the local svc_pool. - */ -int svc_recv(struct svc_rqst *rqstp, long timeout) +static int svc_alloc_arg(struct svc_rqst *rqstp) { - struct svc_xprt *xprt = NULL; - struct svc_serv *serv = rqstp->rq_server; - struct svc_pool *pool = rqstp->rq_pool; - int len, i; - int pages; - struct xdr_buf *arg; - DECLARE_WAITQUEUE(wait, current); - long time_left; - - dprintk("svc: server %p waiting for data (to = %ld)\n", - rqstp, timeout); - - if (rqstp->rq_xprt) - printk(KERN_ERR - "svc_recv: service %p, transport not NULL!\n", - rqstp); - if (waitqueue_active(&rqstp->rq_wait)) - printk(KERN_ERR - "svc_recv: service %p, wait queue active!\n", - rqstp); + struct svc_serv *serv = rqstp->rq_server; + struct xdr_buf *arg; + int pages; + int i; /* now allocate needed pages. If we get a failure, sleep briefly */ pages = (serv->sv_max_mesg + PAGE_SIZE) / PAGE_SIZE; + WARN_ON_ONCE(pages >= RPCSVC_MAXPAGES); + if (pages >= RPCSVC_MAXPAGES) + /* use as many pages as possible */ + pages = RPCSVC_MAXPAGES - 1; for (i = 0; i < pages ; i++) while (rqstp->rq_pages[i] == NULL) { struct page *p = alloc_page(GFP_KERNEL); @@ -631,8 +597,8 @@ int svc_recv(struct svc_rqst *rqstp, long timeout) } rqstp->rq_pages[i] = p; } + rqstp->rq_page_end = &rqstp->rq_pages[i]; rqstp->rq_pages[i++] = NULL; /* this might be seen in nfs_read_actor */ - BUG_ON(pages >= RPCSVC_MAXPAGES); /* Make arg->head point to first page and arg->pages point to rest */ arg = &rqstp->rq_arg; @@ -644,20 +610,39 @@ int svc_recv(struct svc_rqst *rqstp, long timeout) arg->page_len = (pages-2)*PAGE_SIZE; arg->len = (pages-1)*PAGE_SIZE; arg->tail[0].iov_len = 0; + return 0; +} - try_to_freeze(); - cond_resched(); - if (signalled() || kthread_should_stop()) - return -EINTR; +static struct svc_xprt *svc_get_next_xprt(struct svc_rqst *rqstp, long timeout) +{ + struct svc_xprt *xprt; + struct svc_pool *pool = rqstp->rq_pool; + DECLARE_WAITQUEUE(wait, current); + long time_left; + + /* Normally we will wait up to 5 seconds for any required + * cache information to be provided. + */ + rqstp->rq_chandle.thread_wait = 5*HZ; spin_lock_bh(&pool->sp_lock); xprt = svc_xprt_dequeue(pool); if (xprt) { rqstp->rq_xprt = xprt; svc_xprt_get(xprt); - rqstp->rq_reserved = serv->sv_max_mesg; - atomic_add(rqstp->rq_reserved, &xprt->xpt_reserved); + + /* As there is a shortage of threads and this request + * had to be queued, don't allow the thread to wait so + * long for cache updates. + */ + rqstp->rq_chandle.thread_wait = 1*HZ; + pool->sp_task_pending = 0; } else { + if (pool->sp_task_pending) { + pool->sp_task_pending = 0; + spin_unlock_bh(&pool->sp_lock); + return ERR_PTR(-EAGAIN); + } /* No data pending. Go to sleep */ svc_thread_enqueue(pool, rqstp); @@ -677,7 +662,7 @@ int svc_recv(struct svc_rqst *rqstp, long timeout) if (kthread_should_stop()) { set_current_state(TASK_RUNNING); spin_unlock_bh(&pool->sp_lock); - return -EINTR; + return ERR_PTR(-EINTR); } add_wait_queue(&rqstp->rq_wait, &wait); @@ -698,70 +683,129 @@ int svc_recv(struct svc_rqst *rqstp, long timeout) spin_unlock_bh(&pool->sp_lock); dprintk("svc: server %p, no data yet\n", rqstp); if (signalled() || kthread_should_stop()) - return -EINTR; + return ERR_PTR(-EINTR); else - return -EAGAIN; + return ERR_PTR(-EAGAIN); } } spin_unlock_bh(&pool->sp_lock); + return xprt; +} + +static void svc_add_new_temp_xprt(struct svc_serv *serv, struct svc_xprt *newxpt) +{ + spin_lock_bh(&serv->sv_lock); + set_bit(XPT_TEMP, &newxpt->xpt_flags); + list_add(&newxpt->xpt_list, &serv->sv_tempsocks); + serv->sv_tmpcnt++; + if (serv->sv_temptimer.function == NULL) { + /* setup timer to age temp transports */ + setup_timer(&serv->sv_temptimer, svc_age_temp_xprts, + (unsigned long)serv); + mod_timer(&serv->sv_temptimer, + jiffies + svc_conn_age_period * HZ); + } + spin_unlock_bh(&serv->sv_lock); + svc_xprt_received(newxpt); +} + +static int svc_handle_xprt(struct svc_rqst *rqstp, struct svc_xprt *xprt) +{ + struct svc_serv *serv = rqstp->rq_server; + int len = 0; - len = 0; if (test_bit(XPT_CLOSE, &xprt->xpt_flags)) { dprintk("svc_recv: found XPT_CLOSE\n"); svc_delete_xprt(xprt); - } else if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) { + /* Leave XPT_BUSY set on the dead xprt: */ + return 0; + } + if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) { struct svc_xprt *newxpt; + /* + * We know this module_get will succeed because the + * listener holds a reference too + */ + __module_get(xprt->xpt_class->xcl_owner); + svc_check_conn_limits(xprt->xpt_server); newxpt = xprt->xpt_ops->xpo_accept(xprt); - if (newxpt) { - /* - * We know this module_get will succeed because the - * listener holds a reference too - */ - __module_get(newxpt->xpt_class->xcl_owner); - svc_check_conn_limits(xprt->xpt_server); - spin_lock_bh(&serv->sv_lock); - set_bit(XPT_TEMP, &newxpt->xpt_flags); - list_add(&newxpt->xpt_list, &serv->sv_tempsocks); - serv->sv_tmpcnt++; - if (serv->sv_temptimer.function == NULL) { - /* setup timer to age temp transports */ - setup_timer(&serv->sv_temptimer, - svc_age_temp_xprts, - (unsigned long)serv); - mod_timer(&serv->sv_temptimer, - jiffies + svc_conn_age_period * HZ); - } - spin_unlock_bh(&serv->sv_lock); - svc_xprt_received(newxpt); - } - svc_xprt_received(xprt); - } else { + if (newxpt) + svc_add_new_temp_xprt(serv, newxpt); + else + module_put(xprt->xpt_class->xcl_owner); + } else if (xprt->xpt_ops->xpo_has_wspace(xprt)) { + /* XPT_DATA|XPT_DEFERRED case: */ dprintk("svc: server %p, pool %u, transport %p, inuse=%d\n", - rqstp, pool->sp_id, xprt, + rqstp, rqstp->rq_pool->sp_id, xprt, atomic_read(&xprt->xpt_ref.refcount)); rqstp->rq_deferred = svc_deferred_dequeue(xprt); - if (rqstp->rq_deferred) { - svc_xprt_received(xprt); + if (rqstp->rq_deferred) len = svc_deferred_recv(rqstp); - } else + else len = xprt->xpt_ops->xpo_recvfrom(rqstp); dprintk("svc: got len=%d\n", len); + rqstp->rq_reserved = serv->sv_max_mesg; + atomic_add(rqstp->rq_reserved, &xprt->xpt_reserved); } + /* clear XPT_BUSY: */ + svc_xprt_received(xprt); + return len; +} + +/* + * Receive the next request on any transport. This code is carefully + * organised not to touch any cachelines in the shared svc_serv + * structure, only cachelines in the local svc_pool. + */ +int svc_recv(struct svc_rqst *rqstp, long timeout) +{ + struct svc_xprt *xprt = NULL; + struct svc_serv *serv = rqstp->rq_server; + int len, err; + + dprintk("svc: server %p waiting for data (to = %ld)\n", + rqstp, timeout); + + if (rqstp->rq_xprt) + printk(KERN_ERR + "svc_recv: service %p, transport not NULL!\n", + rqstp); + if (waitqueue_active(&rqstp->rq_wait)) + printk(KERN_ERR + "svc_recv: service %p, wait queue active!\n", + rqstp); + + err = svc_alloc_arg(rqstp); + if (err) + return err; + + try_to_freeze(); + cond_resched(); + if (signalled() || kthread_should_stop()) + return -EINTR; + + xprt = svc_get_next_xprt(rqstp, timeout); + if (IS_ERR(xprt)) + return PTR_ERR(xprt); + + len = svc_handle_xprt(rqstp, xprt); /* No data, incomplete (TCP) read, or accept() */ - if (len == 0 || len == -EAGAIN) { - rqstp->rq_res.len = 0; - svc_xprt_release(rqstp); - return -EAGAIN; - } + if (len <= 0) + goto out; + clear_bit(XPT_OLD, &xprt->xpt_flags); - rqstp->rq_secure = svc_port_is_privileged(svc_addr(rqstp)); + rqstp->rq_secure = xprt->xpt_ops->xpo_secure_port(rqstp); rqstp->rq_chandle.defer = svc_defer; if (serv->sv_stats) serv->sv_stats->netcnt++; return len; +out: + rqstp->rq_res.len = 0; + svc_xprt_release(rqstp); + return -EAGAIN; } EXPORT_SYMBOL_GPL(svc_recv); @@ -799,7 +843,8 @@ int svc_send(struct svc_rqst *rqstp) /* Grab mutex to serialize outgoing data. */ mutex_lock(&xprt->xpt_mutex); - if (test_bit(XPT_DEAD, &xprt->xpt_flags)) + if (test_bit(XPT_DEAD, &xprt->xpt_flags) + || test_bit(XPT_CLOSE, &xprt->xpt_flags)) len = -ENOTCONN; else len = xprt->xpt_ops->xpo_sendto(rqstp); @@ -821,7 +866,6 @@ static void svc_age_temp_xprts(unsigned long closure) struct svc_serv *serv = (struct svc_serv *)closure; struct svc_xprt *xprt; struct list_head *le, *next; - LIST_HEAD(to_be_aged); dprintk("svc_age_temp_xprts\n"); @@ -842,40 +886,43 @@ static void svc_age_temp_xprts(unsigned long closure) if (atomic_read(&xprt->xpt_ref.refcount) > 1 || test_bit(XPT_BUSY, &xprt->xpt_flags)) continue; - svc_xprt_get(xprt); - list_move(le, &to_be_aged); + list_del_init(le); set_bit(XPT_CLOSE, &xprt->xpt_flags); set_bit(XPT_DETACHED, &xprt->xpt_flags); - } - spin_unlock_bh(&serv->sv_lock); - - while (!list_empty(&to_be_aged)) { - le = to_be_aged.next; - /* fiddling the xpt_list node is safe 'cos we're XPT_DETACHED */ - list_del_init(le); - xprt = list_entry(le, struct svc_xprt, xpt_list); - dprintk("queuing xprt %p for closing\n", xprt); /* a thread will dequeue and close it soon */ svc_xprt_enqueue(xprt); - svc_xprt_put(xprt); } + spin_unlock_bh(&serv->sv_lock); mod_timer(&serv->sv_temptimer, jiffies + svc_conn_age_period * HZ); } +static void call_xpt_users(struct svc_xprt *xprt) +{ + struct svc_xpt_user *u; + + spin_lock(&xprt->xpt_lock); + while (!list_empty(&xprt->xpt_users)) { + u = list_first_entry(&xprt->xpt_users, struct svc_xpt_user, list); + list_del(&u->list); + u->callback(u); + } + spin_unlock(&xprt->xpt_lock); +} + /* * Remove a dead transport */ -void svc_delete_xprt(struct svc_xprt *xprt) +static void svc_delete_xprt(struct svc_xprt *xprt) { struct svc_serv *serv = xprt->xpt_server; struct svc_deferred_req *dr; /* Only do this once */ if (test_and_set_bit(XPT_DEAD, &xprt->xpt_flags)) - return; + BUG(); dprintk("svc: svc_delete_xprt(%p)\n", xprt); xprt->xpt_ops->xpo_detach(xprt); @@ -883,21 +930,16 @@ void svc_delete_xprt(struct svc_xprt *xprt) spin_lock_bh(&serv->sv_lock); if (!test_and_set_bit(XPT_DETACHED, &xprt->xpt_flags)) list_del_init(&xprt->xpt_list); - /* - * We used to delete the transport from whichever list - * it's sk_xprt.xpt_ready node was on, but we don't actually - * need to. This is because the only time we're called - * while still attached to a queue, the queue itself - * is about to be destroyed (in svc_destroy). - */ + WARN_ON_ONCE(!list_empty(&xprt->xpt_ready)); if (test_bit(XPT_TEMP, &xprt->xpt_flags)) serv->sv_tmpcnt--; + spin_unlock_bh(&serv->sv_lock); while ((dr = svc_deferred_dequeue(xprt)) != NULL) kfree(dr); + call_xpt_users(xprt); svc_xprt_put(xprt); - spin_unlock_bh(&serv->sv_lock); } void svc_close_xprt(struct svc_xprt *xprt) @@ -906,29 +948,87 @@ void svc_close_xprt(struct svc_xprt *xprt) if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags)) /* someone else will have to effect the close */ return; - - svc_xprt_get(xprt); + /* + * We expect svc_close_xprt() to work even when no threads are + * running (e.g., while configuring the server before starting + * any threads), so if the transport isn't busy, we delete + * it ourself: + */ svc_delete_xprt(xprt); - clear_bit(XPT_BUSY, &xprt->xpt_flags); - svc_xprt_put(xprt); } EXPORT_SYMBOL_GPL(svc_close_xprt); -void svc_close_all(struct list_head *xprt_list) +static int svc_close_list(struct svc_serv *serv, struct list_head *xprt_list, struct net *net) { struct svc_xprt *xprt; - struct svc_xprt *tmp; + int ret = 0; - list_for_each_entry_safe(xprt, tmp, xprt_list, xpt_list) { + spin_lock(&serv->sv_lock); + list_for_each_entry(xprt, xprt_list, xpt_list) { + if (xprt->xpt_net != net) + continue; + ret++; set_bit(XPT_CLOSE, &xprt->xpt_flags); - if (test_bit(XPT_BUSY, &xprt->xpt_flags)) { - /* Waiting to be processed, but no threads left, - * So just remove it from the waiting list - */ + svc_xprt_enqueue(xprt); + } + spin_unlock(&serv->sv_lock); + return ret; +} + +static struct svc_xprt *svc_dequeue_net(struct svc_serv *serv, struct net *net) +{ + struct svc_pool *pool; + struct svc_xprt *xprt; + struct svc_xprt *tmp; + int i; + + for (i = 0; i < serv->sv_nrpools; i++) { + pool = &serv->sv_pools[i]; + + spin_lock_bh(&pool->sp_lock); + list_for_each_entry_safe(xprt, tmp, &pool->sp_sockets, xpt_ready) { + if (xprt->xpt_net != net) + continue; list_del_init(&xprt->xpt_ready); - clear_bit(XPT_BUSY, &xprt->xpt_flags); + spin_unlock_bh(&pool->sp_lock); + return xprt; } - svc_close_xprt(xprt); + spin_unlock_bh(&pool->sp_lock); + } + return NULL; +} + +static void svc_clean_up_xprts(struct svc_serv *serv, struct net *net) +{ + struct svc_xprt *xprt; + + while ((xprt = svc_dequeue_net(serv, net))) { + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_delete_xprt(xprt); + } +} + +/* + * Server threads may still be running (especially in the case where the + * service is still running in other network namespaces). + * + * So we shut down sockets the same way we would on a running server, by + * setting XPT_CLOSE, enqueuing, and letting a thread pick it up to do + * the close. In the case there are no such other threads, + * threads running, svc_clean_up_xprts() does a simple version of a + * server's main event loop, and in the case where there are other + * threads, we may need to wait a little while and then check again to + * see if they're done. + */ +void svc_close_net(struct svc_serv *serv, struct net *net) +{ + int delay = 0; + + while (svc_close_list(serv, &serv->sv_permsocks, net) + + svc_close_list(serv, &serv->sv_tempsocks, net)) { + + svc_clean_up_xprts(serv, net); + msleep(delay++); } } @@ -1002,6 +1102,7 @@ static struct cache_deferred_req *svc_defer(struct cache_req *req) } svc_xprt_get(rqstp->rq_xprt); dr->xprt = rqstp->rq_xprt; + rqstp->rq_dropme = true; dr->handle.revisit = svc_revisit; return &dr->handle; @@ -1039,14 +1140,13 @@ static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt) if (!test_bit(XPT_DEFERRED, &xprt->xpt_flags)) return NULL; spin_lock(&xprt->xpt_lock); - clear_bit(XPT_DEFERRED, &xprt->xpt_flags); if (!list_empty(&xprt->xpt_deferred)) { dr = list_entry(xprt->xpt_deferred.next, struct svc_deferred_req, handle.recent); list_del_init(&dr->handle.recent); - set_bit(XPT_DEFERRED, &xprt->xpt_flags); - } + } else + clear_bit(XPT_DEFERRED, &xprt->xpt_flags); spin_unlock(&xprt->xpt_lock); return dr; } @@ -1055,6 +1155,7 @@ static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt) * svc_find_xprt - find an RPC transport instance * @serv: pointer to svc_serv to search * @xcl_name: C string containing transport's class name + * @net: owner net pointer * @af: Address family of transport's local address * @port: transport's IP port number * @@ -1067,7 +1168,8 @@ static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt) * service's list that has a matching class name. */ struct svc_xprt *svc_find_xprt(struct svc_serv *serv, const char *xcl_name, - const sa_family_t af, const unsigned short port) + struct net *net, const sa_family_t af, + const unsigned short port) { struct svc_xprt *xprt; struct svc_xprt *found = NULL; @@ -1078,6 +1180,8 @@ struct svc_xprt *svc_find_xprt(struct svc_serv *serv, const char *xcl_name, spin_lock_bh(&serv->sv_lock); list_for_each_entry(xprt, &serv->sv_permsocks, xpt_list) { + if (xprt->xpt_net != net) + continue; if (strcmp(xprt->xpt_class->xcl_name, xcl_name)) continue; if (af != AF_UNSPEC && af != xprt->xpt_local.ss_family) diff --git a/net/sunrpc/svcauth.c b/net/sunrpc/svcauth.c index 4e9393c2468..79c0f3459b5 100644 --- a/net/sunrpc/svcauth.c +++ b/net/sunrpc/svcauth.c @@ -54,6 +54,8 @@ svc_authenticate(struct svc_rqst *rqstp, __be32 *authp) } spin_unlock(&authtab_lock); + rqstp->rq_auth_slack = 0; + rqstp->rq_authop = aops; return aops->accept(rqstp, authp); } @@ -118,7 +120,6 @@ EXPORT_SYMBOL_GPL(svc_auth_unregister); #define DN_HASHBITS 6 #define DN_HASHMAX (1<<DN_HASHBITS) -#define DN_HASHMASK (DN_HASHMAX-1) static struct hlist_head auth_domain_table[DN_HASHMAX]; static spinlock_t auth_domain_lock = @@ -139,13 +140,12 @@ auth_domain_lookup(char *name, struct auth_domain *new) { struct auth_domain *hp; struct hlist_head *head; - struct hlist_node *np; head = &auth_domain_table[hash_str(name, DN_HASHBITS)]; spin_lock(&auth_domain_lock); - hlist_for_each_entry(hp, np, head, hash) { + hlist_for_each_entry(hp, head, hash) { if (strcmp(hp->name, name)==0) { kref_get(&hp->ref); spin_unlock(&auth_domain_lock); diff --git a/net/sunrpc/svcauth_unix.c b/net/sunrpc/svcauth_unix.c index afdcb0459a8..621ca7b4a15 100644 --- a/net/sunrpc/svcauth_unix.c +++ b/net/sunrpc/svcauth_unix.c @@ -6,16 +6,20 @@ #include <linux/sunrpc/svcsock.h> #include <linux/sunrpc/svcauth.h> #include <linux/sunrpc/gss_api.h> +#include <linux/sunrpc/addr.h> #include <linux/err.h> #include <linux/seq_file.h> #include <linux/hash.h> #include <linux/string.h> +#include <linux/slab.h> #include <net/sock.h> #include <net/ipv6.h> #include <linux/kernel.h> +#include <linux/user_namespace.h> #define RPCDBG_FACILITY RPCDBG_AUTH -#include <linux/sunrpc/clnt.h> + +#include "netns.h" /* * AUTHUNIX and AUTHNULL credentials are both handled here. @@ -27,12 +31,20 @@ struct unix_domain { struct auth_domain h; - int addr_changes; /* other stuff later */ }; +extern struct auth_ops svcauth_null; extern struct auth_ops svcauth_unix; +static void svcauth_unix_domain_release(struct auth_domain *dom) +{ + struct unix_domain *ud = container_of(dom, struct unix_domain, h); + + kfree(dom->name); + kfree(ud); +} + struct auth_domain *unix_domain_find(char *name) { struct auth_domain *rv; @@ -42,7 +54,7 @@ struct auth_domain *unix_domain_find(char *name) while(1) { if (rv) { if (new && rv != &new->h) - auth_domain_put(&new->h); + svcauth_unix_domain_release(&new->h); if (rv->flavour != &svcauth_unix) { auth_domain_put(rv); @@ -61,20 +73,11 @@ struct auth_domain *unix_domain_find(char *name) return NULL; } new->h.flavour = &svcauth_unix; - new->addr_changes = 0; rv = auth_domain_lookup(name, &new->h); } } EXPORT_SYMBOL_GPL(unix_domain_find); -static void svcauth_unix_domain_release(struct auth_domain *dom) -{ - struct unix_domain *ud = container_of(dom, struct unix_domain, h); - - kfree(dom->name); - kfree(ud); -} - /************************************************** * cache for IP address to unix_domain @@ -82,16 +85,13 @@ static void svcauth_unix_domain_release(struct auth_domain *dom) */ #define IP_HASHBITS 8 #define IP_HASHMAX (1<<IP_HASHBITS) -#define IP_HASHMASK (IP_HASHMAX-1) struct ip_map { struct cache_head h; char m_class[8]; /* e.g. "nfsd" */ struct in6_addr m_addr; struct unix_domain *m_client; - int m_add_change; }; -static struct cache_head *ip_table[IP_HASHMAX]; static void ip_map_put(struct kref *kref) { @@ -104,23 +104,9 @@ static void ip_map_put(struct kref *kref) kfree(im); } -#if IP_HASHBITS == 8 -/* hash_long on a 64 bit machine is currently REALLY BAD for - * IP addresses in reverse-endian (i.e. on a little-endian machine). - * So use a trivial but reliable hash instead - */ -static inline int hash_ip(__be32 ip) -{ - int hash = (__force u32)ip ^ ((__force u32)ip>>16); - return (hash ^ (hash>>8)) & 0xff; -} -#endif -static inline int hash_ip6(struct in6_addr ip) +static inline int hash_ip6(const struct in6_addr *ip) { - return (hash_ip(ip.s6_addr32[0]) ^ - hash_ip(ip.s6_addr32[1]) ^ - hash_ip(ip.s6_addr32[2]) ^ - hash_ip(ip.s6_addr32[3])); + return hash_32(ipv6_addr_hash(ip), IP_HASHBITS); } static int ip_map_match(struct cache_head *corig, struct cache_head *cnew) { @@ -135,7 +121,7 @@ static void ip_map_init(struct cache_head *cnew, struct cache_head *citem) struct ip_map *item = container_of(citem, struct ip_map, h); strcpy(new->m_class, item->m_class); - ipv6_addr_copy(&new->m_addr, &item->m_addr); + new->m_addr = item->m_addr; } static void update(struct cache_head *cnew, struct cache_head *citem) { @@ -144,7 +130,6 @@ static void update(struct cache_head *cnew, struct cache_head *citem) kref_get(&item->m_client->h.ref); new->m_client = item->m_client; - new->m_add_change = item->m_add_change; } static struct cache_head *ip_map_alloc(void) { @@ -172,13 +157,8 @@ static void ip_map_request(struct cache_detail *cd, (*bpp)[-1] = '\n'; } -static int ip_map_upcall(struct cache_detail *cd, struct cache_head *h) -{ - return sunrpc_cache_pipe_upcall(cd, h, ip_map_request); -} - -static struct ip_map *ip_map_lookup(char *class, struct in6_addr *addr); -static int ip_map_update(struct ip_map *ipm, struct unix_domain *udom, time_t expiry); +static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, struct in6_addr *addr); +static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, struct unix_domain *udom, time_t expiry); static int ip_map_parse(struct cache_detail *cd, char *mesg, int mlen) @@ -213,17 +193,16 @@ static int ip_map_parse(struct cache_detail *cd, len = qword_get(&mesg, buf, mlen); if (len <= 0) return -EINVAL; - if (rpc_pton(buf, len, &address.sa, sizeof(address)) == 0) + if (rpc_pton(cd->net, buf, len, &address.sa, sizeof(address)) == 0) return -EINVAL; switch (address.sa.sa_family) { case AF_INET: /* Form a mapped IPv4 address in sin6 */ - memset(&sin6, 0, sizeof(sin6)); sin6.sin6_family = AF_INET6; - sin6.sin6_addr.s6_addr32[2] = htonl(0xffff); - sin6.sin6_addr.s6_addr32[3] = address.s4.sin_addr.s_addr; + ipv6_addr_set_v4mapped(address.s4.sin_addr.s_addr, + &sin6.sin6_addr); break; -#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) +#if IS_ENABLED(CONFIG_IPV6) case AF_INET6: memcpy(&sin6, &address.s6, sizeof(sin6)); break; @@ -248,9 +227,9 @@ static int ip_map_parse(struct cache_detail *cd, dom = NULL; /* IPv6 scope IDs are ignored for now */ - ipmp = ip_map_lookup(class, &sin6.sin6_addr); + ipmp = __ip_map_lookup(cd, class, &sin6.sin6_addr); if (ipmp) { - err = ip_map_update(ipmp, + err = __ip_map_update(cd, ipmp, container_of(dom, struct unix_domain, h), expiry); } else @@ -277,7 +256,7 @@ static int ip_map_show(struct seq_file *m, } im = container_of(h, struct ip_map, h); /* class addr domain */ - ipv6_addr_copy(&addr, &im->m_addr); + addr = im->m_addr; if (test_bit(CACHE_VALID, &h->flags) && !test_bit(CACHE_NEGATIVE, &h->flags)) @@ -293,31 +272,17 @@ static int ip_map_show(struct seq_file *m, } -struct cache_detail ip_map_cache = { - .owner = THIS_MODULE, - .hash_size = IP_HASHMAX, - .hash_table = ip_table, - .name = "auth.unix.ip", - .cache_put = ip_map_put, - .cache_upcall = ip_map_upcall, - .cache_parse = ip_map_parse, - .cache_show = ip_map_show, - .match = ip_map_match, - .init = ip_map_init, - .update = update, - .alloc = ip_map_alloc, -}; - -static struct ip_map *ip_map_lookup(char *class, struct in6_addr *addr) +static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, + struct in6_addr *addr) { struct ip_map ip; struct cache_head *ch; strcpy(ip.m_class, class); - ipv6_addr_copy(&ip.m_addr, addr); - ch = sunrpc_cache_lookup(&ip_map_cache, &ip.h, + ip.m_addr = *addr; + ch = sunrpc_cache_lookup(cd, &ip.h, hash_str(class, IP_HASHBITS) ^ - hash_ip6(*addr)); + hash_ip6(addr)); if (ch) return container_of(ch, struct ip_map, h); @@ -325,7 +290,17 @@ static struct ip_map *ip_map_lookup(char *class, struct in6_addr *addr) return NULL; } -static int ip_map_update(struct ip_map *ipm, struct unix_domain *udom, time_t expiry) +static inline struct ip_map *ip_map_lookup(struct net *net, char *class, + struct in6_addr *addr) +{ + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + return __ip_map_lookup(sn->ip_map_cache, class, addr); +} + +static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, + struct unix_domain *udom, time_t expiry) { struct ip_map ip; struct cache_head *ch; @@ -334,96 +309,46 @@ static int ip_map_update(struct ip_map *ipm, struct unix_domain *udom, time_t ex ip.h.flags = 0; if (!udom) set_bit(CACHE_NEGATIVE, &ip.h.flags); - else { - ip.m_add_change = udom->addr_changes; - /* if this is from the legacy set_client system call, - * we need m_add_change to be one higher - */ - if (expiry == NEVER) - ip.m_add_change++; - } ip.h.expiry_time = expiry; - ch = sunrpc_cache_update(&ip_map_cache, - &ip.h, &ipm->h, + ch = sunrpc_cache_update(cd, &ip.h, &ipm->h, hash_str(ipm->m_class, IP_HASHBITS) ^ - hash_ip6(ipm->m_addr)); + hash_ip6(&ipm->m_addr)); if (!ch) return -ENOMEM; - cache_put(ch, &ip_map_cache); + cache_put(ch, cd); return 0; } -int auth_unix_add_addr(struct in6_addr *addr, struct auth_domain *dom) -{ - struct unix_domain *udom; - struct ip_map *ipmp; - - if (dom->flavour != &svcauth_unix) - return -EINVAL; - udom = container_of(dom, struct unix_domain, h); - ipmp = ip_map_lookup("nfsd", addr); - - if (ipmp) - return ip_map_update(ipmp, udom, NEVER); - else - return -ENOMEM; -} -EXPORT_SYMBOL_GPL(auth_unix_add_addr); - -int auth_unix_forget_old(struct auth_domain *dom) +static inline int ip_map_update(struct net *net, struct ip_map *ipm, + struct unix_domain *udom, time_t expiry) { - struct unix_domain *udom; + struct sunrpc_net *sn; - if (dom->flavour != &svcauth_unix) - return -EINVAL; - udom = container_of(dom, struct unix_domain, h); - udom->addr_changes++; - return 0; + sn = net_generic(net, sunrpc_net_id); + return __ip_map_update(sn->ip_map_cache, ipm, udom, expiry); } -EXPORT_SYMBOL_GPL(auth_unix_forget_old); -struct auth_domain *auth_unix_lookup(struct in6_addr *addr) +void svcauth_unix_purge(struct net *net) { - struct ip_map *ipm; - struct auth_domain *rv; - - ipm = ip_map_lookup("nfsd", addr); + struct sunrpc_net *sn; - if (!ipm) - return NULL; - if (cache_check(&ip_map_cache, &ipm->h, NULL)) - return NULL; - - if ((ipm->m_client->addr_changes - ipm->m_add_change) >0) { - if (test_and_set_bit(CACHE_NEGATIVE, &ipm->h.flags) == 0) - auth_domain_put(&ipm->m_client->h); - rv = NULL; - } else { - rv = &ipm->m_client->h; - kref_get(&rv->ref); - } - cache_put(&ipm->h, &ip_map_cache); - return rv; -} -EXPORT_SYMBOL_GPL(auth_unix_lookup); - -void svcauth_unix_purge(void) -{ - cache_purge(&ip_map_cache); + sn = net_generic(net, sunrpc_net_id); + cache_purge(sn->ip_map_cache); } EXPORT_SYMBOL_GPL(svcauth_unix_purge); static inline struct ip_map * -ip_map_cached_get(struct svc_rqst *rqstp) +ip_map_cached_get(struct svc_xprt *xprt) { struct ip_map *ipm = NULL; - struct svc_xprt *xprt = rqstp->rq_xprt; + struct sunrpc_net *sn; if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) { spin_lock(&xprt->xpt_lock); ipm = xprt->xpt_auth_cache; if (ipm != NULL) { - if (!cache_valid(&ipm->h)) { + sn = net_generic(xprt->xpt_net, sunrpc_net_id); + if (cache_is_expired(sn->ip_map_cache, &ipm->h)) { /* * The entry has been invalidated since it was * remembered, e.g. by a second mount from the @@ -431,7 +356,7 @@ ip_map_cached_get(struct svc_rqst *rqstp) */ xprt->xpt_auth_cache = NULL; spin_unlock(&xprt->xpt_lock); - cache_put(&ipm->h, &ip_map_cache); + cache_put(&ipm->h, sn->ip_map_cache); return NULL; } cache_get(&ipm->h); @@ -442,10 +367,8 @@ ip_map_cached_get(struct svc_rqst *rqstp) } static inline void -ip_map_cached_put(struct svc_rqst *rqstp, struct ip_map *ipm) +ip_map_cached_put(struct svc_xprt *xprt, struct ip_map *ipm) { - struct svc_xprt *xprt = rqstp->rq_xprt; - if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) { spin_lock(&xprt->xpt_lock); if (xprt->xpt_auth_cache == NULL) { @@ -455,15 +378,26 @@ ip_map_cached_put(struct svc_rqst *rqstp, struct ip_map *ipm) } spin_unlock(&xprt->xpt_lock); } - if (ipm) - cache_put(&ipm->h, &ip_map_cache); + if (ipm) { + struct sunrpc_net *sn; + + sn = net_generic(xprt->xpt_net, sunrpc_net_id); + cache_put(&ipm->h, sn->ip_map_cache); + } } void -svcauth_unix_info_release(void *info) +svcauth_unix_info_release(struct svc_xprt *xpt) { - struct ip_map *ipm = info; - cache_put(&ipm->h, &ip_map_cache); + struct ip_map *ipm; + + ipm = xpt->xpt_auth_cache; + if (ipm != NULL) { + struct sunrpc_net *sn; + + sn = net_generic(xpt->xpt_net, sunrpc_net_id); + cache_put(&ipm->h, sn->ip_map_cache); + } } /**************************************************************************** @@ -473,14 +407,17 @@ svcauth_unix_info_release(void *info) */ #define GID_HASHBITS 8 #define GID_HASHMAX (1<<GID_HASHBITS) -#define GID_HASHMASK (GID_HASHMAX - 1) struct unix_gid { struct cache_head h; - uid_t uid; + kuid_t uid; struct group_info *gi; }; -static struct cache_head *gid_table[GID_HASHMAX]; + +static int unix_gid_hash(kuid_t uid) +{ + return hash_long(from_kuid(&init_user_ns, uid), GID_HASHBITS); +} static void unix_gid_put(struct kref *kref) { @@ -496,7 +433,7 @@ static int unix_gid_match(struct cache_head *corig, struct cache_head *cnew) { struct unix_gid *orig = container_of(corig, struct unix_gid, h); struct unix_gid *new = container_of(cnew, struct unix_gid, h); - return orig->uid == new->uid; + return uid_eq(orig->uid, new->uid); } static void unix_gid_init(struct cache_head *cnew, struct cache_head *citem) { @@ -528,24 +465,19 @@ static void unix_gid_request(struct cache_detail *cd, char tuid[20]; struct unix_gid *ug = container_of(h, struct unix_gid, h); - snprintf(tuid, 20, "%u", ug->uid); + snprintf(tuid, 20, "%u", from_kuid(&init_user_ns, ug->uid)); qword_add(bpp, blen, tuid); (*bpp)[-1] = '\n'; } -static int unix_gid_upcall(struct cache_detail *cd, struct cache_head *h) -{ - return sunrpc_cache_pipe_upcall(cd, h, unix_gid_request); -} - -static struct unix_gid *unix_gid_lookup(uid_t uid); -extern struct cache_detail unix_gid_cache; +static struct unix_gid *unix_gid_lookup(struct cache_detail *cd, kuid_t uid); static int unix_gid_parse(struct cache_detail *cd, char *mesg, int mlen) { /* uid expiry Ngid gid0 gid1 ... gidN-1 */ - int uid; + int id; + kuid_t uid; int gids; int rv; int i; @@ -553,13 +485,14 @@ static int unix_gid_parse(struct cache_detail *cd, time_t expiry; struct unix_gid ug, *ugp; - if (mlen <= 0 || mesg[mlen-1] != '\n') + if (mesg[mlen - 1] != '\n') return -EINVAL; mesg[mlen-1] = 0; - rv = get_int(&mesg, &uid); + rv = get_int(&mesg, &id); if (rv) return -EINVAL; + uid = make_kuid(&init_user_ns, id); ug.uid = uid; expiry = get_expiry(&mesg); @@ -576,26 +509,30 @@ static int unix_gid_parse(struct cache_detail *cd, for (i = 0 ; i < gids ; i++) { int gid; + kgid_t kgid; rv = get_int(&mesg, &gid); err = -EINVAL; if (rv) goto out; - GROUP_AT(ug.gi, i) = gid; + kgid = make_kgid(&init_user_ns, gid); + if (!gid_valid(kgid)) + goto out; + GROUP_AT(ug.gi, i) = kgid; } - ugp = unix_gid_lookup(uid); + ugp = unix_gid_lookup(cd, uid); if (ugp) { struct cache_head *ch; ug.h.flags = 0; ug.h.expiry_time = expiry; - ch = sunrpc_cache_update(&unix_gid_cache, + ch = sunrpc_cache_update(cd, &ug.h, &ugp->h, - hash_long(uid, GID_HASHBITS)); + unix_gid_hash(uid)); if (!ch) err = -ENOMEM; else { err = 0; - cache_put(ch, &unix_gid_cache); + cache_put(ch, cd); } } else err = -ENOMEM; @@ -609,6 +546,7 @@ static int unix_gid_show(struct seq_file *m, struct cache_detail *cd, struct cache_head *h) { + struct user_namespace *user_ns = &init_user_ns; struct unix_gid *ug; int i; int glen; @@ -624,20 +562,19 @@ static int unix_gid_show(struct seq_file *m, else glen = 0; - seq_printf(m, "%u %d:", ug->uid, glen); + seq_printf(m, "%u %d:", from_kuid_munged(user_ns, ug->uid), glen); for (i = 0; i < glen; i++) - seq_printf(m, " %d", GROUP_AT(ug->gi, i)); + seq_printf(m, " %d", from_kgid_munged(user_ns, GROUP_AT(ug->gi, i))); seq_printf(m, "\n"); return 0; } -struct cache_detail unix_gid_cache = { +static struct cache_detail unix_gid_cache_template = { .owner = THIS_MODULE, .hash_size = GID_HASHMAX, - .hash_table = gid_table, .name = "auth.unix.gid", .cache_put = unix_gid_put, - .cache_upcall = unix_gid_upcall, + .cache_request = unix_gid_request, .cache_parse = unix_gid_parse, .cache_show = unix_gid_show, .match = unix_gid_match, @@ -646,36 +583,68 @@ struct cache_detail unix_gid_cache = { .alloc = unix_gid_alloc, }; -static struct unix_gid *unix_gid_lookup(uid_t uid) +int unix_gid_cache_create(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&unix_gid_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->unix_gid_cache = cd; + return 0; +} + +void unix_gid_cache_destroy(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->unix_gid_cache; + + sn->unix_gid_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} + +static struct unix_gid *unix_gid_lookup(struct cache_detail *cd, kuid_t uid) { struct unix_gid ug; struct cache_head *ch; ug.uid = uid; - ch = sunrpc_cache_lookup(&unix_gid_cache, &ug.h, - hash_long(uid, GID_HASHBITS)); + ch = sunrpc_cache_lookup(cd, &ug.h, unix_gid_hash(uid)); if (ch) return container_of(ch, struct unix_gid, h); else return NULL; } -static struct group_info *unix_gid_find(uid_t uid, struct svc_rqst *rqstp) +static struct group_info *unix_gid_find(kuid_t uid, struct svc_rqst *rqstp) { struct unix_gid *ug; struct group_info *gi; int ret; + struct sunrpc_net *sn = net_generic(rqstp->rq_xprt->xpt_net, + sunrpc_net_id); - ug = unix_gid_lookup(uid); + ug = unix_gid_lookup(sn->unix_gid_cache, uid); if (!ug) return ERR_PTR(-EAGAIN); - ret = cache_check(&unix_gid_cache, &ug->h, &rqstp->rq_chandle); + ret = cache_check(sn->unix_gid_cache, &ug->h, &rqstp->rq_chandle); switch (ret) { case -ENOENT: return ERR_PTR(-ENOENT); + case -ETIMEDOUT: + return ERR_PTR(-ESHUTDOWN); case 0: gi = get_group_info(ug->gi); - cache_put(&ug->h, &unix_gid_cache); + cache_put(&ug->h, sn->unix_gid_cache); return gi; default: return ERR_PTR(-EAGAIN); @@ -690,6 +659,9 @@ svcauth_unix_set_client(struct svc_rqst *rqstp) struct ip_map *ipm; struct group_info *gi; struct svc_cred *cred = &rqstp->rq_cred; + struct svc_xprt *xprt = rqstp->rq_xprt; + struct net *net = xprt->xpt_net; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); switch (rqstp->rq_addr.ss_family) { case AF_INET: @@ -708,26 +680,27 @@ svcauth_unix_set_client(struct svc_rqst *rqstp) if (rqstp->rq_proc == 0) return SVC_OK; - ipm = ip_map_cached_get(rqstp); + ipm = ip_map_cached_get(xprt); if (ipm == NULL) - ipm = ip_map_lookup(rqstp->rq_server->sv_program->pg_class, + ipm = __ip_map_lookup(sn->ip_map_cache, rqstp->rq_server->sv_program->pg_class, &sin6->sin6_addr); if (ipm == NULL) return SVC_DENIED; - switch (cache_check(&ip_map_cache, &ipm->h, &rqstp->rq_chandle)) { + switch (cache_check(sn->ip_map_cache, &ipm->h, &rqstp->rq_chandle)) { default: BUG(); - case -EAGAIN: case -ETIMEDOUT: + return SVC_CLOSE; + case -EAGAIN: return SVC_DROP; case -ENOENT: return SVC_DENIED; case 0: rqstp->rq_client = &ipm->m_client->h; kref_get(&rqstp->rq_client->ref); - ip_map_cached_put(rqstp, ipm); + ip_map_cached_put(xprt, ipm); break; } @@ -735,6 +708,8 @@ svcauth_unix_set_client(struct svc_rqst *rqstp) switch (PTR_ERR(gi)) { case -EAGAIN: return SVC_DROP; + case -ESHUTDOWN: + return SVC_CLOSE; case -ENOENT: break; default: @@ -754,6 +729,7 @@ svcauth_null_accept(struct svc_rqst *rqstp, __be32 *authp) struct svc_cred *cred = &rqstp->rq_cred; cred->cr_group_info = NULL; + cred->cr_principal = NULL; rqstp->rq_client = NULL; if (argv->iov_len < 3*4) @@ -771,17 +747,17 @@ svcauth_null_accept(struct svc_rqst *rqstp, __be32 *authp) } /* Signal that mapping to nobody uid/gid is required */ - cred->cr_uid = (uid_t) -1; - cred->cr_gid = (gid_t) -1; + cred->cr_uid = INVALID_UID; + cred->cr_gid = INVALID_GID; cred->cr_group_info = groups_alloc(0); if (cred->cr_group_info == NULL) - return SVC_DROP; /* kmalloc failure - client must retry */ + return SVC_CLOSE; /* kmalloc failure - client must retry */ /* Put NULL verifier */ svc_putnl(resv, RPC_AUTH_NULL); svc_putnl(resv, 0); - rqstp->rq_flavor = RPC_AUTH_NULL; + rqstp->rq_cred.cr_flavor = RPC_AUTH_NULL; return SVC_OK; } @@ -819,6 +795,7 @@ svcauth_unix_accept(struct svc_rqst *rqstp, __be32 *authp) int len = argv->iov_len; cred->cr_group_info = NULL; + cred->cr_principal = NULL; rqstp->rq_client = NULL; if ((len -= 3*4) < 0) @@ -831,17 +808,25 @@ svcauth_unix_accept(struct svc_rqst *rqstp, __be32 *authp) goto badcred; argv->iov_base = (void*)((__be32*)argv->iov_base + slen); /* skip machname */ argv->iov_len -= slen*4; - - cred->cr_uid = svc_getnl(argv); /* uid */ - cred->cr_gid = svc_getnl(argv); /* gid */ + /* + * Note: we skip uid_valid()/gid_valid() checks here for + * backwards compatibility with clients that use -1 id's. + * Instead, -1 uid or gid is later mapped to the + * (export-specific) anonymous id by nfsd_setuser. + * Supplementary gid's will be left alone. + */ + cred->cr_uid = make_kuid(&init_user_ns, svc_getnl(argv)); /* uid */ + cred->cr_gid = make_kgid(&init_user_ns, svc_getnl(argv)); /* gid */ slen = svc_getnl(argv); /* gids length */ if (slen > 16 || (len -= (slen + 2)*4) < 0) goto badcred; cred->cr_group_info = groups_alloc(slen); if (cred->cr_group_info == NULL) - return SVC_DROP; - for (i = 0; i < slen; i++) - GROUP_AT(cred->cr_group_info, i) = svc_getnl(argv); + return SVC_CLOSE; + for (i = 0; i < slen; i++) { + kgid_t kgid = make_kgid(&init_user_ns, svc_getnl(argv)); + GROUP_AT(cred->cr_group_info, i) = kgid; + } if (svc_getu32(argv) != htonl(RPC_AUTH_NULL) || svc_getu32(argv) != 0) { *authp = rpc_autherr_badverf; return SVC_DENIED; @@ -851,7 +836,7 @@ svcauth_unix_accept(struct svc_rqst *rqstp, __be32 *authp) svc_putnl(resv, RPC_AUTH_NULL); svc_putnl(resv, 0); - rqstp->rq_flavor = RPC_AUTH_UNIX; + rqstp->rq_cred.cr_flavor = RPC_AUTH_UNIX; return SVC_OK; badcred: @@ -885,3 +870,45 @@ struct auth_ops svcauth_unix = { .set_client = svcauth_unix_set_client, }; +static struct cache_detail ip_map_cache_template = { + .owner = THIS_MODULE, + .hash_size = IP_HASHMAX, + .name = "auth.unix.ip", + .cache_put = ip_map_put, + .cache_request = ip_map_request, + .cache_parse = ip_map_parse, + .cache_show = ip_map_show, + .match = ip_map_match, + .init = ip_map_init, + .update = update, + .alloc = ip_map_alloc, +}; + +int ip_map_cache_create(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&ip_map_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->ip_map_cache = cd; + return 0; +} + +void ip_map_cache_destroy(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->ip_map_cache; + + sn->ip_map_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index a29f259204e..b507cd327d9 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -21,6 +21,7 @@ #include <linux/kernel.h> #include <linux/sched.h> +#include <linux/module.h> #include <linux/errno.h> #include <linux/fcntl.h> #include <linux/net.h> @@ -42,6 +43,7 @@ #include <net/tcp_states.h> #include <asm/uaccess.h> #include <asm/ioctls.h> +#include <trace/events/skb.h> #include <linux/sunrpc/types.h> #include <linux/sunrpc/clnt.h> @@ -51,12 +53,14 @@ #include <linux/sunrpc/stats.h> #include <linux/sunrpc/xprt.h> +#include "sunrpc.h" + #define RPCDBG_FACILITY RPCDBG_SVCXPRT static struct svc_sock *svc_setup_socket(struct svc_serv *, struct socket *, - int *errp, int flags); -static void svc_udp_data_ready(struct sock *, int); + int flags); +static void svc_udp_data_ready(struct sock *); static int svc_udp_recvfrom(struct svc_rqst *); static int svc_udp_sendto(struct svc_rqst *); static void svc_sock_detach(struct svc_xprt *); @@ -64,7 +68,15 @@ static void svc_tcp_sock_detach(struct svc_xprt *); static void svc_sock_free(struct svc_xprt *); static struct svc_xprt *svc_create_socket(struct svc_serv *, int, - struct sockaddr *, int, int); + struct net *, struct sockaddr *, + int, int); +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static struct svc_xprt *svc_bc_create_socket(struct svc_serv *, int, + struct net *, struct sockaddr *, + int, int); +static void svc_bc_sock_free(struct svc_xprt *xprt); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + #ifdef CONFIG_DEBUG_LOCK_ALLOC static struct lock_class_key svc_key[2]; static struct lock_class_key svc_slock_key[2]; @@ -72,7 +84,11 @@ static struct lock_class_key svc_slock_key[2]; static void svc_reclassify_socket(struct socket *sock) { struct sock *sk = sock->sk; - BUG_ON(sock_owned_by_user(sk)); + + WARN_ON_ONCE(sock_owned_by_user(sk)); + if (sock_owned_by_user(sk)) + return; + switch (sk->sk_family) { case AF_INET: sock_lock_init_class_and_name(sk, "slock-AF_INET-NFSD", @@ -133,24 +149,24 @@ static void svc_set_cmsg_data(struct svc_rqst *rqstp, struct cmsghdr *cmh) cmh->cmsg_level = SOL_IP; cmh->cmsg_type = IP_PKTINFO; pki->ipi_ifindex = 0; - pki->ipi_spec_dst.s_addr = rqstp->rq_daddr.addr.s_addr; + pki->ipi_spec_dst.s_addr = + svc_daddr_in(rqstp)->sin_addr.s_addr; cmh->cmsg_len = CMSG_LEN(sizeof(*pki)); } break; case AF_INET6: { struct in6_pktinfo *pki = CMSG_DATA(cmh); + struct sockaddr_in6 *daddr = svc_daddr_in6(rqstp); cmh->cmsg_level = SOL_IPV6; cmh->cmsg_type = IPV6_PKTINFO; - pki->ipi6_ifindex = 0; - ipv6_addr_copy(&pki->ipi6_addr, - &rqstp->rq_daddr.addr6); + pki->ipi6_ifindex = daddr->sin6_scope_id; + pki->ipi6_addr = daddr->sin6_addr; cmh->cmsg_len = CMSG_LEN(sizeof(*pki)); } break; } - return; } /* @@ -275,12 +291,14 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining) &inet_sk(sk)->inet_rcv_saddr, inet_sk(sk)->inet_num); break; +#if IS_ENABLED(CONFIG_IPV6) case PF_INET6: len = snprintf(buf, remaining, "ipv6 %s %pI6 %d\n", proto_name, - &inet6_sk(sk)->rcv_saddr, + &sk->sk_v6_rcv_saddr, inet_sk(sk)->inet_num); break; +#endif default: len = snprintf(buf, remaining, "*unknown-%d*\n", sk->sk_family); @@ -293,55 +311,6 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining) return len; } -/** - * svc_sock_names - construct a list of listener names in a string - * @serv: pointer to RPC service - * @buf: pointer to a buffer to fill in with socket names - * @buflen: size of the buffer to be filled - * @toclose: pointer to '\0'-terminated C string containing the name - * of a listener to be closed - * - * Fills in @buf with a '\n'-separated list of names of listener - * sockets. If @toclose is not NULL, the socket named by @toclose - * is closed, and is not included in the output list. - * - * Returns positive length of the socket name string, or a negative - * errno value on error. - */ -int svc_sock_names(struct svc_serv *serv, char *buf, const size_t buflen, - const char *toclose) -{ - struct svc_sock *svsk, *closesk = NULL; - int len = 0; - - if (!serv) - return 0; - - spin_lock_bh(&serv->sv_lock); - list_for_each_entry(svsk, &serv->sv_permsocks, sk_xprt.xpt_list) { - int onelen = svc_one_sock_name(svsk, buf + len, buflen - len); - if (onelen < 0) { - len = onelen; - break; - } - if (toclose && strcmp(toclose, buf + len) == 0) - closesk = svsk; - else - len += onelen; - } - spin_unlock_bh(&serv->sv_lock); - - if (closesk) - /* Should unregister with portmap, but you cannot - * unregister just one protocol... - */ - svc_close_xprt(&closesk->sk_xprt); - else if (toclose) - return -ENOENT; - return len; -} -EXPORT_SYMBOL_GPL(svc_sock_names); - /* * Check input queue length */ @@ -378,6 +347,33 @@ static int svc_recvfrom(struct svc_rqst *rqstp, struct kvec *iov, int nr, return len; } +static int svc_partial_recvfrom(struct svc_rqst *rqstp, + struct kvec *iov, int nr, + int buflen, unsigned int base) +{ + size_t save_iovlen; + void *save_iovbase; + unsigned int i; + int ret; + + if (base == 0) + return svc_recvfrom(rqstp, iov, nr, buflen); + + for (i = 0; i < nr; i++) { + if (iov[i].iov_len > base) + break; + base -= iov[i].iov_len; + } + save_iovlen = iov[i].iov_len; + save_iovbase = iov[i].iov_base; + iov[i].iov_len -= base; + iov[i].iov_base += base; + ret = svc_recvfrom(rqstp, &iov[i], nr - i, buflen); + iov[i].iov_len = save_iovlen; + iov[i].iov_base = save_iovbase; + return ret; +} + /* * Set socket snd and rcv buffer lengths */ @@ -400,27 +396,33 @@ static void svc_sock_setbufsize(struct socket *sock, unsigned int snd, lock_sock(sock->sk); sock->sk->sk_sndbuf = snd * 2; sock->sk->sk_rcvbuf = rcv * 2; - sock->sk->sk_userlocks |= SOCK_SNDBUF_LOCK|SOCK_RCVBUF_LOCK; sock->sk->sk_write_space(sock->sk); release_sock(sock->sk); #endif } + +static int svc_sock_secure_port(struct svc_rqst *rqstp) +{ + return svc_port_is_privileged(svc_addr(rqstp)); +} + /* * INET callback when data has been received on the socket. */ -static void svc_udp_data_ready(struct sock *sk, int count) +static void svc_udp_data_ready(struct sock *sk) { struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + wait_queue_head_t *wq = sk_sleep(sk); if (svsk) { - dprintk("svc: socket %p(inet %p), count=%d, busy=%d\n", - svsk, sk, count, + dprintk("svc: socket %p(inet %p), busy=%d\n", + svsk, sk, test_bit(XPT_BUSY, &svsk->sk_xprt.xpt_flags)); set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); svc_xprt_enqueue(&svsk->sk_xprt); } - if (sk->sk_sleep && waitqueue_active(sk->sk_sleep)) - wake_up_interruptible(sk->sk_sleep); + if (wq && waitqueue_active(wq)) + wake_up_interruptible(wq); } /* @@ -429,6 +431,7 @@ static void svc_udp_data_ready(struct sock *sk, int count) static void svc_write_space(struct sock *sk) { struct svc_sock *svsk = (struct svc_sock *)(sk->sk_user_data); + wait_queue_head_t *wq = sk_sleep(sk); if (svsk) { dprintk("svc: socket %p(inet %p), write_space busy=%d\n", @@ -436,10 +439,10 @@ static void svc_write_space(struct sock *sk) svc_xprt_enqueue(&svsk->sk_xprt); } - if (sk->sk_sleep && waitqueue_active(sk->sk_sleep)) { + if (wq && waitqueue_active(wq)) { dprintk("RPC svc_write_space: someone sleeping on %p\n", svsk); - wake_up_interruptible(sk->sk_sleep); + wake_up_interruptible(wq); } } @@ -447,7 +450,7 @@ static void svc_tcp_write_space(struct sock *sk) { struct socket *sock = sk->sk_socket; - if (sk_stream_wspace(sk) >= sk_stream_min_wspace(sk) && sock) + if (sk_stream_is_writeable(sk) && sock) clear_bit(SOCK_NOSPACE, &sock->flags); svc_write_space(sk); } @@ -459,22 +462,31 @@ static int svc_udp_get_dest_address4(struct svc_rqst *rqstp, struct cmsghdr *cmh) { struct in_pktinfo *pki = CMSG_DATA(cmh); + struct sockaddr_in *daddr = svc_daddr_in(rqstp); + if (cmh->cmsg_type != IP_PKTINFO) return 0; - rqstp->rq_daddr.addr.s_addr = pki->ipi_spec_dst.s_addr; + + daddr->sin_family = AF_INET; + daddr->sin_addr.s_addr = pki->ipi_spec_dst.s_addr; return 1; } /* - * See net/ipv6/datagram.c : datagram_recv_ctl + * See net/ipv6/datagram.c : ip6_datagram_recv_ctl */ static int svc_udp_get_dest_address6(struct svc_rqst *rqstp, struct cmsghdr *cmh) { struct in6_pktinfo *pki = CMSG_DATA(cmh); + struct sockaddr_in6 *daddr = svc_daddr_in6(rqstp); + if (cmh->cmsg_type != IPV6_PKTINFO) return 0; - ipv6_addr_copy(&rqstp->rq_daddr.addr6, &pki->ipi6_addr); + + daddr->sin6_family = AF_INET6; + daddr->sin6_addr = pki->ipi6_addr; + daddr->sin6_scope_id = pki->ipi6_ifindex; return 1; } @@ -547,12 +559,9 @@ static int svc_udp_recvfrom(struct svc_rqst *rqstp) dprintk("svc: recvfrom returned error %d\n", -err); set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); } - svc_xprt_received(&svsk->sk_xprt); - return -EAGAIN; + return 0; } len = svc_addr_len(svc_addr(rqstp)); - if (len == 0) - return -EAFNOSUPPORT; rqstp->rq_addrlen = len; if (skb->tstamp.tv64 == 0) { skb->tstamp = ktime_get_real(); @@ -562,25 +571,17 @@ static int svc_udp_recvfrom(struct svc_rqst *rqstp) svsk->sk_sk->sk_stamp = skb->tstamp; set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); /* there may be more data... */ - /* - * Maybe more packets - kick another thread ASAP. - */ - svc_xprt_received(&svsk->sk_xprt); - len = skb->len - sizeof(struct udphdr); rqstp->rq_arg.len = len; rqstp->rq_prot = IPPROTO_UDP; if (!svc_udp_get_dest_address(rqstp, cmh)) { - if (net_ratelimit()) - printk(KERN_WARNING - "svc: received unknown control message %d/%d; " - "dropping RPC reply datagram\n", - cmh->cmsg_level, cmh->cmsg_type); - skb_free_datagram_locked(svsk->sk_sk, skb); - return 0; + net_warn_ratelimited("svc: received unknown control message %d/%d; dropping RPC reply datagram\n", + cmh->cmsg_level, cmh->cmsg_type); + goto out_free; } + rqstp->rq_daddrlen = svc_addr_len(svc_daddr(rqstp)); if (skb_is_nonlinear(skb)) { /* we have to copy */ @@ -588,8 +589,7 @@ static int svc_udp_recvfrom(struct svc_rqst *rqstp) if (csum_partial_copy_to_xdr(&rqstp->rq_arg, skb)) { local_bh_enable(); /* checksum error */ - skb_free_datagram_locked(svsk->sk_sk, skb); - return 0; + goto out_free; } local_bh_enable(); skb_free_datagram_locked(svsk->sk_sk, skb); @@ -598,10 +598,8 @@ static int svc_udp_recvfrom(struct svc_rqst *rqstp) rqstp->rq_arg.head[0].iov_base = skb->data + sizeof(struct udphdr); rqstp->rq_arg.head[0].iov_len = len; - if (skb_checksum_complete(skb)) { - skb_free_datagram_locked(svsk->sk_sk, skb); - return 0; - } + if (skb_checksum_complete(skb)) + goto out_free; rqstp->rq_xprt_ctxt = skb; } @@ -615,11 +613,16 @@ static int svc_udp_recvfrom(struct svc_rqst *rqstp) rqstp->rq_respages = rqstp->rq_pages + 1 + DIV_ROUND_UP(rqstp->rq_arg.page_len, PAGE_SIZE); } + rqstp->rq_next_page = rqstp->rq_respages+1; if (serv->sv_stats) serv->sv_stats->netudpcnt++; return len; +out_free: + trace_kfree_skb(skb, svc_udp_recvfrom); + skb_free_datagram_locked(svsk->sk_sk, skb); + return 0; } static int @@ -664,10 +667,11 @@ static struct svc_xprt *svc_udp_accept(struct svc_xprt *xprt) } static struct svc_xprt *svc_udp_create(struct svc_serv *serv, + struct net *net, struct sockaddr *sa, int salen, int flags) { - return svc_create_socket(serv, IPPROTO_UDP, sa, salen, flags); + return svc_create_socket(serv, IPPROTO_UDP, net, sa, salen, flags); } static struct svc_xprt_ops svc_udp_ops = { @@ -680,6 +684,7 @@ static struct svc_xprt_ops svc_udp_ops = { .xpo_prep_reply_hdr = svc_udp_prep_reply_hdr, .xpo_has_wspace = svc_udp_has_wspace, .xpo_accept = svc_udp_accept, + .xpo_secure_port = svc_sock_secure_port, }; static struct svc_xprt_class svc_udp_class = { @@ -693,7 +698,8 @@ static void svc_udp_init(struct svc_sock *svsk, struct svc_serv *serv) { int err, level, optname, one = 1; - svc_xprt_init(&svc_udp_class, &svsk->sk_xprt, serv); + svc_xprt_init(sock_net(svsk->sk_sock->sk), &svc_udp_class, + &svsk->sk_xprt, serv); clear_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags); svsk->sk_sk->sk_data_ready = svc_udp_data_ready; svsk->sk_sk->sk_write_space = svc_write_space; @@ -732,9 +738,10 @@ static void svc_udp_init(struct svc_sock *svsk, struct svc_serv *serv) * A data_ready event on a listening socket means there's a connection * pending. Do not use state_change as a substitute for it. */ -static void svc_tcp_listen_data_ready(struct sock *sk, int count_unused) +static void svc_tcp_listen_data_ready(struct sock *sk) { struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + wait_queue_head_t *wq; dprintk("svc: socket %p TCP (listen) state change %d\n", sk, sk->sk_state); @@ -757,8 +764,9 @@ static void svc_tcp_listen_data_ready(struct sock *sk, int count_unused) printk("svc: socket %p: no user data\n", sk); } - if (sk->sk_sleep && waitqueue_active(sk->sk_sleep)) - wake_up_interruptible_all(sk->sk_sleep); + wq = sk_sleep(sk); + if (wq && waitqueue_active(wq)) + wake_up_interruptible_all(wq); } /* @@ -767,6 +775,7 @@ static void svc_tcp_listen_data_ready(struct sock *sk, int count_unused) static void svc_tcp_state_change(struct sock *sk) { struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + wait_queue_head_t *wq = sk_sleep(sk); dprintk("svc: socket %p TCP (connected) state change %d (svsk %p)\n", sk, sk->sk_state, sk->sk_user_data); @@ -777,13 +786,14 @@ static void svc_tcp_state_change(struct sock *sk) set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); svc_xprt_enqueue(&svsk->sk_xprt); } - if (sk->sk_sleep && waitqueue_active(sk->sk_sleep)) - wake_up_interruptible_all(sk->sk_sleep); + if (wq && waitqueue_active(wq)) + wake_up_interruptible_all(wq); } -static void svc_tcp_data_ready(struct sock *sk, int count) +static void svc_tcp_data_ready(struct sock *sk) { struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + wait_queue_head_t *wq = sk_sleep(sk); dprintk("svc: socket %p TCP data ready (svsk %p)\n", sk, sk->sk_user_data); @@ -791,8 +801,8 @@ static void svc_tcp_data_ready(struct sock *sk, int count) set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); svc_xprt_enqueue(&svsk->sk_xprt); } - if (sk->sk_sleep && waitqueue_active(sk->sk_sleep)) - wake_up_interruptible(sk->sk_sleep); + if (wq && waitqueue_active(wq)) + wake_up_interruptible(wq); } /* @@ -820,18 +830,17 @@ static struct svc_xprt *svc_tcp_accept(struct svc_xprt *xprt) if (err == -ENOMEM) printk(KERN_WARNING "%s: no more sockets!\n", serv->sv_name); - else if (err != -EAGAIN && net_ratelimit()) - printk(KERN_WARNING "%s: accept failed (err %d)!\n", - serv->sv_name, -err); + else if (err != -EAGAIN) + net_warn_ratelimited("%s: accept failed (err %d)!\n", + serv->sv_name, -err); return NULL; } set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); err = kernel_getpeername(newsock, sin, &slen); if (err < 0) { - if (net_ratelimit()) - printk(KERN_WARNING "%s: peername failed (err %d)!\n", - serv->sv_name, -err); + net_warn_ratelimited("%s: peername failed (err %d)!\n", + serv->sv_name, -err); goto failed; /* aborted connection or whatever */ } @@ -840,8 +849,7 @@ static struct svc_xprt *svc_tcp_accept(struct svc_xprt *xprt) * tell us anything. For now just warn about unpriv connections. */ if (!svc_port_is_privileged(sin)) { - dprintk(KERN_WARNING - "%s: connect from unprivileged port: %s\n", + dprintk("%s: connect from unprivileged port: %s\n", serv->sv_name, __svc_print_addr(sin, buf, sizeof(buf))); } @@ -853,8 +861,9 @@ static struct svc_xprt *svc_tcp_accept(struct svc_xprt *xprt) */ newsock->sk->sk_sndtimeo = HZ*30; - if (!(newsvsk = svc_setup_socket(serv, newsock, &err, - (SVC_SOCK_ANONYMOUS | SVC_SOCK_TEMPORARY)))) + newsvsk = svc_setup_socket(serv, newsock, + (SVC_SOCK_ANONYMOUS | SVC_SOCK_TEMPORARY)); + if (IS_ERR(newsvsk)) goto failed; svc_xprt_set_remote(&newsvsk->sk_xprt, sin, slen); err = kernel_getsockname(newsock, sin, &slen); @@ -864,6 +873,10 @@ static struct svc_xprt *svc_tcp_accept(struct svc_xprt *xprt) } svc_xprt_set_local(&newsvsk->sk_xprt, sin, slen); + if (sock_is_loopback(newsock->sk)) + set_bit(XPT_LOCAL, &newsvsk->sk_xprt.xpt_flags); + else + clear_bit(XPT_LOCAL, &newsvsk->sk_xprt.xpt_flags); if (serv->sv_stats) serv->sv_stats->nettcpconn++; @@ -874,40 +887,76 @@ failed: return NULL; } +static unsigned int svc_tcp_restore_pages(struct svc_sock *svsk, struct svc_rqst *rqstp) +{ + unsigned int i, len, npages; + + if (svsk->sk_datalen == 0) + return 0; + len = svsk->sk_datalen; + npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + if (rqstp->rq_pages[i] != NULL) + put_page(rqstp->rq_pages[i]); + BUG_ON(svsk->sk_pages[i] == NULL); + rqstp->rq_pages[i] = svsk->sk_pages[i]; + svsk->sk_pages[i] = NULL; + } + rqstp->rq_arg.head[0].iov_base = page_address(rqstp->rq_pages[0]); + return len; +} + +static void svc_tcp_save_pages(struct svc_sock *svsk, struct svc_rqst *rqstp) +{ + unsigned int i, len, npages; + + if (svsk->sk_datalen == 0) + return; + len = svsk->sk_datalen; + npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + svsk->sk_pages[i] = rqstp->rq_pages[i]; + rqstp->rq_pages[i] = NULL; + } +} + +static void svc_tcp_clear_pages(struct svc_sock *svsk) +{ + unsigned int i, len, npages; + + if (svsk->sk_datalen == 0) + goto out; + len = svsk->sk_datalen; + npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + if (svsk->sk_pages[i] == NULL) { + WARN_ON_ONCE(1); + continue; + } + put_page(svsk->sk_pages[i]); + svsk->sk_pages[i] = NULL; + } +out: + svsk->sk_tcplen = 0; + svsk->sk_datalen = 0; +} + /* - * Receive data. + * Receive fragment record header. * If we haven't gotten the record length yet, get the next four bytes. - * Otherwise try to gobble up as much as possible up to the complete - * record length. */ static int svc_tcp_recv_record(struct svc_sock *svsk, struct svc_rqst *rqstp) { struct svc_serv *serv = svsk->sk_xprt.xpt_server; + unsigned int want; int len; - if (test_and_clear_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags)) - /* sndbuf needs to have room for one request - * per thread, otherwise we can stall even when the - * network isn't a bottleneck. - * - * We count all threads rather than threads in a - * particular pool, which provides an upper bound - * on the number of threads which will access the socket. - * - * rcvbuf just needs to be able to hold a few requests. - * Normally they will be removed from the queue - * as soon a a complete request arrives. - */ - svc_sock_setbufsize(svsk->sk_sock, - (serv->sv_nrthreads+3) * serv->sv_max_mesg, - 3 * serv->sv_max_mesg); - clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); if (svsk->sk_tcplen < sizeof(rpc_fraghdr)) { - int want = sizeof(rpc_fraghdr) - svsk->sk_tcplen; struct kvec iov; + want = sizeof(rpc_fraghdr) - svsk->sk_tcplen; iov.iov_base = ((char *) &svsk->sk_reclen) + svsk->sk_tcplen; iov.iov_len = want; if ((len = svc_recvfrom(rqstp, &iov, 1, want)) < 0) @@ -917,111 +966,91 @@ static int svc_tcp_recv_record(struct svc_sock *svsk, struct svc_rqst *rqstp) if (len < want) { dprintk("svc: short recvfrom while reading record " "length (%d of %d)\n", len, want); - svc_xprt_received(&svsk->sk_xprt); - goto err_again; /* record header not complete */ + return -EAGAIN; } - svsk->sk_reclen = ntohl(svsk->sk_reclen); - if (!(svsk->sk_reclen & RPC_LAST_STREAM_FRAGMENT)) { - /* FIXME: technically, a record can be fragmented, - * and non-terminal fragments will not have the top - * bit set in the fragment length header. - * But apparently no known nfs clients send fragmented - * records. */ - if (net_ratelimit()) - printk(KERN_NOTICE "RPC: multiple fragments " - "per record not supported\n"); + dprintk("svc: TCP record, %d bytes\n", svc_sock_reclen(svsk)); + if (svc_sock_reclen(svsk) + svsk->sk_datalen > + serv->sv_max_mesg) { + net_notice_ratelimited("RPC: fragment too large: %d\n", + svc_sock_reclen(svsk)); goto err_delete; } - - svsk->sk_reclen &= RPC_FRAGMENT_SIZE_MASK; - dprintk("svc: TCP record, %d bytes\n", svsk->sk_reclen); - if (svsk->sk_reclen > serv->sv_max_mesg) { - if (net_ratelimit()) - printk(KERN_NOTICE "RPC: " - "fragment too large: 0x%08lx\n", - (unsigned long)svsk->sk_reclen); - goto err_delete; - } - } - - /* Check whether enough data is available */ - len = svc_recv_available(svsk); - if (len < 0) - goto error; - - if (len < svsk->sk_reclen) { - dprintk("svc: incomplete TCP record (%d of %d)\n", - len, svsk->sk_reclen); - svc_xprt_received(&svsk->sk_xprt); - goto err_again; /* record not complete */ } - len = svsk->sk_reclen; - set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + return svc_sock_reclen(svsk); +error: + dprintk("RPC: TCP recv_record got %d\n", len); return len; - error: - if (len == -EAGAIN) { - dprintk("RPC: TCP recv_record got EAGAIN\n"); - svc_xprt_received(&svsk->sk_xprt); - } - return len; - err_delete: +err_delete: set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); - svc_xprt_received(&svsk->sk_xprt); - err_again: return -EAGAIN; } -static int svc_process_calldir(struct svc_sock *svsk, struct svc_rqst *rqstp, - struct rpc_rqst **reqpp, struct kvec *vec) +static int receive_cb_reply(struct svc_sock *svsk, struct svc_rqst *rqstp) { + struct rpc_xprt *bc_xprt = svsk->sk_xprt.xpt_bc_xprt; struct rpc_rqst *req = NULL; - u32 *p; - u32 xid; - u32 calldir; - int len; + struct kvec *src, *dst; + __be32 *p = (__be32 *)rqstp->rq_arg.head[0].iov_base; + __be32 xid; + __be32 calldir; - len = svc_recvfrom(rqstp, vec, 1, 8); - if (len < 0) - goto error; - - p = (u32 *)rqstp->rq_arg.head[0].iov_base; xid = *p++; calldir = *p; - if (calldir == 0) { - /* REQUEST is the most common case */ - vec[0] = rqstp->rq_arg.head[0]; - } else { - /* REPLY */ - if (svsk->sk_bc_xprt) - req = xprt_lookup_rqst(svsk->sk_bc_xprt, xid); - - if (!req) { - printk(KERN_NOTICE - "%s: Got unrecognized reply: " - "calldir 0x%x sk_bc_xprt %p xid %08x\n", - __func__, ntohl(calldir), - svsk->sk_bc_xprt, xid); - vec[0] = rqstp->rq_arg.head[0]; - goto out; - } + if (bc_xprt) + req = xprt_lookup_rqst(bc_xprt, xid); - memcpy(&req->rq_private_buf, &req->rq_rcv_buf, - sizeof(struct xdr_buf)); - /* copy the xid and call direction */ - memcpy(req->rq_private_buf.head[0].iov_base, - rqstp->rq_arg.head[0].iov_base, 8); - vec[0] = req->rq_private_buf.head[0]; + if (!req) { + printk(KERN_NOTICE + "%s: Got unrecognized reply: " + "calldir 0x%x xpt_bc_xprt %p xid %08x\n", + __func__, ntohl(calldir), + bc_xprt, xid); + return -EAGAIN; } - out: - vec[0].iov_base += 8; - vec[0].iov_len -= 8; - len = svsk->sk_reclen - 8; - error: - *reqpp = req; - return len; + + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(struct xdr_buf)); + /* + * XXX!: cheating for now! Only copying HEAD. + * But we know this is good enough for now (in fact, for any + * callback reply in the forseeable future). + */ + dst = &req->rq_private_buf.head[0]; + src = &rqstp->rq_arg.head[0]; + if (dst->iov_len < src->iov_len) + return -EAGAIN; /* whatever; just giving up. */ + memcpy(dst->iov_base, src->iov_base, src->iov_len); + xprt_complete_rqst(req->rq_task, rqstp->rq_arg.len); + rqstp->rq_arg.len = 0; + return 0; +} + +static int copy_pages_to_kvecs(struct kvec *vec, struct page **pages, int len) +{ + int i = 0; + int t = 0; + + while (t < len) { + vec[i].iov_base = page_address(pages[i]); + vec[i].iov_len = PAGE_SIZE; + i++; + t += PAGE_SIZE; + } + return i; +} + +static void svc_tcp_fragment_received(struct svc_sock *svsk) +{ + /* If we have more data, signal svc_xprt_enqueue() to try again */ + if (svc_recv_available(svsk) > sizeof(rpc_fraghdr)) + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + dprintk("svc: TCP %s record (%d bytes)\n", + svc_sock_final_rec(svsk) ? "final" : "nonfinal", + svc_sock_reclen(svsk)); + svsk->sk_tcplen = 0; + svsk->sk_reclen = 0; } /* @@ -1034,8 +1063,10 @@ static int svc_tcp_recvfrom(struct svc_rqst *rqstp) struct svc_serv *serv = svsk->sk_xprt.xpt_server; int len; struct kvec *vec; - int pnum, vlen; - struct rpc_rqst *req = NULL; + unsigned int want, base; + __be32 *p; + __be32 calldir; + int pnum; dprintk("svc: tcp_recv %p data %d conn %d close %d\n", svsk, test_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags), @@ -1046,89 +1077,82 @@ static int svc_tcp_recvfrom(struct svc_rqst *rqstp) if (len < 0) goto error; + base = svc_tcp_restore_pages(svsk, rqstp); + want = svc_sock_reclen(svsk) - (svsk->sk_tcplen - sizeof(rpc_fraghdr)); + vec = rqstp->rq_vec; - vec[0] = rqstp->rq_arg.head[0]; - vlen = PAGE_SIZE; - /* - * We have enough data for the whole tcp record. Let's try and read the - * first 8 bytes to get the xid and the call direction. We can use this - * to figure out if this is a call or a reply to a callback. If - * sk_reclen is < 8 (xid and calldir), then this is a malformed packet. - * In that case, don't bother with the calldir and just read the data. - * It will be rejected in svc_process. - */ - if (len >= 8) { - len = svc_process_calldir(svsk, rqstp, &req, vec); - if (len < 0) - goto err_again; - vlen -= 8; - } + pnum = copy_pages_to_kvecs(&vec[0], &rqstp->rq_pages[0], + svsk->sk_datalen + want); - pnum = 1; - while (vlen < len) { - vec[pnum].iov_base = (req) ? - page_address(req->rq_private_buf.pages[pnum - 1]) : - page_address(rqstp->rq_pages[pnum]); - vec[pnum].iov_len = PAGE_SIZE; - pnum++; - vlen += PAGE_SIZE; - } rqstp->rq_respages = &rqstp->rq_pages[pnum]; + rqstp->rq_next_page = rqstp->rq_respages + 1; /* Now receive data */ - len = svc_recvfrom(rqstp, vec, pnum, len); - if (len < 0) - goto err_again; - - /* - * Account for the 8 bytes we read earlier - */ - len += 8; + len = svc_partial_recvfrom(rqstp, vec, pnum, want, base); + if (len >= 0) { + svsk->sk_tcplen += len; + svsk->sk_datalen += len; + } + if (len != want || !svc_sock_final_rec(svsk)) { + svc_tcp_save_pages(svsk, rqstp); + if (len < 0 && len != -EAGAIN) + goto err_delete; + if (len == want) + svc_tcp_fragment_received(svsk); + else + dprintk("svc: incomplete TCP record (%d of %d)\n", + (int)(svsk->sk_tcplen - sizeof(rpc_fraghdr)), + svc_sock_reclen(svsk)); + goto err_noclose; + } - if (req) { - xprt_complete_rqst(req->rq_task, len); - len = 0; - goto out; + if (svsk->sk_datalen < 8) { + svsk->sk_datalen = 0; + goto err_delete; /* client is nuts. */ } - dprintk("svc: TCP complete record (%d bytes)\n", len); - rqstp->rq_arg.len = len; + + rqstp->rq_arg.len = svsk->sk_datalen; rqstp->rq_arg.page_base = 0; - if (len <= rqstp->rq_arg.head[0].iov_len) { - rqstp->rq_arg.head[0].iov_len = len; + if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len) { + rqstp->rq_arg.head[0].iov_len = rqstp->rq_arg.len; rqstp->rq_arg.page_len = 0; - } else { - rqstp->rq_arg.page_len = len - rqstp->rq_arg.head[0].iov_len; - } + } else + rqstp->rq_arg.page_len = rqstp->rq_arg.len - rqstp->rq_arg.head[0].iov_len; rqstp->rq_xprt_ctxt = NULL; rqstp->rq_prot = IPPROTO_TCP; + rqstp->rq_local = !!test_bit(XPT_LOCAL, &svsk->sk_xprt.xpt_flags); + + p = (__be32 *)rqstp->rq_arg.head[0].iov_base; + calldir = p[1]; + if (calldir) + len = receive_cb_reply(svsk, rqstp); -out: /* Reset TCP read info */ - svsk->sk_reclen = 0; - svsk->sk_tcplen = 0; + svsk->sk_datalen = 0; + svc_tcp_fragment_received(svsk); + + if (len < 0) + goto error; svc_xprt_copy_addrs(rqstp, &svsk->sk_xprt); - svc_xprt_received(&svsk->sk_xprt); if (serv->sv_stats) serv->sv_stats->nettcpcnt++; - return len; + return rqstp->rq_arg.len; -err_again: - if (len == -EAGAIN) { - dprintk("RPC: TCP recvfrom got EAGAIN\n"); - svc_xprt_received(&svsk->sk_xprt); - return len; - } error: - if (len != -EAGAIN) { - printk(KERN_NOTICE "%s: recvfrom returned errno %d\n", - svsk->sk_xprt.xpt_server->sv_name, -len); - set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); - } - return -EAGAIN; + if (len != -EAGAIN) + goto err_delete; + dprintk("RPC: TCP recvfrom got EAGAIN\n"); + return 0; +err_delete: + printk(KERN_NOTICE "%s: recvfrom returned errno %d\n", + svsk->sk_xprt.xpt_server->sv_name, -len); + set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); +err_noclose: + return 0; /* record not complete */ } /* @@ -1147,9 +1171,6 @@ static int svc_tcp_sendto(struct svc_rqst *rqstp) reclen = htonl(0x80000000|((xbufp->len ) - 4)); memcpy(xbufp->head[0].iov_base, &reclen, 4); - if (test_bit(XPT_DEAD, &rqstp->rq_xprt->xpt_flags)) - return -ENOTCONN; - sent = svc_sendto(rqstp, &rqstp->rq_res); if (sent != xbufp->len) { printk(KERN_NOTICE @@ -1185,18 +1206,73 @@ static int svc_tcp_has_wspace(struct svc_xprt *xprt) if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) return 1; required = atomic_read(&xprt->xpt_reserved) + serv->sv_max_mesg; - if (sk_stream_wspace(svsk->sk_sk) >= required) + if (sk_stream_wspace(svsk->sk_sk) >= required || + (sk_stream_min_wspace(svsk->sk_sk) == 0 && + atomic_read(&xprt->xpt_reserved) == 0)) return 1; set_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); return 0; } static struct svc_xprt *svc_tcp_create(struct svc_serv *serv, + struct net *net, + struct sockaddr *sa, int salen, + int flags) +{ + return svc_create_socket(serv, IPPROTO_TCP, net, sa, salen, flags); +} + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static struct svc_xprt *svc_bc_create_socket(struct svc_serv *, int, + struct net *, struct sockaddr *, + int, int); +static void svc_bc_sock_free(struct svc_xprt *xprt); + +static struct svc_xprt *svc_bc_tcp_create(struct svc_serv *serv, + struct net *net, struct sockaddr *sa, int salen, int flags) { - return svc_create_socket(serv, IPPROTO_TCP, sa, salen, flags); + return svc_bc_create_socket(serv, IPPROTO_TCP, net, sa, salen, flags); +} + +static void svc_bc_tcp_sock_detach(struct svc_xprt *xprt) +{ +} + +static struct svc_xprt_ops svc_tcp_bc_ops = { + .xpo_create = svc_bc_tcp_create, + .xpo_detach = svc_bc_tcp_sock_detach, + .xpo_free = svc_bc_sock_free, + .xpo_prep_reply_hdr = svc_tcp_prep_reply_hdr, + .xpo_secure_port = svc_sock_secure_port, +}; + +static struct svc_xprt_class svc_tcp_bc_class = { + .xcl_name = "tcp-bc", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_tcp_bc_ops, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_TCP, +}; + +static void svc_init_bc_xprt_sock(void) +{ + svc_reg_xprt_class(&svc_tcp_bc_class); +} + +static void svc_cleanup_bc_xprt_sock(void) +{ + svc_unreg_xprt_class(&svc_tcp_bc_class); +} +#else /* CONFIG_SUNRPC_BACKCHANNEL */ +static void svc_init_bc_xprt_sock(void) +{ +} + +static void svc_cleanup_bc_xprt_sock(void) +{ } +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ static struct svc_xprt_ops svc_tcp_ops = { .xpo_create = svc_tcp_create, @@ -1208,6 +1284,7 @@ static struct svc_xprt_ops svc_tcp_ops = { .xpo_prep_reply_hdr = svc_tcp_prep_reply_hdr, .xpo_has_wspace = svc_tcp_has_wspace, .xpo_accept = svc_tcp_accept, + .xpo_secure_port = svc_sock_secure_port, }; static struct svc_xprt_class svc_tcp_class = { @@ -1221,19 +1298,22 @@ void svc_init_xprt_sock(void) { svc_reg_xprt_class(&svc_tcp_class); svc_reg_xprt_class(&svc_udp_class); + svc_init_bc_xprt_sock(); } void svc_cleanup_xprt_sock(void) { svc_unreg_xprt_class(&svc_tcp_class); svc_unreg_xprt_class(&svc_udp_class); + svc_cleanup_bc_xprt_sock(); } static void svc_tcp_init(struct svc_sock *svsk, struct svc_serv *serv) { struct sock *sk = svsk->sk_sk; - svc_xprt_init(&svc_tcp_class, &svsk->sk_xprt, serv); + svc_xprt_init(sock_net(svsk->sk_sock->sk), &svc_tcp_class, + &svsk->sk_xprt, serv); set_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags); if (sk->sk_state == TCP_LISTEN) { dprintk("setting up TCP socket for listening\n"); @@ -1248,18 +1328,11 @@ static void svc_tcp_init(struct svc_sock *svsk, struct svc_serv *serv) svsk->sk_reclen = 0; svsk->sk_tcplen = 0; + svsk->sk_datalen = 0; + memset(&svsk->sk_pages[0], 0, sizeof(svsk->sk_pages)); tcp_sk(sk)->nonagle |= TCP_NAGLE_OFF; - /* initialise setting must have enough space to - * receive and respond to one request. - * svc_tcp_recvfrom will re-adjust if necessary - */ - svc_sock_setbufsize(svsk->sk_sock, - 3 * svsk->sk_xprt.xpt_server->sv_max_mesg, - 3 * svsk->sk_xprt.xpt_server->sv_max_mesg); - - set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); if (sk->sk_state != TCP_ESTABLISHED) set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); @@ -1272,19 +1345,11 @@ void svc_sock_update_bufs(struct svc_serv *serv) * The number of server threads has changed. Update * rcvbuf and sndbuf accordingly on all sockets */ - struct list_head *le; + struct svc_sock *svsk; spin_lock_bh(&serv->sv_lock); - list_for_each(le, &serv->sv_permsocks) { - struct svc_sock *svsk = - list_entry(le, struct svc_sock, sk_xprt.xpt_list); + list_for_each_entry(svsk, &serv->sv_permsocks, sk_xprt.xpt_list) set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); - } - list_for_each(le, &serv->sv_tempsocks) { - struct svc_sock *svsk = - list_entry(le, struct svc_sock, sk_xprt.xpt_list); - set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); - } spin_unlock_bh(&serv->sv_lock); } EXPORT_SYMBOL_GPL(svc_sock_update_bufs); @@ -1295,28 +1360,29 @@ EXPORT_SYMBOL_GPL(svc_sock_update_bufs); */ static struct svc_sock *svc_setup_socket(struct svc_serv *serv, struct socket *sock, - int *errp, int flags) + int flags) { struct svc_sock *svsk; struct sock *inet; int pmap_register = !(flags & SVC_SOCK_ANONYMOUS); + int err = 0; dprintk("svc: svc_setup_socket %p\n", sock); - if (!(svsk = kzalloc(sizeof(*svsk), GFP_KERNEL))) { - *errp = -ENOMEM; - return NULL; - } + svsk = kzalloc(sizeof(*svsk), GFP_KERNEL); + if (!svsk) + return ERR_PTR(-ENOMEM); inet = sock->sk; /* Register socket with portmapper */ - if (*errp >= 0 && pmap_register) - *errp = svc_register(serv, inet->sk_family, inet->sk_protocol, + if (pmap_register) + err = svc_register(serv, sock_net(sock->sk), inet->sk_family, + inet->sk_protocol, ntohs(inet_sk(inet)->inet_sport)); - if (*errp < 0) { + if (err < 0) { kfree(svsk); - return NULL; + return ERR_PTR(err); } inet->sk_user_data = svsk; @@ -1329,8 +1395,14 @@ static struct svc_sock *svc_setup_socket(struct svc_serv *serv, /* Initialize the socket */ if (sock->type == SOCK_DGRAM) svc_udp_init(svsk, serv); - else + else { + /* initialise setting must have enough space to + * receive and respond to one request. + */ + svc_sock_setbufsize(svsk->sk_sock, 4 * serv->sv_max_mesg, + 4 * serv->sv_max_mesg); svc_tcp_init(svsk, serv); + } dprintk("svc: svc_setup_socket created %p (inet %p)\n", svsk, svsk->sk_sk); @@ -1338,6 +1410,22 @@ static struct svc_sock *svc_setup_socket(struct svc_serv *serv, return svsk; } +bool svc_alien_sock(struct net *net, int fd) +{ + int err; + struct socket *sock = sockfd_lookup(fd, &err); + bool ret = false; + + if (!sock) + goto out; + if (sock_net(sock->sk) != net) + ret = true; + sockfd_put(sock); +out: + return ret; +} +EXPORT_SYMBOL_GPL(svc_alien_sock); + /** * svc_addsock - add a listener socket to an RPC service * @serv: pointer to RPC service to which to add a new listener @@ -1355,42 +1443,38 @@ int svc_addsock(struct svc_serv *serv, const int fd, char *name_return, int err = 0; struct socket *so = sockfd_lookup(fd, &err); struct svc_sock *svsk = NULL; + struct sockaddr_storage addr; + struct sockaddr *sin = (struct sockaddr *)&addr; + int salen; if (!so) return err; + err = -EAFNOSUPPORT; if ((so->sk->sk_family != PF_INET) && (so->sk->sk_family != PF_INET6)) - err = -EAFNOSUPPORT; - else if (so->sk->sk_protocol != IPPROTO_TCP && + goto out; + err = -EPROTONOSUPPORT; + if (so->sk->sk_protocol != IPPROTO_TCP && so->sk->sk_protocol != IPPROTO_UDP) - err = -EPROTONOSUPPORT; - else if (so->state > SS_UNCONNECTED) - err = -EISCONN; - else { - if (!try_module_get(THIS_MODULE)) - err = -ENOENT; - else - svsk = svc_setup_socket(serv, so, &err, - SVC_SOCK_DEFAULTS); - if (svsk) { - struct sockaddr_storage addr; - struct sockaddr *sin = (struct sockaddr *)&addr; - int salen; - if (kernel_getsockname(svsk->sk_sock, sin, &salen) == 0) - svc_xprt_set_local(&svsk->sk_xprt, sin, salen); - clear_bit(XPT_TEMP, &svsk->sk_xprt.xpt_flags); - spin_lock_bh(&serv->sv_lock); - list_add(&svsk->sk_xprt.xpt_list, &serv->sv_permsocks); - spin_unlock_bh(&serv->sv_lock); - svc_xprt_received(&svsk->sk_xprt); - err = 0; - } else - module_put(THIS_MODULE); - } - if (err) { - sockfd_put(so); - return err; + goto out; + err = -EISCONN; + if (so->state > SS_UNCONNECTED) + goto out; + err = -ENOENT; + if (!try_module_get(THIS_MODULE)) + goto out; + svsk = svc_setup_socket(serv, so, SVC_SOCK_DEFAULTS); + if (IS_ERR(svsk)) { + module_put(THIS_MODULE); + err = PTR_ERR(svsk); + goto out; } + if (kernel_getsockname(svsk->sk_sock, sin, &salen) == 0) + svc_xprt_set_local(&svsk->sk_xprt, sin, salen); + svc_add_new_perm_xprt(serv, &svsk->sk_xprt); return svc_one_sock_name(svsk, name_return, len); +out: + sockfd_put(so); + return err; } EXPORT_SYMBOL_GPL(svc_addsock); @@ -1399,6 +1483,7 @@ EXPORT_SYMBOL_GPL(svc_addsock); */ static struct svc_xprt *svc_create_socket(struct svc_serv *serv, int protocol, + struct net *net, struct sockaddr *sin, int len, int flags) { @@ -1435,7 +1520,7 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv, return ERR_PTR(-EINVAL); } - error = sock_create_kern(family, type, protocol, &sock); + error = __sock_create(net, family, type, protocol, &sock, 1); if (error < 0) return ERR_PTR(error); @@ -1452,7 +1537,7 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv, (char *)&val, sizeof(val)); if (type == SOCK_STREAM) - sock->sk->sk_reuse = 1; /* allow address reuse */ + sock->sk->sk_reuse = SK_CAN_REUSE; /* allow address reuse */ error = kernel_bind(sock, sin, len); if (error < 0) goto bummer; @@ -1467,11 +1552,13 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv, goto bummer; } - if ((svsk = svc_setup_socket(serv, sock, &error, flags)) != NULL) { - svc_xprt_set_local(&svsk->sk_xprt, newsin, newlen); - return (struct svc_xprt *)svsk; + svsk = svc_setup_socket(serv, sock, flags); + if (IS_ERR(svsk)) { + error = PTR_ERR(svsk); + goto bummer; } - + svc_xprt_set_local(&svsk->sk_xprt, newsin, newlen); + return (struct svc_xprt *)svsk; bummer: dprintk("svc: svc_create_socket error = %d\n", -error); sock_release(sock); @@ -1486,6 +1573,7 @@ static void svc_sock_detach(struct svc_xprt *xprt) { struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); struct sock *sk = svsk->sk_sk; + wait_queue_head_t *wq; dprintk("svc: svc_sock_detach(%p)\n", svsk); @@ -1494,8 +1582,9 @@ static void svc_sock_detach(struct svc_xprt *xprt) sk->sk_data_ready = svsk->sk_odata; sk->sk_write_space = svsk->sk_owspace; - if (sk->sk_sleep && waitqueue_active(sk->sk_sleep)) - wake_up_interruptible(sk->sk_sleep); + wq = sk_sleep(sk); + if (wq && waitqueue_active(wq)) + wake_up_interruptible(wq); } /* @@ -1509,8 +1598,10 @@ static void svc_tcp_sock_detach(struct svc_xprt *xprt) svc_sock_detach(xprt); - if (!test_bit(XPT_LISTENER, &xprt->xpt_flags)) + if (!test_bit(XPT_LISTENER, &xprt->xpt_flags)) { + svc_tcp_clear_pages(svsk); kernel_sock_shutdown(svsk->sk_sock, SHUT_RDWR); + } } /* @@ -1528,41 +1619,43 @@ static void svc_sock_free(struct svc_xprt *xprt) kfree(svsk); } +#if defined(CONFIG_SUNRPC_BACKCHANNEL) /* - * Create a svc_xprt. - * - * For internal use only (e.g. nfsv4.1 backchannel). - * Callers should typically use the xpo_create() method. + * Create a back channel svc_xprt which shares the fore channel socket. */ -struct svc_xprt *svc_sock_create(struct svc_serv *serv, int prot) +static struct svc_xprt *svc_bc_create_socket(struct svc_serv *serv, + int protocol, + struct net *net, + struct sockaddr *sin, int len, + int flags) { struct svc_sock *svsk; - struct svc_xprt *xprt = NULL; + struct svc_xprt *xprt; + + if (protocol != IPPROTO_TCP) { + printk(KERN_WARNING "svc: only TCP sockets" + " supported on shared back channel\n"); + return ERR_PTR(-EINVAL); + } - dprintk("svc: %s\n", __func__); svsk = kzalloc(sizeof(*svsk), GFP_KERNEL); if (!svsk) - goto out; + return ERR_PTR(-ENOMEM); xprt = &svsk->sk_xprt; - if (prot == IPPROTO_TCP) - svc_xprt_init(&svc_tcp_class, xprt, serv); - else if (prot == IPPROTO_UDP) - svc_xprt_init(&svc_udp_class, xprt, serv); - else - BUG(); -out: - dprintk("svc: %s return %p\n", __func__, xprt); + svc_xprt_init(net, &svc_tcp_bc_class, xprt, serv); + + serv->sv_bc_xprt = xprt; + return xprt; } -EXPORT_SYMBOL_GPL(svc_sock_create); /* - * Destroy a svc_sock. + * Free a back channel svc_sock. */ -void svc_sock_destroy(struct svc_xprt *xprt) +static void svc_bc_sock_free(struct svc_xprt *xprt) { if (xprt) kfree(container_of(xprt, struct svc_sock, sk_xprt)); } -EXPORT_SYMBOL_GPL(svc_sock_destroy); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ diff --git a/net/sunrpc/sysctl.c b/net/sunrpc/sysctl.c index e65dcc61333..c99c58e2ee6 100644 --- a/net/sunrpc/sysctl.c +++ b/net/sunrpc/sysctl.c @@ -20,6 +20,8 @@ #include <linux/sunrpc/stats.h> #include <linux/sunrpc/svc_xprt.h> +#include "netns.h" + /* * Declare the debug flags here */ @@ -38,7 +40,7 @@ EXPORT_SYMBOL_GPL(nlm_debug); #ifdef RPC_DEBUG static struct ctl_table_header *sunrpc_table_header; -static ctl_table sunrpc_table[]; +static struct ctl_table sunrpc_table[]; void rpc_register_sysctl(void) @@ -56,7 +58,7 @@ rpc_unregister_sysctl(void) } } -static int proc_do_xprt(ctl_table *table, int write, +static int proc_do_xprt(struct ctl_table *table, int write, void __user *buffer, size_t *lenp, loff_t *ppos) { char tmpbuf[256]; @@ -71,7 +73,7 @@ static int proc_do_xprt(ctl_table *table, int write, } static int -proc_dodebug(ctl_table *table, int write, +proc_dodebug(struct ctl_table *table, int write, void __user *buffer, size_t *lenp, loff_t *ppos) { char tmpbuf[20], c, *s; @@ -110,7 +112,7 @@ proc_dodebug(ctl_table *table, int write, *(unsigned int *) table->data = value; /* Display the RPC tasks on writing to rpc_debug */ if (strcmp(table->procname, "rpc_debug") == 0) - rpc_show_tasks(); + rpc_show_tasks(&init_net); } else { if (!access_ok(VERIFY_WRITE, buffer, left)) return -EFAULT; @@ -133,7 +135,7 @@ done: } -static ctl_table debug_table[] = { +static struct ctl_table debug_table[] = { { .procname = "rpc_debug", .data = &rpc_debug, @@ -171,7 +173,7 @@ static ctl_table debug_table[] = { { } }; -static ctl_table sunrpc_table[] = { +static struct ctl_table sunrpc_table[] = { { .procname = "sunrpc", .mode = 0555, diff --git a/net/sunrpc/timer.c b/net/sunrpc/timer.c index dd824341c34..08881d0c967 100644 --- a/net/sunrpc/timer.c +++ b/net/sunrpc/timer.c @@ -34,7 +34,7 @@ void rpc_init_rtt(struct rpc_rtt *rt, unsigned long timeo) { unsigned long init = 0; - unsigned i; + unsigned int i; rt->timeo = timeo; @@ -57,7 +57,7 @@ EXPORT_SYMBOL_GPL(rpc_init_rtt); * NB: When computing the smoothed RTT and standard deviation, * be careful not to produce negative intermediate results. */ -void rpc_update_rtt(struct rpc_rtt *rt, unsigned timer, long m) +void rpc_update_rtt(struct rpc_rtt *rt, unsigned int timer, long m) { long *srtt, *sdrtt; @@ -106,7 +106,7 @@ EXPORT_SYMBOL_GPL(rpc_update_rtt); * read, write, commit - A+4D * other - timeo */ -unsigned long rpc_calc_rto(struct rpc_rtt *rt, unsigned timer) +unsigned long rpc_calc_rto(struct rpc_rtt *rt, unsigned int timer) { unsigned long res; diff --git a/net/sunrpc/xdr.c b/net/sunrpc/xdr.c index 8bd690c48b6..23fb4e75e24 100644 --- a/net/sunrpc/xdr.c +++ b/net/sunrpc/xdr.c @@ -7,6 +7,7 @@ */ #include <linux/module.h> +#include <linux/slab.h> #include <linux/types.h> #include <linux/string.h> #include <linux/kernel.h> @@ -110,33 +111,22 @@ xdr_decode_string_inplace(__be32 *p, char **sp, } EXPORT_SYMBOL_GPL(xdr_decode_string_inplace); +/** + * xdr_terminate_string - '\0'-terminate a string residing in an xdr_buf + * @buf: XDR buffer where string resides + * @len: length of string, in bytes + * + */ void -xdr_encode_pages(struct xdr_buf *xdr, struct page **pages, unsigned int base, - unsigned int len) +xdr_terminate_string(struct xdr_buf *buf, const u32 len) { - struct kvec *tail = xdr->tail; - u32 *p; - - xdr->pages = pages; - xdr->page_base = base; - xdr->page_len = len; - - p = (u32 *)xdr->head[0].iov_base + XDR_QUADLEN(xdr->head[0].iov_len); - tail->iov_base = p; - tail->iov_len = 0; - - if (len & 3) { - unsigned int pad = 4 - (len & 3); + char *kaddr; - *p = 0; - tail->iov_base = (char *)p + (len & 3); - tail->iov_len = pad; - len += pad; - } - xdr->buflen += len; - xdr->len += len; + kaddr = kmap_atomic(buf->pages[0]); + kaddr[buf->page_base + len] = '\0'; + kunmap_atomic(kaddr); } -EXPORT_SYMBOL_GPL(xdr_encode_pages); +EXPORT_SYMBOL_GPL(xdr_terminate_string); void xdr_inline_pages(struct xdr_buf *xdr, unsigned int offset, @@ -162,7 +152,9 @@ EXPORT_SYMBOL_GPL(xdr_inline_pages); /* * Helper routines for doing 'memmove' like operations on a struct xdr_buf - * + */ + +/** * _shift_data_right_pages * @pages: vector of pages containing both the source and dest memory area. * @pgto_base: page vector address of destination @@ -214,17 +206,20 @@ _shift_data_right_pages(struct page **pages, size_t pgto_base, pgto_base -= copy; pgfrom_base -= copy; - vto = kmap_atomic(*pgto, KM_USER0); - vfrom = kmap_atomic(*pgfrom, KM_USER1); - memmove(vto + pgto_base, vfrom + pgfrom_base, copy); + vto = kmap_atomic(*pgto); + if (*pgto != *pgfrom) { + vfrom = kmap_atomic(*pgfrom); + memcpy(vto + pgto_base, vfrom + pgfrom_base, copy); + kunmap_atomic(vfrom); + } else + memmove(vto + pgto_base, vto + pgfrom_base, copy); flush_dcache_page(*pgto); - kunmap_atomic(vfrom, KM_USER1); - kunmap_atomic(vto, KM_USER0); + kunmap_atomic(vto); } while ((len -= copy) != 0); } -/* +/** * _copy_to_pages * @pages: array of pages * @pgbase: page vector address of destination @@ -249,9 +244,9 @@ _copy_to_pages(struct page **pages, size_t pgbase, const char *p, size_t len) if (copy > len) copy = len; - vto = kmap_atomic(*pgto, KM_USER0); + vto = kmap_atomic(*pgto); memcpy(vto + pgbase, p, copy); - kunmap_atomic(vto, KM_USER0); + kunmap_atomic(vto); len -= copy; if (len == 0) @@ -268,7 +263,7 @@ _copy_to_pages(struct page **pages, size_t pgbase, const char *p, size_t len) flush_dcache_page(*pgto); } -/* +/** * _copy_from_pages * @p: pointer to destination * @pages: array of pages @@ -278,7 +273,7 @@ _copy_to_pages(struct page **pages, size_t pgbase, const char *p, size_t len) * Copies data into an arbitrary memory location from an array of pages * The copy is assumed to be non-overlapping. */ -static void +void _copy_from_pages(char *p, struct page **pages, size_t pgbase, size_t len) { struct page **pgfrom; @@ -293,9 +288,9 @@ _copy_from_pages(char *p, struct page **pages, size_t pgbase, size_t len) if (copy > len) copy = len; - vfrom = kmap_atomic(*pgfrom, KM_USER0); + vfrom = kmap_atomic(*pgfrom); memcpy(p, vfrom + pgbase, copy); - kunmap_atomic(vfrom, KM_USER0); + kunmap_atomic(vfrom); pgbase += copy; if (pgbase == PAGE_CACHE_SIZE) { @@ -306,8 +301,9 @@ _copy_from_pages(char *p, struct page **pages, size_t pgbase, size_t len) } while ((len -= copy) != 0); } +EXPORT_SYMBOL_GPL(_copy_from_pages); -/* +/** * xdr_shrink_bufhead * @buf: xdr_buf * @len: bytes to remove from buf->head[0] @@ -325,7 +321,10 @@ xdr_shrink_bufhead(struct xdr_buf *buf, size_t len) tail = buf->tail; head = buf->head; - BUG_ON (len > head->iov_len); + + WARN_ON_ONCE(len > head->iov_len); + if (len > head->iov_len) + len = head->iov_len; /* Shift the tail first */ if (tail->iov_len != 0) { @@ -380,7 +379,7 @@ xdr_shrink_bufhead(struct xdr_buf *buf, size_t len) buf->len = buf->buflen; } -/* +/** * xdr_shrink_pagelen * @buf: xdr_buf * @len: bytes to remove from buf->pages @@ -394,24 +393,29 @@ xdr_shrink_pagelen(struct xdr_buf *buf, size_t len) { struct kvec *tail; size_t copy; - char *p; unsigned int pglen = buf->page_len; + unsigned int tailbuf_len; tail = buf->tail; BUG_ON (len > pglen); + tailbuf_len = buf->buflen - buf->head->iov_len - buf->page_len; + /* Shift the tail first */ - if (tail->iov_len != 0) { - p = (char *)tail->iov_base + len; + if (tailbuf_len != 0) { + unsigned int free_space = tailbuf_len - tail->iov_len; + + if (len < free_space) + free_space = len; + tail->iov_len += free_space; + + copy = len; if (tail->iov_len > len) { - copy = tail->iov_len - len; - memmove(p, tail->iov_base, copy); + char *p = (char *)tail->iov_base + len; + memmove(p, tail->iov_base, tail->iov_len - len); } else - buf->buflen -= len; - /* Copy from the inlined pages into the tail */ - copy = len; - if (copy > tail->iov_len) copy = tail->iov_len; + /* Copy from the inlined pages into the tail */ _copy_from_pages((char *)tail->iov_base, buf->pages, buf->page_base + pglen - len, copy); @@ -431,6 +435,16 @@ xdr_shift_buf(struct xdr_buf *buf, size_t len) EXPORT_SYMBOL_GPL(xdr_shift_buf); /** + * xdr_stream_pos - Return the current offset from the start of the xdr_stream + * @xdr: pointer to struct xdr_stream + */ +unsigned int xdr_stream_pos(const struct xdr_stream *xdr) +{ + return (unsigned int)(XDR_QUADLEN(xdr->buf->len) - xdr->nwords) << 2; +} +EXPORT_SYMBOL_GPL(xdr_stream_pos); + +/** * xdr_init_encode - Initialize a struct xdr_stream for sending data. * @xdr: pointer to xdr_stream struct * @buf: pointer to XDR buffer in which to encode data @@ -448,6 +462,7 @@ void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) struct kvec *iov = buf->head; int scratch_len = buf->buflen - buf->page_len - buf->tail[0].iov_len; + xdr_set_scratch_buffer(xdr, NULL, 0); BUG_ON(scratch_len < 0); xdr->buf = buf; xdr->iov = iov; @@ -468,6 +483,73 @@ void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) EXPORT_SYMBOL_GPL(xdr_init_encode); /** + * xdr_commit_encode - Ensure all data is written to buffer + * @xdr: pointer to xdr_stream + * + * We handle encoding across page boundaries by giving the caller a + * temporary location to write to, then later copying the data into + * place; xdr_commit_encode does that copying. + * + * Normally the caller doesn't need to call this directly, as the + * following xdr_reserve_space will do it. But an explicit call may be + * required at the end of encoding, or any other time when the xdr_buf + * data might be read. + */ +void xdr_commit_encode(struct xdr_stream *xdr) +{ + int shift = xdr->scratch.iov_len; + void *page; + + if (shift == 0) + return; + page = page_address(*xdr->page_ptr); + memcpy(xdr->scratch.iov_base, page, shift); + memmove(page, page + shift, (void *)xdr->p - page); + xdr->scratch.iov_len = 0; +} +EXPORT_SYMBOL_GPL(xdr_commit_encode); + +__be32 *xdr_get_next_encode_buffer(struct xdr_stream *xdr, size_t nbytes) +{ + static __be32 *p; + int space_left; + int frag1bytes, frag2bytes; + + if (nbytes > PAGE_SIZE) + return NULL; /* Bigger buffers require special handling */ + if (xdr->buf->len + nbytes > xdr->buf->buflen) + return NULL; /* Sorry, we're totally out of space */ + frag1bytes = (xdr->end - xdr->p) << 2; + frag2bytes = nbytes - frag1bytes; + if (xdr->iov) + xdr->iov->iov_len += frag1bytes; + else + xdr->buf->page_len += frag1bytes; + xdr->page_ptr++; + xdr->iov = NULL; + /* + * If the last encode didn't end exactly on a page boundary, the + * next one will straddle boundaries. Encode into the next + * page, then copy it back later in xdr_commit_encode. We use + * the "scratch" iov to track any temporarily unused fragment of + * space at the end of the previous buffer: + */ + xdr->scratch.iov_base = xdr->p; + xdr->scratch.iov_len = frag1bytes; + p = page_address(*xdr->page_ptr); + /* + * Note this is where the next encode will start after we've + * shifted this one back: + */ + xdr->p = (void *)p + frag2bytes; + space_left = xdr->buf->buflen - xdr->buf->len; + xdr->end = (void *)p + min_t(int, space_left, PAGE_SIZE); + xdr->buf->page_len += frag2bytes; + xdr->buf->len += nbytes; + return p; +} + +/** * xdr_reserve_space - Reserve buffer space for sending * @xdr: pointer to xdr_stream * @nbytes: number of bytes to reserve @@ -481,20 +563,122 @@ __be32 * xdr_reserve_space(struct xdr_stream *xdr, size_t nbytes) __be32 *p = xdr->p; __be32 *q; + xdr_commit_encode(xdr); /* align nbytes on the next 32-bit boundary */ nbytes += 3; nbytes &= ~3; q = p + (nbytes >> 2); if (unlikely(q > xdr->end || q < p)) - return NULL; + return xdr_get_next_encode_buffer(xdr, nbytes); xdr->p = q; - xdr->iov->iov_len += nbytes; + if (xdr->iov) + xdr->iov->iov_len += nbytes; + else + xdr->buf->page_len += nbytes; xdr->buf->len += nbytes; return p; } EXPORT_SYMBOL_GPL(xdr_reserve_space); /** + * xdr_truncate_encode - truncate an encode buffer + * @xdr: pointer to xdr_stream + * @len: new length of buffer + * + * Truncates the xdr stream, so that xdr->buf->len == len, + * and xdr->p points at offset len from the start of the buffer, and + * head, tail, and page lengths are adjusted to correspond. + * + * If this means moving xdr->p to a different buffer, we assume that + * that the end pointer should be set to the end of the current page, + * except in the case of the head buffer when we assume the head + * buffer's current length represents the end of the available buffer. + * + * This is *not* safe to use on a buffer that already has inlined page + * cache pages (as in a zero-copy server read reply), except for the + * simple case of truncating from one position in the tail to another. + * + */ +void xdr_truncate_encode(struct xdr_stream *xdr, size_t len) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *head = buf->head; + struct kvec *tail = buf->tail; + int fraglen; + int new, old; + + if (len > buf->len) { + WARN_ON_ONCE(1); + return; + } + xdr_commit_encode(xdr); + + fraglen = min_t(int, buf->len - len, tail->iov_len); + tail->iov_len -= fraglen; + buf->len -= fraglen; + if (tail->iov_len && buf->len == len) { + xdr->p = tail->iov_base + tail->iov_len; + /* xdr->end, xdr->iov should be set already */ + return; + } + WARN_ON_ONCE(fraglen); + fraglen = min_t(int, buf->len - len, buf->page_len); + buf->page_len -= fraglen; + buf->len -= fraglen; + + new = buf->page_base + buf->page_len; + old = new + fraglen; + xdr->page_ptr -= (old >> PAGE_SHIFT) - (new >> PAGE_SHIFT); + + if (buf->page_len && buf->len == len) { + xdr->p = page_address(*xdr->page_ptr); + xdr->end = (void *)xdr->p + PAGE_SIZE; + xdr->p = (void *)xdr->p + (new % PAGE_SIZE); + /* xdr->iov should already be NULL */ + return; + } + if (fraglen) { + xdr->end = head->iov_base + head->iov_len; + xdr->page_ptr--; + } + /* (otherwise assume xdr->end is already set) */ + head->iov_len = len; + buf->len = len; + xdr->p = head->iov_base + head->iov_len; + xdr->iov = buf->head; +} +EXPORT_SYMBOL(xdr_truncate_encode); + +/** + * xdr_restrict_buflen - decrease available buffer space + * @xdr: pointer to xdr_stream + * @newbuflen: new maximum number of bytes available + * + * Adjust our idea of how much space is available in the buffer. + * If we've already used too much space in the buffer, returns -1. + * If the available space is already smaller than newbuflen, returns 0 + * and does nothing. Otherwise, adjusts xdr->buf->buflen to newbuflen + * and ensures xdr->end is set at most offset newbuflen from the start + * of the buffer. + */ +int xdr_restrict_buflen(struct xdr_stream *xdr, int newbuflen) +{ + struct xdr_buf *buf = xdr->buf; + int left_in_this_buf = (void *)xdr->end - (void *)xdr->p; + int end_offset = buf->len + left_in_this_buf; + + if (newbuflen < 0 || newbuflen < buf->len) + return -1; + if (newbuflen > buf->buflen) + return 0; + if (newbuflen < end_offset) + xdr->end = (void *)xdr->end + newbuflen - end_offset; + buf->buflen = newbuflen; + return 0; +} +EXPORT_SYMBOL(xdr_restrict_buflen); + +/** * xdr_write_pages - Insert a list of pages into an XDR buffer for sending * @xdr: pointer to xdr_stream * @pages: list of pages @@ -529,6 +713,72 @@ void xdr_write_pages(struct xdr_stream *xdr, struct page **pages, unsigned int b } EXPORT_SYMBOL_GPL(xdr_write_pages); +static void xdr_set_iov(struct xdr_stream *xdr, struct kvec *iov, + unsigned int len) +{ + if (len > iov->iov_len) + len = iov->iov_len; + xdr->p = (__be32*)iov->iov_base; + xdr->end = (__be32*)(iov->iov_base + len); + xdr->iov = iov; + xdr->page_ptr = NULL; +} + +static int xdr_set_page_base(struct xdr_stream *xdr, + unsigned int base, unsigned int len) +{ + unsigned int pgnr; + unsigned int maxlen; + unsigned int pgoff; + unsigned int pgend; + void *kaddr; + + maxlen = xdr->buf->page_len; + if (base >= maxlen) + return -EINVAL; + maxlen -= base; + if (len > maxlen) + len = maxlen; + + base += xdr->buf->page_base; + + pgnr = base >> PAGE_SHIFT; + xdr->page_ptr = &xdr->buf->pages[pgnr]; + kaddr = page_address(*xdr->page_ptr); + + pgoff = base & ~PAGE_MASK; + xdr->p = (__be32*)(kaddr + pgoff); + + pgend = pgoff + len; + if (pgend > PAGE_SIZE) + pgend = PAGE_SIZE; + xdr->end = (__be32*)(kaddr + pgend); + xdr->iov = NULL; + return 0; +} + +static void xdr_set_next_page(struct xdr_stream *xdr) +{ + unsigned int newbase; + + newbase = (1 + xdr->page_ptr - xdr->buf->pages) << PAGE_SHIFT; + newbase -= xdr->buf->page_base; + + if (xdr_set_page_base(xdr, newbase, PAGE_SIZE) < 0) + xdr_set_iov(xdr, xdr->buf->tail, xdr->buf->len); +} + +static bool xdr_set_next_buffer(struct xdr_stream *xdr) +{ + if (xdr->page_ptr != NULL) + xdr_set_next_page(xdr); + else if (xdr->iov == xdr->buf->head) { + if (xdr_set_page_base(xdr, 0, PAGE_SIZE) < 0) + xdr_set_iov(xdr, xdr->buf->tail, xdr->buf->len); + } + return xdr->p != xdr->end; +} + /** * xdr_init_decode - Initialize an xdr_stream for decoding data. * @xdr: pointer to xdr_stream struct @@ -537,20 +787,93 @@ EXPORT_SYMBOL_GPL(xdr_write_pages); */ void xdr_init_decode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) { - struct kvec *iov = buf->head; - unsigned int len = iov->iov_len; - - if (len > buf->len) - len = buf->len; xdr->buf = buf; - xdr->iov = iov; - xdr->p = p; - xdr->end = (__be32 *)((char *)iov->iov_base + len); + xdr->scratch.iov_base = NULL; + xdr->scratch.iov_len = 0; + xdr->nwords = XDR_QUADLEN(buf->len); + if (buf->head[0].iov_len != 0) + xdr_set_iov(xdr, buf->head, buf->len); + else if (buf->page_len != 0) + xdr_set_page_base(xdr, 0, buf->len); + if (p != NULL && p > xdr->p && xdr->end >= p) { + xdr->nwords -= p - xdr->p; + xdr->p = p; + } } EXPORT_SYMBOL_GPL(xdr_init_decode); /** - * xdr_inline_decode - Retrieve non-page XDR data to decode + * xdr_init_decode - Initialize an xdr_stream for decoding data. + * @xdr: pointer to xdr_stream struct + * @buf: pointer to XDR buffer from which to decode data + * @pages: list of pages to decode into + * @len: length in bytes of buffer in pages + */ +void xdr_init_decode_pages(struct xdr_stream *xdr, struct xdr_buf *buf, + struct page **pages, unsigned int len) +{ + memset(buf, 0, sizeof(*buf)); + buf->pages = pages; + buf->page_len = len; + buf->buflen = len; + buf->len = len; + xdr_init_decode(xdr, buf, NULL); +} +EXPORT_SYMBOL_GPL(xdr_init_decode_pages); + +static __be32 * __xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes) +{ + unsigned int nwords = XDR_QUADLEN(nbytes); + __be32 *p = xdr->p; + __be32 *q = p + nwords; + + if (unlikely(nwords > xdr->nwords || q > xdr->end || q < p)) + return NULL; + xdr->p = q; + xdr->nwords -= nwords; + return p; +} + +/** + * xdr_set_scratch_buffer - Attach a scratch buffer for decoding data. + * @xdr: pointer to xdr_stream struct + * @buf: pointer to an empty buffer + * @buflen: size of 'buf' + * + * The scratch buffer is used when decoding from an array of pages. + * If an xdr_inline_decode() call spans across page boundaries, then + * we copy the data into the scratch buffer in order to allow linear + * access. + */ +void xdr_set_scratch_buffer(struct xdr_stream *xdr, void *buf, size_t buflen) +{ + xdr->scratch.iov_base = buf; + xdr->scratch.iov_len = buflen; +} +EXPORT_SYMBOL_GPL(xdr_set_scratch_buffer); + +static __be32 *xdr_copy_to_scratch(struct xdr_stream *xdr, size_t nbytes) +{ + __be32 *p; + void *cpdest = xdr->scratch.iov_base; + size_t cplen = (char *)xdr->end - (char *)xdr->p; + + if (nbytes > xdr->scratch.iov_len) + return NULL; + memcpy(cpdest, xdr->p, cplen); + cpdest += cplen; + nbytes -= cplen; + if (!xdr_set_next_buffer(xdr)) + return NULL; + p = __xdr_inline_decode(xdr, nbytes); + if (p == NULL) + return NULL; + memcpy(cpdest, p, nbytes); + return xdr->scratch.iov_base; +} + +/** + * xdr_inline_decode - Retrieve XDR data to decode * @xdr: pointer to xdr_stream struct * @nbytes: number of bytes of data to decode * @@ -561,16 +884,49 @@ EXPORT_SYMBOL_GPL(xdr_init_decode); */ __be32 * xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes) { - __be32 *p = xdr->p; - __be32 *q = p + XDR_QUADLEN(nbytes); + __be32 *p; - if (unlikely(q > xdr->end || q < p)) + if (nbytes == 0) + return xdr->p; + if (xdr->p == xdr->end && !xdr_set_next_buffer(xdr)) return NULL; - xdr->p = q; - return p; + p = __xdr_inline_decode(xdr, nbytes); + if (p != NULL) + return p; + return xdr_copy_to_scratch(xdr, nbytes); } EXPORT_SYMBOL_GPL(xdr_inline_decode); +static unsigned int xdr_align_pages(struct xdr_stream *xdr, unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *iov; + unsigned int nwords = XDR_QUADLEN(len); + unsigned int cur = xdr_stream_pos(xdr); + + if (xdr->nwords == 0) + return 0; + /* Realign pages to current pointer position */ + iov = buf->head; + if (iov->iov_len > cur) { + xdr_shrink_bufhead(buf, iov->iov_len - cur); + xdr->nwords = XDR_QUADLEN(buf->len - cur); + } + + if (nwords > xdr->nwords) { + nwords = xdr->nwords; + len = nwords << 2; + } + if (buf->page_len <= len) + len = buf->page_len; + else if (nwords < xdr->nwords) { + /* Truncate page data and move it into the tail */ + xdr_shrink_pagelen(buf, buf->page_len - len); + xdr->nwords = XDR_QUADLEN(buf->len - cur); + } + return len; +} + /** * xdr_read_pages - Ensure page-based XDR data to decode is aligned at current pointer position * @xdr: pointer to xdr_stream struct @@ -579,39 +935,37 @@ EXPORT_SYMBOL_GPL(xdr_inline_decode); * Moves data beyond the current pointer position from the XDR head[] buffer * into the page list. Any data that lies beyond current position + "len" * bytes is moved into the XDR tail[]. + * + * Returns the number of XDR encoded bytes now contained in the pages */ -void xdr_read_pages(struct xdr_stream *xdr, unsigned int len) +unsigned int xdr_read_pages(struct xdr_stream *xdr, unsigned int len) { struct xdr_buf *buf = xdr->buf; struct kvec *iov; - ssize_t shift; + unsigned int nwords; unsigned int end; - int padding; - - /* Realign pages to current pointer position */ - iov = buf->head; - shift = iov->iov_len + (char *)iov->iov_base - (char *)xdr->p; - if (shift > 0) - xdr_shrink_bufhead(buf, shift); + unsigned int padding; - /* Truncate page data and move it into the tail */ - if (buf->page_len > len) - xdr_shrink_pagelen(buf, buf->page_len - len); - padding = (XDR_QUADLEN(len) << 2) - len; + len = xdr_align_pages(xdr, len); + if (len == 0) + return 0; + nwords = XDR_QUADLEN(len); + padding = (nwords << 2) - len; xdr->iov = iov = buf->tail; /* Compute remaining message length. */ - end = iov->iov_len; - shift = buf->buflen - buf->len; - if (shift < end) - end -= shift; - else if (shift > 0) - end = 0; + end = ((xdr->nwords - nwords) << 2) + padding; + if (end > iov->iov_len) + end = iov->iov_len; + /* * Position current pointer at beginning of tail, and * set remaining message length. */ xdr->p = (__be32 *)((char *)iov->iov_base + padding); xdr->end = (__be32 *)((char *)iov->iov_base + end); + xdr->page_ptr = NULL; + xdr->nwords = XDR_QUADLEN(end - padding); + return len; } EXPORT_SYMBOL_GPL(xdr_read_pages); @@ -627,16 +981,13 @@ EXPORT_SYMBOL_GPL(xdr_read_pages); */ void xdr_enter_page(struct xdr_stream *xdr, unsigned int len) { - char * kaddr = page_address(xdr->buf->pages[0]); - xdr_read_pages(xdr, len); + len = xdr_align_pages(xdr, len); /* * Position current pointer at beginning of tail, and * set remaining message length. */ - if (len > PAGE_CACHE_SIZE - xdr->buf->page_base) - len = PAGE_CACHE_SIZE - xdr->buf->page_base; - xdr->p = (__be32 *)(kaddr + xdr->buf->page_base); - xdr->end = (__be32 *)((char *)xdr->p + len); + if (len != 0) + xdr_set_page_base(xdr, 0, len); } EXPORT_SYMBOL_GPL(xdr_enter_page); @@ -652,8 +1003,20 @@ xdr_buf_from_iov(struct kvec *iov, struct xdr_buf *buf) } EXPORT_SYMBOL_GPL(xdr_buf_from_iov); -/* Sets subbuf to the portion of buf of length len beginning base bytes - * from the start of buf. Returns -1 if base of length are out of bounds. */ +/** + * xdr_buf_subsegment - set subbuf to a portion of buf + * @buf: an xdr buffer + * @subbuf: the result buffer + * @base: beginning of range in bytes + * @len: length of range in bytes + * + * sets @subbuf to an xdr buffer representing the portion of @buf of + * length @len starting at offset @base. + * + * @buf and @subbuf may be pointers to the same struct xdr_buf. + * + * Returns -1 if base of length are out of bounds. + */ int xdr_buf_subsegment(struct xdr_buf *buf, struct xdr_buf *subbuf, unsigned int base, unsigned int len) @@ -666,9 +1029,8 @@ xdr_buf_subsegment(struct xdr_buf *buf, struct xdr_buf *subbuf, len -= subbuf->head[0].iov_len; base = 0; } else { - subbuf->head[0].iov_base = NULL; - subbuf->head[0].iov_len = 0; base -= buf->head[0].iov_len; + subbuf->head[0].iov_len = 0; } if (base < buf->page_len) { @@ -690,9 +1052,8 @@ xdr_buf_subsegment(struct xdr_buf *buf, struct xdr_buf *subbuf, len -= subbuf->tail[0].iov_len; base = 0; } else { - subbuf->tail[0].iov_base = NULL; - subbuf->tail[0].iov_len = 0; base -= buf->tail[0].iov_len; + subbuf->tail[0].iov_len = 0; } if (base || len) @@ -701,6 +1062,47 @@ xdr_buf_subsegment(struct xdr_buf *buf, struct xdr_buf *subbuf, } EXPORT_SYMBOL_GPL(xdr_buf_subsegment); +/** + * xdr_buf_trim - lop at most "len" bytes off the end of "buf" + * @buf: buf to be trimmed + * @len: number of bytes to reduce "buf" by + * + * Trim an xdr_buf by the given number of bytes by fixing up the lengths. Note + * that it's possible that we'll trim less than that amount if the xdr_buf is + * too small, or if (for instance) it's all in the head and the parser has + * already read too far into it. + */ +void xdr_buf_trim(struct xdr_buf *buf, unsigned int len) +{ + size_t cur; + unsigned int trim = len; + + if (buf->tail[0].iov_len) { + cur = min_t(size_t, buf->tail[0].iov_len, trim); + buf->tail[0].iov_len -= cur; + trim -= cur; + if (!trim) + goto fix_len; + } + + if (buf->page_len) { + cur = min_t(unsigned int, buf->page_len, trim); + buf->page_len -= cur; + trim -= cur; + if (!trim) + goto fix_len; + } + + if (buf->head[0].iov_len) { + cur = min_t(size_t, buf->head[0].iov_len, trim); + buf->head[0].iov_len -= cur; + trim -= cur; + } +fix_len: + buf->len -= (len - trim); +} +EXPORT_SYMBOL_GPL(xdr_buf_trim); + static void __read_bytes_from_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigned int len) { unsigned int this_len; @@ -761,6 +1163,7 @@ int write_bytes_to_xdr_buf(struct xdr_buf *buf, unsigned int base, void *obj, un __write_bytes_to_xdr_buf(&subbuf, obj, len); return 0; } +EXPORT_SYMBOL_GPL(write_bytes_to_xdr_buf); int xdr_decode_word(struct xdr_buf *buf, unsigned int base, u32 *obj) @@ -1046,7 +1449,7 @@ xdr_process_buf(struct xdr_buf *buf, unsigned int offset, unsigned int len, int (*actor)(struct scatterlist *, void *), void *data) { int i, ret = 0; - unsigned page_len, thislen, page_offset; + unsigned int page_len, thislen, page_offset; struct scatterlist sg[1]; sg_init_table(sg, 1); diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c index 469de292c23..c3b2b3369e5 100644 --- a/net/sunrpc/xprt.c +++ b/net/sunrpc/xprt.c @@ -43,9 +43,11 @@ #include <linux/interrupt.h> #include <linux/workqueue.h> #include <linux/net.h> +#include <linux/ktime.h> #include <linux/sunrpc/clnt.h> #include <linux/sunrpc/metrics.h> +#include <linux/sunrpc/bc_xprt.h> #include "sunrpc.h" @@ -60,32 +62,15 @@ /* * Local functions */ +static void xprt_init(struct rpc_xprt *xprt, struct net *net); static void xprt_request_init(struct rpc_task *, struct rpc_xprt *); -static inline void do_xprt_reserve(struct rpc_task *); static void xprt_connect_status(struct rpc_task *task); static int __xprt_get_cong(struct rpc_xprt *, struct rpc_task *); +static void xprt_destroy(struct rpc_xprt *xprt); static DEFINE_SPINLOCK(xprt_list_lock); static LIST_HEAD(xprt_list); -/* - * The transport code maintains an estimate on the maximum number of out- - * standing RPC requests, using a smoothed version of the congestion - * avoidance implemented in 44BSD. This is basically the Van Jacobson - * congestion algorithm: If a retransmit occurs, the congestion window is - * halved; otherwise, it is incremented by 1/cwnd when - * - * - a reply is received and - * - a full number of requests are outstanding and - * - the congestion window hasn't been updated recently. - */ -#define RPC_CWNDSHIFT (8U) -#define RPC_CWNDSCALE (1U << RPC_CWNDSHIFT) -#define RPC_INITCWND RPC_CWNDSCALE -#define RPC_MAXCWND(xprt) ((xprt)->max_reqs << RPC_CWNDSHIFT) - -#define RPCXPRT_CONGESTED(xprt) ((xprt)->cong >= (xprt)->cwnd) - /** * xprt_register_transport - register a transport implementation * @transport: transport to register @@ -165,7 +150,6 @@ EXPORT_SYMBOL_GPL(xprt_unregister_transport); int xprt_load_transport(const char *transport_name) { struct xprt_class *t; - char module_name[sizeof t->name + 5]; int result; result = 0; @@ -177,9 +161,7 @@ int xprt_load_transport(const char *transport_name) } } spin_unlock(&xprt_list_lock); - strcpy(module_name, "xprt"); - strncat(module_name, transport_name, sizeof t->name); - result = request_module(module_name); + result = request_module("xprt%s", transport_name); out: return result; } @@ -188,28 +170,26 @@ EXPORT_SYMBOL_GPL(xprt_load_transport); /** * xprt_reserve_xprt - serialize write access to transports * @task: task that is requesting access to the transport + * @xprt: pointer to the target transport * * This prevents mixing the payload of separate requests, and prevents * transport connects from colliding with writes. No congestion control * is provided. */ -int xprt_reserve_xprt(struct rpc_task *task) +int xprt_reserve_xprt(struct rpc_xprt *xprt, struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; - struct rpc_xprt *xprt = req->rq_xprt; + int priority; if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { if (task == xprt->snd_task) return 1; - if (task == NULL) - return 0; goto out_sleep; } xprt->snd_task = task; - if (req) { - req->rq_bytes_sent = 0; + if (req != NULL) req->rq_ntrans++; - } + return 1; out_sleep: @@ -217,10 +197,13 @@ out_sleep: task->tk_pid, xprt); task->tk_timeout = 0; task->tk_status = -EAGAIN; - if (req && req->rq_ntrans) - rpc_sleep_on(&xprt->resend, task, NULL); + if (req == NULL) + priority = RPC_PRIORITY_LOW; + else if (!req->rq_ntrans) + priority = RPC_PRIORITY_NORMAL; else - rpc_sleep_on(&xprt->sending, task, NULL); + priority = RPC_PRIORITY_HIGH; + rpc_sleep_on_priority(&xprt->sending, task, NULL, priority); return 0; } EXPORT_SYMBOL_GPL(xprt_reserve_xprt); @@ -228,10 +211,10 @@ EXPORT_SYMBOL_GPL(xprt_reserve_xprt); static void xprt_clear_locked(struct rpc_xprt *xprt) { xprt->snd_task = NULL; - if (!test_bit(XPRT_CLOSE_WAIT, &xprt->state) || xprt->shutdown) { - smp_mb__before_clear_bit(); + if (!test_bit(XPRT_CLOSE_WAIT, &xprt->state)) { + smp_mb__before_atomic(); clear_bit(XPRT_LOCKED, &xprt->state); - smp_mb__after_clear_bit(); + smp_mb__after_atomic(); } else queue_work(rpciod_workqueue, &xprt->task_cleanup); } @@ -244,22 +227,23 @@ static void xprt_clear_locked(struct rpc_xprt *xprt) * integrated into the decision of whether a request is allowed to be * woken up and given access to the transport. */ -int xprt_reserve_xprt_cong(struct rpc_task *task) +int xprt_reserve_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) { - struct rpc_xprt *xprt = task->tk_xprt; struct rpc_rqst *req = task->tk_rqstp; + int priority; if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { if (task == xprt->snd_task) return 1; goto out_sleep; } + if (req == NULL) { + xprt->snd_task = task; + return 1; + } if (__xprt_get_cong(xprt, task)) { xprt->snd_task = task; - if (req) { - req->rq_bytes_sent = 0; - req->rq_ntrans++; - } + req->rq_ntrans++; return 1; } xprt_clear_locked(xprt); @@ -267,10 +251,13 @@ out_sleep: dprintk("RPC: %5u failed to lock transport %p\n", task->tk_pid, xprt); task->tk_timeout = 0; task->tk_status = -EAGAIN; - if (req && req->rq_ntrans) - rpc_sleep_on(&xprt->resend, task, NULL); + if (req == NULL) + priority = RPC_PRIORITY_LOW; + else if (!req->rq_ntrans) + priority = RPC_PRIORITY_NORMAL; else - rpc_sleep_on(&xprt->sending, task, NULL); + priority = RPC_PRIORITY_HIGH; + rpc_sleep_on_priority(&xprt->sending, task, NULL, priority); return 0; } EXPORT_SYMBOL_GPL(xprt_reserve_xprt_cong); @@ -280,61 +267,59 @@ static inline int xprt_lock_write(struct rpc_xprt *xprt, struct rpc_task *task) int retval; spin_lock_bh(&xprt->transport_lock); - retval = xprt->ops->reserve_xprt(task); + retval = xprt->ops->reserve_xprt(xprt, task); spin_unlock_bh(&xprt->transport_lock); return retval; } -static void __xprt_lock_write_next(struct rpc_xprt *xprt) +static bool __xprt_lock_write_func(struct rpc_task *task, void *data) { - struct rpc_task *task; + struct rpc_xprt *xprt = data; struct rpc_rqst *req; + req = task->tk_rqstp; + xprt->snd_task = task; + if (req) + req->rq_ntrans++; + return true; +} + +static void __xprt_lock_write_next(struct rpc_xprt *xprt) +{ if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) return; - task = rpc_wake_up_next(&xprt->resend); - if (!task) { - task = rpc_wake_up_next(&xprt->sending); - if (!task) - goto out_unlock; - } + if (rpc_wake_up_first(&xprt->sending, __xprt_lock_write_func, xprt)) + return; + xprt_clear_locked(xprt); +} + +static bool __xprt_lock_write_cong_func(struct rpc_task *task, void *data) +{ + struct rpc_xprt *xprt = data; + struct rpc_rqst *req; req = task->tk_rqstp; - xprt->snd_task = task; - if (req) { - req->rq_bytes_sent = 0; + if (req == NULL) { + xprt->snd_task = task; + return true; + } + if (__xprt_get_cong(xprt, task)) { + xprt->snd_task = task; req->rq_ntrans++; + return true; } - return; - -out_unlock: - xprt_clear_locked(xprt); + return false; } static void __xprt_lock_write_next_cong(struct rpc_xprt *xprt) { - struct rpc_task *task; - if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) return; if (RPCXPRT_CONGESTED(xprt)) goto out_unlock; - task = rpc_wake_up_next(&xprt->resend); - if (!task) { - task = rpc_wake_up_next(&xprt->sending); - if (!task) - goto out_unlock; - } - if (__xprt_get_cong(xprt, task)) { - struct rpc_rqst *req = task->tk_rqstp; - xprt->snd_task = task; - if (req) { - req->rq_bytes_sent = 0; - req->rq_ntrans++; - } + if (rpc_wake_up_first(&xprt->sending, __xprt_lock_write_cong_func, xprt)) return; - } out_unlock: xprt_clear_locked(xprt); } @@ -349,6 +334,11 @@ out_unlock: void xprt_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task) { if (xprt->snd_task == task) { + if (task != NULL) { + struct rpc_rqst *req = task->tk_rqstp; + if (req != NULL) + req->rq_bytes_sent = 0; + } xprt_clear_locked(xprt); __xprt_lock_write_next(xprt); } @@ -366,6 +356,11 @@ EXPORT_SYMBOL_GPL(xprt_release_xprt); void xprt_release_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) { if (xprt->snd_task == task) { + if (task != NULL) { + struct rpc_rqst *req = task->tk_rqstp; + if (req != NULL) + req->rq_bytes_sent = 0; + } xprt_clear_locked(xprt); __xprt_lock_write_next_cong(xprt); } @@ -421,21 +416,31 @@ __xprt_put_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) */ void xprt_release_rqst_cong(struct rpc_task *task) { - __xprt_put_cong(task->tk_xprt, task->tk_rqstp); + struct rpc_rqst *req = task->tk_rqstp; + + __xprt_put_cong(req->rq_xprt, req); } EXPORT_SYMBOL_GPL(xprt_release_rqst_cong); /** * xprt_adjust_cwnd - adjust transport congestion window + * @xprt: pointer to xprt * @task: recently completed RPC request used to adjust window * @result: result code of completed RPC request * - * We use a time-smoothed congestion estimator to avoid heavy oscillation. + * The transport code maintains an estimate on the maximum number of out- + * standing RPC requests, using a smoothed version of the congestion + * avoidance implemented in 44BSD. This is basically the Van Jacobson + * congestion algorithm: If a retransmit occurs, the congestion window is + * halved; otherwise, it is incremented by 1/cwnd when + * + * - a reply is received and + * - a full number of requests are outstanding and + * - the congestion window hasn't been updated recently. */ -void xprt_adjust_cwnd(struct rpc_task *task, int result) +void xprt_adjust_cwnd(struct rpc_xprt *xprt, struct rpc_task *task, int result) { struct rpc_rqst *req = task->tk_rqstp; - struct rpc_xprt *xprt = task->tk_xprt; unsigned long cwnd = xprt->cwnd; if (result >= 0 && cwnd <= xprt->cong) { @@ -476,13 +481,17 @@ EXPORT_SYMBOL_GPL(xprt_wake_pending_tasks); * xprt_wait_for_buffer_space - wait for transport output buffer to clear * @task: task to be put to sleep * @action: function pointer to be executed after wait + * + * Note that we only set the timer for the case of RPC_IS_SOFT(), since + * we don't in general want to force a socket disconnection due to + * an incomplete RPC call transmission. */ void xprt_wait_for_buffer_space(struct rpc_task *task, rpc_action action) { struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; - task->tk_timeout = req->rq_timeout; + task->tk_timeout = RPC_IS_SOFT(task) ? req->rq_timeout : 0; rpc_sleep_on(&xprt->pending, task, action); } EXPORT_SYMBOL_GPL(xprt_wait_for_buffer_space); @@ -495,9 +504,6 @@ EXPORT_SYMBOL_GPL(xprt_wait_for_buffer_space); */ void xprt_write_space(struct rpc_xprt *xprt) { - if (unlikely(xprt->shutdown)) - return; - spin_lock_bh(&xprt->transport_lock); if (xprt->snd_task) { dprintk("RPC: write space: waking waiting task on " @@ -522,7 +528,7 @@ void xprt_set_retrans_timeout_def(struct rpc_task *task) } EXPORT_SYMBOL_GPL(xprt_set_retrans_timeout_def); -/* +/** * xprt_set_retrans_timeout_rtt - set a request's retransmit timeout * @task: task whose timeout is to be set * @@ -670,7 +676,7 @@ xprt_init_autodisconnect(unsigned long data) struct rpc_xprt *xprt = (struct rpc_xprt *)data; spin_lock(&xprt->transport_lock); - if (!list_empty(&xprt->recv) || xprt->shutdown) + if (!list_empty(&xprt->recv)) goto out_abort; if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) goto out_abort; @@ -689,7 +695,7 @@ out_abort: */ void xprt_connect(struct rpc_task *task) { - struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; dprintk("RPC: %5u xprt_connect xprt %p %s connected\n", task->tk_pid, xprt, (xprt_connected(xprt) ? "is" : "is not")); @@ -707,20 +713,22 @@ void xprt_connect(struct rpc_task *task) if (xprt_connected(xprt)) xprt_release_write(xprt, task); else { - if (task->tk_rqstp) - task->tk_rqstp->rq_bytes_sent = 0; - - task->tk_timeout = xprt->connect_timeout; + task->tk_rqstp->rq_bytes_sent = 0; + task->tk_timeout = task->tk_rqstp->rq_timeout; rpc_sleep_on(&xprt->pending, task, xprt_connect_status); + + if (test_bit(XPRT_CLOSING, &xprt->state)) + return; + if (xprt_test_and_set_connecting(xprt)) + return; xprt->stat.connect_start = jiffies; - xprt->ops->connect(task); + xprt->ops->connect(xprt, task); } - return; } static void xprt_connect_status(struct rpc_task *task) { - struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; if (task->tk_status == 0) { xprt->stat.connect_count++; @@ -731,6 +739,11 @@ static void xprt_connect_status(struct rpc_task *task) } switch (task->tk_status) { + case -ECONNREFUSED: + case -ECONNRESET: + case -ECONNABORTED: + case -ENETUNREACH: + case -EHOSTUNREACH: case -EAGAIN: dprintk("RPC: %5u xprt_connect_status: retrying\n", task->tk_pid); break; @@ -741,7 +754,7 @@ static void xprt_connect_status(struct rpc_task *task) default: dprintk("RPC: %5u xprt_connect_status: error %d connecting to " "server %s\n", task->tk_pid, -task->tk_status, - task->tk_client->cl_server); + xprt->servername); xprt_release_write(xprt, task); task->tk_status = -EIO; } @@ -755,13 +768,11 @@ static void xprt_connect_status(struct rpc_task *task) */ struct rpc_rqst *xprt_lookup_rqst(struct rpc_xprt *xprt, __be32 xid) { - struct list_head *pos; + struct rpc_rqst *entry; - list_for_each(pos, &xprt->recv) { - struct rpc_rqst *entry = list_entry(pos, struct rpc_rqst, rq_list); + list_for_each_entry(entry, &xprt->recv, rq_list) if (entry->rq_xid == xid) return entry; - } dprintk("RPC: xprt_lookup_rqst did not find xid %08x\n", ntohl(xid)); @@ -770,25 +781,19 @@ struct rpc_rqst *xprt_lookup_rqst(struct rpc_xprt *xprt, __be32 xid) } EXPORT_SYMBOL_GPL(xprt_lookup_rqst); -/** - * xprt_update_rtt - update an RPC client's RTT state after receiving a reply - * @task: RPC request that recently completed - * - */ -void xprt_update_rtt(struct rpc_task *task) +static void xprt_update_rtt(struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; struct rpc_rtt *rtt = task->tk_client->cl_rtt; - unsigned timer = task->tk_msg.rpc_proc->p_timer; + unsigned int timer = task->tk_msg.rpc_proc->p_timer; + long m = usecs_to_jiffies(ktime_to_us(req->rq_rtt)); if (timer) { if (req->rq_ntrans == 1) - rpc_update_rtt(rtt, timer, - (long)jiffies - req->rq_xtime); + rpc_update_rtt(rtt, timer, m); rpc_set_timeo(rtt, timer, req->rq_ntrans - 1); } } -EXPORT_SYMBOL_GPL(xprt_update_rtt); /** * xprt_complete_rqst - called when reply processing is complete @@ -806,7 +811,9 @@ void xprt_complete_rqst(struct rpc_task *task, int copied) task->tk_pid, ntohl(req->rq_xid), copied); xprt->stat.recvs++; - task->tk_rtt = (long)jiffies - req->rq_xtime; + req->rq_rtt = ktime_sub(ktime_get(), req->rq_xtime); + if (xprt->ops->timer != NULL) + xprt_update_rtt(task); list_del_init(&req->rq_list); req->rq_private_buf.len = copied; @@ -830,7 +837,7 @@ static void xprt_timer(struct rpc_task *task) spin_lock_bh(&xprt->transport_lock); if (!req->rq_reply_bytes_recvd) { if (xprt->ops->timer) - xprt->ops->timer(task); + xprt->ops->timer(xprt, task); } else task->tk_status = 0; spin_unlock_bh(&xprt->transport_lock); @@ -846,24 +853,36 @@ static inline int xprt_has_timer(struct rpc_xprt *xprt) * @task: RPC task about to send a request * */ -int xprt_prepare_transmit(struct rpc_task *task) +bool xprt_prepare_transmit(struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; - int err = 0; + bool ret = false; dprintk("RPC: %5u xprt_prepare_transmit\n", task->tk_pid); spin_lock_bh(&xprt->transport_lock); - if (req->rq_reply_bytes_recvd && !req->rq_bytes_sent) { - err = req->rq_reply_bytes_recvd; + if (!req->rq_bytes_sent) { + if (req->rq_reply_bytes_recvd) { + task->tk_status = req->rq_reply_bytes_recvd; + goto out_unlock; + } + if ((task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) + && xprt_connected(xprt) + && req->rq_connect_cookie == xprt->connect_cookie) { + xprt->ops->set_retrans_timeout(task); + rpc_sleep_on(&xprt->pending, task, xprt_timer); + goto out_unlock; + } + } + if (!xprt->ops->reserve_xprt(xprt, task)) { + task->tk_status = -EAGAIN; goto out_unlock; } - if (!xprt->ops->reserve_xprt(task)) - err = -EAGAIN; + ret = true; out_unlock: spin_unlock_bh(&xprt->transport_lock); - return err; + return ret; } void xprt_end_transmit(struct rpc_task *task) @@ -881,7 +900,7 @@ void xprt_transmit(struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; - int status; + int status, numreqs; dprintk("RPC: %5u xprt_transmit(%u)\n", task->tk_pid, req->rq_slen); @@ -904,8 +923,7 @@ void xprt_transmit(struct rpc_task *task) } else if (!req->rq_bytes_sent) return; - req->rq_connect_cookie = xprt->connect_cookie; - req->rq_xtime = jiffies; + req->rq_xtime = ktime_get(); status = xprt->ops->send_request(task); if (status != 0) { task->tk_status = status; @@ -913,83 +931,277 @@ void xprt_transmit(struct rpc_task *task) } dprintk("RPC: %5u xmit complete\n", task->tk_pid); + task->tk_flags |= RPC_TASK_SENT; spin_lock_bh(&xprt->transport_lock); xprt->ops->set_retrans_timeout(task); + numreqs = atomic_read(&xprt->num_reqs); + if (numreqs > xprt->stat.max_slots) + xprt->stat.max_slots = numreqs; xprt->stat.sends++; xprt->stat.req_u += xprt->stat.sends - xprt->stat.recvs; xprt->stat.bklog_u += xprt->backlog.qlen; + xprt->stat.sending_u += xprt->sending.qlen; + xprt->stat.pending_u += xprt->pending.qlen; /* Don't race with disconnect */ if (!xprt_connected(xprt)) task->tk_status = -ENOTCONN; - else if (!req->rq_reply_bytes_recvd && rpc_reply_expected(task)) { + else { /* * Sleep on the pending queue since * we're expecting a reply. */ - rpc_sleep_on(&xprt->pending, task, xprt_timer); + if (!req->rq_reply_bytes_recvd && rpc_reply_expected(task)) + rpc_sleep_on(&xprt->pending, task, xprt_timer); + req->rq_connect_cookie = xprt->connect_cookie; } spin_unlock_bh(&xprt->transport_lock); } -static inline void do_xprt_reserve(struct rpc_task *task) +static void xprt_add_backlog(struct rpc_xprt *xprt, struct rpc_task *task) { - struct rpc_xprt *xprt = task->tk_xprt; + set_bit(XPRT_CONGESTED, &xprt->state); + rpc_sleep_on(&xprt->backlog, task, NULL); +} - task->tk_status = 0; - if (task->tk_rqstp) - return; +static void xprt_wake_up_backlog(struct rpc_xprt *xprt) +{ + if (rpc_wake_up_next(&xprt->backlog) == NULL) + clear_bit(XPRT_CONGESTED, &xprt->state); +} + +static bool xprt_throttle_congested(struct rpc_xprt *xprt, struct rpc_task *task) +{ + bool ret = false; + + if (!test_bit(XPRT_CONGESTED, &xprt->state)) + goto out; + spin_lock(&xprt->reserve_lock); + if (test_bit(XPRT_CONGESTED, &xprt->state)) { + rpc_sleep_on(&xprt->backlog, task, NULL); + ret = true; + } + spin_unlock(&xprt->reserve_lock); +out: + return ret; +} + +static struct rpc_rqst *xprt_dynamic_alloc_slot(struct rpc_xprt *xprt, gfp_t gfp_flags) +{ + struct rpc_rqst *req = ERR_PTR(-EAGAIN); + + if (!atomic_add_unless(&xprt->num_reqs, 1, xprt->max_reqs)) + goto out; + req = kzalloc(sizeof(struct rpc_rqst), gfp_flags); + if (req != NULL) + goto out; + atomic_dec(&xprt->num_reqs); + req = ERR_PTR(-ENOMEM); +out: + return req; +} + +static bool xprt_dynamic_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + if (atomic_add_unless(&xprt->num_reqs, -1, xprt->min_reqs)) { + kfree(req); + return true; + } + return false; +} + +void xprt_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpc_rqst *req; + + spin_lock(&xprt->reserve_lock); if (!list_empty(&xprt->free)) { - struct rpc_rqst *req = list_entry(xprt->free.next, struct rpc_rqst, rq_list); - list_del_init(&req->rq_list); - task->tk_rqstp = req; - xprt_request_init(task, xprt); - return; + req = list_entry(xprt->free.next, struct rpc_rqst, rq_list); + list_del(&req->rq_list); + goto out_init_req; } - dprintk("RPC: waiting for request slot\n"); - task->tk_status = -EAGAIN; - task->tk_timeout = 0; - rpc_sleep_on(&xprt->backlog, task, NULL); + req = xprt_dynamic_alloc_slot(xprt, GFP_NOWAIT|__GFP_NOWARN); + if (!IS_ERR(req)) + goto out_init_req; + switch (PTR_ERR(req)) { + case -ENOMEM: + dprintk("RPC: dynamic allocation of request slot " + "failed! Retrying\n"); + task->tk_status = -ENOMEM; + break; + case -EAGAIN: + xprt_add_backlog(xprt, task); + dprintk("RPC: waiting for request slot\n"); + default: + task->tk_status = -EAGAIN; + } + spin_unlock(&xprt->reserve_lock); + return; +out_init_req: + task->tk_status = 0; + task->tk_rqstp = req; + xprt_request_init(task, xprt); + spin_unlock(&xprt->reserve_lock); +} +EXPORT_SYMBOL_GPL(xprt_alloc_slot); + +void xprt_lock_and_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) +{ + /* Note: grabbing the xprt_lock_write() ensures that we throttle + * new slot allocation if the transport is congested (i.e. when + * reconnecting a stream transport or when out of socket write + * buffer space). + */ + if (xprt_lock_write(xprt, task)) { + xprt_alloc_slot(xprt, task); + xprt_release_write(xprt, task); + } +} +EXPORT_SYMBOL_GPL(xprt_lock_and_alloc_slot); + +static void xprt_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + spin_lock(&xprt->reserve_lock); + if (!xprt_dynamic_free_slot(xprt, req)) { + memset(req, 0, sizeof(*req)); /* mark unused */ + list_add(&req->rq_list, &xprt->free); + } + xprt_wake_up_backlog(xprt); + spin_unlock(&xprt->reserve_lock); +} + +static void xprt_free_all_slots(struct rpc_xprt *xprt) +{ + struct rpc_rqst *req; + while (!list_empty(&xprt->free)) { + req = list_first_entry(&xprt->free, struct rpc_rqst, rq_list); + list_del(&req->rq_list); + kfree(req); + } +} + +struct rpc_xprt *xprt_alloc(struct net *net, size_t size, + unsigned int num_prealloc, + unsigned int max_alloc) +{ + struct rpc_xprt *xprt; + struct rpc_rqst *req; + int i; + + xprt = kzalloc(size, GFP_KERNEL); + if (xprt == NULL) + goto out; + + xprt_init(xprt, net); + + for (i = 0; i < num_prealloc; i++) { + req = kzalloc(sizeof(struct rpc_rqst), GFP_KERNEL); + if (!req) + goto out_free; + list_add(&req->rq_list, &xprt->free); + } + if (max_alloc > num_prealloc) + xprt->max_reqs = max_alloc; + else + xprt->max_reqs = num_prealloc; + xprt->min_reqs = num_prealloc; + atomic_set(&xprt->num_reqs, num_prealloc); + + return xprt; + +out_free: + xprt_free(xprt); +out: + return NULL; +} +EXPORT_SYMBOL_GPL(xprt_alloc); + +void xprt_free(struct rpc_xprt *xprt) +{ + put_net(xprt->xprt_net); + xprt_free_all_slots(xprt); + kfree(xprt); } +EXPORT_SYMBOL_GPL(xprt_free); /** * xprt_reserve - allocate an RPC request slot * @task: RPC task requesting a slot allocation * - * If no more slots are available, place the task on the transport's + * If the transport is marked as being congested, or if no more + * slots are available, place the task on the transport's * backlog queue. */ void xprt_reserve(struct rpc_task *task) { - struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_xprt *xprt; - task->tk_status = -EIO; - spin_lock(&xprt->reserve_lock); - do_xprt_reserve(task); - spin_unlock(&xprt->reserve_lock); + task->tk_status = 0; + if (task->tk_rqstp != NULL) + return; + + task->tk_timeout = 0; + task->tk_status = -EAGAIN; + rcu_read_lock(); + xprt = rcu_dereference(task->tk_client->cl_xprt); + if (!xprt_throttle_congested(xprt, task)) + xprt->ops->alloc_slot(xprt, task); + rcu_read_unlock(); +} + +/** + * xprt_retry_reserve - allocate an RPC request slot + * @task: RPC task requesting a slot allocation + * + * If no more slots are available, place the task on the transport's + * backlog queue. + * Note that the only difference with xprt_reserve is that we now + * ignore the value of the XPRT_CONGESTED flag. + */ +void xprt_retry_reserve(struct rpc_task *task) +{ + struct rpc_xprt *xprt; + + task->tk_status = 0; + if (task->tk_rqstp != NULL) + return; + + task->tk_timeout = 0; + task->tk_status = -EAGAIN; + rcu_read_lock(); + xprt = rcu_dereference(task->tk_client->cl_xprt); + xprt->ops->alloc_slot(xprt, task); + rcu_read_unlock(); } static inline __be32 xprt_alloc_xid(struct rpc_xprt *xprt) { - return xprt->xid++; + return (__force __be32)xprt->xid++; } static inline void xprt_init_xid(struct rpc_xprt *xprt) { - xprt->xid = net_random(); + xprt->xid = prandom_u32(); } static void xprt_request_init(struct rpc_task *task, struct rpc_xprt *xprt) { struct rpc_rqst *req = task->tk_rqstp; + INIT_LIST_HEAD(&req->rq_list); req->rq_timeout = task->tk_client->cl_timeout->to_initval; req->rq_task = task; req->rq_xprt = xprt; req->rq_buffer = NULL; req->rq_xid = xprt_alloc_xid(xprt); + req->rq_connect_cookie = xprt->connect_cookie - 1; + req->rq_bytes_sent = 0; + req->rq_snd_buf.len = 0; + req->rq_snd_buf.buflen = 0; + req->rq_rcv_buf.len = 0; + req->rq_rcv_buf.buflen = 0; req->rq_release_snd_buf = NULL; xprt_reset_majortimeo(req); dprintk("RPC: %5u reserved req %p xid %08x\n", task->tk_pid, @@ -1004,17 +1216,24 @@ static void xprt_request_init(struct rpc_task *task, struct rpc_xprt *xprt) void xprt_release(struct rpc_task *task) { struct rpc_xprt *xprt; - struct rpc_rqst *req; - int is_bc_request; + struct rpc_rqst *req = task->tk_rqstp; - if (!(req = task->tk_rqstp)) + if (req == NULL) { + if (task->tk_client) { + rcu_read_lock(); + xprt = rcu_dereference(task->tk_client->cl_xprt); + if (xprt->snd_task == task) + xprt_release_write(xprt, task); + rcu_read_unlock(); + } return; - - /* Preallocated backchannel request? */ - is_bc_request = bc_prealloc(req); + } xprt = req->rq_xprt; - rpc_count_iostats(task); + if (task->tk_ops->rpc_count_stats != NULL) + task->tk_ops->rpc_count_stats(task, task->tk_calldata); + else if (task->tk_client) + rpc_count_iostats(task, task->tk_client->cl_metrics); spin_lock_bh(&xprt->transport_lock); xprt->ops->release_xprt(xprt, task); if (xprt->ops->release_request) @@ -1026,27 +1245,47 @@ void xprt_release(struct rpc_task *task) mod_timer(&xprt->timer, xprt->last_used + xprt->idle_timeout); spin_unlock_bh(&xprt->transport_lock); - if (!bc_prealloc(req)) + if (req->rq_buffer) xprt->ops->buf_free(req->rq_buffer); + if (req->rq_cred != NULL) + put_rpccred(req->rq_cred); task->tk_rqstp = NULL; if (req->rq_release_snd_buf) req->rq_release_snd_buf(req); - /* - * Early exit if this is a backchannel preallocated request. - * There is no need to have it added to the RPC slot list. - */ - if (is_bc_request) - return; + dprintk("RPC: %5u release request %p\n", task->tk_pid, req); + if (likely(!bc_prealloc(req))) + xprt_free_slot(xprt, req); + else + xprt_free_bc_request(req); +} - memset(req, 0, sizeof(*req)); /* mark unused */ +static void xprt_init(struct rpc_xprt *xprt, struct net *net) +{ + atomic_set(&xprt->count, 1); - dprintk("RPC: %5u release request %p\n", task->tk_pid, req); + spin_lock_init(&xprt->transport_lock); + spin_lock_init(&xprt->reserve_lock); - spin_lock(&xprt->reserve_lock); - list_add(&req->rq_list, &xprt->free); - rpc_wake_up_next(&xprt->backlog); - spin_unlock(&xprt->reserve_lock); + INIT_LIST_HEAD(&xprt->free); + INIT_LIST_HEAD(&xprt->recv); +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + spin_lock_init(&xprt->bc_pa_lock); + INIT_LIST_HEAD(&xprt->bc_pa_list); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + + xprt->last_used = jiffies; + xprt->cwnd = RPC_INITCWND; + xprt->bind_index = 0; + + rpc_init_wait_queue(&xprt->binding, "xprt_binding"); + rpc_init_wait_queue(&xprt->pending, "xprt_pending"); + rpc_init_priority_wait_queue(&xprt->sending, "xprt_sending"); + rpc_init_priority_wait_queue(&xprt->backlog, "xprt_backlog"); + + xprt_init_xid(xprt); + + xprt->xprt_net = get_net(net); } /** @@ -1057,7 +1296,6 @@ void xprt_release(struct rpc_task *task) struct rpc_xprt *xprt_create_transport(struct xprt_create *args) { struct rpc_xprt *xprt; - struct rpc_rqst *req; struct xprt_class *t; spin_lock(&xprt_list_lock); @@ -1076,65 +1314,49 @@ found: if (IS_ERR(xprt)) { dprintk("RPC: xprt_create_transport: failed, %ld\n", -PTR_ERR(xprt)); - return xprt; + goto out; } - - kref_init(&xprt->kref); - spin_lock_init(&xprt->transport_lock); - spin_lock_init(&xprt->reserve_lock); - - INIT_LIST_HEAD(&xprt->free); - INIT_LIST_HEAD(&xprt->recv); -#if defined(CONFIG_NFS_V4_1) - spin_lock_init(&xprt->bc_pa_lock); - INIT_LIST_HEAD(&xprt->bc_pa_list); -#endif /* CONFIG_NFS_V4_1 */ - + if (args->flags & XPRT_CREATE_NO_IDLE_TIMEOUT) + xprt->idle_timeout = 0; INIT_WORK(&xprt->task_cleanup, xprt_autoclose); if (xprt_has_timer(xprt)) setup_timer(&xprt->timer, xprt_init_autodisconnect, (unsigned long)xprt); else init_timer(&xprt->timer); - xprt->last_used = jiffies; - xprt->cwnd = RPC_INITCWND; - xprt->bind_index = 0; - - rpc_init_wait_queue(&xprt->binding, "xprt_binding"); - rpc_init_wait_queue(&xprt->pending, "xprt_pending"); - rpc_init_wait_queue(&xprt->sending, "xprt_sending"); - rpc_init_wait_queue(&xprt->resend, "xprt_resend"); - rpc_init_priority_wait_queue(&xprt->backlog, "xprt_backlog"); - - /* initialize free list */ - for (req = &xprt->slot[xprt->max_reqs-1]; req >= &xprt->slot[0]; req--) - list_add(&req->rq_list, &xprt->free); - xprt_init_xid(xprt); + if (strlen(args->servername) > RPC_MAXNETNAMELEN) { + xprt_destroy(xprt); + return ERR_PTR(-EINVAL); + } + xprt->servername = kstrdup(args->servername, GFP_KERNEL); + if (xprt->servername == NULL) { + xprt_destroy(xprt); + return ERR_PTR(-ENOMEM); + } dprintk("RPC: created transport %p with %u slots\n", xprt, xprt->max_reqs); +out: return xprt; } /** * xprt_destroy - destroy an RPC transport, killing off all requests. - * @kref: kref for the transport to destroy + * @xprt: transport to destroy * */ -static void xprt_destroy(struct kref *kref) +static void xprt_destroy(struct rpc_xprt *xprt) { - struct rpc_xprt *xprt = container_of(kref, struct rpc_xprt, kref); - dprintk("RPC: destroying transport %p\n", xprt); - xprt->shutdown = 1; del_timer_sync(&xprt->timer); rpc_destroy_wait_queue(&xprt->binding); rpc_destroy_wait_queue(&xprt->pending); rpc_destroy_wait_queue(&xprt->sending); - rpc_destroy_wait_queue(&xprt->resend); rpc_destroy_wait_queue(&xprt->backlog); + cancel_work_sync(&xprt->task_cleanup); + kfree(xprt->servername); /* * Tear down transport state and free the rpc_xprt */ @@ -1148,16 +1370,6 @@ static void xprt_destroy(struct kref *kref) */ void xprt_put(struct rpc_xprt *xprt) { - kref_put(&xprt->kref, xprt_destroy); -} - -/** - * xprt_get - return a reference to an RPC transport. - * @xprt: pointer to the transport - * - */ -struct rpc_xprt *xprt_get(struct rpc_xprt *xprt) -{ - kref_get(&xprt->kref); - return xprt; + if (atomic_dec_and_test(&xprt->count)) + xprt_destroy(xprt); } diff --git a/net/sunrpc/xprtrdma/Makefile b/net/sunrpc/xprtrdma/Makefile index 5a8f268bdd3..da5136fd569 100644 --- a/net/sunrpc/xprtrdma/Makefile +++ b/net/sunrpc/xprtrdma/Makefile @@ -1,8 +1,8 @@ -obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma.o +obj-$(CONFIG_SUNRPC_XPRT_RDMA_CLIENT) += xprtrdma.o xprtrdma-y := transport.o rpc_rdma.o verbs.o -obj-$(CONFIG_SUNRPC_XPRT_RDMA) += svcrdma.o +obj-$(CONFIG_SUNRPC_XPRT_RDMA_SERVER) += svcrdma.o svcrdma-y := svc_rdma.o svc_rdma_transport.o \ svc_rdma_marshal.o svc_rdma_sendto.o svc_rdma_recvfrom.o diff --git a/net/sunrpc/xprtrdma/rpc_rdma.c b/net/sunrpc/xprtrdma/rpc_rdma.c index e5e28d1946a..693966d3f33 100644 --- a/net/sunrpc/xprtrdma/rpc_rdma.c +++ b/net/sunrpc/xprtrdma/rpc_rdma.c @@ -78,8 +78,7 @@ static const char transfertypes[][12] = { * elements. Segments are then coalesced when registered, if possible * within the selected memreg mode. * - * Note, this routine is never called if the connection's memory - * registration strategy is 0 (bounce buffers). + * Returns positive number of segments converted, or a negative errno. */ static int @@ -87,6 +86,8 @@ rpcrdma_convert_iovs(struct xdr_buf *xdrbuf, unsigned int pos, enum rpcrdma_chunktype type, struct rpcrdma_mr_seg *seg, int nsegs) { int len, n = 0, p; + int page_base; + struct page **ppages; if (pos == 0 && xdrbuf->head[0].iov_len) { seg[n].mr_page = NULL; @@ -95,35 +96,40 @@ rpcrdma_convert_iovs(struct xdr_buf *xdrbuf, unsigned int pos, ++n; } - if (xdrbuf->page_len && (xdrbuf->pages[0] != NULL)) { - if (n == nsegs) - return 0; - seg[n].mr_page = xdrbuf->pages[0]; - seg[n].mr_offset = (void *)(unsigned long) xdrbuf->page_base; - seg[n].mr_len = min_t(u32, - PAGE_SIZE - xdrbuf->page_base, xdrbuf->page_len); - len = xdrbuf->page_len - seg[n].mr_len; - ++n; - p = 1; - while (len > 0) { - if (n == nsegs) - return 0; - seg[n].mr_page = xdrbuf->pages[p]; - seg[n].mr_offset = NULL; - seg[n].mr_len = min_t(u32, PAGE_SIZE, len); - len -= seg[n].mr_len; - ++n; - ++p; + len = xdrbuf->page_len; + ppages = xdrbuf->pages + (xdrbuf->page_base >> PAGE_SHIFT); + page_base = xdrbuf->page_base & ~PAGE_MASK; + p = 0; + while (len && n < nsegs) { + if (!ppages[p]) { + /* alloc the pagelist for receiving buffer */ + ppages[p] = alloc_page(GFP_ATOMIC); + if (!ppages[p]) + return -ENOMEM; } + seg[n].mr_page = ppages[p]; + seg[n].mr_offset = (void *)(unsigned long) page_base; + seg[n].mr_len = min_t(u32, PAGE_SIZE - page_base, len); + if (seg[n].mr_len > PAGE_SIZE) + return -EIO; + len -= seg[n].mr_len; + ++n; + ++p; + page_base = 0; /* page offset only applies to first page */ } + /* Message overflows the seg array */ + if (len && n == nsegs) + return -EIO; + if (xdrbuf->tail[0].iov_len) { /* the rpcrdma protocol allows us to omit any trailing * xdr pad bytes, saving the server an RDMA operation. */ if (xdrbuf->tail[0].iov_len < 4 && xprt_rdma_pad_optimize) return n; if (n == nsegs) - return 0; + /* Tail remains, but we're out of segments */ + return -EIO; seg[n].mr_page = NULL; seg[n].mr_offset = xdrbuf->tail[0].iov_base; seg[n].mr_len = xdrbuf->tail[0].iov_len; @@ -164,15 +170,17 @@ rpcrdma_convert_iovs(struct xdr_buf *xdrbuf, unsigned int pos, * Reply chunk (a counted array): * N elements: * 1 - N - HLOO - HLOO - ... - HLOO + * + * Returns positive RPC/RDMA header size, or negative errno. */ -static unsigned int +static ssize_t rpcrdma_create_chunks(struct rpc_rqst *rqst, struct xdr_buf *target, struct rpcrdma_msg *headerp, enum rpcrdma_chunktype type) { struct rpcrdma_req *req = rpcr_to_rdmar(rqst); - struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_task->tk_xprt); - int nsegs, nchunks = 0; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); + int n, nsegs, nchunks = 0; unsigned int pos; struct rpcrdma_mr_seg *seg = req->rl_segments; struct rpcrdma_read_chunk *cur_rchunk = NULL; @@ -198,12 +206,11 @@ rpcrdma_create_chunks(struct rpc_rqst *rqst, struct xdr_buf *target, pos = target->head[0].iov_len; nsegs = rpcrdma_convert_iovs(target, pos, type, seg, RPCRDMA_MAX_SEGS); - if (nsegs == 0) - return 0; + if (nsegs < 0) + return nsegs; do { - /* bind/register the memory, then build chunk from result. */ - int n = rpcrdma_register_external(seg, nsegs, + n = rpcrdma_register_external(seg, nsegs, cur_wchunk != NULL, r_xprt); if (n <= 0) goto out; @@ -248,8 +255,6 @@ rpcrdma_create_chunks(struct rpc_rqst *rqst, struct xdr_buf *target, /* success. all failures return above */ req->rl_nchunks = nchunks; - BUG_ON(nchunks == 0); - /* * finish off header. If write, marshal discrim and nchunks. */ @@ -276,8 +281,8 @@ rpcrdma_create_chunks(struct rpc_rqst *rqst, struct xdr_buf *target, out: for (pos = 0; nchunks--;) pos += rpcrdma_deregister_external( - &req->rl_segments[pos], r_xprt, NULL); - return 0; + &req->rl_segments[pos], r_xprt); + return n; } /* @@ -294,6 +299,8 @@ rpcrdma_inline_pullup(struct rpc_rqst *rqst, int pad) int copy_len; unsigned char *srcp, *destp; struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); + int page_base; + struct page **ppages; destp = rqst->rq_svec[0].iov_base; curlen = rqst->rq_svec[0].iov_len; @@ -322,28 +329,25 @@ rpcrdma_inline_pullup(struct rpc_rqst *rqst, int pad) __func__, destp + copy_len, curlen); rqst->rq_svec[0].iov_len += curlen; } - r_xprt->rx_stats.pullup_copy_count += copy_len; - npages = PAGE_ALIGN(rqst->rq_snd_buf.page_base+copy_len) >> PAGE_SHIFT; + + page_base = rqst->rq_snd_buf.page_base; + ppages = rqst->rq_snd_buf.pages + (page_base >> PAGE_SHIFT); + page_base &= ~PAGE_MASK; + npages = PAGE_ALIGN(page_base+copy_len) >> PAGE_SHIFT; for (i = 0; copy_len && i < npages; i++) { - if (i == 0) - curlen = PAGE_SIZE - rqst->rq_snd_buf.page_base; - else - curlen = PAGE_SIZE; + curlen = PAGE_SIZE - page_base; if (curlen > copy_len) curlen = copy_len; dprintk("RPC: %s: page %d destp 0x%p len %d curlen %d\n", __func__, i, destp, copy_len, curlen); - srcp = kmap_atomic(rqst->rq_snd_buf.pages[i], - KM_SKB_SUNRPC_DATA); - if (i == 0) - memcpy(destp, srcp+rqst->rq_snd_buf.page_base, curlen); - else - memcpy(destp, srcp, curlen); - kunmap_atomic(srcp, KM_SKB_SUNRPC_DATA); + srcp = kmap_atomic(ppages[i]); + memcpy(destp, srcp+page_base, curlen); + kunmap_atomic(srcp); rqst->rq_svec[0].iov_len += curlen; destp += curlen; copy_len -= curlen; + page_base = 0; } /* header now contains entire send message */ return pad; @@ -360,16 +364,19 @@ rpcrdma_inline_pullup(struct rpc_rqst *rqst, int pad) * [1] -- the RPC header/data, marshaled by RPC and the NFS protocol. * [2] -- optional padding. * [3] -- if padded, header only in [1] and data here. + * + * Returns zero on success, otherwise a negative errno. */ int rpcrdma_marshal_req(struct rpc_rqst *rqst) { - struct rpc_xprt *xprt = rqst->rq_task->tk_xprt; + struct rpc_xprt *xprt = rqst->rq_xprt; struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); struct rpcrdma_req *req = rpcr_to_rdmar(rqst); char *base; - size_t hdrlen, rpclen, padlen; + size_t rpclen, padlen; + ssize_t hdrlen; enum rpcrdma_chunktype rtype, wtype; struct rpcrdma_msg *headerp; @@ -440,14 +447,10 @@ rpcrdma_marshal_req(struct rpc_rqst *rqst) /* The following simplification is not true forever */ if (rtype != rpcrdma_noch && wtype == rpcrdma_replych) wtype = rpcrdma_noch; - BUG_ON(rtype != rpcrdma_noch && wtype != rpcrdma_noch); - - if (r_xprt->rx_ia.ri_memreg_strategy == RPCRDMA_BOUNCEBUFFERS && - (rtype != rpcrdma_noch || wtype != rpcrdma_noch)) { - /* forced to "pure inline"? */ - dprintk("RPC: %s: too much data (%d/%d) for inline\n", - __func__, rqst->rq_rcv_buf.len, rqst->rq_snd_buf.len); - return -1; + if (rtype != rpcrdma_noch && wtype != rpcrdma_noch) { + dprintk("RPC: %s: cannot marshal multiple chunk lists\n", + __func__); + return -EIO; } hdrlen = 28; /*sizeof *headerp;*/ @@ -473,8 +476,11 @@ rpcrdma_marshal_req(struct rpc_rqst *rqst) headerp->rm_body.rm_padded.rm_pempty[1] = xdr_zero; headerp->rm_body.rm_padded.rm_pempty[2] = xdr_zero; hdrlen += 2 * sizeof(u32); /* extra words in padhdr */ - BUG_ON(wtype != rpcrdma_noch); - + if (wtype != rpcrdma_noch) { + dprintk("RPC: %s: invalid chunk list\n", + __func__); + return -EIO; + } } else { headerp->rm_body.rm_nochunks.rm_empty[0] = xdr_zero; headerp->rm_body.rm_nochunks.rm_empty[1] = xdr_zero; @@ -491,8 +497,7 @@ rpcrdma_marshal_req(struct rpc_rqst *rqst) * on receive. Therefore, we request a reply chunk * for non-writes wherever feasible and efficient. */ - if (wtype == rpcrdma_noch && - r_xprt->rx_ia.ri_memreg_strategy > RPCRDMA_REGISTER) + if (wtype == rpcrdma_noch) wtype = rpcrdma_replych; } } @@ -510,9 +515,8 @@ rpcrdma_marshal_req(struct rpc_rqst *rqst) hdrlen = rpcrdma_create_chunks(rqst, &rqst->rq_rcv_buf, headerp, wtype); } - - if (hdrlen == 0) - return -1; + if (hdrlen < 0) + return hdrlen; dprintk("RPC: %s: %s: hdrlen %zd rpclen %zd padlen %zd" " headerp 0x%p base 0x%p lkey 0x%x\n", @@ -604,6 +608,8 @@ rpcrdma_inline_fixup(struct rpc_rqst *rqst, char *srcp, int copy_len, int pad) { int i, npages, curlen, olen; char *destp; + struct page **ppages; + int page_base; curlen = rqst->rq_rcv_buf.head[0].iov_len; if (curlen > copy_len) { /* write chunk header fixup */ @@ -622,36 +628,31 @@ rpcrdma_inline_fixup(struct rpc_rqst *rqst, char *srcp, int copy_len, int pad) olen = copy_len; i = 0; rpcx_to_rdmax(rqst->rq_xprt)->rx_stats.fixup_copy_count += olen; + page_base = rqst->rq_rcv_buf.page_base; + ppages = rqst->rq_rcv_buf.pages + (page_base >> PAGE_SHIFT); + page_base &= ~PAGE_MASK; + if (copy_len && rqst->rq_rcv_buf.page_len) { - npages = PAGE_ALIGN(rqst->rq_rcv_buf.page_base + + npages = PAGE_ALIGN(page_base + rqst->rq_rcv_buf.page_len) >> PAGE_SHIFT; for (; i < npages; i++) { - if (i == 0) - curlen = PAGE_SIZE - rqst->rq_rcv_buf.page_base; - else - curlen = PAGE_SIZE; + curlen = PAGE_SIZE - page_base; if (curlen > copy_len) curlen = copy_len; dprintk("RPC: %s: page %d" " srcp 0x%p len %d curlen %d\n", __func__, i, srcp, copy_len, curlen); - destp = kmap_atomic(rqst->rq_rcv_buf.pages[i], - KM_SKB_SUNRPC_DATA); - if (i == 0) - memcpy(destp + rqst->rq_rcv_buf.page_base, - srcp, curlen); - else - memcpy(destp, srcp, curlen); - flush_dcache_page(rqst->rq_rcv_buf.pages[i]); - kunmap_atomic(destp, KM_SKB_SUNRPC_DATA); + destp = kmap_atomic(ppages[i]); + memcpy(destp + page_base, srcp, curlen); + flush_dcache_page(ppages[i]); + kunmap_atomic(destp); srcp += curlen; copy_len -= curlen; if (copy_len == 0) break; + page_base = 0; } - rqst->rq_rcv_buf.page_len = olen - copy_len; - } else - rqst->rq_rcv_buf.page_len = 0; + } if (copy_len && rqst->rq_rcv_buf.tail[0].iov_len) { curlen = copy_len; @@ -682,15 +683,11 @@ rpcrdma_inline_fixup(struct rpc_rqst *rqst, char *srcp, int copy_len, int pad) rqst->rq_private_buf = rqst->rq_rcv_buf; } -/* - * This function is called when an async event is posted to - * the connection which changes the connection state. All it - * does at this point is mark the connection up/down, the rpc - * timers do the rest. - */ void -rpcrdma_conn_func(struct rpcrdma_ep *ep) +rpcrdma_connect_worker(struct work_struct *work) { + struct rpcrdma_ep *ep = + container_of(work, struct rpcrdma_ep, rep_connect_worker.work); struct rpc_xprt *xprt = ep->rep_xprt; spin_lock_bh(&xprt->transport_lock); @@ -707,13 +704,15 @@ rpcrdma_conn_func(struct rpcrdma_ep *ep) } /* - * This function is called when memory window unbind which we are waiting - * for completes. Just use rr_func (zeroed by upcall) to signal completion. + * This function is called when an async event is posted to + * the connection which changes the connection state. All it + * does at this point is mark the connection up/down, the rpc + * timers do the rest. */ -static void -rpcrdma_unbind_func(struct rpcrdma_rep *rep) +void +rpcrdma_conn_func(struct rpcrdma_ep *ep) { - wake_up(&rep->rr_unbind); + schedule_delayed_work(&ep->rep_connect_worker, 0); } /* @@ -730,7 +729,8 @@ rpcrdma_reply_handler(struct rpcrdma_rep *rep) struct rpc_xprt *xprt = rep->rr_xprt; struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); __be32 *iptr; - int i, rdmalen, status; + int rdmalen, status; + unsigned long cwnd; /* Check status. If bad, signal disconnect and return rep to pool */ if (rep->rr_len == ~0U) { @@ -771,15 +771,21 @@ repost: /* get request object */ req = rpcr_to_rdmar(rqst); + if (req->rl_reply) { + spin_unlock(&xprt->transport_lock); + dprintk("RPC: %s: duplicate reply 0x%p to RPC " + "request 0x%p: xid 0x%08x\n", __func__, rep, req, + headerp->rm_xid); + goto repost; + } dprintk("RPC: %s: reply 0x%p completes request 0x%p\n" " RPC request 0x%p xid 0x%08x\n", __func__, rep, req, rqst, headerp->rm_xid); - BUG_ON(!req || req->rl_reply); - /* from here on, the reply is no longer an orphan */ req->rl_reply = rep; + xprt->reestablish_timeout = 0; /* check for expected message types */ /* The order of some of these tests is important. */ @@ -854,26 +860,10 @@ badheader: break; } - /* If using mw bind, start the deregister process now. */ - /* (Note: if mr_free(), cannot perform it here, in tasklet context) */ - if (req->rl_nchunks) switch (r_xprt->rx_ia.ri_memreg_strategy) { - case RPCRDMA_MEMWINDOWS: - for (i = 0; req->rl_nchunks-- > 1;) - i += rpcrdma_deregister_external( - &req->rl_segments[i], r_xprt, NULL); - /* Optionally wait (not here) for unbinds to complete */ - rep->rr_func = rpcrdma_unbind_func; - (void) rpcrdma_deregister_external(&req->rl_segments[i], - r_xprt, rep); - break; - case RPCRDMA_MEMWINDOWS_ASYNC: - for (i = 0; req->rl_nchunks--;) - i += rpcrdma_deregister_external(&req->rl_segments[i], - r_xprt, NULL); - break; - default: - break; - } + cwnd = xprt->cwnd; + xprt->cwnd = atomic_read(&r_xprt->rx_buf.rb_credits) << RPC_CWNDSHIFT; + if (xprt->cwnd > cwnd) + xprt_release_rqst_cong(rqst->rq_task); dprintk("RPC: %s: xprt_complete_rqst(0x%p, 0x%p, %d)\n", __func__, xprt, rqst, status); diff --git a/net/sunrpc/xprtrdma/svc_rdma.c b/net/sunrpc/xprtrdma/svc_rdma.c index 5b8a8ff93a2..c1b6270262c 100644 --- a/net/sunrpc/xprtrdma/svc_rdma.c +++ b/net/sunrpc/xprtrdma/svc_rdma.c @@ -40,11 +40,14 @@ */ #include <linux/module.h> #include <linux/init.h> +#include <linux/slab.h> #include <linux/fs.h> #include <linux/sysctl.h> +#include <linux/workqueue.h> #include <linux/sunrpc/clnt.h> #include <linux/sunrpc/sched.h> #include <linux/sunrpc/svc_rdma.h> +#include "xprt_rdma.h" #define RPCDBG_FACILITY RPCDBG_SVCXPRT @@ -73,13 +76,15 @@ atomic_t rdma_stat_sq_prod; struct kmem_cache *svc_rdma_map_cachep; struct kmem_cache *svc_rdma_ctxt_cachep; +struct workqueue_struct *svc_rdma_wq; + /* * This function implements reading and resetting an atomic_t stat * variable through read/write to a proc file. Any write to the file * resets the associated statistic to zero. Any read returns it's * current value. */ -static int read_reset_stat(ctl_table *table, int write, +static int read_reset_stat(struct ctl_table *table, int write, void __user *buffer, size_t *lenp, loff_t *ppos) { @@ -114,7 +119,7 @@ static int read_reset_stat(ctl_table *table, int write, } static struct ctl_table_header *svcrdma_table_header; -static ctl_table svcrdma_parm_table[] = { +static struct ctl_table svcrdma_parm_table[] = { { .procname = "max_requests", .data = &svcrdma_max_requests, @@ -209,7 +214,7 @@ static ctl_table svcrdma_parm_table[] = { { }, }; -static ctl_table svcrdma_table[] = { +static struct ctl_table svcrdma_table[] = { { .procname = "svc_rdma", .mode = 0555, @@ -218,7 +223,7 @@ static ctl_table svcrdma_table[] = { { }, }; -static ctl_table svcrdma_root_table[] = { +static struct ctl_table svcrdma_root_table[] = { { .procname = "sunrpc", .mode = 0555, @@ -230,7 +235,7 @@ static ctl_table svcrdma_root_table[] = { void svc_rdma_cleanup(void) { dprintk("SVCRDMA Module Removed, deregister RPC RDMA transport\n"); - flush_scheduled_work(); + destroy_workqueue(svc_rdma_wq); if (svcrdma_table_header) { unregister_sysctl_table(svcrdma_table_header); svcrdma_table_header = NULL; @@ -248,6 +253,11 @@ int svc_rdma_init(void) dprintk("\tsq_depth : %d\n", svcrdma_max_requests * RPCRDMA_SQ_DEPTH_MULT); dprintk("\tmax_inline : %d\n", svcrdma_max_req_size); + + svc_rdma_wq = alloc_workqueue("svc_rdma", 0, 0); + if (!svc_rdma_wq) + return -ENOMEM; + if (!svcrdma_table_header) svcrdma_table_header = register_sysctl_table(svcrdma_root_table); @@ -282,6 +292,7 @@ int svc_rdma_init(void) kmem_cache_destroy(svc_rdma_map_cachep); err0: unregister_sysctl_table(svcrdma_table_header); + destroy_workqueue(svc_rdma_wq); return -ENOMEM; } MODULE_AUTHOR("Tom Tucker <tom@opengridcomputing.com>"); diff --git a/net/sunrpc/xprtrdma/svc_rdma_marshal.c b/net/sunrpc/xprtrdma/svc_rdma_marshal.c index 9530ef2d40d..65b146297f5 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_marshal.c +++ b/net/sunrpc/xprtrdma/svc_rdma_marshal.c @@ -60,21 +60,11 @@ static u32 *decode_read_list(u32 *va, u32 *vaend) struct rpcrdma_read_chunk *ch = (struct rpcrdma_read_chunk *)va; while (ch->rc_discrim != xdr_zero) { - u64 ch_offset; - if (((unsigned long)ch + sizeof(struct rpcrdma_read_chunk)) > (unsigned long)vaend) { dprintk("svcrdma: vaend=%p, ch=%p\n", vaend, ch); return NULL; } - - ch->rc_discrim = ntohl(ch->rc_discrim); - ch->rc_position = ntohl(ch->rc_position); - ch->rc_target.rs_handle = ntohl(ch->rc_target.rs_handle); - ch->rc_target.rs_length = ntohl(ch->rc_target.rs_length); - va = (u32 *)&ch->rc_target.rs_offset; - xdr_decode_hyper(va, &ch_offset); - put_unaligned(ch_offset, (u64 *)va); ch++; } return (u32 *)&ch->rc_position; @@ -91,7 +81,7 @@ void svc_rdma_rcl_chunk_counts(struct rpcrdma_read_chunk *ch, *byte_count = 0; *ch_count = 0; for (; ch->rc_discrim != 0; ch++) { - *byte_count = *byte_count + ch->rc_target.rs_length; + *byte_count = *byte_count + ntohl(ch->rc_target.rs_length); *ch_count = *ch_count + 1; } } @@ -108,7 +98,9 @@ void svc_rdma_rcl_chunk_counts(struct rpcrdma_read_chunk *ch, */ static u32 *decode_write_list(u32 *va, u32 *vaend) { - int ch_no; + unsigned long start, end; + int nchunks; + struct rpcrdma_write_array *ary = (struct rpcrdma_write_array *)va; @@ -121,37 +113,28 @@ static u32 *decode_write_list(u32 *va, u32 *vaend) dprintk("svcrdma: ary=%p, vaend=%p\n", ary, vaend); return NULL; } - ary->wc_discrim = ntohl(ary->wc_discrim); - ary->wc_nchunks = ntohl(ary->wc_nchunks); - if (((unsigned long)&ary->wc_array[0] + - (sizeof(struct rpcrdma_write_chunk) * ary->wc_nchunks)) > - (unsigned long)vaend) { + nchunks = ntohl(ary->wc_nchunks); + + start = (unsigned long)&ary->wc_array[0]; + end = (unsigned long)vaend; + if (nchunks < 0 || + nchunks > (SIZE_MAX - start) / sizeof(struct rpcrdma_write_chunk) || + (start + (sizeof(struct rpcrdma_write_chunk) * nchunks)) > end) { dprintk("svcrdma: ary=%p, wc_nchunks=%d, vaend=%p\n", - ary, ary->wc_nchunks, vaend); + ary, nchunks, vaend); return NULL; } - for (ch_no = 0; ch_no < ary->wc_nchunks; ch_no++) { - u64 ch_offset; - - ary->wc_array[ch_no].wc_target.rs_handle = - ntohl(ary->wc_array[ch_no].wc_target.rs_handle); - ary->wc_array[ch_no].wc_target.rs_length = - ntohl(ary->wc_array[ch_no].wc_target.rs_length); - va = (u32 *)&ary->wc_array[ch_no].wc_target.rs_offset; - xdr_decode_hyper(va, &ch_offset); - put_unaligned(ch_offset, (u64 *)va); - } - /* * rs_length is the 2nd 4B field in wc_target and taking its * address skips the list terminator */ - return (u32 *)&ary->wc_array[ch_no].wc_target.rs_length; + return (u32 *)&ary->wc_array[nchunks].wc_target.rs_length; } static u32 *decode_reply_array(u32 *va, u32 *vaend) { - int ch_no; + unsigned long start, end; + int nchunks; struct rpcrdma_write_array *ary = (struct rpcrdma_write_array *)va; @@ -164,28 +147,18 @@ static u32 *decode_reply_array(u32 *va, u32 *vaend) dprintk("svcrdma: ary=%p, vaend=%p\n", ary, vaend); return NULL; } - ary->wc_discrim = ntohl(ary->wc_discrim); - ary->wc_nchunks = ntohl(ary->wc_nchunks); - if (((unsigned long)&ary->wc_array[0] + - (sizeof(struct rpcrdma_write_chunk) * ary->wc_nchunks)) > - (unsigned long)vaend) { + nchunks = ntohl(ary->wc_nchunks); + + start = (unsigned long)&ary->wc_array[0]; + end = (unsigned long)vaend; + if (nchunks < 0 || + nchunks > (SIZE_MAX - start) / sizeof(struct rpcrdma_write_chunk) || + (start + (sizeof(struct rpcrdma_write_chunk) * nchunks)) > end) { dprintk("svcrdma: ary=%p, wc_nchunks=%d, vaend=%p\n", - ary, ary->wc_nchunks, vaend); + ary, nchunks, vaend); return NULL; } - for (ch_no = 0; ch_no < ary->wc_nchunks; ch_no++) { - u64 ch_offset; - - ary->wc_array[ch_no].wc_target.rs_handle = - ntohl(ary->wc_array[ch_no].wc_target.rs_handle); - ary->wc_array[ch_no].wc_target.rs_length = - ntohl(ary->wc_array[ch_no].wc_target.rs_length); - va = (u32 *)&ary->wc_array[ch_no].wc_target.rs_offset; - xdr_decode_hyper(va, &ch_offset); - put_unaligned(ch_offset, (u64 *)va); - } - - return (u32 *)&ary->wc_array[ch_no]; + return (u32 *)&ary->wc_array[nchunks]; } int svc_rdma_xdr_decode_req(struct rpcrdma_msg **rdma_req, @@ -386,13 +359,14 @@ void svc_rdma_xdr_encode_reply_array(struct rpcrdma_write_array *ary, void svc_rdma_xdr_encode_array_chunk(struct rpcrdma_write_array *ary, int chunk_no, - u32 rs_handle, u64 rs_offset, + __be32 rs_handle, + __be64 rs_offset, u32 write_len) { struct rpcrdma_segment *seg = &ary->wc_array[chunk_no].wc_target; - seg->rs_handle = htonl(rs_handle); + seg->rs_handle = rs_handle; + seg->rs_offset = rs_offset; seg->rs_length = htonl(write_len); - xdr_encode_hyper((u32 *) &seg->rs_offset, rs_offset); } void svc_rdma_xdr_encode_reply_header(struct svcxprt_rdma *xprt, diff --git a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c index f92e37eb413..8f92a61ee2d 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c +++ b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c @@ -1,4 +1,5 @@ /* + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. * * This software is available to you under a choice of one of two @@ -69,7 +70,8 @@ static void rdma_build_arg_xdr(struct svc_rqst *rqstp, /* Set up the XDR head */ rqstp->rq_arg.head[0].iov_base = page_address(page); - rqstp->rq_arg.head[0].iov_len = min(byte_count, ctxt->sge[0].length); + rqstp->rq_arg.head[0].iov_len = + min_t(size_t, byte_count, ctxt->sge[0].length); rqstp->rq_arg.len = byte_count; rqstp->rq_arg.buflen = byte_count; @@ -85,11 +87,12 @@ static void rdma_build_arg_xdr(struct svc_rqst *rqstp, page = ctxt->pages[sge_no]; put_page(rqstp->rq_pages[sge_no]); rqstp->rq_pages[sge_no] = page; - bc -= min(bc, ctxt->sge[sge_no].length); + bc -= min_t(u32, bc, ctxt->sge[sge_no].length); rqstp->rq_arg.buflen += ctxt->sge[sge_no].length; sge_no++; } rqstp->rq_respages = &rqstp->rq_pages[sge_no]; + rqstp->rq_next_page = rqstp->rq_respages + 1; /* We should never run out of SGE because the limit is defined to * support the max allowed RPC data length @@ -112,284 +115,265 @@ static void rdma_build_arg_xdr(struct svc_rqst *rqstp, rqstp->rq_arg.tail[0].iov_len = 0; } -/* Encode a read-chunk-list as an array of IB SGE - * - * Assumptions: - * - chunk[0]->position points to pages[0] at an offset of 0 - * - pages[] is not physically or virtually contiguous and consists of - * PAGE_SIZE elements. - * - * Output: - * - sge array pointing into pages[] array. - * - chunk_sge array specifying sge index and count for each - * chunk in the read list - * - */ -static int map_read_chunks(struct svcxprt_rdma *xprt, - struct svc_rqst *rqstp, - struct svc_rdma_op_ctxt *head, - struct rpcrdma_msg *rmsgp, - struct svc_rdma_req_map *rpl_map, - struct svc_rdma_req_map *chl_map, - int ch_count, - int byte_count) +static int rdma_read_max_sge(struct svcxprt_rdma *xprt, int sge_count) { - int sge_no; - int sge_bytes; - int page_off; - int page_no; - int ch_bytes; - int ch_no; - struct rpcrdma_read_chunk *ch; + if (rdma_node_get_transport(xprt->sc_cm_id->device->node_type) == + RDMA_TRANSPORT_IWARP) + return 1; + else + return min_t(int, sge_count, xprt->sc_max_sge); +} - sge_no = 0; - page_no = 0; - page_off = 0; - ch = (struct rpcrdma_read_chunk *)&rmsgp->rm_body.rm_chunks[0]; - ch_no = 0; - ch_bytes = ch->rc_target.rs_length; - head->arg.head[0] = rqstp->rq_arg.head[0]; - head->arg.tail[0] = rqstp->rq_arg.tail[0]; - head->arg.pages = &head->pages[head->count]; - head->hdr_count = head->count; /* save count of hdr pages */ - head->arg.page_base = 0; - head->arg.page_len = ch_bytes; - head->arg.len = rqstp->rq_arg.len + ch_bytes; - head->arg.buflen = rqstp->rq_arg.buflen + ch_bytes; - head->count++; - chl_map->ch[0].start = 0; - while (byte_count) { - rpl_map->sge[sge_no].iov_base = - page_address(rqstp->rq_arg.pages[page_no]) + page_off; - sge_bytes = min_t(int, PAGE_SIZE-page_off, ch_bytes); - rpl_map->sge[sge_no].iov_len = sge_bytes; - /* - * Don't bump head->count here because the same page - * may be used by multiple SGE. - */ - head->arg.pages[page_no] = rqstp->rq_arg.pages[page_no]; - rqstp->rq_respages = &rqstp->rq_arg.pages[page_no+1]; +typedef int (*rdma_reader_fn)(struct svcxprt_rdma *xprt, + struct svc_rqst *rqstp, + struct svc_rdma_op_ctxt *head, + int *page_no, + u32 *page_offset, + u32 rs_handle, + u32 rs_length, + u64 rs_offset, + int last); + +/* Issue an RDMA_READ using the local lkey to map the data sink */ +static int rdma_read_chunk_lcl(struct svcxprt_rdma *xprt, + struct svc_rqst *rqstp, + struct svc_rdma_op_ctxt *head, + int *page_no, + u32 *page_offset, + u32 rs_handle, + u32 rs_length, + u64 rs_offset, + int last) +{ + struct ib_send_wr read_wr; + int pages_needed = PAGE_ALIGN(*page_offset + rs_length) >> PAGE_SHIFT; + struct svc_rdma_op_ctxt *ctxt = svc_rdma_get_context(xprt); + int ret, read, pno; + u32 pg_off = *page_offset; + u32 pg_no = *page_no; - byte_count -= sge_bytes; - ch_bytes -= sge_bytes; - sge_no++; - /* - * If all bytes for this chunk have been mapped to an - * SGE, move to the next SGE - */ - if (ch_bytes == 0) { - chl_map->ch[ch_no].count = - sge_no - chl_map->ch[ch_no].start; - ch_no++; - ch++; - chl_map->ch[ch_no].start = sge_no; - ch_bytes = ch->rc_target.rs_length; - /* If bytes remaining account for next chunk */ - if (byte_count) { - head->arg.page_len += ch_bytes; - head->arg.len += ch_bytes; - head->arg.buflen += ch_bytes; - } + ctxt->direction = DMA_FROM_DEVICE; + ctxt->read_hdr = head; + pages_needed = + min_t(int, pages_needed, rdma_read_max_sge(xprt, pages_needed)); + read = min_t(int, pages_needed << PAGE_SHIFT, rs_length); + + for (pno = 0; pno < pages_needed; pno++) { + int len = min_t(int, rs_length, PAGE_SIZE - pg_off); + + head->arg.pages[pg_no] = rqstp->rq_arg.pages[pg_no]; + head->arg.page_len += len; + head->arg.len += len; + if (!pg_off) + head->count++; + rqstp->rq_respages = &rqstp->rq_arg.pages[pg_no+1]; + rqstp->rq_next_page = rqstp->rq_respages + 1; + ctxt->sge[pno].addr = + ib_dma_map_page(xprt->sc_cm_id->device, + head->arg.pages[pg_no], pg_off, + PAGE_SIZE - pg_off, + DMA_FROM_DEVICE); + ret = ib_dma_mapping_error(xprt->sc_cm_id->device, + ctxt->sge[pno].addr); + if (ret) + goto err; + atomic_inc(&xprt->sc_dma_used); + + /* The lkey here is either a local dma lkey or a dma_mr lkey */ + ctxt->sge[pno].lkey = xprt->sc_dma_lkey; + ctxt->sge[pno].length = len; + ctxt->count++; + + /* adjust offset and wrap to next page if needed */ + pg_off += len; + if (pg_off == PAGE_SIZE) { + pg_off = 0; + pg_no++; } - /* - * If this SGE consumed all of the page, move to the - * next page - */ - if ((sge_bytes + page_off) == PAGE_SIZE) { - page_no++; - page_off = 0; - /* - * If there are still bytes left to map, bump - * the page count - */ - if (byte_count) - head->count++; - } else - page_off += sge_bytes; + rs_length -= len; } - BUG_ON(byte_count != 0); - return sge_no; + + if (last && rs_length == 0) + set_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags); + else + clear_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags); + + memset(&read_wr, 0, sizeof(read_wr)); + read_wr.wr_id = (unsigned long)ctxt; + read_wr.opcode = IB_WR_RDMA_READ; + ctxt->wr_op = read_wr.opcode; + read_wr.send_flags = IB_SEND_SIGNALED; + read_wr.wr.rdma.rkey = rs_handle; + read_wr.wr.rdma.remote_addr = rs_offset; + read_wr.sg_list = ctxt->sge; + read_wr.num_sge = pages_needed; + + ret = svc_rdma_send(xprt, &read_wr); + if (ret) { + pr_err("svcrdma: Error %d posting RDMA_READ\n", ret); + set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); + goto err; + } + + /* return current location in page array */ + *page_no = pg_no; + *page_offset = pg_off; + ret = read; + atomic_inc(&rdma_stat_read); + return ret; + err: + svc_rdma_unmap_dma(ctxt); + svc_rdma_put_context(ctxt, 0); + return ret; } -/* Map a read-chunk-list to an XDR and fast register the page-list. - * - * Assumptions: - * - chunk[0] position points to pages[0] at an offset of 0 - * - pages[] will be made physically contiguous by creating a one-off memory - * region using the fastreg verb. - * - byte_count is # of bytes in read-chunk-list - * - ch_count is # of chunks in read-chunk-list - * - * Output: - * - sge array pointing into pages[] array. - * - chunk_sge array specifying sge index and count for each - * chunk in the read list - */ -static int fast_reg_read_chunks(struct svcxprt_rdma *xprt, +/* Issue an RDMA_READ using an FRMR to map the data sink */ +static int rdma_read_chunk_frmr(struct svcxprt_rdma *xprt, struct svc_rqst *rqstp, struct svc_rdma_op_ctxt *head, - struct rpcrdma_msg *rmsgp, - struct svc_rdma_req_map *rpl_map, - struct svc_rdma_req_map *chl_map, - int ch_count, - int byte_count) + int *page_no, + u32 *page_offset, + u32 rs_handle, + u32 rs_length, + u64 rs_offset, + int last) { - int page_no; - int ch_no; - u32 offset; - struct rpcrdma_read_chunk *ch; - struct svc_rdma_fastreg_mr *frmr; - int ret = 0; + struct ib_send_wr read_wr; + struct ib_send_wr inv_wr; + struct ib_send_wr fastreg_wr; + u8 key; + int pages_needed = PAGE_ALIGN(*page_offset + rs_length) >> PAGE_SHIFT; + struct svc_rdma_op_ctxt *ctxt = svc_rdma_get_context(xprt); + struct svc_rdma_fastreg_mr *frmr = svc_rdma_get_frmr(xprt); + int ret, read, pno; + u32 pg_off = *page_offset; + u32 pg_no = *page_no; - frmr = svc_rdma_get_frmr(xprt); if (IS_ERR(frmr)) return -ENOMEM; - head->frmr = frmr; - head->arg.head[0] = rqstp->rq_arg.head[0]; - head->arg.tail[0] = rqstp->rq_arg.tail[0]; - head->arg.pages = &head->pages[head->count]; - head->hdr_count = head->count; /* save count of hdr pages */ - head->arg.page_base = 0; - head->arg.page_len = byte_count; - head->arg.len = rqstp->rq_arg.len + byte_count; - head->arg.buflen = rqstp->rq_arg.buflen + byte_count; + ctxt->direction = DMA_FROM_DEVICE; + ctxt->frmr = frmr; + pages_needed = min_t(int, pages_needed, xprt->sc_frmr_pg_list_len); + read = min_t(int, pages_needed << PAGE_SHIFT, rs_length); - /* Fast register the page list */ - frmr->kva = page_address(rqstp->rq_arg.pages[0]); + frmr->kva = page_address(rqstp->rq_arg.pages[pg_no]); frmr->direction = DMA_FROM_DEVICE; frmr->access_flags = (IB_ACCESS_LOCAL_WRITE|IB_ACCESS_REMOTE_WRITE); - frmr->map_len = byte_count; - frmr->page_list_len = PAGE_ALIGN(byte_count) >> PAGE_SHIFT; - for (page_no = 0; page_no < frmr->page_list_len; page_no++) { - frmr->page_list->page_list[page_no] = - ib_dma_map_single(xprt->sc_cm_id->device, - page_address(rqstp->rq_arg.pages[page_no]), - PAGE_SIZE, DMA_FROM_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - frmr->page_list->page_list[page_no])) - goto fatal_err; + frmr->map_len = pages_needed << PAGE_SHIFT; + frmr->page_list_len = pages_needed; + + for (pno = 0; pno < pages_needed; pno++) { + int len = min_t(int, rs_length, PAGE_SIZE - pg_off); + + head->arg.pages[pg_no] = rqstp->rq_arg.pages[pg_no]; + head->arg.page_len += len; + head->arg.len += len; + if (!pg_off) + head->count++; + rqstp->rq_respages = &rqstp->rq_arg.pages[pg_no+1]; + rqstp->rq_next_page = rqstp->rq_respages + 1; + frmr->page_list->page_list[pno] = + ib_dma_map_page(xprt->sc_cm_id->device, + head->arg.pages[pg_no], 0, + PAGE_SIZE, DMA_FROM_DEVICE); + ret = ib_dma_mapping_error(xprt->sc_cm_id->device, + frmr->page_list->page_list[pno]); + if (ret) + goto err; atomic_inc(&xprt->sc_dma_used); - head->arg.pages[page_no] = rqstp->rq_arg.pages[page_no]; - } - head->count += page_no; - - /* rq_respages points one past arg pages */ - rqstp->rq_respages = &rqstp->rq_arg.pages[page_no]; - /* Create the reply and chunk maps */ - offset = 0; - ch = (struct rpcrdma_read_chunk *)&rmsgp->rm_body.rm_chunks[0]; - for (ch_no = 0; ch_no < ch_count; ch_no++) { - rpl_map->sge[ch_no].iov_base = frmr->kva + offset; - rpl_map->sge[ch_no].iov_len = ch->rc_target.rs_length; - chl_map->ch[ch_no].count = 1; - chl_map->ch[ch_no].start = ch_no; - offset += ch->rc_target.rs_length; - ch++; + /* adjust offset and wrap to next page if needed */ + pg_off += len; + if (pg_off == PAGE_SIZE) { + pg_off = 0; + pg_no++; + } + rs_length -= len; } - ret = svc_rdma_fastreg(xprt, frmr); - if (ret) - goto fatal_err; - - return ch_no; - - fatal_err: - printk("svcrdma: error fast registering xdr for xprt %p", xprt); - svc_rdma_put_frmr(xprt, frmr); - return -EIO; -} - -static int rdma_set_ctxt_sge(struct svcxprt_rdma *xprt, - struct svc_rdma_op_ctxt *ctxt, - struct svc_rdma_fastreg_mr *frmr, - struct kvec *vec, - u64 *sgl_offset, - int count) -{ - int i; + if (last && rs_length == 0) + set_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags); + else + clear_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags); - ctxt->count = count; - ctxt->direction = DMA_FROM_DEVICE; - for (i = 0; i < count; i++) { - ctxt->sge[i].length = 0; /* in case map fails */ - if (!frmr) { - ctxt->sge[i].addr = - ib_dma_map_single(xprt->sc_cm_id->device, - vec[i].iov_base, - vec[i].iov_len, - DMA_FROM_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - ctxt->sge[i].addr)) - return -EINVAL; - ctxt->sge[i].lkey = xprt->sc_dma_lkey; - atomic_inc(&xprt->sc_dma_used); - } else { - ctxt->sge[i].addr = (unsigned long)vec[i].iov_base; - ctxt->sge[i].lkey = frmr->mr->lkey; - } - ctxt->sge[i].length = vec[i].iov_len; - *sgl_offset = *sgl_offset + vec[i].iov_len; + /* Bump the key */ + key = (u8)(frmr->mr->lkey & 0x000000FF); + ib_update_fast_reg_key(frmr->mr, ++key); + + ctxt->sge[0].addr = (unsigned long)frmr->kva + *page_offset; + ctxt->sge[0].lkey = frmr->mr->lkey; + ctxt->sge[0].length = read; + ctxt->count = 1; + ctxt->read_hdr = head; + + /* Prepare FASTREG WR */ + memset(&fastreg_wr, 0, sizeof(fastreg_wr)); + fastreg_wr.opcode = IB_WR_FAST_REG_MR; + fastreg_wr.send_flags = IB_SEND_SIGNALED; + fastreg_wr.wr.fast_reg.iova_start = (unsigned long)frmr->kva; + fastreg_wr.wr.fast_reg.page_list = frmr->page_list; + fastreg_wr.wr.fast_reg.page_list_len = frmr->page_list_len; + fastreg_wr.wr.fast_reg.page_shift = PAGE_SHIFT; + fastreg_wr.wr.fast_reg.length = frmr->map_len; + fastreg_wr.wr.fast_reg.access_flags = frmr->access_flags; + fastreg_wr.wr.fast_reg.rkey = frmr->mr->lkey; + fastreg_wr.next = &read_wr; + + /* Prepare RDMA_READ */ + memset(&read_wr, 0, sizeof(read_wr)); + read_wr.send_flags = IB_SEND_SIGNALED; + read_wr.wr.rdma.rkey = rs_handle; + read_wr.wr.rdma.remote_addr = rs_offset; + read_wr.sg_list = ctxt->sge; + read_wr.num_sge = 1; + if (xprt->sc_dev_caps & SVCRDMA_DEVCAP_READ_W_INV) { + read_wr.opcode = IB_WR_RDMA_READ_WITH_INV; + read_wr.wr_id = (unsigned long)ctxt; + read_wr.ex.invalidate_rkey = ctxt->frmr->mr->lkey; + } else { + read_wr.opcode = IB_WR_RDMA_READ; + read_wr.next = &inv_wr; + /* Prepare invalidate */ + memset(&inv_wr, 0, sizeof(inv_wr)); + inv_wr.wr_id = (unsigned long)ctxt; + inv_wr.opcode = IB_WR_LOCAL_INV; + inv_wr.send_flags = IB_SEND_SIGNALED | IB_SEND_FENCE; + inv_wr.ex.invalidate_rkey = frmr->mr->lkey; + } + ctxt->wr_op = read_wr.opcode; + + /* Post the chain */ + ret = svc_rdma_send(xprt, &fastreg_wr); + if (ret) { + pr_err("svcrdma: Error %d posting RDMA_READ\n", ret); + set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); + goto err; } - return 0; -} -static int rdma_read_max_sge(struct svcxprt_rdma *xprt, int sge_count) -{ - if ((rdma_node_get_transport(xprt->sc_cm_id->device->node_type) == - RDMA_TRANSPORT_IWARP) && - sge_count > 1) - return 1; - else - return min_t(int, sge_count, xprt->sc_max_sge); + /* return current location in page array */ + *page_no = pg_no; + *page_offset = pg_off; + ret = read; + atomic_inc(&rdma_stat_read); + return ret; + err: + svc_rdma_unmap_dma(ctxt); + svc_rdma_put_context(ctxt, 0); + svc_rdma_put_frmr(xprt, frmr); + return ret; } -/* - * Use RDMA_READ to read data from the advertised client buffer into the - * XDR stream starting at rq_arg.head[0].iov_base. - * Each chunk in the array - * contains the following fields: - * discrim - '1', This isn't used for data placement - * position - The xdr stream offset (the same for every chunk) - * handle - RMR for client memory region - * length - data transfer length - * offset - 64 bit tagged offset in remote memory region - * - * On our side, we need to read into a pagelist. The first page immediately - * follows the RPC header. - * - * This function returns: - * 0 - No error and no read-list found. - * - * 1 - Successful read-list processing. The data is not yet in - * the pagelist and therefore the RPC request must be deferred. The - * I/O completion will enqueue the transport again and - * svc_rdma_recvfrom will complete the request. - * - * <0 - Error processing/posting read-list. - * - * NOTE: The ctxt must not be touched after the last WR has been posted - * because the I/O completion processing may occur on another - * processor and free / modify the context. Ne touche pas! - */ -static int rdma_read_xdr(struct svcxprt_rdma *xprt, - struct rpcrdma_msg *rmsgp, - struct svc_rqst *rqstp, - struct svc_rdma_op_ctxt *hdr_ctxt) +static int rdma_read_chunks(struct svcxprt_rdma *xprt, + struct rpcrdma_msg *rmsgp, + struct svc_rqst *rqstp, + struct svc_rdma_op_ctxt *head) { - struct ib_send_wr read_wr; - struct ib_send_wr inv_wr; - int err = 0; - int ch_no; - int ch_count; - int byte_count; - int sge_count; - u64 sgl_offset; + int page_no, ch_count, ret; struct rpcrdma_read_chunk *ch; - struct svc_rdma_op_ctxt *ctxt = NULL; - struct svc_rdma_req_map *rpl_map; - struct svc_rdma_req_map *chl_map; + u32 page_offset, byte_count; + u64 rs_offset; + rdma_reader_fn reader; /* If no read list is present, return 0 */ ch = svc_rdma_get_read_chunk(rmsgp); @@ -400,127 +384,55 @@ static int rdma_read_xdr(struct svcxprt_rdma *xprt, if (ch_count > RPCSVC_MAXPAGES) return -EINVAL; - /* Allocate temporary reply and chunk maps */ - rpl_map = svc_rdma_get_req_map(); - chl_map = svc_rdma_get_req_map(); + /* The request is completed when the RDMA_READs complete. The + * head context keeps all the pages that comprise the + * request. + */ + head->arg.head[0] = rqstp->rq_arg.head[0]; + head->arg.tail[0] = rqstp->rq_arg.tail[0]; + head->arg.pages = &head->pages[head->count]; + head->hdr_count = head->count; + head->arg.page_base = 0; + head->arg.page_len = 0; + head->arg.len = rqstp->rq_arg.len; + head->arg.buflen = rqstp->rq_arg.buflen; - if (!xprt->sc_frmr_pg_list_len) - sge_count = map_read_chunks(xprt, rqstp, hdr_ctxt, rmsgp, - rpl_map, chl_map, ch_count, - byte_count); + /* Use FRMR if supported */ + if (xprt->sc_dev_caps & SVCRDMA_DEVCAP_FAST_REG) + reader = rdma_read_chunk_frmr; else - sge_count = fast_reg_read_chunks(xprt, rqstp, hdr_ctxt, rmsgp, - rpl_map, chl_map, ch_count, - byte_count); - if (sge_count < 0) { - err = -EIO; - goto out; - } - - sgl_offset = 0; - ch_no = 0; + reader = rdma_read_chunk_lcl; + page_no = 0; page_offset = 0; for (ch = (struct rpcrdma_read_chunk *)&rmsgp->rm_body.rm_chunks[0]; - ch->rc_discrim != 0; ch++, ch_no++) { -next_sge: - ctxt = svc_rdma_get_context(xprt); - ctxt->direction = DMA_FROM_DEVICE; - ctxt->frmr = hdr_ctxt->frmr; - ctxt->read_hdr = NULL; - clear_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags); - clear_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); - - /* Prepare READ WR */ - memset(&read_wr, 0, sizeof read_wr); - read_wr.wr_id = (unsigned long)ctxt; - read_wr.opcode = IB_WR_RDMA_READ; - ctxt->wr_op = read_wr.opcode; - read_wr.send_flags = IB_SEND_SIGNALED; - read_wr.wr.rdma.rkey = ch->rc_target.rs_handle; - read_wr.wr.rdma.remote_addr = - get_unaligned(&(ch->rc_target.rs_offset)) + - sgl_offset; - read_wr.sg_list = ctxt->sge; - read_wr.num_sge = - rdma_read_max_sge(xprt, chl_map->ch[ch_no].count); - err = rdma_set_ctxt_sge(xprt, ctxt, hdr_ctxt->frmr, - &rpl_map->sge[chl_map->ch[ch_no].start], - &sgl_offset, - read_wr.num_sge); - if (err) { - svc_rdma_unmap_dma(ctxt); - svc_rdma_put_context(ctxt, 0); - goto out; - } - if (((ch+1)->rc_discrim == 0) && - (read_wr.num_sge == chl_map->ch[ch_no].count)) { - /* - * Mark the last RDMA_READ with a bit to - * indicate all RPC data has been fetched from - * the client and the RPC needs to be enqueued. - */ - set_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags); - if (hdr_ctxt->frmr) { - set_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); - /* - * Invalidate the local MR used to map the data - * sink. - */ - if (xprt->sc_dev_caps & - SVCRDMA_DEVCAP_READ_W_INV) { - read_wr.opcode = - IB_WR_RDMA_READ_WITH_INV; - ctxt->wr_op = read_wr.opcode; - read_wr.ex.invalidate_rkey = - ctxt->frmr->mr->lkey; - } else { - /* Prepare INVALIDATE WR */ - memset(&inv_wr, 0, sizeof inv_wr); - inv_wr.opcode = IB_WR_LOCAL_INV; - inv_wr.send_flags = IB_SEND_SIGNALED; - inv_wr.ex.invalidate_rkey = - hdr_ctxt->frmr->mr->lkey; - read_wr.next = &inv_wr; - } - } - ctxt->read_hdr = hdr_ctxt; + ch->rc_discrim != 0; ch++) { + + xdr_decode_hyper((__be32 *)&ch->rc_target.rs_offset, + &rs_offset); + byte_count = ntohl(ch->rc_target.rs_length); + + while (byte_count > 0) { + ret = reader(xprt, rqstp, head, + &page_no, &page_offset, + ntohl(ch->rc_target.rs_handle), + byte_count, rs_offset, + ((ch+1)->rc_discrim == 0) /* last */ + ); + if (ret < 0) + goto err; + byte_count -= ret; + rs_offset += ret; + head->arg.buflen += ret; } - /* Post the read */ - err = svc_rdma_send(xprt, &read_wr); - if (err) { - printk(KERN_ERR "svcrdma: Error %d posting RDMA_READ\n", - err); - set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); - svc_rdma_put_context(ctxt, 0); - goto out; - } - atomic_inc(&rdma_stat_read); - - if (read_wr.num_sge < chl_map->ch[ch_no].count) { - chl_map->ch[ch_no].count -= read_wr.num_sge; - chl_map->ch[ch_no].start += read_wr.num_sge; - goto next_sge; - } - sgl_offset = 0; - err = 1; } - - out: - svc_rdma_put_req_map(rpl_map); - svc_rdma_put_req_map(chl_map); - + ret = 1; + err: /* Detach arg pages. svc_recv will replenish them */ - for (ch_no = 0; &rqstp->rq_pages[ch_no] < rqstp->rq_respages; ch_no++) - rqstp->rq_pages[ch_no] = NULL; - - /* - * Detach res pages. svc_release must see a resused count of - * zero or it will attempt to put them. - */ - while (rqstp->rq_resused) - rqstp->rq_respages[--rqstp->rq_resused] = NULL; + for (page_no = 0; + &rqstp->rq_pages[page_no] < rqstp->rq_respages; page_no++) + rqstp->rq_pages[page_no] = NULL; - return err; + return ret; } static int rdma_read_complete(struct svc_rqst *rqstp, @@ -543,7 +455,7 @@ static int rdma_read_complete(struct svc_rqst *rqstp, /* rq_respages starts after the last arg page */ rqstp->rq_respages = &rqstp->rq_arg.pages[page_no]; - rqstp->rq_resused = 0; + rqstp->rq_next_page = rqstp->rq_respages + 1; /* Rebuild rq_arg head and tail. */ rqstp->rq_arg.head[0] = head->arg.head[0]; @@ -566,7 +478,6 @@ static int rdma_read_complete(struct svc_rqst *rqstp, ret, rqstp->rq_arg.len, rqstp->rq_arg.head[0].iov_base, rqstp->rq_arg.head[0].iov_len); - svc_xprt_received(rqstp->rq_xprt); return ret; } @@ -593,13 +504,9 @@ int svc_rdma_recvfrom(struct svc_rqst *rqstp) struct svc_rdma_op_ctxt, dto_q); list_del_init(&ctxt->dto_q); - } - if (ctxt) { spin_unlock_bh(&rdma_xprt->sc_rq_dto_lock); return rdma_read_complete(rqstp, ctxt); - } - - if (!list_empty(&rdma_xprt->sc_rq_dto_q)) { + } else if (!list_empty(&rdma_xprt->sc_rq_dto_q)) { ctxt = list_entry(rdma_xprt->sc_rq_dto_q.next, struct svc_rdma_op_ctxt, dto_q); @@ -619,7 +526,6 @@ int svc_rdma_recvfrom(struct svc_rqst *rqstp) if (test_bit(XPT_CLOSE, &xprt->xpt_flags)) goto close_out; - BUG_ON(ret); goto out; } dprintk("svcrdma: processing ctxt=%p on xprt=%p, rqstp=%p, status=%d\n", @@ -642,12 +548,11 @@ int svc_rdma_recvfrom(struct svc_rqst *rqstp) } /* Read read-list data. */ - ret = rdma_read_xdr(rdma_xprt, rmsgp, rqstp, ctxt); + ret = rdma_read_chunks(rdma_xprt, rmsgp, rqstp, ctxt); if (ret > 0) { /* read-list posted, defer until data received from client. */ goto defer; - } - if (ret < 0) { + } else if (ret < 0) { /* Post of read-list failed, free context. */ svc_rdma_put_context(ctxt, 1); return 0; @@ -665,7 +570,6 @@ int svc_rdma_recvfrom(struct svc_rqst *rqstp) rqstp->rq_arg.head[0].iov_len); rqstp->rq_prot = IPPROTO_MAX; svc_xprt_copy_addrs(rqstp, xprt); - svc_xprt_received(xprt); return ret; close_out: @@ -678,6 +582,5 @@ int svc_rdma_recvfrom(struct svc_rqst *rqstp) */ set_bit(XPT_CLOSE, &xprt->xpt_flags); defer: - svc_xprt_received(xprt); return 0; } diff --git a/net/sunrpc/xprtrdma/svc_rdma_sendto.c b/net/sunrpc/xprtrdma/svc_rdma_sendto.c index b15e1ebb2bf..49fd21a5c21 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_sendto.c +++ b/net/sunrpc/xprtrdma/svc_rdma_sendto.c @@ -1,4 +1,5 @@ /* + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. * * This software is available to you under a choice of one of two @@ -49,146 +50,6 @@ #define RPCDBG_FACILITY RPCDBG_SVCXPRT -/* Encode an XDR as an array of IB SGE - * - * Assumptions: - * - head[0] is physically contiguous. - * - tail[0] is physically contiguous. - * - pages[] is not physically or virtually contiguous and consists of - * PAGE_SIZE elements. - * - * Output: - * SGE[0] reserved for RCPRDMA header - * SGE[1] data from xdr->head[] - * SGE[2..sge_count-2] data from xdr->pages[] - * SGE[sge_count-1] data from xdr->tail. - * - * The max SGE we need is the length of the XDR / pagesize + one for - * head + one for tail + one for RPCRDMA header. Since RPCSVC_MAXPAGES - * reserves a page for both the request and the reply header, and this - * array is only concerned with the reply we are assured that we have - * on extra page for the RPCRMDA header. - */ -static int fast_reg_xdr(struct svcxprt_rdma *xprt, - struct xdr_buf *xdr, - struct svc_rdma_req_map *vec) -{ - int sge_no; - u32 sge_bytes; - u32 page_bytes; - u32 page_off; - int page_no = 0; - u8 *frva; - struct svc_rdma_fastreg_mr *frmr; - - frmr = svc_rdma_get_frmr(xprt); - if (IS_ERR(frmr)) - return -ENOMEM; - vec->frmr = frmr; - - /* Skip the RPCRDMA header */ - sge_no = 1; - - /* Map the head. */ - frva = (void *)((unsigned long)(xdr->head[0].iov_base) & PAGE_MASK); - vec->sge[sge_no].iov_base = xdr->head[0].iov_base; - vec->sge[sge_no].iov_len = xdr->head[0].iov_len; - vec->count = 2; - sge_no++; - - /* Build the FRMR */ - frmr->kva = frva; - frmr->direction = DMA_TO_DEVICE; - frmr->access_flags = 0; - frmr->map_len = PAGE_SIZE; - frmr->page_list_len = 1; - frmr->page_list->page_list[page_no] = - ib_dma_map_single(xprt->sc_cm_id->device, - (void *)xdr->head[0].iov_base, - PAGE_SIZE, DMA_TO_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - frmr->page_list->page_list[page_no])) - goto fatal_err; - atomic_inc(&xprt->sc_dma_used); - - page_off = xdr->page_base; - page_bytes = xdr->page_len + page_off; - if (!page_bytes) - goto encode_tail; - - /* Map the pages */ - vec->sge[sge_no].iov_base = frva + frmr->map_len + page_off; - vec->sge[sge_no].iov_len = page_bytes; - sge_no++; - while (page_bytes) { - struct page *page; - - page = xdr->pages[page_no++]; - sge_bytes = min_t(u32, page_bytes, (PAGE_SIZE - page_off)); - page_bytes -= sge_bytes; - - frmr->page_list->page_list[page_no] = - ib_dma_map_single(xprt->sc_cm_id->device, - page_address(page), - PAGE_SIZE, DMA_TO_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - frmr->page_list->page_list[page_no])) - goto fatal_err; - - atomic_inc(&xprt->sc_dma_used); - page_off = 0; /* reset for next time through loop */ - frmr->map_len += PAGE_SIZE; - frmr->page_list_len++; - } - vec->count++; - - encode_tail: - /* Map tail */ - if (0 == xdr->tail[0].iov_len) - goto done; - - vec->count++; - vec->sge[sge_no].iov_len = xdr->tail[0].iov_len; - - if (((unsigned long)xdr->tail[0].iov_base & PAGE_MASK) == - ((unsigned long)xdr->head[0].iov_base & PAGE_MASK)) { - /* - * If head and tail use the same page, we don't need - * to map it again. - */ - vec->sge[sge_no].iov_base = xdr->tail[0].iov_base; - } else { - void *va; - - /* Map another page for the tail */ - page_off = (unsigned long)xdr->tail[0].iov_base & ~PAGE_MASK; - va = (void *)((unsigned long)xdr->tail[0].iov_base & PAGE_MASK); - vec->sge[sge_no].iov_base = frva + frmr->map_len + page_off; - - frmr->page_list->page_list[page_no] = - ib_dma_map_single(xprt->sc_cm_id->device, va, PAGE_SIZE, - DMA_TO_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - frmr->page_list->page_list[page_no])) - goto fatal_err; - atomic_inc(&xprt->sc_dma_used); - frmr->map_len += PAGE_SIZE; - frmr->page_list_len++; - } - - done: - if (svc_rdma_fastreg(xprt, frmr)) - goto fatal_err; - - return 0; - - fatal_err: - printk("svcrdma: Error fast registering memory for xprt %p\n", xprt); - vec->frmr = NULL; - svc_rdma_put_frmr(xprt, frmr); - return -EIO; -} - static int map_xdr(struct svcxprt_rdma *xprt, struct xdr_buf *xdr, struct svc_rdma_req_map *vec) @@ -202,9 +63,6 @@ static int map_xdr(struct svcxprt_rdma *xprt, BUG_ON(xdr->len != (xdr->head[0].iov_len + xdr->page_len + xdr->tail[0].iov_len)); - if (xprt->sc_frmr_pg_list_len) - return fast_reg_xdr(xprt, xdr, vec); - /* Skip the first sge, this is for the RPCRDMA header */ sge_no = 1; @@ -245,9 +103,37 @@ static int map_xdr(struct svcxprt_rdma *xprt, return 0; } +static dma_addr_t dma_map_xdr(struct svcxprt_rdma *xprt, + struct xdr_buf *xdr, + u32 xdr_off, size_t len, int dir) +{ + struct page *page; + dma_addr_t dma_addr; + if (xdr_off < xdr->head[0].iov_len) { + /* This offset is in the head */ + xdr_off += (unsigned long)xdr->head[0].iov_base & ~PAGE_MASK; + page = virt_to_page(xdr->head[0].iov_base); + } else { + xdr_off -= xdr->head[0].iov_len; + if (xdr_off < xdr->page_len) { + /* This offset is in the page list */ + xdr_off += xdr->page_base; + page = xdr->pages[xdr_off >> PAGE_SHIFT]; + xdr_off &= ~PAGE_MASK; + } else { + /* This offset is in the tail */ + xdr_off -= xdr->page_len; + xdr_off += (unsigned long) + xdr->tail[0].iov_base & ~PAGE_MASK; + page = virt_to_page(xdr->tail[0].iov_base); + } + } + dma_addr = ib_dma_map_page(xprt->sc_cm_id->device, page, xdr_off, + min_t(size_t, PAGE_SIZE, len), dir); + return dma_addr; +} + /* Assumptions: - * - We are using FRMR - * - or - * - The specified write_len can be represented in sc_max_sge * PAGE_SIZE */ static int send_write(struct svcxprt_rdma *xprt, struct svc_rqst *rqstp, @@ -291,24 +177,16 @@ static int send_write(struct svcxprt_rdma *xprt, struct svc_rqst *rqstp, sge_bytes = min_t(size_t, bc, vec->sge[xdr_sge_no].iov_len-sge_off); sge[sge_no].length = sge_bytes; - if (!vec->frmr) { - sge[sge_no].addr = - ib_dma_map_single(xprt->sc_cm_id->device, - (void *) - vec->sge[xdr_sge_no].iov_base + sge_off, - sge_bytes, DMA_TO_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - sge[sge_no].addr)) - goto err; - atomic_inc(&xprt->sc_dma_used); - sge[sge_no].lkey = xprt->sc_dma_lkey; - } else { - sge[sge_no].addr = (unsigned long) - vec->sge[xdr_sge_no].iov_base + sge_off; - sge[sge_no].lkey = vec->frmr->mr->lkey; - } + sge[sge_no].addr = + dma_map_xdr(xprt, &rqstp->rq_res, xdr_off, + sge_bytes, DMA_TO_DEVICE); + xdr_off += sge_bytes; + if (ib_dma_mapping_error(xprt->sc_cm_id->device, + sge[sge_no].addr)) + goto err; + atomic_inc(&xprt->sc_dma_used); + sge[sge_no].lkey = xprt->sc_dma_lkey; ctxt->count++; - ctxt->frmr = vec->frmr; sge_off = 0; sge_no++; xdr_sge_no++; @@ -333,6 +211,7 @@ static int send_write(struct svcxprt_rdma *xprt, struct svc_rqst *rqstp, goto err; return 0; err: + svc_rdma_unmap_dma(ctxt); svc_rdma_put_context(ctxt, 0); /* Fatal error, close transport */ return -EIO; @@ -360,10 +239,7 @@ static int send_write_chunks(struct svcxprt_rdma *xprt, res_ary = (struct rpcrdma_write_array *) &rdma_resp->rm_body.rm_chunks[1]; - if (vec->frmr) - max_write = vec->frmr->map_len; - else - max_write = xprt->sc_max_sge * PAGE_SIZE; + max_write = xprt->sc_max_sge * PAGE_SIZE; /* Write chunks start at the pagelist */ for (xdr_off = rqstp->rq_res.head[0].iov_len, chunk_no = 0; @@ -373,21 +249,21 @@ static int send_write_chunks(struct svcxprt_rdma *xprt, u64 rs_offset; arg_ch = &arg_ary->wc_array[chunk_no].wc_target; - write_len = min(xfer_len, arg_ch->rs_length); + write_len = min(xfer_len, ntohl(arg_ch->rs_length)); /* Prepare the response chunk given the length actually * written */ - rs_offset = get_unaligned(&(arg_ch->rs_offset)); + xdr_decode_hyper((__be32 *)&arg_ch->rs_offset, &rs_offset); svc_rdma_xdr_encode_array_chunk(res_ary, chunk_no, - arg_ch->rs_handle, - rs_offset, - write_len); + arg_ch->rs_handle, + arg_ch->rs_offset, + write_len); chunk_off = 0; while (write_len) { int this_write; this_write = min(write_len, max_write); ret = send_write(xprt, rqstp, - arg_ch->rs_handle, + ntohl(arg_ch->rs_handle), rs_offset + chunk_off, xdr_off, this_write, @@ -421,6 +297,7 @@ static int send_reply_chunks(struct svcxprt_rdma *xprt, u32 xdr_off; int chunk_no; int chunk_off; + int nchunks; struct rpcrdma_segment *ch; struct rpcrdma_write_array *arg_ary; struct rpcrdma_write_array *res_ary; @@ -434,32 +311,30 @@ static int send_reply_chunks(struct svcxprt_rdma *xprt, res_ary = (struct rpcrdma_write_array *) &rdma_resp->rm_body.rm_chunks[2]; - if (vec->frmr) - max_write = vec->frmr->map_len; - else - max_write = xprt->sc_max_sge * PAGE_SIZE; + max_write = xprt->sc_max_sge * PAGE_SIZE; /* xdr offset starts at RPC message */ + nchunks = ntohl(arg_ary->wc_nchunks); for (xdr_off = 0, chunk_no = 0; - xfer_len && chunk_no < arg_ary->wc_nchunks; + xfer_len && chunk_no < nchunks; chunk_no++) { u64 rs_offset; ch = &arg_ary->wc_array[chunk_no].wc_target; - write_len = min(xfer_len, ch->rs_length); + write_len = min(xfer_len, htonl(ch->rs_length)); /* Prepare the reply chunk given the length actually * written */ - rs_offset = get_unaligned(&(ch->rs_offset)); + xdr_decode_hyper((__be32 *)&ch->rs_offset, &rs_offset); svc_rdma_xdr_encode_array_chunk(res_ary, chunk_no, - ch->rs_handle, rs_offset, - write_len); + ch->rs_handle, ch->rs_offset, + write_len); chunk_off = 0; while (write_len) { int this_write; this_write = min(write_len, max_write); ret = send_write(xprt, rqstp, - ch->rs_handle, + ntohl(ch->rs_handle), rs_offset + chunk_off, xdr_off, this_write, @@ -494,7 +369,8 @@ static int send_reply_chunks(struct svcxprt_rdma *xprt, * In all three cases, this function prepares the RPCRDMA header in * sge[0], the 'type' parameter indicates the type to place in the * RPCRDMA header, and the 'byte_count' field indicates how much of - * the XDR to include in this RDMA_SEND. + * the XDR to include in this RDMA_SEND. NB: The offset of the payload + * to send is zero in the XDR. */ static int send_reply(struct svcxprt_rdma *rdma, struct svc_rqst *rqstp, @@ -505,10 +381,10 @@ static int send_reply(struct svcxprt_rdma *rdma, int byte_count) { struct ib_send_wr send_wr; - struct ib_send_wr inv_wr; int sge_no; int sge_bytes; int page_no; + int pages; int ret; /* Post a recv buffer to handle another request. */ @@ -518,7 +394,6 @@ static int send_reply(struct svcxprt_rdma *rdma, "svcrdma: could not post a receive buffer, err=%d." "Closing transport %p.\n", ret, rdma); set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); - svc_rdma_put_frmr(rdma, vec->frmr); svc_rdma_put_context(ctxt, 0); return -ENOTCONN; } @@ -526,43 +401,33 @@ static int send_reply(struct svcxprt_rdma *rdma, /* Prepare the context */ ctxt->pages[0] = page; ctxt->count = 1; - ctxt->frmr = vec->frmr; - if (vec->frmr) - set_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); - else - clear_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); /* Prepare the SGE for the RPCRDMA Header */ ctxt->sge[0].lkey = rdma->sc_dma_lkey; ctxt->sge[0].length = svc_rdma_xdr_get_reply_hdr_len(rdma_resp); ctxt->sge[0].addr = - ib_dma_map_single(rdma->sc_cm_id->device, page_address(page), - ctxt->sge[0].length, DMA_TO_DEVICE); + ib_dma_map_page(rdma->sc_cm_id->device, page, 0, + ctxt->sge[0].length, DMA_TO_DEVICE); if (ib_dma_mapping_error(rdma->sc_cm_id->device, ctxt->sge[0].addr)) goto err; atomic_inc(&rdma->sc_dma_used); ctxt->direction = DMA_TO_DEVICE; - /* Determine how many of our SGE are to be transmitted */ + /* Map the payload indicated by 'byte_count' */ for (sge_no = 1; byte_count && sge_no < vec->count; sge_no++) { + int xdr_off = 0; sge_bytes = min_t(size_t, vec->sge[sge_no].iov_len, byte_count); byte_count -= sge_bytes; - if (!vec->frmr) { - ctxt->sge[sge_no].addr = - ib_dma_map_single(rdma->sc_cm_id->device, - vec->sge[sge_no].iov_base, - sge_bytes, DMA_TO_DEVICE); - if (ib_dma_mapping_error(rdma->sc_cm_id->device, - ctxt->sge[sge_no].addr)) - goto err; - atomic_inc(&rdma->sc_dma_used); - ctxt->sge[sge_no].lkey = rdma->sc_dma_lkey; - } else { - ctxt->sge[sge_no].addr = (unsigned long) - vec->sge[sge_no].iov_base; - ctxt->sge[sge_no].lkey = vec->frmr->mr->lkey; - } + ctxt->sge[sge_no].addr = + dma_map_xdr(rdma, &rqstp->rq_res, xdr_off, + sge_bytes, DMA_TO_DEVICE); + xdr_off += sge_bytes; + if (ib_dma_mapping_error(rdma->sc_cm_id->device, + ctxt->sge[sge_no].addr)) + goto err; + atomic_inc(&rdma->sc_dma_used); + ctxt->sge[sge_no].lkey = rdma->sc_dma_lkey; ctxt->sge[sge_no].length = sge_bytes; } BUG_ON(byte_count != 0); @@ -571,7 +436,8 @@ static int send_reply(struct svcxprt_rdma *rdma, * respages array. They are our pages until the I/O * completes. */ - for (page_no = 0; page_no < rqstp->rq_resused; page_no++) { + pages = rqstp->rq_next_page - rqstp->rq_respages; + for (page_no = 0; page_no < pages; page_no++) { ctxt->pages[page_no+1] = rqstp->rq_respages[page_no]; ctxt->count++; rqstp->rq_respages[page_no] = NULL; @@ -583,6 +449,8 @@ static int send_reply(struct svcxprt_rdma *rdma, if (page_no+1 >= sge_no) ctxt->sge[page_no+1].length = 0; } + rqstp->rq_next_page = rqstp->rq_respages + 1; + BUG_ON(sge_no > rdma->sc_max_sge); memset(&send_wr, 0, sizeof send_wr); ctxt->wr_op = IB_WR_SEND; @@ -591,15 +459,6 @@ static int send_reply(struct svcxprt_rdma *rdma, send_wr.num_sge = sge_no; send_wr.opcode = IB_WR_SEND; send_wr.send_flags = IB_SEND_SIGNALED; - if (vec->frmr) { - /* Prepare INVALIDATE WR */ - memset(&inv_wr, 0, sizeof inv_wr); - inv_wr.opcode = IB_WR_LOCAL_INV; - inv_wr.send_flags = IB_SEND_SIGNALED; - inv_wr.ex.invalidate_rkey = - vec->frmr->mr->lkey; - send_wr.next = &inv_wr; - } ret = svc_rdma_send(rdma, &send_wr); if (ret) @@ -609,7 +468,6 @@ static int send_reply(struct svcxprt_rdma *rdma, err: svc_rdma_unmap_dma(ctxt); - svc_rdma_put_frmr(rdma, vec->frmr); svc_rdma_put_context(ctxt, 1); return -EIO; } diff --git a/net/sunrpc/xprtrdma/svc_rdma_transport.c b/net/sunrpc/xprtrdma/svc_rdma_transport.c index 3fa5751af0e..e7323fbbd34 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_transport.c +++ b/net/sunrpc/xprtrdma/svc_rdma_transport.c @@ -1,4 +1,5 @@ /* + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. * Copyright (c) 2005-2007 Network Appliance, Inc. All rights reserved. * * This software is available to you under a choice of one of two @@ -42,15 +43,21 @@ #include <linux/sunrpc/svc_xprt.h> #include <linux/sunrpc/debug.h> #include <linux/sunrpc/rpc_rdma.h> +#include <linux/interrupt.h> #include <linux/sched.h> +#include <linux/slab.h> #include <linux/spinlock.h> +#include <linux/workqueue.h> #include <rdma/ib_verbs.h> #include <rdma/rdma_cm.h> #include <linux/sunrpc/svc_rdma.h> +#include <linux/export.h> +#include "xprt_rdma.h" #define RPCDBG_FACILITY RPCDBG_SVCXPRT static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, + struct net *net, struct sockaddr *sa, int salen, int flags); static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt); @@ -59,6 +66,7 @@ static void dto_tasklet_func(unsigned long data); static void svc_rdma_detach(struct svc_xprt *xprt); static void svc_rdma_free(struct svc_xprt *xprt); static int svc_rdma_has_wspace(struct svc_xprt *xprt); +static int svc_rdma_secure_port(struct svc_rqst *); static void rq_cq_reap(struct svcxprt_rdma *xprt); static void sq_cq_reap(struct svcxprt_rdma *xprt); @@ -76,6 +84,7 @@ static struct svc_xprt_ops svc_rdma_ops = { .xpo_prep_reply_hdr = svc_rdma_prep_reply_hdr, .xpo_has_wspace = svc_rdma_has_wspace, .xpo_accept = svc_rdma_accept, + .xpo_secure_port = svc_rdma_secure_port, }; struct svc_xprt_class svc_rdma_class = { @@ -85,9 +94,6 @@ struct svc_xprt_class svc_rdma_class = { .xcl_max_payload = RPCSVC_MAXPAYLOAD_TCP, }; -/* WR context cache. Created in svc_rdma.c */ -extern struct kmem_cache *svc_rdma_ctxt_cachep; - struct svc_rdma_op_ctxt *svc_rdma_get_context(struct svcxprt_rdma *xprt) { struct svc_rdma_op_ctxt *ctxt; @@ -119,7 +125,7 @@ void svc_rdma_unmap_dma(struct svc_rdma_op_ctxt *ctxt) */ if (ctxt->sge[i].lkey == xprt->sc_dma_lkey) { atomic_dec(&xprt->sc_dma_used); - ib_dma_unmap_single(xprt->sc_cm_id->device, + ib_dma_unmap_page(xprt->sc_cm_id->device, ctxt->sge[i].addr, ctxt->sge[i].length, ctxt->direction); @@ -142,9 +148,6 @@ void svc_rdma_put_context(struct svc_rdma_op_ctxt *ctxt, int free_pages) atomic_dec(&xprt->sc_ctxt_used); } -/* Temporary NFS request map cache. Created in svc_rdma.c */ -extern struct kmem_cache *svc_rdma_map_cachep; - /* * Temporary NFS req mappings are shared across all transport * instances. These are short lived and should be bounded by the number @@ -160,7 +163,6 @@ struct svc_rdma_req_map *svc_rdma_get_req_map(void) schedule_timeout_uninterruptible(msecs_to_jiffies(500)); } map->count = 0; - map->frmr = NULL; return map; } @@ -327,7 +329,7 @@ static void rq_cq_reap(struct svcxprt_rdma *xprt) } /* - * Processs a completion context + * Process a completion context */ static void process_context(struct svcxprt_rdma *xprt, struct svc_rdma_op_ctxt *ctxt) @@ -336,22 +338,21 @@ static void process_context(struct svcxprt_rdma *xprt, switch (ctxt->wr_op) { case IB_WR_SEND: - if (test_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags)) - svc_rdma_put_frmr(xprt, ctxt->frmr); + BUG_ON(ctxt->frmr); svc_rdma_put_context(ctxt, 1); break; case IB_WR_RDMA_WRITE: + BUG_ON(ctxt->frmr); svc_rdma_put_context(ctxt, 0); break; case IB_WR_RDMA_READ: case IB_WR_RDMA_READ_WITH_INV: + svc_rdma_put_frmr(xprt, ctxt->frmr); if (test_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags)) { struct svc_rdma_op_ctxt *read_hdr = ctxt->read_hdr; BUG_ON(!read_hdr); - if (test_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags)) - svc_rdma_put_frmr(xprt, ctxt->frmr); spin_lock_bh(&xprt->sc_rq_dto_lock); set_bit(XPT_DATA, &xprt->sc_xprt.xpt_flags); list_add_tail(&read_hdr->dto_q, @@ -363,6 +364,7 @@ static void process_context(struct svcxprt_rdma *xprt, break; default: + BUG_ON(1); printk(KERN_ERR "svcrdma: unexpected completion type, " "opcode=%d\n", ctxt->wr_op); @@ -378,29 +380,42 @@ static void process_context(struct svcxprt_rdma *xprt, static void sq_cq_reap(struct svcxprt_rdma *xprt) { struct svc_rdma_op_ctxt *ctxt = NULL; - struct ib_wc wc; + struct ib_wc wc_a[6]; + struct ib_wc *wc; struct ib_cq *cq = xprt->sc_sq_cq; int ret; + memset(wc_a, 0, sizeof(wc_a)); + if (!test_and_clear_bit(RDMAXPRT_SQ_PENDING, &xprt->sc_flags)) return; ib_req_notify_cq(xprt->sc_sq_cq, IB_CQ_NEXT_COMP); atomic_inc(&rdma_stat_sq_poll); - while ((ret = ib_poll_cq(cq, 1, &wc)) > 0) { - if (wc.status != IB_WC_SUCCESS) - /* Close the transport */ - set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); + while ((ret = ib_poll_cq(cq, ARRAY_SIZE(wc_a), wc_a)) > 0) { + int i; - /* Decrement used SQ WR count */ - atomic_dec(&xprt->sc_sq_count); - wake_up(&xprt->sc_send_wait); + for (i = 0; i < ret; i++) { + wc = &wc_a[i]; + if (wc->status != IB_WC_SUCCESS) { + dprintk("svcrdma: sq wc err status %d\n", + wc->status); - ctxt = (struct svc_rdma_op_ctxt *)(unsigned long)wc.wr_id; - if (ctxt) - process_context(xprt, ctxt); + /* Close the transport */ + set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); + } - svc_xprt_put(&xprt->sc_xprt); + /* Decrement used SQ WR count */ + atomic_dec(&xprt->sc_sq_count); + wake_up(&xprt->sc_send_wait); + + ctxt = (struct svc_rdma_op_ctxt *) + (unsigned long)wc->wr_id; + if (ctxt) + process_context(xprt, ctxt); + + svc_xprt_put(&xprt->sc_xprt); + } } if (ctxt) @@ -445,7 +460,7 @@ static struct svcxprt_rdma *rdma_create_xprt(struct svc_serv *serv, if (!cma_xprt) return NULL; - svc_xprt_init(&svc_rdma_class, &cma_xprt->sc_xprt, serv); + svc_xprt_init(&init_net, &svc_rdma_class, &cma_xprt->sc_xprt, serv); INIT_LIST_HEAD(&cma_xprt->sc_accept_q); INIT_LIST_HEAD(&cma_xprt->sc_dto_q); INIT_LIST_HEAD(&cma_xprt->sc_rq_dto_q); @@ -477,8 +492,7 @@ struct page *svc_rdma_get_page(void) while ((page = alloc_page(GFP_KERNEL)) == NULL) { /* If we can't get memory, wait a bit and try again */ - printk(KERN_INFO "svcrdma: out of memory...retrying in 1000 " - "jiffies.\n"); + printk(KERN_INFO "svcrdma: out of memory...retrying in 1s\n"); schedule_timeout_uninterruptible(msecs_to_jiffies(1000)); } return page; @@ -501,8 +515,8 @@ int svc_rdma_post_recv(struct svcxprt_rdma *xprt) BUG_ON(sge_no >= xprt->sc_max_sge); page = svc_rdma_get_page(); ctxt->pages[sge_no] = page; - pa = ib_dma_map_single(xprt->sc_cm_id->device, - page_address(page), PAGE_SIZE, + pa = ib_dma_map_page(xprt->sc_cm_id->device, + page, 0, PAGE_SIZE, DMA_FROM_DEVICE); if (ib_dma_mapping_error(xprt->sc_cm_id->device, pa)) goto err_put_ctxt; @@ -510,9 +524,9 @@ int svc_rdma_post_recv(struct svcxprt_rdma *xprt) ctxt->sge[sge_no].addr = pa; ctxt->sge[sge_no].length = PAGE_SIZE; ctxt->sge[sge_no].lkey = xprt->sc_dma_lkey; + ctxt->count = sge_no + 1; buflen += PAGE_SIZE; } - ctxt->count = sge_no; recv_wr.next = NULL; recv_wr.sg_list = &ctxt->sge[0]; recv_wr.num_sge = ctxt->count; @@ -528,6 +542,7 @@ int svc_rdma_post_recv(struct svcxprt_rdma *xprt) return ret; err_put_ctxt: + svc_rdma_unmap_dma(ctxt); svc_rdma_put_context(ctxt, 1); return -ENOMEM; } @@ -577,10 +592,6 @@ static void handle_connect_req(struct rdma_cm_id *new_cma_id, size_t client_ird) list_add_tail(&newxprt->sc_accept_q, &listen_xprt->sc_accept_q); spin_unlock_bh(&listen_xprt->sc_lock); - /* - * Can't use svc_xprt_received here because we are not on a - * rqstp thread - */ set_bit(XPT_CONN, &listen_xprt->sc_xprt.xpt_flags); svc_xprt_enqueue(&listen_xprt->sc_xprt); } @@ -669,6 +680,7 @@ static int rdma_cma_handler(struct rdma_cm_id *cma_id, * Create a listening RDMA service endpoint. */ static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, + struct net *net, struct sockaddr *sa, int salen, int flags) { @@ -678,13 +690,17 @@ static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, int ret; dprintk("svcrdma: Creating RDMA socket\n"); - + if (sa->sa_family != AF_INET) { + dprintk("svcrdma: Address family %d is not supported.\n", sa->sa_family); + return ERR_PTR(-EAFNOSUPPORT); + } cma_xprt = rdma_create_xprt(serv, 1); if (!cma_xprt) return ERR_PTR(-ENOMEM); xprt = &cma_xprt->sc_xprt; - listen_id = rdma_create_id(rdma_listen_handler, cma_xprt, RDMA_PS_TCP); + listen_id = rdma_create_id(rdma_listen_handler, cma_xprt, RDMA_PS_TCP, + IB_QPT_RC); if (IS_ERR(listen_id)) { ret = PTR_ERR(listen_id); dprintk("svcrdma: rdma_create_id failed = %d\n", ret); @@ -794,8 +810,8 @@ static void frmr_unmap_dma(struct svcxprt_rdma *xprt, if (ib_dma_mapping_error(frmr->mr->device, addr)) continue; atomic_dec(&xprt->sc_dma_used); - ib_dma_unmap_single(frmr->mr->device, addr, PAGE_SIZE, - frmr->direction); + ib_dma_unmap_page(frmr->mr->device, addr, PAGE_SIZE, + frmr->direction); } } @@ -992,7 +1008,11 @@ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) need_dma_mr = 0; break; case RDMA_TRANSPORT_IB: - if (!(devattr.device_cap_flags & IB_DEVICE_LOCAL_DMA_LKEY)) { + if (!(newxprt->sc_dev_caps & SVCRDMA_DEVCAP_FAST_REG)) { + need_dma_mr = 1; + dma_mr_acc = IB_ACCESS_LOCAL_WRITE; + } else if (!(devattr.device_cap_flags & + IB_DEVICE_LOCAL_DMA_LKEY)) { need_dma_mr = 1; dma_mr_acc = IB_ACCESS_LOCAL_WRITE; } else @@ -1180,7 +1200,7 @@ static void svc_rdma_free(struct svc_xprt *xprt) struct svcxprt_rdma *rdma = container_of(xprt, struct svcxprt_rdma, sc_xprt); INIT_WORK(&rdma->sc_work, __svc_rdma_free); - schedule_work(&rdma->sc_work); + queue_work(svc_rdma_wq, &rdma->sc_work); } static int svc_rdma_has_wspace(struct svc_xprt *xprt) @@ -1189,14 +1209,7 @@ static int svc_rdma_has_wspace(struct svc_xprt *xprt) container_of(xprt, struct svcxprt_rdma, sc_xprt); /* - * If there are fewer SQ WR available than required to send a - * simple response, return false. - */ - if ((rdma->sc_sq_depth - atomic_read(&rdma->sc_sq_count) < 3)) - return 0; - - /* - * ...or there are already waiters on the SQ, + * If there are already waiters on the SQ, * return false. */ if (waitqueue_active(&rdma->sc_send_wait)) @@ -1206,6 +1219,11 @@ static int svc_rdma_has_wspace(struct svc_xprt *xprt) return 1; } +static int svc_rdma_secure_port(struct svc_rqst *rqstp) +{ + return 1; +} + /* * Attempt to register the kvec representing the RPC memory with the * device. @@ -1270,7 +1288,7 @@ int svc_rdma_send(struct svcxprt_rdma *xprt, struct ib_send_wr *wr) atomic_read(&xprt->sc_sq_count) < xprt->sc_sq_depth); if (test_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags)) - return 0; + return -ENOTCONN; continue; } /* Take a transport ref for each WR posted */ @@ -1302,7 +1320,6 @@ void svc_rdma_send_error(struct svcxprt_rdma *xprt, struct rpcrdma_msg *rmsgp, enum rpcrdma_errcode err) { struct ib_send_wr err_wr; - struct ib_sge sge; struct page *p; struct svc_rdma_op_ctxt *ctxt; u32 *va; @@ -1315,26 +1332,28 @@ void svc_rdma_send_error(struct svcxprt_rdma *xprt, struct rpcrdma_msg *rmsgp, /* XDR encode error */ length = svc_rdma_xdr_encode_error(xprt, rmsgp, err, va); + ctxt = svc_rdma_get_context(xprt); + ctxt->direction = DMA_FROM_DEVICE; + ctxt->count = 1; + ctxt->pages[0] = p; + /* Prepare SGE for local address */ - sge.addr = ib_dma_map_single(xprt->sc_cm_id->device, - page_address(p), PAGE_SIZE, DMA_FROM_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, sge.addr)) { + ctxt->sge[0].addr = ib_dma_map_page(xprt->sc_cm_id->device, + p, 0, length, DMA_FROM_DEVICE); + if (ib_dma_mapping_error(xprt->sc_cm_id->device, ctxt->sge[0].addr)) { put_page(p); + svc_rdma_put_context(ctxt, 1); return; } atomic_inc(&xprt->sc_dma_used); - sge.lkey = xprt->sc_dma_lkey; - sge.length = length; - - ctxt = svc_rdma_get_context(xprt); - ctxt->count = 1; - ctxt->pages[0] = p; + ctxt->sge[0].lkey = xprt->sc_dma_lkey; + ctxt->sge[0].length = length; /* Prepare SEND WR */ memset(&err_wr, 0, sizeof err_wr); ctxt->wr_op = IB_WR_SEND; err_wr.wr_id = (unsigned long)ctxt; - err_wr.sg_list = &sge; + err_wr.sg_list = ctxt->sge; err_wr.num_sge = 1; err_wr.opcode = IB_WR_SEND; err_wr.send_flags = IB_SEND_SIGNALED; @@ -1344,9 +1363,7 @@ void svc_rdma_send_error(struct svcxprt_rdma *xprt, struct rpcrdma_msg *rmsgp, if (ret) { dprintk("svcrdma: Error %d posting send for protocol error\n", ret); - ib_dma_unmap_single(xprt->sc_cm_id->device, - sge.addr, PAGE_SIZE, - DMA_FROM_DEVICE); + svc_rdma_unmap_dma(ctxt); svc_rdma_put_context(ctxt, 1); } } diff --git a/net/sunrpc/xprtrdma/transport.c b/net/sunrpc/xprtrdma/transport.c index f96c2fe6137..66f91f0d071 100644 --- a/net/sunrpc/xprtrdma/transport.c +++ b/net/sunrpc/xprtrdma/transport.c @@ -49,7 +49,9 @@ #include <linux/module.h> #include <linux/init.h> +#include <linux/slab.h> #include <linux/seq_file.h> +#include <linux/sunrpc/addr.h> #include "xprt_rdma.h" @@ -84,7 +86,7 @@ static unsigned int max_memreg = RPCRDMA_LAST - 1; static struct ctl_table_header *sunrpc_table_header; -static ctl_table xr_tunables_table[] = { +static struct ctl_table xr_tunables_table[] = { { .procname = "rdma_slot_table_entries", .data = &xprt_rdma_slot_table_entries, @@ -136,7 +138,7 @@ static ctl_table xr_tunables_table[] = { { }, }; -static ctl_table sunrpc_table[] = { +static struct ctl_table sunrpc_table[] = { { .procname = "sunrpc", .mode = 0555, @@ -147,6 +149,11 @@ static ctl_table sunrpc_table[] = { #endif +#define RPCRDMA_BIND_TO (60U * HZ) +#define RPCRDMA_INIT_REEST_TO (5U * HZ) +#define RPCRDMA_MAX_REEST_TO (30U * HZ) +#define RPCRDMA_IDLE_DISC_TO (5U * 60 * HZ) + static struct rpc_xprt_ops xprt_rdma_procs; /* forward reference */ static void @@ -198,23 +205,18 @@ xprt_rdma_connect_worker(struct work_struct *work) struct rpc_xprt *xprt = &r_xprt->xprt; int rc = 0; - if (!xprt->shutdown) { - xprt_clear_connected(xprt); - - dprintk("RPC: %s: %sconnect\n", __func__, - r_xprt->rx_ep.rep_connected != 0 ? "re" : ""); - rc = rpcrdma_ep_connect(&r_xprt->rx_ep, &r_xprt->rx_ia); - if (rc) - goto out; - } - goto out_clear; + current->flags |= PF_FSTRANS; + xprt_clear_connected(xprt); -out: - xprt_wake_pending_tasks(xprt, rc); + dprintk("RPC: %s: %sconnect\n", __func__, + r_xprt->rx_ep.rep_connected != 0 ? "re" : ""); + rc = rpcrdma_ep_connect(&r_xprt->rx_ep, &r_xprt->rx_ia); + if (rc) + xprt_wake_pending_tasks(xprt, rc); -out_clear: dprintk("RPC: %s: exit\n", __func__); xprt_clear_connecting(xprt); + current->flags &= ~PF_FSTRANS; } /* @@ -232,27 +234,20 @@ static void xprt_rdma_destroy(struct rpc_xprt *xprt) { struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); - int rc; dprintk("RPC: %s: called\n", __func__); - cancel_delayed_work(&r_xprt->rdma_connect); - flush_scheduled_work(); + cancel_delayed_work_sync(&r_xprt->rdma_connect); xprt_clear_connected(xprt); rpcrdma_buffer_destroy(&r_xprt->rx_buf); - rc = rpcrdma_ep_destroy(&r_xprt->rx_ep, &r_xprt->rx_ia); - if (rc) - dprintk("RPC: %s: rpcrdma_ep_destroy returned %i\n", - __func__, rc); + rpcrdma_ep_destroy(&r_xprt->rx_ep, &r_xprt->rx_ia); rpcrdma_ia_close(&r_xprt->rx_ia); xprt_rdma_free_addresses(xprt); - kfree(xprt->slot); - xprt->slot = NULL; - kfree(xprt); + xprt_free(xprt); dprintk("RPC: %s: returning\n", __func__); @@ -284,29 +279,20 @@ xprt_setup_rdma(struct xprt_create *args) return ERR_PTR(-EBADF); } - xprt = kzalloc(sizeof(struct rpcrdma_xprt), GFP_KERNEL); + xprt = xprt_alloc(args->net, sizeof(struct rpcrdma_xprt), + xprt_rdma_slot_table_entries, + xprt_rdma_slot_table_entries); if (xprt == NULL) { dprintk("RPC: %s: couldn't allocate rpcrdma_xprt\n", __func__); return ERR_PTR(-ENOMEM); } - xprt->max_reqs = xprt_rdma_slot_table_entries; - xprt->slot = kcalloc(xprt->max_reqs, - sizeof(struct rpc_rqst), GFP_KERNEL); - if (xprt->slot == NULL) { - dprintk("RPC: %s: couldn't allocate %d slots\n", - __func__, xprt->max_reqs); - kfree(xprt); - return ERR_PTR(-ENOMEM); - } - /* 60 second timeout, no retries */ xprt->timeout = &xprt_rdma_default_timeout; - xprt->bind_timeout = (60U * HZ); - xprt->connect_timeout = (60U * HZ); - xprt->reestablish_timeout = (5U * HZ); - xprt->idle_timeout = (5U * 60 * HZ); + xprt->bind_timeout = RPCRDMA_BIND_TO; + xprt->reestablish_timeout = RPCRDMA_INIT_REEST_TO; + xprt->idle_timeout = RPCRDMA_IDLE_DISC_TO; xprt->resvport = 0; /* privileged port not needed */ xprt->tsh_size = 0; /* RPC-RDMA handles framing */ @@ -406,12 +392,11 @@ out4: xprt_rdma_free_addresses(xprt); rc = -EINVAL; out3: - (void) rpcrdma_ep_destroy(new_ep, &new_xprt->rx_ia); + rpcrdma_ep_destroy(new_ep, &new_xprt->rx_ia); out2: rpcrdma_ia_close(&new_xprt->rx_ia); out1: - kfree(xprt->slot); - kfree(xprt); + xprt_free(xprt); return ERR_PTR(rc); } @@ -443,45 +428,24 @@ xprt_rdma_set_port(struct rpc_xprt *xprt, u16 port) } static void -xprt_rdma_connect(struct rpc_task *task) +xprt_rdma_connect(struct rpc_xprt *xprt, struct rpc_task *task) { - struct rpc_xprt *xprt = (struct rpc_xprt *)task->tk_xprt; struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); - if (!xprt_test_and_set_connecting(xprt)) { - if (r_xprt->rx_ep.rep_connected != 0) { - /* Reconnect */ - schedule_delayed_work(&r_xprt->rdma_connect, - xprt->reestablish_timeout); - xprt->reestablish_timeout <<= 1; - if (xprt->reestablish_timeout > (30 * HZ)) - xprt->reestablish_timeout = (30 * HZ); - else if (xprt->reestablish_timeout < (5 * HZ)) - xprt->reestablish_timeout = (5 * HZ); - } else { - schedule_delayed_work(&r_xprt->rdma_connect, 0); - if (!RPC_IS_ASYNC(task)) - flush_scheduled_work(); - } - } -} - -static int -xprt_rdma_reserve_xprt(struct rpc_task *task) -{ - struct rpc_xprt *xprt = task->tk_xprt; - struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); - int credits = atomic_read(&r_xprt->rx_buf.rb_credits); - - /* == RPC_CWNDSCALE @ init, but *after* setup */ - if (r_xprt->rx_buf.rb_cwndscale == 0UL) { - r_xprt->rx_buf.rb_cwndscale = xprt->cwnd; - dprintk("RPC: %s: cwndscale %lu\n", __func__, - r_xprt->rx_buf.rb_cwndscale); - BUG_ON(r_xprt->rx_buf.rb_cwndscale <= 0); + if (r_xprt->rx_ep.rep_connected != 0) { + /* Reconnect */ + schedule_delayed_work(&r_xprt->rdma_connect, + xprt->reestablish_timeout); + xprt->reestablish_timeout <<= 1; + if (xprt->reestablish_timeout > RPCRDMA_MAX_REEST_TO) + xprt->reestablish_timeout = RPCRDMA_MAX_REEST_TO; + else if (xprt->reestablish_timeout < RPCRDMA_INIT_REEST_TO) + xprt->reestablish_timeout = RPCRDMA_INIT_REEST_TO; + } else { + schedule_delayed_work(&r_xprt->rdma_connect, 0); + if (!RPC_IS_ASYNC(task)) + flush_delayed_work(&r_xprt->rdma_connect); } - xprt->cwnd = credits * r_xprt->rx_buf.rb_cwndscale; - return xprt_reserve_xprt_cong(task); } /* @@ -495,11 +459,12 @@ xprt_rdma_reserve_xprt(struct rpc_task *task) static void * xprt_rdma_allocate(struct rpc_task *task, size_t size) { - struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; struct rpcrdma_req *req, *nreq; req = rpcrdma_buffer_get(&rpcx_to_rdmax(xprt)->rx_buf); - BUG_ON(NULL == req); + if (req == NULL) + return NULL; if (size > req->rl_size) { dprintk("RPC: %s: size %zd too large for buffer[%zd]: " @@ -523,18 +488,6 @@ xprt_rdma_allocate(struct rpc_task *task, size_t size) * If the allocation or registration fails, the RPC framework * will (doggedly) retry. */ - if (rpcx_to_rdmax(xprt)->rx_ia.ri_memreg_strategy == - RPCRDMA_BOUNCEBUFFERS) { - /* forced to "pure inline" */ - dprintk("RPC: %s: too much data (%zd) for inline " - "(r/w max %d/%d)\n", __func__, size, - rpcx_to_rdmad(xprt).inline_rsize, - rpcx_to_rdmad(xprt).inline_wsize); - size = req->rl_size; - rpc_exit(task, -EIO); /* fail the operation */ - rpcx_to_rdmax(xprt)->rx_stats.failed_marshal_count++; - goto out; - } if (task->tk_flags & RPC_TASK_SWAPPER) nreq = kmalloc(sizeof *req + size, GFP_ATOMIC); else @@ -563,7 +516,6 @@ xprt_rdma_allocate(struct rpc_task *task, size_t size) req = nreq; } dprintk("RPC: %s: size %zd, request 0x%p\n", __func__, size, req); -out: req->rl_connect_cookie = 0; /* our reserved value */ return req->rl_xdr_buf; @@ -599,9 +551,7 @@ xprt_rdma_free(void *buffer) __func__, rep, (rep && rep->rr_func) ? " (with waiter)" : ""); /* - * Finish the deregistration. When using mw bind, this was - * begun in rpcrdma_reply_handler(). In all other modes, we - * do it here, in thread context. The process is considered + * Finish the deregistration. The process is considered * complete when the rr_func vector becomes NULL - this * was put in place during rpcrdma_reply_handler() - the wait * call below will not block if the dereg is "done". If @@ -610,12 +560,7 @@ xprt_rdma_free(void *buffer) for (i = 0; req->rl_nchunks;) { --req->rl_nchunks; i += rpcrdma_deregister_external( - &req->rl_segments[i], r_xprt, NULL); - } - - if (rep && wait_event_interruptible(rep->rr_unbind, !rep->rr_func)) { - rep->rr_func = NULL; /* abandon the callback */ - req->rl_reply = NULL; + &req->rl_segments[i], r_xprt); } if (req->rl_iov.length == 0) { /* see allocate above */ @@ -647,16 +592,15 @@ static int xprt_rdma_send_request(struct rpc_task *task) { struct rpc_rqst *rqst = task->tk_rqstp; - struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_xprt *xprt = rqst->rq_xprt; struct rpcrdma_req *req = rpcr_to_rdmar(rqst); struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + int rc; - /* marshal the send itself */ - if (req->rl_niovs == 0 && rpcrdma_marshal_req(rqst) != 0) { - r_xprt->rx_stats.failed_marshal_count++; - dprintk("RPC: %s: rpcrdma_marshal_req failed\n", - __func__); - return -EIO; + if (req->rl_niovs == 0) { + rc = rpcrdma_marshal_req(rqst); + if (rc < 0) + goto failed_marshal; } if (req->rl_reply == NULL) /* e.g. reconnection */ @@ -676,10 +620,16 @@ xprt_rdma_send_request(struct rpc_task *task) if (rpcrdma_ep_post(&r_xprt->rx_ia, &r_xprt->rx_ep, req)) goto drop_connection; - task->tk_bytes_sent += rqst->rq_snd_buf.len; + rqst->rq_xmit_bytes_sent += rqst->rq_snd_buf.len; rqst->rq_bytes_sent = 0; return 0; +failed_marshal: + r_xprt->rx_stats.failed_marshal_count++; + dprintk("RPC: %s: rpcrdma_marshal_req failed, status %i\n", + __func__, rc); + if (rc == -EIO) + return -EIO; drop_connection: xprt_disconnect_done(xprt); return -ENOTCONN; /* implies disconnect */ @@ -725,8 +675,9 @@ static void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) */ static struct rpc_xprt_ops xprt_rdma_procs = { - .reserve_xprt = xprt_rdma_reserve_xprt, + .reserve_xprt = xprt_reserve_xprt_cong, .release_xprt = xprt_release_xprt_cong, /* sunrpc/xprt.c */ + .alloc_slot = xprt_alloc_slot, .release_request = xprt_release_rqst_cong, /* ditto */ .set_retrans_timeout = xprt_set_retrans_timeout_def, /* ditto */ .rpcbind = rpcb_getport_async, /* sunrpc/rpcb_clnt.c */ @@ -752,7 +703,7 @@ static void __exit xprt_rdma_cleanup(void) { int rc; - dprintk(KERN_INFO "RPCRDMA Module Removed, deregister RPC RDMA transport\n"); + dprintk("RPCRDMA Module Removed, deregister RPC RDMA transport\n"); #ifdef RPC_DEBUG if (sunrpc_table_header) { unregister_sysctl_table(sunrpc_table_header); @@ -774,14 +725,14 @@ static int __init xprt_rdma_init(void) if (rc) return rc; - dprintk(KERN_INFO "RPCRDMA Module Init, register RPC RDMA transport\n"); + dprintk("RPCRDMA Module Init, register RPC RDMA transport\n"); - dprintk(KERN_INFO "Defaults:\n"); - dprintk(KERN_INFO "\tSlots %d\n" + dprintk("Defaults:\n"); + dprintk("\tSlots %d\n" "\tMaxInlineRead %d\n\tMaxInlineWrite %d\n", xprt_rdma_slot_table_entries, xprt_rdma_max_inline_read, xprt_rdma_max_inline_write); - dprintk(KERN_INFO "\tPadding %d\n\tMemreg %d\n", + dprintk("\tPadding %d\n\tMemreg %d\n", xprt_rdma_inline_write_padding, xprt_rdma_memreg_strategy); #ifdef RPC_DEBUG diff --git a/net/sunrpc/xprtrdma/verbs.c b/net/sunrpc/xprtrdma/verbs.c index 2209aa87d89..13dbd1c389f 100644 --- a/net/sunrpc/xprtrdma/verbs.c +++ b/net/sunrpc/xprtrdma/verbs.c @@ -47,7 +47,9 @@ * o buffer memory */ -#include <linux/pci.h> /* for Tavor hack below */ +#include <linux/interrupt.h> +#include <linux/slab.h> +#include <asm/bitops.h> #include "xprt_rdma.h" @@ -140,89 +142,139 @@ rpcrdma_cq_async_error_upcall(struct ib_event *event, void *context) } } -static inline -void rpcrdma_event_process(struct ib_wc *wc) +static void +rpcrdma_sendcq_process_wc(struct ib_wc *wc) { - struct rpcrdma_rep *rep = - (struct rpcrdma_rep *)(unsigned long) wc->wr_id; + struct rpcrdma_mw *frmr = (struct rpcrdma_mw *)(unsigned long)wc->wr_id; - dprintk("RPC: %s: event rep %p status %X opcode %X length %u\n", - __func__, rep, wc->status, wc->opcode, wc->byte_len); + dprintk("RPC: %s: frmr %p status %X opcode %d\n", + __func__, frmr, wc->status, wc->opcode); - if (!rep) /* send or bind completion that we don't care about */ + if (wc->wr_id == 0ULL) + return; + if (wc->status != IB_WC_SUCCESS) return; - if (IB_WC_SUCCESS != wc->status) { - dprintk("RPC: %s: %s WC status %X, connection lost\n", - __func__, (wc->opcode & IB_WC_RECV) ? "recv" : "send", - wc->status); - rep->rr_len = ~0U; - rpcrdma_schedule_tasklet(rep); + if (wc->opcode == IB_WC_FAST_REG_MR) + frmr->r.frmr.state = FRMR_IS_VALID; + else if (wc->opcode == IB_WC_LOCAL_INV) + frmr->r.frmr.state = FRMR_IS_INVALID; +} + +static int +rpcrdma_sendcq_poll(struct ib_cq *cq, struct rpcrdma_ep *ep) +{ + struct ib_wc *wcs; + int budget, count, rc; + + budget = RPCRDMA_WC_BUDGET / RPCRDMA_POLLSIZE; + do { + wcs = ep->rep_send_wcs; + + rc = ib_poll_cq(cq, RPCRDMA_POLLSIZE, wcs); + if (rc <= 0) + return rc; + + count = rc; + while (count-- > 0) + rpcrdma_sendcq_process_wc(wcs++); + } while (rc == RPCRDMA_POLLSIZE && --budget); + return 0; +} + +/* + * Handle send, fast_reg_mr, and local_inv completions. + * + * Send events are typically suppressed and thus do not result + * in an upcall. Occasionally one is signaled, however. This + * prevents the provider's completion queue from wrapping and + * losing a completion. + */ +static void +rpcrdma_sendcq_upcall(struct ib_cq *cq, void *cq_context) +{ + struct rpcrdma_ep *ep = (struct rpcrdma_ep *)cq_context; + int rc; + + rc = rpcrdma_sendcq_poll(cq, ep); + if (rc) { + dprintk("RPC: %s: ib_poll_cq failed: %i\n", + __func__, rc); return; } - switch (wc->opcode) { - case IB_WC_RECV: - rep->rr_len = wc->byte_len; - ib_dma_sync_single_for_cpu( - rdmab_to_ia(rep->rr_buffer)->ri_id->device, - rep->rr_iov.addr, rep->rr_len, DMA_FROM_DEVICE); - /* Keep (only) the most recent credits, after check validity */ - if (rep->rr_len >= 16) { - struct rpcrdma_msg *p = - (struct rpcrdma_msg *) rep->rr_base; - unsigned int credits = ntohl(p->rm_credit); - if (credits == 0) { - dprintk("RPC: %s: server" - " dropped credits to 0!\n", __func__); - /* don't deadlock */ - credits = 1; - } else if (credits > rep->rr_buffer->rb_max_requests) { - dprintk("RPC: %s: server" - " over-crediting: %d (%d)\n", - __func__, credits, - rep->rr_buffer->rb_max_requests); - credits = rep->rr_buffer->rb_max_requests; - } - atomic_set(&rep->rr_buffer->rb_credits, credits); - } - /* fall through */ - case IB_WC_BIND_MW: - rpcrdma_schedule_tasklet(rep); - break; - default: - dprintk("RPC: %s: unexpected WC event %X\n", - __func__, wc->opcode); - break; + rc = ib_req_notify_cq(cq, + IB_CQ_NEXT_COMP | IB_CQ_REPORT_MISSED_EVENTS); + if (rc == 0) + return; + if (rc < 0) { + dprintk("RPC: %s: ib_req_notify_cq failed: %i\n", + __func__, rc); + return; } + + rpcrdma_sendcq_poll(cq, ep); } -static inline int -rpcrdma_cq_poll(struct ib_cq *cq) +static void +rpcrdma_recvcq_process_wc(struct ib_wc *wc) { - struct ib_wc wc; - int rc; + struct rpcrdma_rep *rep = + (struct rpcrdma_rep *)(unsigned long)wc->wr_id; - for (;;) { - rc = ib_poll_cq(cq, 1, &wc); - if (rc < 0) { - dprintk("RPC: %s: ib_poll_cq failed %i\n", - __func__, rc); - return rc; - } - if (rc == 0) - break; + dprintk("RPC: %s: rep %p status %X opcode %X length %u\n", + __func__, rep, wc->status, wc->opcode, wc->byte_len); - rpcrdma_event_process(&wc); + if (wc->status != IB_WC_SUCCESS) { + rep->rr_len = ~0U; + goto out_schedule; } + if (wc->opcode != IB_WC_RECV) + return; + rep->rr_len = wc->byte_len; + ib_dma_sync_single_for_cpu(rdmab_to_ia(rep->rr_buffer)->ri_id->device, + rep->rr_iov.addr, rep->rr_len, DMA_FROM_DEVICE); + + if (rep->rr_len >= 16) { + struct rpcrdma_msg *p = (struct rpcrdma_msg *)rep->rr_base; + unsigned int credits = ntohl(p->rm_credit); + + if (credits == 0) + credits = 1; /* don't deadlock */ + else if (credits > rep->rr_buffer->rb_max_requests) + credits = rep->rr_buffer->rb_max_requests; + atomic_set(&rep->rr_buffer->rb_credits, credits); + } + +out_schedule: + rpcrdma_schedule_tasklet(rep); +} + +static int +rpcrdma_recvcq_poll(struct ib_cq *cq, struct rpcrdma_ep *ep) +{ + struct ib_wc *wcs; + int budget, count, rc; + + budget = RPCRDMA_WC_BUDGET / RPCRDMA_POLLSIZE; + do { + wcs = ep->rep_recv_wcs; + + rc = ib_poll_cq(cq, RPCRDMA_POLLSIZE, wcs); + if (rc <= 0) + return rc; + + count = rc; + while (count-- > 0) + rpcrdma_recvcq_process_wc(wcs++); + } while (rc == RPCRDMA_POLLSIZE && --budget); return 0; } /* - * rpcrdma_cq_event_upcall + * Handle receive completions. * - * This upcall handles recv, send, bind and unbind events. * It is reentrant but processes single events in order to maintain * ordering of receives to keep server credits. * @@ -231,26 +283,31 @@ rpcrdma_cq_poll(struct ib_cq *cq) * connection shutdown. That is, the structures required for * the completion of the reply handler must remain intact until * all memory has been reclaimed. - * - * Note that send events are suppressed and do not result in an upcall. */ static void -rpcrdma_cq_event_upcall(struct ib_cq *cq, void *context) +rpcrdma_recvcq_upcall(struct ib_cq *cq, void *cq_context) { + struct rpcrdma_ep *ep = (struct rpcrdma_ep *)cq_context; int rc; - rc = rpcrdma_cq_poll(cq); - if (rc) + rc = rpcrdma_recvcq_poll(cq, ep); + if (rc) { + dprintk("RPC: %s: ib_poll_cq failed: %i\n", + __func__, rc); return; + } - rc = ib_req_notify_cq(cq, IB_CQ_NEXT_COMP); - if (rc) { - dprintk("RPC: %s: ib_req_notify_cq failed %i\n", + rc = ib_req_notify_cq(cq, + IB_CQ_NEXT_COMP | IB_CQ_REPORT_MISSED_EVENTS); + if (rc == 0) + return; + if (rc < 0) { + dprintk("RPC: %s: ib_req_notify_cq failed: %i\n", __func__, rc); return; } - rpcrdma_cq_poll(cq); + rpcrdma_recvcq_poll(cq, ep); } #ifdef RPC_DEBUG @@ -377,7 +434,7 @@ rpcrdma_create_id(struct rpcrdma_xprt *xprt, init_completion(&ia->ri_done); - id = rdma_create_id(rpcrdma_conn_upcall, xprt, RDMA_PS_TCP); + id = rdma_create_id(rpcrdma_conn_upcall, xprt, RDMA_PS_TCP, IB_QPT_RC); if (IS_ERR(id)) { rc = PTR_ERR(id); dprintk("RPC: %s: rdma_create_id() failed %i\n", @@ -482,54 +539,32 @@ rpcrdma_ia_open(struct rpcrdma_xprt *xprt, struct sockaddr *addr, int memreg) ia->ri_dma_lkey = ia->ri_id->device->local_dma_lkey; } - switch (memreg) { - case RPCRDMA_MEMWINDOWS: - case RPCRDMA_MEMWINDOWS_ASYNC: - if (!(devattr.device_cap_flags & IB_DEVICE_MEM_WINDOW)) { - dprintk("RPC: %s: MEMWINDOWS registration " - "specified but not supported by adapter, " - "using slower RPCRDMA_REGISTER\n", - __func__); - memreg = RPCRDMA_REGISTER; - } - break; - case RPCRDMA_MTHCAFMR: - if (!ia->ri_id->device->alloc_fmr) { -#if RPCRDMA_PERSISTENT_REGISTRATION - dprintk("RPC: %s: MTHCAFMR registration " - "specified but not supported by adapter, " - "using riskier RPCRDMA_ALLPHYSICAL\n", - __func__); - memreg = RPCRDMA_ALLPHYSICAL; -#else - dprintk("RPC: %s: MTHCAFMR registration " - "specified but not supported by adapter, " - "using slower RPCRDMA_REGISTER\n", - __func__); - memreg = RPCRDMA_REGISTER; -#endif - } - break; - case RPCRDMA_FRMR: + if (memreg == RPCRDMA_FRMR) { /* Requires both frmr reg and local dma lkey */ if ((devattr.device_cap_flags & (IB_DEVICE_MEM_MGT_EXTENSIONS|IB_DEVICE_LOCAL_DMA_LKEY)) != (IB_DEVICE_MEM_MGT_EXTENSIONS|IB_DEVICE_LOCAL_DMA_LKEY)) { -#if RPCRDMA_PERSISTENT_REGISTRATION dprintk("RPC: %s: FRMR registration " - "specified but not supported by adapter, " - "using riskier RPCRDMA_ALLPHYSICAL\n", - __func__); + "not supported by HCA\n", __func__); + memreg = RPCRDMA_MTHCAFMR; + } else { + /* Mind the ia limit on FRMR page list depth */ + ia->ri_max_frmr_depth = min_t(unsigned int, + RPCRDMA_MAX_DATA_SEGS, + devattr.max_fast_reg_page_list_len); + } + } + if (memreg == RPCRDMA_MTHCAFMR) { + if (!ia->ri_id->device->alloc_fmr) { + dprintk("RPC: %s: MTHCAFMR registration " + "not supported by HCA\n", __func__); +#if RPCRDMA_PERSISTENT_REGISTRATION memreg = RPCRDMA_ALLPHYSICAL; #else - dprintk("RPC: %s: FRMR registration " - "specified but not supported by adapter, " - "using slower RPCRDMA_REGISTER\n", - __func__); - memreg = RPCRDMA_REGISTER; + rc = -ENOMEM; + goto out2; #endif } - break; } /* @@ -541,8 +576,6 @@ rpcrdma_ia_open(struct rpcrdma_xprt *xprt, struct sockaddr *addr, int memreg) * adapter. */ switch (memreg) { - case RPCRDMA_BOUNCEBUFFERS: - case RPCRDMA_REGISTER: case RPCRDMA_FRMR: break; #if RPCRDMA_PERSISTENT_REGISTRATION @@ -552,30 +585,26 @@ rpcrdma_ia_open(struct rpcrdma_xprt *xprt, struct sockaddr *addr, int memreg) IB_ACCESS_REMOTE_READ; goto register_setup; #endif - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - mem_priv = IB_ACCESS_LOCAL_WRITE | - IB_ACCESS_MW_BIND; - goto register_setup; case RPCRDMA_MTHCAFMR: if (ia->ri_have_dma_lkey) break; mem_priv = IB_ACCESS_LOCAL_WRITE; +#if RPCRDMA_PERSISTENT_REGISTRATION register_setup: +#endif ia->ri_bind_mem = ib_get_dma_mr(ia->ri_pd, mem_priv); if (IS_ERR(ia->ri_bind_mem)) { printk(KERN_ALERT "%s: ib_get_dma_mr for " - "phys register failed with %lX\n\t" - "Will continue with degraded performance\n", + "phys register failed with %lX\n", __func__, PTR_ERR(ia->ri_bind_mem)); - memreg = RPCRDMA_REGISTER; - ia->ri_bind_mem = NULL; + rc = -ENOMEM; + goto out2; } break; default: - printk(KERN_ERR "%s: invalid memory registration mode %d\n", - __func__, memreg); - rc = -EINVAL; + printk(KERN_ERR "RPC: Unsupported memory " + "registration mode: %d\n", memreg); + rc = -ENOMEM; goto out2; } dprintk("RPC: %s: memory registration strategy is %d\n", @@ -629,6 +658,7 @@ rpcrdma_ep_create(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, struct rpcrdma_create_data_internal *cdata) { struct ib_device_attr devattr; + struct ib_cq *sendcq, *recvcq; int rc, err; rc = ib_query_device(ia->ri_id->device, &devattr); @@ -648,20 +678,42 @@ rpcrdma_ep_create(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, ep->rep_attr.srq = NULL; ep->rep_attr.cap.max_send_wr = cdata->max_requests; switch (ia->ri_memreg_strategy) { - case RPCRDMA_FRMR: - /* Add room for frmr register and invalidate WRs */ - ep->rep_attr.cap.max_send_wr *= 3; - if (ep->rep_attr.cap.max_send_wr > devattr.max_qp_wr) - return -EINVAL; - break; - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - /* Add room for mw_binds+unbinds - overkill! */ - ep->rep_attr.cap.max_send_wr++; - ep->rep_attr.cap.max_send_wr *= (2 * RPCRDMA_MAX_SEGS); - if (ep->rep_attr.cap.max_send_wr > devattr.max_qp_wr) - return -EINVAL; + case RPCRDMA_FRMR: { + int depth = 7; + + /* Add room for frmr register and invalidate WRs. + * 1. FRMR reg WR for head + * 2. FRMR invalidate WR for head + * 3. N FRMR reg WRs for pagelist + * 4. N FRMR invalidate WRs for pagelist + * 5. FRMR reg WR for tail + * 6. FRMR invalidate WR for tail + * 7. The RDMA_SEND WR + */ + + /* Calculate N if the device max FRMR depth is smaller than + * RPCRDMA_MAX_DATA_SEGS. + */ + if (ia->ri_max_frmr_depth < RPCRDMA_MAX_DATA_SEGS) { + int delta = RPCRDMA_MAX_DATA_SEGS - + ia->ri_max_frmr_depth; + + do { + depth += 2; /* FRMR reg + invalidate */ + delta -= ia->ri_max_frmr_depth; + } while (delta > 0); + + } + ep->rep_attr.cap.max_send_wr *= depth; + if (ep->rep_attr.cap.max_send_wr > devattr.max_qp_wr) { + cdata->max_requests = devattr.max_qp_wr / depth; + if (!cdata->max_requests) + return -EINVAL; + ep->rep_attr.cap.max_send_wr = cdata->max_requests * + depth; + } break; + } default: break; } @@ -682,46 +734,51 @@ rpcrdma_ep_create(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, ep->rep_attr.cap.max_recv_sge); /* set trigger for requesting send completion */ - ep->rep_cqinit = ep->rep_attr.cap.max_send_wr/2 /* - 1*/; - switch (ia->ri_memreg_strategy) { - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - ep->rep_cqinit -= RPCRDMA_MAX_SEGS; - break; - default: - break; - } + ep->rep_cqinit = ep->rep_attr.cap.max_send_wr/2 - 1; if (ep->rep_cqinit <= 2) ep->rep_cqinit = 0; INIT_CQCOUNT(ep); ep->rep_ia = ia; init_waitqueue_head(&ep->rep_connect_wait); + INIT_DELAYED_WORK(&ep->rep_connect_worker, rpcrdma_connect_worker); - /* - * Create a single cq for receive dto and mw_bind (only ever - * care about unbind, really). Send completions are suppressed. - * Use single threaded tasklet upcalls to maintain ordering. - */ - ep->rep_cq = ib_create_cq(ia->ri_id->device, rpcrdma_cq_event_upcall, - rpcrdma_cq_async_error_upcall, NULL, - ep->rep_attr.cap.max_recv_wr + + sendcq = ib_create_cq(ia->ri_id->device, rpcrdma_sendcq_upcall, + rpcrdma_cq_async_error_upcall, ep, ep->rep_attr.cap.max_send_wr + 1, 0); - if (IS_ERR(ep->rep_cq)) { - rc = PTR_ERR(ep->rep_cq); - dprintk("RPC: %s: ib_create_cq failed: %i\n", + if (IS_ERR(sendcq)) { + rc = PTR_ERR(sendcq); + dprintk("RPC: %s: failed to create send CQ: %i\n", __func__, rc); goto out1; } - rc = ib_req_notify_cq(ep->rep_cq, IB_CQ_NEXT_COMP); + rc = ib_req_notify_cq(sendcq, IB_CQ_NEXT_COMP); if (rc) { dprintk("RPC: %s: ib_req_notify_cq failed: %i\n", __func__, rc); goto out2; } - ep->rep_attr.send_cq = ep->rep_cq; - ep->rep_attr.recv_cq = ep->rep_cq; + recvcq = ib_create_cq(ia->ri_id->device, rpcrdma_recvcq_upcall, + rpcrdma_cq_async_error_upcall, ep, + ep->rep_attr.cap.max_recv_wr + 1, 0); + if (IS_ERR(recvcq)) { + rc = PTR_ERR(recvcq); + dprintk("RPC: %s: failed to create recv CQ: %i\n", + __func__, rc); + goto out2; + } + + rc = ib_req_notify_cq(recvcq, IB_CQ_NEXT_COMP); + if (rc) { + dprintk("RPC: %s: ib_req_notify_cq failed: %i\n", + __func__, rc); + ib_destroy_cq(recvcq); + goto out2; + } + + ep->rep_attr.send_cq = sendcq; + ep->rep_attr.recv_cq = recvcq; /* Initialize cma parameters */ @@ -731,9 +788,7 @@ rpcrdma_ep_create(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, /* Client offers RDMA Read but does not initiate */ ep->rep_remote_cma.initiator_depth = 0; - if (ia->ri_memreg_strategy == RPCRDMA_BOUNCEBUFFERS) - ep->rep_remote_cma.responder_resources = 0; - else if (devattr.max_qp_rd_atom > 32) /* arbitrary but <= 255 */ + if (devattr.max_qp_rd_atom > 32) /* arbitrary but <= 255 */ ep->rep_remote_cma.responder_resources = 32; else ep->rep_remote_cma.responder_resources = devattr.max_qp_rd_atom; @@ -745,7 +800,7 @@ rpcrdma_ep_create(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, return 0; out2: - err = ib_destroy_cq(ep->rep_cq); + err = ib_destroy_cq(sendcq); if (err) dprintk("RPC: %s: ib_destroy_cq returned %i\n", __func__, err); @@ -759,11 +814,8 @@ out1: * Disconnect and destroy endpoint. After this, the only * valid operations on the ep are to free it (if dynamically * allocated) or re-create it. - * - * The caller's error handling must be sure to not leak the endpoint - * if this function fails. */ -int +void rpcrdma_ep_destroy(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) { int rc; @@ -771,6 +823,8 @@ rpcrdma_ep_destroy(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) dprintk("RPC: %s: entering, connected is %d\n", __func__, ep->rep_connected); + cancel_delayed_work_sync(&ep->rep_connect_worker); + if (ia->ri_id->qp) { rc = rpcrdma_ep_disconnect(ep, ia); if (rc) @@ -786,13 +840,17 @@ rpcrdma_ep_destroy(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) ep->rep_pad_mr = NULL; } - rpcrdma_clean_cq(ep->rep_cq); - rc = ib_destroy_cq(ep->rep_cq); + rpcrdma_clean_cq(ep->rep_attr.recv_cq); + rc = ib_destroy_cq(ep->rep_attr.recv_cq); if (rc) dprintk("RPC: %s: ib_destroy_cq returned %i\n", __func__, rc); - return rc; + rpcrdma_clean_cq(ep->rep_attr.send_cq); + rc = ib_destroy_cq(ep->rep_attr.send_cq); + if (rc) + dprintk("RPC: %s: ib_destroy_cq returned %i\n", + __func__, rc); } /* @@ -808,17 +866,20 @@ rpcrdma_ep_connect(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) if (ep->rep_connected != 0) { struct rpcrdma_xprt *xprt; retry: + dprintk("RPC: %s: reconnecting...\n", __func__); rc = rpcrdma_ep_disconnect(ep, ia); if (rc && rc != -ENOTCONN) dprintk("RPC: %s: rpcrdma_ep_disconnect" " status %i\n", __func__, rc); - rpcrdma_clean_cq(ep->rep_cq); + + rpcrdma_clean_cq(ep->rep_attr.recv_cq); + rpcrdma_clean_cq(ep->rep_attr.send_cq); xprt = container_of(ia, struct rpcrdma_xprt, rx_ia); id = rpcrdma_create_id(xprt, ia, (struct sockaddr *)&xprt->rx_data.addr); if (IS_ERR(id)) { - rc = PTR_ERR(id); + rc = -EHOSTUNREACH; goto out; } /* TEMP TEMP TEMP - fail if new device: @@ -832,35 +893,32 @@ retry: printk("RPC: %s: can't reconnect on " "different device!\n", __func__); rdma_destroy_id(id); - rc = -ENETDOWN; + rc = -ENETUNREACH; goto out; } /* END TEMP */ + rc = rdma_create_qp(id, ia->ri_pd, &ep->rep_attr); + if (rc) { + dprintk("RPC: %s: rdma_create_qp failed %i\n", + __func__, rc); + rdma_destroy_id(id); + rc = -ENETUNREACH; + goto out; + } rdma_destroy_qp(ia->ri_id); rdma_destroy_id(ia->ri_id); ia->ri_id = id; + } else { + dprintk("RPC: %s: connecting...\n", __func__); + rc = rdma_create_qp(ia->ri_id, ia->ri_pd, &ep->rep_attr); + if (rc) { + dprintk("RPC: %s: rdma_create_qp failed %i\n", + __func__, rc); + /* do not update ep->rep_connected */ + return -ENETUNREACH; + } } - rc = rdma_create_qp(ia->ri_id, ia->ri_pd, &ep->rep_attr); - if (rc) { - dprintk("RPC: %s: rdma_create_qp failed %i\n", - __func__, rc); - goto out; - } - -/* XXX Tavor device performs badly with 2K MTU! */ -if (strnicmp(ia->ri_id->device->dma_device->bus->name, "pci", 3) == 0) { - struct pci_dev *pcid = to_pci_dev(ia->ri_id->device->dma_device); - if (pcid->device == PCI_DEVICE_ID_MELLANOX_TAVOR && - (pcid->vendor == PCI_VENDOR_ID_MELLANOX || - pcid->vendor == PCI_VENDOR_ID_TOPSPIN)) { - struct ib_qp_attr attr = { - .path_mtu = IB_MTU_1024 - }; - rc = ib_modify_qp(ia->ri_id->qp, &attr, IB_QP_PATH_MTU); - } -} - ep->rep_connected = 0; rc = rdma_connect(ia->ri_id, &ep->rep_remote_cma); @@ -921,7 +979,8 @@ rpcrdma_ep_disconnect(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) { int rc; - rpcrdma_clean_cq(ep->rep_cq); + rpcrdma_clean_cq(ep->rep_attr.recv_cq); + rpcrdma_clean_cq(ep->rep_attr.send_cq); rc = rdma_disconnect(ia->ri_id); if (!rc) { /* returns without wait if not connected */ @@ -944,7 +1003,7 @@ rpcrdma_buffer_create(struct rpcrdma_buffer *buf, struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, struct rpcrdma_create_data_internal *cdata) { char *p; - size_t len; + size_t len, rlen, wlen; int i, rc; struct rpcrdma_mw *r; @@ -974,11 +1033,6 @@ rpcrdma_buffer_create(struct rpcrdma_buffer *buf, struct rpcrdma_ep *ep, len += (buf->rb_max_requests + 1) * RPCRDMA_MAX_SEGS * sizeof(struct rpcrdma_mw); break; - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - len += (buf->rb_max_requests + 1) * RPCRDMA_MAX_SEGS * - sizeof(struct rpcrdma_mw); - break; default: break; } @@ -1009,32 +1063,29 @@ rpcrdma_buffer_create(struct rpcrdma_buffer *buf, struct rpcrdma_ep *ep, } p += cdata->padding; - /* - * Allocate the fmr's, or mw's for mw_bind chunk registration. - * We "cycle" the mw's in order to minimize rkey reuse, - * and also reduce unbind-to-bind collision. - */ INIT_LIST_HEAD(&buf->rb_mws); r = (struct rpcrdma_mw *)p; switch (ia->ri_memreg_strategy) { case RPCRDMA_FRMR: for (i = buf->rb_max_requests * RPCRDMA_MAX_SEGS; i; i--) { r->r.frmr.fr_mr = ib_alloc_fast_reg_mr(ia->ri_pd, - RPCRDMA_MAX_SEGS); + ia->ri_max_frmr_depth); if (IS_ERR(r->r.frmr.fr_mr)) { rc = PTR_ERR(r->r.frmr.fr_mr); dprintk("RPC: %s: ib_alloc_fast_reg_mr" " failed %i\n", __func__, rc); goto out; } - r->r.frmr.fr_pgl = - ib_alloc_fast_reg_page_list(ia->ri_id->device, - RPCRDMA_MAX_SEGS); + r->r.frmr.fr_pgl = ib_alloc_fast_reg_page_list( + ia->ri_id->device, + ia->ri_max_frmr_depth); if (IS_ERR(r->r.frmr.fr_pgl)) { rc = PTR_ERR(r->r.frmr.fr_pgl); dprintk("RPC: %s: " "ib_alloc_fast_reg_page_list " "failed %i\n", __func__, rc); + + ib_dereg_mr(r->r.frmr.fr_mr); goto out; } list_add(&r->mw_list, &buf->rb_mws); @@ -1059,21 +1110,6 @@ rpcrdma_buffer_create(struct rpcrdma_buffer *buf, struct rpcrdma_ep *ep, ++r; } break; - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - /* Allocate one extra request's worth, for full cycling */ - for (i = (buf->rb_max_requests+1) * RPCRDMA_MAX_SEGS; i; i--) { - r->r.mw = ib_alloc_mw(ia->ri_pd); - if (IS_ERR(r->r.mw)) { - rc = PTR_ERR(r->r.mw); - dprintk("RPC: %s: ib_alloc_mw" - " failed %i\n", __func__, rc); - goto out; - } - list_add(&r->mw_list, &buf->rb_mws); - ++r; - } - break; default: break; } @@ -1082,16 +1118,16 @@ rpcrdma_buffer_create(struct rpcrdma_buffer *buf, struct rpcrdma_ep *ep, * Allocate/init the request/reply buffers. Doing this * using kmalloc for now -- one for each buf. */ + wlen = 1 << fls(cdata->inline_wsize + sizeof(struct rpcrdma_req)); + rlen = 1 << fls(cdata->inline_rsize + sizeof(struct rpcrdma_rep)); + dprintk("RPC: %s: wlen = %zu, rlen = %zu\n", + __func__, wlen, rlen); + for (i = 0; i < buf->rb_max_requests; i++) { struct rpcrdma_req *req; struct rpcrdma_rep *rep; - len = cdata->inline_wsize + sizeof(struct rpcrdma_req); - /* RPC layer requests *double* size + 1K RPC_SLACK_SPACE! */ - /* Typical ~2400b, so rounding up saves work later */ - if (len < 4096) - len = 4096; - req = kmalloc(len, GFP_KERNEL); + req = kmalloc(wlen, GFP_KERNEL); if (req == NULL) { dprintk("RPC: %s: request buffer %d alloc" " failed\n", __func__, i); @@ -1103,16 +1139,16 @@ rpcrdma_buffer_create(struct rpcrdma_buffer *buf, struct rpcrdma_ep *ep, buf->rb_send_bufs[i]->rl_buffer = buf; rc = rpcrdma_register_internal(ia, req->rl_base, - len - offsetof(struct rpcrdma_req, rl_base), + wlen - offsetof(struct rpcrdma_req, rl_base), &buf->rb_send_bufs[i]->rl_handle, &buf->rb_send_bufs[i]->rl_iov); if (rc) goto out; - buf->rb_send_bufs[i]->rl_size = len-sizeof(struct rpcrdma_req); + buf->rb_send_bufs[i]->rl_size = wlen - + sizeof(struct rpcrdma_req); - len = cdata->inline_rsize + sizeof(struct rpcrdma_rep); - rep = kmalloc(len, GFP_KERNEL); + rep = kmalloc(rlen, GFP_KERNEL); if (rep == NULL) { dprintk("RPC: %s: reply buffer %d alloc failed\n", __func__, i); @@ -1122,10 +1158,9 @@ rpcrdma_buffer_create(struct rpcrdma_buffer *buf, struct rpcrdma_ep *ep, memset(rep, 0, sizeof(struct rpcrdma_rep)); buf->rb_recv_bufs[i] = rep; buf->rb_recv_bufs[i]->rr_buffer = buf; - init_waitqueue_head(&rep->rr_unbind); rc = rpcrdma_register_internal(ia, rep->rr_base, - len - offsetof(struct rpcrdma_rep, rr_base), + rlen - offsetof(struct rpcrdma_rep, rr_base), &buf->rb_recv_bufs[i]->rr_handle, &buf->rb_recv_bufs[i]->rr_iov); if (rc) @@ -1156,7 +1191,6 @@ rpcrdma_buffer_destroy(struct rpcrdma_buffer *buf) /* clean up in reverse order from create * 1. recv mr memory (mr free, then kfree) - * 1a. bind mw memory * 2. send mr memory (mr free, then kfree) * 3. padding (if any) [moved to rpcrdma_ep_destroy] * 4. arrays @@ -1171,41 +1205,6 @@ rpcrdma_buffer_destroy(struct rpcrdma_buffer *buf) kfree(buf->rb_recv_bufs[i]); } if (buf->rb_send_bufs && buf->rb_send_bufs[i]) { - while (!list_empty(&buf->rb_mws)) { - r = list_entry(buf->rb_mws.next, - struct rpcrdma_mw, mw_list); - list_del(&r->mw_list); - switch (ia->ri_memreg_strategy) { - case RPCRDMA_FRMR: - rc = ib_dereg_mr(r->r.frmr.fr_mr); - if (rc) - dprintk("RPC: %s:" - " ib_dereg_mr" - " failed %i\n", - __func__, rc); - ib_free_fast_reg_page_list(r->r.frmr.fr_pgl); - break; - case RPCRDMA_MTHCAFMR: - rc = ib_dealloc_fmr(r->r.fmr); - if (rc) - dprintk("RPC: %s:" - " ib_dealloc_fmr" - " failed %i\n", - __func__, rc); - break; - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - rc = ib_dealloc_mw(r->r.mw); - if (rc) - dprintk("RPC: %s:" - " ib_dealloc_mw" - " failed %i\n", - __func__, rc); - break; - default: - break; - } - } rpcrdma_deregister_internal(ia, buf->rb_send_bufs[i]->rl_handle, &buf->rb_send_bufs[i]->rl_iov); @@ -1213,6 +1212,33 @@ rpcrdma_buffer_destroy(struct rpcrdma_buffer *buf) } } + while (!list_empty(&buf->rb_mws)) { + r = list_entry(buf->rb_mws.next, + struct rpcrdma_mw, mw_list); + list_del(&r->mw_list); + switch (ia->ri_memreg_strategy) { + case RPCRDMA_FRMR: + rc = ib_dereg_mr(r->r.frmr.fr_mr); + if (rc) + dprintk("RPC: %s:" + " ib_dereg_mr" + " failed %i\n", + __func__, rc); + ib_free_fast_reg_page_list(r->r.frmr.fr_pgl); + break; + case RPCRDMA_MTHCAFMR: + rc = ib_dealloc_fmr(r->r.fmr); + if (rc) + dprintk("RPC: %s:" + " ib_dealloc_fmr" + " failed %i\n", + __func__, rc); + break; + default: + break; + } + } + kfree(buf->rb_pool); } @@ -1276,21 +1302,17 @@ rpcrdma_buffer_put(struct rpcrdma_req *req) int i; unsigned long flags; - BUG_ON(req->rl_nchunks != 0); spin_lock_irqsave(&buffers->rb_lock, flags); buffers->rb_send_bufs[--buffers->rb_send_index] = req; req->rl_niovs = 0; if (req->rl_reply) { buffers->rb_recv_bufs[--buffers->rb_recv_index] = req->rl_reply; - init_waitqueue_head(&req->rl_reply->rr_unbind); req->rl_reply->rr_func = NULL; req->rl_reply = NULL; } switch (ia->ri_memreg_strategy) { case RPCRDMA_FRMR: case RPCRDMA_MTHCAFMR: - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: /* * Cycle mw's back in reverse order, and "spin" them. * This delays and scrambles reuse as much as possible. @@ -1335,8 +1357,7 @@ rpcrdma_recv_buffer_get(struct rpcrdma_req *req) /* * Put reply buffers back into pool when not attached to - * request. This happens in error conditions, and when - * aborting unbinds. Pre-decrement counter/array index. + * request. This happens in error conditions. */ void rpcrdma_recv_buffer_put(struct rpcrdma_rep *rep) @@ -1437,6 +1458,12 @@ rpcrdma_map_one(struct rpcrdma_ia *ia, struct rpcrdma_mr_seg *seg, int writing) seg->mr_dma = ib_dma_map_single(ia->ri_id->device, seg->mr_offset, seg->mr_dmalen, seg->mr_dir); + if (ib_dma_mapping_error(ia->ri_id->device, seg->mr_dma)) { + dprintk("RPC: %s: mr_dma %llx mr_offset %p mr_dma_len %zu\n", + __func__, + (unsigned long long)seg->mr_dma, + seg->mr_offset, seg->mr_dmalen); + } } static void @@ -1456,20 +1483,29 @@ rpcrdma_register_frmr_external(struct rpcrdma_mr_seg *seg, struct rpcrdma_xprt *r_xprt) { struct rpcrdma_mr_seg *seg1 = seg; - struct ib_send_wr frmr_wr, *bad_wr; + struct ib_send_wr invalidate_wr, frmr_wr, *bad_wr, *post_wr; + u8 key; int len, pageoff; int i, rc; + int seg_len; + u64 pa; + int page_no; pageoff = offset_in_page(seg1->mr_offset); seg1->mr_offset -= pageoff; /* start of page */ seg1->mr_len += pageoff; len = -pageoff; - if (*nsegs > RPCRDMA_MAX_DATA_SEGS) - *nsegs = RPCRDMA_MAX_DATA_SEGS; - for (i = 0; i < *nsegs;) { + if (*nsegs > ia->ri_max_frmr_depth) + *nsegs = ia->ri_max_frmr_depth; + for (page_no = i = 0; i < *nsegs;) { rpcrdma_map_one(ia, seg, writing); - seg1->mr_chunk.rl_mw->r.frmr.fr_pgl->page_list[i] = seg->mr_dma; + pa = seg->mr_dma; + for (seg_len = seg->mr_len; seg_len > 0; seg_len -= PAGE_SIZE) { + seg1->mr_chunk.rl_mw->r.frmr.fr_pgl-> + page_list[page_no++] = pa; + pa += PAGE_SIZE; + } len += seg->mr_len; ++seg; ++i; @@ -1481,26 +1517,50 @@ rpcrdma_register_frmr_external(struct rpcrdma_mr_seg *seg, dprintk("RPC: %s: Using frmr %p to map %d segments\n", __func__, seg1->mr_chunk.rl_mw, i); - /* Bump the key */ - key = (u8)(seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey & 0x000000FF); - ib_update_fast_reg_key(seg1->mr_chunk.rl_mw->r.frmr.fr_mr, ++key); + if (unlikely(seg1->mr_chunk.rl_mw->r.frmr.state == FRMR_IS_VALID)) { + dprintk("RPC: %s: frmr %x left valid, posting invalidate.\n", + __func__, + seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey); + /* Invalidate before using. */ + memset(&invalidate_wr, 0, sizeof invalidate_wr); + invalidate_wr.wr_id = (unsigned long)(void *)seg1->mr_chunk.rl_mw; + invalidate_wr.next = &frmr_wr; + invalidate_wr.opcode = IB_WR_LOCAL_INV; + invalidate_wr.send_flags = IB_SEND_SIGNALED; + invalidate_wr.ex.invalidate_rkey = + seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey; + DECR_CQCOUNT(&r_xprt->rx_ep); + post_wr = &invalidate_wr; + } else + post_wr = &frmr_wr; /* Prepare FRMR WR */ memset(&frmr_wr, 0, sizeof frmr_wr); + frmr_wr.wr_id = (unsigned long)(void *)seg1->mr_chunk.rl_mw; frmr_wr.opcode = IB_WR_FAST_REG_MR; - frmr_wr.send_flags = 0; /* unsignaled */ - frmr_wr.wr.fast_reg.iova_start = (unsigned long)seg1->mr_dma; + frmr_wr.send_flags = IB_SEND_SIGNALED; + frmr_wr.wr.fast_reg.iova_start = seg1->mr_dma; frmr_wr.wr.fast_reg.page_list = seg1->mr_chunk.rl_mw->r.frmr.fr_pgl; - frmr_wr.wr.fast_reg.page_list_len = i; + frmr_wr.wr.fast_reg.page_list_len = page_no; frmr_wr.wr.fast_reg.page_shift = PAGE_SHIFT; - frmr_wr.wr.fast_reg.length = i << PAGE_SHIFT; + frmr_wr.wr.fast_reg.length = page_no << PAGE_SHIFT; + if (frmr_wr.wr.fast_reg.length < len) { + while (seg1->mr_nsegs--) + rpcrdma_unmap_one(ia, seg++); + return -EIO; + } + + /* Bump the key */ + key = (u8)(seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey & 0x000000FF); + ib_update_fast_reg_key(seg1->mr_chunk.rl_mw->r.frmr.fr_mr, ++key); + frmr_wr.wr.fast_reg.access_flags = (writing ? IB_ACCESS_REMOTE_WRITE | IB_ACCESS_LOCAL_WRITE : IB_ACCESS_REMOTE_READ); frmr_wr.wr.fast_reg.rkey = seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey; DECR_CQCOUNT(&r_xprt->rx_ep); - rc = ib_post_send(ia->ri_id->qp, &frmr_wr, &bad_wr); + rc = ib_post_send(ia->ri_id->qp, post_wr, &bad_wr); if (rc) { dprintk("RPC: %s: failed ib_post_send for register," @@ -1529,8 +1589,9 @@ rpcrdma_deregister_frmr_external(struct rpcrdma_mr_seg *seg, rpcrdma_unmap_one(ia, seg++); memset(&invalidate_wr, 0, sizeof invalidate_wr); + invalidate_wr.wr_id = (unsigned long)(void *)seg1->mr_chunk.rl_mw; invalidate_wr.opcode = IB_WR_LOCAL_INV; - invalidate_wr.send_flags = 0; /* unsignaled */ + invalidate_wr.send_flags = IB_SEND_SIGNALED; invalidate_wr.ex.invalidate_rkey = seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey; DECR_CQCOUNT(&r_xprt->rx_ep); @@ -1603,135 +1664,6 @@ rpcrdma_deregister_fmr_external(struct rpcrdma_mr_seg *seg, return rc; } -static int -rpcrdma_register_memwin_external(struct rpcrdma_mr_seg *seg, - int *nsegs, int writing, struct rpcrdma_ia *ia, - struct rpcrdma_xprt *r_xprt) -{ - int mem_priv = (writing ? IB_ACCESS_REMOTE_WRITE : - IB_ACCESS_REMOTE_READ); - struct ib_mw_bind param; - int rc; - - *nsegs = 1; - rpcrdma_map_one(ia, seg, writing); - param.mr = ia->ri_bind_mem; - param.wr_id = 0ULL; /* no send cookie */ - param.addr = seg->mr_dma; - param.length = seg->mr_len; - param.send_flags = 0; - param.mw_access_flags = mem_priv; - - DECR_CQCOUNT(&r_xprt->rx_ep); - rc = ib_bind_mw(ia->ri_id->qp, seg->mr_chunk.rl_mw->r.mw, ¶m); - if (rc) { - dprintk("RPC: %s: failed ib_bind_mw " - "%u@0x%llx status %i\n", - __func__, seg->mr_len, - (unsigned long long)seg->mr_dma, rc); - rpcrdma_unmap_one(ia, seg); - } else { - seg->mr_rkey = seg->mr_chunk.rl_mw->r.mw->rkey; - seg->mr_base = param.addr; - seg->mr_nsegs = 1; - } - return rc; -} - -static int -rpcrdma_deregister_memwin_external(struct rpcrdma_mr_seg *seg, - struct rpcrdma_ia *ia, - struct rpcrdma_xprt *r_xprt, void **r) -{ - struct ib_mw_bind param; - LIST_HEAD(l); - int rc; - - BUG_ON(seg->mr_nsegs != 1); - param.mr = ia->ri_bind_mem; - param.addr = 0ULL; /* unbind */ - param.length = 0; - param.mw_access_flags = 0; - if (*r) { - param.wr_id = (u64) (unsigned long) *r; - param.send_flags = IB_SEND_SIGNALED; - INIT_CQCOUNT(&r_xprt->rx_ep); - } else { - param.wr_id = 0ULL; - param.send_flags = 0; - DECR_CQCOUNT(&r_xprt->rx_ep); - } - rc = ib_bind_mw(ia->ri_id->qp, seg->mr_chunk.rl_mw->r.mw, ¶m); - rpcrdma_unmap_one(ia, seg); - if (rc) - dprintk("RPC: %s: failed ib_(un)bind_mw," - " status %i\n", __func__, rc); - else - *r = NULL; /* will upcall on completion */ - return rc; -} - -static int -rpcrdma_register_default_external(struct rpcrdma_mr_seg *seg, - int *nsegs, int writing, struct rpcrdma_ia *ia) -{ - int mem_priv = (writing ? IB_ACCESS_REMOTE_WRITE : - IB_ACCESS_REMOTE_READ); - struct rpcrdma_mr_seg *seg1 = seg; - struct ib_phys_buf ipb[RPCRDMA_MAX_DATA_SEGS]; - int len, i, rc = 0; - - if (*nsegs > RPCRDMA_MAX_DATA_SEGS) - *nsegs = RPCRDMA_MAX_DATA_SEGS; - for (len = 0, i = 0; i < *nsegs;) { - rpcrdma_map_one(ia, seg, writing); - ipb[i].addr = seg->mr_dma; - ipb[i].size = seg->mr_len; - len += seg->mr_len; - ++seg; - ++i; - /* Check for holes */ - if ((i < *nsegs && offset_in_page(seg->mr_offset)) || - offset_in_page((seg-1)->mr_offset+(seg-1)->mr_len)) - break; - } - seg1->mr_base = seg1->mr_dma; - seg1->mr_chunk.rl_mr = ib_reg_phys_mr(ia->ri_pd, - ipb, i, mem_priv, &seg1->mr_base); - if (IS_ERR(seg1->mr_chunk.rl_mr)) { - rc = PTR_ERR(seg1->mr_chunk.rl_mr); - dprintk("RPC: %s: failed ib_reg_phys_mr " - "%u@0x%llx (%d)... status %i\n", - __func__, len, - (unsigned long long)seg1->mr_dma, i, rc); - while (i--) - rpcrdma_unmap_one(ia, --seg); - } else { - seg1->mr_rkey = seg1->mr_chunk.rl_mr->rkey; - seg1->mr_nsegs = i; - seg1->mr_len = len; - } - *nsegs = i; - return rc; -} - -static int -rpcrdma_deregister_default_external(struct rpcrdma_mr_seg *seg, - struct rpcrdma_ia *ia) -{ - struct rpcrdma_mr_seg *seg1 = seg; - int rc; - - rc = ib_dereg_mr(seg1->mr_chunk.rl_mr); - seg1->mr_chunk.rl_mr = NULL; - while (seg1->mr_nsegs--) - rpcrdma_unmap_one(ia, seg++); - if (rc) - dprintk("RPC: %s: failed ib_dereg_mr," - " status %i\n", __func__, rc); - return rc; -} - int rpcrdma_register_external(struct rpcrdma_mr_seg *seg, int nsegs, int writing, struct rpcrdma_xprt *r_xprt) @@ -1761,16 +1693,8 @@ rpcrdma_register_external(struct rpcrdma_mr_seg *seg, rc = rpcrdma_register_fmr_external(seg, &nsegs, writing, ia); break; - /* Registration using memory windows */ - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - rc = rpcrdma_register_memwin_external(seg, &nsegs, writing, ia, r_xprt); - break; - - /* Default registration each time */ default: - rc = rpcrdma_register_default_external(seg, &nsegs, writing, ia); - break; + return -1; } if (rc) return -1; @@ -1780,7 +1704,7 @@ rpcrdma_register_external(struct rpcrdma_mr_seg *seg, int rpcrdma_deregister_external(struct rpcrdma_mr_seg *seg, - struct rpcrdma_xprt *r_xprt, void *r) + struct rpcrdma_xprt *r_xprt) { struct rpcrdma_ia *ia = &r_xprt->rx_ia; int nsegs = seg->mr_nsegs, rc; @@ -1789,9 +1713,7 @@ rpcrdma_deregister_external(struct rpcrdma_mr_seg *seg, #if RPCRDMA_PERSISTENT_REGISTRATION case RPCRDMA_ALLPHYSICAL: - BUG_ON(nsegs != 1); rpcrdma_unmap_one(ia, seg); - rc = 0; break; #endif @@ -1803,21 +1725,9 @@ rpcrdma_deregister_external(struct rpcrdma_mr_seg *seg, rc = rpcrdma_deregister_fmr_external(seg, ia); break; - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - rc = rpcrdma_deregister_memwin_external(seg, ia, r_xprt, &r); - break; - default: - rc = rpcrdma_deregister_default_external(seg, ia); break; } - if (r) { - struct rpcrdma_rep *rep = r; - void (*func)(struct rpcrdma_rep *) = rep->rr_func; - rep->rr_func = NULL; - func(rep); /* dereg done, callback now */ - } return nsegs; } @@ -1892,7 +1802,6 @@ rpcrdma_ep_post_recv(struct rpcrdma_ia *ia, ib_dma_sync_single_for_cpu(ia->ri_id->device, rep->rr_iov.addr, rep->rr_iov.length, DMA_BIDIRECTIONAL); - DECR_CQCOUNT(ep); rc = ib_post_recv(ia->ri_id->qp, &recv_wr, &recv_wr_fail); if (rc) diff --git a/net/sunrpc/xprtrdma/xprt_rdma.h b/net/sunrpc/xprtrdma/xprt_rdma.h index c7a7eba991b..89e7cd47970 100644 --- a/net/sunrpc/xprtrdma/xprt_rdma.h +++ b/net/sunrpc/xprtrdma/xprt_rdma.h @@ -42,7 +42,8 @@ #include <linux/wait.h> /* wait_queue_head_t, etc */ #include <linux/spinlock.h> /* spinlock_t, etc */ -#include <asm/atomic.h> /* atomic_t, etc */ +#include <linux/atomic.h> /* atomic_t, etc */ +#include <linux/workqueue.h> /* struct work_struct */ #include <rdma/rdma_cm.h> /* RDMA connection api */ #include <rdma/ib_verbs.h> /* RDMA verbs api */ @@ -66,18 +67,21 @@ struct rpcrdma_ia { struct completion ri_done; int ri_async_rc; enum rpcrdma_memreg ri_memreg_strategy; + unsigned int ri_max_frmr_depth; }; /* * RDMA Endpoint -- one per transport instance */ +#define RPCRDMA_WC_BUDGET (128) +#define RPCRDMA_POLLSIZE (16) + struct rpcrdma_ep { atomic_t rep_cqcount; int rep_cqinit; int rep_connected; struct rpcrdma_ia *rep_ia; - struct ib_cq *rep_cq; struct ib_qp_init_attr rep_attr; wait_queue_head_t rep_connect_wait; struct ib_sge rep_pad; /* holds zeroed pad */ @@ -86,6 +90,9 @@ struct rpcrdma_ep { struct rpc_xprt *rep_xprt; /* for rep_func */ struct rdma_conn_param rep_remote_cma; struct sockaddr_storage rep_remote_addr; + struct delayed_work rep_connect_worker; + struct ib_wc rep_send_wcs[RPCRDMA_POLLSIZE]; + struct ib_wc rep_recv_wcs[RPCRDMA_POLLSIZE]; }; #define INIT_CQCOUNT(ep) atomic_set(&(ep)->rep_cqcount, (ep)->rep_cqinit) @@ -109,7 +116,7 @@ struct rpcrdma_ep { */ /* temporary static scatter/gather max */ -#define RPCRDMA_MAX_DATA_SEGS (8) /* max scatter/gather */ +#define RPCRDMA_MAX_DATA_SEGS (64) /* max scatter/gather */ #define RPCRDMA_MAX_SEGS (RPCRDMA_MAX_DATA_SEGS + 2) /* head+tail = 2 */ #define MAX_RPCRDMAHDR (\ /* max supported RPC/RDMA header */ \ @@ -124,7 +131,6 @@ struct rpcrdma_rep { struct rpc_xprt *rr_xprt; /* needed for request/reply matching */ void (*rr_func)(struct rpcrdma_rep *);/* called by tasklet in softint */ struct list_head rr_list; /* tasklet list */ - wait_queue_head_t rr_unbind; /* optional unbind wait */ struct ib_sge rr_iov; /* for posting */ struct ib_mr *rr_handle; /* handle for mem in rr_iov */ char rr_base[MAX_RPCRDMAHDR]; /* minimal inline receive buffer */ @@ -159,11 +165,11 @@ struct rpcrdma_mr_seg { /* chunk descriptors */ struct ib_mr *rl_mr; /* if registered directly */ struct rpcrdma_mw { /* if registered from region */ union { - struct ib_mw *mw; struct ib_fmr *fmr; struct { struct ib_fast_reg_page_list *fr_pgl; struct ib_mr *fr_mr; + enum { FRMR_IS_INVALID, FRMR_IS_VALID } state; } frmr; } r; struct list_head mw_list; @@ -206,7 +212,6 @@ struct rpcrdma_req { struct rpcrdma_buffer { spinlock_t rb_lock; /* protects indexes */ atomic_t rb_credits; /* most recent server credits */ - unsigned long rb_cwndscale; /* cached framework rpc_cwndscale */ int rb_max_requests;/* client max requests */ struct list_head rb_mws; /* optional memory windows/fmrs/frmrs */ int rb_send_index; @@ -234,13 +239,13 @@ struct rpcrdma_create_data_internal { }; #define RPCRDMA_INLINE_READ_THRESHOLD(rq) \ - (rpcx_to_rdmad(rq->rq_task->tk_xprt).inline_rsize) + (rpcx_to_rdmad(rq->rq_xprt).inline_rsize) #define RPCRDMA_INLINE_WRITE_THRESHOLD(rq)\ - (rpcx_to_rdmad(rq->rq_task->tk_xprt).inline_wsize) + (rpcx_to_rdmad(rq->rq_xprt).inline_wsize) #define RPCRDMA_INLINE_PAD_VALUE(rq)\ - rpcx_to_rdmad(rq->rq_task->tk_xprt).padding + rpcx_to_rdmad(rq->rq_xprt).padding /* * Statistics for RPCRDMA @@ -299,7 +304,7 @@ void rpcrdma_ia_close(struct rpcrdma_ia *); */ int rpcrdma_ep_create(struct rpcrdma_ep *, struct rpcrdma_ia *, struct rpcrdma_create_data_internal *); -int rpcrdma_ep_destroy(struct rpcrdma_ep *, struct rpcrdma_ia *); +void rpcrdma_ep_destroy(struct rpcrdma_ep *, struct rpcrdma_ia *); int rpcrdma_ep_connect(struct rpcrdma_ep *, struct rpcrdma_ia *); int rpcrdma_ep_disconnect(struct rpcrdma_ep *, struct rpcrdma_ia *); @@ -329,11 +334,12 @@ int rpcrdma_deregister_internal(struct rpcrdma_ia *, int rpcrdma_register_external(struct rpcrdma_mr_seg *, int, int, struct rpcrdma_xprt *); int rpcrdma_deregister_external(struct rpcrdma_mr_seg *, - struct rpcrdma_xprt *, void *); + struct rpcrdma_xprt *); /* * RPC/RDMA connection management calls - xprtrdma/rpc_rdma.c */ +void rpcrdma_connect_worker(struct work_struct *); void rpcrdma_conn_func(struct rpcrdma_ep *); void rpcrdma_reply_handler(struct rpcrdma_rep *); @@ -342,4 +348,11 @@ void rpcrdma_reply_handler(struct rpcrdma_rep *); */ int rpcrdma_marshal_req(struct rpc_rqst *); +/* Temporary NFS request map cache. Created in svc_rdma.c */ +extern struct kmem_cache *svc_rdma_map_cachep; +/* WR context cache. Created in svc_rdma.c */ +extern struct kmem_cache *svc_rdma_ctxt_cachep; +/* Workqueue created in svc_rdma.c */ +extern struct workqueue_struct *svc_rdma_wq; + #endif /* _LINUX_SUNRPC_XPRT_RDMA_H */ diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index e4839c07c91..be8bbd5d65e 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -19,6 +19,7 @@ */ #include <linux/types.h> +#include <linux/string.h> #include <linux/slab.h> #include <linux/module.h> #include <linux/capability.h> @@ -28,14 +29,16 @@ #include <linux/in.h> #include <linux/net.h> #include <linux/mm.h> +#include <linux/un.h> #include <linux/udp.h> #include <linux/tcp.h> #include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/addr.h> #include <linux/sunrpc/sched.h> #include <linux/sunrpc/svcsock.h> #include <linux/sunrpc/xprtsock.h> #include <linux/file.h> -#ifdef CONFIG_NFS_V4_1 +#ifdef CONFIG_SUNRPC_BACKCHANNEL #include <linux/sunrpc/bc_xprt.h> #endif @@ -44,15 +47,21 @@ #include <net/udp.h> #include <net/tcp.h> +#include <trace/events/sunrpc.h> + #include "sunrpc.h" + +static void xs_close(struct rpc_xprt *xprt); + /* * xprtsock tunables */ -unsigned int xprt_udp_slot_table_entries = RPC_DEF_SLOT_TABLE; -unsigned int xprt_tcp_slot_table_entries = RPC_DEF_SLOT_TABLE; +static unsigned int xprt_udp_slot_table_entries = RPC_DEF_SLOT_TABLE; +static unsigned int xprt_tcp_slot_table_entries = RPC_MIN_SLOT_TABLE; +static unsigned int xprt_max_tcp_slot_table_entries = RPC_MAX_SLOT_TABLE; -unsigned int xprt_min_resvport = RPC_DEF_MIN_RESVPORT; -unsigned int xprt_max_resvport = RPC_DEF_MAX_RESVPORT; +static unsigned int xprt_min_resvport = RPC_DEF_MIN_RESVPORT; +static unsigned int xprt_max_resvport = RPC_DEF_MAX_RESVPORT; #define XS_TCP_LINGER_TO (15U * HZ) static unsigned int xs_tcp_fin_timeout __read_mostly = XS_TCP_LINGER_TO; @@ -70,6 +79,7 @@ static unsigned int xs_tcp_fin_timeout __read_mostly = XS_TCP_LINGER_TO; static unsigned int min_slot_table_size = RPC_MIN_SLOT_TABLE; static unsigned int max_slot_table_size = RPC_MAX_SLOT_TABLE; +static unsigned int max_tcp_slot_table_limit = RPC_MAX_SLOT_TABLE_LIMIT; static unsigned int xprt_min_resvport_limit = RPC_MIN_RESVPORT; static unsigned int xprt_max_resvport_limit = RPC_MAX_RESVPORT; @@ -79,7 +89,7 @@ static struct ctl_table_header *sunrpc_table_header; * FIXME: changing the UDP slot table size should also resize the UDP * socket buffers for existing UDP transports */ -static ctl_table xs_tunables_table[] = { +static struct ctl_table xs_tunables_table[] = { { .procname = "udp_slot_table_entries", .data = &xprt_udp_slot_table_entries, @@ -99,6 +109,15 @@ static ctl_table xs_tunables_table[] = { .extra2 = &max_slot_table_size }, { + .procname = "tcp_max_slot_table_entries", + .data = &xprt_max_tcp_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_slot_table_size, + .extra2 = &max_tcp_slot_table_limit + }, + { .procname = "min_resvport", .data = &xprt_min_resvport, .maxlen = sizeof(unsigned int), @@ -126,7 +145,7 @@ static ctl_table xs_tunables_table[] = { { }, }; -static ctl_table sunrpc_table[] = { +static struct ctl_table sunrpc_table[] = { { .procname = "sunrpc", .mode = 0555, @@ -138,20 +157,6 @@ static ctl_table sunrpc_table[] = { #endif /* - * Time out for an RPC UDP socket connect. UDP socket connects are - * synchronous, but we set a timeout anyway in case of resource - * exhaustion on the local host. - */ -#define XS_UDP_CONN_TO (5U * HZ) - -/* - * Wait duration for an RPC TCP connection to be established. Solaris - * NFS over TCP uses 60 seconds, for example, which is in line with how - * long a server takes to reboot. - */ -#define XS_TCP_CONN_TO (60U * HZ) - -/* * Wait duration for a reply from the RPC portmapper. */ #define XS_BIND_TO (60U * HZ) @@ -224,7 +229,8 @@ struct sock_xprt { * State of TCP reply receive */ __be32 tcp_fraghdr, - tcp_xid; + tcp_xid, + tcp_calldir; u32 tcp_offset, tcp_reclen; @@ -248,7 +254,7 @@ struct sock_xprt { /* * Saved socket callback addresses */ - void (*old_data_ready)(struct sock *, int); + void (*old_data_ready)(struct sock *); void (*old_state_change)(struct sock *); void (*old_write_space)(struct sock *); void (*old_error_report)(struct sock *); @@ -269,11 +275,21 @@ struct sock_xprt { */ #define TCP_RPC_REPLY (1UL << 6) +static inline struct rpc_xprt *xprt_from_sock(struct sock *sk) +{ + return (struct rpc_xprt *) sk->sk_user_data; +} + static inline struct sockaddr *xs_addr(struct rpc_xprt *xprt) { return (struct sockaddr *) &xprt->addr; } +static inline struct sockaddr_un *xs_addr_un(struct rpc_xprt *xprt) +{ + return (struct sockaddr_un *) &xprt->addr; +} + static inline struct sockaddr_in *xs_addr_in(struct rpc_xprt *xprt) { return (struct sockaddr_in *) &xprt->addr; @@ -289,23 +305,34 @@ static void xs_format_common_peer_addresses(struct rpc_xprt *xprt) struct sockaddr *sap = xs_addr(xprt); struct sockaddr_in6 *sin6; struct sockaddr_in *sin; + struct sockaddr_un *sun; char buf[128]; - (void)rpc_ntop(sap, buf, sizeof(buf)); - xprt->address_strings[RPC_DISPLAY_ADDR] = kstrdup(buf, GFP_KERNEL); - switch (sap->sa_family) { + case AF_LOCAL: + sun = xs_addr_un(xprt); + strlcpy(buf, sun->sun_path, sizeof(buf)); + xprt->address_strings[RPC_DISPLAY_ADDR] = + kstrdup(buf, GFP_KERNEL); + break; case AF_INET: + (void)rpc_ntop(sap, buf, sizeof(buf)); + xprt->address_strings[RPC_DISPLAY_ADDR] = + kstrdup(buf, GFP_KERNEL); sin = xs_addr_in(xprt); snprintf(buf, sizeof(buf), "%08x", ntohl(sin->sin_addr.s_addr)); break; case AF_INET6: + (void)rpc_ntop(sap, buf, sizeof(buf)); + xprt->address_strings[RPC_DISPLAY_ADDR] = + kstrdup(buf, GFP_KERNEL); sin6 = xs_addr_in6(xprt); snprintf(buf, sizeof(buf), "%pi6", &sin6->sin6_addr); break; default: BUG(); } + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL); } @@ -372,8 +399,10 @@ static int xs_send_kvec(struct socket *sock, struct sockaddr *addr, int addrlen, return kernel_sendmsg(sock, &msg, NULL, 0, 0); } -static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned int base, int more) +static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned int base, int more, bool zerocopy) { + ssize_t (*do_sendpage)(struct socket *sock, struct page *page, + int offset, size_t size, int flags); struct page **ppage; unsigned int remainder; int err, sent = 0; @@ -382,6 +411,9 @@ static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned i base += xdr->page_base; ppage = xdr->pages + (base >> PAGE_SHIFT); base &= ~PAGE_MASK; + do_sendpage = sock->ops->sendpage; + if (!zerocopy) + do_sendpage = sock_no_sendpage; for(;;) { unsigned int len = min_t(unsigned int, PAGE_SIZE - base, remainder); int flags = XS_SENDMSG_FLAGS; @@ -389,7 +421,7 @@ static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned i remainder -= len; if (remainder != 0 || more) flags |= MSG_MORE; - err = sock->ops->sendpage(sock, *ppage, base, len, flags); + err = do_sendpage(sock, *ppage, base, len, flags); if (remainder == 0 || err != len) break; sent += err; @@ -410,9 +442,10 @@ static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned i * @addrlen: UDP only -- length of destination address * @xdr: buffer containing this request * @base: starting position in the buffer + * @zerocopy: true if it is safe to use sendpage() * */ -static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base) +static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base, bool zerocopy) { unsigned int remainder = xdr->len - base; int err, sent = 0; @@ -440,7 +473,7 @@ static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, if (base < xdr->page_len) { unsigned int len = xdr->page_len - base; remainder -= len; - err = xs_send_pagedata(sock, xdr, base, remainder != 0); + err = xs_send_pagedata(sock, xdr, base, remainder != 0, zerocopy); if (remainder == 0 || err != len) goto out; sent += err; @@ -477,7 +510,8 @@ static int xs_nospace(struct rpc_task *task) struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - int ret = 0; + struct sock *sk = transport->inet; + int ret = -EAGAIN; dprintk("RPC: %5u xmit incomplete (%u left of %u)\n", task->tk_pid, req->rq_slen - req->rq_bytes_sent, @@ -489,13 +523,12 @@ static int xs_nospace(struct rpc_task *task) /* Don't race with disconnect */ if (xprt_connected(xprt)) { if (test_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags)) { - ret = -EAGAIN; /* * Notify TCP that we're limited by the application * window size */ set_bit(SOCK_NOSPACE, &transport->sock->flags); - transport->inet->sk_write_pending++; + sk->sk_write_pending++; /* ...and wait for more buffer space */ xprt_wait_for_buffer_space(task, xs_nospace_callback); } @@ -505,9 +538,76 @@ static int xs_nospace(struct rpc_task *task) } spin_unlock_bh(&xprt->transport_lock); + + /* Race breaker in case memory is freed before above code is called */ + sk->sk_write_space(sk); return ret; } +/* + * Construct a stream transport record marker in @buf. + */ +static inline void xs_encode_stream_record_marker(struct xdr_buf *buf) +{ + u32 reclen = buf->len - sizeof(rpc_fraghdr); + rpc_fraghdr *base = buf->head[0].iov_base; + *base = cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | reclen); +} + +/** + * xs_local_send_request - write an RPC request to an AF_LOCAL socket + * @task: RPC task that manages the state of an RPC request + * + * Return values: + * 0: The request has been sent + * EAGAIN: The socket was blocked, please call again later to + * complete the request + * ENOTCONN: Caller needs to invoke connect logic then call again + * other: Some other error occured, the request was not sent + */ +static int xs_local_send_request(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + struct xdr_buf *xdr = &req->rq_snd_buf; + int status; + + xs_encode_stream_record_marker(&req->rq_snd_buf); + + xs_pktdump("packet data:", + req->rq_svec->iov_base, req->rq_svec->iov_len); + + status = xs_sendpages(transport->sock, NULL, 0, + xdr, req->rq_bytes_sent, true); + dprintk("RPC: %s(%u) = %d\n", + __func__, xdr->len - req->rq_bytes_sent, status); + if (likely(status >= 0)) { + req->rq_bytes_sent += status; + req->rq_xmit_bytes_sent += status; + if (likely(req->rq_bytes_sent >= req->rq_slen)) { + req->rq_bytes_sent = 0; + return 0; + } + status = -EAGAIN; + } + + switch (status) { + case -EAGAIN: + status = xs_nospace(task); + break; + default: + dprintk("RPC: sendmsg returned unrecognized error %d\n", + -status); + case -EPIPE: + xs_close(xprt); + status = -ENOTCONN; + } + + return status; +} + /** * xs_udp_send_request - write an RPC request to a UDP socket * @task: address of RPC task that manages the state of an RPC request @@ -517,7 +617,7 @@ static int xs_nospace(struct rpc_task *task) * EAGAIN: The socket was blocked, please call again later to * complete the request * ENOTCONN: Caller needs to invoke connect logic then call again - * other: Some other error occured, the request was not sent + * other: Some other error occurred, the request was not sent */ static int xs_udp_send_request(struct rpc_task *task) { @@ -536,13 +636,13 @@ static int xs_udp_send_request(struct rpc_task *task) status = xs_sendpages(transport->sock, xs_addr(xprt), xprt->addrlen, xdr, - req->rq_bytes_sent); + req->rq_bytes_sent, true); dprintk("RPC: xs_udp_send_request(%u) = %d\n", xdr->len - req->rq_bytes_sent, status); if (status >= 0) { - task->tk_bytes_sent += status; + req->rq_xmit_bytes_sent += status; if (status >= req->rq_slen) return 0; /* Still some bytes left; set up for a retry later. */ @@ -583,15 +683,10 @@ static void xs_tcp_shutdown(struct rpc_xprt *xprt) struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); struct socket *sock = transport->sock; - if (sock != NULL) + if (sock != NULL) { kernel_sock_shutdown(sock, SHUT_WR); -} - -static inline void xs_encode_tcp_record_marker(struct xdr_buf *buf) -{ - u32 reclen = buf->len - sizeof(rpc_fraghdr); - rpc_fraghdr *base = buf->head[0].iov_base; - *base = htonl(RPC_LAST_STREAM_FRAGMENT | reclen); + trace_rpc_socket_shutdown(xprt, sock); + } } /** @@ -603,7 +698,7 @@ static inline void xs_encode_tcp_record_marker(struct xdr_buf *buf) * EAGAIN: The socket was blocked, please call again later to * complete the request * ENOTCONN: Caller needs to invoke connect logic then call again - * other: Some other error occured, the request was not sent + * other: Some other error occurred, the request was not sent * * XXX: In the case of soft timeouts, should we eventually give up * if sendmsg is not able to make progress? @@ -614,20 +709,28 @@ static int xs_tcp_send_request(struct rpc_task *task) struct rpc_xprt *xprt = req->rq_xprt; struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); struct xdr_buf *xdr = &req->rq_snd_buf; + bool zerocopy = true; int status; - xs_encode_tcp_record_marker(&req->rq_snd_buf); + xs_encode_stream_record_marker(&req->rq_snd_buf); xs_pktdump("packet data:", req->rq_svec->iov_base, req->rq_svec->iov_len); + /* Don't use zero copy if this is a resend. If the RPC call + * completes while the socket holds a reference to the pages, + * then we may end up resending corrupted data. + */ + if (task->tk_flags & RPC_TASK_SENT) + zerocopy = false; /* Continue transmitting the packet/record. We must be careful * to cope with writespace callbacks arriving _after_ we have * called sendmsg(). */ while (1) { status = xs_sendpages(transport->sock, - NULL, 0, xdr, req->rq_bytes_sent); + NULL, 0, xdr, req->rq_bytes_sent, + zerocopy); dprintk("RPC: xs_tcp_send_request(%u) = %d\n", xdr->len - req->rq_bytes_sent, status); @@ -638,7 +741,7 @@ static int xs_tcp_send_request(struct rpc_task *task) /* If we've sent the entire packet, immediately * reset the count of bytes sent. */ req->rq_bytes_sent += status; - task->tk_bytes_sent += status; + req->rq_xmit_bytes_sent += status; if (likely(req->rq_bytes_sent >= req->rq_slen)) { req->rq_bytes_sent = 0; return 0; @@ -662,10 +765,10 @@ static int xs_tcp_send_request(struct rpc_task *task) dprintk("RPC: sendmsg returned unrecognized error %d\n", -status); case -ECONNRESET: - case -EPIPE: xs_tcp_shutdown(xprt); case -ECONNREFUSED: case -ENOTCONN: + case -EPIPE: clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags); } @@ -690,11 +793,13 @@ static void xs_tcp_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task) if (task == NULL) goto out_release; req = task->tk_rqstp; + if (req == NULL) + goto out_release; if (req->rq_bytes_sent == 0) goto out_release; if (req->rq_bytes_sent == req->rq_snd_buf.len) goto out_release; - set_bit(XPRT_CLOSE_WAIT, &task->tk_xprt->state); + set_bit(XPRT_CLOSE_WAIT, &xprt->state); out_release: xprt_release_xprt(xprt, task); } @@ -715,6 +820,33 @@ static void xs_restore_old_callbacks(struct sock_xprt *transport, struct sock *s sk->sk_error_report = transport->old_error_report; } +/** + * xs_error_report - callback to handle TCP socket state errors + * @sk: socket + * + * Note: we don't call sock_error() since there may be a rpc_task + * using the socket, and so we don't want to clear sk->sk_err. + */ +static void xs_error_report(struct sock *sk) +{ + struct rpc_xprt *xprt; + int err; + + read_lock_bh(&sk->sk_callback_lock); + if (!(xprt = xprt_from_sock(sk))) + goto out; + + err = -sk->sk_err; + if (err == 0) + goto out; + dprintk("RPC: xs_error_report client %p, error=%d...\n", + xprt, -err); + trace_rpc_socket_error(xprt, sk->sk_socket, err); + xprt_wake_pending_tasks(xprt, err); + out: + read_unlock_bh(&sk->sk_callback_lock); +} + static void xs_reset_transport(struct sock_xprt *transport) { struct socket *sock = transport->sock; @@ -723,6 +855,8 @@ static void xs_reset_transport(struct sock_xprt *transport) if (sk == NULL) return; + transport->srcport = 0; + write_lock_bh(&sk->sk_callback_lock); transport->inet = NULL; transport->sock = NULL; @@ -732,8 +866,7 @@ static void xs_reset_transport(struct sock_xprt *transport) xs_restore_old_callbacks(transport, sk); write_unlock_bh(&sk->sk_callback_lock); - sk->sk_no_check = 0; - + trace_rpc_socket_close(&transport->xprt, sock); sock_release(sock); } @@ -753,14 +886,16 @@ static void xs_close(struct rpc_xprt *xprt) dprintk("RPC: xs_close xprt %p\n", xprt); + cancel_delayed_work_sync(&transport->connect_worker); + xs_reset_transport(transport); xprt->reestablish_timeout = 0; - smp_mb__before_clear_bit(); + smp_mb__before_atomic(); clear_bit(XPRT_CONNECTION_ABORT, &xprt->state); clear_bit(XPRT_CLOSE_WAIT, &xprt->state); clear_bit(XPRT_CLOSING, &xprt->state); - smp_mb__after_clear_bit(); + smp_mb__after_atomic(); xprt_disconnect_done(xprt); } @@ -772,6 +907,12 @@ static void xs_tcp_close(struct rpc_xprt *xprt) xs_tcp_shutdown(xprt); } +static void xs_xprt_free(struct rpc_xprt *xprt) +{ + xs_free_peer_addresses(xprt); + xprt_free(xprt); +} + /** * xs_destroy - prepare to shutdown a transport * @xprt: doomed transport @@ -779,22 +920,90 @@ static void xs_tcp_close(struct rpc_xprt *xprt) */ static void xs_destroy(struct rpc_xprt *xprt) { - struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - dprintk("RPC: xs_destroy xprt %p\n", xprt); - cancel_rearming_delayed_work(&transport->connect_worker); - xs_close(xprt); - xs_free_peer_addresses(xprt); - kfree(xprt->slot); - kfree(xprt); + xs_xprt_free(xprt); module_put(THIS_MODULE); } -static inline struct rpc_xprt *xprt_from_sock(struct sock *sk) +static int xs_local_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb) { - return (struct rpc_xprt *) sk->sk_user_data; + struct xdr_skb_reader desc = { + .skb = skb, + .offset = sizeof(rpc_fraghdr), + .count = skb->len - sizeof(rpc_fraghdr), + }; + + if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_bits) < 0) + return -1; + if (desc.count) + return -1; + return 0; +} + +/** + * xs_local_data_ready - "data ready" callback for AF_LOCAL sockets + * @sk: socket with data to read + * @len: how much data to read + * + * Currently this assumes we can read the whole reply in a single gulp. + */ +static void xs_local_data_ready(struct sock *sk) +{ + struct rpc_task *task; + struct rpc_xprt *xprt; + struct rpc_rqst *rovr; + struct sk_buff *skb; + int err, repsize, copied; + u32 _xid; + __be32 *xp; + + read_lock_bh(&sk->sk_callback_lock); + dprintk("RPC: %s...\n", __func__); + xprt = xprt_from_sock(sk); + if (xprt == NULL) + goto out; + + skb = skb_recv_datagram(sk, 0, 1, &err); + if (skb == NULL) + goto out; + + repsize = skb->len - sizeof(rpc_fraghdr); + if (repsize < 4) { + dprintk("RPC: impossible RPC reply size %d\n", repsize); + goto dropit; + } + + /* Copy the XID from the skb... */ + xp = skb_header_pointer(skb, sizeof(rpc_fraghdr), sizeof(_xid), &_xid); + if (xp == NULL) + goto dropit; + + /* Look up and lock the request corresponding to the given XID */ + spin_lock(&xprt->transport_lock); + rovr = xprt_lookup_rqst(xprt, *xp); + if (!rovr) + goto out_unlock; + task = rovr->rq_task; + + copied = rovr->rq_private_buf.buflen; + if (copied > repsize) + copied = repsize; + + if (xs_local_copy_to_xdr(&rovr->rq_private_buf, skb)) { + dprintk("RPC: sk_buff copy failed\n"); + goto out_unlock; + } + + xprt_complete_rqst(task, copied); + + out_unlock: + spin_unlock(&xprt->transport_lock); + dropit: + skb_free_datagram(sk, skb); + out: + read_unlock_bh(&sk->sk_callback_lock); } /** @@ -803,7 +1012,7 @@ static inline struct rpc_xprt *xprt_from_sock(struct sock *sk) * @len: how much data to read * */ -static void xs_udp_data_ready(struct sock *sk, int len) +static void xs_udp_data_ready(struct sock *sk) { struct rpc_task *task; struct rpc_xprt *xprt; @@ -813,7 +1022,7 @@ static void xs_udp_data_ready(struct sock *sk, int len) u32 _xid; __be32 *xp; - read_lock(&sk->sk_callback_lock); + read_lock_bh(&sk->sk_callback_lock); dprintk("RPC: xs_udp_data_ready...\n"); if (!(xprt = xprt_from_sock(sk))) goto out; @@ -821,9 +1030,6 @@ static void xs_udp_data_ready(struct sock *sk, int len) if ((skb = skb_recv_datagram(sk, 0, 1, &err)) == NULL) goto out; - if (xprt->shutdown) - goto dropit; - repsize = skb->len - sizeof(struct udphdr); if (repsize < 4) { dprintk("RPC: impossible RPC reply size %d!\n", repsize); @@ -854,11 +1060,7 @@ static void xs_udp_data_ready(struct sock *sk, int len) UDPX_INC_STATS_BH(sk, UDP_MIB_INDATAGRAMS); - /* Something worked... */ - dst_confirm(skb_dst(skb)); - - xprt_adjust_cwnd(task, copied); - xprt_update_rtt(task); + xprt_adjust_cwnd(xprt, task, copied); xprt_complete_rqst(task, copied); out_unlock: @@ -866,7 +1068,17 @@ static void xs_udp_data_ready(struct sock *sk, int len) dropit: skb_free_datagram(sk, skb); out: - read_unlock(&sk->sk_callback_lock); + read_unlock_bh(&sk->sk_callback_lock); +} + +/* + * Helper function to force a TCP close if the server is sending + * junk and/or it has put us in CLOSE_WAIT + */ +static void xs_tcp_force_close(struct rpc_xprt *xprt) +{ + set_bit(XPRT_CONNECTION_CLOSE, &xprt->state); + xprt_force_disconnect(xprt); } static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, struct xdr_skb_reader *desc) @@ -895,7 +1107,7 @@ static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, struct xdr_skb_rea /* Sanity check of the record length */ if (unlikely(transport->tcp_reclen < 8)) { dprintk("RPC: invalid TCP record fragment length\n"); - xprt_force_disconnect(xprt); + xs_tcp_force_close(xprt); return; } dprintk("RPC: reading TCP record fragment of length %d\n", @@ -942,7 +1154,7 @@ static inline void xs_tcp_read_calldir(struct sock_xprt *transport, { size_t len, used; u32 offset; - __be32 calldir; + char *p; /* * We want transport->tcp_offset to be 8 at the end of this routine @@ -951,26 +1163,33 @@ static inline void xs_tcp_read_calldir(struct sock_xprt *transport, * transport->tcp_offset is 4 (after having already read the xid). */ offset = transport->tcp_offset - sizeof(transport->tcp_xid); - len = sizeof(calldir) - offset; + len = sizeof(transport->tcp_calldir) - offset; dprintk("RPC: reading CALL/REPLY flag (%Zu bytes)\n", len); - used = xdr_skb_read_bits(desc, &calldir, len); + p = ((char *) &transport->tcp_calldir) + offset; + used = xdr_skb_read_bits(desc, p, len); transport->tcp_offset += used; if (used != len) return; transport->tcp_flags &= ~TCP_RCV_READ_CALLDIR; - transport->tcp_flags |= TCP_RCV_COPY_CALLDIR; - transport->tcp_flags |= TCP_RCV_COPY_DATA; /* * We don't yet have the XDR buffer, so we will write the calldir * out after we get the buffer from the 'struct rpc_rqst' */ - if (ntohl(calldir) == RPC_REPLY) + switch (ntohl(transport->tcp_calldir)) { + case RPC_REPLY: + transport->tcp_flags |= TCP_RCV_COPY_CALLDIR; + transport->tcp_flags |= TCP_RCV_COPY_DATA; transport->tcp_flags |= TCP_RPC_REPLY; - else + break; + case RPC_CALL: + transport->tcp_flags |= TCP_RCV_COPY_CALLDIR; + transport->tcp_flags |= TCP_RCV_COPY_DATA; transport->tcp_flags &= ~TCP_RPC_REPLY; - dprintk("RPC: reading %s CALL/REPLY flag %08x\n", - (transport->tcp_flags & TCP_RPC_REPLY) ? - "reply for" : "request with", calldir); + break; + default: + dprintk("RPC: invalid request message type\n"); + xs_tcp_force_close(&transport->xprt); + } xs_tcp_check_fraghdr(transport); } @@ -990,12 +1209,10 @@ static inline void xs_tcp_read_common(struct rpc_xprt *xprt, /* * Save the RPC direction in the XDR buffer */ - __be32 calldir = transport->tcp_flags & TCP_RPC_REPLY ? - htonl(RPC_REPLY) : 0; - memcpy(rcvbuf->head[0].iov_base + transport->tcp_copied, - &calldir, sizeof(calldir)); - transport->tcp_copied += sizeof(calldir); + &transport->tcp_calldir, + sizeof(transport->tcp_calldir)); + transport->tcp_copied += sizeof(transport->tcp_calldir); transport->tcp_flags &= ~TCP_RCV_COPY_CALLDIR; } @@ -1050,8 +1267,6 @@ static inline void xs_tcp_read_common(struct rpc_xprt *xprt, if (transport->tcp_flags & TCP_RCV_LAST_FRAG) transport->tcp_flags &= ~TCP_RCV_COPY_DATA; } - - return; } /* @@ -1086,7 +1301,7 @@ static inline int xs_tcp_read_reply(struct rpc_xprt *xprt, return 0; } -#if defined(CONFIG_NFS_V4_1) +#if defined(CONFIG_SUNRPC_BACKCHANNEL) /* * Obtains an rpc_rqst previously allocated and invokes the common * tcp read code to read the data. The result is placed in the callback @@ -1094,41 +1309,29 @@ static inline int xs_tcp_read_reply(struct rpc_xprt *xprt, * If we're unable to obtain the rpc_rqst we schedule the closing of the * connection and return -1. */ -static inline int xs_tcp_read_callback(struct rpc_xprt *xprt, +static int xs_tcp_read_callback(struct rpc_xprt *xprt, struct xdr_skb_reader *desc) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); struct rpc_rqst *req; - req = xprt_alloc_bc_request(xprt); + /* Look up and lock the request corresponding to the given XID */ + spin_lock(&xprt->transport_lock); + req = xprt_lookup_bc_request(xprt, transport->tcp_xid); if (req == NULL) { + spin_unlock(&xprt->transport_lock); printk(KERN_WARNING "Callback slot table overflowed\n"); xprt_force_disconnect(xprt); return -1; } - req->rq_xid = transport->tcp_xid; dprintk("RPC: read callback XID %08x\n", ntohl(req->rq_xid)); xs_tcp_read_common(xprt, desc, req); - if (!(transport->tcp_flags & TCP_RCV_COPY_DATA)) { - struct svc_serv *bc_serv = xprt->bc_serv; - - /* - * Add callback request to callback list. The callback - * service sleeps on the sv_cb_waitq waiting for new - * requests. Wake it up after adding enqueing the - * request. - */ - dprintk("RPC: add callback request to list\n"); - spin_lock(&bc_serv->sv_cb_lock); - list_add(&req->rq_bc_list, &bc_serv->sv_cb_list); - spin_unlock(&bc_serv->sv_cb_lock); - wake_up(&bc_serv->sv_cb_waitq); - } - - req->rq_private_buf.len = transport->tcp_copied; + if (!(transport->tcp_flags & TCP_RCV_COPY_DATA)) + xprt_complete_bc_request(req, transport->tcp_copied); + spin_unlock(&xprt->transport_lock); return 0; } @@ -1149,7 +1352,7 @@ static inline int _xs_tcp_read_data(struct rpc_xprt *xprt, { return xs_tcp_read_reply(xprt, desc); } -#endif /* CONFIG_NFS_V4_1 */ +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ /* * Read data off the transport. This can be either an RPC_CALL or an @@ -1232,7 +1435,7 @@ static int xs_tcp_data_recv(read_descriptor_t *rd_desc, struct sk_buff *skb, uns * @bytes: how much data to read * */ -static void xs_tcp_data_ready(struct sock *sk, int bytes) +static void xs_tcp_data_ready(struct sock *sk) { struct rpc_xprt *xprt; read_descriptor_t rd_desc; @@ -1240,12 +1443,9 @@ static void xs_tcp_data_ready(struct sock *sk, int bytes) dprintk("RPC: xs_tcp_data_ready...\n"); - read_lock(&sk->sk_callback_lock); + read_lock_bh(&sk->sk_callback_lock); if (!(xprt = xprt_from_sock(sk))) goto out; - if (xprt->shutdown) - goto out; - /* Any data means we had a useful conversation, so * the we don't need to delay the next reconnect */ @@ -1259,7 +1459,7 @@ static void xs_tcp_data_ready(struct sock *sk, int bytes) read = tcp_read_sock(sk, &rd_desc, xs_tcp_data_recv); } while (read > 0); out: - read_unlock(&sk->sk_callback_lock); + read_unlock_bh(&sk->sk_callback_lock); } /* @@ -1293,12 +1493,19 @@ static void xs_tcp_cancel_linger_timeout(struct rpc_xprt *xprt) xprt_clear_connecting(xprt); } -static void xs_sock_mark_closed(struct rpc_xprt *xprt) +static void xs_sock_reset_connection_flags(struct rpc_xprt *xprt) { - smp_mb__before_clear_bit(); + smp_mb__before_atomic(); + clear_bit(XPRT_CONNECTION_ABORT, &xprt->state); + clear_bit(XPRT_CONNECTION_CLOSE, &xprt->state); clear_bit(XPRT_CLOSE_WAIT, &xprt->state); clear_bit(XPRT_CLOSING, &xprt->state); - smp_mb__after_clear_bit(); + smp_mb__after_atomic(); +} + +static void xs_sock_mark_closed(struct rpc_xprt *xprt) +{ + xs_sock_reset_connection_flags(xprt); /* Mark transport as closed and wake up all pending tasks */ xprt_disconnect_done(xprt); } @@ -1312,18 +1519,20 @@ static void xs_tcp_state_change(struct sock *sk) { struct rpc_xprt *xprt; - read_lock(&sk->sk_callback_lock); + read_lock_bh(&sk->sk_callback_lock); if (!(xprt = xprt_from_sock(sk))) goto out; dprintk("RPC: xs_tcp_state_change client %p...\n", xprt); - dprintk("RPC: state %x conn %d dead %d zapped %d\n", + dprintk("RPC: state %x conn %d dead %d zapped %d sk_shutdown %d\n", sk->sk_state, xprt_connected(xprt), sock_flag(sk, SOCK_DEAD), - sock_flag(sk, SOCK_ZAPPED)); + sock_flag(sk, SOCK_ZAPPED), + sk->sk_shutdown); + trace_rpc_socket_state_change(xprt, sk->sk_socket); switch (sk->sk_state) { case TCP_ESTABLISHED: - spin_lock_bh(&xprt->transport_lock); + spin_lock(&xprt->transport_lock); if (!xprt_test_and_set_connected(xprt)) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); @@ -1334,27 +1543,28 @@ static void xs_tcp_state_change(struct sock *sk) transport->tcp_copied = 0; transport->tcp_flags = TCP_RCV_COPY_FRAGHDR | TCP_RCV_COPY_XID; + xprt->connect_cookie++; xprt_wake_pending_tasks(xprt, -EAGAIN); } - spin_unlock_bh(&xprt->transport_lock); + spin_unlock(&xprt->transport_lock); break; case TCP_FIN_WAIT1: /* The client initiated a shutdown of the socket */ xprt->connect_cookie++; xprt->reestablish_timeout = 0; set_bit(XPRT_CLOSING, &xprt->state); - smp_mb__before_clear_bit(); + smp_mb__before_atomic(); clear_bit(XPRT_CONNECTED, &xprt->state); clear_bit(XPRT_CLOSE_WAIT, &xprt->state); - smp_mb__after_clear_bit(); + smp_mb__after_atomic(); xs_tcp_schedule_linger_timeout(xprt, xs_tcp_fin_timeout); break; case TCP_CLOSE_WAIT: /* The server initiated a shutdown of the socket */ - xprt_force_disconnect(xprt); - case TCP_SYN_SENT: xprt->connect_cookie++; + clear_bit(XPRT_CONNECTED, &xprt->state); + xs_tcp_force_close(xprt); case TCP_CLOSING: /* * If the server closed down the connection, make sure that @@ -1366,35 +1576,16 @@ static void xs_tcp_state_change(struct sock *sk) case TCP_LAST_ACK: set_bit(XPRT_CLOSING, &xprt->state); xs_tcp_schedule_linger_timeout(xprt, xs_tcp_fin_timeout); - smp_mb__before_clear_bit(); + smp_mb__before_atomic(); clear_bit(XPRT_CONNECTED, &xprt->state); - smp_mb__after_clear_bit(); + smp_mb__after_atomic(); break; case TCP_CLOSE: xs_tcp_cancel_linger_timeout(xprt); xs_sock_mark_closed(xprt); } out: - read_unlock(&sk->sk_callback_lock); -} - -/** - * xs_error_report - callback mainly for catching socket errors - * @sk: socket - */ -static void xs_error_report(struct sock *sk) -{ - struct rpc_xprt *xprt; - - read_lock(&sk->sk_callback_lock); - if (!(xprt = xprt_from_sock(sk))) - goto out; - dprintk("RPC: %s client %p...\n" - "RPC: error %d\n", - __func__, xprt, sk->sk_err); - xprt_wake_pending_tasks(xprt, -EAGAIN); -out: - read_unlock(&sk->sk_callback_lock); + read_unlock_bh(&sk->sk_callback_lock); } static void xs_write_space(struct sock *sk) @@ -1426,13 +1617,13 @@ static void xs_write_space(struct sock *sk) */ static void xs_udp_write_space(struct sock *sk) { - read_lock(&sk->sk_callback_lock); + read_lock_bh(&sk->sk_callback_lock); /* from net/core/sock.c:sock_def_write_space */ if (sock_writeable(sk)) xs_write_space(sk); - read_unlock(&sk->sk_callback_lock); + read_unlock_bh(&sk->sk_callback_lock); } /** @@ -1447,13 +1638,13 @@ static void xs_udp_write_space(struct sock *sk) */ static void xs_tcp_write_space(struct sock *sk) { - read_lock(&sk->sk_callback_lock); + read_lock_bh(&sk->sk_callback_lock); /* from net/core/stream.c:sk_stream_write_space */ - if (sk_stream_wspace(sk) >= sk_stream_min_wspace(sk)) + if (sk_stream_is_writeable(sk)) xs_write_space(sk); - read_unlock(&sk->sk_callback_lock); + read_unlock_bh(&sk->sk_callback_lock); } static void xs_udp_do_set_buffer_size(struct rpc_xprt *xprt) @@ -1500,15 +1691,15 @@ static void xs_udp_set_buffer_size(struct rpc_xprt *xprt, size_t sndsize, size_t * * Adjust the congestion window after a retransmit timeout has occurred. */ -static void xs_udp_timer(struct rpc_task *task) +static void xs_udp_timer(struct rpc_xprt *xprt, struct rpc_task *task) { - xprt_adjust_cwnd(task, -ETIMEDOUT); + xprt_adjust_cwnd(xprt, task, -ETIMEDOUT); } static unsigned short xs_get_random_port(void) { unsigned short range = xprt_max_resvport - xprt_min_resvport; - unsigned short rand = (unsigned short) net_random() % range; + unsigned short rand = (unsigned short) prandom_u32() % range; return rand + xprt_min_resvport; } @@ -1526,7 +1717,7 @@ static void xs_set_port(struct rpc_xprt *xprt, unsigned short port) xs_update_peer_port(xprt); } -static unsigned short xs_get_srcport(struct sock_xprt *transport, struct socket *sock) +static unsigned short xs_get_srcport(struct sock_xprt *transport) { unsigned short port = transport->srcport; @@ -1535,7 +1726,7 @@ static unsigned short xs_get_srcport(struct sock_xprt *transport, struct socket return port; } -static unsigned short xs_next_srcport(struct sock_xprt *transport, struct socket *sock, unsigned short port) +static unsigned short xs_next_srcport(struct sock_xprt *transport, unsigned short port) { if (transport->srcport != 0) transport->srcport = 0; @@ -1545,23 +1736,18 @@ static unsigned short xs_next_srcport(struct sock_xprt *transport, struct socket return xprt_max_resvport; return --port; } - -static int xs_bind4(struct sock_xprt *transport, struct socket *sock) +static int xs_bind(struct sock_xprt *transport, struct socket *sock) { - struct sockaddr_in myaddr = { - .sin_family = AF_INET, - }; - struct sockaddr_in *sa; + struct sockaddr_storage myaddr; int err, nloop = 0; - unsigned short port = xs_get_srcport(transport, sock); + unsigned short port = xs_get_srcport(transport); unsigned short last; - sa = (struct sockaddr_in *)&transport->srcaddr; - myaddr.sin_addr = sa->sin_addr; + memcpy(&myaddr, &transport->srcaddr, transport->xprt.addrlen); do { - myaddr.sin_port = htons(port); - err = kernel_bind(sock, (struct sockaddr *) &myaddr, - sizeof(myaddr)); + rpc_set_port((struct sockaddr *)&myaddr, port); + err = kernel_bind(sock, (struct sockaddr *)&myaddr, + transport->xprt.addrlen); if (port == 0) break; if (err == 0) { @@ -1569,57 +1755,52 @@ static int xs_bind4(struct sock_xprt *transport, struct socket *sock) break; } last = port; - port = xs_next_srcport(transport, sock, port); + port = xs_next_srcport(transport, port); if (port > last) nloop++; } while (err == -EADDRINUSE && nloop != 2); - dprintk("RPC: %s %pI4:%u: %s (%d)\n", - __func__, &myaddr.sin_addr, - port, err ? "failed" : "ok", err); + + if (myaddr.ss_family == AF_INET) + dprintk("RPC: %s %pI4:%u: %s (%d)\n", __func__, + &((struct sockaddr_in *)&myaddr)->sin_addr, + port, err ? "failed" : "ok", err); + else + dprintk("RPC: %s %pI6:%u: %s (%d)\n", __func__, + &((struct sockaddr_in6 *)&myaddr)->sin6_addr, + port, err ? "failed" : "ok", err); return err; } -static int xs_bind6(struct sock_xprt *transport, struct socket *sock) +/* + * We don't support autobind on AF_LOCAL sockets + */ +static void xs_local_rpcbind(struct rpc_task *task) { - struct sockaddr_in6 myaddr = { - .sin6_family = AF_INET6, - }; - struct sockaddr_in6 *sa; - int err, nloop = 0; - unsigned short port = xs_get_srcport(transport, sock); - unsigned short last; + rcu_read_lock(); + xprt_set_bound(rcu_dereference(task->tk_client->cl_xprt)); + rcu_read_unlock(); +} - sa = (struct sockaddr_in6 *)&transport->srcaddr; - myaddr.sin6_addr = sa->sin6_addr; - do { - myaddr.sin6_port = htons(port); - err = kernel_bind(sock, (struct sockaddr *) &myaddr, - sizeof(myaddr)); - if (port == 0) - break; - if (err == 0) { - transport->srcport = port; - break; - } - last = port; - port = xs_next_srcport(transport, sock, port); - if (port > last) - nloop++; - } while (err == -EADDRINUSE && nloop != 2); - dprintk("RPC: xs_bind6 %pI6:%u: %s (%d)\n", - &myaddr.sin6_addr, port, err ? "failed" : "ok", err); - return err; +static void xs_local_set_port(struct rpc_xprt *xprt, unsigned short port) +{ } #ifdef CONFIG_DEBUG_LOCK_ALLOC static struct lock_class_key xs_key[2]; static struct lock_class_key xs_slock_key[2]; +static inline void xs_reclassify_socketu(struct socket *sock) +{ + struct sock *sk = sock->sk; + + sock_lock_init_class_and_name(sk, "slock-AF_LOCAL-RPC", + &xs_slock_key[1], "sk_lock-AF_LOCAL-RPC", &xs_key[1]); +} + static inline void xs_reclassify_socket4(struct socket *sock) { struct sock *sk = sock->sk; - BUG_ON(sock_owned_by_user(sk)); sock_lock_init_class_and_name(sk, "slock-AF_INET-RPC", &xs_slock_key[0], "sk_lock-AF_INET-RPC", &xs_key[0]); } @@ -1628,11 +1809,33 @@ static inline void xs_reclassify_socket6(struct socket *sock) { struct sock *sk = sock->sk; - BUG_ON(sock_owned_by_user(sk)); sock_lock_init_class_and_name(sk, "slock-AF_INET6-RPC", &xs_slock_key[1], "sk_lock-AF_INET6-RPC", &xs_key[1]); } + +static inline void xs_reclassify_socket(int family, struct socket *sock) +{ + WARN_ON_ONCE(sock_owned_by_user(sock->sk)); + if (sock_owned_by_user(sock->sk)) + return; + + switch (family) { + case AF_LOCAL: + xs_reclassify_socketu(sock); + break; + case AF_INET: + xs_reclassify_socket4(sock); + break; + case AF_INET6: + xs_reclassify_socket6(sock); + break; + } +} #else +static inline void xs_reclassify_socketu(struct socket *sock) +{ +} + static inline void xs_reclassify_socket4(struct socket *sock) { } @@ -1640,11 +1843,46 @@ static inline void xs_reclassify_socket4(struct socket *sock) static inline void xs_reclassify_socket6(struct socket *sock) { } + +static inline void xs_reclassify_socket(int family, struct socket *sock) +{ +} #endif -static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) +static void xs_dummy_setup_socket(struct work_struct *work) { - struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); +} + +static struct socket *xs_create_sock(struct rpc_xprt *xprt, + struct sock_xprt *transport, int family, int type, int protocol) +{ + struct socket *sock; + int err; + + err = __sock_create(xprt->xprt_net, family, type, protocol, &sock, 1); + if (err < 0) { + dprintk("RPC: can't create %d transport socket (%d).\n", + protocol, -err); + goto out; + } + xs_reclassify_socket(family, sock); + + err = xs_bind(transport, sock); + if (err) { + sock_release(sock); + goto out; + } + + return sock; +out: + return ERR_PTR(err); +} + +static int xs_local_finish_connecting(struct rpc_xprt *xprt, + struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, + xprt); if (!transport->inet) { struct sock *sk = sock->sk; @@ -1654,13 +1892,12 @@ static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) xs_save_old_callbacks(transport, sk); sk->sk_user_data = xprt; - sk->sk_data_ready = xs_udp_data_ready; + sk->sk_data_ready = xs_local_data_ready; sk->sk_write_space = xs_udp_write_space; sk->sk_error_report = xs_error_report; - sk->sk_no_check = UDP_CSUM_NORCV; sk->sk_allocation = GFP_ATOMIC; - xprt_set_connected(xprt); + xprt_clear_connected(xprt); /* Reset to new socket */ transport->sock = sock; @@ -1668,85 +1905,176 @@ static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) write_unlock_bh(&sk->sk_callback_lock); } - xs_udp_do_set_buffer_size(xprt); + + /* Tell the socket layer to start connecting... */ + xprt->stat.connect_count++; + xprt->stat.connect_start = jiffies; + return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, 0); } /** - * xs_udp_connect_worker4 - set up a UDP socket - * @work: RPC transport to connect - * - * Invoked by a work queue tasklet. + * xs_local_setup_socket - create AF_LOCAL socket, connect to a local endpoint + * @xprt: RPC transport to connect + * @transport: socket transport to connect + * @create_sock: function to create a socket of the correct type */ -static void xs_udp_connect_worker4(struct work_struct *work) +static int xs_local_setup_socket(struct sock_xprt *transport) { - struct sock_xprt *transport = - container_of(work, struct sock_xprt, connect_worker.work); struct rpc_xprt *xprt = &transport->xprt; - struct socket *sock = transport->sock; - int err, status = -EIO; - - if (xprt->shutdown) - goto out; + struct socket *sock; + int status = -EIO; - /* Start by resetting any existing state */ - xs_reset_transport(transport); + current->flags |= PF_FSTRANS; - err = sock_create_kern(PF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock); - if (err < 0) { - dprintk("RPC: can't create UDP transport socket (%d).\n", -err); + clear_bit(XPRT_CONNECTION_ABORT, &xprt->state); + status = __sock_create(xprt->xprt_net, AF_LOCAL, + SOCK_STREAM, 0, &sock, 1); + if (status < 0) { + dprintk("RPC: can't create AF_LOCAL " + "transport socket (%d).\n", -status); goto out; } - xs_reclassify_socket4(sock); + xs_reclassify_socketu(sock); - if (xs_bind4(transport, sock)) { - sock_release(sock); - goto out; - } + dprintk("RPC: worker connecting xprt %p via AF_LOCAL to %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); - dprintk("RPC: worker connecting xprt %p via %s to " - "%s (port %s)\n", xprt, - xprt->address_strings[RPC_DISPLAY_PROTO], - xprt->address_strings[RPC_DISPLAY_ADDR], - xprt->address_strings[RPC_DISPLAY_PORT]); + status = xs_local_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, status); + switch (status) { + case 0: + dprintk("RPC: xprt %p connected to %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + xprt_set_connected(xprt); + break; + case -ENOENT: + dprintk("RPC: xprt %p: socket %s does not exist\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + break; + case -ECONNREFUSED: + dprintk("RPC: xprt %p: connection refused for %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + break; + default: + printk(KERN_ERR "%s: unhandled error (%d) connecting to %s\n", + __func__, -status, + xprt->address_strings[RPC_DISPLAY_ADDR]); + } - xs_udp_finish_connecting(xprt, sock); - status = 0; out: xprt_clear_connecting(xprt); xprt_wake_pending_tasks(xprt, status); + current->flags &= ~PF_FSTRANS; + return status; +} + +static void xs_local_connect(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + int ret; + + if (RPC_IS_ASYNC(task)) { + /* + * We want the AF_LOCAL connect to be resolved in the + * filesystem namespace of the process making the rpc + * call. Thus we connect synchronously. + * + * If we want to support asynchronous AF_LOCAL calls, + * we'll need to figure out how to pass a namespace to + * connect. + */ + rpc_exit(task, -ENOTCONN); + return; + } + ret = xs_local_setup_socket(transport); + if (ret && !RPC_IS_SOFTCONN(task)) + msleep_interruptible(15000); +} + +#ifdef CONFIG_SUNRPC_SWAP +static void xs_set_memalloc(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, + xprt); + + if (xprt->swapper) + sk_set_memalloc(transport->inet); } /** - * xs_udp_connect_worker6 - set up a UDP socket - * @work: RPC transport to connect + * xs_swapper - Tag this transport as being used for swap. + * @xprt: transport to tag + * @enable: enable/disable * - * Invoked by a work queue tasklet. */ -static void xs_udp_connect_worker6(struct work_struct *work) +int xs_swapper(struct rpc_xprt *xprt, int enable) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, + xprt); + int err = 0; + + if (enable) { + xprt->swapper++; + xs_set_memalloc(xprt); + } else if (xprt->swapper) { + xprt->swapper--; + sk_clear_memalloc(transport->inet); + } + + return err; +} +EXPORT_SYMBOL_GPL(xs_swapper); +#else +static void xs_set_memalloc(struct rpc_xprt *xprt) +{ +} +#endif + +static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + if (!transport->inet) { + struct sock *sk = sock->sk; + + write_lock_bh(&sk->sk_callback_lock); + + xs_save_old_callbacks(transport, sk); + + sk->sk_user_data = xprt; + sk->sk_data_ready = xs_udp_data_ready; + sk->sk_write_space = xs_udp_write_space; + sk->sk_allocation = GFP_ATOMIC; + + xprt_set_connected(xprt); + + /* Reset to new socket */ + transport->sock = sock; + transport->inet = sk; + + xs_set_memalloc(xprt); + + write_unlock_bh(&sk->sk_callback_lock); + } + xs_udp_do_set_buffer_size(xprt); +} + +static void xs_udp_setup_socket(struct work_struct *work) { struct sock_xprt *transport = container_of(work, struct sock_xprt, connect_worker.work); struct rpc_xprt *xprt = &transport->xprt; struct socket *sock = transport->sock; - int err, status = -EIO; + int status = -EIO; - if (xprt->shutdown) - goto out; + current->flags |= PF_FSTRANS; /* Start by resetting any existing state */ xs_reset_transport(transport); - - err = sock_create_kern(PF_INET6, SOCK_DGRAM, IPPROTO_UDP, &sock); - if (err < 0) { - dprintk("RPC: can't create UDP transport socket (%d).\n", -err); + sock = xs_create_sock(xprt, transport, + xs_addr(xprt)->sa_family, SOCK_DGRAM, IPPROTO_UDP); + if (IS_ERR(sock)) goto out; - } - xs_reclassify_socket6(sock); - - if (xs_bind6(transport, sock) < 0) { - sock_release(sock); - goto out; - } dprintk("RPC: worker connecting xprt %p via %s to " "%s (port %s)\n", xprt, @@ -1755,22 +2083,24 @@ static void xs_udp_connect_worker6(struct work_struct *work) xprt->address_strings[RPC_DISPLAY_PORT]); xs_udp_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, 0); status = 0; out: xprt_clear_connecting(xprt); xprt_wake_pending_tasks(xprt, status); + current->flags &= ~PF_FSTRANS; } /* * We need to preserve the port number so the reply cache on the server can * find our cached RPC replies when we get around to reconnecting. */ -static void xs_abort_connection(struct rpc_xprt *xprt, struct sock_xprt *transport) +static void xs_abort_connection(struct sock_xprt *transport) { int result; struct sockaddr any; - dprintk("RPC: disconnecting xprt %p to reuse port\n", xprt); + dprintk("RPC: disconnecting xprt %p to reuse port\n", transport); /* * Disconnect the transport socket by doing a connect operation @@ -1779,30 +2109,59 @@ static void xs_abort_connection(struct rpc_xprt *xprt, struct sock_xprt *transpo memset(&any, 0, sizeof(any)); any.sa_family = AF_UNSPEC; result = kernel_connect(transport->sock, &any, sizeof(any), 0); + trace_rpc_socket_reset_connection(&transport->xprt, + transport->sock, result); if (!result) - xs_sock_mark_closed(xprt); - else - dprintk("RPC: AF_UNSPEC connect return code %d\n", - result); + xs_sock_reset_connection_flags(&transport->xprt); + dprintk("RPC: AF_UNSPEC connect return code %d\n", result); } -static void xs_tcp_reuse_connection(struct rpc_xprt *xprt, struct sock_xprt *transport) +static void xs_tcp_reuse_connection(struct sock_xprt *transport) { unsigned int state = transport->inet->sk_state; - if (state == TCP_CLOSE && transport->sock->state == SS_UNCONNECTED) - return; - if ((1 << state) & (TCPF_ESTABLISHED|TCPF_SYN_SENT)) - return; - xs_abort_connection(xprt, transport); + if (state == TCP_CLOSE && transport->sock->state == SS_UNCONNECTED) { + /* we don't need to abort the connection if the socket + * hasn't undergone a shutdown + */ + if (transport->inet->sk_shutdown == 0) + return; + dprintk("RPC: %s: TCP_CLOSEd and sk_shutdown set to %d\n", + __func__, transport->inet->sk_shutdown); + } + if ((1 << state) & (TCPF_ESTABLISHED|TCPF_SYN_SENT)) { + /* we don't need to abort the connection if the socket + * hasn't undergone a shutdown + */ + if (transport->inet->sk_shutdown == 0) + return; + dprintk("RPC: %s: ESTABLISHED/SYN_SENT " + "sk_shutdown set to %d\n", + __func__, transport->inet->sk_shutdown); + } + xs_abort_connection(transport); } static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + int ret = -ENOTCONN; if (!transport->inet) { struct sock *sk = sock->sk; + unsigned int keepidle = xprt->timeout->to_initval / HZ; + unsigned int keepcnt = xprt->timeout->to_retries + 1; + unsigned int opt_on = 1; + + /* TCP Keepalive options */ + kernel_setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, + (char *)&opt_on, sizeof(opt_on)); + kernel_setsockopt(sock, SOL_TCP, TCP_KEEPIDLE, + (char *)&keepidle, sizeof(keepidle)); + kernel_setsockopt(sock, SOL_TCP, TCP_KEEPINTVL, + (char *)&keepidle, sizeof(keepidle)); + kernel_setsockopt(sock, SOL_TCP, TCP_KEEPCNT, + (char *)&keepcnt, sizeof(keepcnt)); write_lock_bh(&sk->sk_callback_lock); @@ -1831,12 +2190,23 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) } if (!xprt_bound(xprt)) - return -ENOTCONN; + goto out; + + xs_set_memalloc(xprt); /* Tell the socket layer to start connecting... */ xprt->stat.connect_count++; xprt->stat.connect_start = jiffies; - return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK); + ret = kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK); + switch (ret) { + case 0: + case -EINPROGRESS: + /* SYN_SENT! */ + if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO) + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + } +out: + return ret; } /** @@ -1847,20 +2217,20 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) * * Invoked by a work queue tasklet. */ -static void xs_tcp_setup_socket(struct rpc_xprt *xprt, - struct sock_xprt *transport, - struct socket *(*create_sock)(struct rpc_xprt *, - struct sock_xprt *)) +static void xs_tcp_setup_socket(struct work_struct *work) { + struct sock_xprt *transport = + container_of(work, struct sock_xprt, connect_worker.work); struct socket *sock = transport->sock; + struct rpc_xprt *xprt = &transport->xprt; int status = -EIO; - if (xprt->shutdown) - goto out; + current->flags |= PF_FSTRANS; if (!sock) { clear_bit(XPRT_CONNECTION_ABORT, &xprt->state); - sock = create_sock(xprt, transport); + sock = xs_create_sock(xprt, transport, + xs_addr(xprt)->sa_family, SOCK_STREAM, IPPROTO_TCP); if (IS_ERR(sock)) { status = PTR_ERR(sock); goto out; @@ -1871,7 +2241,7 @@ static void xs_tcp_setup_socket(struct rpc_xprt *xprt, abort_and_exit = test_and_clear_bit(XPRT_CONNECTION_ABORT, &xprt->state); /* "close" the socket, preserving the local port */ - xs_tcp_reuse_connection(xprt, transport); + xs_tcp_reuse_connection(transport); if (abort_and_exit) goto out_eagain; @@ -1884,6 +2254,7 @@ static void xs_tcp_setup_socket(struct rpc_xprt *xprt, xprt->address_strings[RPC_DISPLAY_PORT]); status = xs_tcp_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, status); dprintk("RPC: %p connect status %d connected %d sock state %d\n", xprt, -status, xprt_connected(xprt), sock->sk->sk_state); @@ -1895,22 +2266,22 @@ static void xs_tcp_setup_socket(struct rpc_xprt *xprt, /* We're probably in TIME_WAIT. Get rid of existing socket, * and retry */ - set_bit(XPRT_CONNECTION_CLOSE, &xprt->state); - xprt_force_disconnect(xprt); + xs_tcp_force_close(xprt); break; - case -ECONNREFUSED: - case -ECONNRESET: - case -ENETUNREACH: - /* retry with existing socket, after a delay */ case 0: case -EINPROGRESS: case -EALREADY: xprt_clear_connecting(xprt); + current->flags &= ~PF_FSTRANS; return; case -EINVAL: /* Happens, for instance, if the user specified a link * local IPv6 address without a scope-id. */ + case -ECONNREFUSED: + case -ECONNRESET: + case -ENETUNREACH: + /* retry with existing socket, after a delay */ goto out; } out_eagain: @@ -1918,88 +2289,12 @@ out_eagain: out: xprt_clear_connecting(xprt); xprt_wake_pending_tasks(xprt, status); -} - -static struct socket *xs_create_tcp_sock4(struct rpc_xprt *xprt, - struct sock_xprt *transport) -{ - struct socket *sock; - int err; - - /* start from scratch */ - err = sock_create_kern(PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock); - if (err < 0) { - dprintk("RPC: can't create TCP transport socket (%d).\n", - -err); - goto out_err; - } - xs_reclassify_socket4(sock); - - if (xs_bind4(transport, sock) < 0) { - sock_release(sock); - goto out_err; - } - return sock; -out_err: - return ERR_PTR(-EIO); -} - -/** - * xs_tcp_connect_worker4 - connect a TCP socket to a remote endpoint - * @work: RPC transport to connect - * - * Invoked by a work queue tasklet. - */ -static void xs_tcp_connect_worker4(struct work_struct *work) -{ - struct sock_xprt *transport = - container_of(work, struct sock_xprt, connect_worker.work); - struct rpc_xprt *xprt = &transport->xprt; - - xs_tcp_setup_socket(xprt, transport, xs_create_tcp_sock4); -} - -static struct socket *xs_create_tcp_sock6(struct rpc_xprt *xprt, - struct sock_xprt *transport) -{ - struct socket *sock; - int err; - - /* start from scratch */ - err = sock_create_kern(PF_INET6, SOCK_STREAM, IPPROTO_TCP, &sock); - if (err < 0) { - dprintk("RPC: can't create TCP transport socket (%d).\n", - -err); - goto out_err; - } - xs_reclassify_socket6(sock); - - if (xs_bind6(transport, sock) < 0) { - sock_release(sock); - goto out_err; - } - return sock; -out_err: - return ERR_PTR(-EIO); -} - -/** - * xs_tcp_connect_worker6 - connect a TCP socket to a remote endpoint - * @work: RPC transport to connect - * - * Invoked by a work queue tasklet. - */ -static void xs_tcp_connect_worker6(struct work_struct *work) -{ - struct sock_xprt *transport = - container_of(work, struct sock_xprt, connect_worker.work); - struct rpc_xprt *xprt = &transport->xprt; - - xs_tcp_setup_socket(xprt, transport, xs_create_tcp_sock6); + current->flags &= ~PF_FSTRANS; } /** * xs_connect - connect a socket to a remote endpoint + * @xprt: pointer to transport structure * @task: address of RPC task that manages state of connect request * * TCP: If the remote end dropped the connection, delay reconnecting. @@ -2011,14 +2306,10 @@ static void xs_tcp_connect_worker6(struct work_struct *work) * If a UDP socket connect fails, the delay behavior here prevents * retry floods (hard mounts). */ -static void xs_connect(struct rpc_task *task) +static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task) { - struct rpc_xprt *xprt = task->tk_xprt; struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - if (xprt_test_and_set_connecting(xprt)) - return; - if (transport->sock != NULL && !RPC_IS_SOFTCONN(task)) { dprintk("RPC: xs_connect delayed xprt %p for %lu " "seconds\n", @@ -2038,14 +2329,33 @@ static void xs_connect(struct rpc_task *task) } } -static void xs_tcp_connect(struct rpc_task *task) +/** + * xs_local_print_stats - display AF_LOCAL socket-specifc stats + * @xprt: rpc_xprt struct containing statistics + * @seq: output file + * + */ +static void xs_local_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) { - struct rpc_xprt *xprt = task->tk_xprt; + long idle_time = 0; - /* Exit if we need to wait for socket shutdown to complete */ - if (test_bit(XPRT_CLOSING, &xprt->state)) - return; - xs_connect(task); + if (xprt_connected(xprt)) + idle_time = (long)(jiffies - xprt->last_used) / HZ; + + seq_printf(seq, "\txprt:\tlocal %lu %lu %lu %ld %lu %lu %lu " + "%llu %llu %lu %llu %llu\n", + xprt->stat.bind_count, + xprt->stat.connect_count, + xprt->stat.connect_time, + idle_time, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u, + xprt->stat.max_slots, + xprt->stat.sending_u, + xprt->stat.pending_u); } /** @@ -2058,14 +2368,18 @@ static void xs_udp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - seq_printf(seq, "\txprt:\tudp %u %lu %lu %lu %lu %Lu %Lu\n", + seq_printf(seq, "\txprt:\tudp %u %lu %lu %lu %lu %llu %llu " + "%lu %llu %llu\n", transport->srcport, xprt->stat.bind_count, xprt->stat.sends, xprt->stat.recvs, xprt->stat.bad_xids, xprt->stat.req_u, - xprt->stat.bklog_u); + xprt->stat.bklog_u, + xprt->stat.max_slots, + xprt->stat.sending_u, + xprt->stat.pending_u); } /** @@ -2082,7 +2396,8 @@ static void xs_tcp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) if (xprt_connected(xprt)) idle_time = (long)(jiffies - xprt->last_used) / HZ; - seq_printf(seq, "\txprt:\ttcp %u %lu %lu %lu %ld %lu %lu %lu %Lu %Lu\n", + seq_printf(seq, "\txprt:\ttcp %u %lu %lu %lu %ld %lu %lu %lu " + "%llu %llu %lu %llu %llu\n", transport->srcport, xprt->stat.bind_count, xprt->stat.connect_count, @@ -2092,7 +2407,10 @@ static void xs_tcp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) xprt->stat.recvs, xprt->stat.bad_xids, xprt->stat.req_u, - xprt->stat.bklog_u); + xprt->stat.bklog_u, + xprt->stat.max_slots, + xprt->stat.sending_u, + xprt->stat.pending_u); } /* @@ -2105,9 +2423,11 @@ static void *bc_malloc(struct rpc_task *task, size_t size) struct page *page; struct rpc_buffer *buf; - BUG_ON(size > PAGE_SIZE - sizeof(struct rpc_buffer)); - page = alloc_page(GFP_KERNEL); + WARN_ON_ONCE(size > PAGE_SIZE - sizeof(struct rpc_buffer)); + if (size > PAGE_SIZE - sizeof(struct rpc_buffer)) + return NULL; + page = alloc_page(GFP_KERNEL); if (!page) return NULL; @@ -2146,10 +2466,7 @@ static int bc_sendto(struct rpc_rqst *req) unsigned long headoff; unsigned long tailoff; - /* - * Set up the rpc header and record marker stuff - */ - xs_encode_tcp_record_marker(xbufp); + xs_encode_stream_record_marker(xbufp); tailoff = (unsigned long)xbufp->tail[0].iov_base & ~PAGE_MASK; headoff = (unsigned long)xbufp->head[0].iov_base & ~PAGE_MASK; @@ -2172,7 +2489,6 @@ static int bc_send_request(struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; struct svc_xprt *xprt; - struct svc_sock *svsk; u32 len; dprintk("sending request with xid: %08x\n", ntohl(req->rq_xid)); @@ -2180,7 +2496,6 @@ static int bc_send_request(struct rpc_task *task) * Get the server socket associated with this callback xprt */ xprt = req->rq_xprt->bc_xprt; - svsk = container_of(xprt, struct svc_sock, sk_xprt); /* * Grab the mutex to serialize data as the connection is shared @@ -2210,7 +2525,6 @@ static int bc_send_request(struct rpc_task *task) static void bc_close(struct rpc_xprt *xprt) { - return; } /* @@ -2220,13 +2534,33 @@ static void bc_close(struct rpc_xprt *xprt) static void bc_destroy(struct rpc_xprt *xprt) { - return; + dprintk("RPC: bc_destroy xprt %p\n", xprt); + + xs_xprt_free(xprt); + module_put(THIS_MODULE); } +static struct rpc_xprt_ops xs_local_ops = { + .reserve_xprt = xprt_reserve_xprt, + .release_xprt = xs_tcp_release_xprt, + .alloc_slot = xprt_alloc_slot, + .rpcbind = xs_local_rpcbind, + .set_port = xs_local_set_port, + .connect = xs_local_connect, + .buf_alloc = rpc_malloc, + .buf_free = rpc_free, + .send_request = xs_local_send_request, + .set_retrans_timeout = xprt_set_retrans_timeout_def, + .close = xs_close, + .destroy = xs_destroy, + .print_stats = xs_local_print_stats, +}; + static struct rpc_xprt_ops xs_udp_ops = { .set_buffer_size = xs_udp_set_buffer_size, .reserve_xprt = xprt_reserve_xprt_cong, .release_xprt = xprt_release_xprt_cong, + .alloc_slot = xprt_alloc_slot, .rpcbind = rpcb_getport_async, .set_port = xs_set_port, .connect = xs_connect, @@ -2244,16 +2578,14 @@ static struct rpc_xprt_ops xs_udp_ops = { static struct rpc_xprt_ops xs_tcp_ops = { .reserve_xprt = xprt_reserve_xprt, .release_xprt = xs_tcp_release_xprt, + .alloc_slot = xprt_lock_and_alloc_slot, .rpcbind = rpcb_getport_async, .set_port = xs_set_port, - .connect = xs_tcp_connect, + .connect = xs_connect, .buf_alloc = rpc_malloc, .buf_free = rpc_free, .send_request = xs_tcp_send_request, .set_retrans_timeout = xprt_set_retrans_timeout_def, -#if defined(CONFIG_NFS_V4_1) - .release_request = bc_release_request, -#endif /* CONFIG_NFS_V4_1 */ .close = xs_tcp_close, .destroy = xs_destroy, .print_stats = xs_tcp_print_stats, @@ -2266,6 +2598,7 @@ static struct rpc_xprt_ops xs_tcp_ops = { static struct rpc_xprt_ops bc_tcp_ops = { .reserve_xprt = xprt_reserve_xprt, .release_xprt = xprt_release_xprt, + .alloc_slot = xprt_alloc_slot, .buf_alloc = bc_malloc, .buf_free = bc_free, .send_request = bc_send_request, @@ -2275,8 +2608,36 @@ static struct rpc_xprt_ops bc_tcp_ops = { .print_stats = xs_tcp_print_stats, }; +static int xs_init_anyaddr(const int family, struct sockaddr *sap) +{ + static const struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + }; + static const struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + }; + + switch (family) { + case AF_LOCAL: + break; + case AF_INET: + memcpy(sap, &sin, sizeof(sin)); + break; + case AF_INET6: + memcpy(sap, &sin6, sizeof(sin6)); + break; + default: + dprintk("RPC: %s: Bad address family\n", __func__); + return -EAFNOSUPPORT; + } + return 0; +} + static struct rpc_xprt *xs_setup_xprt(struct xprt_create *args, - unsigned int slot_table_size) + unsigned int slot_table_size, + unsigned int max_slot_table_size) { struct rpc_xprt *xprt; struct sock_xprt *new; @@ -2286,31 +2647,101 @@ static struct rpc_xprt *xs_setup_xprt(struct xprt_create *args, return ERR_PTR(-EBADF); } - new = kzalloc(sizeof(*new), GFP_KERNEL); - if (new == NULL) { + xprt = xprt_alloc(args->net, sizeof(*new), slot_table_size, + max_slot_table_size); + if (xprt == NULL) { dprintk("RPC: xs_setup_xprt: couldn't allocate " "rpc_xprt\n"); return ERR_PTR(-ENOMEM); } - xprt = &new->xprt; - - xprt->max_reqs = slot_table_size; - xprt->slot = kcalloc(xprt->max_reqs, sizeof(struct rpc_rqst), GFP_KERNEL); - if (xprt->slot == NULL) { - kfree(xprt); - dprintk("RPC: xs_setup_xprt: couldn't allocate slot " - "table\n"); - return ERR_PTR(-ENOMEM); - } + new = container_of(xprt, struct sock_xprt, xprt); memcpy(&xprt->addr, args->dstaddr, args->addrlen); xprt->addrlen = args->addrlen; if (args->srcaddr) memcpy(&new->srcaddr, args->srcaddr, args->addrlen); + else { + int err; + err = xs_init_anyaddr(args->dstaddr->sa_family, + (struct sockaddr *)&new->srcaddr); + if (err != 0) { + xprt_free(xprt); + return ERR_PTR(err); + } + } return xprt; } +static const struct rpc_timeout xs_local_default_timeout = { + .to_initval = 10 * HZ, + .to_maxval = 10 * HZ, + .to_retries = 2, +}; + +/** + * xs_setup_local - Set up transport to use an AF_LOCAL socket + * @args: rpc transport creation arguments + * + * AF_LOCAL is a "tpi_cots_ord" transport, just like TCP + */ +static struct rpc_xprt *xs_setup_local(struct xprt_create *args) +{ + struct sockaddr_un *sun = (struct sockaddr_un *)args->dstaddr; + struct sock_xprt *transport; + struct rpc_xprt *xprt; + struct rpc_xprt *ret; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + xprt_max_tcp_slot_table_entries); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = 0; + xprt->tsh_size = sizeof(rpc_fraghdr) / sizeof(u32); + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + + xprt->bind_timeout = XS_BIND_TO; + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_local_ops; + xprt->timeout = &xs_local_default_timeout; + + INIT_DELAYED_WORK(&transport->connect_worker, + xs_dummy_setup_socket); + + switch (sun->sun_family) { + case AF_LOCAL: + if (sun->sun_path[0] != '/') { + dprintk("RPC: bad AF_LOCAL address: %s\n", + sun->sun_path); + ret = ERR_PTR(-EINVAL); + goto out_err; + } + xprt_set_bound(xprt); + xs_format_peer_addresses(xprt, "local", RPCBIND_NETID_LOCAL); + ret = ERR_PTR(xs_local_setup_socket(transport)); + if (ret) + goto out_err; + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + dprintk("RPC: set up xprt to %s via AF_LOCAL\n", + xprt->address_strings[RPC_DISPLAY_ADDR]); + + if (try_module_get(THIS_MODULE)) + return xprt; + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + static const struct rpc_timeout xs_udp_default_timeout = { .to_initval = 5 * HZ, .to_maxval = 30 * HZ, @@ -2328,8 +2759,10 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) struct sockaddr *addr = args->dstaddr; struct rpc_xprt *xprt; struct sock_xprt *transport; + struct rpc_xprt *ret; - xprt = xs_setup_xprt(args, xprt_udp_slot_table_entries); + xprt = xs_setup_xprt(args, xprt_udp_slot_table_entries, + xprt_udp_slot_table_entries); if (IS_ERR(xprt)) return xprt; transport = container_of(xprt, struct sock_xprt, xprt); @@ -2340,7 +2773,6 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) xprt->max_payload = (1U << 16) - (MAX_HEADER << 3); xprt->bind_timeout = XS_BIND_TO; - xprt->connect_timeout = XS_UDP_CONN_TO; xprt->reestablish_timeout = XS_UDP_REEST_TO; xprt->idle_timeout = XS_IDLE_DISC_TO; @@ -2354,7 +2786,7 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) xprt_set_bound(xprt); INIT_DELAYED_WORK(&transport->connect_worker, - xs_udp_connect_worker4); + xs_udp_setup_socket); xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP); break; case AF_INET6: @@ -2362,12 +2794,12 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) xprt_set_bound(xprt); INIT_DELAYED_WORK(&transport->connect_worker, - xs_udp_connect_worker6); + xs_udp_setup_socket); xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP6); break; default: - kfree(xprt); - return ERR_PTR(-EAFNOSUPPORT); + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; } if (xprt_bound(xprt)) @@ -2382,10 +2814,10 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) if (try_module_get(THIS_MODULE)) return xprt; - - kfree(xprt->slot); - kfree(xprt); - return ERR_PTR(-EINVAL); + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; } static const struct rpc_timeout xs_tcp_default_timeout = { @@ -2404,8 +2836,14 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) struct sockaddr *addr = args->dstaddr; struct rpc_xprt *xprt; struct sock_xprt *transport; + struct rpc_xprt *ret; + unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries; + + if (args->flags & XPRT_CREATE_INFINITE_SLOTS) + max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT; - xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries); + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + max_slot_table_size); if (IS_ERR(xprt)) return xprt; transport = container_of(xprt, struct sock_xprt, xprt); @@ -2415,7 +2853,6 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; xprt->bind_timeout = XS_BIND_TO; - xprt->connect_timeout = XS_TCP_CONN_TO; xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; xprt->idle_timeout = XS_IDLE_DISC_TO; @@ -2428,7 +2865,7 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) xprt_set_bound(xprt); INIT_DELAYED_WORK(&transport->connect_worker, - xs_tcp_connect_worker4); + xs_tcp_setup_socket); xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP); break; case AF_INET6: @@ -2436,12 +2873,12 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) xprt_set_bound(xprt); INIT_DELAYED_WORK(&transport->connect_worker, - xs_tcp_connect_worker6); + xs_tcp_setup_socket); xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6); break; default: - kfree(xprt); - return ERR_PTR(-EAFNOSUPPORT); + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; } if (xprt_bound(xprt)) @@ -2454,13 +2891,12 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) xprt->address_strings[RPC_DISPLAY_ADDR], xprt->address_strings[RPC_DISPLAY_PROTO]); - if (try_module_get(THIS_MODULE)) return xprt; - - kfree(xprt->slot); - kfree(xprt); - return ERR_PTR(-EINVAL); + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; } /** @@ -2474,11 +2910,10 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args) struct rpc_xprt *xprt; struct sock_xprt *transport; struct svc_sock *bc_sock; + struct rpc_xprt *ret; - if (!args->bc_xprt) - ERR_PTR(-EINVAL); - - xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries); + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + xprt_tcp_slot_table_entries); if (IS_ERR(xprt)) return xprt; transport = container_of(xprt, struct sock_xprt, xprt); @@ -2491,20 +2926,9 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args) /* backchannel */ xprt_set_bound(xprt); xprt->bind_timeout = 0; - xprt->connect_timeout = 0; xprt->reestablish_timeout = 0; xprt->idle_timeout = 0; - /* - * The backchannel uses the same socket connection as the - * forechannel - */ - xprt->bc_xprt = args->bc_xprt; - bc_sock = container_of(args->bc_xprt, struct svc_sock, sk_xprt); - bc_sock->sk_bc_xprt = xprt; - transport->sock = bc_sock->sk_sock; - transport->inet = bc_sock->sk_sk; - xprt->ops = &bc_tcp_ops; switch (addr->sa_family) { @@ -2517,19 +2941,27 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args) RPCBIND_NETID_TCP6); break; default: - kfree(xprt); - return ERR_PTR(-EAFNOSUPPORT); + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; } - if (xprt_bound(xprt)) - dprintk("RPC: set up xprt to %s (port %s) via %s\n", - xprt->address_strings[RPC_DISPLAY_ADDR], - xprt->address_strings[RPC_DISPLAY_PORT], - xprt->address_strings[RPC_DISPLAY_PROTO]); - else - dprintk("RPC: set up xprt to %s (autobind) via %s\n", - xprt->address_strings[RPC_DISPLAY_ADDR], - xprt->address_strings[RPC_DISPLAY_PROTO]); + dprintk("RPC: set up xprt to %s (port %s) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + xprt->address_strings[RPC_DISPLAY_PROTO]); + + /* + * Once we've associated a backchannel xprt with a connection, + * we want to keep it around as long as the connection lasts, + * in case we need to start using it for a backchannel again; + * this reference won't be dropped until bc_xprt is destroyed. + */ + xprt_get(xprt); + args->bc_xprt->xpt_bc_xprt = xprt; + xprt->bc_xprt = args->bc_xprt; + bc_sock = container_of(args->bc_xprt, struct svc_sock, sk_xprt); + transport->sock = bc_sock->sk_sock; + transport->inet = bc_sock->sk_sk; /* * Since we don't want connections for the backchannel, we set @@ -2537,14 +2969,25 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args) */ xprt_set_connected(xprt); - if (try_module_get(THIS_MODULE)) return xprt; - kfree(xprt->slot); - kfree(xprt); - return ERR_PTR(-EINVAL); + + args->bc_xprt->xpt_bc_xprt = NULL; + xprt_put(xprt); + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; } +static struct xprt_class xs_local_transport = { + .list = LIST_HEAD_INIT(xs_local_transport.list), + .name = "named UNIX socket", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_LOCAL, + .setup = xs_setup_local, +}; + static struct xprt_class xs_udp_transport = { .list = LIST_HEAD_INIT(xs_udp_transport.list), .name = "udp", @@ -2580,6 +3023,7 @@ int init_socket_xprt(void) sunrpc_table_header = register_sysctl_table(sunrpc_table); #endif + xprt_register_transport(&xs_local_transport); xprt_register_transport(&xs_udp_transport); xprt_register_transport(&xs_tcp_transport); xprt_register_transport(&xs_bc_tcp_transport); @@ -2600,12 +3044,14 @@ void cleanup_socket_xprt(void) } #endif + xprt_unregister_transport(&xs_local_transport); xprt_unregister_transport(&xs_udp_transport); xprt_unregister_transport(&xs_tcp_transport); xprt_unregister_transport(&xs_bc_tcp_transport); } -static int param_set_uint_minmax(const char *val, struct kernel_param *kp, +static int param_set_uint_minmax(const char *val, + const struct kernel_param *kp, unsigned int min, unsigned int max) { unsigned long num; @@ -2620,39 +3066,60 @@ static int param_set_uint_minmax(const char *val, struct kernel_param *kp, return 0; } -static int param_set_portnr(const char *val, struct kernel_param *kp) +static int param_set_portnr(const char *val, const struct kernel_param *kp) { return param_set_uint_minmax(val, kp, RPC_MIN_RESVPORT, RPC_MAX_RESVPORT); } -static int param_get_portnr(char *buffer, struct kernel_param *kp) -{ - return param_get_uint(buffer, kp); -} +static struct kernel_param_ops param_ops_portnr = { + .set = param_set_portnr, + .get = param_get_uint, +}; + #define param_check_portnr(name, p) \ __param_check(name, p, unsigned int); module_param_named(min_resvport, xprt_min_resvport, portnr, 0644); module_param_named(max_resvport, xprt_max_resvport, portnr, 0644); -static int param_set_slot_table_size(const char *val, struct kernel_param *kp) +static int param_set_slot_table_size(const char *val, + const struct kernel_param *kp) { return param_set_uint_minmax(val, kp, RPC_MIN_SLOT_TABLE, RPC_MAX_SLOT_TABLE); } -static int param_get_slot_table_size(char *buffer, struct kernel_param *kp) +static struct kernel_param_ops param_ops_slot_table_size = { + .set = param_set_slot_table_size, + .get = param_get_uint, +}; + +#define param_check_slot_table_size(name, p) \ + __param_check(name, p, unsigned int); + +static int param_set_max_slot_table_size(const char *val, + const struct kernel_param *kp) { - return param_get_uint(buffer, kp); + return param_set_uint_minmax(val, kp, + RPC_MIN_SLOT_TABLE, + RPC_MAX_SLOT_TABLE_LIMIT); } -#define param_check_slot_table_size(name, p) \ + +static struct kernel_param_ops param_ops_max_slot_table_size = { + .set = param_set_max_slot_table_size, + .get = param_get_uint, +}; + +#define param_check_max_slot_table_size(name, p) \ __param_check(name, p, unsigned int); module_param_named(tcp_slot_table_entries, xprt_tcp_slot_table_entries, slot_table_size, 0644); +module_param_named(tcp_max_slot_table_entries, xprt_max_tcp_slot_table_entries, + max_slot_table_size, 0644); module_param_named(udp_slot_table_entries, xprt_udp_slot_table_entries, slot_table_size, 0644); |
