diff --git a/easytier/src/connector/udp_hole_punch/both_easy_sym.rs b/easytier/src/connector/udp_hole_punch/both_easy_sym.rs index fc84406..39d80cd 100644 --- a/easytier/src/connector/udp_hole_punch/both_easy_sym.rs +++ b/easytier/src/connector/udp_hole_punch/both_easy_sym.rs @@ -12,6 +12,7 @@ use crate::{ connector::udp_hole_punch::common::{ try_connect_with_socket, UdpHolePunchListener, HOLE_PUNCH_PACKET_BODY_LEN, }, + connector::udp_hole_punch::handle_rpc_result, peers::peer_manager::PeerManager, proto::{ peer_rpc::{ @@ -171,11 +172,18 @@ impl PunchBothEasySymHoleServer { #[derive(Debug)] pub(crate) struct PunchBothEasySymHoleClient { peer_mgr: Arc, + blacklist: Arc>, } impl PunchBothEasySymHoleClient { - pub(crate) fn new(peer_mgr: Arc) -> Self { - Self { peer_mgr } + pub(crate) fn new( + peer_mgr: Arc, + blacklist: Arc>, + ) -> Self { + Self { + peer_mgr, + blacklist, + } } #[tracing::instrument(ret)] @@ -186,6 +194,12 @@ impl PunchBothEasySymHoleClient { peer_nat_info: UdpNatType, is_busy: &mut bool, ) -> Result>, anyhow::Error> { + // Check if peer is blacklisted + if self.blacklist.contains(&dst_peer_id) { + tracing::debug!(?dst_peer_id, "peer is blacklisted, skipping hole punching"); + return Ok(None); + } + *is_busy = false; let udp_array = UdpSocketArray::new( @@ -244,7 +258,10 @@ impl PunchBothEasySymHoleClient { wait_time_ms: REMOTE_WAIT_TIME_MS as u32, }, ) - .await?; + .await; + + let remote_ret = handle_rpc_result(remote_ret, dst_peer_id, self.blacklist.clone())?; + if remote_ret.is_busy { *is_busy = true; anyhow::bail!("remote is busy"); diff --git a/easytier/src/connector/udp_hole_punch/cone.rs b/easytier/src/connector/udp_hole_punch/cone.rs index 7e11e80..70cffb8 100644 --- a/easytier/src/connector/udp_hole_punch/cone.rs +++ b/easytier/src/connector/udp_hole_punch/cone.rs @@ -11,6 +11,7 @@ use crate::{ connector::udp_hole_punch::common::{ try_connect_with_socket, UdpSocketArray, HOLE_PUNCH_PACKET_BODY_LEN, }, + connector::udp_hole_punch::handle_rpc_result, peers::peer_manager::PeerManager, proto::{ common::Void, @@ -83,11 +84,18 @@ impl PunchConeHoleServer { pub(crate) struct PunchConeHoleClient { peer_mgr: Arc, + blacklist: Arc>, } impl PunchConeHoleClient { - pub(crate) fn new(peer_mgr: Arc) -> Self { - Self { peer_mgr } + pub(crate) fn new( + peer_mgr: Arc, + blacklist: Arc>, + ) -> Self { + Self { + peer_mgr, + blacklist, + } } #[tracing::instrument(skip(self))] @@ -95,6 +103,12 @@ impl PunchConeHoleClient { &self, dst_peer_id: PeerId, ) -> Result>, anyhow::Error> { + // Check if peer is blacklisted + if self.blacklist.contains(&dst_peer_id) { + tracing::debug!(?dst_peer_id, "peer is blacklisted, skipping hole punching"); + return Ok(None); + } + tracing::info!(?dst_peer_id, "start hole punching"); let tid = rand::random(); @@ -138,8 +152,10 @@ impl PunchConeHoleClient { BaseController::default(), SelectPunchListenerRequest { force_new: false }, ) - .await - .with_context(|| "failed to select punch listener")?; + .await; + + let resp = handle_rpc_result(resp, dst_peer_id, self.blacklist.clone())?; + let remote_mapped_addr = resp.listener_mapped_addr.ok_or(anyhow::anyhow!( "select_punch_listener response missing listener_mapped_addr" ))?; diff --git a/easytier/src/connector/udp_hole_punch/mod.rs b/easytier/src/connector/udp_hole_punch/mod.rs index a068a58..261820a 100644 --- a/easytier/src/connector/udp_hole_punch/mod.rs +++ b/easytier/src/connector/udp_hole_punch/mod.rs @@ -1,4 +1,7 @@ -use std::sync::{atomic::AtomicBool, Arc}; +use std::{ + sync::{atomic::AtomicBool, Arc}, + time::Duration, +}; use anyhow::{Context, Error}; use both_easy_sym::{PunchBothEasySymHoleClient, PunchBothEasySymHoleServer}; @@ -37,7 +40,10 @@ pub(crate) mod sym_to_cone; // sym punch should be serialized static SYM_PUNCH_LOCK: Lazy>>> = Lazy::new(|| DashMap::new()); -static RUN_TESTING: Lazy = Lazy::new(|| AtomicBool::new(false)); +pub static RUN_TESTING: Lazy = Lazy::new(|| AtomicBool::new(false)); + +// Blacklist timeout in seconds +pub const BLACKLIST_TIMEOUT_SEC: u64 = 3600; fn get_sym_punch_lock(peer_id: PeerId) -> Arc> { SYM_PUNCH_LOCK @@ -174,24 +180,44 @@ impl BackOff { } } +pub fn handle_rpc_result( + ret: Result, + dst_peer_id: PeerId, + blacklist: Arc>, +) -> Result { + match ret { + Ok(ret) => Ok(ret), + Err(e) => { + if matches!(e, rpc_types::error::Error::InvalidServiceKey(_, _)) { + blacklist.insert(dst_peer_id, (), Duration::from_secs(BLACKLIST_TIMEOUT_SEC)); + } + Err(e) + } + } +} + struct UdpHoePunchConnectorData { cone_client: PunchConeHoleClient, sym_to_cone_client: PunchSymToConeHoleClient, both_easy_sym_client: PunchBothEasySymHoleClient, peer_mgr: Arc, + blacklist: Arc>, } impl UdpHoePunchConnectorData { pub fn new(peer_mgr: Arc) -> Arc { - let cone_client = PunchConeHoleClient::new(peer_mgr.clone()); - let sym_to_cone_client = PunchSymToConeHoleClient::new(peer_mgr.clone()); - let both_easy_sym_client = PunchBothEasySymHoleClient::new(peer_mgr.clone()); + let blacklist = Arc::new(timedmap::TimedMap::new()); + let cone_client = PunchConeHoleClient::new(peer_mgr.clone(), blacklist.clone()); + let sym_to_cone_client = PunchSymToConeHoleClient::new(peer_mgr.clone(), blacklist.clone()); + let both_easy_sym_client = + PunchBothEasySymHoleClient::new(peer_mgr.clone(), blacklist.clone()); Arc::new(Self { cone_client, sym_to_cone_client, both_easy_sym_client, peer_mgr, + blacklist, }) } @@ -402,9 +428,12 @@ impl PeerTaskLauncher for UdpHolePunchPeerTaskLauncher { let my_peer_id = data.peer_mgr.my_peer_id(); + data.blacklist.cleanup(); + // collect peer list from peer manager and do some filter: // 1. peers without direct conns; // 2. peers is full cone (any restricted type); + // 3. peers not in blacklist; for route in data.peer_mgr.list_routes().await.iter() { if route .feature_flag @@ -425,6 +454,13 @@ impl PeerTaskLauncher for UdpHolePunchPeerTaskLauncher { let peer_nat_type = peer_nat_type.into(); let peer_id: PeerId = route.peer_id; + + // Check if peer is blacklisted + if data.blacklist.contains(&peer_id) { + tracing::debug!(?peer_id, "peer is blacklisted, skipping"); + continue; + } + let conns = data.peer_mgr.list_peer_conns(peer_id).await; if conns.is_some() && conns.unwrap().len() > 0 { continue; @@ -536,11 +572,17 @@ impl UdpHolePunchConnector { pub mod tests { use std::sync::Arc; + use std::time::Duration; use crate::common::stun::MockStunInfoCollector; + use crate::peers::{ + peer_manager::PeerManager, + tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear}, + }; use crate::proto::common::NatType; + use crate::tunnel::common::tests::wait_for_condition; - use crate::peers::{peer_manager::PeerManager, tests::create_mock_peer_manager}; + use super::{UdpHolePunchConnector, RUN_TESTING}; pub fn replace_stun_info_collector(peer_mgr: Arc, udp_nat_type: NatType) { let collector = Box::new(MockStunInfoCollector { udp_nat_type }); @@ -556,4 +598,37 @@ pub mod tests { replace_stun_info_collector(p_a.clone(), udp_nat_type); p_a } + + #[rstest::rstest] + #[tokio::test] + pub async fn test_hole_punching_blacklist( + #[values(NatType::Symmetric, NatType::PortRestricted, NatType::Unknown)] nat_type: NatType, + ) { + RUN_TESTING.store(true, std::sync::atomic::Ordering::Relaxed); + + let p_a = create_mock_peer_manager_with_mock_stun(nat_type).await; + let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + connect_peer_manager(p_a.clone(), p_b.clone()).await; + connect_peer_manager(p_b.clone(), p_c.clone()).await; + wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); + + let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone()); + + hole_punching_a.run().await.unwrap(); + + hole_punching_a.client.run_immediately().await; + + wait_for_condition( + || async { + hole_punching_a + .client + .data() + .blacklist + .contains(&p_c.my_peer_id()) + }, + Duration::from_secs(10), + ) + .await; + } } diff --git a/easytier/src/connector/udp_hole_punch/sym_to_cone.rs b/easytier/src/connector/udp_hole_punch/sym_to_cone.rs index e3b609e..c265898 100644 --- a/easytier/src/connector/udp_hole_punch/sym_to_cone.rs +++ b/easytier/src/connector/udp_hole_punch/sym_to_cone.rs @@ -18,6 +18,7 @@ use crate::{ connector::udp_hole_punch::common::{ send_symmetric_hole_punch_packet, try_connect_with_socket, HOLE_PUNCH_PACKET_BODY_LEN, }, + connector::udp_hole_punch::handle_rpc_result, defer, peers::peer_manager::PeerManager, proto::{ @@ -199,16 +200,21 @@ pub(crate) struct PunchSymToConeHoleClient { try_direct_connect: AtomicBool, punch_predicablely: AtomicBool, punch_randomly: AtomicBool, + blacklist: Arc>, } impl PunchSymToConeHoleClient { - pub(crate) fn new(peer_mgr: Arc) -> Self { + pub(crate) fn new( + peer_mgr: Arc, + blacklist: Arc>, + ) -> Self { Self { peer_mgr, udp_array: RwLock::new(None), try_direct_connect: AtomicBool::new(true), punch_predicablely: AtomicBool::new(true), punch_randomly: AtomicBool::new(true), + blacklist, } } @@ -394,6 +400,12 @@ impl PunchSymToConeHoleClient { last_port_idx: &mut usize, my_nat_info: UdpNatType, ) -> Result>, anyhow::Error> { + // Check if peer is blacklisted + if self.blacklist.contains(&dst_peer_id) { + tracing::debug!(?dst_peer_id, "peer is blacklisted, skipping hole punching"); + return Ok(None); + } + let udp_array = self.prepare_udp_array().await?; let global_ctx = self.peer_mgr.get_global_ctx(); @@ -412,8 +424,10 @@ impl PunchSymToConeHoleClient { BaseController::default(), SelectPunchListenerRequest { force_new: false }, ) - .await - .with_context(|| "failed to select punch listener")?; + .await; + + let resp = handle_rpc_result(resp, dst_peer_id, self.blacklist.clone())?; + let remote_mapped_addr = resp.listener_mapped_addr.ok_or(anyhow::anyhow!( "select_punch_listener response missing listener_mapped_addr" ))?; diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index d3ea94b..f93b082 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -361,8 +361,8 @@ impl PeerManager { is_directly_connected: bool, ) -> Result<(), Error> { tracing::info!("add tunnel as server start"); - let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel); - peer.do_handshake_as_server_ext(|peer, msg| { + let mut conn = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel); + conn.do_handshake_as_server_ext(|peer, msg| { if msg.network_name == self.global_ctx.get_network_identity().network_name { @@ -396,13 +396,14 @@ impl PeerManager { }) .await?; - let peer_network_name = peer.get_network_identity().network_name.clone(); + let peer_network_name = conn.get_network_identity().network_name.clone(); + + conn.set_is_hole_punched(!is_directly_connected); if peer_network_name == self.global_ctx.get_network_identity().network_name { - peer.set_is_hole_punched(!is_directly_connected); - self.add_new_peer_conn(peer).await?; + self.add_new_peer_conn(conn).await?; } else { - self.foreign_network_manager.add_peer_conn(peer).await?; + self.foreign_network_manager.add_peer_conn(conn).await?; } self.reserved_my_peer_id_map.remove(&peer_network_name);