diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_device.c | 6 | ||||
-rw-r--r-- | net/tls/tls_main.c | 2 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 21 |
3 files changed, 15 insertions, 14 deletions
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 5a3715ddc592..683d00837693 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -523,8 +523,10 @@ last_record: int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) { unsigned char record_type = TLS_RECORD_TYPE_DATA; + struct tls_context *tls_ctx = tls_get_ctx(sk); int rc; + mutex_lock(&tls_ctx->tx_lock); lock_sock(sk); if (unlikely(msg->msg_controllen)) { @@ -538,12 +540,14 @@ int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) out: release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); return rc; } int tls_device_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags) { + struct tls_context *tls_ctx = tls_get_ctx(sk); struct iov_iter msg_iter; char *kaddr = kmap(page); struct kvec iov; @@ -552,6 +556,7 @@ int tls_device_sendpage(struct sock *sk, struct page *page, if (flags & MSG_SENDPAGE_NOTLAST) flags |= MSG_MORE; + mutex_lock(&tls_ctx->tx_lock); lock_sock(sk); if (flags & MSG_OOB) { @@ -568,6 +573,7 @@ int tls_device_sendpage(struct sock *sk, struct page *page, out: release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); return rc; } diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index ac88877dcade..0775ae40fcfb 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -267,6 +267,7 @@ void tls_ctx_free(struct sock *sk, struct tls_context *ctx) memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); + mutex_destroy(&ctx->tx_lock); if (sk) kfree_rcu(ctx, rcu); @@ -612,6 +613,7 @@ static struct tls_context *create_ctx(struct sock *sk) if (!ctx) return NULL; + mutex_init(&ctx->tx_lock); rcu_assign_pointer(icsk->icsk_ulp_data, ctx); ctx->sk_proto = sk->sk_prot; return ctx; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index e155b792df0b..446f23c1f3ce 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -897,15 +897,9 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) return -ENOTSUPP; + mutex_lock(&tls_ctx->tx_lock); lock_sock(sk); - /* Wait till there is any pending write on socket */ - if (unlikely(sk->sk_write_pending)) { - ret = wait_on_pending_writer(sk, &timeo); - if (unlikely(ret)) - goto send_end; - } - if (unlikely(msg->msg_controllen)) { ret = tls_proccess_cmsg(sk, msg, &record_type); if (ret) { @@ -1091,6 +1085,7 @@ send_end: ret = sk_stream_error(sk, msg->msg_flags, ret); release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); return copied ? copied : ret; } @@ -1114,13 +1109,6 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page, eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST)); sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); - /* Wait till there is any pending write on socket */ - if (unlikely(sk->sk_write_pending)) { - ret = wait_on_pending_writer(sk, &timeo); - if (unlikely(ret)) - goto sendpage_end; - } - /* Call the sk_stream functions to manage the sndbuf mem. */ while (size > 0) { size_t copy, required_size; @@ -1219,15 +1207,18 @@ sendpage_end: int tls_sw_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags) { + struct tls_context *tls_ctx = tls_get_ctx(sk); int ret; if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) return -ENOTSUPP; + mutex_lock(&tls_ctx->tx_lock); lock_sock(sk); ret = tls_sw_do_sendpage(sk, page, offset, size, flags); release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); return ret; } @@ -2170,9 +2161,11 @@ static void tx_work_handler(struct work_struct *work) if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) return; + mutex_lock(&tls_ctx->tx_lock); lock_sock(sk); tls_tx_records(sk, -1); release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); } void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) |