Commit f7fa9b10 authored by Patrick McHardy's avatar Patrick McHardy Committed by David S. Miller

[NETLINK]: Support dynamic number of multicast groups per netlink family

Signed-off-by: default avatarPatrick McHardy <kaber@trash.net>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent ab33a171
...@@ -60,21 +60,24 @@ ...@@ -60,21 +60,24 @@
#include <net/scm.h> #include <net/scm.h>
#define Nprintk(a...) #define Nprintk(a...)
#define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8)
struct netlink_sock { struct netlink_sock {
/* struct sock has to be the first member of netlink_sock */ /* struct sock has to be the first member of netlink_sock */
struct sock sk; struct sock sk;
u32 pid; u32 pid;
unsigned int groups;
u32 dst_pid; u32 dst_pid;
u32 dst_group; u32 dst_group;
u32 flags;
u32 subscriptions;
u32 ngroups;
unsigned long *groups;
unsigned long state; unsigned long state;
wait_queue_head_t wait; wait_queue_head_t wait;
struct netlink_callback *cb; struct netlink_callback *cb;
spinlock_t cb_lock; spinlock_t cb_lock;
void (*data_ready)(struct sock *sk, int bytes); void (*data_ready)(struct sock *sk, int bytes);
struct module *module; struct module *module;
u32 flags;
}; };
#define NETLINK_KERNEL_SOCKET 0x1 #define NETLINK_KERNEL_SOCKET 0x1
...@@ -101,6 +104,7 @@ struct netlink_table { ...@@ -101,6 +104,7 @@ struct netlink_table {
struct nl_pid_hash hash; struct nl_pid_hash hash;
struct hlist_head mc_list; struct hlist_head mc_list;
unsigned int nl_nonroot; unsigned int nl_nonroot;
unsigned int groups;
struct module *module; struct module *module;
int registered; int registered;
}; };
...@@ -138,6 +142,7 @@ static void netlink_sock_destruct(struct sock *sk) ...@@ -138,6 +142,7 @@ static void netlink_sock_destruct(struct sock *sk)
BUG_TRAP(!atomic_read(&sk->sk_rmem_alloc)); BUG_TRAP(!atomic_read(&sk->sk_rmem_alloc));
BUG_TRAP(!atomic_read(&sk->sk_wmem_alloc)); BUG_TRAP(!atomic_read(&sk->sk_wmem_alloc));
BUG_TRAP(!nlk_sk(sk)->cb); BUG_TRAP(!nlk_sk(sk)->cb);
BUG_TRAP(!nlk_sk(sk)->groups);
} }
/* This lock without WQ_FLAG_EXCLUSIVE is good on UP and it is _very_ bad on SMP. /* This lock without WQ_FLAG_EXCLUSIVE is good on UP and it is _very_ bad on SMP.
...@@ -333,7 +338,7 @@ static void netlink_remove(struct sock *sk) ...@@ -333,7 +338,7 @@ static void netlink_remove(struct sock *sk)
netlink_table_grab(); netlink_table_grab();
if (sk_del_node_init(sk)) if (sk_del_node_init(sk))
nl_table[sk->sk_protocol].hash.entries--; nl_table[sk->sk_protocol].hash.entries--;
if (nlk_sk(sk)->groups) if (nlk_sk(sk)->subscriptions)
__sk_del_bind_node(sk); __sk_del_bind_node(sk);
netlink_table_ungrab(); netlink_table_ungrab();
} }
...@@ -369,6 +374,8 @@ static int __netlink_create(struct socket *sock, int protocol) ...@@ -369,6 +374,8 @@ static int __netlink_create(struct socket *sock, int protocol)
static int netlink_create(struct socket *sock, int protocol) static int netlink_create(struct socket *sock, int protocol)
{ {
struct module *module = NULL; struct module *module = NULL;
struct netlink_sock *nlk;
unsigned int groups;
int err = 0; int err = 0;
sock->state = SS_UNCONNECTED; sock->state = SS_UNCONNECTED;
...@@ -392,15 +399,23 @@ static int netlink_create(struct socket *sock, int protocol) ...@@ -392,15 +399,23 @@ static int netlink_create(struct socket *sock, int protocol)
module = nl_table[protocol].module; module = nl_table[protocol].module;
else else
err = -EPROTONOSUPPORT; err = -EPROTONOSUPPORT;
groups = nl_table[protocol].groups;
netlink_unlock_table(); netlink_unlock_table();
if (err) if (err || (err = __netlink_create(sock, protocol) < 0))
goto out; goto out_module;
if ((err = __netlink_create(sock, protocol) < 0)) nlk = nlk_sk(sock->sk);
nlk->groups = kmalloc(NLGRPSZ(groups), GFP_KERNEL);
if (nlk->groups == NULL) {
err = -ENOMEM;
goto out_module; goto out_module;
}
memset(nlk->groups, 0, NLGRPSZ(groups));
nlk->ngroups = groups;
nlk_sk(sock->sk)->module = module; nlk->module = module;
out: out:
return err; return err;
...@@ -437,7 +452,7 @@ static int netlink_release(struct socket *sock) ...@@ -437,7 +452,7 @@ static int netlink_release(struct socket *sock)
skb_queue_purge(&sk->sk_write_queue); skb_queue_purge(&sk->sk_write_queue);
if (nlk->pid && !nlk->groups) { if (nlk->pid && !nlk->subscriptions) {
struct netlink_notify n = { struct netlink_notify n = {
.protocol = sk->sk_protocol, .protocol = sk->sk_protocol,
.pid = nlk->pid, .pid = nlk->pid,
...@@ -455,6 +470,9 @@ static int netlink_release(struct socket *sock) ...@@ -455,6 +470,9 @@ static int netlink_release(struct socket *sock)
netlink_table_ungrab(); netlink_table_ungrab();
} }
kfree(nlk->groups);
nlk->groups = NULL;
sock_put(sk); sock_put(sk);
return 0; return 0;
} }
...@@ -503,6 +521,18 @@ static inline int netlink_capable(struct socket *sock, unsigned int flag) ...@@ -503,6 +521,18 @@ static inline int netlink_capable(struct socket *sock, unsigned int flag)
capable(CAP_NET_ADMIN); capable(CAP_NET_ADMIN);
} }
static void
netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions)
{
struct netlink_sock *nlk = nlk_sk(sk);
if (nlk->subscriptions && !subscriptions)
__sk_del_bind_node(sk);
else if (!nlk->subscriptions && subscriptions)
sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list);
nlk->subscriptions = subscriptions;
}
static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len) static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
...@@ -528,15 +558,14 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len ...@@ -528,15 +558,14 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len
return err; return err;
} }
if (!nladdr->nl_groups && !nlk->groups) if (!nladdr->nl_groups && !(u32)nlk->groups[0])
return 0; return 0;
netlink_table_grab(); netlink_table_grab();
if (nlk->groups && !nladdr->nl_groups) netlink_update_subscriptions(sk, nlk->subscriptions +
__sk_del_bind_node(sk); hweight32(nladdr->nl_groups) -
else if (!nlk->groups && nladdr->nl_groups) hweight32(nlk->groups[0]));
sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list); nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups;
nlk->groups = nladdr->nl_groups;
netlink_table_ungrab(); netlink_table_ungrab();
return 0; return 0;
...@@ -590,7 +619,7 @@ static int netlink_getname(struct socket *sock, struct sockaddr *addr, int *addr ...@@ -590,7 +619,7 @@ static int netlink_getname(struct socket *sock, struct sockaddr *addr, int *addr
nladdr->nl_groups = netlink_group_mask(nlk->dst_group); nladdr->nl_groups = netlink_group_mask(nlk->dst_group);
} else { } else {
nladdr->nl_pid = nlk->pid; nladdr->nl_pid = nlk->pid;
nladdr->nl_groups = nlk->groups; nladdr->nl_groups = nlk->groups[0];
} }
return 0; return 0;
} }
...@@ -791,7 +820,8 @@ static inline int do_one_broadcast(struct sock *sk, ...@@ -791,7 +820,8 @@ static inline int do_one_broadcast(struct sock *sk,
if (p->exclude_sk == sk) if (p->exclude_sk == sk)
goto out; goto out;
if (nlk->pid == p->pid || !(nlk->groups & netlink_group_mask(p->group))) if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups ||
!test_bit(p->group - 1, nlk->groups))
goto out; goto out;
if (p->failure) { if (p->failure) {
...@@ -887,7 +917,8 @@ static inline int do_one_set_err(struct sock *sk, ...@@ -887,7 +917,8 @@ static inline int do_one_set_err(struct sock *sk,
if (sk == p->exclude_sk) if (sk == p->exclude_sk)
goto out; goto out;
if (nlk->pid == p->pid || !(nlk->groups & netlink_group_mask(p->group))) if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups ||
!test_bit(p->group - 1, nlk->groups))
goto out; goto out;
sk->sk_err = p->code; sk->sk_err = p->code;
...@@ -1112,6 +1143,7 @@ netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct ...@@ -1112,6 +1143,7 @@ netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct
nlk->flags |= NETLINK_KERNEL_SOCKET; nlk->flags |= NETLINK_KERNEL_SOCKET;
netlink_table_grab(); netlink_table_grab();
nl_table[unit].groups = 32;
nl_table[unit].module = module; nl_table[unit].module = module;
nl_table[unit].registered = 1; nl_table[unit].registered = 1;
netlink_table_ungrab(); netlink_table_ungrab();
...@@ -1358,7 +1390,8 @@ static int netlink_seq_show(struct seq_file *seq, void *v) ...@@ -1358,7 +1390,8 @@ static int netlink_seq_show(struct seq_file *seq, void *v)
s, s,
s->sk_protocol, s->sk_protocol,
nlk->pid, nlk->pid,
nlk->groups, nlk->flags & NETLINK_KERNEL_SOCKET ?
0 : (unsigned int)nlk->groups[0],
atomic_read(&s->sk_rmem_alloc), atomic_read(&s->sk_rmem_alloc),
atomic_read(&s->sk_wmem_alloc), atomic_read(&s->sk_wmem_alloc),
nlk->cb, nlk->cb,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment