diff options
Diffstat (limited to 'net/ipv4/inet_diag.c')
| -rw-r--r-- | net/ipv4/inet_diag.c | 1040 |
1 files changed, 682 insertions, 358 deletions
diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c index 71f3c7350c6..e34dccbc4d7 100644 --- a/net/ipv4/inet_diag.c +++ b/net/ipv4/inet_diag.c @@ -1,8 +1,6 @@ /* * inet_diag.c Module for monitoring INET transport protocols sockets. * - * Version: $Id: inet_diag.c,v 1.3 2002/02/01 22:01:04 davem Exp $ - * * Authors: Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru> * * This program is free software; you can redistribute it and/or @@ -11,11 +9,12 @@ * 2 of the License, or (at your option) any later version. */ -#include <linux/config.h> +#include <linux/kernel.h> #include <linux/module.h> #include <linux/types.h> #include <linux/fcntl.h> #include <linux/random.h> +#include <linux/slab.h> #include <linux/cache.h> #include <linux/init.h> #include <linux/time.h> @@ -28,121 +27,144 @@ #include <net/inet_hashtables.h> #include <net/inet_timewait_sock.h> #include <net/inet6_hashtables.h> +#include <net/netlink.h> #include <linux/inet.h> #include <linux/stddef.h> #include <linux/inet_diag.h> +#include <linux/sock_diag.h> static const struct inet_diag_handler **inet_diag_table; struct inet_diag_entry { - u32 *saddr; - u32 *daddr; + __be32 *saddr; + __be32 *daddr; u16 sport; u16 dport; u16 family; u16 userlocks; +#if IS_ENABLED(CONFIG_IPV6) + struct in6_addr saddr_storage; /* for IPv4-mapped-IPv6 addresses */ + struct in6_addr daddr_storage; /* for IPv4-mapped-IPv6 addresses */ +#endif }; -static struct sock *idiagnl; +static DEFINE_MUTEX(inet_diag_table_mutex); -#define INET_DIAG_PUT(skb, attrtype, attrlen) \ - RTA_DATA(__RTA_PUT(skb, attrtype, attrlen)) +static const struct inet_diag_handler *inet_diag_lock_handler(int proto) +{ + if (!inet_diag_table[proto]) + request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK, + NETLINK_SOCK_DIAG, AF_INET, proto); -static int inet_diag_fill(struct sk_buff *skb, struct sock *sk, - int ext, u32 pid, u32 seq, u16 nlmsg_flags, - const struct nlmsghdr *unlh) + mutex_lock(&inet_diag_table_mutex); + if (!inet_diag_table[proto]) + return ERR_PTR(-ENOENT); + + return inet_diag_table[proto]; +} + +static inline void inet_diag_unlock_handler( + const struct inet_diag_handler *handler) +{ + mutex_unlock(&inet_diag_table_mutex); +} + +int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, + struct sk_buff *skb, struct inet_diag_req_v2 *req, + struct user_namespace *user_ns, + u32 portid, u32 seq, u16 nlmsg_flags, + const struct nlmsghdr *unlh) { const struct inet_sock *inet = inet_sk(sk); - const struct inet_connection_sock *icsk = inet_csk(sk); struct inet_diag_msg *r; struct nlmsghdr *nlh; + struct nlattr *attr; void *info = NULL; - struct inet_diag_meminfo *minfo = NULL; - unsigned char *b = skb->tail; const struct inet_diag_handler *handler; + int ext = req->idiag_ext; - handler = inet_diag_table[unlh->nlmsg_type]; + handler = inet_diag_table[req->sdiag_protocol]; BUG_ON(handler == NULL); - nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r)); - nlh->nlmsg_flags = nlmsg_flags; - - r = NLMSG_DATA(nlh); - if (sk->sk_state != TCP_TIME_WAIT) { - if (ext & (1 << (INET_DIAG_MEMINFO - 1))) - minfo = INET_DIAG_PUT(skb, INET_DIAG_MEMINFO, - sizeof(*minfo)); - if (ext & (1 << (INET_DIAG_INFO - 1))) - info = INET_DIAG_PUT(skb, INET_DIAG_INFO, - handler->idiag_info_size); - - if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops) { - size_t len = strlen(icsk->icsk_ca_ops->name); - strcpy(INET_DIAG_PUT(skb, INET_DIAG_CONG, len + 1), - icsk->icsk_ca_ops->name); - } - } + nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r), + nlmsg_flags); + if (!nlh) + return -EMSGSIZE; + + r = nlmsg_data(nlh); + BUG_ON(sk->sk_state == TCP_TIME_WAIT); + r->idiag_family = sk->sk_family; r->idiag_state = sk->sk_state; r->idiag_timer = 0; r->idiag_retrans = 0; r->id.idiag_if = sk->sk_bound_dev_if; - r->id.idiag_cookie[0] = (u32)(unsigned long)sk; - r->id.idiag_cookie[1] = (u32)(((unsigned long)sk >> 31) >> 1); - - if (r->idiag_state == TCP_TIME_WAIT) { - const struct inet_timewait_sock *tw = inet_twsk(sk); - long tmo = tw->tw_ttd - jiffies; - if (tmo < 0) - tmo = 0; - - r->id.idiag_sport = tw->tw_sport; - r->id.idiag_dport = tw->tw_dport; - r->id.idiag_src[0] = tw->tw_rcv_saddr; - r->id.idiag_dst[0] = tw->tw_daddr; - r->idiag_state = tw->tw_substate; - r->idiag_timer = 3; - r->idiag_expires = (tmo * 1000 + HZ - 1) / HZ; - r->idiag_rqueue = 0; - r->idiag_wqueue = 0; - r->idiag_uid = 0; - r->idiag_inode = 0; -#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE) - if (r->idiag_family == AF_INET6) { - const struct tcp6_timewait_sock *tcp6tw = tcp6_twsk(sk); - - ipv6_addr_copy((struct in6_addr *)r->id.idiag_src, - &tcp6tw->tw_v6_rcv_saddr); - ipv6_addr_copy((struct in6_addr *)r->id.idiag_dst, - &tcp6tw->tw_v6_daddr); - } -#endif - nlh->nlmsg_len = skb->tail - b; - return skb->len; - } + sock_diag_save_cookie(sk, r->id.idiag_cookie); + + r->id.idiag_sport = inet->inet_sport; + r->id.idiag_dport = inet->inet_dport; + + memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src)); + memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst)); + + r->id.idiag_src[0] = inet->inet_rcv_saddr; + r->id.idiag_dst[0] = inet->inet_daddr; + + if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown)) + goto errout; - r->id.idiag_sport = inet->sport; - r->id.idiag_dport = inet->dport; - r->id.idiag_src[0] = inet->rcv_saddr; - r->id.idiag_dst[0] = inet->daddr; + /* IPv6 dual-stack sockets use inet->tos for IPv4 connections, + * hence this needs to be included regardless of socket family. + */ + if (ext & (1 << (INET_DIAG_TOS - 1))) + if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0) + goto errout; -#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE) +#if IS_ENABLED(CONFIG_IPV6) if (r->idiag_family == AF_INET6) { - struct ipv6_pinfo *np = inet6_sk(sk); - ipv6_addr_copy((struct in6_addr *)r->id.idiag_src, - &np->rcv_saddr); - ipv6_addr_copy((struct in6_addr *)r->id.idiag_dst, - &np->daddr); + *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr; + *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr; + + if (ext & (1 << (INET_DIAG_TCLASS - 1))) + if (nla_put_u8(skb, INET_DIAG_TCLASS, + inet6_sk(sk)->tclass) < 0) + goto errout; } #endif -#define EXPIRES_IN_MS(tmo) ((tmo - jiffies) * 1000 + HZ - 1) / HZ + r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk)); + r->idiag_inode = sock_i_ino(sk); + + if (ext & (1 << (INET_DIAG_MEMINFO - 1))) { + struct inet_diag_meminfo minfo = { + .idiag_rmem = sk_rmem_alloc_get(sk), + .idiag_wmem = sk->sk_wmem_queued, + .idiag_fmem = sk->sk_forward_alloc, + .idiag_tmem = sk_wmem_alloc_get(sk), + }; + + if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0) + goto errout; + } + + if (ext & (1 << (INET_DIAG_SKMEMINFO - 1))) + if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO)) + goto errout; + + if (icsk == NULL) { + handler->idiag_get_info(sk, r, NULL); + goto out; + } - if (icsk->icsk_pending == ICSK_TIME_RETRANS) { +#define EXPIRES_IN_MS(tmo) DIV_ROUND_UP((tmo - jiffies) * 1000, HZ) + + if (icsk->icsk_pending == ICSK_TIME_RETRANS || + icsk->icsk_pending == ICSK_TIME_EARLY_RETRANS || + icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) { r->idiag_timer = 1; r->idiag_retrans = icsk->icsk_retransmits; r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout); @@ -160,52 +182,129 @@ static int inet_diag_fill(struct sk_buff *skb, struct sock *sk, } #undef EXPIRES_IN_MS - r->idiag_uid = sock_i_uid(sk); - r->idiag_inode = sock_i_ino(sk); + if (ext & (1 << (INET_DIAG_INFO - 1))) { + attr = nla_reserve(skb, INET_DIAG_INFO, + sizeof(struct tcp_info)); + if (!attr) + goto errout; - if (minfo) { - minfo->idiag_rmem = atomic_read(&sk->sk_rmem_alloc); - minfo->idiag_wmem = sk->sk_wmem_queued; - minfo->idiag_fmem = sk->sk_forward_alloc; - minfo->idiag_tmem = atomic_read(&sk->sk_wmem_alloc); + info = nla_data(attr); } + if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops) + if (nla_put_string(skb, INET_DIAG_CONG, + icsk->icsk_ca_ops->name) < 0) + goto errout; + handler->idiag_get_info(sk, r, info); if (sk->sk_state < TCP_TIME_WAIT && icsk->icsk_ca_ops && icsk->icsk_ca_ops->get_info) icsk->icsk_ca_ops->get_info(sk, ext, skb); - nlh->nlmsg_len = skb->tail - b; - return skb->len; +out: + return nlmsg_end(skb, nlh); + +errout: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} +EXPORT_SYMBOL_GPL(inet_sk_diag_fill); + +static int inet_csk_diag_fill(struct sock *sk, + struct sk_buff *skb, struct inet_diag_req_v2 *req, + struct user_namespace *user_ns, + u32 portid, u32 seq, u16 nlmsg_flags, + const struct nlmsghdr *unlh) +{ + return inet_sk_diag_fill(sk, inet_csk(sk), + skb, req, user_ns, portid, seq, nlmsg_flags, unlh); +} + +static int inet_twsk_diag_fill(struct inet_timewait_sock *tw, + struct sk_buff *skb, struct inet_diag_req_v2 *req, + u32 portid, u32 seq, u16 nlmsg_flags, + const struct nlmsghdr *unlh) +{ + s32 tmo; + struct inet_diag_msg *r; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r), + nlmsg_flags); + if (!nlh) + return -EMSGSIZE; + + r = nlmsg_data(nlh); + BUG_ON(tw->tw_state != TCP_TIME_WAIT); + + tmo = tw->tw_ttd - inet_tw_time_stamp(); + if (tmo < 0) + tmo = 0; + + r->idiag_family = tw->tw_family; + r->idiag_retrans = 0; + + r->id.idiag_if = tw->tw_bound_dev_if; + sock_diag_save_cookie(tw, r->id.idiag_cookie); + + r->id.idiag_sport = tw->tw_sport; + r->id.idiag_dport = tw->tw_dport; + + memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src)); + memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst)); + + r->id.idiag_src[0] = tw->tw_rcv_saddr; + r->id.idiag_dst[0] = tw->tw_daddr; + + r->idiag_state = tw->tw_substate; + r->idiag_timer = 3; + r->idiag_expires = jiffies_to_msecs(tmo); + r->idiag_rqueue = 0; + r->idiag_wqueue = 0; + r->idiag_uid = 0; + r->idiag_inode = 0; +#if IS_ENABLED(CONFIG_IPV6) + if (tw->tw_family == AF_INET6) { + *(struct in6_addr *)r->id.idiag_src = tw->tw_v6_rcv_saddr; + *(struct in6_addr *)r->id.idiag_dst = tw->tw_v6_daddr; + } +#endif + + return nlmsg_end(skb, nlh); +} -rtattr_failure: -nlmsg_failure: - skb_trim(skb, b - skb->data); - return -1; +static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, + struct inet_diag_req_v2 *r, + struct user_namespace *user_ns, + u32 portid, u32 seq, u16 nlmsg_flags, + const struct nlmsghdr *unlh) +{ + if (sk->sk_state == TCP_TIME_WAIT) + return inet_twsk_diag_fill(inet_twsk(sk), skb, r, portid, seq, + nlmsg_flags, unlh); + + return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq, + nlmsg_flags, unlh); } -static int inet_diag_get_exact(struct sk_buff *in_skb, const struct nlmsghdr *nlh) +int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *in_skb, + const struct nlmsghdr *nlh, struct inet_diag_req_v2 *req) { int err; struct sock *sk; - struct inet_diag_req *req = NLMSG_DATA(nlh); struct sk_buff *rep; - struct inet_hashinfo *hashinfo; - const struct inet_diag_handler *handler; - - handler = inet_diag_table[nlh->nlmsg_type]; - BUG_ON(handler == NULL); - hashinfo = handler->idiag_hashinfo; + struct net *net = sock_net(in_skb->sk); - if (req->idiag_family == AF_INET) { - sk = inet_lookup(hashinfo, req->id.idiag_dst[0], + err = -EINVAL; + if (req->sdiag_family == AF_INET) { + sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0], req->id.idiag_dport, req->id.idiag_src[0], req->id.idiag_sport, req->id.idiag_if); } -#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE) - else if (req->idiag_family == AF_INET6) { - sk = inet6_lookup(hashinfo, +#if IS_ENABLED(CONFIG_IPV6) + else if (req->sdiag_family == AF_INET6) { + sk = inet6_lookup(net, hashinfo, (struct in6_addr *)req->id.idiag_dst, req->id.idiag_dport, (struct in6_addr *)req->id.idiag_src, @@ -214,48 +313,66 @@ static int inet_diag_get_exact(struct sk_buff *in_skb, const struct nlmsghdr *nl } #endif else { - return -EINVAL; + goto out_nosk; } + err = -ENOENT; if (sk == NULL) - return -ENOENT; + goto out_nosk; - err = -ESTALE; - if ((req->id.idiag_cookie[0] != INET_DIAG_NOCOOKIE || - req->id.idiag_cookie[1] != INET_DIAG_NOCOOKIE) && - ((u32)(unsigned long)sk != req->id.idiag_cookie[0] || - (u32)((((unsigned long)sk) >> 31) >> 1) != req->id.idiag_cookie[1])) + err = sock_diag_check_cookie(sk, req->id.idiag_cookie); + if (err) goto out; - err = -ENOMEM; - rep = alloc_skb(NLMSG_SPACE((sizeof(struct inet_diag_msg) + - sizeof(struct inet_diag_meminfo) + - handler->idiag_info_size + 64)), - GFP_KERNEL); - if (!rep) + rep = nlmsg_new(sizeof(struct inet_diag_msg) + + sizeof(struct inet_diag_meminfo) + + sizeof(struct tcp_info) + 64, GFP_KERNEL); + if (!rep) { + err = -ENOMEM; goto out; + } - if (inet_diag_fill(rep, sk, req->idiag_ext, - NETLINK_CB(in_skb).pid, - nlh->nlmsg_seq, 0, nlh) <= 0) - BUG(); - - err = netlink_unicast(idiagnl, rep, NETLINK_CB(in_skb).pid, + err = sk_diag_fill(sk, rep, req, + sk_user_ns(NETLINK_CB(in_skb).sk), + NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, 0, nlh); + if (err < 0) { + WARN_ON(err == -EMSGSIZE); + nlmsg_free(rep); + goto out; + } + err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid, MSG_DONTWAIT); if (err > 0) err = 0; out: - if (sk) { - if (sk->sk_state == TCP_TIME_WAIT) - inet_twsk_put((struct inet_timewait_sock *)sk); - else - sock_put(sk); - } + if (sk) + sock_gen_put(sk); + +out_nosk: + return err; +} +EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk); + +static int inet_diag_get_exact(struct sk_buff *in_skb, + const struct nlmsghdr *nlh, + struct inet_diag_req_v2 *req) +{ + const struct inet_diag_handler *handler; + int err; + + handler = inet_diag_lock_handler(req->sdiag_protocol); + if (IS_ERR(handler)) + err = PTR_ERR(handler); + else + err = handler->dump_one(in_skb, nlh, req); + inet_diag_unlock_handler(handler); + return err; } -static int bitstring_match(const u32 *a1, const u32 *a2, int bits) +static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits) { int words = bits >> 5; @@ -266,8 +383,8 @@ static int bitstring_match(const u32 *a1, const u32 *a2, int bits) return 0; } if (bits) { - __u32 w1, w2; - __u32 mask; + __be32 w1, w2; + __be32 mask; w1 = a1[words]; w2 = a2[words]; @@ -282,9 +399,12 @@ static int bitstring_match(const u32 *a1, const u32 *a2, int bits) } -static int inet_diag_bc_run(const void *bc, int len, - const struct inet_diag_entry *entry) +static int inet_diag_bc_run(const struct nlattr *_bc, + const struct inet_diag_entry *entry) { + const void *bc = nla_data(_bc); + int len = nla_len(_bc); + while (len > 0) { int yes = 1; const struct inet_diag_bc_op *op = bc; @@ -299,7 +419,7 @@ static int inet_diag_bc_run(const void *bc, int len, yes = entry->sport >= op[1].no; break; case INET_DIAG_BC_S_LE: - yes = entry->dport <= op[1].no; + yes = entry->sport <= op[1].no; break; case INET_DIAG_BC_D_GE: yes = entry->dport >= op[1].no; @@ -313,7 +433,7 @@ static int inet_diag_bc_run(const void *bc, int len, case INET_DIAG_BC_S_COND: case INET_DIAG_BC_D_COND: { struct inet_diag_hostcond *cond; - u32 *addr; + __be32 *addr; cond = (struct inet_diag_hostcond *)(op + 1); if (cond->port != -1 && @@ -322,31 +442,38 @@ static int inet_diag_bc_run(const void *bc, int len, yes = 0; break; } - - if (cond->prefix_len == 0) - break; if (op->code == INET_DIAG_BC_S_COND) addr = entry->saddr; else addr = entry->daddr; - if (bitstring_match(addr, cond->addr, cond->prefix_len)) + if (cond->family != AF_UNSPEC && + cond->family != entry->family) { + if (entry->family == AF_INET6 && + cond->family == AF_INET) { + if (addr[0] == 0 && addr[1] == 0 && + addr[2] == htonl(0xffff) && + bitstring_match(addr + 3, + cond->addr, + cond->prefix_len)) + break; + } + yes = 0; break; - if (entry->family == AF_INET6 && - cond->family == AF_INET) { - if (addr[0] == 0 && addr[1] == 0 && - addr[2] == htonl(0xffff) && - bitstring_match(addr + 3, cond->addr, - cond->prefix_len)) - break; } + + if (cond->prefix_len == 0) + break; + if (bitstring_match(addr, cond->addr, + cond->prefix_len)) + break; yes = 0; break; } } - if (yes) { + if (yes) { len -= op->yes; bc += op->yes; } else { @@ -354,8 +481,36 @@ static int inet_diag_bc_run(const void *bc, int len, bc += op->no; } } - return (len == 0); + return len == 0; +} + +int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk) +{ + struct inet_diag_entry entry; + struct inet_sock *inet = inet_sk(sk); + + if (bc == NULL) + return 1; + + entry.family = sk->sk_family; +#if IS_ENABLED(CONFIG_IPV6) + if (entry.family == AF_INET6) { + + entry.saddr = sk->sk_v6_rcv_saddr.s6_addr32; + entry.daddr = sk->sk_v6_daddr.s6_addr32; + } else +#endif + { + entry.saddr = &inet->inet_rcv_saddr; + entry.daddr = &inet->inet_daddr; + } + entry.sport = inet->inet_num; + entry.dport = ntohs(inet->inet_dport); + entry.userlocks = sk->sk_userlocks; + + return inet_diag_bc_run(bc, &entry); } +EXPORT_SYMBOL_GPL(inet_diag_bc_sk); static int valid_cc(const void *bc, int len, int cc) { @@ -366,7 +521,7 @@ static int valid_cc(const void *bc, int len, int cc) return 0; if (cc == len) return 1; - if (op->yes < 4) + if (op->yes < 4 || op->yes & 3) return 0; len -= op->yes; bc += op->yes; @@ -374,143 +529,246 @@ static int valid_cc(const void *bc, int len, int cc) return 0; } +/* Validate an inet_diag_hostcond. */ +static bool valid_hostcond(const struct inet_diag_bc_op *op, int len, + int *min_len) +{ + int addr_len; + struct inet_diag_hostcond *cond; + + /* Check hostcond space. */ + *min_len += sizeof(struct inet_diag_hostcond); + if (len < *min_len) + return false; + cond = (struct inet_diag_hostcond *)(op + 1); + + /* Check address family and address length. */ + switch (cond->family) { + case AF_UNSPEC: + addr_len = 0; + break; + case AF_INET: + addr_len = sizeof(struct in_addr); + break; + case AF_INET6: + addr_len = sizeof(struct in6_addr); + break; + default: + return false; + } + *min_len += addr_len; + if (len < *min_len) + return false; + + /* Check prefix length (in bits) vs address length (in bytes). */ + if (cond->prefix_len > 8 * addr_len) + return false; + + return true; +} + +/* Validate a port comparison operator. */ +static inline bool valid_port_comparison(const struct inet_diag_bc_op *op, + int len, int *min_len) +{ + /* Port comparisons put the port in a follow-on inet_diag_bc_op. */ + *min_len += sizeof(struct inet_diag_bc_op); + if (len < *min_len) + return false; + return true; +} + static int inet_diag_bc_audit(const void *bytecode, int bytecode_len) { - const unsigned char *bc = bytecode; + const void *bc = bytecode; int len = bytecode_len; while (len > 0) { - struct inet_diag_bc_op *op = (struct inet_diag_bc_op *)bc; + const struct inet_diag_bc_op *op = bc; + int min_len = sizeof(struct inet_diag_bc_op); //printk("BC: %d %d %d {%d} / %d\n", op->code, op->yes, op->no, op[1].no, len); switch (op->code) { - case INET_DIAG_BC_AUTO: case INET_DIAG_BC_S_COND: case INET_DIAG_BC_D_COND: + if (!valid_hostcond(bc, len, &min_len)) + return -EINVAL; + break; case INET_DIAG_BC_S_GE: case INET_DIAG_BC_S_LE: case INET_DIAG_BC_D_GE: case INET_DIAG_BC_D_LE: - if (op->yes < 4 || op->yes > len + 4) - return -EINVAL; - case INET_DIAG_BC_JMP: - if (op->no < 4 || op->no > len + 4) - return -EINVAL; - if (op->no < len && - !valid_cc(bytecode, bytecode_len, len - op->no)) + if (!valid_port_comparison(bc, len, &min_len)) return -EINVAL; break; + case INET_DIAG_BC_AUTO: + case INET_DIAG_BC_JMP: case INET_DIAG_BC_NOP: - if (op->yes < 4 || op->yes > len + 4) - return -EINVAL; break; default: return -EINVAL; } - bc += op->yes; + + if (op->code != INET_DIAG_BC_NOP) { + if (op->no < min_len || op->no > len + 4 || op->no & 3) + return -EINVAL; + if (op->no < len && + !valid_cc(bytecode, bytecode_len, len - op->no)) + return -EINVAL; + } + + if (op->yes < min_len || op->yes > len + 4 || op->yes & 3) + return -EINVAL; + bc += op->yes; len -= op->yes; } return len == 0 ? 0 : -EINVAL; } -static int inet_diag_dump_sock(struct sk_buff *skb, struct sock *sk, - struct netlink_callback *cb) +static int inet_csk_diag_dump(struct sock *sk, + struct sk_buff *skb, + struct netlink_callback *cb, + struct inet_diag_req_v2 *r, + const struct nlattr *bc) { - struct inet_diag_req *r = NLMSG_DATA(cb->nlh); + if (!inet_diag_bc_sk(bc, sk)) + return 0; - if (cb->nlh->nlmsg_len > 4 + NLMSG_SPACE(sizeof(*r))) { - struct inet_diag_entry entry; - struct rtattr *bc = (struct rtattr *)(r + 1); - struct inet_sock *inet = inet_sk(sk); + return inet_csk_diag_fill(sk, skb, r, + sk_user_ns(NETLINK_CB(cb->skb).sk), + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh); +} + +static int inet_twsk_diag_dump(struct sock *sk, + struct sk_buff *skb, + struct netlink_callback *cb, + struct inet_diag_req_v2 *r, + const struct nlattr *bc) +{ + struct inet_timewait_sock *tw = inet_twsk(sk); - entry.family = sk->sk_family; -#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE) - if (entry.family == AF_INET6) { - struct ipv6_pinfo *np = inet6_sk(sk); + if (bc != NULL) { + struct inet_diag_entry entry; - entry.saddr = np->rcv_saddr.s6_addr32; - entry.daddr = np->daddr.s6_addr32; + entry.family = tw->tw_family; +#if IS_ENABLED(CONFIG_IPV6) + if (tw->tw_family == AF_INET6) { + entry.saddr = tw->tw_v6_rcv_saddr.s6_addr32; + entry.daddr = tw->tw_v6_daddr.s6_addr32; } else #endif { - entry.saddr = &inet->rcv_saddr; - entry.daddr = &inet->daddr; + entry.saddr = &tw->tw_rcv_saddr; + entry.daddr = &tw->tw_daddr; } - entry.sport = inet->num; - entry.dport = ntohs(inet->dport); - entry.userlocks = sk->sk_userlocks; + entry.sport = tw->tw_num; + entry.dport = ntohs(tw->tw_dport); + entry.userlocks = 0; - if (!inet_diag_bc_run(RTA_DATA(bc), RTA_PAYLOAD(bc), &entry)) + if (!inet_diag_bc_run(bc, &entry)) return 0; } - return inet_diag_fill(skb, sk, r->idiag_ext, NETLINK_CB(cb->skb).pid, - cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh); + return inet_twsk_diag_fill(tw, skb, r, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh); +} + +/* Get the IPv4, IPv6, or IPv4-mapped-IPv6 local and remote addresses + * from a request_sock. For IPv4-mapped-IPv6 we must map IPv4 to IPv6. + */ +static inline void inet_diag_req_addrs(const struct sock *sk, + const struct request_sock *req, + struct inet_diag_entry *entry) +{ + struct inet_request_sock *ireq = inet_rsk(req); + +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) { + if (req->rsk_ops->family == AF_INET6) { + entry->saddr = ireq->ir_v6_loc_addr.s6_addr32; + entry->daddr = ireq->ir_v6_rmt_addr.s6_addr32; + } else if (req->rsk_ops->family == AF_INET) { + ipv6_addr_set_v4mapped(ireq->ir_loc_addr, + &entry->saddr_storage); + ipv6_addr_set_v4mapped(ireq->ir_rmt_addr, + &entry->daddr_storage); + entry->saddr = entry->saddr_storage.s6_addr32; + entry->daddr = entry->daddr_storage.s6_addr32; + } + } else +#endif + { + entry->saddr = &ireq->ir_loc_addr; + entry->daddr = &ireq->ir_rmt_addr; + } } static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk, - struct request_sock *req, - u32 pid, u32 seq, - const struct nlmsghdr *unlh) + struct request_sock *req, + struct user_namespace *user_ns, + u32 portid, u32 seq, + const struct nlmsghdr *unlh) { const struct inet_request_sock *ireq = inet_rsk(req); struct inet_sock *inet = inet_sk(sk); - unsigned char *b = skb->tail; struct inet_diag_msg *r; struct nlmsghdr *nlh; long tmo; - nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r)); - nlh->nlmsg_flags = NLM_F_MULTI; - r = NLMSG_DATA(nlh); + nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r), + NLM_F_MULTI); + if (!nlh) + return -EMSGSIZE; + r = nlmsg_data(nlh); r->idiag_family = sk->sk_family; r->idiag_state = TCP_SYN_RECV; r->idiag_timer = 1; - r->idiag_retrans = req->retrans; + r->idiag_retrans = req->num_retrans; r->id.idiag_if = sk->sk_bound_dev_if; - r->id.idiag_cookie[0] = (u32)(unsigned long)req; - r->id.idiag_cookie[1] = (u32)(((unsigned long)req >> 31) >> 1); + sock_diag_save_cookie(req, r->id.idiag_cookie); tmo = req->expires - jiffies; if (tmo < 0) tmo = 0; - r->id.idiag_sport = inet->sport; - r->id.idiag_dport = ireq->rmt_port; - r->id.idiag_src[0] = ireq->loc_addr; - r->id.idiag_dst[0] = ireq->rmt_addr; + r->id.idiag_sport = inet->inet_sport; + r->id.idiag_dport = ireq->ir_rmt_port; + + memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src)); + memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst)); + + r->id.idiag_src[0] = ireq->ir_loc_addr; + r->id.idiag_dst[0] = ireq->ir_rmt_addr; + r->idiag_expires = jiffies_to_msecs(tmo); r->idiag_rqueue = 0; r->idiag_wqueue = 0; - r->idiag_uid = sock_i_uid(sk); + r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk)); r->idiag_inode = 0; -#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE) +#if IS_ENABLED(CONFIG_IPV6) if (r->idiag_family == AF_INET6) { - ipv6_addr_copy((struct in6_addr *)r->id.idiag_src, - &tcp6_rsk(req)->loc_addr); - ipv6_addr_copy((struct in6_addr *)r->id.idiag_dst, - &tcp6_rsk(req)->rmt_addr); + struct inet_diag_entry entry; + inet_diag_req_addrs(sk, req, &entry); + memcpy(r->id.idiag_src, entry.saddr, sizeof(struct in6_addr)); + memcpy(r->id.idiag_dst, entry.daddr, sizeof(struct in6_addr)); } #endif - nlh->nlmsg_len = skb->tail - b; - return skb->len; - -nlmsg_failure: - skb_trim(skb, b - skb->data); - return -1; + return nlmsg_end(skb, nlh); } static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk, - struct netlink_callback *cb) + struct netlink_callback *cb, + struct inet_diag_req_v2 *r, + const struct nlattr *bc) { struct inet_diag_entry entry; - struct inet_diag_req *r = NLMSG_DATA(cb->nlh); struct inet_connection_sock *icsk = inet_csk(sk); struct listen_sock *lopt; - struct rtattr *bc = NULL; struct inet_sock *inet = inet_sk(sk); int j, s_j; int reqnum, s_reqnum; @@ -530,9 +788,8 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk, if (!lopt || !lopt->qlen) goto out; - if (cb->nlh->nlmsg_len > 4 + NLMSG_SPACE(sizeof(*r))) { - bc = (struct rtattr *)(r + 1); - entry.sport = inet->num; + if (bc != NULL) { + entry.sport = inet->inet_num; entry.userlocks = sk->sk_userlocks; } @@ -545,32 +802,21 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk, if (reqnum < s_reqnum) continue; - if (r->id.idiag_dport != ireq->rmt_port && + if (r->id.idiag_dport != ireq->ir_rmt_port && r->id.idiag_dport) continue; if (bc) { - entry.saddr = -#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE) - (entry.family == AF_INET6) ? - tcp6_rsk(req)->loc_addr.s6_addr32 : -#endif - &ireq->loc_addr; - entry.daddr = -#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE) - (entry.family == AF_INET6) ? - tcp6_rsk(req)->rmt_addr.s6_addr32 : -#endif - &ireq->rmt_addr; - entry.dport = ntohs(ireq->rmt_port); + inet_diag_req_addrs(sk, req, &entry); + entry.dport = ntohs(ireq->ir_rmt_port); - if (!inet_diag_bc_run(RTA_DATA(bc), - RTA_PAYLOAD(bc), &entry)) + if (!inet_diag_bc_run(bc, &entry)) continue; } err = inet_diag_fill_req(skb, sk, req, - NETLINK_CB(cb->skb).pid, + sk_user_ns(NETLINK_CB(cb->skb).sk), + NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, cb->nlh); if (err < 0) { cb->args[3] = j + 1; @@ -588,18 +834,13 @@ out: return err; } -static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) +void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, + struct netlink_callback *cb, struct inet_diag_req_v2 *r, struct nlattr *bc) { int i, num; int s_i, s_num; - struct inet_diag_req *r = NLMSG_DATA(cb->nlh); - const struct inet_diag_handler *handler; - struct inet_hashinfo *hashinfo; + struct net *net = sock_net(skb->sk); - handler = inet_diag_table[cb->nlh->nlmsg_type]; - BUG_ON(handler == NULL); - hashinfo = handler->idiag_hashinfo; - s_i = cb->args[1]; s_num = num = cb->args[2]; @@ -607,21 +848,30 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) if (!(r->idiag_states & (TCPF_LISTEN | TCPF_SYN_RECV))) goto skip_listen_ht; - inet_listen_lock(hashinfo); for (i = s_i; i < INET_LHTABLE_SIZE; i++) { struct sock *sk; - struct hlist_node *node; + struct hlist_nulls_node *node; + struct inet_listen_hashbucket *ilb; num = 0; - sk_for_each(sk, node, &hashinfo->listening_hash[i]) { + ilb = &hashinfo->listening_hash[i]; + spin_lock_bh(&ilb->lock); + sk_nulls_for_each(sk, node, &ilb->head) { struct inet_sock *inet = inet_sk(sk); + if (!net_eq(sock_net(sk), net)) + continue; + if (num < s_num) { num++; continue; } - if (r->id.idiag_sport != inet->sport && + if (r->sdiag_family != AF_UNSPEC && + sk->sk_family != r->sdiag_family) + goto next_listen; + + if (r->id.idiag_sport != inet->inet_sport && r->id.idiag_sport) goto next_listen; @@ -630,8 +880,8 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) cb->args[3] > 0) goto syn_recv; - if (inet_diag_dump_sock(skb, sk, cb) < 0) { - inet_listen_unlock(hashinfo); + if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) { + spin_unlock_bh(&ilb->lock); goto done; } @@ -639,8 +889,8 @@ syn_recv: if (!(r->idiag_states & TCPF_SYN_RECV)) goto next_listen; - if (inet_diag_dump_reqs(skb, sk, cb) < 0) { - inet_listen_unlock(hashinfo); + if (inet_diag_dump_reqs(skb, sk, cb, r, bc) < 0) { + spin_unlock_bh(&ilb->lock); goto done; } @@ -649,171 +899,237 @@ next_listen: cb->args[4] = 0; ++num; } + spin_unlock_bh(&ilb->lock); s_num = 0; cb->args[3] = 0; cb->args[4] = 0; } - inet_listen_unlock(hashinfo); skip_listen_ht: cb->args[0] = 1; s_i = num = s_num = 0; } if (!(r->idiag_states & ~(TCPF_LISTEN | TCPF_SYN_RECV))) - return skb->len; + goto out; - for (i = s_i; i < hashinfo->ehash_size; i++) { + for (i = s_i; i <= hashinfo->ehash_mask; i++) { struct inet_ehash_bucket *head = &hashinfo->ehash[i]; + spinlock_t *lock = inet_ehash_lockp(hashinfo, i); struct sock *sk; - struct hlist_node *node; + struct hlist_nulls_node *node; + + num = 0; + + if (hlist_nulls_empty(&head->chain)) + continue; if (i > s_i) s_num = 0; - read_lock_bh(&head->lock); - - num = 0; - sk_for_each(sk, node, &head->chain) { - struct inet_sock *inet = inet_sk(sk); + spin_lock_bh(lock); + sk_nulls_for_each(sk, node, &head->chain) { + int res; + int state; + if (!net_eq(sock_net(sk), net)) + continue; if (num < s_num) goto next_normal; - if (!(r->idiag_states & (1 << sk->sk_state))) + state = (sk->sk_state == TCP_TIME_WAIT) ? + inet_twsk(sk)->tw_substate : sk->sk_state; + if (!(r->idiag_states & (1 << state))) + goto next_normal; + if (r->sdiag_family != AF_UNSPEC && + sk->sk_family != r->sdiag_family) goto next_normal; - if (r->id.idiag_sport != inet->sport && + if (r->id.idiag_sport != htons(sk->sk_num) && r->id.idiag_sport) goto next_normal; - if (r->id.idiag_dport != inet->dport && r->id.idiag_dport) + if (r->id.idiag_dport != sk->sk_dport && + r->id.idiag_dport) goto next_normal; - if (inet_diag_dump_sock(skb, sk, cb) < 0) { - read_unlock_bh(&head->lock); + if (sk->sk_state == TCP_TIME_WAIT) + res = inet_twsk_diag_dump(sk, skb, cb, r, bc); + else + res = inet_csk_diag_dump(sk, skb, cb, r, bc); + if (res < 0) { + spin_unlock_bh(lock); goto done; } next_normal: ++num; } - if (r->idiag_states & TCPF_TIME_WAIT) { - sk_for_each(sk, node, - &hashinfo->ehash[i + hashinfo->ehash_size].chain) { - struct inet_sock *inet = inet_sk(sk); - - if (num < s_num) - goto next_dying; - if (r->id.idiag_sport != inet->sport && - r->id.idiag_sport) - goto next_dying; - if (r->id.idiag_dport != inet->dport && - r->id.idiag_dport) - goto next_dying; - if (inet_diag_dump_sock(skb, sk, cb) < 0) { - read_unlock_bh(&head->lock); - goto done; - } -next_dying: - ++num; - } - } - read_unlock_bh(&head->lock); + spin_unlock_bh(lock); } done: cb->args[1] = i; cb->args[2] = num; - return skb->len; +out: + ; } +EXPORT_SYMBOL_GPL(inet_diag_dump_icsk); -static int inet_diag_dump_done(struct netlink_callback *cb) +static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, + struct inet_diag_req_v2 *r, struct nlattr *bc) { - return 0; + const struct inet_diag_handler *handler; + int err = 0; + + handler = inet_diag_lock_handler(r->sdiag_protocol); + if (!IS_ERR(handler)) + handler->dump(skb, cb, r, bc); + else + err = PTR_ERR(handler); + inet_diag_unlock_handler(handler); + + return err ? : skb->len; } +static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct nlattr *bc = NULL; + int hdrlen = sizeof(struct inet_diag_req_v2); + + if (nlmsg_attrlen(cb->nlh, hdrlen)) + bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE); -static __inline__ int -inet_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh) + return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc); +} + +static inline int inet_diag_type2proto(int type) { - if (!(nlh->nlmsg_flags&NLM_F_REQUEST)) + switch (type) { + case TCPDIAG_GETSOCK: + return IPPROTO_TCP; + case DCCPDIAG_GETSOCK: + return IPPROTO_DCCP; + default: return 0; - - if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX) - goto err_inval; - - if (inet_diag_table[nlh->nlmsg_type] == NULL) - return -ENOENT; - - if (NLMSG_LENGTH(sizeof(struct inet_diag_req)) > skb->len) - goto err_inval; - - if (nlh->nlmsg_flags&NLM_F_DUMP) { - if (nlh->nlmsg_len > - (4 + NLMSG_SPACE(sizeof(struct inet_diag_req)))) { - struct rtattr *rta = (void *)(NLMSG_DATA(nlh) + - sizeof(struct inet_diag_req)); - if (rta->rta_type != INET_DIAG_REQ_BYTECODE || - rta->rta_len < 8 || - rta->rta_len > - (nlh->nlmsg_len - - NLMSG_SPACE(sizeof(struct inet_diag_req)))) - goto err_inval; - if (inet_diag_bc_audit(RTA_DATA(rta), RTA_PAYLOAD(rta))) - goto err_inval; - } - return netlink_dump_start(idiagnl, skb, nlh, - inet_diag_dump, - inet_diag_dump_done); - } else { - return inet_diag_get_exact(skb, nlh); } +} + +static int inet_diag_dump_compat(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct inet_diag_req *rc = nlmsg_data(cb->nlh); + struct inet_diag_req_v2 req; + struct nlattr *bc = NULL; + int hdrlen = sizeof(struct inet_diag_req); + + req.sdiag_family = AF_UNSPEC; /* compatibility */ + req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type); + req.idiag_ext = rc->idiag_ext; + req.idiag_states = rc->idiag_states; + req.id = rc->id; + + if (nlmsg_attrlen(cb->nlh, hdrlen)) + bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE); -err_inval: - return -EINVAL; + return __inet_diag_dump(skb, cb, &req, bc); } +static int inet_diag_get_exact_compat(struct sk_buff *in_skb, + const struct nlmsghdr *nlh) +{ + struct inet_diag_req *rc = nlmsg_data(nlh); + struct inet_diag_req_v2 req; + + req.sdiag_family = rc->idiag_family; + req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type); + req.idiag_ext = rc->idiag_ext; + req.idiag_states = rc->idiag_states; + req.id = rc->id; + + return inet_diag_get_exact(in_skb, nlh, &req); +} -static inline void inet_diag_rcv_skb(struct sk_buff *skb) +static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh) { - int err; - struct nlmsghdr * nlh; - - if (skb->len >= NLMSG_SPACE(0)) { - nlh = (struct nlmsghdr *)skb->data; - if (nlh->nlmsg_len < sizeof(*nlh) || skb->len < nlh->nlmsg_len) - return; - err = inet_diag_rcv_msg(skb, nlh); - if (err || nlh->nlmsg_flags & NLM_F_ACK) - netlink_ack(skb, nlh, err); + int hdrlen = sizeof(struct inet_diag_req); + struct net *net = sock_net(skb->sk); + + if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX || + nlmsg_len(nlh) < hdrlen) + return -EINVAL; + + if (nlh->nlmsg_flags & NLM_F_DUMP) { + if (nlmsg_attrlen(nlh, hdrlen)) { + struct nlattr *attr; + + attr = nlmsg_find_attr(nlh, hdrlen, + INET_DIAG_REQ_BYTECODE); + if (attr == NULL || + nla_len(attr) < sizeof(struct inet_diag_bc_op) || + inet_diag_bc_audit(nla_data(attr), nla_len(attr))) + return -EINVAL; + } + { + struct netlink_dump_control c = { + .dump = inet_diag_dump_compat, + }; + return netlink_dump_start(net->diag_nlsk, skb, nlh, &c); + } } + + return inet_diag_get_exact_compat(skb, nlh); } -static void inet_diag_rcv(struct sock *sk, int len) +static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) { - struct sk_buff *skb; - unsigned int qlen = skb_queue_len(&sk->sk_receive_queue); + int hdrlen = sizeof(struct inet_diag_req_v2); + struct net *net = sock_net(skb->sk); - while (qlen-- && (skb = skb_dequeue(&sk->sk_receive_queue))) { - inet_diag_rcv_skb(skb); - kfree_skb(skb); + if (nlmsg_len(h) < hdrlen) + return -EINVAL; + + if (h->nlmsg_flags & NLM_F_DUMP) { + if (nlmsg_attrlen(h, hdrlen)) { + struct nlattr *attr; + attr = nlmsg_find_attr(h, hdrlen, + INET_DIAG_REQ_BYTECODE); + if (attr == NULL || + nla_len(attr) < sizeof(struct inet_diag_bc_op) || + inet_diag_bc_audit(nla_data(attr), nla_len(attr))) + return -EINVAL; + } + { + struct netlink_dump_control c = { + .dump = inet_diag_dump, + }; + return netlink_dump_start(net->diag_nlsk, skb, h, &c); + } } + + return inet_diag_get_exact(skb, h, nlmsg_data(h)); } -static DEFINE_SPINLOCK(inet_diag_register_lock); +static const struct sock_diag_handler inet_diag_handler = { + .family = AF_INET, + .dump = inet_diag_handler_dump, +}; + +static const struct sock_diag_handler inet6_diag_handler = { + .family = AF_INET6, + .dump = inet_diag_handler_dump, +}; int inet_diag_register(const struct inet_diag_handler *h) { const __u16 type = h->idiag_type; int err = -EINVAL; - if (type >= INET_DIAG_GETSOCK_MAX) + if (type >= IPPROTO_MAX) goto out; - spin_lock(&inet_diag_register_lock); + mutex_lock(&inet_diag_table_mutex); err = -EEXIST; if (inet_diag_table[type] == NULL) { inet_diag_table[type] = h; err = 0; } - spin_unlock(&inet_diag_register_lock); + mutex_unlock(&inet_diag_table_mutex); out: return err; } @@ -823,46 +1139,54 @@ void inet_diag_unregister(const struct inet_diag_handler *h) { const __u16 type = h->idiag_type; - if (type >= INET_DIAG_GETSOCK_MAX) + if (type >= IPPROTO_MAX) return; - spin_lock(&inet_diag_register_lock); + mutex_lock(&inet_diag_table_mutex); inet_diag_table[type] = NULL; - spin_unlock(&inet_diag_register_lock); - - synchronize_rcu(); + mutex_unlock(&inet_diag_table_mutex); } EXPORT_SYMBOL_GPL(inet_diag_unregister); static int __init inet_diag_init(void) { - const int inet_diag_table_size = (INET_DIAG_GETSOCK_MAX * + const int inet_diag_table_size = (IPPROTO_MAX * sizeof(struct inet_diag_handler *)); int err = -ENOMEM; - inet_diag_table = kmalloc(inet_diag_table_size, GFP_KERNEL); + inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL); if (!inet_diag_table) goto out; - memset(inet_diag_table, 0, inet_diag_table_size); - idiagnl = netlink_kernel_create(NETLINK_INET_DIAG, 0, inet_diag_rcv, - THIS_MODULE); - if (idiagnl == NULL) - goto out_free_table; - err = 0; + err = sock_diag_register(&inet_diag_handler); + if (err) + goto out_free_nl; + + err = sock_diag_register(&inet6_diag_handler); + if (err) + goto out_free_inet; + + sock_diag_register_inet_compat(inet_diag_rcv_msg_compat); out: return err; -out_free_table: + +out_free_inet: + sock_diag_unregister(&inet_diag_handler); +out_free_nl: kfree(inet_diag_table); goto out; } static void __exit inet_diag_exit(void) { - sock_release(idiagnl->sk_socket); + sock_diag_unregister(&inet6_diag_handler); + sock_diag_unregister(&inet_diag_handler); + sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat); kfree(inet_diag_table); } module_init(inet_diag_init); module_exit(inet_diag_exit); MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */); |
