diff options
-rw-r--r-- | include/net/tls.h | 3 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 61 |
2 files changed, 57 insertions, 7 deletions
diff --git a/include/net/tls.h b/include/net/tls.h index 8742e13bc362..e8935cfe0cd6 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -116,11 +116,14 @@ struct tls_sw_context_rx { void (*saved_data_ready)(struct sock *sk); struct sk_buff *recv_pkt; + u8 reader_present; u8 async_capable:1; u8 zc_capable:1; + u8 reader_contended:1; atomic_t decrypt_pending; /* protect crypto_wait with decrypt_pending*/ spinlock_t decrypt_compl_lock; + struct wait_queue_head wq; }; struct tls_record_info { diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 68d79ee48a56..761a63751616 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -1753,6 +1753,51 @@ tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, sk_flush_backlog(sk); } +static long tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx, + bool nonblock) +{ + long timeo; + + lock_sock(sk); + + timeo = sock_rcvtimeo(sk, nonblock); + + while (unlikely(ctx->reader_present)) { + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + ctx->reader_contended = 1; + + add_wait_queue(&ctx->wq, &wait); + sk_wait_event(sk, &timeo, + !READ_ONCE(ctx->reader_present), &wait); + remove_wait_queue(&ctx->wq, &wait); + + if (!timeo) + return -EAGAIN; + if (signal_pending(current)) + return sock_intr_errno(timeo); + } + + WRITE_ONCE(ctx->reader_present, 1); + + return timeo; +} + +static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx) +{ + if (unlikely(ctx->reader_contended)) { + if (wq_has_sleeper(&ctx->wq)) + wake_up(&ctx->wq); + else + ctx->reader_contended = 0; + + WARN_ON_ONCE(!ctx->reader_present); + } + + WRITE_ONCE(ctx->reader_present, 0); + release_sock(sk); +} + int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, @@ -1782,7 +1827,9 @@ int tls_sw_recvmsg(struct sock *sk, return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); psock = sk_psock_get(sk); - lock_sock(sk); + timeo = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT); + if (timeo < 0) + return timeo; bpf_strp_enabled = sk_psock_strp_enabled(psock); /* If crypto failed the connection is broken */ @@ -1801,7 +1848,6 @@ int tls_sw_recvmsg(struct sock *sk, target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); len = len - copied; - timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek && ctx->zc_capable; @@ -1956,7 +2002,7 @@ recv_end: copied += decrypted; end: - release_sock(sk); + tls_rx_reader_unlock(sk, ctx); if (psock) sk_psock_put(sk, psock); return copied ? : err; @@ -1978,9 +2024,9 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, long timeo; int chunk; - lock_sock(sk); - - timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK); + timeo = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK); + if (timeo < 0) + return timeo; from_queue = !skb_queue_empty(&ctx->rx_list); if (from_queue) { @@ -2029,7 +2075,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, } splice_read_end: - release_sock(sk); + tls_rx_reader_unlock(sk, ctx); return copied ? : err; } @@ -2371,6 +2417,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) } else { crypto_init_wait(&sw_ctx_rx->async_wait); spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); + init_waitqueue_head(&sw_ctx_rx->wq); crypto_info = &ctx->crypto_recv.info; cctx = &ctx->rx; skb_queue_head_init(&sw_ctx_rx->rx_list); |