diff --git a/net/mptcp/pm.c b/net/mptcp/pm.c index e19e1525ecbb..1f310abbf1ed 100644 --- a/net/mptcp/pm.c +++ b/net/mptcp/pm.c @@ -225,6 +225,15 @@ int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) return mptcp_pm_nl_get_local_id(msk, skc); } +bool mptcp_pm_is_backup(struct mptcp_sock *msk, struct sock_common *skc) +{ + struct mptcp_addr_info skc_local; + + mptcp_local_address((struct sock_common *)skc, &skc_local); + + return mptcp_pm_nl_is_backup(msk, &skc_local); +} + void mptcp_pm_data_init(struct mptcp_sock *msk) { msk->pm.add_addr_signaled = 0; diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 6a0079d42dc4..ca57d856d5df 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -568,6 +568,26 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) return ret; } +bool mptcp_pm_nl_is_backup(struct mptcp_sock *msk, struct mptcp_addr_info *skc) +{ + struct mptcp_pm_addr_entry *entry; + struct pm_nl_pernet *pernet; + bool backup = false; + + pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); + + rcu_read_lock(); + list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { + if (addresses_equal(&entry->addr, skc, entry->addr.port)) { + backup = !!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP); + break; + } + } + rcu_read_unlock(); + + return backup; +} + void mptcp_pm_nl_data_init(struct mptcp_sock *msk) { struct mptcp_pm_data *pm = &msk->pm; diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index ddfc7bde8c90..4348bccb982f 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -481,6 +481,7 @@ bool mptcp_pm_add_addr_signal(struct mptcp_sock *msk, unsigned int remaining, bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining, u8 *rm_id); int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc); +bool mptcp_pm_is_backup(struct mptcp_sock *msk, struct sock_common *skc); void __init mptcp_pm_nl_init(void); void mptcp_pm_nl_data_init(struct mptcp_sock *msk); @@ -490,6 +491,7 @@ void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk); void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk); void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id); int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc); +bool mptcp_pm_nl_is_backup(struct mptcp_sock *msk, struct mptcp_addr_info *skc); static inline struct mptcp_ext *mptcp_get_ext(struct sk_buff *skb) { diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c index f4067484727e..ba86cb06d6d8 100644 --- a/net/mptcp/subflow.c +++ b/net/mptcp/subflow.c @@ -80,6 +80,7 @@ static struct mptcp_sock *subflow_token_join_request(struct request_sock *req, return NULL; } subflow_req->local_id = local_id; + subflow_req->request_bkup = mptcp_pm_is_backup(msk, (struct sock_common *)req); get_random_bytes(&subflow_req->local_nonce, sizeof(u32));