cli for port forward and tcp whitelist (#1165)

This commit is contained in:
Sijie.Sun
2025-07-29 09:30:47 +08:00
committed by GitHub
parent 5514de1187
commit 2ec88da823
8 changed files with 828 additions and 171 deletions

View File

@@ -6,8 +6,9 @@ use std::{
time::{Duration, SystemTime, UNIX_EPOCH},
};
use crate::common::token_bucket::TokenBucket;
use crate::common::{config::ConfigLoader, global_ctx::ArcGlobalCtx, token_bucket::TokenBucket};
use crate::proto::acl::*;
use anyhow::Context as _;
use dashmap::DashMap;
use tokio::task::JoinSet;
@@ -993,6 +994,146 @@ impl AclStatKey {
}
}
pub struct AclRuleBuilder {
pub acl: Option<Acl>,
pub tcp_whitelist: Vec<String>,
pub udp_whitelist: Vec<String>,
pub whitelist_priority: Option<u32>,
}
impl AclRuleBuilder {
fn parse_port_list(port_list: &[String]) -> anyhow::Result<Vec<String>> {
let mut ports = Vec::new();
for port_spec in port_list {
if port_spec.contains('-') {
// Handle port range like "8000-9000"
let parts: Vec<&str> = port_spec.split('-').collect();
if parts.len() != 2 {
return Err(anyhow::anyhow!("Invalid port range format: {}", port_spec));
}
let start: u16 = parts[0]
.parse()
.with_context(|| format!("Invalid start port in range: {}", port_spec))?;
let end: u16 = parts[1]
.parse()
.with_context(|| format!("Invalid end port in range: {}", port_spec))?;
if start > end {
return Err(anyhow::anyhow!(
"Start port must be <= end port in range: {}",
port_spec
));
}
// acl can handle port range
ports.push(port_spec.clone());
} else {
// Handle single port
let port: u16 = port_spec
.parse()
.with_context(|| format!("Invalid port number: {}", port_spec))?;
ports.push(port.to_string());
}
}
Ok(ports)
}
fn generate_acl_from_whitelists(&mut self) -> anyhow::Result<()> {
if self.tcp_whitelist.is_empty() && self.udp_whitelist.is_empty() {
return Ok(());
}
// Create inbound chain for whitelist rules
let mut inbound_chain = Chain {
name: "inbound_whitelist".to_string(),
chain_type: ChainType::Inbound as i32,
description: "Auto-generated inbound whitelist from CLI".to_string(),
enabled: true,
rules: vec![],
default_action: Action::Drop as i32, // Default deny
};
let mut rule_priority = self.whitelist_priority.unwrap_or(1000u32);
// Add TCP whitelist rules
if !self.tcp_whitelist.is_empty() {
let tcp_ports = Self::parse_port_list(&self.tcp_whitelist)?;
let tcp_rule = Rule {
name: "tcp_whitelist".to_string(),
description: "Auto-generated TCP whitelist rule".to_string(),
priority: rule_priority,
enabled: true,
protocol: Protocol::Tcp as i32,
ports: tcp_ports,
source_ips: vec![],
destination_ips: vec![],
source_ports: vec![],
action: Action::Allow as i32,
rate_limit: 0,
burst_limit: 0,
stateful: true,
};
inbound_chain.rules.push(tcp_rule);
rule_priority -= 1;
}
// Add UDP whitelist rules
if !self.udp_whitelist.is_empty() {
let udp_ports = Self::parse_port_list(&self.udp_whitelist)?;
let udp_rule = Rule {
name: "udp_whitelist".to_string(),
description: "Auto-generated UDP whitelist rule".to_string(),
priority: rule_priority,
enabled: true,
protocol: Protocol::Udp as i32,
ports: udp_ports,
source_ips: vec![],
destination_ips: vec![],
source_ports: vec![],
action: Action::Allow as i32,
rate_limit: 0,
burst_limit: 0,
stateful: false,
};
inbound_chain.rules.push(udp_rule);
}
if self.acl.is_none() {
self.acl = Some(Acl::default());
}
let acl = self.acl.as_mut().unwrap();
if let Some(ref mut acl_v1) = acl.acl_v1 {
acl_v1.chains.push(inbound_chain);
} else {
acl.acl_v1 = Some(AclV1 {
chains: vec![inbound_chain],
});
}
Ok(())
}
fn do_build(mut self) -> anyhow::Result<Option<Acl>> {
self.generate_acl_from_whitelists()?;
Ok(self.acl.clone())
}
pub fn build(global_ctx: &ArcGlobalCtx) -> anyhow::Result<Option<Acl>> {
let builder = AclRuleBuilder {
acl: global_ctx.config.get_acl(),
tcp_whitelist: global_ctx.config.get_tcp_whitelist(),
udp_whitelist: global_ctx.config.get_udp_whitelist(),
whitelist_priority: None,
};
builder.do_build()
}
}
#[derive(Debug, Clone, Copy)]
pub enum AclStatType {
Total,

View File

@@ -122,6 +122,12 @@ pub trait ConfigLoader: Send + Sync {
fn get_acl(&self) -> Option<Acl>;
fn set_acl(&self, acl: Option<Acl>);
fn get_tcp_whitelist(&self) -> Vec<String>;
fn set_tcp_whitelist(&self, whitelist: Vec<String>);
fn get_udp_whitelist(&self) -> Vec<String>;
fn set_udp_whitelist(&self, whitelist: Vec<String>);
fn dump(&self) -> String;
}
@@ -230,7 +236,7 @@ pub struct VpnPortalConfig {
pub wireguard_listen: SocketAddr,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash)]
pub struct PortForwardConfig {
pub bind_addr: SocketAddr,
pub dst_addr: SocketAddr,
@@ -299,6 +305,9 @@ struct Config {
flags_struct: Option<Flags>,
acl: Option<Acl>,
tcp_whitelist: Option<Vec<String>>,
udp_whitelist: Option<Vec<String>>,
}
#[derive(Debug, Clone)]
@@ -665,6 +674,32 @@ impl ConfigLoader for TomlConfigLoader {
self.config.lock().unwrap().acl = acl;
}
fn get_tcp_whitelist(&self) -> Vec<String> {
self.config
.lock()
.unwrap()
.tcp_whitelist
.clone()
.unwrap_or_default()
}
fn set_tcp_whitelist(&self, whitelist: Vec<String>) {
self.config.lock().unwrap().tcp_whitelist = Some(whitelist);
}
fn get_udp_whitelist(&self) -> Vec<String> {
self.config
.lock()
.unwrap()
.udp_whitelist
.clone()
.unwrap_or_default()
}
fn set_udp_whitelist(&self, whitelist: Vec<String>) {
self.config.lock().unwrap().udp_whitelist = Some(whitelist);
}
fn dump(&self) -> String {
let default_flags_json = serde_json::to_string(&gen_default_flags()).unwrap();
let default_flags_hashmap =

View File

@@ -22,23 +22,25 @@ use tokio::time::timeout;
use easytier::{
common::{
config::PortForwardConfig,
constants::EASYTIER_VERSION,
stun::{StunInfoCollector, StunInfoCollectorTrait},
},
proto::{
cli::{
list_peer_route_pair, AclManageRpc, AclManageRpcClientFactory, ConnectorManageRpc,
ConnectorManageRpcClientFactory, DumpRouteRequest, GetAclStatsRequest,
GetVpnPortalInfoRequest, ListConnectorRequest, ListForeignNetworkRequest,
ListGlobalForeignNetworkRequest, ListMappedListenerRequest, ListPeerRequest,
ListPeerResponse, ListRouteRequest, ListRouteResponse, ManageMappedListenerRequest,
MappedListenerManageAction, MappedListenerManageRpc,
MappedListenerManageRpcClientFactory, NodeInfo, PeerManageRpc,
PeerManageRpcClientFactory, ShowNodeInfoRequest, TcpProxyEntryState,
list_peer_route_pair, AclManageRpc, AclManageRpcClientFactory, AddPortForwardRequest,
ConnectorManageRpc, ConnectorManageRpcClientFactory, DumpRouteRequest,
GetAclStatsRequest, GetVpnPortalInfoRequest, GetWhitelistRequest, ListConnectorRequest,
ListForeignNetworkRequest, ListGlobalForeignNetworkRequest, ListMappedListenerRequest,
ListPeerRequest, ListPeerResponse, ListPortForwardRequest, ListRouteRequest,
ListRouteResponse, ManageMappedListenerRequest, MappedListenerManageAction,
MappedListenerManageRpc, MappedListenerManageRpcClientFactory, NodeInfo, PeerManageRpc,
PeerManageRpcClientFactory, PortForwardManageRpc, PortForwardManageRpcClientFactory,
RemovePortForwardRequest, SetWhitelistRequest, ShowNodeInfoRequest, TcpProxyEntryState,
TcpProxyEntryTransportType, TcpProxyRpc, TcpProxyRpcClientFactory, VpnPortalRpc,
VpnPortalRpcClientFactory,
},
common::NatType,
common::{NatType, SocketType},
peer_rpc::{GetGlobalPeerMapRequest, PeerCenterRpc, PeerCenterRpcClientFactory},
rpc_impl::standalone::StandAloneClient,
rpc_types::controller::BaseController,
@@ -96,6 +98,10 @@ enum SubCommand {
Proxy,
#[command(about = "show ACL rules statistics")]
Acl(AclArgs),
#[command(about = "manage port forwarding")]
PortForward(PortForwardArgs),
#[command(about = "manage TCP/UDP whitelist")]
Whitelist(WhitelistArgs),
#[command(about = t!("core_clap.generate_completions").to_string())]
GenAutocomplete { shell: Shell },
}
@@ -193,6 +199,62 @@ enum AclSubCommand {
Stats,
}
#[derive(Args, Debug)]
struct PortForwardArgs {
#[command(subcommand)]
sub_command: Option<PortForwardSubCommand>,
}
#[derive(Subcommand, Debug)]
enum PortForwardSubCommand {
/// Add port forward rule
Add {
#[arg(help = "Protocol (tcp/udp)")]
protocol: String,
#[arg(help = "Local bind address (e.g., 0.0.0.0:8080)")]
bind_addr: String,
#[arg(help = "Destination address (e.g., 10.1.1.1:80)")]
dst_addr: String,
},
/// Remove port forward rule
Remove {
#[arg(help = "Protocol (tcp/udp)")]
protocol: String,
#[arg(help = "Local bind address (e.g., 0.0.0.0:8080)")]
bind_addr: String,
#[arg(help = "Optional Destination address (e.g., 10.1.1.1:80)")]
dst_addr: Option<String>,
},
/// List port forward rules
List,
}
#[derive(Args, Debug)]
struct WhitelistArgs {
#[command(subcommand)]
sub_command: Option<WhitelistSubCommand>,
}
#[derive(Subcommand, Debug)]
enum WhitelistSubCommand {
/// Set TCP port whitelist
SetTcp {
#[arg(help = "TCP ports (e.g., 80,443,8000-9000)")]
ports: String,
},
/// Set UDP port whitelist
SetUdp {
#[arg(help = "UDP ports (e.g., 53,5000-6000)")]
ports: String,
},
/// Clear TCP whitelist
ClearTcp,
/// Clear UDP whitelist
ClearUdp,
/// Show current whitelist configuration
Show,
}
#[derive(Args, Debug)]
struct ServiceArgs {
#[arg(short, long, default_value = env!("CARGO_PKG_NAME"), help = "service name")]
@@ -340,6 +402,18 @@ impl CommandHandler<'_> {
.with_context(|| "failed to get vpn portal client")?)
}
async fn get_port_forward_manager_client(
&self,
) -> Result<Box<dyn PortForwardManageRpc<Controller = BaseController>>, Error> {
Ok(self
.client
.lock()
.unwrap()
.scoped_client::<PortForwardManageRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get port forward manager client")?)
}
async fn list_peers(&self) -> Result<ListPeerResponse, Error> {
let client = self.get_peer_manager_client().await?;
let request = ListPeerRequest::default();
@@ -788,6 +862,265 @@ impl CommandHandler<'_> {
}
Ok(url)
}
async fn handle_port_forward_add(
&self,
protocol: &str,
bind_addr: &str,
dst_addr: &str,
) -> Result<(), Error> {
let bind_addr: std::net::SocketAddr = bind_addr
.parse()
.with_context(|| format!("Invalid bind address: {}", bind_addr))?;
let dst_addr: std::net::SocketAddr = dst_addr
.parse()
.with_context(|| format!("Invalid destination address: {}", dst_addr))?;
if protocol != "tcp" && protocol != "udp" {
return Err(anyhow::anyhow!("Protocol must be 'tcp' or 'udp'"));
}
let client = self.get_port_forward_manager_client().await?;
let request = AddPortForwardRequest {
cfg: Some(
PortForwardConfig {
proto: protocol.to_string(),
bind_addr: bind_addr.into(),
dst_addr: dst_addr.into(),
}
.into(),
),
};
client
.add_port_forward(BaseController::default(), request)
.await?;
println!(
"Port forward rule added: {} {} -> {}",
protocol, bind_addr, dst_addr
);
Ok(())
}
async fn handle_port_forward_remove(
&self,
protocol: &str,
bind_addr: &str,
dst_addr: Option<&str>,
) -> Result<(), Error> {
let bind_addr: std::net::SocketAddr = bind_addr
.parse()
.with_context(|| format!("Invalid bind address: {}", bind_addr))?;
if protocol != "tcp" && protocol != "udp" {
return Err(anyhow::anyhow!("Protocol must be 'tcp' or 'udp'"));
}
let client = self.get_port_forward_manager_client().await?;
let request = RemovePortForwardRequest {
cfg: Some(
PortForwardConfig {
proto: protocol.to_string(),
bind_addr: bind_addr.into(),
dst_addr: dst_addr
.map(|s| s.parse::<SocketAddr>().unwrap())
.map(Into::into)
.unwrap_or("0.0.0.0:0".parse::<SocketAddr>().unwrap().into()),
}
.into(),
),
};
client
.remove_port_forward(BaseController::default(), request)
.await?;
println!("Port forward rule removed: {} {}", protocol, bind_addr);
Ok(())
}
async fn handle_port_forward_list(&self) -> Result<(), Error> {
let client = self.get_port_forward_manager_client().await?;
let request = ListPortForwardRequest::default();
let response = client
.list_port_forward(BaseController::default(), request)
.await?;
if self.verbose || *self.output_format == OutputFormat::Json {
println!("{}", serde_json::to_string_pretty(&response)?);
return Ok(());
}
#[derive(tabled::Tabled, serde::Serialize)]
struct PortForwardTableItem {
protocol: String,
bind_addr: String,
dst_addr: String,
}
let items: Vec<PortForwardTableItem> = response
.cfgs
.into_iter()
.map(|rule| PortForwardTableItem {
protocol: format!(
"{:?}",
SocketType::try_from(rule.socket_type).unwrap_or(SocketType::Tcp)
),
bind_addr: rule
.bind_addr
.map(|addr| addr.to_string())
.unwrap_or_default(),
dst_addr: rule
.dst_addr
.map(|addr| addr.to_string())
.unwrap_or_default(),
})
.collect();
print_output(&items, self.output_format)?;
Ok(())
}
async fn handle_whitelist_set_tcp(&self, ports: &str) -> Result<(), Error> {
let tcp_ports = Self::parse_port_list(ports)?;
let client = self.get_acl_manager_client().await?;
// Get current UDP ports to preserve them
let current = client
.get_whitelist(BaseController::default(), GetWhitelistRequest::default())
.await?;
let request = SetWhitelistRequest {
tcp_ports,
udp_ports: current.udp_ports,
};
client
.set_whitelist(BaseController::default(), request)
.await?;
println!("TCP whitelist updated: {}", ports);
Ok(())
}
async fn handle_whitelist_set_udp(&self, ports: &str) -> Result<(), Error> {
let udp_ports = Self::parse_port_list(ports)?;
let client = self.get_acl_manager_client().await?;
// Get current TCP ports to preserve them
let current = client
.get_whitelist(BaseController::default(), GetWhitelistRequest::default())
.await?;
let request = SetWhitelistRequest {
tcp_ports: current.tcp_ports,
udp_ports,
};
client
.set_whitelist(BaseController::default(), request)
.await?;
println!("UDP whitelist updated: {}", ports);
Ok(())
}
async fn handle_whitelist_clear_tcp(&self) -> Result<(), Error> {
let client = self.get_acl_manager_client().await?;
// Get current UDP ports to preserve them
let current = client
.get_whitelist(BaseController::default(), GetWhitelistRequest::default())
.await?;
let request = SetWhitelistRequest {
tcp_ports: vec![],
udp_ports: current.udp_ports,
};
client
.set_whitelist(BaseController::default(), request)
.await?;
println!("TCP whitelist cleared");
Ok(())
}
async fn handle_whitelist_clear_udp(&self) -> Result<(), Error> {
let client = self.get_acl_manager_client().await?;
// Get current TCP ports to preserve them
let current = client
.get_whitelist(BaseController::default(), GetWhitelistRequest::default())
.await?;
let request = SetWhitelistRequest {
tcp_ports: current.tcp_ports,
udp_ports: vec![],
};
client
.set_whitelist(BaseController::default(), request)
.await?;
println!("UDP whitelist cleared");
Ok(())
}
async fn handle_whitelist_show(&self) -> Result<(), Error> {
let client = self.get_acl_manager_client().await?;
let request = GetWhitelistRequest::default();
let response = client
.get_whitelist(BaseController::default(), request)
.await?;
if self.verbose || *self.output_format == OutputFormat::Json {
println!("{}", serde_json::to_string_pretty(&response)?);
return Ok(());
}
println!(
"TCP Whitelist: {}",
if response.tcp_ports.is_empty() {
"None".to_string()
} else {
response.tcp_ports.join(", ")
}
);
println!(
"UDP Whitelist: {}",
if response.udp_ports.is_empty() {
"None".to_string()
} else {
response.udp_ports.join(", ")
}
);
Ok(())
}
fn parse_port_list(ports_str: &str) -> Result<Vec<String>, Error> {
let mut ports = Vec::new();
for port_spec in ports_str.split(',') {
let port_spec = port_spec.trim();
if port_spec.contains('-') {
// Handle port range
let parts: Vec<&str> = port_spec.split('-').collect();
if parts.len() != 2 {
return Err(anyhow::anyhow!("Invalid port range: {}", port_spec));
}
let start: u16 = parts[0]
.parse()
.with_context(|| format!("Invalid start port: {}", parts[0]))?;
let end: u16 = parts[1]
.parse()
.with_context(|| format!("Invalid end port: {}", parts[1]))?;
if start > end {
return Err(anyhow::anyhow!("Invalid port range: start > end"));
}
ports.push(format!("{}-{}", start, end));
} else {
// Handle single port
let port: u16 = port_spec
.parse()
.with_context(|| format!("Invalid port number: {}", port_spec))?;
ports.push(port.to_string());
}
}
Ok(ports)
}
}
#[derive(Debug)]
@@ -1494,6 +1827,46 @@ async fn main() -> Result<(), Error> {
handler.handle_acl_stats().await?;
}
},
SubCommand::PortForward(port_forward_args) => match &port_forward_args.sub_command {
Some(PortForwardSubCommand::Add {
protocol,
bind_addr,
dst_addr,
}) => {
handler
.handle_port_forward_add(protocol, bind_addr, dst_addr)
.await?;
}
Some(PortForwardSubCommand::Remove {
protocol,
bind_addr,
dst_addr,
}) => {
handler
.handle_port_forward_remove(protocol, bind_addr, dst_addr.as_deref())
.await?;
}
Some(PortForwardSubCommand::List) | None => {
handler.handle_port_forward_list().await?;
}
},
SubCommand::Whitelist(whitelist_args) => match &whitelist_args.sub_command {
Some(WhitelistSubCommand::SetTcp { ports }) => {
handler.handle_whitelist_set_tcp(ports).await?;
}
Some(WhitelistSubCommand::SetUdp { ports }) => {
handler.handle_whitelist_set_udp(ports).await?;
}
Some(WhitelistSubCommand::ClearTcp) => {
handler.handle_whitelist_clear_tcp().await?;
}
Some(WhitelistSubCommand::ClearUdp) => {
handler.handle_whitelist_clear_udp().await?;
}
Some(WhitelistSubCommand::Show) | None => {
handler.handle_whitelist_show().await?;
}
},
SubCommand::GenAutocomplete { shell } => {
let mut cmd = Cli::command();
easytier::print_completions(shell, &mut cmd, "easytier-cli");

View File

@@ -29,10 +29,7 @@ use easytier::{
connector::create_connector_by_url,
instance_manager::NetworkInstanceManager,
launcher::{add_proxy_network_to_config, ConfigSource},
proto::{
acl::{Acl, AclV1, Action, Chain, ChainType, Protocol, Rule},
common::{CompressionAlgoPb, NatType},
},
proto::common::{CompressionAlgoPb, NatType},
tunnel::{IpVersion, PROTO_PORT_OFFSET},
utils::{init_logger, setup_panic_handler},
web_client,
@@ -622,115 +619,6 @@ impl NetworkOptions {
false
}
fn parse_port_list(port_list: &[String]) -> anyhow::Result<Vec<String>> {
let mut ports = Vec::new();
for port_spec in port_list {
if port_spec.contains('-') {
// Handle port range like "8000-9000"
let parts: Vec<&str> = port_spec.split('-').collect();
if parts.len() != 2 {
return Err(anyhow::anyhow!("Invalid port range format: {}", port_spec));
}
let start: u16 = parts[0]
.parse()
.with_context(|| format!("Invalid start port in range: {}", port_spec))?;
let end: u16 = parts[1]
.parse()
.with_context(|| format!("Invalid end port in range: {}", port_spec))?;
if start > end {
return Err(anyhow::anyhow!(
"Start port must be <= end port in range: {}",
port_spec
));
}
// acl can handle port range
ports.push(port_spec.clone());
} else {
// Handle single port
let port: u16 = port_spec
.parse()
.with_context(|| format!("Invalid port number: {}", port_spec))?;
ports.push(port.to_string());
}
}
Ok(ports)
}
fn generate_acl_from_whitelists(&self) -> anyhow::Result<Option<Acl>> {
if self.tcp_whitelist.is_empty() && self.udp_whitelist.is_empty() {
return Ok(None);
}
let mut acl = Acl {
acl_v1: Some(AclV1 { chains: vec![] }),
};
let acl_v1 = acl.acl_v1.as_mut().unwrap();
// Create inbound chain for whitelist rules
let mut inbound_chain = Chain {
name: "inbound_whitelist".to_string(),
chain_type: ChainType::Inbound as i32,
description: "Auto-generated inbound whitelist from CLI".to_string(),
enabled: true,
rules: vec![],
default_action: Action::Drop as i32, // Default deny
};
let mut rule_priority = 1000u32;
// Add TCP whitelist rules
if !self.tcp_whitelist.is_empty() {
let tcp_ports = Self::parse_port_list(&self.tcp_whitelist)?;
let tcp_rule = Rule {
name: "tcp_whitelist".to_string(),
description: "Auto-generated TCP whitelist rule".to_string(),
priority: rule_priority,
enabled: true,
protocol: Protocol::Tcp as i32,
ports: tcp_ports,
source_ips: vec![],
destination_ips: vec![],
source_ports: vec![],
action: Action::Allow as i32,
rate_limit: 0,
burst_limit: 0,
stateful: true,
};
inbound_chain.rules.push(tcp_rule);
rule_priority -= 1;
}
// Add UDP whitelist rules
if !self.udp_whitelist.is_empty() {
let udp_ports = Self::parse_port_list(&self.udp_whitelist)?;
let udp_rule = Rule {
name: "udp_whitelist".to_string(),
description: "Auto-generated UDP whitelist rule".to_string(),
priority: rule_priority,
enabled: true,
protocol: Protocol::Udp as i32,
ports: udp_ports,
source_ips: vec![],
destination_ips: vec![],
source_ports: vec![],
action: Action::Allow as i32,
rate_limit: 0,
burst_limit: 0,
stateful: false,
};
inbound_chain.rules.push(udp_rule);
}
acl_v1.chains.push(inbound_chain);
Ok(Some(acl))
}
fn merge_into(&self, cfg: &mut TomlConfigLoader) -> anyhow::Result<()> {
if self.hostname.is_some() {
cfg.set_hostname(self.hostname.clone());
@@ -988,10 +876,13 @@ impl NetworkOptions {
cfg.set_exit_nodes(self.exit_nodes.clone());
}
// Handle port whitelists by generating ACL configuration
if let Some(acl) = self.generate_acl_from_whitelists()? {
cfg.set_acl(Some(acl));
}
let mut old_tcp_whitelist = cfg.get_tcp_whitelist();
old_tcp_whitelist.extend(self.tcp_whitelist.clone());
cfg.set_tcp_whitelist(old_tcp_whitelist);
let mut old_udp_whitelist = cfg.get_udp_whitelist();
old_udp_whitelist.extend(self.udp_whitelist.clone());
cfg.set_udp_whitelist(old_udp_whitelist);
Ok(())
}

View File

@@ -6,6 +6,7 @@ use std::{
use crossbeam::atomic::AtomicCell;
use kcp_sys::{endpoint::KcpEndpoint, stream::KcpStream};
use tokio_util::sync::{CancellationToken, DropGuard};
use crate::{
common::{
@@ -432,6 +433,8 @@ pub struct Socks5Server {
udp_forward_task: Arc<DashMap<UdpClientKey, ScopedTask<()>>>,
kcp_endpoint: Mutex<Option<Weak<KcpEndpoint>>>,
cancel_tokens: DashMap<PortForwardConfig, DropGuard>,
}
#[async_trait::async_trait]
@@ -531,6 +534,8 @@ impl Socks5Server {
udp_forward_task: Arc::new(DashMap::new()),
kcp_endpoint: Mutex::new(None),
cancel_tokens: DashMap::new(),
})
}
@@ -614,10 +619,9 @@ impl Socks5Server {
need_start = true;
};
for port_forward in self.global_ctx.config.get_port_forwards() {
self.add_port_forward(port_forward).await?;
need_start = true;
}
let cfgs = self.global_ctx.config.get_port_forwards();
self.reload_port_forwards(&cfgs).await?;
need_start = need_start || cfgs.len() > 0;
if need_start {
self.peer_manager
@@ -630,6 +634,26 @@ impl Socks5Server {
Ok(())
}
pub async fn reload_port_forwards(&self, cfgs: &Vec<PortForwardConfig>) -> Result<(), Error> {
// remove entries not in new cfg
self.cancel_tokens.retain(|k, _| {
cfgs.iter().any(|cfg| {
if cfg.dst_addr.ip().is_unspecified() {
k.bind_addr == cfg.bind_addr && k.proto == cfg.proto
} else {
k == cfg
}
})
});
// add new ones
for cfg in cfgs {
if !self.cancel_tokens.contains_key(cfg) {
self.add_port_forward(cfg.clone()).await?;
}
}
Ok(())
}
async fn handle_port_forward_connection(
mut incoming_socket: tokio::net::TcpStream,
connector: Box<dyn AsyncTcpConnector<S = SocksTcpStream> + Send>,
@@ -660,12 +684,10 @@ impl Socks5Server {
pub async fn add_port_forward(&self, cfg: PortForwardConfig) -> Result<(), Error> {
match cfg.proto.to_lowercase().as_str() {
"tcp" => {
self.add_tcp_port_forward(cfg.bind_addr, cfg.dst_addr)
.await?;
self.add_tcp_port_forward(&cfg).await?;
}
"udp" => {
self.add_udp_port_forward(cfg.bind_addr, cfg.dst_addr)
.await?;
self.add_udp_port_forward(&cfg).await?;
}
_ => {
return Err(anyhow::anyhow!(
@@ -680,11 +702,12 @@ impl Socks5Server {
Ok(())
}
pub async fn add_tcp_port_forward(
&self,
bind_addr: SocketAddr,
dst_addr: SocketAddr,
) -> Result<(), Error> {
pub fn remove_port_forward(&self, cfg: PortForwardConfig) {
let _ = self.cancel_tokens.remove(&cfg);
}
pub async fn add_tcp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
let listener = bind_tcp_socket(bind_addr, self.global_ctx.net_ns.clone())?;
let net = self.net.clone();
@@ -693,14 +716,26 @@ impl Socks5Server {
let forward_tasks = tasks.clone();
let kcp_endpoint = self.kcp_endpoint.lock().await.clone();
let peer_mgr = Arc::downgrade(&self.peer_manager.clone());
let cancel_token = CancellationToken::new();
self.cancel_tokens
.insert(cfg.clone(), cancel_token.clone().drop_guard());
self.tasks.lock().unwrap().spawn(async move {
loop {
let (incoming_socket, addr) = match listener.accept().await {
Ok(result) => result,
Err(err) => {
tracing::error!("port forward accept error = {:?}", err);
continue;
let (incoming_socket, addr) = select! {
biased;
_ = cancel_token.cancelled() => {
tracing::info!("port forward for {:?} cancelled", bind_addr);
break;
}
res = listener.accept() => {
match res {
Ok(result) => result,
Err(err) => {
tracing::error!("port forward accept error = {:?}", err);
continue;
}
}
}
};
@@ -747,11 +782,8 @@ impl Socks5Server {
}
#[tracing::instrument(name = "add_udp_port_forward", skip(self))]
pub async fn add_udp_port_forward(
&self,
bind_addr: SocketAddr,
dst_addr: SocketAddr,
) -> Result<(), Error> {
pub async fn add_udp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> {
let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr);
let socket = Arc::new(bind_udp_socket(bind_addr, self.global_ctx.net_ns.clone())?);
let entries = self.entries.clone();
@@ -759,16 +791,28 @@ impl Socks5Server {
let net = self.net.clone();
let udp_client_map = self.udp_client_map.clone();
let udp_forward_task = self.udp_forward_task.clone();
let cancel_token = CancellationToken::new();
self.cancel_tokens
.insert(cfg.clone(), cancel_token.clone().drop_guard());
self.tasks.lock().unwrap().spawn(async move {
loop {
// we set the max buffer size of smoltcp to 8192, so we need to use a buffer size that is less than 8192 here.
let mut buf = vec![0u8; 8192];
let (len, addr) = match socket.recv_from(&mut buf).await {
Ok(result) => result,
Err(err) => {
tracing::error!("udp port forward recv error = {:?}", err);
continue;
let (len, addr) = select! {
biased;
_ = cancel_token.cancelled() => {
tracing::info!("udp port forward for {:?} cancelled", bind_addr);
break;
}
res = socket.recv_from(&mut buf) => {
match res {
Ok(result) => result,
Err(err) => {
tracing::error!("udp port forward recv error = {:?}", err);
continue;
}
}
}
};

View File

@@ -10,6 +10,7 @@ use cidr::{IpCidr, Ipv4Inet};
use tokio::{sync::Mutex, task::JoinSet};
use tokio_util::sync::CancellationToken;
use crate::common::acl_processor::AclRuleBuilder;
use crate::common::config::ConfigLoader;
use crate::common::error::Error;
use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent};
@@ -29,13 +30,15 @@ use crate::peers::peer_manager::{PeerManager, RouteAlgoType};
use crate::peers::rpc_service::PeerManagerRpcService;
use crate::peers::{create_packet_recv_chan, recv_packet_from_chan, PacketRecvChanReceiver};
use crate::proto::cli::VpnPortalRpc;
use crate::proto::cli::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
use crate::proto::cli::{
ListMappedListenerRequest, ListMappedListenerResponse, ManageMappedListenerRequest,
ManageMappedListenerResponse, MappedListener, MappedListenerManageAction,
MappedListenerManageRpc,
AddPortForwardRequest, AddPortForwardResponse, ListMappedListenerRequest,
ListMappedListenerResponse, ListPortForwardRequest, ListPortForwardResponse,
ManageMappedListenerRequest, ManageMappedListenerResponse, MappedListener,
MappedListenerManageAction, MappedListenerManageRpc, PortForwardManageRpc,
RemovePortForwardRequest, RemovePortForwardResponse,
};
use crate::proto::common::TunnelInfo;
use crate::proto::cli::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
use crate::proto::common::{PortForwardConfigPb, TunnelInfo};
use crate::proto::peer_rpc::PeerCenterRpcServer;
use crate::proto::rpc_impl::standalone::{RpcServerHook, StandAloneServer};
use crate::proto::rpc_types;
@@ -609,9 +612,9 @@ impl Instance {
}
}
if let Some(acl) = self.global_ctx.config.get_acl() {
self.global_ctx.get_acl_filter().reload_rules(Some(&acl));
}
self.global_ctx
.get_acl_filter()
.reload_rules(AclRuleBuilder::build(&self.global_ctx)?.as_ref());
// run after tun device created, so listener can bind to tun device, which may be required by win 10
self.ip_proxy = Some(IpProxy::new(
@@ -790,6 +793,85 @@ impl Instance {
MappedListenerManagerRpcService(self.global_ctx.clone())
}
fn get_port_forward_manager_rpc_service(
&self,
) -> impl PortForwardManageRpc<Controller = BaseController> + Clone {
#[derive(Clone)]
pub struct PortForwardManagerRpcService {
global_ctx: ArcGlobalCtx,
socks5_server: Weak<Socks5Server>,
}
#[async_trait::async_trait]
impl PortForwardManageRpc for PortForwardManagerRpcService {
type Controller = BaseController;
async fn add_port_forward(
&self,
_: BaseController,
request: AddPortForwardRequest,
) -> Result<AddPortForwardResponse, rpc_types::error::Error> {
let Some(socks5_server) = self.socks5_server.upgrade() else {
return Err(anyhow::anyhow!("socks5 server not available").into());
};
if let Some(cfg) = request.cfg {
tracing::info!("Port forward rule added: {:?}", cfg);
let mut current_forwards = self.global_ctx.config.get_port_forwards();
current_forwards.push(cfg.into());
self.global_ctx
.config
.set_port_forwards(current_forwards.clone());
socks5_server
.reload_port_forwards(&current_forwards)
.await
.with_context(|| "Failed to reload port forwards")?;
}
Ok(AddPortForwardResponse {})
}
async fn remove_port_forward(
&self,
_: BaseController,
request: RemovePortForwardRequest,
) -> Result<RemovePortForwardResponse, rpc_types::error::Error> {
let Some(socks5_server) = self.socks5_server.upgrade() else {
return Err(anyhow::anyhow!("socks5 server not available").into());
};
let Some(cfg) = request.cfg else {
return Err(anyhow::anyhow!("port forward config is empty").into());
};
let cfg = cfg.into();
let mut current_forwards = self.global_ctx.config.get_port_forwards();
current_forwards.retain(|e| *e != cfg);
self.global_ctx
.config
.set_port_forwards(current_forwards.clone());
socks5_server
.reload_port_forwards(&current_forwards)
.await
.with_context(|| "Failed to reload port forwards")?;
tracing::info!("Port forward rule removed: {:?}", cfg);
Ok(RemovePortForwardResponse {})
}
async fn list_port_forward(
&self,
_: BaseController,
_request: ListPortForwardRequest,
) -> Result<ListPortForwardResponse, rpc_types::error::Error> {
let forwards = self.global_ctx.config.get_port_forwards();
let cfgs: Vec<PortForwardConfigPb> = forwards.into_iter().map(Into::into).collect();
Ok(ListPortForwardResponse { cfgs })
}
}
PortForwardManagerRpcService {
global_ctx: self.global_ctx.clone(),
socks5_server: Arc::downgrade(&self.socks5_server),
}
}
async fn run_rpc_server(&mut self) -> Result<(), Error> {
let Some(_) = self.global_ctx.config.get_rpc_portal() else {
tracing::info!("rpc server not enabled, because rpc_portal is not set.");
@@ -803,6 +885,7 @@ impl Instance {
let peer_center = self.peer_center.clone();
let vpn_portal_rpc = self.get_vpn_portal_rpc_service();
let mapped_listener_manager_rpc = self.get_mapped_listener_manager_rpc_service();
let port_forward_manager_rpc = self.get_port_forward_manager_rpc_service();
let s = self.rpc_server.as_mut().unwrap();
let peer_mgr_rpc_service = PeerManagerRpcService::new(peer_mgr.clone());
@@ -823,6 +906,10 @@ impl Instance {
MappedListenerManageRpcServer::new(mapped_listener_manager_rpc),
"",
);
s.registry().register(
PortForwardManageRpcServer::new(port_forward_manager_rpc),
"",
);
if let Some(ip_proxy) = self.ip_proxy.as_ref() {
s.registry().register(

View File

@@ -1,13 +1,18 @@
use std::sync::Arc;
use crate::proto::{
cli::{
AclManageRpc, DumpRouteRequest, DumpRouteResponse, GetAclStatsRequest, GetAclStatsResponse,
ListForeignNetworkRequest, ListForeignNetworkResponse, ListGlobalForeignNetworkRequest,
ListGlobalForeignNetworkResponse, ListPeerRequest, ListPeerResponse, ListRouteRequest,
ListRouteResponse, PeerInfo, PeerManageRpc, ShowNodeInfoRequest, ShowNodeInfoResponse,
use crate::{
common::acl_processor::AclRuleBuilder,
proto::{
cli::{
AclManageRpc, DumpRouteRequest, DumpRouteResponse, GetAclStatsRequest,
GetAclStatsResponse, GetWhitelistRequest, GetWhitelistResponse,
ListForeignNetworkRequest, ListForeignNetworkResponse, ListGlobalForeignNetworkRequest,
ListGlobalForeignNetworkResponse, ListPeerRequest, ListPeerResponse, ListRouteRequest,
ListRouteResponse, PeerInfo, PeerManageRpc, SetWhitelistRequest, SetWhitelistResponse,
ShowNodeInfoRequest, ShowNodeInfoResponse,
},
rpc_types::{self, controller::BaseController},
},
rpc_types::{self, controller::BaseController},
};
use super::peer_manager::PeerManager;
@@ -153,4 +158,45 @@ impl AclManageRpc for PeerManagerRpcService {
acl_stats: Some(acl_stats),
})
}
async fn set_whitelist(
&self,
_: BaseController,
request: SetWhitelistRequest,
) -> Result<SetWhitelistResponse, rpc_types::error::Error> {
tracing::info!(
"Setting whitelist - TCP: {:?}, UDP: {:?}",
request.tcp_ports,
request.udp_ports
);
let global_ctx = self.peer_manager.get_global_ctx();
global_ctx.config.set_tcp_whitelist(request.tcp_ports);
global_ctx.config.set_udp_whitelist(request.udp_ports);
global_ctx
.get_acl_filter()
.reload_rules(AclRuleBuilder::build(&global_ctx)?.as_ref());
Ok(SetWhitelistResponse {})
}
async fn get_whitelist(
&self,
_: BaseController,
_request: GetWhitelistRequest,
) -> Result<GetWhitelistResponse, rpc_types::error::Error> {
let global_ctx = self.peer_manager.get_global_ctx();
let tcp_ports = global_ctx.config.get_tcp_whitelist();
let udp_ports = global_ctx.config.get_udp_whitelist();
tracing::info!(
"Getting whitelist - TCP: {:?}, UDP: {:?}",
tcp_ports,
udp_ports
);
Ok(GetWhitelistResponse {
tcp_ports,
udp_ports,
})
}
}

View File

@@ -261,4 +261,44 @@ message GetAclStatsResponse {
service AclManageRpc {
rpc GetAclStats(GetAclStatsRequest) returns (GetAclStatsResponse);
rpc SetWhitelist(SetWhitelistRequest) returns (SetWhitelistResponse);
rpc GetWhitelist(GetWhitelistRequest) returns (GetWhitelistResponse);
}
message SetWhitelistRequest {
repeated string tcp_ports = 1;
repeated string udp_ports = 2;
}
message SetWhitelistResponse {}
message GetWhitelistRequest {}
message GetWhitelistResponse {
repeated string tcp_ports = 1;
repeated string udp_ports = 2;
}
message AddPortForwardRequest {
common.PortForwardConfigPb cfg = 1;
}
message AddPortForwardResponse {}
message RemovePortForwardRequest {
common.PortForwardConfigPb cfg = 1;
}
message RemovePortForwardResponse {}
message ListPortForwardRequest {}
message ListPortForwardResponse {
repeated common.PortForwardConfigPb cfgs = 1;
}
service PortForwardManageRpc {
rpc AddPortForward(AddPortForwardRequest) returns (AddPortForwardResponse);
rpc RemovePortForward(RemovePortForwardRequest) returns (RemovePortForwardResponse);
rpc ListPortForward(ListPortForwardRequest) returns (ListPortForwardResponse);
}