diff options
Diffstat (limited to 'net/ipv4/tcp_minisocks.c')
-rw-r--r-- | net/ipv4/tcp_minisocks.c | 61 |
1 files changed, 39 insertions, 22 deletions
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c index c375f603a16c..e002f2e1d4f2 100644 --- a/net/ipv4/tcp_minisocks.c +++ b/net/ipv4/tcp_minisocks.c @@ -240,6 +240,40 @@ kill: } EXPORT_SYMBOL(tcp_timewait_state_process); +static void tcp_time_wait_init(struct sock *sk, struct tcp_timewait_sock *tcptw) +{ +#ifdef CONFIG_TCP_MD5SIG + const struct tcp_sock *tp = tcp_sk(sk); + struct tcp_md5sig_key *key; + + /* + * The timewait bucket does not have the key DB from the + * sock structure. We just make a quick copy of the + * md5 key being used (if indeed we are using one) + * so the timewait ack generating code has the key. + */ + tcptw->tw_md5_key = NULL; + if (!static_branch_unlikely(&tcp_md5_needed.key)) + return; + + key = tp->af_specific->md5_lookup(sk, sk); + if (key) { + tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC); + if (!tcptw->tw_md5_key) + return; + if (!tcp_alloc_md5sig_pool()) + goto out_free; + if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) + goto out_free; + } + return; +out_free: + WARN_ON_ONCE(1); + kfree(tcptw->tw_md5_key); + tcptw->tw_md5_key = NULL; +#endif +} + /* * Move a socket to time-wait or dead fin-wait-2 state. */ @@ -282,26 +316,7 @@ void tcp_time_wait(struct sock *sk, int state, int timeo) } #endif -#ifdef CONFIG_TCP_MD5SIG - /* - * The timewait bucket does not have the key DB from the - * sock structure. We just make a quick copy of the - * md5 key being used (if indeed we are using one) - * so the timewait ack generating code has the key. - */ - do { - tcptw->tw_md5_key = NULL; - if (static_branch_unlikely(&tcp_md5_needed)) { - struct tcp_md5sig_key *key; - - key = tp->af_specific->md5_lookup(sk, sk); - if (key) { - tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC); - BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool()); - } - } - } while (0); -#endif + tcp_time_wait_init(sk, tcptw); /* Get the TIME_WAIT timeout firing. */ if (timeo < rto) @@ -337,11 +352,13 @@ EXPORT_SYMBOL(tcp_time_wait); void tcp_twsk_destructor(struct sock *sk) { #ifdef CONFIG_TCP_MD5SIG - if (static_branch_unlikely(&tcp_md5_needed)) { + if (static_branch_unlikely(&tcp_md5_needed.key)) { struct tcp_timewait_sock *twsk = tcp_twsk(sk); - if (twsk->tw_md5_key) + if (twsk->tw_md5_key) { kfree_rcu(twsk->tw_md5_key, rcu); + static_branch_slow_dec_deferred(&tcp_md5_needed); + } } #endif } |