summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/linux/skmsg.h25
-rw-r--r--net/core/sock_map.c11
2 files changed, 26 insertions, 10 deletions
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 22347b08e1f8..84e18863f6a4 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -270,11 +270,6 @@ static inline struct sk_psock *sk_psock(const struct sock *sk)
return rcu_dereference_sk_user_data(sk);
}
-static inline bool sk_has_psock(struct sock *sk)
-{
- return sk_psock(sk) != NULL && sk->sk_prot->recvmsg == tcp_bpf_recvmsg;
-}
-
static inline void sk_psock_queue_msg(struct sk_psock *psock,
struct sk_msg *msg)
{
@@ -374,6 +369,26 @@ static inline bool sk_psock_test_state(const struct sk_psock *psock,
return test_bit(bit, &psock->state);
}
+static inline struct sk_psock *sk_psock_get_checked(struct sock *sk)
+{
+ struct sk_psock *psock;
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (psock) {
+ if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
+ psock = ERR_PTR(-EBUSY);
+ goto out;
+ }
+
+ if (!refcount_inc_not_zero(&psock->refcnt))
+ psock = ERR_PTR(-EBUSY);
+ }
+out:
+ rcu_read_unlock();
+ return psock;
+}
+
static inline struct sk_psock *sk_psock_get(struct sock *sk)
{
struct sk_psock *psock;
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index 3c0e44cb811a..be6092ac69f8 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -175,12 +175,13 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
}
}
- psock = sk_psock_get(sk);
+ psock = sk_psock_get_checked(sk);
+ if (IS_ERR(psock)) {
+ ret = PTR_ERR(psock);
+ goto out_progs;
+ }
+
if (psock) {
- if (!sk_has_psock(sk)) {
- ret = -EBUSY;
- goto out_progs;
- }
if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
(skb_progs && READ_ONCE(psock->progs.skb_parser))) {
sk_psock_put(sk, psock);