diff options
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r-- | net/tls/tls_main.c | 489 |
1 files changed, 489 insertions, 0 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c new file mode 100644 index 000000000000..a03130a47b85 --- /dev/null +++ b/net/tls/tls_main.c @@ -0,0 +1,489 @@ +/* + * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. + * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include <linux/module.h> + +#include <net/tcp.h> +#include <net/inet_common.h> +#include <linux/highmem.h> +#include <linux/netdevice.h> +#include <linux/sched/signal.h> + +#include <net/tls.h> + +MODULE_AUTHOR("Mellanox Technologies"); +MODULE_DESCRIPTION("Transport Layer Security Support"); +MODULE_LICENSE("Dual BSD/GPL"); + +static struct proto tls_base_prot; +static struct proto tls_sw_prot; + +int wait_on_pending_writer(struct sock *sk, long *timeo) +{ + int rc = 0; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + add_wait_queue(sk_sleep(sk), &wait); + while (1) { + if (!*timeo) { + rc = -EAGAIN; + break; + } + + if (signal_pending(current)) { + rc = sock_intr_errno(*timeo); + break; + } + + if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait)) + break; + } + remove_wait_queue(sk_sleep(sk), &wait); + return rc; +} + +int tls_push_sg(struct sock *sk, + struct tls_context *ctx, + struct scatterlist *sg, + u16 first_offset, + int flags) +{ + int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST; + int ret = 0; + struct page *p; + size_t size; + int offset = first_offset; + + size = sg->length - offset; + offset += sg->offset; + + while (1) { + if (sg_is_last(sg)) + sendpage_flags = flags; + + /* is sending application-limited? */ + tcp_rate_check_app_limited(sk); + p = sg_page(sg); +retry: + ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags); + + if (ret != size) { + if (ret > 0) { + offset += ret; + size -= ret; + goto retry; + } + + offset -= sg->offset; + ctx->partially_sent_offset = offset; + ctx->partially_sent_record = (void *)sg; + return ret; + } + + put_page(p); + sk_mem_uncharge(sk, sg->length); + sg = sg_next(sg); + if (!sg) + break; + + offset = sg->offset; + size = sg->length; + } + + clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); + + return 0; +} + +static int tls_handle_open_record(struct sock *sk, int flags) +{ + struct tls_context *ctx = tls_get_ctx(sk); + + if (tls_is_pending_open_record(ctx)) + return ctx->push_pending_record(sk, flags); + + return 0; +} + +int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, + unsigned char *record_type) +{ + struct cmsghdr *cmsg; + int rc = -EINVAL; + + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) + return -EINVAL; + if (cmsg->cmsg_level != SOL_TLS) + continue; + + switch (cmsg->cmsg_type) { + case TLS_SET_RECORD_TYPE: + if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) + return -EINVAL; + + if (msg->msg_flags & MSG_MORE) + return -EINVAL; + + rc = tls_handle_open_record(sk, msg->msg_flags); + if (rc) + return rc; + + *record_type = *(unsigned char *)CMSG_DATA(cmsg); + rc = 0; + break; + default: + return -EINVAL; + } + } + + return rc; +} + +int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx, + int flags, long *timeo) +{ + struct scatterlist *sg; + u16 offset; + + if (!tls_is_partially_sent_record(ctx)) + return ctx->push_pending_record(sk, flags); + + sg = ctx->partially_sent_record; + offset = ctx->partially_sent_offset; + + ctx->partially_sent_record = NULL; + return tls_push_sg(sk, ctx, sg, offset, flags); +} + +static void tls_write_space(struct sock *sk) +{ + struct tls_context *ctx = tls_get_ctx(sk); + + if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) { + gfp_t sk_allocation = sk->sk_allocation; + int rc; + long timeo = 0; + + sk->sk_allocation = GFP_ATOMIC; + rc = tls_push_pending_closed_record(sk, ctx, + MSG_DONTWAIT | + MSG_NOSIGNAL, + &timeo); + sk->sk_allocation = sk_allocation; + + if (rc < 0) + return; + } + + ctx->sk_write_space(sk); +} + +static void tls_sk_proto_close(struct sock *sk, long timeout) +{ + struct tls_context *ctx = tls_get_ctx(sk); + long timeo = sock_sndtimeo(sk, 0); + void (*sk_proto_close)(struct sock *sk, long timeout); + + lock_sock(sk); + + if (!tls_complete_pending_work(sk, ctx, 0, &timeo)) + tls_handle_open_record(sk, 0); + + if (ctx->partially_sent_record) { + struct scatterlist *sg = ctx->partially_sent_record; + + while (1) { + put_page(sg_page(sg)); + sk_mem_uncharge(sk, sg->length); + + if (sg_is_last(sg)) + break; + sg++; + } + } + ctx->free_resources(sk); + kfree(ctx->rec_seq); + kfree(ctx->iv); + + sk_proto_close = ctx->sk_proto_close; + kfree(ctx); + + release_sock(sk); + sk_proto_close(sk, timeout); +} + +static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, + int __user *optlen) +{ + int rc = 0; + struct tls_context *ctx = tls_get_ctx(sk); + struct tls_crypto_info *crypto_info; + int len; + + if (get_user(len, optlen)) + return -EFAULT; + + if (!optval || (len < sizeof(*crypto_info))) { + rc = -EINVAL; + goto out; + } + + if (!ctx) { + rc = -EBUSY; + goto out; + } + + /* get user crypto info */ + crypto_info = &ctx->crypto_send; + + if (!TLS_CRYPTO_INFO_READY(crypto_info)) { + rc = -EBUSY; + goto out; + } + + if (len == sizeof(crypto_info)) { + if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) + rc = -EFAULT; + goto out; + } + + switch (crypto_info->cipher_type) { + case TLS_CIPHER_AES_GCM_128: { + struct tls12_crypto_info_aes_gcm_128 * + crypto_info_aes_gcm_128 = + container_of(crypto_info, + struct tls12_crypto_info_aes_gcm_128, + info); + + if (len != sizeof(*crypto_info_aes_gcm_128)) { + rc = -EINVAL; + goto out; + } + lock_sock(sk); + memcpy(crypto_info_aes_gcm_128->iv, ctx->iv, + TLS_CIPHER_AES_GCM_128_IV_SIZE); + release_sock(sk); + if (copy_to_user(optval, + crypto_info_aes_gcm_128, + sizeof(*crypto_info_aes_gcm_128))) + rc = -EFAULT; + break; + } + default: + rc = -EINVAL; + } + +out: + return rc; +} + +static int do_tls_getsockopt(struct sock *sk, int optname, + char __user *optval, int __user *optlen) +{ + int rc = 0; + + switch (optname) { + case TLS_TX: + rc = do_tls_getsockopt_tx(sk, optval, optlen); + break; + default: + rc = -ENOPROTOOPT; + break; + } + return rc; +} + +static int tls_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + + if (level != SOL_TLS) + return ctx->getsockopt(sk, level, optname, optval, optlen); + + return do_tls_getsockopt(sk, optname, optval, optlen); +} + +static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval, + unsigned int optlen) +{ + struct tls_crypto_info *crypto_info, tmp_crypto_info; + struct tls_context *ctx = tls_get_ctx(sk); + struct proto *prot = NULL; + int rc = 0; + + if (!optval || (optlen < sizeof(*crypto_info))) { + rc = -EINVAL; + goto out; + } + + rc = copy_from_user(&tmp_crypto_info, optval, sizeof(*crypto_info)); + if (rc) { + rc = -EFAULT; + goto out; + } + + /* check version */ + if (tmp_crypto_info.version != TLS_1_2_VERSION) { + rc = -ENOTSUPP; + goto out; + } + + /* get user crypto info */ + crypto_info = &ctx->crypto_send; + + /* Currently we don't support set crypto info more than one time */ + if (TLS_CRYPTO_INFO_READY(crypto_info)) + goto out; + + switch (tmp_crypto_info.cipher_type) { + case TLS_CIPHER_AES_GCM_128: { + if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) { + rc = -EINVAL; + goto out; + } + rc = copy_from_user( + crypto_info, + optval, + sizeof(struct tls12_crypto_info_aes_gcm_128)); + + if (rc) { + rc = -EFAULT; + goto err_crypto_info; + } + break; + } + default: + rc = -EINVAL; + goto out; + } + + ctx->sk_write_space = sk->sk_write_space; + sk->sk_write_space = tls_write_space; + + ctx->sk_proto_close = sk->sk_prot->close; + + /* currently SW is default, we will have ethtool in future */ + rc = tls_set_sw_offload(sk, ctx); + prot = &tls_sw_prot; + if (rc) + goto err_crypto_info; + + sk->sk_prot = prot; + goto out; + +err_crypto_info: + memset(crypto_info, 0, sizeof(*crypto_info)); +out: + return rc; +} + +static int do_tls_setsockopt(struct sock *sk, int optname, + char __user *optval, unsigned int optlen) +{ + int rc = 0; + + switch (optname) { + case TLS_TX: + lock_sock(sk); + rc = do_tls_setsockopt_tx(sk, optval, optlen); + release_sock(sk); + break; + default: + rc = -ENOPROTOOPT; + break; + } + return rc; +} + +static int tls_setsockopt(struct sock *sk, int level, int optname, + char __user *optval, unsigned int optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + + if (level != SOL_TLS) + return ctx->setsockopt(sk, level, optname, optval, optlen); + + return do_tls_setsockopt(sk, optname, optval, optlen); +} + +static int tls_init(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tls_context *ctx; + int rc = 0; + + /* allocate tls context */ + ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); + if (!ctx) { + rc = -ENOMEM; + goto out; + } + icsk->icsk_ulp_data = ctx; + ctx->setsockopt = sk->sk_prot->setsockopt; + ctx->getsockopt = sk->sk_prot->getsockopt; + sk->sk_prot = &tls_base_prot; +out: + return rc; +} + +static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { + .name = "tls", + .owner = THIS_MODULE, + .init = tls_init, +}; + +static int __init tls_register(void) +{ + tls_base_prot = tcp_prot; + tls_base_prot.setsockopt = tls_setsockopt; + tls_base_prot.getsockopt = tls_getsockopt; + + tls_sw_prot = tls_base_prot; + tls_sw_prot.sendmsg = tls_sw_sendmsg; + tls_sw_prot.sendpage = tls_sw_sendpage; + tls_sw_prot.close = tls_sk_proto_close; + + tcp_register_ulp(&tcp_tls_ulp_ops); + + return 0; +} + +static void __exit tls_unregister(void) +{ + tcp_unregister_ulp(&tcp_tls_ulp_ops); +} + +module_init(tls_register); +module_exit(tls_unregister); |