summaryrefslogtreecommitdiff
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls_device.c6
-rw-r--r--net/tls/tls_main.c2
-rw-r--r--net/tls/tls_sw.c21
3 files changed, 15 insertions, 14 deletions
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index 5a3715ddc592..683d00837693 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -523,8 +523,10 @@ last_record:
int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{
unsigned char record_type = TLS_RECORD_TYPE_DATA;
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
int rc;
+ mutex_lock(&tls_ctx->tx_lock);
lock_sock(sk);
if (unlikely(msg->msg_controllen)) {
@@ -538,12 +540,14 @@ int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
out:
release_sock(sk);
+ mutex_unlock(&tls_ctx->tx_lock);
return rc;
}
int tls_device_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags)
{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
struct iov_iter msg_iter;
char *kaddr = kmap(page);
struct kvec iov;
@@ -552,6 +556,7 @@ int tls_device_sendpage(struct sock *sk, struct page *page,
if (flags & MSG_SENDPAGE_NOTLAST)
flags |= MSG_MORE;
+ mutex_lock(&tls_ctx->tx_lock);
lock_sock(sk);
if (flags & MSG_OOB) {
@@ -568,6 +573,7 @@ int tls_device_sendpage(struct sock *sk, struct page *page,
out:
release_sock(sk);
+ mutex_unlock(&tls_ctx->tx_lock);
return rc;
}
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index ac88877dcade..0775ae40fcfb 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -267,6 +267,7 @@ void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
+ mutex_destroy(&ctx->tx_lock);
if (sk)
kfree_rcu(ctx, rcu);
@@ -612,6 +613,7 @@ static struct tls_context *create_ctx(struct sock *sk)
if (!ctx)
return NULL;
+ mutex_init(&ctx->tx_lock);
rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
ctx->sk_proto = sk->sk_prot;
return ctx;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index e155b792df0b..446f23c1f3ce 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -897,15 +897,9 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
return -ENOTSUPP;
+ mutex_lock(&tls_ctx->tx_lock);
lock_sock(sk);
- /* Wait till there is any pending write on socket */
- if (unlikely(sk->sk_write_pending)) {
- ret = wait_on_pending_writer(sk, &timeo);
- if (unlikely(ret))
- goto send_end;
- }
-
if (unlikely(msg->msg_controllen)) {
ret = tls_proccess_cmsg(sk, msg, &record_type);
if (ret) {
@@ -1091,6 +1085,7 @@ send_end:
ret = sk_stream_error(sk, msg->msg_flags, ret);
release_sock(sk);
+ mutex_unlock(&tls_ctx->tx_lock);
return copied ? copied : ret;
}
@@ -1114,13 +1109,6 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
- /* Wait till there is any pending write on socket */
- if (unlikely(sk->sk_write_pending)) {
- ret = wait_on_pending_writer(sk, &timeo);
- if (unlikely(ret))
- goto sendpage_end;
- }
-
/* Call the sk_stream functions to manage the sndbuf mem. */
while (size > 0) {
size_t copy, required_size;
@@ -1219,15 +1207,18 @@ sendpage_end:
int tls_sw_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags)
{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
int ret;
if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
return -ENOTSUPP;
+ mutex_lock(&tls_ctx->tx_lock);
lock_sock(sk);
ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
release_sock(sk);
+ mutex_unlock(&tls_ctx->tx_lock);
return ret;
}
@@ -2170,9 +2161,11 @@ static void tx_work_handler(struct work_struct *work)
if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
return;
+ mutex_lock(&tls_ctx->tx_lock);
lock_sock(sk);
tls_tx_records(sk, -1);
release_sock(sk);
+ mutex_unlock(&tls_ctx->tx_lock);
}
void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)