feat(acl): add group-based ACL rules and related structures (#1265)

* feat(acl): add group-based ACL rules and related structures

* refactor(acl): optimize group handling with Arc and improve cache management

* refactor(acl): clippy

* feat(tests): add performance tests for generate_with_proof and verify methods

* feat: update group_trust_map to use HashMap for more secure group proofs

* refactor: refactor the logic of the trusted group getting and setting

* feat(acl): support kcp/quic use group acl

* feat(proxy): optimize group retrieval by IP in Kcp and Quic proxy handlers

* feat(tests): add group-based ACL tree node test

* always allow quic proxy traffic

---------

Co-authored-by: Sijie.Sun <sunsijie@buaa.edu.cn>
Co-authored-by: sijie.sun <sijie.sun@smartx.com>
This commit is contained in:
Mg Pig
2025-08-22 22:25:00 +08:00
committed by GitHub
parent 34560af141
commit 08a92a53c3
18 changed files with 1042 additions and 29 deletions

2
Cargo.lock generated
View File

@@ -2129,6 +2129,7 @@ dependencies = [
"hickory-proto",
"hickory-resolver",
"hickory-server",
"hmac",
"http",
"http_req",
"humansize",
@@ -2171,6 +2172,7 @@ dependencies = [
"serde_json",
"serial_test",
"service-manager",
"sha2",
"smoltcp",
"socket2",
"stun_codec",

View File

@@ -215,6 +215,8 @@ derive_builder = "0.20.2"
humantime-serde = "1.1.1"
multimap = "0.10.0"
version-compare = "0.2.0"
hmac = "0.12.1"
sha2 = "0.10.8"
[target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies]
machine-uid = "0.5.3"

View File

@@ -1,5 +1,5 @@
use std::{
collections::HashMap,
collections::{HashMap, HashSet},
net::{IpAddr, SocketAddr},
str::FromStr as _,
sync::Arc,
@@ -61,6 +61,8 @@ pub struct FastLookupRule {
pub dst_ip_ranges: Vec<cidr::IpCidr>,
pub src_port_ranges: Vec<(u16, u16)>,
pub dst_port_ranges: Vec<(u16, u16)>,
pub source_groups: HashSet<String>,
pub destination_groups: HashSet<String>,
pub action: Action,
pub enabled: bool,
pub stateful: bool,
@@ -78,6 +80,8 @@ pub struct AclCacheKey {
pub dst_ip: IpAddr,
pub src_port: u16,
pub dst_port: u16,
pub src_groups: Arc<Vec<String>>,
pub dst_groups: Arc<Vec<String>>,
}
impl AclCacheKey {
@@ -89,6 +93,8 @@ impl AclCacheKey {
dst_ip: packet_info.dst_ip,
src_port: packet_info.src_port.unwrap_or(0),
dst_port: packet_info.dst_port.unwrap_or(0),
src_groups: packet_info.src_groups.clone(),
dst_groups: packet_info.dst_groups.clone(),
}
}
}
@@ -116,6 +122,8 @@ pub struct PacketInfo {
pub dst_port: Option<u16>,
pub protocol: Protocol,
pub packet_size: usize,
pub src_groups: Arc<Vec<String>>,
pub dst_groups: Arc<Vec<String>>,
}
// ACL processing result
@@ -684,6 +692,28 @@ impl AclProcessor {
}
}
// Source group check
if !rule.source_groups.is_empty() {
let matches = packet_info
.src_groups
.iter()
.any(|group| rule.source_groups.contains(group));
if !matches {
return false;
}
}
// Destination group check
if !rule.destination_groups.is_empty() {
let matches = packet_info
.dst_groups
.iter()
.any(|group| rule.destination_groups.contains(group));
if !matches {
return false;
}
}
true
}
@@ -804,6 +834,8 @@ impl AclProcessor {
dst_ip_ranges,
src_port_ranges,
dst_port_ranges,
source_groups: rule.source_groups.iter().cloned().collect(),
destination_groups: rule.destination_groups.iter().cloned().collect(),
action: rule.action(),
enabled: rule.enabled,
stateful: rule.stateful,
@@ -1071,6 +1103,8 @@ impl AclRuleBuilder {
rate_limit: 0,
burst_limit: 0,
stateful: true,
source_groups: vec![],
destination_groups: vec![],
};
inbound_chain.rules.push(tcp_rule);
rule_priority -= 1;
@@ -1093,6 +1127,8 @@ impl AclRuleBuilder {
rate_limit: 0,
burst_limit: 0,
stateful: false,
source_groups: vec![],
destination_groups: vec![],
};
inbound_chain.rules.push(udp_rule);
}
@@ -1108,6 +1144,10 @@ impl AclRuleBuilder {
} else {
acl.acl_v1 = Some(AclV1 {
chains: vec![inbound_chain],
group: Some(GroupInfo {
declares: vec![],
members: vec![],
}),
});
}
@@ -1144,6 +1184,106 @@ mod tests {
use std::hash::{Hash, Hasher};
use std::net::{IpAddr, Ipv4Addr};
#[tokio::test]
async fn test_group_based_acl_rules() {
let mut acl_config = Acl::default();
let mut acl_v1 = AclV1::default();
let mut chain = Chain {
name: "group_test_chain".to_string(),
chain_type: ChainType::Inbound as i32,
enabled: true,
default_action: Action::Drop as i32,
..Default::default()
};
// Rules
chain.rules.push(Rule {
name: "allow_admins_to_db".to_string(),
priority: 100,
enabled: true,
action: Action::Allow as i32,
protocol: Protocol::Any as i32,
source_groups: vec!["admin".to_string()],
destination_groups: vec!["db-server".to_string()],
..Default::default()
});
chain.rules.push(Rule {
name: "allow_devs_from_anywhere".to_string(),
priority: 90,
enabled: true,
action: Action::Allow as i32,
protocol: Protocol::Any as i32,
source_groups: vec!["dev".to_string()],
..Default::default()
});
chain.rules.push(Rule {
name: "deny_guests_to_db".to_string(),
priority: 80,
enabled: true,
action: Action::Drop as i32,
protocol: Protocol::Any as i32,
source_groups: vec!["guest".to_string()],
destination_groups: vec!["db-server".to_string()],
..Default::default()
});
chain.rules.push(Rule {
name: "allow_specific_ip".to_string(),
priority: 70,
enabled: true,
action: Action::Allow as i32,
protocol: Protocol::Any as i32,
source_ips: vec!["1.2.3.4/32".to_string()],
..Default::default()
});
acl_v1.chains.push(chain);
acl_config.acl_v1 = Some(acl_v1);
let processor = AclProcessor::new(acl_config);
// Case 3.1: Source group match (devs from anywhere)
let mut packet_info = create_test_packet_info();
packet_info.src_groups = Arc::new(vec!["dev".to_string()]);
let result = processor.process_packet(&packet_info, ChainType::Inbound);
assert_eq!(result.action, Action::Allow);
assert_eq!(result.matched_rule, Some(RuleId::Priority(90)));
// Case 3.2: Source group no match
packet_info.src_groups = Arc::new(vec!["guest".to_string()]);
let result = processor.process_packet(&packet_info, ChainType::Inbound);
assert_eq!(result.action, Action::Drop); // Default drop
assert_eq!(result.matched_rule, Some(RuleId::Default));
// Case 3.3: Destination group match (deny guests to db)
packet_info.src_groups = Arc::new(vec!["guest".to_string()]);
packet_info.dst_groups = Arc::new(vec!["db-server".to_string()]);
let result = processor.process_packet(&packet_info, ChainType::Inbound);
assert_eq!(result.action, Action::Drop);
assert_eq!(result.matched_rule, Some(RuleId::Priority(80)));
// Case 3.4: Source and Destination groups match
packet_info.src_groups = Arc::new(vec!["admin".to_string()]);
packet_info.dst_groups = Arc::new(vec!["db-server".to_string()]);
let result = processor.process_packet(&packet_info, ChainType::Inbound);
assert_eq!(result.action, Action::Allow);
assert_eq!(result.matched_rule, Some(RuleId::Priority(100)));
// Case 3.5: Partial match (admin to web-server)
packet_info.src_groups = Arc::new(vec!["admin".to_string()]);
packet_info.dst_groups = Arc::new(vec!["web-server".to_string()]);
let result = processor.process_packet(&packet_info, ChainType::Inbound);
assert_eq!(result.action, Action::Drop); // Default drop
assert_eq!(result.matched_rule, Some(RuleId::Default));
// Case 3.6: Rule with no group definition
packet_info.src_ip = "1.2.3.4".parse().unwrap();
packet_info.src_groups = Arc::new(vec!["admin".to_string()]);
packet_info.dst_groups = Arc::new(vec![]);
let result = processor.process_packet(&packet_info, ChainType::Inbound);
assert_eq!(result.action, Action::Allow);
assert_eq!(result.matched_rule, Some(RuleId::Priority(70)));
}
fn create_test_acl_config() -> Acl {
let mut acl_config = Acl::default();
@@ -1182,6 +1322,8 @@ mod tests {
dst_port: Some(80),
protocol: Protocol::Tcp,
packet_size: 1024,
src_groups: Arc::new(vec![]),
dst_groups: Arc::new(vec![]),
}
}
@@ -1380,6 +1522,8 @@ mod tests {
dst_port: Some(53), // DNS
protocol: Protocol::Udp, // UDP
packet_size: 512,
src_groups: Arc::new(vec![]),
dst_groups: Arc::new(vec![]),
};
// Test TCP packet (should hit stateful+rate-limited rule)
@@ -1390,6 +1534,8 @@ mod tests {
dst_port: Some(80), // HTTP
protocol: Protocol::Tcp, // TCP
packet_size: 1024,
src_groups: Arc::new(vec![]),
dst_groups: Arc::new(vec![]),
};
// Process UDP packets multiple times

View File

@@ -8,8 +8,10 @@ use crate::common::config::ProxyNetworkConfig;
use crate::common::stats_manager::StatsManager;
use crate::common::token_bucket::TokenBucketManager;
use crate::peers::acl_filter::AclFilter;
use crate::proto::acl::GroupIdentity;
use crate::proto::cli::PeerConnInfo;
use crate::proto::common::{PeerFeatureFlag, PortForwardConfigPb};
use crate::proto::peer_rpc::PeerGroupInfo;
use crossbeam::atomic::AtomicCell;
use super::{
@@ -351,6 +353,7 @@ impl GlobalCtx {
}
pub fn set_quic_proxy_port(&self, port: Option<u16>) {
self.acl_filter.set_quic_udp_port(port.unwrap_or(0));
self.quic_proxy_port.store(port);
}
@@ -365,6 +368,37 @@ impl GlobalCtx {
pub fn get_acl_filter(&self) -> &Arc<AclFilter> {
&self.acl_filter
}
pub fn get_acl_groups(&self, peer_id: PeerId) -> Vec<PeerGroupInfo> {
use std::collections::HashSet;
self.config
.get_acl()
.and_then(|acl| acl.acl_v1)
.and_then(|acl_v1| acl_v1.group)
.map_or_else(Vec::new, |group| {
let memberships: HashSet<_> = group.members.iter().collect();
group
.declares
.iter()
.filter(|g| memberships.contains(&g.group_name))
.map(|g| {
PeerGroupInfo::generate_with_proof(
g.group_name.clone(),
g.group_secret.clone(),
peer_id,
)
})
.collect()
})
}
pub fn get_acl_group_declarations(&self) -> Vec<GroupIdentity> {
self.config
.get_acl()
.and_then(|acl| acl.acl_v1)
.and_then(|acl_v1| acl_v1.group)
.map_or_else(Vec::new, |group| group.declares.to_vec())
}
}
#[cfg(test)]

View File

@@ -440,12 +440,13 @@ impl KcpProxyDst {
}
}
#[tracing::instrument(ret)]
#[tracing::instrument(ret, skip(route))]
async fn handle_one_in_stream(
kcp_stream: KcpStream,
global_ctx: ArcGlobalCtx,
proxy_entries: Arc<DashMap<ConnId, TcpProxyEntry>>,
cidr_set: Arc<CidrSet>,
route: Arc<(dyn crate::peers::route_trait::Route + Send + Sync + 'static)>,
) -> Result<()> {
let mut conn_data = kcp_stream.conn_data().clone();
let parsed_conn_data = KcpConnData::decode(&mut conn_data)
@@ -481,6 +482,13 @@ impl KcpProxyDst {
proxy_entries.remove(&conn_id);
}
let src_ip = src_socket.ip();
let dst_ip = dst_socket.ip();
let (src_groups, dst_groups) = tokio::join!(
route.get_peer_groups_by_ip(&src_ip),
route.get_peer_groups_by_ip(&dst_ip)
);
let send_to_self =
Some(dst_socket.ip()) == global_ctx.get_ipv4().map(|ip| IpAddr::V4(ip.address()));
@@ -491,12 +499,14 @@ impl KcpProxyDst {
let acl_handler = ProxyAclHandler {
acl_filter: global_ctx.get_acl_filter().clone(),
packet_info: PacketInfo {
src_ip: src_socket.ip(),
dst_ip: dst_socket.ip(),
src_ip,
dst_ip,
src_port: Some(src_socket.port()),
dst_port: Some(dst_socket.port()),
protocol: Protocol::Tcp,
packet_size: conn_data.len(),
src_groups,
dst_groups,
},
chain_type: if send_to_self {
ChainType::Inbound
@@ -530,6 +540,7 @@ impl KcpProxyDst {
let global_ctx = self.peer_manager.get_global_ctx().clone();
let proxy_entries = self.proxy_entries.clone();
let cidr_set = self.cidr_set.clone();
let route = Arc::new(self.peer_manager.get_route());
self.tasks.spawn(async move {
while let Ok(conn) = kcp_endpoint.accept().await {
let stream = KcpStream::new(&kcp_endpoint, conn)
@@ -539,9 +550,16 @@ impl KcpProxyDst {
let global_ctx = global_ctx.clone();
let proxy_entries = proxy_entries.clone();
let cidr_set = cidr_set.clone();
let route = route.clone();
tokio::spawn(async move {
let _ = Self::handle_one_in_stream(stream, global_ctx, proxy_entries, cidr_set)
.await;
let _ = Self::handle_one_in_stream(
stream,
global_ctx,
proxy_entries,
cidr_set,
route,
)
.await;
});
}
});

View File

@@ -247,10 +247,14 @@ pub struct QUICProxyDst {
endpoint: Arc<quinn::Endpoint>,
proxy_entries: Arc<DashMap<SocketAddr, TcpProxyEntry>>,
tasks: Arc<Mutex<JoinSet<()>>>,
route: Arc<(dyn crate::peers::route_trait::Route + Send + Sync + 'static)>,
}
impl QUICProxyDst {
pub fn new(global_ctx: ArcGlobalCtx) -> Result<Self> {
pub fn new(
global_ctx: ArcGlobalCtx,
route: Arc<(dyn crate::peers::route_trait::Route + Send + Sync + 'static)>,
) -> Result<Self> {
let _g = global_ctx.net_ns.guard();
let (endpoint, _) = make_server_endpoint("0.0.0.0:0".parse().unwrap())
.map_err(|e| anyhow::anyhow!("failed to create QUIC endpoint: {}", e))?;
@@ -261,6 +265,7 @@ impl QUICProxyDst {
endpoint: Arc::new(endpoint),
proxy_entries: Arc::new(DashMap::new()),
tasks,
route,
})
}
@@ -270,6 +275,7 @@ impl QUICProxyDst {
let ctx = self.global_ctx.clone();
let cidr_set = Arc::new(CidrSet::new(ctx.clone()));
let proxy_entries = self.proxy_entries.clone();
let route = self.route.clone();
let task = async move {
loop {
@@ -289,6 +295,7 @@ impl QUICProxyDst {
ctx.clone(),
cidr_set.clone(),
proxy_entries.clone(),
route.clone(),
));
}
None => {
@@ -312,6 +319,7 @@ impl QUICProxyDst {
ctx: Arc<GlobalCtx>,
cidr_set: Arc<CidrSet>,
proxy_entries: Arc<DashMap<SocketAddr, TcpProxyEntry>>,
route: Arc<(dyn crate::peers::route_trait::Route + Send + Sync + 'static)>,
) {
let remote_addr = conn.remote_address();
defer!(
@@ -319,7 +327,14 @@ impl QUICProxyDst {
);
let ret = timeout(
std::time::Duration::from_secs(10),
Self::handle_connection(conn, ctx, cidr_set, remote_addr, proxy_entries.clone()),
Self::handle_connection(
conn,
ctx,
cidr_set,
remote_addr,
proxy_entries.clone(),
route,
),
)
.await;
@@ -348,6 +363,7 @@ impl QUICProxyDst {
cidr_set: Arc<CidrSet>,
proxy_entry_key: SocketAddr,
proxy_entries: Arc<DashMap<SocketAddr, TcpProxyEntry>>,
route: Arc<(dyn crate::peers::route_trait::Route + Send + Sync + 'static)>,
) -> Result<(QUICStream, TcpStream, ProxyAclHandler)> {
let conn = incoming.await.with_context(|| "accept failed")?;
let addr = conn.remote_address();
@@ -379,6 +395,13 @@ impl QUICProxyDst {
dst_socket.set_ip(real_ip);
}
let src_ip = addr.ip();
let dst_ip = *dst_socket.ip();
let (src_groups, dst_groups) = tokio::join!(
route.get_peer_groups_by_ip(&src_ip),
route.get_peer_groups_by_ipv4(&dst_ip)
);
let send_to_self = Some(*dst_socket.ip()) == ctx.get_ipv4().map(|ip| ip.address());
if send_to_self && ctx.no_tun() {
dst_socket = format!("127.0.0.1:{}", dst_socket.port()).parse().unwrap();
@@ -398,12 +421,14 @@ impl QUICProxyDst {
let acl_handler = ProxyAclHandler {
acl_filter: ctx.get_acl_filter().clone(),
packet_info: PacketInfo {
src_ip: addr.ip(),
dst_ip: (*dst_socket.ip()).into(),
src_ip,
dst_ip: dst_ip.into(),
src_port: Some(addr.port()),
dst_port: Some(dst_socket.port()),
protocol: Protocol::Tcp,
packet_size: len as usize,
src_groups,
dst_groups,
},
chain_type: if send_to_self {
ChainType::Inbound

View File

@@ -531,7 +531,8 @@ impl Instance {
return Ok(());
}
let quic_dst = QUICProxyDst::new(self.global_ctx.clone())?;
let route = Arc::new(self.peer_manager.get_route());
let quic_dst = QUICProxyDst::new(self.global_ctx.clone(), route)?;
quic_dst.start().await?;
self.global_ctx
.set_quic_proxy_port(Some(quic_dst.local_addr()?.port()));

View File

@@ -1,5 +1,5 @@
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicU16, Ordering};
use std::{
net::IpAddr,
sync::{atomic::AtomicBool, Arc},
@@ -25,6 +25,7 @@ pub struct AclFilter {
// Use ArcSwap for lock-free atomic replacement during hot reload
acl_processor: ArcSwap<AclProcessor>,
acl_enabled: Arc<AtomicBool>,
quic_udp_port: AtomicU16,
}
impl Default for AclFilter {
@@ -38,6 +39,7 @@ impl AclFilter {
Self {
acl_processor: ArcSwap::from(Arc::new(AclProcessor::new(Acl::default()))),
acl_enabled: Arc::new(AtomicBool::new(false)),
quic_udp_port: AtomicU16::new(0),
}
}
@@ -88,7 +90,11 @@ impl AclFilter {
}
/// Extract packet information for ACL processing
fn extract_packet_info(&self, packet: &ZCPacket) -> Option<PacketInfo> {
fn extract_packet_info(
&self,
packet: &ZCPacket,
route: &(dyn super::route_trait::Route + Send + Sync + 'static),
) -> Option<PacketInfo> {
let payload = packet.payload();
let src_ip;
@@ -155,6 +161,15 @@ impl AclFilter {
_ => Protocol::Unspecified,
};
let src_groups = packet
.get_src_peer_id()
.map(|peer_id| route.get_peer_groups(peer_id))
.unwrap_or_else(|| Arc::new(Vec::new()));
let dst_groups = packet
.get_dst_peer_id()
.map(|peer_id| route.get_peer_groups(peer_id))
.unwrap_or_else(|| Arc::new(Vec::new()));
Some(PacketInfo {
src_ip,
dst_ip,
@@ -162,6 +177,8 @@ impl AclFilter {
dst_port,
protocol: acl_protocol,
packet_size: payload.len(),
src_groups,
dst_groups,
})
}
@@ -181,6 +198,8 @@ impl AclFilter {
dst_ip = %packet_info.dst_ip,
src_port = packet_info.src_port,
dst_port = packet_info.dst_port,
src_group = packet_info.src_groups.join(","),
dst_group = packet_info.dst_groups.join(","),
protocol = ?packet_info.protocol,
action = ?result.action,
rule = result.matched_rule_str().as_deref().unwrap_or("unknown"),
@@ -226,6 +245,40 @@ impl AclFilter {
processor.increment_stat(AclStatKey::PacketsTotal);
}
fn check_is_quic_packet(
&self,
packet_info: &PacketInfo,
my_ipv4: &Option<Ipv4Addr>,
my_ipv6: &Option<Ipv6Addr>,
) -> bool {
if packet_info.protocol != Protocol::Udp {
return false;
}
let quic_port = self.get_quic_udp_port();
if quic_port == 0 {
return false;
}
// quic input
if packet_info.dst_port == Some(quic_port)
&& (packet_info.dst_ip == my_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED)
|| packet_info.dst_ip == my_ipv6.unwrap_or(Ipv6Addr::UNSPECIFIED))
{
return true;
}
// quic output
if packet_info.src_port == Some(quic_port)
&& (packet_info.src_ip == my_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED)
|| packet_info.src_ip == my_ipv6.unwrap_or(Ipv6Addr::UNSPECIFIED))
{
return true;
}
false
}
/// Common ACL processing logic
pub fn process_packet_with_acl(
&self,
@@ -233,6 +286,7 @@ impl AclFilter {
is_in: bool,
my_ipv4: Option<Ipv4Addr>,
my_ipv6: Option<Ipv6Addr>,
route: &(dyn super::route_trait::Route + Send + Sync + 'static),
) -> bool {
if !self.acl_enabled.load(Ordering::Relaxed) {
return true;
@@ -243,7 +297,7 @@ impl AclFilter {
}
// Extract packet information
let packet_info = match self.extract_packet_info(packet) {
let packet_info = match self.extract_packet_info(packet, route) {
Some(info) => info,
None => {
tracing::warn!(
@@ -256,6 +310,10 @@ impl AclFilter {
}
};
if self.check_is_quic_packet(&packet_info, &my_ipv4, &my_ipv6) {
return true;
}
let chain_type = if is_in {
if packet_info.dst_ip == my_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED)
|| packet_info.dst_ip == my_ipv6.unwrap_or(Ipv6Addr::UNSPECIFIED)
@@ -292,4 +350,12 @@ impl AclFilter {
}
}
}
pub fn get_quic_udp_port(&self) -> u16 {
self.quic_udp_port.load(Ordering::Relaxed)
}
pub fn set_quic_udp_port(&self, port: u16) {
self.quic_udp_port.store(port, Ordering::Relaxed);
}
}

View File

@@ -32,7 +32,7 @@ use crate::{
peer_conn::PeerConn,
peer_rpc::PeerRpcManagerTransport,
recv_packet_from_chan,
route_trait::{ForeignNetworkRouteInfoMap, NextHopPolicy, RouteInterface},
route_trait::{ForeignNetworkRouteInfoMap, MockRoute, NextHopPolicy, RouteInterface},
PeerPacketFilter,
},
proto::{
@@ -634,6 +634,7 @@ impl PeerManager {
let acl_filter = self.global_ctx.get_acl_filter().clone();
let global_ctx = self.global_ctx.clone();
let stats_mgr = self.global_ctx.stats_manager().clone();
let route = self.get_route();
let label_set =
LabelSet::new().with_label_type(LabelType::NetworkName(global_ctx.get_network_name()));
@@ -737,6 +738,7 @@ impl PeerManager {
true,
global_ctx.get_ipv4().map(|x| x.address()),
global_ctx.get_ipv6().map(|x| x.address()),
&route,
) {
continue;
}
@@ -914,7 +916,7 @@ impl PeerManager {
pub fn get_route(&self) -> Box<dyn Route + Send + Sync + 'static> {
match &self.route_algo_inst {
RouteAlgoInst::Ospf(route) => Box::new(route.clone()),
RouteAlgoInst::None => panic!("no route"),
RouteAlgoInst::None => Box::new(MockRoute {}),
}
}
@@ -960,11 +962,13 @@ impl PeerManager {
}
async fn run_nic_packet_process_pipeline(&self, data: &mut ZCPacket) {
if !self
.global_ctx
.get_acl_filter()
.process_packet_with_acl(data, false, None, None)
{
if !self.global_ctx.get_acl_filter().process_packet_with_acl(
data,
false,
None,
None,
&self.get_route(),
) {
return;
}

View File

@@ -1,5 +1,7 @@
use std::{
collections::{BTreeMap, BTreeSet},
collections::{
HashMap, {BTreeMap, BTreeSet},
},
fmt::Debug,
net::{Ipv4Addr, Ipv6Addr},
sync::{
@@ -33,6 +35,7 @@ use crate::{
},
peers::route_trait::{Route, RouteInterfaceBox},
proto::{
acl::GroupIdentity,
common::{Ipv4Inet, NatType, StunInfo},
peer_rpc::{
route_foreign_network_infos, route_foreign_network_summary,
@@ -127,6 +130,7 @@ impl RoutePeerInfo {
network_length: 24,
quic_port: None,
ipv6_addr: None,
groups: Vec::new(),
}
}
@@ -168,6 +172,8 @@ impl RoutePeerInfo {
quic_port: global_ctx.get_quic_proxy_port().map(|x| x as u32),
ipv6_addr: global_ctx.get_ipv6().map(|x| x.into()),
groups: global_ctx.get_acl_groups(my_peer_id),
};
let need_update_periodically = if let Ok(Ok(d)) =
@@ -296,6 +302,8 @@ struct SyncedRouteInfo {
raw_peer_infos: DashMap<PeerId, DynamicMessage>,
conn_map: DashMap<PeerId, (BTreeSet<PeerId>, AtomicVersion)>,
foreign_network: DashMap<ForeignNetworkRouteInfoKey, ForeignNetworkRouteInfoEntry>,
group_trust_map: DashMap<PeerId, HashMap<String, Vec<u8>>>,
group_trust_map_cache: DashMap<PeerId, Arc<Vec<String>>>, // cache for group trust map, should sync with group_trust_map
version: AtomicVersion,
}
@@ -306,6 +314,7 @@ impl Debug for SyncedRouteInfo {
.field("peer_infos", &self.peer_infos)
.field("conn_map", &self.conn_map)
.field("foreign_network", &self.foreign_network)
.field("group_trust_map", &self.group_trust_map)
.field("version", &self.version.get())
.finish()
}
@@ -324,6 +333,8 @@ impl SyncedRouteInfo {
self.raw_peer_infos.remove(&peer_id);
self.conn_map.remove(&peer_id);
self.foreign_network.retain(|k, _| k.peer_id != peer_id);
self.group_trust_map.remove(&peer_id);
self.group_trust_map_cache.remove(&peer_id);
self.version.inc();
}
@@ -613,6 +624,85 @@ impl SyncedRouteInfo {
self.is_peer_bidirectly_connected(src_peer_id, dst_peer_id)
|| self.is_peer_bidirectly_connected(dst_peer_id, src_peer_id)
}
fn verify_and_update_group_trusts(
&self,
peer_infos: &[RoutePeerInfo],
local_group_declarations: &[GroupIdentity],
) {
let local_group_declarations = local_group_declarations
.iter()
.map(|g| (g.group_name.as_str(), g.group_secret.as_str()))
.collect::<std::collections::HashMap<&str, &str>>();
let verify_groups = |old_trusted_groups: Option<&HashMap<String, Vec<u8>>>,
info: &RoutePeerInfo|
-> HashMap<String, Vec<u8>> {
let mut trusted_groups_for_peer: HashMap<String, Vec<u8>> = HashMap::new();
for group_proof in &info.groups {
let name = &group_proof.group_name;
let proof_bytes = group_proof.group_proof.clone();
// If we already trusted this group and the proof hasn't changed, reuse it.
if old_trusted_groups
.and_then(|g| g.get(name))
.map(|old| old == &proof_bytes)
.unwrap_or(false)
{
trusted_groups_for_peer.insert(name.clone(), proof_bytes);
continue;
}
if let Some(&local_secret) =
local_group_declarations.get(group_proof.group_name.as_str())
{
if group_proof.verify(local_secret, info.peer_id) {
trusted_groups_for_peer.insert(name.clone(), proof_bytes);
} else {
tracing::warn!(
peer_id = info.peer_id,
group = %group_proof.group_name,
"Group proof verification failed"
);
}
}
}
trusted_groups_for_peer
};
for info in peer_infos {
match self.group_trust_map.entry(info.peer_id) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
let old_trusted_groups = entry.get().clone();
let trusted_groups_for_peer = verify_groups(Some(&old_trusted_groups), info);
if trusted_groups_for_peer.is_empty() {
entry.remove();
self.group_trust_map_cache.remove(&info.peer_id);
} else {
self.group_trust_map_cache.insert(
info.peer_id,
Arc::new(trusted_groups_for_peer.keys().cloned().collect()),
);
*entry.get_mut() = trusted_groups_for_peer;
}
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
let trusted_groups_for_peer = verify_groups(None, info);
if !trusted_groups_for_peer.is_empty() {
self.group_trust_map_cache.insert(
info.peer_id,
Arc::new(trusted_groups_for_peer.keys().cloned().collect()),
);
entry.insert(trusted_groups_for_peer);
}
}
}
}
}
}
type PeerGraph = Graph<PeerId, usize, Directed>;
@@ -1154,6 +1244,8 @@ impl PeerRouteServiceImpl {
raw_peer_infos: DashMap::new(),
conn_map: DashMap::new(),
foreign_network: DashMap::new(),
group_trust_map: DashMap::new(),
group_trust_map_cache: DashMap::new(),
version: AtomicVersion::new(),
},
cached_local_conn_map: std::sync::Mutex::new(RouteConnBitmap::new()),
@@ -1679,6 +1771,14 @@ impl PeerRouteServiceImpl {
fn get_peer_info_last_update(&self) -> std::time::Instant {
self.peer_info_last_update.load()
}
fn get_peer_groups(&self, peer_id: PeerId) -> Arc<Vec<String>> {
self.synced_route_info
.group_trust_map_cache
.get(&peer_id)
.map(|groups| groups.value().clone())
.unwrap_or_default()
}
}
impl Drop for PeerRouteServiceImpl {
@@ -2016,6 +2116,12 @@ impl RouteSessionManager {
peer_infos,
raw_peer_infos.as_ref().unwrap(),
)?;
service_impl
.synced_route_info
.verify_and_update_group_trusts(
peer_infos,
&service_impl.global_ctx.get_acl_group_declarations(),
);
session.update_dst_saved_peer_info_version(peer_infos);
need_update_route_table = true;
}
@@ -2364,6 +2470,10 @@ impl Route for PeerRoute {
async fn get_peer_info_last_update_time(&self) -> Instant {
self.service_impl.get_peer_info_last_update()
}
fn get_peer_groups(&self, peer_id: PeerId) -> Arc<Vec<String>> {
self.service_impl.get_peer_groups(peer_id)
}
}
impl PeerPacketFilter for Arc<PeerRoute> {}

View File

@@ -122,9 +122,58 @@ pub trait Route {
async fn get_peer_info_last_update_time(&self) -> std::time::Instant;
fn get_peer_groups(&self, peer_id: PeerId) -> Arc<Vec<String>>;
async fn get_peer_groups_by_ip(&self, ip: &std::net::IpAddr) -> Arc<Vec<String>> {
match self.get_peer_id_by_ip(ip).await {
Some(peer_id) => self.get_peer_groups(peer_id),
None => Arc::new(Vec::new()),
}
}
async fn get_peer_groups_by_ipv4(&self, ipv4: &Ipv4Addr) -> Arc<Vec<String>> {
match self.get_peer_id_by_ipv4(ipv4).await {
Some(peer_id) => self.get_peer_groups(peer_id),
None => Arc::new(Vec::new()),
}
}
async fn dump(&self) -> String {
"this route implementation does not support dump".to_string()
}
}
pub type ArcRoute = Arc<Box<dyn Route + Send + Sync>>;
pub struct MockRoute {}
#[async_trait::async_trait]
impl Route for MockRoute {
async fn open(&self, _interface: RouteInterfaceBox) -> Result<u8, ()> {
panic!("mock route")
}
async fn close(&self) {
panic!("mock route")
}
async fn get_next_hop(&self, _peer_id: PeerId) -> Option<PeerId> {
panic!("mock route")
}
async fn list_routes(&self) -> Vec<crate::proto::cli::Route> {
panic!("mock route")
}
async fn get_peer_info(&self, _peer_id: PeerId) -> Option<RoutePeerInfo> {
panic!("mock route")
}
async fn get_peer_info_last_update_time(&self) -> std::time::Instant {
panic!("mock route")
}
fn get_peer_groups(&self, _peer_id: PeerId) -> Arc<Vec<String>> {
panic!("mock route")
}
}

View File

@@ -67,6 +67,10 @@ message Rule {
// Connection tracking
bool stateful = 13; // Enable connection tracking
// Group matching criteria
repeated string source_groups = 14;
repeated string destination_groups = 15;
}
// Rule chain with metadata and optimization hints
@@ -84,7 +88,20 @@ message Chain {
Action default_action = 6;
}
message AclV1 { repeated Chain chains = 1; }
message GroupInfo {
repeated GroupIdentity declares = 1;
repeated string members = 2;
}
message GroupIdentity {
string group_name = 1;
string group_secret = 2;
}
message AclV1 {
repeated Chain chains = 1;
GroupInfo group = 2;
}
enum ConnState {
New = 0;

View File

@@ -25,6 +25,8 @@ message RoutePeerInfo {
optional uint32 quic_port = 14;
optional common.Ipv6Inet ipv6_addr = 15;
repeated PeerGroupInfo groups = 16;
}
message PeerIdVersion {
@@ -70,6 +72,11 @@ message RouteForeignNetworkSummary {
map<uint32, Info> info_map = 1;
}
message PeerGroupInfo {
string group_name = 1;
bytes group_proof = 2;
}
message SyncRouteInfoRequest {
uint32 my_peer_id = 1;
uint64 my_session_id = 2;

View File

@@ -1 +1,245 @@
use hmac::{Hmac, Mac};
use sha2::Sha256;
use crate::common::PeerId;
include!(concat!(env!("OUT_DIR"), "/peer_rpc.rs"));
impl PeerGroupInfo {
pub fn generate_with_proof(group_name: String, group_secret: String, peer_id: PeerId) -> Self {
let mut mac = Hmac::<Sha256>::new_from_slice(group_secret.as_bytes())
.expect("HMAC can take key of any size");
let mut data_to_sign = group_name.as_bytes().to_vec();
data_to_sign.push(0x00); // Add a delimiter byte
data_to_sign.extend_from_slice(&peer_id.to_be_bytes());
mac.update(&data_to_sign);
let proof = mac.finalize().into_bytes().to_vec();
PeerGroupInfo {
group_name,
group_proof: proof,
}
}
pub fn verify(&self, group_secret: &str, peer_id: PeerId) -> bool {
let mut verifier = Hmac::<Sha256>::new_from_slice(group_secret.as_bytes())
.expect("HMAC can take key of any size");
let mut data_to_sign = self.group_name.as_bytes().to_vec();
data_to_sign.push(0x00); // Add a delimiter byte
data_to_sign.extend_from_slice(&peer_id.to_be_bytes());
verifier.update(&data_to_sign);
verifier.verify_slice(&self.group_proof).is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_peer_group_info_new() {
let group_name = "test_group".to_string();
let group_secret = "secret123".to_string();
let peer_id = 42u32;
let peer_group_info =
PeerGroupInfo::generate_with_proof(group_name.clone(), group_secret, peer_id);
assert_eq!(peer_group_info.group_name, group_name);
assert!(!peer_group_info.group_proof.is_empty());
// HMAC-SHA256 produces a 32-byte output
assert_eq!(peer_group_info.group_proof.len(), 32);
}
#[test]
fn test_peer_group_info_verify_valid() {
let group_name = "test_group".to_string();
let group_secret = "secret123".to_string();
let peer_id = 42u32;
let peer_group_info =
PeerGroupInfo::generate_with_proof(group_name, group_secret.clone(), peer_id);
// Verification should succeed using the same secret and peer_id
assert!(peer_group_info.verify(&group_secret, peer_id));
}
#[test]
fn test_peer_group_info_verify_invalid_secret() {
let group_name = "test_group".to_string();
let group_secret = "secret123".to_string();
let peer_id = 42u32;
let peer_group_info = PeerGroupInfo::generate_with_proof(group_name, group_secret, peer_id);
// Verification should fail with a wrong secret
assert!(!peer_group_info.verify("wrong_secret", peer_id));
}
#[test]
fn test_peer_group_info_verify_invalid_peer_id() {
let group_name = "test_group".to_string();
let group_secret = "secret123".to_string();
let peer_id = 42u32;
let peer_group_info =
PeerGroupInfo::generate_with_proof(group_name, group_secret.clone(), peer_id);
// Verification should fail with a wrong peer_id
assert!(!peer_group_info.verify(&group_secret, 999u32));
}
#[test]
fn test_peer_group_info_different_groups_different_proofs() {
let group_secret = "secret123".to_string();
let peer_id = 42u32;
let group1 =
PeerGroupInfo::generate_with_proof("group1".to_string(), group_secret.clone(), peer_id);
let group2 =
PeerGroupInfo::generate_with_proof("group2".to_string(), group_secret, peer_id);
// Different group names should produce different proofs
assert_ne!(group1.group_proof, group2.group_proof);
}
#[test]
fn test_peer_group_info_same_params_same_proof() {
let group_name = "test_group".to_string();
let group_secret = "secret123".to_string();
let peer_id = 42u32;
let group1 =
PeerGroupInfo::generate_with_proof(group_name.clone(), group_secret.clone(), peer_id);
let group2 = PeerGroupInfo::generate_with_proof(group_name, group_secret, peer_id);
// Same parameters should produce the same proof
assert_eq!(group1.group_proof, group2.group_proof);
}
#[test]
fn test_peer_group_info_empty_group_name() {
let group_name = "".to_string();
let group_secret = "secret123".to_string();
let peer_id = 42u32;
let peer_group_info =
PeerGroupInfo::generate_with_proof(group_name.clone(), group_secret.clone(), peer_id);
assert_eq!(peer_group_info.group_name, group_name);
assert!(peer_group_info.verify(&group_secret, peer_id));
}
#[test]
fn test_peer_group_info_empty_secret() {
let group_name = "test_group".to_string();
let group_secret = "".to_string();
let peer_id = 42u32;
let peer_group_info =
PeerGroupInfo::generate_with_proof(group_name, group_secret.clone(), peer_id);
assert!(peer_group_info.verify(&group_secret, peer_id));
}
#[test]
fn test_peer_group_info_unicode_group_name() {
let group_name = "测试组🚀".to_string();
let group_secret = "secret123".to_string();
let peer_id = 42u32;
let peer_group_info =
PeerGroupInfo::generate_with_proof(group_name.clone(), group_secret.clone(), peer_id);
assert_eq!(peer_group_info.group_name, group_name);
assert!(peer_group_info.verify(&group_secret, peer_id));
}
#[test]
fn test_peer_group_info_unicode_secret() {
let group_name = "test_group".to_string();
let group_secret = "密码123🔐".to_string();
let peer_id = 42u32;
let peer_group_info =
PeerGroupInfo::generate_with_proof(group_name, group_secret.clone(), peer_id);
assert!(peer_group_info.verify(&group_secret, peer_id));
}
#[test]
fn test_peer_group_info_zero_peer_id() {
let group_name = "test_group".to_string();
let group_secret = "secret123".to_string();
let peer_id = 0u32;
let peer_group_info =
PeerGroupInfo::generate_with_proof(group_name, group_secret.clone(), peer_id);
assert!(peer_group_info.verify(&group_secret, peer_id));
}
#[test]
fn test_peer_group_info_max_peer_id() {
let group_name = "test_group".to_string();
let group_secret = "secret123".to_string();
let peer_id = u32::MAX;
let peer_group_info =
PeerGroupInfo::generate_with_proof(group_name, group_secret.clone(), peer_id);
assert!(peer_group_info.verify(&group_secret, peer_id));
}
#[test]
#[ignore]
fn perf_test_generate_with_proof() {
let group_name = "test_group".to_string();
let group_secret = "secret123".to_string();
let peer_id = 42u32;
let iterations = 100000;
let start = std::time::Instant::now();
for _ in 0..iterations {
let _ = PeerGroupInfo::generate_with_proof(
group_name.clone(),
group_secret.clone(),
peer_id,
);
}
let duration = start.elapsed();
println!(
"generate_with_proof took {:?} for {} iterations",
duration, iterations
);
println!("Avg time per iteration: {:?}", duration / iterations as u32);
}
#[test]
#[ignore]
fn perf_test_verify() {
let group_name = "test_group".to_string();
let group_secret = "secret123".to_string();
let peer_id = 42u32;
let iterations = 100000;
let peer_group_info =
PeerGroupInfo::generate_with_proof(group_name.clone(), group_secret.clone(), peer_id);
let start = std::time::Instant::now();
for _ in 0..iterations {
assert!(peer_group_info.verify(&group_secret, peer_id));
}
let duration = start.elapsed();
println!("verify took {:?} for {} iterations", duration, iterations);
println!("Avg time per iteration: {:?}", duration / iterations as u32);
}
}

View File

@@ -1818,3 +1818,244 @@ pub async fn acl_rule_test_subnet_proxy(
drop_insts(insts).await;
}
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial]
pub async fn acl_group_based_test(
#[values("tcp", "udp")] protocol: &str,
#[values(true, false)] enable_kcp_proxy: bool,
#[values(true, false)] enable_quic_proxy: bool,
) {
use crate::tunnel::{
common::tests::_tunnel_pingpong_netns_with_timeout,
tcp::{TcpTunnelConnector, TcpTunnelListener},
udp::{UdpTunnelConnector, UdpTunnelListener},
TunnelConnector, TunnelListener,
};
use rand::Rng;
// 构造 ACL 配置,包含组信息
use crate::proto::acl::*;
// 设置组信息
let group_declares = vec![
GroupIdentity {
group_name: "admin".to_string(),
group_secret: "admin-secret".to_string(),
},
GroupIdentity {
group_name: "user".to_string(),
group_secret: "user-secret".to_string(),
},
];
let mut chain = Chain {
name: "group_acl_test".to_string(),
chain_type: ChainType::Inbound as i32,
enabled: true,
default_action: Action::Drop as i32,
..Default::default()
};
// 规则1: 允许admin组访问所有端口
let admin_allow_rule = Rule {
name: "allow_admin_all".to_string(),
priority: 300,
enabled: true,
action: Action::Allow as i32,
protocol: Protocol::Any as i32,
source_groups: vec!["admin".to_string()],
stateful: true,
..Default::default()
};
chain.rules.push(admin_allow_rule);
// 规则2: 允许user组访问8080端口
let user_8080_rule = Rule {
name: "allow_user_8080".to_string(),
priority: 200,
enabled: true,
action: Action::Allow as i32,
protocol: Protocol::Any as i32,
source_groups: vec!["user".to_string()],
ports: vec!["8080".to_string()],
stateful: true,
..Default::default()
};
chain.rules.push(user_8080_rule);
let acl_admin = Acl {
acl_v1: Some(AclV1 {
group: Some(GroupInfo {
declares: group_declares.clone(),
members: vec!["admin".to_string()],
}),
..AclV1::default()
}),
};
let acl_user = Acl {
acl_v1: Some(AclV1 {
group: Some(GroupInfo {
declares: group_declares.clone(),
members: vec!["user".to_string()],
}),
..AclV1::default()
}),
};
let acl_target = Acl {
acl_v1: Some(AclV1 {
chains: vec![chain.clone()],
group: Some(GroupInfo {
declares: group_declares.clone(),
members: vec![],
}),
}),
};
let insts = init_three_node_ex(
protocol,
move |cfg| {
match cfg.get_inst_name().as_str() {
"inst1" => {
cfg.set_acl(Some(acl_admin.clone()));
}
"inst2" => {
cfg.set_acl(Some(acl_user.clone()));
}
"inst3" => {
cfg.set_acl(Some(acl_target.clone()));
}
_ => {}
}
let mut flags = cfg.get_flags();
flags.enable_kcp_proxy = enable_kcp_proxy;
flags.enable_quic_proxy = enable_quic_proxy;
cfg.set_flags(flags);
cfg
},
false,
)
.await;
println!("Testing group-based ACL rules...");
let make_listener = |port: u16| -> Box<dyn TunnelListener + Send + Sync + 'static> {
match protocol {
"tcp" => Box::new(TcpTunnelListener::new(
format!("tcp://0.0.0.0:{}", port).parse().unwrap(),
)),
"udp" => Box::new(UdpTunnelListener::new(
format!("udp://0.0.0.0:{}", port).parse().unwrap(),
)),
_ => panic!("unsupported protocol: {}", protocol),
}
};
let make_connector = |port: u16| -> Box<dyn TunnelConnector + Send + Sync + 'static> {
match protocol {
"tcp" => Box::new(TcpTunnelConnector::new(
format!("tcp://10.144.144.3:{}", port).parse().unwrap(),
)),
"udp" => Box::new(UdpTunnelConnector::new(
format!("udp://10.144.144.3:{}", port).parse().unwrap(),
)),
_ => panic!("unsupported protocol: {}", protocol),
}
};
// 构造测试数据
let mut buf = vec![0; 32];
rand::thread_rng().fill(&mut buf[..]);
// 测试1: inst1 (admin组) 访问8080 - 应该成功
let result = _tunnel_pingpong_netns_with_timeout(
make_listener(8080),
make_connector(8080),
NetNS::new(Some("net_c".into())),
NetNS::new(Some("net_a".into())),
buf.clone(),
std::time::Duration::from_millis(30000),
)
.await;
assert!(
result.is_ok(),
"Admin group access to port 8080 should be allowed (protocol={})",
protocol
);
println!(
"✓ Admin group access to port 8080 succeeded ({})\n",
protocol
);
// 测试2: inst1 (admin组) 访问8081 - 应该成功
let result = _tunnel_pingpong_netns_with_timeout(
make_listener(8081),
make_connector(8081),
NetNS::new(Some("net_c".into())),
NetNS::new(Some("net_a".into())),
buf.clone(),
std::time::Duration::from_millis(30000),
)
.await;
assert!(
result.is_ok(),
"Admin group access to port 8081 should be allowed (protocol={})",
protocol
);
println!(
"✓ Admin group access to port 8081 succeeded ({})\n",
protocol
);
// 测试3: inst2 (user组) 访问8080 - 应该成功
let result = _tunnel_pingpong_netns_with_timeout(
make_listener(8080),
make_connector(8080),
NetNS::new(Some("net_c".into())),
NetNS::new(Some("net_b".into())),
buf.clone(),
std::time::Duration::from_millis(30000),
)
.await;
assert!(
result.is_ok(),
"User group access to port 8080 should be allowed (protocol={})",
protocol
);
println!(
"✓ User group access to port 8080 succeeded ({})\n",
protocol
);
// 测试4: inst2 (user组) 访问8081 - 应该失败
let result = _tunnel_pingpong_netns_with_timeout(
make_listener(8081),
make_connector(8081),
NetNS::new(Some("net_c".into())),
NetNS::new(Some("net_b".into())),
buf.clone(),
std::time::Duration::from_millis(200),
)
.await;
assert!(
result.is_err(),
"User group access to port 8081 should be blocked (protocol={})",
protocol
);
println!(
"✓ User group access to port 8081 blocked as expected ({})\n",
protocol
);
let stats = insts[2].get_global_ctx().get_acl_filter().get_stats();
println!("ACL stats after group {} tests: {:?}", protocol, stats);
println!("✓ All group-based ACL tests completed successfully");
drop_insts(insts).await;
}

View File

@@ -560,6 +560,45 @@ pub mod tests {
}
}
pub(crate) async fn _tunnel_pingpong_netns_with_timeout<L, C>(
listener: L,
connector: C,
l_netns: NetNS,
c_netns: NetNS,
buf: Vec<u8>,
timeout: std::time::Duration,
) -> Result<(), anyhow::Error>
where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
let handle = tokio::spawn(async move {
_tunnel_pingpong_netns(listener, connector, l_netns, c_netns, buf).await;
});
match tokio::time::timeout(timeout, handle).await {
Ok(join_res) => match join_res {
Ok(_) => Ok(()),
Err(join_err) => {
if join_err.is_panic() {
let payload = join_err.into_panic();
let msg = match payload.downcast::<String>() {
Ok(s) => *s,
Err(payload) => match payload.downcast::<&str>() {
Ok(s) => (*s).to_string(),
Err(_) => "non-string panic payload".to_string(),
},
};
Err(anyhow::anyhow!("task panicked: {}", msg))
} else {
Err(anyhow::anyhow!("task cancelled"))
}
}
},
Err(elapsed) => Err(elapsed.into()),
}
}
pub(crate) async fn _tunnel_bench<L, C>(listener: L, connector: C)
where
L: TunnelListener + Send + Sync + 'static,

View File

@@ -679,6 +679,14 @@ impl ZCPacket {
ZCPacketType::DummyTunnel,
)
}
pub fn get_src_peer_id(&self) -> Option<u32> {
self.peer_manager_header().map(|hdr| hdr.from_peer_id.get())
}
pub fn get_dst_peer_id(&self) -> Option<u32> {
self.peer_manager_header().map(|hdr| hdr.to_peer_id.get())
}
}
#[cfg(test)]

12
flake.lock generated
View File

@@ -20,11 +20,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1753429684,
"narHash": "sha256-9h7+4/53cSfQ/uA3pSvCaBepmZaz/dLlLVJnbQ+SJjk=",
"lastModified": 1754725699,
"narHash": "sha256-iAcj9T/Y+3DBy2J0N+yF9XQQQ8IEb5swLFzs23CdP88=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "7fd36ee82c0275fb545775cc5e4d30542899511d",
"rev": "85dbfc7aaf52ecb755f87e577ddbe6dbbdbc1054",
"type": "github"
},
"original": {
@@ -48,11 +48,11 @@
]
},
"locked": {
"lastModified": 1753671061,
"narHash": "sha256-IU4eBWfe9h2QejJYST+EAlhg8a1H6mh9gbcmWgZ2/mQ=",
"lastModified": 1754966322,
"narHash": "sha256-7f/LH60DnjjQVKbXAsHIniGaU7ixVM7eWU3hyjT24YI=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "40065d17ee4dbec3ded8ca61236132aede843fab",
"rev": "7c13cec2e3828d964b9980d0ffd680bd8d4dce90",
"type": "github"
},
"original": {