Commit 750ac63e authored by jan.koester's avatar jan.koester
Browse files

added missing files

parent fd43a4ab
Loading
Loading
Loading
Loading

net/mctp/sched.c

0 → 100644
+173 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0
/* Multipath TCP
 *
 * Copyright (c) 2022, SUSE.
 */

#define pr_fmt(fmt) "MPTCP: " fmt

#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/list.h>
#include <linux/rculist.h>
#include <linux/spinlock.h>
#include "protocol.h"

static DEFINE_SPINLOCK(mptcp_sched_list_lock);
static LIST_HEAD(mptcp_sched_list);

static int mptcp_sched_default_get_subflow(struct mptcp_sock *msk,
					   struct mptcp_sched_data *data)
{
	struct sock *ssk;

	ssk = data->reinject ? mptcp_subflow_get_retrans(msk) :
			       mptcp_subflow_get_send(msk);
	if (!ssk)
		return -EINVAL;

	mptcp_subflow_set_scheduled(mptcp_subflow_ctx(ssk), true);
	return 0;
}

static struct mptcp_sched_ops mptcp_sched_default = {
	.get_subflow	= mptcp_sched_default_get_subflow,
	.name		= "default",
	.owner		= THIS_MODULE,
};

/* Must be called with rcu read lock held */
struct mptcp_sched_ops *mptcp_sched_find(const char *name)
{
	struct mptcp_sched_ops *sched, *ret = NULL;

	list_for_each_entry_rcu(sched, &mptcp_sched_list, list) {
		if (!strcmp(sched->name, name)) {
			ret = sched;
			break;
		}
	}

	return ret;
}

int mptcp_register_scheduler(struct mptcp_sched_ops *sched)
{
	if (!sched->get_subflow)
		return -EINVAL;

	spin_lock(&mptcp_sched_list_lock);
	if (mptcp_sched_find(sched->name)) {
		spin_unlock(&mptcp_sched_list_lock);
		return -EEXIST;
	}
	list_add_tail_rcu(&sched->list, &mptcp_sched_list);
	spin_unlock(&mptcp_sched_list_lock);

	pr_debug("%s registered", sched->name);
	return 0;
}

void mptcp_unregister_scheduler(struct mptcp_sched_ops *sched)
{
	if (sched == &mptcp_sched_default)
		return;

	spin_lock(&mptcp_sched_list_lock);
	list_del_rcu(&sched->list);
	spin_unlock(&mptcp_sched_list_lock);
}

void mptcp_sched_init(void)
{
	mptcp_register_scheduler(&mptcp_sched_default);
}

int mptcp_init_sched(struct mptcp_sock *msk,
		     struct mptcp_sched_ops *sched)
{
	if (!sched)
		sched = &mptcp_sched_default;

	if (!bpf_try_module_get(sched, sched->owner))
		return -EBUSY;

	msk->sched = sched;
	if (msk->sched->init)
		msk->sched->init(msk);

	pr_debug("sched=%s", msk->sched->name);

	return 0;
}

void mptcp_release_sched(struct mptcp_sock *msk)
{
	struct mptcp_sched_ops *sched = msk->sched;

	if (!sched)
		return;

	msk->sched = NULL;
	if (sched->release)
		sched->release(msk);

	bpf_module_put(sched, sched->owner);
}

void mptcp_subflow_set_scheduled(struct mptcp_subflow_context *subflow,
				 bool scheduled)
{
	WRITE_ONCE(subflow->scheduled, scheduled);
}

int mptcp_sched_get_send(struct mptcp_sock *msk)
{
	struct mptcp_subflow_context *subflow;
	struct mptcp_sched_data data;

	msk_owned_by_me(msk);

	/* the following check is moved out of mptcp_subflow_get_send */
	if (__mptcp_check_fallback(msk)) {
		if (msk->first &&
		    __tcp_can_send(msk->first) &&
		    sk_stream_memory_free(msk->first)) {
			mptcp_subflow_set_scheduled(mptcp_subflow_ctx(msk->first), true);
			return 0;
		}
		return -EINVAL;
	}

	mptcp_for_each_subflow(msk, subflow) {
		if (READ_ONCE(subflow->scheduled))
			return 0;
	}

	data.reinject = false;
	if (msk->sched == &mptcp_sched_default || !msk->sched)
		return mptcp_sched_default_get_subflow(msk, &data);
	return msk->sched->get_subflow(msk, &data);
}

int mptcp_sched_get_retrans(struct mptcp_sock *msk)
{
	struct mptcp_subflow_context *subflow;
	struct mptcp_sched_data data;

	msk_owned_by_me(msk);

	/* the following check is moved out of mptcp_subflow_get_retrans */
	if (__mptcp_check_fallback(msk))
		return -EINVAL;

	mptcp_for_each_subflow(msk, subflow) {
		if (READ_ONCE(subflow->scheduled))
			return 0;
	}

	data.reinject = true;
	if (msk->sched == &mptcp_sched_default || !msk->sched)
		return mptcp_sched_default_get_subflow(msk, &data);
	return msk->sched->get_subflow(msk, &data);
}

net/mctp/sockopt.c

0 → 100644
+1486 −0

File added.

Preview size limit exceeded, changes collapsed.

net/mctp/syncookies.c

0 → 100644
+133 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0
#include <linux/skbuff.h>

#include "protocol.h"

/* Syncookies do not work for JOIN requests.
 *
 * Unlike MP_CAPABLE, where the ACK cookie contains the needed MPTCP
 * options to reconstruct the initial syn state, MP_JOIN does not contain
 * the token to obtain the mptcp socket nor the server-generated nonce
 * that was used in the cookie SYN/ACK response.
 *
 * Keep a small best effort state table to store the syn/synack data,
 * indexed by skb hash.
 *
 * A MP_JOIN SYN packet handled by syn cookies is only stored if the 32bit
 * token matches a known mptcp connection that can still accept more subflows.
 *
 * There is no timeout handling -- state is only re-constructed
 * when the TCP ACK passed the cookie validation check.
 */

struct join_entry {
	u32 token;
	u32 remote_nonce;
	u32 local_nonce;
	u8 join_id;
	u8 local_id;
	u8 backup;
	u8 valid;
};

#define COOKIE_JOIN_SLOTS	1024

static struct join_entry join_entries[COOKIE_JOIN_SLOTS] __cacheline_aligned_in_smp;
static spinlock_t join_entry_locks[COOKIE_JOIN_SLOTS] __cacheline_aligned_in_smp;

static u32 mptcp_join_entry_hash(struct sk_buff *skb, struct net *net)
{
	static u32 mptcp_join_hash_secret __read_mostly;
	struct tcphdr *th = tcp_hdr(skb);
	u32 seq, i;

	net_get_random_once(&mptcp_join_hash_secret,
			    sizeof(mptcp_join_hash_secret));

	if (th->syn)
		seq = TCP_SKB_CB(skb)->seq;
	else
		seq = TCP_SKB_CB(skb)->seq - 1;

	i = jhash_3words(seq, net_hash_mix(net),
			 (__force __u32)th->source << 16 | (__force __u32)th->dest,
			 mptcp_join_hash_secret);

	return i % ARRAY_SIZE(join_entries);
}

static void mptcp_join_store_state(struct join_entry *entry,
				   const struct mptcp_subflow_request_sock *subflow_req)
{
	entry->token = subflow_req->token;
	entry->remote_nonce = subflow_req->remote_nonce;
	entry->local_nonce = subflow_req->local_nonce;
	entry->backup = subflow_req->backup;
	entry->join_id = subflow_req->remote_id;
	entry->local_id = subflow_req->local_id;
	entry->valid = 1;
}

void subflow_init_req_cookie_join_save(const struct mptcp_subflow_request_sock *subflow_req,
				       struct sk_buff *skb)
{
	struct net *net = read_pnet(&subflow_req->sk.req.ireq_net);
	u32 i = mptcp_join_entry_hash(skb, net);

	/* No use in waiting if other cpu is already using this slot --
	 * would overwrite the data that got stored.
	 */
	spin_lock_bh(&join_entry_locks[i]);
	mptcp_join_store_state(&join_entries[i], subflow_req);
	spin_unlock_bh(&join_entry_locks[i]);
}

/* Called for a cookie-ack with MP_JOIN option present.
 * Look up the saved state based on skb hash & check token matches msk
 * in same netns.
 *
 * Caller will check msk can still accept another subflow.  The hmac
 * present in the cookie ACK mptcp option space will be checked later.
 */
bool mptcp_token_join_cookie_init_state(struct mptcp_subflow_request_sock *subflow_req,
					struct sk_buff *skb)
{
	struct net *net = read_pnet(&subflow_req->sk.req.ireq_net);
	u32 i = mptcp_join_entry_hash(skb, net);
	struct mptcp_sock *msk;
	struct join_entry *e;

	e = &join_entries[i];

	spin_lock_bh(&join_entry_locks[i]);

	if (e->valid == 0) {
		spin_unlock_bh(&join_entry_locks[i]);
		return false;
	}

	e->valid = 0;

	msk = mptcp_token_get_sock(net, e->token);
	if (!msk) {
		spin_unlock_bh(&join_entry_locks[i]);
		return false;
	}

	subflow_req->remote_nonce = e->remote_nonce;
	subflow_req->local_nonce = e->local_nonce;
	subflow_req->backup = e->backup;
	subflow_req->remote_id = e->join_id;
	subflow_req->token = e->token;
	subflow_req->msk = msk;
	spin_unlock_bh(&join_entry_locks[i]);
	return true;
}

void __init mptcp_join_cookie_init(void)
{
	int i;

	for (i = 0; i < COOKIE_JOIN_SLOTS; i++)
		spin_lock_init(&join_entry_locks[i]);
}

net/mctp/token.c

0 → 100644
+422 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0
/* Multipath TCP token management
 * Copyright (c) 2017 - 2019, Intel Corporation.
 *
 * Note: This code is based on mptcp_ctrl.c from multipath-tcp.org,
 *       authored by:
 *
 *       Sébastien Barré <sebastien.barre@uclouvain.be>
 *       Christoph Paasch <christoph.paasch@uclouvain.be>
 *       Jaakko Korkeaniemi <jaakko.korkeaniemi@aalto.fi>
 *       Gregory Detal <gregory.detal@uclouvain.be>
 *       Fabien Duchêne <fabien.duchene@uclouvain.be>
 *       Andreas Seelinger <Andreas.Seelinger@rwth-aachen.de>
 *       Lavkesh Lahngir <lavkesh51@gmail.com>
 *       Andreas Ripke <ripke@neclab.eu>
 *       Vlad Dogaru <vlad.dogaru@intel.com>
 *       Octavian Purdila <octavian.purdila@intel.com>
 *       John Ronan <jronan@tssg.org>
 *       Catalin Nicutar <catalin.nicutar@gmail.com>
 *       Brandon Heller <brandonh@stanford.edu>
 */

#define pr_fmt(fmt) "MPTCP: " fmt

#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/memblock.h>
#include <linux/ip.h>
#include <linux/tcp.h>
#include <net/sock.h>
#include <net/inet_common.h>
#include <net/protocol.h>
#include <net/mptcp.h>
#include "protocol.h"

#define TOKEN_MAX_CHAIN_LEN	4

struct token_bucket {
	spinlock_t		lock;
	int			chain_len;
	struct hlist_nulls_head	req_chain;
	struct hlist_nulls_head	msk_chain;
};

static struct token_bucket *token_hash __read_mostly;
static unsigned int token_mask __read_mostly;

static struct token_bucket *token_bucket(u32 token)
{
	return &token_hash[token & token_mask];
}

/* called with bucket lock held */
static struct mptcp_subflow_request_sock *
__token_lookup_req(struct token_bucket *t, u32 token)
{
	struct mptcp_subflow_request_sock *req;
	struct hlist_nulls_node *pos;

	hlist_nulls_for_each_entry_rcu(req, pos, &t->req_chain, token_node)
		if (req->token == token)
			return req;
	return NULL;
}

/* called with bucket lock held */
static struct mptcp_sock *
__token_lookup_msk(struct token_bucket *t, u32 token)
{
	struct hlist_nulls_node *pos;
	struct sock *sk;

	sk_nulls_for_each_rcu(sk, pos, &t->msk_chain)
		if (mptcp_sk(sk)->token == token)
			return mptcp_sk(sk);
	return NULL;
}

static bool __token_bucket_busy(struct token_bucket *t, u32 token)
{
	return !token || t->chain_len >= TOKEN_MAX_CHAIN_LEN ||
	       __token_lookup_req(t, token) || __token_lookup_msk(t, token);
}

static void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn)
{
	/* we might consider a faster version that computes the key as a
	 * hash of some information available in the MPTCP socket. Use
	 * random data at the moment, as it's probably the safest option
	 * in case multiple sockets are opened in different namespaces at
	 * the same time.
	 */
	get_random_bytes(key, sizeof(u64));
	mptcp_crypto_key_sha(*key, token, idsn);
}

/**
 * mptcp_token_new_request - create new key/idsn/token for subflow_request
 * @req: the request socket
 *
 * This function is called when a new mptcp connection is coming in.
 *
 * It creates a unique token to identify the new mptcp connection,
 * a secret local key and the initial data sequence number (idsn).
 *
 * Returns 0 on success.
 */
int mptcp_token_new_request(struct request_sock *req)
{
	struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
	struct token_bucket *bucket;
	u32 token;

	mptcp_crypto_key_sha(subflow_req->local_key,
			     &subflow_req->token,
			     &subflow_req->idsn);
	pr_debug("req=%p local_key=%llu, token=%u, idsn=%llu\n",
		 req, subflow_req->local_key, subflow_req->token,
		 subflow_req->idsn);

	token = subflow_req->token;
	bucket = token_bucket(token);
	spin_lock_bh(&bucket->lock);
	if (__token_bucket_busy(bucket, token)) {
		spin_unlock_bh(&bucket->lock);
		return -EBUSY;
	}

	hlist_nulls_add_head_rcu(&subflow_req->token_node, &bucket->req_chain);
	bucket->chain_len++;
	spin_unlock_bh(&bucket->lock);
	return 0;
}

/**
 * mptcp_token_new_connect - create new key/idsn/token for subflow
 * @ssk: the socket that will initiate a connection
 *
 * This function is called when a new outgoing mptcp connection is
 * initiated.
 *
 * It creates a unique token to identify the new mptcp connection,
 * a secret local key and the initial data sequence number (idsn).
 *
 * On success, the mptcp connection can be found again using
 * the computed token at a later time, this is needed to process
 * join requests.
 *
 * returns 0 on success.
 */
int mptcp_token_new_connect(struct sock *ssk)
{
	struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
	struct mptcp_sock *msk = mptcp_sk(subflow->conn);
	int retries = MPTCP_TOKEN_MAX_RETRIES;
	struct sock *sk = subflow->conn;
	struct token_bucket *bucket;

again:
	mptcp_crypto_key_gen_sha(&subflow->local_key, &subflow->token,
				 &subflow->idsn);

	bucket = token_bucket(subflow->token);
	spin_lock_bh(&bucket->lock);
	if (__token_bucket_busy(bucket, subflow->token)) {
		spin_unlock_bh(&bucket->lock);
		if (!--retries)
			return -EBUSY;
		goto again;
	}

	pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n",
		 ssk, subflow->local_key, subflow->token, subflow->idsn);

	WRITE_ONCE(msk->token, subflow->token);
	__sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain);
	bucket->chain_len++;
	spin_unlock_bh(&bucket->lock);
	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
	return 0;
}

/**
 * mptcp_token_accept - replace a req sk with full sock in token hash
 * @req: the request socket to be removed
 * @msk: the just cloned socket linked to the new connection
 *
 * Called when a SYN packet creates a new logical connection, i.e.
 * is not a join request.
 */
void mptcp_token_accept(struct mptcp_subflow_request_sock *req,
			struct mptcp_sock *msk)
{
	struct mptcp_subflow_request_sock *pos;
	struct sock *sk = (struct sock *)msk;
	struct token_bucket *bucket;

	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
	bucket = token_bucket(req->token);
	spin_lock_bh(&bucket->lock);

	/* pedantic lookup check for the moved token */
	pos = __token_lookup_req(bucket, req->token);
	if (!WARN_ON_ONCE(pos != req))
		hlist_nulls_del_init_rcu(&req->token_node);
	__sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain);
	spin_unlock_bh(&bucket->lock);
}

bool mptcp_token_exists(u32 token)
{
	struct hlist_nulls_node *pos;
	struct token_bucket *bucket;
	struct mptcp_sock *msk;
	struct sock *sk;

	rcu_read_lock();
	bucket = token_bucket(token);

again:
	sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
		msk = mptcp_sk(sk);
		if (READ_ONCE(msk->token) == token)
			goto found;
	}
	if (get_nulls_value(pos) != (token & token_mask))
		goto again;

	rcu_read_unlock();
	return false;
found:
	rcu_read_unlock();
	return true;
}

/**
 * mptcp_token_get_sock - retrieve mptcp connection sock using its token
 * @net: restrict to this namespace
 * @token: token of the mptcp connection to retrieve
 *
 * This function returns the mptcp connection structure with the given token.
 * A reference count on the mptcp socket returned is taken.
 *
 * returns NULL if no connection with the given token value exists.
 */
struct mptcp_sock *mptcp_token_get_sock(struct net *net, u32 token)
{
	struct hlist_nulls_node *pos;
	struct token_bucket *bucket;
	struct mptcp_sock *msk;
	struct sock *sk;

	rcu_read_lock();
	bucket = token_bucket(token);

again:
	sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
		msk = mptcp_sk(sk);
		if (READ_ONCE(msk->token) != token ||
		    !net_eq(sock_net(sk), net))
			continue;

		if (!refcount_inc_not_zero(&sk->sk_refcnt))
			goto not_found;

		if (READ_ONCE(msk->token) != token ||
		    !net_eq(sock_net(sk), net)) {
			sock_put(sk);
			goto again;
		}
		goto found;
	}
	if (get_nulls_value(pos) != (token & token_mask))
		goto again;

not_found:
	msk = NULL;

found:
	rcu_read_unlock();
	return msk;
}
EXPORT_SYMBOL_GPL(mptcp_token_get_sock);

/**
 * mptcp_token_iter_next - iterate over the token container from given pos
 * @net: namespace to be iterated
 * @s_slot: start slot number
 * @s_num: start number inside the given lock
 *
 * This function returns the first mptcp connection structure found inside the
 * token container starting from the specified position, or NULL.
 *
 * On successful iteration, the iterator is moved to the next position and
 * a reference to the returned socket is acquired.
 */
struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot,
					 long *s_num)
{
	struct mptcp_sock *ret = NULL;
	struct hlist_nulls_node *pos;
	int slot, num = 0;

	for (slot = *s_slot; slot <= token_mask; *s_num = 0, slot++) {
		struct token_bucket *bucket = &token_hash[slot];
		struct sock *sk;

		num = 0;

		if (hlist_nulls_empty(&bucket->msk_chain))
			continue;

		rcu_read_lock();
		sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
			++num;
			if (!net_eq(sock_net(sk), net))
				continue;

			if (num <= *s_num)
				continue;

			if (!refcount_inc_not_zero(&sk->sk_refcnt))
				continue;

			if (!net_eq(sock_net(sk), net)) {
				sock_put(sk);
				continue;
			}

			ret = mptcp_sk(sk);
			rcu_read_unlock();
			goto out;
		}
		rcu_read_unlock();
	}

out:
	*s_slot = slot;
	*s_num = num;
	return ret;
}
EXPORT_SYMBOL_GPL(mptcp_token_iter_next);

/**
 * mptcp_token_destroy_request - remove mptcp connection/token
 * @req: mptcp request socket dropping the token
 *
 * Remove the token associated to @req.
 */
void mptcp_token_destroy_request(struct request_sock *req)
{
	struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
	struct mptcp_subflow_request_sock *pos;
	struct token_bucket *bucket;

	if (hlist_nulls_unhashed(&subflow_req->token_node))
		return;

	bucket = token_bucket(subflow_req->token);
	spin_lock_bh(&bucket->lock);
	pos = __token_lookup_req(bucket, subflow_req->token);
	if (!WARN_ON_ONCE(pos != subflow_req)) {
		hlist_nulls_del_init_rcu(&pos->token_node);
		bucket->chain_len--;
	}
	spin_unlock_bh(&bucket->lock);
}

/**
 * mptcp_token_destroy - remove mptcp connection/token
 * @msk: mptcp connection dropping the token
 *
 * Remove the token associated to @msk
 */
void mptcp_token_destroy(struct mptcp_sock *msk)
{
	struct sock *sk = (struct sock *)msk;
	struct token_bucket *bucket;
	struct mptcp_sock *pos;

	if (sk_unhashed((struct sock *)msk))
		return;

	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
	bucket = token_bucket(msk->token);
	spin_lock_bh(&bucket->lock);
	pos = __token_lookup_msk(bucket, msk->token);
	if (!WARN_ON_ONCE(pos != msk)) {
		__sk_nulls_del_node_init_rcu((struct sock *)pos);
		bucket->chain_len--;
	}
	spin_unlock_bh(&bucket->lock);
	WRITE_ONCE(msk->token, 0);
}

void __init mptcp_token_init(void)
{
	int i;

	token_hash = alloc_large_system_hash("MPTCP token",
					     sizeof(struct token_bucket),
					     0,
					     20,/* one slot per 1MB of memory */
					     HASH_ZERO,
					     NULL,
					     &token_mask,
					     0,
					     64 * 1024);
	for (i = 0; i < token_mask + 1; ++i) {
		INIT_HLIST_NULLS_HEAD(&token_hash[i].req_chain, i);
		INIT_HLIST_NULLS_HEAD(&token_hash[i].msk_chain, i);
		spin_lock_init(&token_hash[i].lock);
	}
}

#if IS_MODULE(CONFIG_MPTCP_KUNIT_TEST)
EXPORT_SYMBOL_GPL(mptcp_token_new_request);
EXPORT_SYMBOL_GPL(mptcp_token_new_connect);
EXPORT_SYMBOL_GPL(mptcp_token_accept);
EXPORT_SYMBOL_GPL(mptcp_token_destroy_request);
EXPORT_SYMBOL_GPL(mptcp_token_destroy);
#endif

net/mctp/token_test.c

0 → 100644
+145 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0
#include <kunit/test.h>

#include "protocol.h"

static struct mptcp_subflow_request_sock *build_req_sock(struct kunit *test)
{
	struct mptcp_subflow_request_sock *req;

	req = kunit_kzalloc(test, sizeof(struct mptcp_subflow_request_sock),
			    GFP_USER);
	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, req);
	mptcp_token_init_request((struct request_sock *)req);
	sock_net_set((struct sock *)req, &init_net);
	return req;
}

static void mptcp_token_test_req_basic(struct kunit *test)
{
	struct mptcp_subflow_request_sock *req = build_req_sock(test);
	struct mptcp_sock *null_msk = NULL;

	KUNIT_ASSERT_EQ(test, 0,
			mptcp_token_new_request((struct request_sock *)req));
	KUNIT_EXPECT_NE(test, 0, (int)req->token);
	KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, req->token));

	/* cleanup */
	mptcp_token_destroy_request((struct request_sock *)req);
}

static struct inet_connection_sock *build_icsk(struct kunit *test)
{
	struct inet_connection_sock *icsk;

	icsk = kunit_kzalloc(test, sizeof(struct inet_connection_sock),
			     GFP_USER);
	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, icsk);
	return icsk;
}

static struct mptcp_subflow_context *build_ctx(struct kunit *test)
{
	struct mptcp_subflow_context *ctx;

	ctx = kunit_kzalloc(test, sizeof(struct mptcp_subflow_context),
			    GFP_USER);
	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, ctx);
	return ctx;
}

static struct mptcp_sock *build_msk(struct kunit *test)
{
	struct mptcp_sock *msk;

	msk = kunit_kzalloc(test, sizeof(struct mptcp_sock), GFP_USER);
	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, msk);
	refcount_set(&((struct sock *)msk)->sk_refcnt, 1);
	sock_net_set((struct sock *)msk, &init_net);

	/* be sure the token helpers can dereference sk->sk_prot */
	((struct sock *)msk)->sk_prot = &tcp_prot;
	return msk;
}

static void mptcp_token_test_msk_basic(struct kunit *test)
{
	struct inet_connection_sock *icsk = build_icsk(test);
	struct mptcp_subflow_context *ctx = build_ctx(test);
	struct mptcp_sock *msk = build_msk(test);
	struct mptcp_sock *null_msk = NULL;
	struct sock *sk;

	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
	ctx->conn = (struct sock *)msk;
	sk = (struct sock *)msk;

	KUNIT_ASSERT_EQ(test, 0,
			mptcp_token_new_connect((struct sock *)icsk));
	KUNIT_EXPECT_NE(test, 0, (int)ctx->token);
	KUNIT_EXPECT_EQ(test, ctx->token, msk->token);
	KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, ctx->token));
	KUNIT_EXPECT_EQ(test, 2, (int)refcount_read(&sk->sk_refcnt));

	mptcp_token_destroy(msk);
	KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, ctx->token));
}

static void mptcp_token_test_accept(struct kunit *test)
{
	struct mptcp_subflow_request_sock *req = build_req_sock(test);
	struct mptcp_sock *msk = build_msk(test);

	KUNIT_ASSERT_EQ(test, 0,
			mptcp_token_new_request((struct request_sock *)req));
	msk->token = req->token;
	mptcp_token_accept(req, msk);
	KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, msk->token));

	/* this is now a no-op */
	mptcp_token_destroy_request((struct request_sock *)req);
	KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, msk->token));

	/* cleanup */
	mptcp_token_destroy(msk);
}

static void mptcp_token_test_destroyed(struct kunit *test)
{
	struct mptcp_subflow_request_sock *req = build_req_sock(test);
	struct mptcp_sock *msk = build_msk(test);
	struct mptcp_sock *null_msk = NULL;
	struct sock *sk;

	sk = (struct sock *)msk;

	KUNIT_ASSERT_EQ(test, 0,
			mptcp_token_new_request((struct request_sock *)req));
	msk->token = req->token;
	mptcp_token_accept(req, msk);

	/* simulate race on removal */
	refcount_set(&sk->sk_refcnt, 0);
	KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, msk->token));

	/* cleanup */
	mptcp_token_destroy(msk);
}

static struct kunit_case mptcp_token_test_cases[] = {
	KUNIT_CASE(mptcp_token_test_req_basic),
	KUNIT_CASE(mptcp_token_test_msk_basic),
	KUNIT_CASE(mptcp_token_test_accept),
	KUNIT_CASE(mptcp_token_test_destroyed),
	{}
};

static struct kunit_suite mptcp_token_suite = {
	.name = "mptcp-token",
	.test_cases = mptcp_token_test_cases,
};

kunit_test_suite(mptcp_token_suite);

MODULE_LICENSE("GPL");