diff options
Diffstat (limited to 'net/phonet/socket.c')
| -rw-r--r-- | net/phonet/socket.c | 227 | 
1 files changed, 99 insertions, 128 deletions
diff --git a/net/phonet/socket.c b/net/phonet/socket.c index 25f746d20c1..008214a3d5e 100644 --- a/net/phonet/socket.c +++ b/net/phonet/socket.c @@ -5,8 +5,8 @@   *   * Copyright (C) 2008 Nokia Corporation.   * - * Contact: Remi Denis-Courmont <remi.denis-courmont@nokia.com> - * Original author: Sakari Ailus <sakari.ailus@nokia.com> + * Authors: Sakari Ailus <sakari.ailus@nokia.com> + *          Rémi Denis-Courmont   *   * This program is free software; you can redistribute it and/or   * modify it under the terms of the GNU General Public License @@ -31,6 +31,7 @@  #include <net/tcp_states.h>  #include <linux/phonet.h> +#include <linux/export.h>  #include <net/phonet/phonet.h>  #include <net/phonet/pep.h>  #include <net/phonet/pn_dev.h> @@ -52,16 +53,16 @@ static int pn_socket_release(struct socket *sock)  static struct  {  	struct hlist_head hlist[PN_HASHSIZE]; -	spinlock_t lock; +	struct mutex lock;  } pnsocks;  void __init pn_sock_init(void)  { -	unsigned i; +	unsigned int i;  	for (i = 0; i < PN_HASHSIZE; i++)  		INIT_HLIST_HEAD(pnsocks.hlist + i); -	spin_lock_init(&pnsocks.lock); +	mutex_init(&pnsocks.lock);  }  static struct hlist_head *pn_hash_list(u16 obj) @@ -75,16 +76,14 @@ static struct hlist_head *pn_hash_list(u16 obj)   */  struct sock *pn_find_sock_by_sa(struct net *net, const struct sockaddr_pn *spn)  { -	struct hlist_node *node;  	struct sock *sknode;  	struct sock *rval = NULL;  	u16 obj = pn_sockaddr_get_object(spn);  	u8 res = spn->spn_resource;  	struct hlist_head *hlist = pn_hash_list(obj); -	spin_lock_bh(&pnsocks.lock); - -	sk_for_each(sknode, node, hlist) { +	rcu_read_lock(); +	sk_for_each_rcu(sknode, hlist) {  		struct pn_sock *pn = pn_sk(sknode);  		BUG_ON(!pn->sobject); /* unbound socket */ @@ -107,8 +106,7 @@ struct sock *pn_find_sock_by_sa(struct net *net, const struct sockaddr_pn *spn)  		sock_hold(sknode);  		break;  	} - -	spin_unlock_bh(&pnsocks.lock); +	rcu_read_unlock();  	return rval;  } @@ -117,14 +115,13 @@ struct sock *pn_find_sock_by_sa(struct net *net, const struct sockaddr_pn *spn)  void pn_deliver_sock_broadcast(struct net *net, struct sk_buff *skb)  {  	struct hlist_head *hlist = pnsocks.hlist; -	unsigned h; +	unsigned int h; -	spin_lock(&pnsocks.lock); +	rcu_read_lock();  	for (h = 0; h < PN_HASHSIZE; h++) { -		struct hlist_node *node;  		struct sock *sknode; -		sk_for_each(sknode, node, hlist) { +		sk_for_each(sknode, hlist) {  			struct sk_buff *clone;  			if (!net_eq(sock_net(sknode), net)) @@ -140,25 +137,26 @@ void pn_deliver_sock_broadcast(struct net *net, struct sk_buff *skb)  		}  		hlist++;  	} -	spin_unlock(&pnsocks.lock); +	rcu_read_unlock();  }  void pn_sock_hash(struct sock *sk)  {  	struct hlist_head *hlist = pn_hash_list(pn_sk(sk)->sobject); -	spin_lock_bh(&pnsocks.lock); -	sk_add_node(sk, hlist); -	spin_unlock_bh(&pnsocks.lock); +	mutex_lock(&pnsocks.lock); +	sk_add_node_rcu(sk, hlist); +	mutex_unlock(&pnsocks.lock);  }  EXPORT_SYMBOL(pn_sock_hash);  void pn_sock_unhash(struct sock *sk)  { -	spin_lock_bh(&pnsocks.lock); -	sk_del_node_init(sk); -	spin_unlock_bh(&pnsocks.lock); +	mutex_lock(&pnsocks.lock); +	sk_del_node_init_rcu(sk); +	mutex_unlock(&pnsocks.lock);  	pn_sock_unbind_all_res(sk); +	synchronize_rcu();  }  EXPORT_SYMBOL(pn_sock_unhash); @@ -225,15 +223,18 @@ static int pn_socket_autobind(struct socket *sock)  	return 0; /* socket was already bound */  } -#ifdef CONFIG_PHONET_PIPECTRLR  static int pn_socket_connect(struct socket *sock, struct sockaddr *addr,  		int len, int flags)  {  	struct sock *sk = sock->sk; +	struct pn_sock *pn = pn_sk(sk);  	struct sockaddr_pn *spn = (struct sockaddr_pn *)addr; -	long timeo; +	struct task_struct *tsk = current; +	long timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK);  	int err; +	if (pn_socket_autobind(sock)) +		return -ENOBUFS;  	if (len < sizeof(struct sockaddr_pn))  		return -EINVAL;  	if (spn->spn_family != AF_PHONET) @@ -243,82 +244,61 @@ static int pn_socket_connect(struct socket *sock, struct sockaddr *addr,  	switch (sock->state) {  	case SS_UNCONNECTED: -		sk->sk_state = TCP_CLOSE; -		break; -	case SS_CONNECTING: -		switch (sk->sk_state) { -		case TCP_SYN_RECV: -			sock->state = SS_CONNECTED; -			err = -EISCONN; -			goto out; -		case TCP_CLOSE: -			err = -EALREADY; -			if (flags & O_NONBLOCK) -				goto out; -			goto wait_connect; -		} -		break; -	case SS_CONNECTED: -		switch (sk->sk_state) { -		case TCP_SYN_RECV: +		if (sk->sk_state != TCP_CLOSE) {  			err = -EISCONN;  			goto out; -		case TCP_CLOSE: -			sock->state = SS_UNCONNECTED; -			break;  		}  		break; -	case SS_DISCONNECTING: -	case SS_FREE: -		break; +	case SS_CONNECTING: +		err = -EALREADY; +		goto out; +	default: +		err = -EISCONN; +		goto out;  	} -	sk->sk_state = TCP_CLOSE; -	sk_stream_kill_queues(sk); +	pn->dobject = pn_sockaddr_get_object(spn); +	pn->resource = pn_sockaddr_get_resource(spn);  	sock->state = SS_CONNECTING; +  	err = sk->sk_prot->connect(sk, addr, len); -	if (err < 0) { +	if (err) {  		sock->state = SS_UNCONNECTED; -		sk->sk_state = TCP_CLOSE; +		pn->dobject = 0;  		goto out;  	} -	err = -EINPROGRESS; -wait_connect: -	if (sk->sk_state != TCP_SYN_RECV && (flags & O_NONBLOCK)) -		goto out; - -	timeo = sock_sndtimeo(sk, flags & O_NONBLOCK); -	release_sock(sk); - -	err = -ERESTARTSYS; -	timeo = wait_event_interruptible_timeout(*sk_sleep(sk), -			sk->sk_state != TCP_CLOSE, -			timeo); - -	lock_sock(sk); -	if (timeo < 0) -		goto out; /* -ERESTARTSYS */ +	while (sk->sk_state == TCP_SYN_SENT) { +		DEFINE_WAIT(wait); -	err = -ETIMEDOUT; -	if (timeo == 0 && sk->sk_state != TCP_SYN_RECV) -		goto out; +		if (!timeo) { +			err = -EINPROGRESS; +			goto out; +		} +		if (signal_pending(tsk)) { +			err = sock_intr_errno(timeo); +			goto out; +		} -	if (sk->sk_state != TCP_SYN_RECV) { -		sock->state = SS_UNCONNECTED; -		err = sock_error(sk); -		if (!err) -			err = -ECONNREFUSED; -		goto out; +		prepare_to_wait_exclusive(sk_sleep(sk), &wait, +						TASK_INTERRUPTIBLE); +		release_sock(sk); +		timeo = schedule_timeout(timeo); +		lock_sock(sk); +		finish_wait(sk_sleep(sk), &wait);  	} -	sock->state = SS_CONNECTED; -	err = 0; +	if ((1 << sk->sk_state) & (TCPF_SYN_RECV|TCPF_ESTABLISHED)) +		err = 0; +	else if (sk->sk_state == TCP_CLOSE_WAIT) +		err = -ECONNRESET; +	else +		err = -ECONNREFUSED; +	sock->state = err ? SS_UNCONNECTED : SS_CONNECTED;  out:  	release_sock(sk);  	return err;  } -#endif  static int pn_socket_accept(struct socket *sock, struct socket *newsock,  				int flags) @@ -327,6 +307,9 @@ static int pn_socket_accept(struct socket *sock, struct socket *newsock,  	struct sock *newsk;  	int err; +	if (unlikely(sk->sk_state != TCP_LISTEN)) +		return -EINVAL; +  	newsk = sk->sk_prot->accept(sk, flags, &err);  	if (!newsk)  		return err; @@ -363,13 +346,8 @@ static unsigned int pn_socket_poll(struct file *file, struct socket *sock,  	poll_wait(file, sk_sleep(sk), wait); -	switch (sk->sk_state) { -	case TCP_LISTEN: -		return hlist_empty(&pn->ackq) ? 0 : POLLIN; -	case TCP_CLOSE: +	if (sk->sk_state == TCP_CLOSE)  		return POLLERR; -	} -  	if (!skb_queue_empty(&sk->sk_receive_queue))  		mask |= POLLIN | POLLRDNORM;  	if (!skb_queue_empty(&pn->ctrlreq_queue)) @@ -428,19 +406,19 @@ static int pn_socket_listen(struct socket *sock, int backlog)  	struct sock *sk = sock->sk;  	int err = 0; -	if (sock->state != SS_UNCONNECTED) -		return -EINVAL;  	if (pn_socket_autobind(sock))  		return -ENOBUFS;  	lock_sock(sk); -	if (sk->sk_state != TCP_CLOSE) { +	if (sock->state != SS_UNCONNECTED) {  		err = -EINVAL;  		goto out;  	} -	sk->sk_state = TCP_LISTEN; -	sk->sk_ack_backlog = 0; +	if (sk->sk_state != TCP_LISTEN) { +		sk->sk_state = TCP_LISTEN; +		sk->sk_ack_backlog = 0; +	}  	sk->sk_max_ack_backlog = backlog;  out:  	release_sock(sk); @@ -488,11 +466,7 @@ const struct proto_ops phonet_stream_ops = {  	.owner		= THIS_MODULE,  	.release	= pn_socket_release,  	.bind		= pn_socket_bind, -#ifdef CONFIG_PHONET_PIPECTRLR  	.connect	= pn_socket_connect, -#else -	.connect	= sock_no_connect, -#endif  	.socketpair	= sock_no_socketpair,  	.accept		= pn_socket_accept,  	.getname	= pn_socket_getname, @@ -567,12 +541,11 @@ static struct sock *pn_sock_get_idx(struct seq_file *seq, loff_t pos)  {  	struct net *net = seq_file_net(seq);  	struct hlist_head *hlist = pnsocks.hlist; -	struct hlist_node *node;  	struct sock *sknode; -	unsigned h; +	unsigned int h;  	for (h = 0; h < PN_HASHSIZE; h++) { -		sk_for_each(sknode, node, hlist) { +		sk_for_each_rcu(sknode, hlist) {  			if (!net_eq(net, sock_net(sknode)))  				continue;  			if (!pos) @@ -596,9 +569,9 @@ static struct sock *pn_sock_get_next(struct seq_file *seq, struct sock *sk)  }  static void *pn_sock_seq_start(struct seq_file *seq, loff_t *pos) -	__acquires(pnsocks.lock) +	__acquires(rcu)  { -	spin_lock_bh(&pnsocks.lock); +	rcu_read_lock();  	return *pos ? pn_sock_get_idx(seq, *pos - 1) : SEQ_START_TOKEN;  } @@ -615,32 +588,32 @@ static void *pn_sock_seq_next(struct seq_file *seq, void *v, loff_t *pos)  }  static void pn_sock_seq_stop(struct seq_file *seq, void *v) -	__releases(pnsocks.lock) +	__releases(rcu)  { -	spin_unlock_bh(&pnsocks.lock); +	rcu_read_unlock();  }  static int pn_sock_seq_show(struct seq_file *seq, void *v)  { -	int len; - +	seq_setwidth(seq, 127);  	if (v == SEQ_START_TOKEN) -		seq_printf(seq, "%s%n", "pt  loc  rem rs st tx_queue rx_queue " -			"  uid inode ref pointer drops", &len); +		seq_puts(seq, "pt  loc  rem rs st tx_queue rx_queue " +			"  uid inode ref pointer drops");  	else {  		struct sock *sk = v;  		struct pn_sock *pn = pn_sk(sk);  		seq_printf(seq, "%2d %04X:%04X:%02X %02X %08X:%08X %5d %lu " -			"%d %p %d%n", -			sk->sk_protocol, pn->sobject, 0, pn->resource, -			sk->sk_state, +			"%d %pK %d", +			sk->sk_protocol, pn->sobject, pn->dobject, +			pn->resource, sk->sk_state,  			sk_wmem_alloc_get(sk), sk_rmem_alloc_get(sk), -			sock_i_uid(sk), sock_i_ino(sk), +			from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)), +			sock_i_ino(sk),  			atomic_read(&sk->sk_refcnt), sk, -			atomic_read(&sk->sk_drops), &len); +			atomic_read(&sk->sk_drops));  	} -	seq_printf(seq, "%*s\n", 127 - len, ""); +	seq_pad(seq, '\n');  	return 0;  } @@ -720,7 +693,7 @@ int pn_sock_unbind_res(struct sock *sk, u8 res)  	mutex_lock(&resource_mutex);  	if (pnres.sk[res] == sk) { -		rcu_assign_pointer(pnres.sk[res], NULL); +		RCU_INIT_POINTER(pnres.sk[res], NULL);  		ret = 0;  	}  	mutex_unlock(&resource_mutex); @@ -734,31 +707,29 @@ int pn_sock_unbind_res(struct sock *sk, u8 res)  void pn_sock_unbind_all_res(struct sock *sk)  { -	unsigned res, match = 0; +	unsigned int res, match = 0;  	mutex_lock(&resource_mutex);  	for (res = 0; res < 256; res++) {  		if (pnres.sk[res] == sk) { -			rcu_assign_pointer(pnres.sk[res], NULL); +			RCU_INIT_POINTER(pnres.sk[res], NULL);  			match++;  		}  	}  	mutex_unlock(&resource_mutex); -	if (match == 0) -		return; -	synchronize_rcu();  	while (match > 0) { -		sock_put(sk); +		__sock_put(sk);  		match--;  	} +	/* Caller is responsible for RCU sync before final sock_put() */  }  #ifdef CONFIG_PROC_FS  static struct sock **pn_res_get_idx(struct seq_file *seq, loff_t pos)  {  	struct net *net = seq_file_net(seq); -	unsigned i; +	unsigned int i;  	if (!net_eq(net, &init_net))  		return NULL; @@ -776,7 +747,7 @@ static struct sock **pn_res_get_idx(struct seq_file *seq, loff_t pos)  static struct sock **pn_res_get_next(struct seq_file *seq, struct sock **sk)  {  	struct net *net = seq_file_net(seq); -	unsigned i; +	unsigned int i;  	BUG_ON(!net_eq(net, &init_net)); @@ -813,19 +784,19 @@ static void pn_res_seq_stop(struct seq_file *seq, void *v)  static int pn_res_seq_show(struct seq_file *seq, void *v)  { -	int len; - +	seq_setwidth(seq, 63);  	if (v == SEQ_START_TOKEN) -		seq_printf(seq, "%s%n", "rs   uid inode", &len); +		seq_puts(seq, "rs   uid inode");  	else {  		struct sock **psk = v;  		struct sock *sk = *psk; -		seq_printf(seq, "%02X %5d %lu%n", -			   (int) (psk - pnres.sk), sock_i_uid(sk), -			   sock_i_ino(sk), &len); +		seq_printf(seq, "%02X %5u %lu", +			   (int) (psk - pnres.sk), +			   from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)), +			   sock_i_ino(sk));  	} -	seq_printf(seq, "%*s\n", 63 - len, ""); +	seq_pad(seq, '\n');  	return 0;  }  | 
