From: Kuniyuki Iwashima kuniyu@amazon.com
[ Upstream commit 69421bf98482d089e50799f45e48b25ce4a8d154 ]
When we call connect() for a UDP socket in a reuseport group, we have to update sk->sk_reuseport_cb->has_conns to 1. Otherwise, the kernel could select a unconnected socket wrongly for packets sent to the connected socket.
However, the current way to set has_conns is illegal and possible to trigger that problem. reuseport_has_conns() changes has_conns under rcu_read_lock(), which upgrades the RCU reader to the updater. Then, it must do the update under the updater's lock, reuseport_lock, but it doesn't for now.
For this reason, there is a race below where we fail to set has_conns resulting in the wrong socket selection. To avoid the race, let's split the reader and updater with proper locking.
cpu1 cpu2 +----+ +----+
__ip[46]_datagram_connect() reuseport_grow() . . |- reuseport_has_conns(sk, true) |- more_reuse = __reuseport_alloc(more_socks_size) | . | | |- rcu_read_lock() | |- reuse = rcu_dereference(sk->sk_reuseport_cb) | | | | | /* reuse->has_conns == 0 here */ | | |- more_reuse->has_conns = reuse->has_conns | |- reuse->has_conns = 1 | /* more_reuse->has_conns SHOULD BE 1 HERE */ | | | | | |- rcu_assign_pointer(reuse->socks[i]->sk_reuseport_cb, | | | more_reuse) | `- rcu_read_unlock() `- kfree_rcu(reuse, rcu) | |- sk->sk_state = TCP_ESTABLISHED
Note the likely(reuse) in reuseport_has_conns_set() is always true, but we put the test there for ease of review. [0]
For the record, usually, sk_reuseport_cb is changed under lock_sock(). The only exception is reuseport_grow() & TCP reqsk migration case.
1) shutdown() TCP listener, which is moved into the latter part of reuse->socks[] to migrate reqsk.
2) New listen() overflows reuse->socks[] and call reuseport_grow().
3) reuse->max_socks overflows u16 with the new listener.
4) reuseport_grow() pops the old shutdown()ed listener from the array and update its sk->sk_reuseport_cb as NULL without lock_sock().
shutdown()ed TCP sk->sk_reuseport_cb can be changed without lock_sock(), but, reuseport_has_conns_set() is called only for UDP under lock_sock(), so likely(reuse) never be false in reuseport_has_conns_set().
[0]: https://lore.kernel.org/netdev/CANn89iLja=eQHbsM_Ta2sQF0tOGU8vAGrh_izRuuHjuO...
Fixes: acdcecc61285 ("udp: correct reuseport selection with connected sockets") Signed-off-by: Kuniyuki Iwashima kuniyu@amazon.com Link: https://lore.kernel.org/r/20221014182625.89913-1-kuniyu@amazon.com Signed-off-by: Paolo Abeni pabeni@redhat.com Signed-off-by: Sasha Levin sashal@kernel.org --- include/net/sock_reuseport.h | 11 +++++------ net/core/sock_reuseport.c | 16 ++++++++++++++++ net/ipv4/datagram.c | 2 +- net/ipv4/udp.c | 2 +- net/ipv6/datagram.c | 2 +- net/ipv6/udp.c | 2 +- 6 files changed, 25 insertions(+), 10 deletions(-)
diff --git a/include/net/sock_reuseport.h b/include/net/sock_reuseport.h index 473b0b0fa4ab..efc9085c6892 100644 --- a/include/net/sock_reuseport.h +++ b/include/net/sock_reuseport.h @@ -43,21 +43,20 @@ struct sock *reuseport_migrate_sock(struct sock *sk, extern int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog); extern int reuseport_detach_prog(struct sock *sk);
-static inline bool reuseport_has_conns(struct sock *sk, bool set) +static inline bool reuseport_has_conns(struct sock *sk) { struct sock_reuseport *reuse; bool ret = false;
rcu_read_lock(); reuse = rcu_dereference(sk->sk_reuseport_cb); - if (reuse) { - if (set) - reuse->has_conns = 1; - ret = reuse->has_conns; - } + if (reuse && reuse->has_conns) + ret = true; rcu_read_unlock();
return ret; }
+void reuseport_has_conns_set(struct sock *sk); + #endif /* _SOCK_REUSEPORT_H */ diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c index 5daa1fa54249..fb90e1e00773 100644 --- a/net/core/sock_reuseport.c +++ b/net/core/sock_reuseport.c @@ -21,6 +21,22 @@ static DEFINE_IDA(reuseport_ida); static int reuseport_resurrect(struct sock *sk, struct sock_reuseport *old_reuse, struct sock_reuseport *reuse, bool bind_inany);
+void reuseport_has_conns_set(struct sock *sk) +{ + struct sock_reuseport *reuse; + + if (!rcu_access_pointer(sk->sk_reuseport_cb)) + return; + + spin_lock_bh(&reuseport_lock); + reuse = rcu_dereference_protected(sk->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + if (likely(reuse)) + reuse->has_conns = 1; + spin_unlock_bh(&reuseport_lock); +} +EXPORT_SYMBOL(reuseport_has_conns_set); + static int reuseport_sock_index(struct sock *sk, const struct sock_reuseport *reuse, bool closed) diff --git a/net/ipv4/datagram.c b/net/ipv4/datagram.c index 4a8550c49202..112c6e892d30 100644 --- a/net/ipv4/datagram.c +++ b/net/ipv4/datagram.c @@ -70,7 +70,7 @@ int __ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len } inet->inet_daddr = fl4->daddr; inet->inet_dport = usin->sin_port; - reuseport_has_conns(sk, true); + reuseport_has_conns_set(sk); sk->sk_state = TCP_ESTABLISHED; sk_set_txhash(sk); inet->inet_id = prandom_u32(); diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 75d1977ecc07..79d5425bed07 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -446,7 +446,7 @@ static struct sock *udp4_lib_lookup2(struct net *net, result = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum); /* Fall back to scoring if group has connections */ - if (result && !reuseport_has_conns(sk, false)) + if (result && !reuseport_has_conns(sk)) return result;
result = result ? : sk; diff --git a/net/ipv6/datagram.c b/net/ipv6/datagram.c index 206f66310a88..f4559e5bc84b 100644 --- a/net/ipv6/datagram.c +++ b/net/ipv6/datagram.c @@ -256,7 +256,7 @@ int __ip6_datagram_connect(struct sock *sk, struct sockaddr *uaddr, goto out; }
- reuseport_has_conns(sk, true); + reuseport_has_conns_set(sk); sk->sk_state = TCP_ESTABLISHED; sk_set_txhash(sk); out: diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index 07726a51a3f0..19b6c4da0f42 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -180,7 +180,7 @@ static struct sock *udp6_lib_lookup2(struct net *net, result = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum); /* Fall back to scoring if group has connections */ - if (result && !reuseport_has_conns(sk, false)) + if (result && !reuseport_has_conns(sk)) return result;
result = result ? : sk;