aboutsummaryrefslogtreecommitdiff
path: root/net/sunrpc/rpcb_clnt.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/sunrpc/rpcb_clnt.c')
-rw-r--r--net/sunrpc/rpcb_clnt.c121
1 files changed, 78 insertions, 43 deletions
diff --git a/net/sunrpc/rpcb_clnt.c b/net/sunrpc/rpcb_clnt.c
index b1f08bd6788..1891a1022c1 100644
--- a/net/sunrpc/rpcb_clnt.c
+++ b/net/sunrpc/rpcb_clnt.c
@@ -23,10 +23,10 @@
#include <linux/errno.h>
#include <linux/mutex.h>
#include <linux/slab.h>
-#include <linux/nsproxy.h>
#include <net/ipv6.h>
#include <linux/sunrpc/clnt.h>
+#include <linux/sunrpc/addr.h>
#include <linux/sunrpc/sched.h>
#include <linux/sunrpc/xprtsock.h>
@@ -180,14 +180,16 @@ void rpcb_put_local(struct net *net)
struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
struct rpc_clnt *clnt = sn->rpcb_local_clnt;
struct rpc_clnt *clnt4 = sn->rpcb_local_clnt4;
- int shutdown;
+ int shutdown = 0;
spin_lock(&sn->rpcb_clnt_lock);
- if (--sn->rpcb_users == 0) {
- sn->rpcb_local_clnt = NULL;
- sn->rpcb_local_clnt4 = NULL;
+ if (sn->rpcb_users) {
+ if (--sn->rpcb_users == 0) {
+ sn->rpcb_local_clnt = NULL;
+ sn->rpcb_local_clnt4 = NULL;
+ }
+ shutdown = !sn->rpcb_users;
}
- shutdown = !sn->rpcb_users;
spin_unlock(&sn->rpcb_clnt_lock);
if (shutdown) {
@@ -202,13 +204,15 @@ void rpcb_put_local(struct net *net)
}
static void rpcb_set_local(struct net *net, struct rpc_clnt *clnt,
- struct rpc_clnt *clnt4)
+ struct rpc_clnt *clnt4,
+ bool is_af_local)
{
struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
/* Protected by rpcb_create_local_mutex */
sn->rpcb_local_clnt = clnt;
sn->rpcb_local_clnt4 = clnt4;
+ sn->rpcb_is_af_local = is_af_local ? 1 : 0;
smp_wmb();
sn->rpcb_users = 1;
dprintk("RPC: created new rpcb local clients (rpcb_local_clnt: "
@@ -236,6 +240,14 @@ static int rpcb_create_local_unix(struct net *net)
.program = &rpcb_program,
.version = RPCBVERS_2,
.authflavor = RPC_AUTH_NULL,
+ /*
+ * We turn off the idle timeout to prevent the kernel
+ * from automatically disconnecting the socket.
+ * Otherwise, we'd have to cache the mount namespace
+ * of the caller and somehow pass that to the socket
+ * reconnect code.
+ */
+ .flags = RPC_CLNT_CREATE_NO_IDLE_TIMEOUT,
};
struct rpc_clnt *clnt, *clnt4;
int result = 0;
@@ -249,7 +261,7 @@ static int rpcb_create_local_unix(struct net *net)
if (IS_ERR(clnt)) {
dprintk("RPC: failed to create AF_LOCAL rpcbind "
"client (errno %ld).\n", PTR_ERR(clnt));
- result = -PTR_ERR(clnt);
+ result = PTR_ERR(clnt);
goto out;
}
@@ -261,7 +273,7 @@ static int rpcb_create_local_unix(struct net *net)
clnt4 = NULL;
}
- rpcb_set_local(net, clnt, clnt4);
+ rpcb_set_local(net, clnt, clnt4, true);
out:
return result;
@@ -296,7 +308,7 @@ static int rpcb_create_local_net(struct net *net)
if (IS_ERR(clnt)) {
dprintk("RPC: failed to create local rpcbind "
"client (errno %ld).\n", PTR_ERR(clnt));
- result = -PTR_ERR(clnt);
+ result = PTR_ERR(clnt);
goto out;
}
@@ -313,7 +325,7 @@ static int rpcb_create_local_net(struct net *net)
clnt4 = NULL;
}
- rpcb_set_local(net, clnt, clnt4);
+ rpcb_set_local(net, clnt, clnt4, false);
out:
return result;
@@ -374,13 +386,16 @@ static struct rpc_clnt *rpcb_create(struct net *net, const char *hostname,
return rpc_create(&args);
}
-static int rpcb_register_call(struct rpc_clnt *clnt, struct rpc_message *msg)
+static int rpcb_register_call(struct sunrpc_net *sn, struct rpc_clnt *clnt, struct rpc_message *msg, bool is_set)
{
- int result, error = 0;
+ int flags = RPC_TASK_NOCONNECT;
+ int error, result = 0;
+ if (is_set || !sn->rpcb_is_af_local)
+ flags = RPC_TASK_SOFTCONN;
msg->rpc_resp = &result;
- error = rpc_call_sync(clnt, msg, RPC_TASK_SOFTCONN);
+ error = rpc_call_sync(clnt, msg, flags);
if (error < 0) {
dprintk("RPC: failed to contact local rpcbind "
"server (errno %d).\n", -error);
@@ -394,6 +409,7 @@ static int rpcb_register_call(struct rpc_clnt *clnt, struct rpc_message *msg)
/**
* rpcb_register - set or unset a port registration with the local rpcbind svc
+ * @net: target network namespace
* @prog: RPC program number to bind
* @vers: RPC version number to bind
* @prot: transport protocol to register
@@ -436,16 +452,19 @@ int rpcb_register(struct net *net, u32 prog, u32 vers, int prot, unsigned short
.rpc_argp = &map,
};
struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+ bool is_set = false;
dprintk("RPC: %sregistering (%u, %u, %d, %u) with local "
"rpcbind\n", (port ? "" : "un"),
prog, vers, prot, port);
msg.rpc_proc = &rpcb_procedures2[RPCBPROC_UNSET];
- if (port)
+ if (port != 0) {
msg.rpc_proc = &rpcb_procedures2[RPCBPROC_SET];
+ is_set = true;
+ }
- return rpcb_register_call(sn->rpcb_local_clnt, &msg);
+ return rpcb_register_call(sn, sn->rpcb_local_clnt, &msg, is_set);
}
/*
@@ -458,6 +477,7 @@ static int rpcb_register_inet4(struct sunrpc_net *sn,
const struct sockaddr_in *sin = (const struct sockaddr_in *)sap;
struct rpcbind_args *map = msg->rpc_argp;
unsigned short port = ntohs(sin->sin_port);
+ bool is_set = false;
int result;
map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL);
@@ -468,10 +488,12 @@ static int rpcb_register_inet4(struct sunrpc_net *sn,
map->r_addr, map->r_netid);
msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET];
- if (port)
+ if (port != 0) {
msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET];
+ is_set = true;
+ }
- result = rpcb_register_call(sn->rpcb_local_clnt4, msg);
+ result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set);
kfree(map->r_addr);
return result;
}
@@ -486,6 +508,7 @@ static int rpcb_register_inet6(struct sunrpc_net *sn,
const struct sockaddr_in6 *sin6 = (const struct sockaddr_in6 *)sap;
struct rpcbind_args *map = msg->rpc_argp;
unsigned short port = ntohs(sin6->sin6_port);
+ bool is_set = false;
int result;
map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL);
@@ -496,10 +519,12 @@ static int rpcb_register_inet6(struct sunrpc_net *sn,
map->r_addr, map->r_netid);
msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET];
- if (port)
+ if (port != 0) {
msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET];
+ is_set = true;
+ }
- result = rpcb_register_call(sn->rpcb_local_clnt4, msg);
+ result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set);
kfree(map->r_addr);
return result;
}
@@ -516,11 +541,12 @@ static int rpcb_unregister_all_protofamilies(struct sunrpc_net *sn,
map->r_addr = "";
msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET];
- return rpcb_register_call(sn->rpcb_local_clnt4, msg);
+ return rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, false);
}
/**
* rpcb_v4_register - set or unset a port registration with the local rpcbind
+ * @net: target network namespace
* @program: RPC program number of service to (un)register
* @version: RPC version number of service to (un)register
* @address: address family, IP address, and port to (un)register
@@ -620,9 +646,10 @@ static struct rpc_task *rpcb_call_async(struct rpc_clnt *rpcb_clnt, struct rpcbi
static struct rpc_clnt *rpcb_find_transport_owner(struct rpc_clnt *clnt)
{
struct rpc_clnt *parent = clnt->cl_parent;
+ struct rpc_xprt *xprt = rcu_dereference(clnt->cl_xprt);
while (parent != clnt) {
- if (parent->cl_xprt != clnt->cl_xprt)
+ if (rcu_dereference(parent->cl_xprt) != xprt)
break;
if (clnt->cl_autobind)
break;
@@ -653,12 +680,16 @@ void rpcb_getport_async(struct rpc_task *task)
size_t salen;
int status;
- clnt = rpcb_find_transport_owner(task->tk_client);
- xprt = clnt->cl_xprt;
+ rcu_read_lock();
+ do {
+ clnt = rpcb_find_transport_owner(task->tk_client);
+ xprt = xprt_get(rcu_dereference(clnt->cl_xprt));
+ } while (xprt == NULL);
+ rcu_read_unlock();
dprintk("RPC: %5u %s(%s, %u, %u, %d)\n",
task->tk_pid, __func__,
- clnt->cl_server, clnt->cl_prog, clnt->cl_vers, xprt->prot);
+ xprt->servername, clnt->cl_prog, clnt->cl_vers, xprt->prot);
/* Put self on the wait queue to ensure we get notified if
* some other task is already attempting to bind the port */
@@ -667,6 +698,7 @@ void rpcb_getport_async(struct rpc_task *task)
if (xprt_test_and_set_binding(xprt)) {
dprintk("RPC: %5u %s: waiting for another binder\n",
task->tk_pid, __func__);
+ xprt_put(xprt);
return;
}
@@ -708,7 +740,7 @@ void rpcb_getport_async(struct rpc_task *task)
dprintk("RPC: %5u %s: trying rpcbind version %u\n",
task->tk_pid, __func__, bind_version);
- rpcb_clnt = rpcb_create(xprt->xprt_net, clnt->cl_server, sap, salen,
+ rpcb_clnt = rpcb_create(xprt->xprt_net, xprt->servername, sap, salen,
xprt->prot, bind_version);
if (IS_ERR(rpcb_clnt)) {
status = PTR_ERR(rpcb_clnt);
@@ -728,13 +760,13 @@ void rpcb_getport_async(struct rpc_task *task)
map->r_vers = clnt->cl_vers;
map->r_prot = xprt->prot;
map->r_port = 0;
- map->r_xprt = xprt_get(xprt);
+ map->r_xprt = xprt;
map->r_status = -EIO;
switch (bind_version) {
case RPCBVERS_4:
case RPCBVERS_3:
- map->r_netid = rpc_peeraddr2str(clnt, RPC_DISPLAY_NETID);
+ map->r_netid = xprt->address_strings[RPC_DISPLAY_NETID];
map->r_addr = rpc_sockaddr2uaddr(sap, GFP_ATOMIC);
map->r_owner = "";
break;
@@ -763,6 +795,7 @@ bailout_release_client:
bailout_nofree:
rpcb_wake_rpcbind_waiters(xprt, status);
task->tk_status = status;
+ xprt_put(xprt);
}
EXPORT_SYMBOL_GPL(rpcb_getport_async);
@@ -810,11 +843,11 @@ static void rpcb_getport_done(struct rpc_task *child, void *data)
static void rpcb_enc_mapping(struct rpc_rqst *req, struct xdr_stream *xdr,
const struct rpcbind_args *rpcb)
{
- struct rpc_task *task = req->rq_task;
__be32 *p;
dprintk("RPC: %5u encoding PMAP_%s call (%u, %u, %d, %u)\n",
- task->tk_pid, task->tk_msg.rpc_proc->p_name,
+ req->rq_task->tk_pid,
+ req->rq_task->tk_msg.rpc_proc->p_name,
rpcb->r_prog, rpcb->r_vers, rpcb->r_prot, rpcb->r_port);
p = xdr_reserve_space(xdr, RPCB_mappingargs_sz << 2);
@@ -827,7 +860,6 @@ static void rpcb_enc_mapping(struct rpc_rqst *req, struct xdr_stream *xdr,
static int rpcb_dec_getport(struct rpc_rqst *req, struct xdr_stream *xdr,
struct rpcbind_args *rpcb)
{
- struct rpc_task *task = req->rq_task;
unsigned long port;
__be32 *p;
@@ -838,8 +870,8 @@ static int rpcb_dec_getport(struct rpc_rqst *req, struct xdr_stream *xdr,
return -EIO;
port = be32_to_cpup(p);
- dprintk("RPC: %5u PMAP_%s result: %lu\n", task->tk_pid,
- task->tk_msg.rpc_proc->p_name, port);
+ dprintk("RPC: %5u PMAP_%s result: %lu\n", req->rq_task->tk_pid,
+ req->rq_task->tk_msg.rpc_proc->p_name, port);
if (unlikely(port > USHRT_MAX))
return -EIO;
@@ -850,7 +882,6 @@ static int rpcb_dec_getport(struct rpc_rqst *req, struct xdr_stream *xdr,
static int rpcb_dec_set(struct rpc_rqst *req, struct xdr_stream *xdr,
unsigned int *boolp)
{
- struct rpc_task *task = req->rq_task;
__be32 *p;
p = xdr_inline_decode(xdr, 4);
@@ -862,7 +893,8 @@ static int rpcb_dec_set(struct rpc_rqst *req, struct xdr_stream *xdr,
*boolp = 1;
dprintk("RPC: %5u RPCB_%s call %s\n",
- task->tk_pid, task->tk_msg.rpc_proc->p_name,
+ req->rq_task->tk_pid,
+ req->rq_task->tk_msg.rpc_proc->p_name,
(*boolp ? "succeeded" : "failed"));
return 0;
}
@@ -874,7 +906,10 @@ static void encode_rpcb_string(struct xdr_stream *xdr, const char *string,
u32 len;
len = strlen(string);
- BUG_ON(len > maxstrlen);
+ WARN_ON_ONCE(len > maxstrlen);
+ if (len > maxstrlen)
+ /* truncate and hope for the best */
+ len = maxstrlen;
p = xdr_reserve_space(xdr, 4 + len);
xdr_encode_opaque(p, string, len);
}
@@ -882,11 +917,11 @@ static void encode_rpcb_string(struct xdr_stream *xdr, const char *string,
static void rpcb_enc_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr,
const struct rpcbind_args *rpcb)
{
- struct rpc_task *task = req->rq_task;
__be32 *p;
dprintk("RPC: %5u encoding RPCB_%s call (%u, %u, '%s', '%s')\n",
- task->tk_pid, task->tk_msg.rpc_proc->p_name,
+ req->rq_task->tk_pid,
+ req->rq_task->tk_msg.rpc_proc->p_name,
rpcb->r_prog, rpcb->r_vers,
rpcb->r_netid, rpcb->r_addr);
@@ -904,7 +939,6 @@ static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr,
{
struct sockaddr_storage address;
struct sockaddr *sap = (struct sockaddr *)&address;
- struct rpc_task *task = req->rq_task;
__be32 *p;
u32 len;
@@ -921,7 +955,7 @@ static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr,
*/
if (len == 0) {
dprintk("RPC: %5u RPCB reply: program not registered\n",
- task->tk_pid);
+ req->rq_task->tk_pid);
return 0;
}
@@ -931,8 +965,8 @@ static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr,
p = xdr_inline_decode(xdr, len);
if (unlikely(p == NULL))
goto out_fail;
- dprintk("RPC: %5u RPCB_%s reply: %s\n", task->tk_pid,
- task->tk_msg.rpc_proc->p_name, (char *)p);
+ dprintk("RPC: %5u RPCB_%s reply: %s\n", req->rq_task->tk_pid,
+ req->rq_task->tk_msg.rpc_proc->p_name, (char *)p);
if (rpc_uaddr2sockaddr(req->rq_xprt->xprt_net, (char *)p, len,
sap, sizeof(address)) == 0)
@@ -943,7 +977,8 @@ static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr,
out_fail:
dprintk("RPC: %5u malformed RPCB_%s reply\n",
- task->tk_pid, task->tk_msg.rpc_proc->p_name);
+ req->rq_task->tk_pid,
+ req->rq_task->tk_msg.rpc_proc->p_name);
return -EIO;
}