clippy all codes (#1214)

1. clippy code
2. add fmt and clippy check in ci
This commit is contained in:
Sijie.Sun
2025-08-10 22:56:41 +08:00
committed by GitHub
parent 0087ac3ffc
commit e43537939a
144 changed files with 1475 additions and 1531 deletions

View File

@@ -229,8 +229,8 @@ jobs:
rustup set auto-self-update disable rustup set auto-self-update disable
rustup install 1.87 rustup install 1.89
rustup default 1.87 rustup default 1.89
export CC=clang export CC=clang
export CXX=clang++ export CXX=clang++

View File

@@ -29,7 +29,7 @@ jobs:
concurrent_skipping: 'same_content_newer' concurrent_skipping: 'same_content_newer'
skip_after_successful_duplicate: 'true' skip_after_successful_duplicate: 'true'
cancel_others: 'true' cancel_others: 'true'
paths: '["Cargo.toml", "Cargo.lock", "easytier/**", "easytier-gui/**", ".github/workflows/gui.yml", ".github/workflows/install_rust.sh"]' paths: '["Cargo.toml", "Cargo.lock", "easytier/**", "easytier-gui/**", ".github/workflows/gui.yml", ".github/workflows/install_rust.sh", ".github/workflows/install_gui_dep.sh"]'
build-gui: build-gui:
strategy: strategy:
fail-fast: false fail-fast: false
@@ -78,20 +78,11 @@ jobs:
needs: pre_job needs: pre_job
if: needs.pre_job.outputs.should_skip != 'true' if: needs.pre_job.outputs.should_skip != 'true'
steps: steps:
- uses: actions/checkout@v3
- name: Install GUI dependencies (x86 only) - name: Install GUI dependencies (x86 only)
if: ${{ matrix.TARGET == 'x86_64-unknown-linux-musl' }} if: ${{ matrix.TARGET == 'x86_64-unknown-linux-musl' }}
run: | run: bash ./.github/workflows/install_gui_dep.sh
sudo apt update
sudo apt install -qq libwebkit2gtk-4.1-dev \
build-essential \
curl \
wget \
file \
libgtk-3-dev \
librsvg2-dev \
libxdo-dev \
libssl-dev \
patchelf
- name: Install GUI cross compile (aarch64 only) - name: Install GUI cross compile (aarch64 only)
if: ${{ matrix.TARGET == 'aarch64-unknown-linux-musl' }} if: ${{ matrix.TARGET == 'aarch64-unknown-linux-musl' }}
@@ -128,8 +119,6 @@ jobs:
echo "PKG_CONFIG_SYSROOT_DIR=/usr/aarch64-linux-gnu/" >> "$GITHUB_ENV" echo "PKG_CONFIG_SYSROOT_DIR=/usr/aarch64-linux-gnu/" >> "$GITHUB_ENV"
echo "PKG_CONFIG_PATH=/usr/lib/aarch64-linux-gnu/pkgconfig/" >> "$GITHUB_ENV" echo "PKG_CONFIG_PATH=/usr/lib/aarch64-linux-gnu/pkgconfig/" >> "$GITHUB_ENV"
- uses: actions/checkout@v3
- name: Set current ref as env variable - name: Set current ref as env variable
run: | run: |
echo "GIT_DESC=$(git log -1 --format=%cd.%h --date=format:%Y-%m-%d_%H:%M:%S)" >> $GITHUB_ENV echo "GIT_DESC=$(git log -1 --format=%cd.%h --date=format:%Y-%m-%d_%H:%M:%S)" >> $GITHUB_ENV

11
.github/workflows/install_gui_dep.sh vendored Normal file
View File

@@ -0,0 +1,11 @@
sudo apt update
sudo apt install -qq libwebkit2gtk-4.1-dev \
build-essential \
curl \
wget \
file \
libgtk-3-dev \
librsvg2-dev \
libxdo-dev \
libssl-dev \
patchelf

View File

@@ -31,8 +31,8 @@ fi
# see https://github.com/rust-lang/rustup/issues/3709 # see https://github.com/rust-lang/rustup/issues/3709
rustup set auto-self-update disable rustup set auto-self-update disable
rustup install 1.87 rustup install 1.89
rustup default 1.87 rustup default 1.89
# mips/mipsel cannot add target from rustup, need compile by ourselves # mips/mipsel cannot add target from rustup, need compile by ourselves
if [[ $OS =~ ^ubuntu.*$ && $TARGET =~ ^mips.*$ ]]; then if [[ $OS =~ ^ubuntu.*$ && $TARGET =~ ^mips.*$ ]]; then

View File

@@ -28,7 +28,7 @@ jobs:
# All of these options are optional, so you can remove them if you are happy with the defaults # All of these options are optional, so you can remove them if you are happy with the defaults
concurrent_skipping: 'never' concurrent_skipping: 'never'
skip_after_successful_duplicate: 'true' skip_after_successful_duplicate: 'true'
paths: '["Cargo.toml", "Cargo.lock", "easytier/**", ".github/workflows/test.yml"]' paths: '["Cargo.toml", "Cargo.lock", "easytier/**", ".github/workflows/test.yml", ".github/workflows/install_gui_dep.sh", ".github/workflows/install_rust.sh"]'
test: test:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
needs: pre_job needs: pre_job
@@ -89,6 +89,24 @@ jobs:
./target ./target
key: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.lock') }} key: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.lock') }}
- name: Install GUI dependencies (Used by clippy)
run: |
bash ./.github/workflows/install_gui_dep.sh
bash ./.github/workflows/install_rust.sh
rustup component add rustfmt
rustup component add clippy
- name: Check formatting
if: ${{ !cancelled() }}
run: cargo fmt --all -- --check
- name: Check Clippy
if: ${{ !cancelled() }}
# NOTE: tauri need `dist` dir in build.rs
run: |
mkdir -p easytier-gui/dist
cargo clippy --all-targets --all-features --all -- -D warnings
- name: Run tests - name: Run tests
run: | run: |
sudo prlimit --pid $$ --nofile=1048576:1048576 sudo prlimit --pid $$ --nofile=1048576:1048576

View File

@@ -26,7 +26,7 @@ Thank you for your interest in contributing to EasyTier! This document provides
#### Required Tools #### Required Tools
- Node.js v21 or higher - Node.js v21 or higher
- pnpm v9 or higher - pnpm v9 or higher
- Rust toolchain (version 1.87) - Rust toolchain (version 1.89)
- LLVM and Clang - LLVM and Clang
- Protoc (Protocol Buffers compiler) - Protoc (Protocol Buffers compiler)
@@ -79,8 +79,8 @@ sudo apt install -y bridge-utils
2. Install dependencies: 2. Install dependencies:
```bash ```bash
# Install Rust toolchain # Install Rust toolchain
rustup install 1.87 rustup install 1.89
rustup default 1.87 rustup default 1.89
# Install project dependencies # Install project dependencies
pnpm -r install pnpm -r install

View File

@@ -34,7 +34,7 @@
#### 必需工具 #### 必需工具
- Node.js v21 或更高版本 - Node.js v21 或更高版本
- pnpm v9 或更高版本 - pnpm v9 或更高版本
- Rust 工具链(版本 1.87 - Rust 工具链(版本 1.89
- LLVM 和 Clang - LLVM 和 Clang
- ProtocProtocol Buffers 编译器) - ProtocProtocol Buffers 编译器)
@@ -87,8 +87,8 @@ sudo apt install -y bridge-utils
2. 安装依赖: 2. 安装依赖:
```bash ```bash
# 安装 Rust 工具链 # 安装 Rust 工具链
rustup install 1.87 rustup install 1.89
rustup default 1.87 rustup default 1.89
# 安装项目依赖 # 安装项目依赖
pnpm -r install pnpm -r install

View File

@@ -29,8 +29,10 @@ fn set_error_msg(msg: &str) {
msg_buf[..len].copy_from_slice(bytes); msg_buf[..len].copy_from_slice(bytes);
} }
/// # Safety
/// Set the tun fd
#[no_mangle] #[no_mangle]
pub extern "C" fn set_tun_fd( pub unsafe extern "C" fn set_tun_fd(
inst_name: *const std::ffi::c_char, inst_name: *const std::ffi::c_char,
fd: std::ffi::c_int, fd: std::ffi::c_int,
) -> std::ffi::c_int { ) -> std::ffi::c_int {
@@ -43,18 +45,23 @@ pub extern "C" fn set_tun_fd(
if !INSTANCE_NAME_ID_MAP.contains_key(&inst_name) { if !INSTANCE_NAME_ID_MAP.contains_key(&inst_name) {
return -1; return -1;
} }
match INSTANCE_MANAGER.set_tun_fd(&INSTANCE_NAME_ID_MAP.get(&inst_name).unwrap().value(), fd) {
Ok(_) => { let inst_id = *INSTANCE_NAME_ID_MAP
0 .get(&inst_name)
} .as_ref()
Err(_) => { .unwrap()
-1 .value();
}
match INSTANCE_MANAGER.set_tun_fd(&inst_id, fd) {
Ok(_) => 0,
Err(_) => -1,
} }
} }
/// # Safety
/// Get the last error message
#[no_mangle] #[no_mangle]
pub extern "C" fn get_error_msg(out: *mut *const std::ffi::c_char) { pub unsafe extern "C" fn get_error_msg(out: *mut *const std::ffi::c_char) {
let msg_buf = ERROR_MSG.lock().unwrap(); let msg_buf = ERROR_MSG.lock().unwrap();
if msg_buf.is_empty() { if msg_buf.is_empty() {
unsafe { unsafe {
@@ -78,8 +85,10 @@ pub extern "C" fn free_string(s: *const std::ffi::c_char) {
} }
} }
/// # Safety
/// Parse the config
#[no_mangle] #[no_mangle]
pub extern "C" fn parse_config(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int { pub unsafe extern "C" fn parse_config(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int {
let cfg_str = unsafe { let cfg_str = unsafe {
assert!(!cfg_str.is_null()); assert!(!cfg_str.is_null());
std::ffi::CStr::from_ptr(cfg_str) std::ffi::CStr::from_ptr(cfg_str)
@@ -95,8 +104,10 @@ pub extern "C" fn parse_config(cfg_str: *const std::ffi::c_char) -> std::ffi::c_
0 0
} }
/// # Safety
/// Run the network instance
#[no_mangle] #[no_mangle]
pub extern "C" fn run_network_instance(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int { pub unsafe extern "C" fn run_network_instance(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int {
let cfg_str = unsafe { let cfg_str = unsafe {
assert!(!cfg_str.is_null()); assert!(!cfg_str.is_null());
std::ffi::CStr::from_ptr(cfg_str) std::ffi::CStr::from_ptr(cfg_str)
@@ -131,8 +142,10 @@ pub extern "C" fn run_network_instance(cfg_str: *const std::ffi::c_char) -> std:
0 0
} }
/// # Safety
/// Retain the network instance
#[no_mangle] #[no_mangle]
pub extern "C" fn retain_network_instance( pub unsafe extern "C" fn retain_network_instance(
inst_names: *const *const std::ffi::c_char, inst_names: *const *const std::ffi::c_char,
length: usize, length: usize,
) -> std::ffi::c_int { ) -> std::ffi::c_int {
@@ -168,13 +181,15 @@ pub extern "C" fn retain_network_instance(
return -1; return -1;
} }
let _ = INSTANCE_NAME_ID_MAP.retain(|k, _| inst_names.contains(k)); INSTANCE_NAME_ID_MAP.retain(|k, _| inst_names.contains(k));
0 0
} }
/// # Safety
/// Collect the network infos
#[no_mangle] #[no_mangle]
pub extern "C" fn collect_network_infos( pub unsafe extern "C" fn collect_network_infos(
infos: *mut KeyValuePair, infos: *mut KeyValuePair,
max_length: usize, max_length: usize,
) -> std::ffi::c_int { ) -> std::ffi::c_int {
@@ -233,7 +248,9 @@ mod tests {
network = "test_network" network = "test_network"
"#; "#;
let cstr = std::ffi::CString::new(cfg_str).unwrap(); let cstr = std::ffi::CString::new(cfg_str).unwrap();
assert_eq!(parse_config(cstr.as_ptr()), 0); unsafe {
assert_eq!(parse_config(cstr.as_ptr()), 0);
}
} }
#[test] #[test]
@@ -243,6 +260,8 @@ mod tests {
network = "test_network" network = "test_network"
"#; "#;
let cstr = std::ffi::CString::new(cfg_str).unwrap(); let cstr = std::ffi::CString::new(cfg_str).unwrap();
assert_eq!(run_network_instance(cstr.as_ptr()), 0); unsafe {
assert_eq!(run_network_instance(cstr.as_ptr()), 0);
}
} }
} }

View File

@@ -16,41 +16,13 @@ impl Command {
/// Check the state the current program running /// Check the state the current program running
/// ///
/// Return `true` if the program is running as root, otherwise false /// Return `true` if the program is running as root, otherwise false
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
///
/// fn main() {
/// let is_elevated = Command::is_elevated();
///
/// }
/// ```
pub fn is_elevated() -> bool { pub fn is_elevated() -> bool {
let uid = unsafe { libc::getuid() }; let uid = unsafe { libc::getuid() };
if uid == 0 { uid == 0
true
} else {
false
}
} }
/// Prompting the user with a graphical OS dialog for the root password, /// Prompting the user with a graphical OS dialog for the root password,
/// excuting the command with escalated privileges, and return the output /// excuting the command with escalated privileges, and return the output
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// let output = elevated_cmd.output().unwrap();
/// }
/// ```
pub fn output(&self) -> Result<Output> { pub fn output(&self) -> Result<Output> {
let pkexec = PathBuf::from_str("/bin/pkexec")?; let pkexec = PathBuf::from_str("/bin/pkexec")?;
let mut command = StdCommand::new(pkexec); let mut command = StdCommand::new(pkexec);
@@ -70,10 +42,8 @@ impl Command {
if let Ok(home) = home { if let Ok(home) = home {
command.arg(format!("HOME={}", home)); command.arg(format!("HOME={}", home));
} }
} else { } else if self.cmd.get_envs().any(|(_, v)| v.is_some()) {
if self.cmd.get_envs().any(|(_, v)| v.is_some()) { command.arg("env");
command.arg("env");
}
} }
for (k, v) in self.cmd.get_envs() { for (k, v) in self.cmd.get_envs() {
if let Some(value) = v { if let Some(value) = v {

View File

@@ -40,22 +40,6 @@ impl Command {
/// To pass environment variables on Windows, /// To pass environment variables on Windows,
/// to inherit environment variables from the parent process and /// to inherit environment variables from the parent process and
/// to change the working directory will be supported in later versions /// to change the working directory will be supported in later versions
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
///
/// cmd.arg("some arg");
/// cmd.env("some key", "some value");
///
/// let elevated_cmd = Command::new(cmd);
/// }
/// ```
pub fn new(cmd: StdCommand) -> Self { pub fn new(cmd: StdCommand) -> Self {
Self { Self {
cmd, cmd,
@@ -67,73 +51,21 @@ impl Command {
/// Consumes the `Take`, returning the wrapped std::process::Command /// Consumes the `Take`, returning the wrapped std::process::Command
/// ///
/// # Examples /// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// let cmd = elevated_cmd.into_inner();
/// }
/// ```
pub fn into_inner(self) -> StdCommand { pub fn into_inner(self) -> StdCommand {
self.cmd self.cmd
} }
/// Gets a mutable reference to the underlying std::process::Command /// Gets a mutable reference to the underlying std::process::Command
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// let cmd = elevated_cmd.get_ref();
/// }
/// ```
pub fn get_ref(&self) -> &StdCommand { pub fn get_ref(&self) -> &StdCommand {
&self.cmd &self.cmd
} }
/// Gets a reference to the underlying std::process::Command /// Gets a reference to the underlying std::process::Command
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// let cmd = elevated_cmd.get_mut();
/// }
/// ```
pub fn get_mut(&mut self) -> &mut StdCommand { pub fn get_mut(&mut self) -> &mut StdCommand {
&mut self.cmd &mut self.cmd
} }
/// Set the `icon` for the pop-up graphical OS dialog /// Set the `icon` for the pop-up graphical OS dialog
///
/// This method is only applicable on `MacOS`
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// elevated_cmd.icon(include_bytes!("path to the icon").to_vec());
/// }
/// ```
pub fn icon(&mut self, icon: Vec<u8>) -> &mut Self { pub fn icon(&mut self, icon: Vec<u8>) -> &mut Self {
self.icon = Some(icon); self.icon = Some(icon);
self self
@@ -142,19 +74,6 @@ impl Command {
/// Set the name for the pop-up graphical OS dialog /// Set the name for the pop-up graphical OS dialog
/// ///
/// This method is only applicable on `MacOS` /// This method is only applicable on `MacOS`
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// elevated_cmd.name("some name".to_string());
/// }
/// ```
pub fn name(&mut self, name: String) -> &mut Self { pub fn name(&mut self, name: String) -> &mut Self {
self.name = Some(name); self.name = Some(name);
self self

View File

@@ -8,7 +8,7 @@ repository = "https://github.com/EasyTier/EasyTier"
authors = ["kkrainbow"] authors = ["kkrainbow"]
keywords = ["vpn", "p2p", "network", "easytier"] keywords = ["vpn", "p2p", "network", "easytier"]
categories = ["network-programming", "command-line-utilities"] categories = ["network-programming", "command-line-utilities"]
rust-version = "1.87.0" rust-version = "1.89.0"
license-file = "LICENSE" license-file = "LICENSE"
readme = "README.md" readme = "README.md"

View File

@@ -14,18 +14,11 @@ const NAMESPACE: &str = "easytier::proto::rpc_types";
/// ///
/// See the crate-level documentation for more info. /// See the crate-level documentation for more info.
#[allow(missing_copy_implementations)] #[allow(missing_copy_implementations)]
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
pub struct ServiceGenerator { pub struct ServiceGenerator {
_private: (), _private: (),
} }
impl ServiceGenerator {
/// Create a new `ServiceGenerator` instance with the default options set.
pub fn new() -> ServiceGenerator {
ServiceGenerator { _private: () }
}
}
impl prost_build::ServiceGenerator for ServiceGenerator { impl prost_build::ServiceGenerator for ServiceGenerator {
fn generate(&mut self, service: prost_build::Service, mut buf: &mut String) { fn generate(&mut self, service: prost_build::Service, mut buf: &mut String) {
use std::fmt::Write; use std::fmt::Write;
@@ -78,7 +71,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
enum_methods, enum_methods,
" {name} = {index},", " {name} = {index},",
name = method.proto_name, name = method.proto_name,
index = format!("{}", idx + 1) index = idx + 1
) )
.unwrap(); .unwrap();
@@ -87,7 +80,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
" {index} => Ok({service_name}MethodDescriptor::{name}),", " {index} => Ok({service_name}MethodDescriptor::{name}),",
service_name = service.name, service_name = service.name,
name = method.proto_name, name = method.proto_name,
index = format!("{}", idx + 1), index = idx + 1,
) )
.unwrap(); .unwrap();
@@ -102,12 +95,12 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
writeln!( writeln!(
client_methods, client_methods,
r#" async fn {name}(&self, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{ r#" async fn {name}(&self, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{
{client_name}::{name}_inner(self.0.clone(), ctrl, input).await {client_name}Client::{name}_inner(self.0.clone(), ctrl, input).await
}}"#, }}"#,
name = method.name, name = method.name,
input_type = method.input_type, input_type = method.input_type,
output_type = method.output_type, output_type = method.output_type,
client_name = format!("{}Client", service.name), client_name = service.name,
namespace = NAMESPACE, namespace = NAMESPACE,
) )
.unwrap(); .unwrap();

View File

@@ -1,7 +1,10 @@
fn main() { fn main() {
// enable thunk-rs when target os is windows and arch is x86_64 or i686 // enable thunk-rs when target os is windows and arch is x86_64 or i686
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
if !std::env::var("TARGET").unwrap_or_default().contains("aarch64"){ if !std::env::var("TARGET")
thunk::thunk(); .unwrap_or_default()
} .contains("aarch64")
} {
thunk::thunk();
}
}

View File

@@ -25,7 +25,7 @@ fn load_geoip_db(geoip_db: Option<String>) -> Option<maxminddb::Reader<Vec<u8>>>
match maxminddb::Reader::open_readfile(&path) { match maxminddb::Reader::open_readfile(&path) {
Ok(reader) => { Ok(reader) => {
tracing::info!("Successfully loaded GeoIP2 database from {}", path); tracing::info!("Successfully loaded GeoIP2 database from {}", path);
return Some(reader); Some(reader)
} }
Err(err) => { Err(err) => {
tracing::debug!("Failed to load GeoIP2 database from {}: {}", path, err); tracing::debug!("Failed to load GeoIP2 database from {}: {}", path, err);
@@ -207,10 +207,8 @@ impl ClientManager {
let region = city.subdivisions.map(|r| { let region = city.subdivisions.map(|r| {
r.iter() r.iter()
.map(|x| x.names.as_ref()) .filter_map(|x| x.names.as_ref())
.flatten() .filter_map(|x| x.get("zh-CN").or_else(|| x.get("en")))
.map(|x| x.get("zh-CN").or_else(|| x.get("en")))
.flatten()
.map(|x| x.to_string()) .map(|x| x.to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(",") .join(",")

View File

@@ -94,14 +94,10 @@ impl SessionRpcService {
return Ok(HeartbeatResponse {}); return Ok(HeartbeatResponse {});
}; };
let machine_id: uuid::Uuid = let machine_id: uuid::Uuid = req.machine_id.map(Into::into).ok_or(anyhow::anyhow!(
"Machine id is not set correctly, expect uuid but got: {:?}",
req.machine_id req.machine_id
.clone() ))?;
.map(Into::into)
.ok_or(anyhow::anyhow!(
"Machine id is not set correctly, expect uuid but got: {:?}",
req.machine_id
))?;
let user_id = storage let user_id = storage
.db() .db()
@@ -121,7 +117,7 @@ impl SessionRpcService {
if data.req.replace(req.clone()).is_none() { if data.req.replace(req.clone()).is_none() {
assert!(data.storage_token.is_none()); assert!(data.storage_token.is_none());
data.storage_token = Some(StorageToken { data.storage_token = Some(StorageToken {
token: req.user_token.clone().into(), token: req.user_token.clone(),
client_url: data.client_url.clone(), client_url: data.client_url.clone(),
machine_id, machine_id,
user_id, user_id,

View File

@@ -34,7 +34,7 @@ impl TryFrom<WeakRefStorage> for Storage {
type Error = (); type Error = ();
fn try_from(weak: Weak<StorageInner>) -> Result<Self, Self::Error> { fn try_from(weak: Weak<StorageInner>) -> Result<Self, Self::Error> {
weak.upgrade().map(|inner| Storage(inner)).ok_or(()) weak.upgrade().map(Storage).ok_or(())
} }
} }
@@ -51,9 +51,7 @@ impl Storage {
machine_id: &uuid::Uuid, machine_id: &uuid::Uuid,
client_url: &url::Url, client_url: &url::Url,
) { ) {
map.remove_if(&machine_id, |_, v| { map.remove_if(machine_id, |_, v| v.storage_token.client_url == *client_url);
v.storage_token.client_url == *client_url
});
} }
fn update_mid_to_client_info_map( fn update_mid_to_client_info_map(
@@ -74,11 +72,7 @@ impl Storage {
} }
pub fn update_client(&self, stoken: StorageToken, report_time: i64) { pub fn update_client(&self, stoken: StorageToken, report_time: i64) {
let inner = self let inner = self.0.user_clients_map.entry(stoken.user_id).or_default();
.0
.user_clients_map
.entry(stoken.user_id)
.or_insert_with(DashMap::new);
let client_info = ClientInfo { let client_info = ClientInfo {
storage_token: stoken.clone(), storage_token: stoken.clone(),

View File

@@ -151,7 +151,7 @@ async fn get_dual_stack_listener(
} else { } else {
None None
}; };
let v4_listener = if let Ok(_) = local_ipv4().await { let v4_listener = if local_ipv4().await.is_ok() {
get_listener_by_url(&format!("{}://0.0.0.0:{}", protocol, port).parse().unwrap()).ok() get_listener_by_url(&format!("{}://0.0.0.0:{}", protocol, port).parse().unwrap()).ok()
} else { } else {
None None

View File

@@ -137,7 +137,7 @@ mod post {
mod get { mod get {
use crate::restful::{ use crate::restful::{
captcha::{ captcha::{
captcha::spec::SpecCaptcha, builder::spec::SpecCaptcha,
extension::{axum_tower_sessions::CaptchaAxumTowerSessionExt as _, CaptchaUtil}, extension::{axum_tower_sessions::CaptchaAxumTowerSessionExt as _, CaptchaUtil},
NewCaptcha as _, NewCaptcha as _,
}, },

View File

@@ -46,22 +46,22 @@ pub(crate) struct Captcha {
/// 验证码文本类型 The character type of the captcha /// 验证码文本类型 The character type of the captcha
pub enum CaptchaType { pub enum CaptchaType {
/// 字母数字混合 /// 字母数字混合
TypeDefault = 1, Default = 1,
/// 纯数字 /// 纯数字
TypeOnlyNumber, OnlyNumber,
/// 纯字母 /// 纯字母
TypeOnlyChar, OnlyChar,
/// 纯大写字母 /// 纯大写字母
TypeOnlyUpper, OnlyUpper,
/// 纯小写字母 /// 纯小写字母
TypeOnlyLower, OnlyLower,
/// 数字大写字母 /// 数字大写字母
TypeNumAndUpper, NumAndUpper,
} }
/// 内置字体 Fonts shipped with the library /// 内置字体 Fonts shipped with the library
@@ -92,29 +92,29 @@ impl Captcha {
/// 生成随机验证码 /// 生成随机验证码
pub fn alphas(&mut self) -> Vec<char> { pub fn alphas(&mut self) -> Vec<char> {
let mut cs = vec!['\0'; self.len]; let mut cs = vec!['\0'; self.len];
for i in 0..self.len { for cs_i in cs.iter_mut() {
match self.char_type { match self.char_type {
CaptchaType::TypeDefault => cs[i] = self.randoms.alpha(), CaptchaType::Default => *cs_i = self.randoms.alpha(),
CaptchaType::TypeOnlyNumber => { CaptchaType::OnlyNumber => {
cs[i] = self.randoms.alpha_under(self.randoms.num_max_index) *cs_i = self.randoms.alpha_under(self.randoms.num_max_index)
} }
CaptchaType::TypeOnlyChar => { CaptchaType::OnlyChar => {
cs[i] = self *cs_i = self
.randoms .randoms
.alpha_between(self.randoms.char_min_index, self.randoms.char_max_index) .alpha_between(self.randoms.char_min_index, self.randoms.char_max_index)
} }
CaptchaType::TypeOnlyUpper => { CaptchaType::OnlyUpper => {
cs[i] = self *cs_i = self
.randoms .randoms
.alpha_between(self.randoms.upper_min_index, self.randoms.upper_max_index) .alpha_between(self.randoms.upper_min_index, self.randoms.upper_max_index)
} }
CaptchaType::TypeOnlyLower => { CaptchaType::OnlyLower => {
cs[i] = self *cs_i = self
.randoms .randoms
.alpha_between(self.randoms.lower_min_index, self.randoms.lower_max_index) .alpha_between(self.randoms.lower_min_index, self.randoms.lower_max_index)
} }
CaptchaType::TypeNumAndUpper => { CaptchaType::NumAndUpper => {
cs[i] = self.randoms.alpha_under(self.randoms.upper_max_index) *cs_i = self.randoms.alpha_under(self.randoms.upper_max_index)
} }
} }
} }
@@ -142,7 +142,7 @@ impl Captcha {
} }
} }
pub fn get_font(&mut self) -> Arc<Font> { pub fn get_font(&'_ mut self) -> Arc<Font<'_>> {
if let Some(font) = font::get_font(&self.font_name) { if let Some(font) = font::get_font(&self.font_name) {
font font
} else { } else {
@@ -185,6 +185,7 @@ where
/// 特别地/In particular: /// 特别地/In particular:
/// ///
/// - 对算术验证码[ArithmeticCaptcha](crate::captcha::arithmetic::ArithmeticCaptcha)而言,这里的`len`是验证码中数字的数量。 /// - 对算术验证码[ArithmeticCaptcha](crate::captcha::arithmetic::ArithmeticCaptcha)而言,这里的`len`是验证码中数字的数量。
///
/// For [ArithmeticCaptcha](crate::captcha::arithmetic::ArithmeticCaptcha), the `len` presents the count of the digits /// For [ArithmeticCaptcha](crate::captcha::arithmetic::ArithmeticCaptcha), the `len` presents the count of the digits
/// in the Captcha. /// in the Captcha.
fn with_size_and_len(width: i32, height: i32, len: usize) -> Self; fn with_size_and_len(width: i32, height: i32, len: usize) -> Self;
@@ -226,7 +227,7 @@ impl NewCaptcha for Captcha {
let len = 5; let len = 5;
let width = 130; let width = 130;
let height = 48; let height = 48;
let char_type = CaptchaType::TypeDefault; let char_type = CaptchaType::Default;
let chars = None; let chars = None;
Self { Self {

View File

@@ -1,6 +1,4 @@
use rand::random;
use rand::{random};
/// 随机数工具类 /// 随机数工具类
pub(crate) struct Randoms { pub(crate) struct Randoms {

View File

@@ -10,7 +10,7 @@ use axum::response::Response;
use std::fmt::Debug; use std::fmt::Debug;
use tower_sessions::Session; use tower_sessions::Session;
const CAPTCHA_KEY: &'static str = "ez-captcha"; const CAPTCHA_KEY: &str = "ez-captcha";
/// Axum & Tower_Sessions /// Axum & Tower_Sessions
#[async_trait] #[async_trait]
@@ -32,7 +32,7 @@ pub trait CaptchaAxumTowerSessionStaticExt {
/// Verify the Captcha code, and return whether user's code is correct. /// Verify the Captcha code, and return whether user's code is correct.
async fn ver(code: &str, session: &Session) -> bool { async fn ver(code: &str, session: &Session) -> bool {
match session.get::<String>(CAPTCHA_KEY).await { match session.get::<String>(CAPTCHA_KEY).await {
Ok(Some(ans)) => ans.to_ascii_lowercase() == code.to_ascii_lowercase(), Ok(Some(ans)) => ans.eq_ignore_ascii_case(code),
_ => false, _ => false,
} }
} }

View File

@@ -1,7 +1,7 @@
pub mod axum_tower_sessions; pub mod axum_tower_sessions;
use super::base::captcha::AbstractCaptcha; use super::base::captcha::AbstractCaptcha;
use super::captcha::spec::SpecCaptcha; use super::builder::spec::SpecCaptcha;
use super::{CaptchaFont, NewCaptcha}; use super::{CaptchaFont, NewCaptcha};
/// 验证码工具类 - Captcha Utils /// 验证码工具类 - Captcha Utils

View File

@@ -117,7 +117,7 @@
#![allow(dead_code)] #![allow(dead_code)]
pub(crate) mod base; pub(crate) mod base;
pub mod captcha; pub mod builder;
pub mod extension; pub mod extension;
mod utils; mod utils;

View File

@@ -32,21 +32,24 @@ impl From<(u8, u8, u8)> for Color {
} }
} }
impl Into<(u8, u8, u8, u8)> for Color { impl From<Color> for (u8, u8, u8, u8) {
fn into(self) -> (u8, u8, u8, u8) { fn from(val: Color) -> Self {
( (
(self.0 * 255.0) as u8, (val.0 * 255.0) as u8,
(self.1 * 255.0) as u8, (val.1 * 255.0) as u8,
(self.2 * 255.0) as u8, (val.2 * 255.0) as u8,
(self.3 * 255.0) as u8, (val.3 * 255.0) as u8,
) )
} }
} }
impl Into<u32> for Color { impl From<Color> for u32 {
fn into(self) -> u32 { fn from(val: Color) -> Self {
let color: (u8, u8, u8, u8) = self.into(); let color: (u8, u8, u8, u8) = val.into();
(color.0 as u32) << 24 + (color.1 as u32) << 16 + (color.2 as u32) << 8 + (color.3 as u32) (color.0 as u32)
<< (24 + (color.1 as u32))
<< (16 + (color.2 as u32))
<< (8 + (color.3 as u32))
} }
} }

View File

@@ -11,7 +11,7 @@ struct FontAssets;
// pub(crate) static ref FONTS: RwLock<HashMap<String, Arc<Font>>> = Default::default(); // pub(crate) static ref FONTS: RwLock<HashMap<String, Arc<Font>>> = Default::default();
// } // }
pub fn get_font(font_name: &str) -> Option<Arc<Font>> { pub fn get_font(font_name: &'_ str) -> Option<Arc<Font<'_>>> {
// let fonts_cell = FONTS.get_or_init(|| Default::default()); // let fonts_cell = FONTS.get_or_init(|| Default::default());
// let guard = fonts_cell.read(); // let guard = fonts_cell.read();
// //
@@ -31,7 +31,7 @@ pub fn get_font(font_name: &str) -> Option<Arc<Font>> {
// } // }
} }
pub fn load_font(font_name: &str) -> Result<Option<Font>, Box<dyn Error>> { pub fn load_font(font_name: &'_ str) -> Result<Option<Font<'_>>, Box<dyn Error>> {
match FontAssets::get(font_name) { match FontAssets::get(font_name) {
Some(assets) => { Some(assets) => {
let font = Font::try_from_vec(Vec::from(assets.data)).unwrap(); let font = Font::try_from_vec(Vec::from(assets.data)).unwrap();

View File

@@ -143,7 +143,7 @@ impl RestfulServer {
return Err((StatusCode::UNAUTHORIZED, other_error("No such user").into())); return Err((StatusCode::UNAUTHORIZED, other_error("No such user").into()));
}; };
let machines = client_mgr.list_machine_by_user_id(user.id().clone()).await; let machines = client_mgr.list_machine_by_user_id(user.id()).await;
Ok(GetSummaryJsonResp { Ok(GetSummaryJsonResp {
device_count: machines.len() as u32, device_count: machines.len() as u32,

View File

@@ -8,7 +8,7 @@ use axum_login::AuthUser;
use easytier::launcher::NetworkConfig; use easytier::launcher::NetworkConfig;
use easytier::proto::common::Void; use easytier::proto::common::Void;
use easytier::proto::rpc_types::controller::BaseController; use easytier::proto::rpc_types::controller::BaseController;
use easytier::proto::web::*; use easytier::proto::{self, web::*};
use crate::client_manager::session::{Location, Session}; use crate::client_manager::session::{Location, Session};
use crate::client_manager::ClientManager; use crate::client_manager::ClientManager;
@@ -85,7 +85,7 @@ impl NetworkApi {
let Some(user_id) = auth_session.user.as_ref().map(|x| x.id()) else { let Some(user_id) = auth_session.user.as_ref().map(|x| x.id()) else {
return Err(( return Err((
StatusCode::UNAUTHORIZED, StatusCode::UNAUTHORIZED,
other_error(format!("No user id found")).into(), other_error("No user id found".to_string()).into(),
)); ));
}; };
Ok(user_id) Ok(user_id)
@@ -108,7 +108,7 @@ impl NetworkApi {
let Some(token) = result.get_token().await else { let Some(token) = result.get_token().await else {
return Err(( return Err((
StatusCode::UNAUTHORIZED, StatusCode::UNAUTHORIZED,
other_error(format!("No token reported")).into(), other_error("No token reported".to_string()).into(),
)); ));
}; };
@@ -120,7 +120,7 @@ impl NetworkApi {
{ {
return Err(( return Err((
StatusCode::FORBIDDEN, StatusCode::FORBIDDEN,
other_error(format!("Token mismatch")).into(), other_error("Token mismatch".to_string()).into(),
)); ));
} }
@@ -177,7 +177,7 @@ impl NetworkApi {
.insert_or_update_user_network_config( .insert_or_update_user_network_config(
auth_session.user.as_ref().unwrap().id(), auth_session.user.as_ref().unwrap().id(),
machine_id, machine_id,
resp.inst_id.clone().unwrap_or_default().into(), resp.inst_id.unwrap_or_default().into(),
serde_json::to_string(&config).unwrap(), serde_json::to_string(&config).unwrap(),
) )
.await .await
@@ -248,7 +248,7 @@ impl NetworkApi {
.await .await
.map_err(convert_rpc_error)?; .map_err(convert_rpc_error)?;
let running_inst_ids = ret.inst_ids.clone().into_iter().map(Into::into).collect(); let running_inst_ids = ret.inst_ids.clone().into_iter().collect();
// collect networks that are disabled // collect networks that are disabled
let disabled_inst_ids = client_mgr let disabled_inst_ids = client_mgr
@@ -261,7 +261,7 @@ impl NetworkApi {
.await .await
.map_err(convert_db_error)? .map_err(convert_db_error)?
.iter() .iter()
.filter_map(|x| x.network_instance_id.clone().try_into().ok()) .map(|x| Into::<proto::common::Uuid>::into(x.network_instance_id.clone()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Ok(ListNetworkInstanceIdsJsonResp { Ok(ListNetworkInstanceIdsJsonResp {
@@ -330,9 +330,8 @@ impl NetworkApi {
// not implement disable all // not implement disable all
return Err(( return Err((
StatusCode::NOT_IMPLEMENTED, StatusCode::NOT_IMPLEMENTED,
other_error(format!("Not implemented")).into(), other_error("Not implemented".to_string()).into(),
)) ));
.into();
}; };
let sess = Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?; let sess = Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;

View File

@@ -76,32 +76,32 @@ impl Backend {
pub async fn register_new_user(&self, new_user: &RegisterNewUser) -> anyhow::Result<()> { pub async fn register_new_user(&self, new_user: &RegisterNewUser) -> anyhow::Result<()> {
let hashed_password = password_auth::generate_hash(new_user.credentials.password.as_str()); let hashed_password = password_auth::generate_hash(new_user.credentials.password.as_str());
let mut txn = self.db.orm_db().begin().await?; let txn = self.db.orm_db().begin().await?;
entity::users::ActiveModel { entity::users::ActiveModel {
username: Set(new_user.credentials.username.clone()), username: Set(new_user.credentials.username.clone()),
password: Set(hashed_password.clone()), password: Set(hashed_password.clone()),
..Default::default() ..Default::default()
} }
.save(&mut txn) .save(&txn)
.await?; .await?;
entity::users_groups::ActiveModel { entity::users_groups::ActiveModel {
user_id: Set(entity::users::Entity::find() user_id: Set(entity::users::Entity::find()
.filter(entity::users::Column::Username.eq(new_user.credentials.username.as_str())) .filter(entity::users::Column::Username.eq(new_user.credentials.username.as_str()))
.one(&mut txn) .one(&txn)
.await? .await?
.unwrap() .unwrap()
.id), .id),
group_id: Set(entity::groups::Entity::find() group_id: Set(entity::groups::Entity::find()
.filter(entity::groups::Column::Name.eq("users")) .filter(entity::groups::Column::Name.eq("users"))
.one(&mut txn) .one(&txn)
.await? .await?
.unwrap() .unwrap()
.id), .id),
..Default::default() ..Default::default()
} }
.save(&mut txn) .save(&txn)
.await?; .await?;
txn.commit().await?; txn.commit().await?;

View File

@@ -52,9 +52,7 @@ pub fn build_router(api_host: Option<url::Url>) -> Router {
router router
}; };
let router = router.fallback_service(service); router.fallback_service(service)
router
} }
pub struct WebServer { pub struct WebServer {

View File

@@ -8,7 +8,7 @@ edition = "2021"
authors = ["kkrainbow"] authors = ["kkrainbow"]
keywords = ["vpn", "p2p", "network", "easytier"] keywords = ["vpn", "p2p", "network", "easytier"]
categories = ["network-programming", "command-line-utilities"] categories = ["network-programming", "command-line-utilities"]
rust-version = "1.87.0" rust-version = "1.89.0"
license-file = "LICENSE" license-file = "LICENSE"
readme = "README.md" readme = "README.md"

View File

@@ -116,7 +116,7 @@ fn check_locale() {
if let Ok(globs) = globwalk::glob(locale_path) { if let Ok(globs) = globwalk::glob(locale_path) {
for entry in globs { for entry in globs {
if let Err(e) = entry { if let Err(e) = entry {
println!("cargo:i18n-error={}", e); println!("cargo:i18n-error={e}");
continue; continue;
} }
@@ -151,7 +151,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
]; ];
for proto_file in proto_files.iter().chain(proto_files_reflect.iter()) { for proto_file in proto_files.iter().chain(proto_files_reflect.iter()) {
println!("cargo:rerun-if-changed={}", proto_file); println!("cargo:rerun-if-changed={proto_file}");
} }
let mut config = prost_build::Config::new(); let mut config = prost_build::Config::new();
@@ -173,7 +173,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.field_attribute(".web.NetworkConfig", "#[serde(default)]") .field_attribute(".web.NetworkConfig", "#[serde(default)]")
.service_generator(Box::new(rpc_build::ServiceGenerator::new())) .service_generator(Box::new(rpc_build::ServiceGenerator::new()))
.btree_map(["."]) .btree_map(["."])
.skip_debug(&[".common.Ipv4Addr", ".common.Ipv6Addr", ".common.UUID"]); .skip_debug([".common.Ipv4Addr", ".common.Ipv6Addr", ".common.UUID"]);
config.compile_protos(&proto_files, &["src/proto/"])?; config.compile_protos(&proto_files, &["src/proto/"])?;

View File

@@ -178,6 +178,12 @@ impl AclLogContext {
} }
} }
pub type SharedState = (
Arc<DashMap<String, ConnTrackEntry>>,
Arc<DashMap<RateLimitKey, Arc<TokenBucket>>>,
Arc<DashMap<AclStatKey, u64>>,
);
// High-performance ACL processor - No more internal locks! // High-performance ACL processor - No more internal locks!
pub struct AclProcessor { pub struct AclProcessor {
// Immutable rule vectors - no locks needed since they're never modified after creation // Immutable rule vectors - no locks needed since they're never modified after creation
@@ -321,7 +327,7 @@ impl AclProcessor {
.rules .rules
.iter() .iter()
.filter(|rule| rule.enabled) .filter(|rule| rule.enabled)
.map(|rule| Self::convert_to_fast_lookup_rule(rule)) .map(Self::convert_to_fast_lookup_rule)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// Sort by priority (higher priority first) // Sort by priority (higher priority first)
@@ -422,7 +428,7 @@ impl AclProcessor {
self.inc_cache_entry_stats(cache_entry, packet_info); self.inc_cache_entry_stats(cache_entry, packet_info);
return cache_entry.acl_result.clone().unwrap(); cache_entry.acl_result.clone().unwrap()
} }
fn inc_cache_entry_stats(&self, cache_entry: &AclCacheEntry, packet_info: &PacketInfo) { fn inc_cache_entry_stats(&self, cache_entry: &AclCacheEntry, packet_info: &PacketInfo) {
@@ -539,7 +545,7 @@ impl AclProcessor {
cache_entry.rule_stats_vec.push(rule.rule_stats.clone()); cache_entry.rule_stats_vec.push(rule.rule_stats.clone());
cache_entry.matched_rule = RuleId::Priority(rule.priority); cache_entry.matched_rule = RuleId::Priority(rule.priority);
cache_entry.acl_result = Some(AclResult { cache_entry.acl_result = Some(AclResult {
action: rule.action.clone(), action: rule.action,
matched_rule: Some(RuleId::Priority(rule.priority)), matched_rule: Some(RuleId::Priority(rule.priority)),
should_log: false, should_log: false,
log_context: Some(AclLogContext::RuleMatch { log_context: Some(AclLogContext::RuleMatch {
@@ -595,13 +601,7 @@ impl AclProcessor {
} }
/// Get shared state for preserving across hot reloads /// Get shared state for preserving across hot reloads
pub fn get_shared_state( pub fn get_shared_state(&self) -> SharedState {
&self,
) -> (
Arc<DashMap<String, ConnTrackEntry>>,
Arc<DashMap<RateLimitKey, Arc<TokenBucket>>>,
Arc<DashMap<AclStatKey, u64>>,
) {
( (
self.conn_track.clone(), self.conn_track.clone(),
self.rate_limiters.clone(), self.rate_limiters.clone(),
@@ -698,9 +698,9 @@ impl AclProcessor {
} }
/// Check connection state for stateful rules /// Check connection state for stateful rules
fn check_connection_state(&self, conn_track_key: &String, packet_info: &PacketInfo) { fn check_connection_state(&self, conn_track_key: &str, packet_info: &PacketInfo) {
self.conn_track self.conn_track
.entry(conn_track_key.clone()) .entry(conn_track_key.to_string())
.and_modify(|x| { .and_modify(|x| {
x.last_seen = SystemTime::now() x.last_seen = SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
@@ -764,13 +764,13 @@ impl AclProcessor {
let src_ip_ranges = rule let src_ip_ranges = rule
.source_ips .source_ips
.iter() .iter()
.filter_map(|ip_inet| Self::convert_ip_inet_to_cidr(ip_inet)) .filter_map(|x| Self::convert_ip_inet_to_cidr(x.as_str()))
.collect(); .collect();
let dst_ip_ranges = rule let dst_ip_ranges = rule
.destination_ips .destination_ips
.iter() .iter()
.filter_map(|ip_inet| Self::convert_ip_inet_to_cidr(ip_inet)) .filter_map(|x| Self::convert_ip_inet_to_cidr(x.as_str()))
.collect(); .collect();
let src_port_ranges = rule let src_port_ranges = rule
@@ -820,8 +820,8 @@ impl AclProcessor {
} }
/// Convert IpInet to CIDR for fast lookup /// Convert IpInet to CIDR for fast lookup
fn convert_ip_inet_to_cidr(input: &String) -> Option<cidr::IpCidr> { fn convert_ip_inet_to_cidr(input: &str) -> Option<cidr::IpCidr> {
cidr::IpCidr::from_str(input.as_str()).ok() cidr::IpCidr::from_str(input).ok()
} }
/// Increment statistics counter /// Increment statistics counter
@@ -898,17 +898,13 @@ impl AclProcessor {
} }
// 新增辅助函数 // 新增辅助函数
fn parse_port_start( fn parse_port_start(port_strs: &[String]) -> Option<u16> {
port_strs: &::prost::alloc::vec::Vec<::prost::alloc::string::String>,
) -> Option<u16> {
port_strs port_strs
.iter() .iter()
.filter_map(|s| parse_port_range(s).map(|(start, _)| start)) .filter_map(|s| parse_port_range(s).map(|(start, _)| start))
.min() .min()
} }
fn parse_port_end( fn parse_port_end(port_strs: &[String]) -> Option<u16> {
port_strs: &::prost::alloc::vec::Vec<::prost::alloc::string::String>,
) -> Option<u16> {
port_strs port_strs
.iter() .iter()
.filter_map(|s| parse_port_range(s).map(|(_, end)| end)) .filter_map(|s| parse_port_range(s).map(|(_, end)| end))
@@ -1154,18 +1150,22 @@ mod tests {
let mut acl_v1 = AclV1::default(); let mut acl_v1 = AclV1::default();
// Create inbound chain // Create inbound chain
let mut chain = Chain::default(); let mut chain = Chain {
chain.name = "test_inbound".to_string(); name: "test_inbound".to_string(),
chain.chain_type = ChainType::Inbound as i32; chain_type: ChainType::Inbound as i32,
chain.enabled = true; enabled: true,
..Default::default()
};
// Allow all rule // Allow all rule
let mut rule = Rule::default(); let rule = Rule {
rule.name = "allow_all".to_string(); name: "allow_all".to_string(),
rule.priority = 100; priority: 100,
rule.enabled = true; enabled: true,
rule.action = Action::Allow as i32; action: Action::Allow as i32,
rule.protocol = Protocol::Any as i32; protocol: Protocol::Any as i32,
..Default::default()
};
chain.rules.push(rule); chain.rules.push(rule);
acl_v1.chains.push(chain); acl_v1.chains.push(chain);
@@ -1278,12 +1278,14 @@ mod tests {
// 创建新配置(模拟热加载) // 创建新配置(模拟热加载)
let mut new_config = create_test_acl_config(); let mut new_config = create_test_acl_config();
if let Some(ref mut acl_v1) = new_config.acl_v1 { if let Some(ref mut acl_v1) = new_config.acl_v1 {
let mut drop_rule = Rule::default(); let drop_rule = Rule {
drop_rule.name = "drop_all".to_string(); name: "drop_all".to_string(),
drop_rule.priority = 200; priority: 200,
drop_rule.enabled = true; enabled: true,
drop_rule.action = Action::Drop as i32; action: Action::Drop as i32,
drop_rule.protocol = Protocol::Any as i32; protocol: Protocol::Any as i32,
..Default::default()
};
acl_v1.chains[0].rules.push(drop_rule); acl_v1.chains[0].rules.push(drop_rule);
} }
@@ -1321,40 +1323,48 @@ mod tests {
let mut acl_config = Acl::default(); let mut acl_config = Acl::default();
let mut acl_v1 = AclV1::default(); let mut acl_v1 = AclV1::default();
let mut chain = Chain::default(); let mut chain = Chain {
chain.name = "performance_test".to_string(); name: "performance_test".to_string(),
chain.chain_type = ChainType::Inbound as i32; chain_type: ChainType::Inbound as i32,
chain.enabled = true; enabled: true,
..Default::default()
};
// 1. High-priority simple rule for UDP (can be cached efficiently) // 1. High-priority simple rule for UDP (can be cached efficiently)
let mut simple_rule = Rule::default(); let simple_rule = Rule {
simple_rule.name = "simple_udp".to_string(); name: "simple_udp".to_string(),
simple_rule.priority = 300; priority: 300,
simple_rule.enabled = true; enabled: true,
simple_rule.action = Action::Allow as i32; action: Action::Allow as i32,
simple_rule.protocol = Protocol::Udp as i32; protocol: Protocol::Udp as i32,
..Default::default()
};
// No stateful or rate limit - can benefit from full cache optimization // No stateful or rate limit - can benefit from full cache optimization
chain.rules.push(simple_rule); chain.rules.push(simple_rule);
// 2. Medium-priority stateful + rate-limited rule for TCP (security critical) // 2. Medium-priority stateful + rate-limited rule for TCP (security critical)
let mut security_rule = Rule::default(); let security_rule = Rule {
security_rule.name = "security_tcp".to_string(); name: "security_tcp".to_string(),
security_rule.priority = 200; priority: 200,
security_rule.enabled = true; enabled: true,
security_rule.action = Action::Allow as i32; action: Action::Allow as i32,
security_rule.protocol = Protocol::Tcp as i32; protocol: Protocol::Tcp as i32,
security_rule.stateful = true; stateful: true,
security_rule.rate_limit = 100; // 100 packets/sec rate_limit: 100,
security_rule.burst_limit = 200; burst_limit: 200,
..Default::default()
};
chain.rules.push(security_rule); chain.rules.push(security_rule);
// 3. Low-priority default allow rule for Any // 3. Low-priority default allow rule for Any
let mut default_rule = Rule::default(); let default_rule = Rule {
default_rule.name = "default_allow".to_string(); name: "default_allow".to_string(),
default_rule.priority = 100; priority: 100,
default_rule.enabled = true; enabled: true,
default_rule.action = Action::Allow as i32; action: Action::Allow as i32,
default_rule.protocol = Protocol::Any as i32; protocol: Protocol::Any as i32,
..Default::default()
};
chain.rules.push(default_rule); chain.rules.push(default_rule);
acl_v1.chains.push(chain); acl_v1.chains.push(chain);
@@ -1441,15 +1451,16 @@ mod tests {
// Create a very restrictive rate-limited rule // Create a very restrictive rate-limited rule
if let Some(ref mut acl_v1) = acl_config.acl_v1 { if let Some(ref mut acl_v1) = acl_config.acl_v1 {
let mut rule = Rule::default(); let rule = Rule {
rule.name = "strict_rate_limit".to_string(); name: "strict_rate_limit".to_string(),
rule.priority = 200; priority: 200,
rule.enabled = true; enabled: true,
rule.action = Action::Allow as i32; action: Action::Allow as i32,
rule.protocol = Protocol::Any as i32; protocol: Protocol::Any as i32,
rule.rate_limit = 1; // Allow only 1 packet per second rate_limit: 1, // Allow only 1 packet per second
rule.burst_limit = 1; // Burst of 1 packet burst_limit: 1, // Burst of 1 packet
..Default::default()
};
acl_v1.chains[0].rules.push(rule); acl_v1.chains[0].rules.push(rule);
} }

View File

@@ -21,6 +21,12 @@ pub trait Compressor {
pub struct DefaultCompressor {} pub struct DefaultCompressor {}
impl Default for DefaultCompressor {
fn default() -> Self {
Self::new()
}
}
impl DefaultCompressor { impl DefaultCompressor {
pub fn new() -> Self { pub fn new() -> Self {
DefaultCompressor {} DefaultCompressor {}
@@ -195,11 +201,11 @@ pub mod tests {
packet, packet,
packet.payload_len() packet.payload_len()
); );
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), true); assert!(packet.peer_manager_header().unwrap().is_compressed());
compressor.decompress(&mut packet).await.unwrap(); compressor.decompress(&mut packet).await.unwrap();
assert_eq!(packet.payload(), text); assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false); assert!(!packet.peer_manager_header().unwrap().is_compressed());
} }
#[tokio::test] #[tokio::test]
@@ -215,10 +221,10 @@ pub mod tests {
.compress(&mut packet, CompressorAlgo::ZstdDefault) .compress(&mut packet, CompressorAlgo::ZstdDefault)
.await .await
.unwrap(); .unwrap();
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false); assert!(!packet.peer_manager_header().unwrap().is_compressed());
compressor.decompress(&mut packet).await.unwrap(); compressor.decompress(&mut packet).await.unwrap();
assert_eq!(packet.payload(), text); assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false); assert!(!packet.peer_manager_header().unwrap().is_compressed());
} }
} }

View File

@@ -1,8 +1,8 @@
use std::{ use std::{
hash::Hasher,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
path::PathBuf, path::PathBuf,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
u64,
}; };
use anyhow::Context; use anyhow::Context;
@@ -48,7 +48,7 @@ pub fn gen_default_flags() -> Flags {
disable_quic_input: false, disable_quic_input: false,
foreign_relay_bps_limit: u64::MAX, foreign_relay_bps_limit: u64::MAX,
multi_thread_count: 2, multi_thread_count: 2,
encryption_algorithm: "".to_string(), // 空字符串表示使用默认的 AES-GCM encryption_algorithm: "aes-gcm".to_string(),
} }
} }
@@ -210,7 +210,7 @@ pub trait LoggingConfigLoader {
pub type NetworkSecretDigest = [u8; 32]; pub type NetworkSecretDigest = [u8; 32];
#[derive(Debug, Clone, Deserialize, Serialize, Default, Eq, Hash)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct NetworkIdentity { pub struct NetworkIdentity {
pub network_name: String, pub network_name: String,
pub network_secret: Option<String>, pub network_secret: Option<String>,
@@ -218,27 +218,53 @@ pub struct NetworkIdentity {
pub network_secret_digest: Option<NetworkSecretDigest>, pub network_secret_digest: Option<NetworkSecretDigest>,
} }
#[derive(Eq, PartialEq, Hash)]
struct NetworkIdentityWithOnlyDigest {
network_name: String,
network_secret_digest: Option<NetworkSecretDigest>,
}
impl From<NetworkIdentity> for NetworkIdentityWithOnlyDigest {
fn from(identity: NetworkIdentity) -> Self {
if identity.network_secret_digest.is_some() {
Self {
network_name: identity.network_name,
network_secret_digest: identity.network_secret_digest,
}
} else if identity.network_secret.is_some() {
let mut network_secret_digest = [0u8; 32];
generate_digest_from_str(
&identity.network_name,
identity.network_secret.as_ref().unwrap(),
&mut network_secret_digest,
);
Self {
network_name: identity.network_name,
network_secret_digest: Some(network_secret_digest),
}
} else {
Self {
network_name: identity.network_name,
network_secret_digest: None,
}
}
}
}
impl PartialEq for NetworkIdentity { impl PartialEq for NetworkIdentity {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.network_name != other.network_name { let self_with_digest = NetworkIdentityWithOnlyDigest::from(self.clone());
return false; let other_with_digest = NetworkIdentityWithOnlyDigest::from(other.clone());
} self_with_digest == other_with_digest
}
}
if self.network_secret.is_some() impl Eq for NetworkIdentity {}
&& other.network_secret.is_some()
&& self.network_secret != other.network_secret
{
return false;
}
if self.network_secret_digest.is_some() impl std::hash::Hash for NetworkIdentity {
&& other.network_secret_digest.is_some() fn hash<H: Hasher>(&self, state: &mut H) {
&& self.network_secret_digest != other.network_secret_digest let self_with_digest = NetworkIdentityWithOnlyDigest::from(self.clone());
{ self_with_digest.hash(state);
return false;
}
return true;
} }
} }
@@ -253,8 +279,10 @@ impl NetworkIdentity {
network_secret_digest: Some(network_secret_digest), network_secret_digest: Some(network_secret_digest),
} }
} }
}
pub fn default() -> Self { impl Default for NetworkIdentity {
fn default() -> Self {
Self::new("default".to_string(), "".to_string()) Self::new("default".to_string(), "".to_string())
} }
} }
@@ -328,12 +356,12 @@ impl From<PortForwardConfigPb> for PortForwardConfig {
} }
} }
impl Into<PortForwardConfigPb> for PortForwardConfig { impl From<PortForwardConfig> for PortForwardConfigPb {
fn into(self) -> PortForwardConfigPb { fn from(val: PortForwardConfig) -> Self {
PortForwardConfigPb { PortForwardConfigPb {
bind_addr: Some(self.bind_addr.into()), bind_addr: Some(val.bind_addr.into()),
dst_addr: Some(self.dst_addr.into()), dst_addr: Some(val.dst_addr.into()),
socket_type: match self.proto.to_lowercase().as_str() { socket_type: match val.proto.to_lowercase().as_str() {
"tcp" => SocketType::Tcp as i32, "tcp" => SocketType::Tcp as i32,
"udp" => SocketType::Udp as i32, "udp" => SocketType::Udp as i32,
_ => SocketType::Tcp as i32, _ => SocketType::Tcp as i32,
@@ -493,8 +521,7 @@ impl ConfigLoader for TomlConfigLoader {
locked_config locked_config
.ipv4 .ipv4
.as_ref() .as_ref()
.map(|s| s.parse().ok()) .and_then(|s| s.parse().ok())
.flatten()
.map(|c: cidr::Ipv4Inet| { .map(|c: cidr::Ipv4Inet| {
if c.network_length() == 32 { if c.network_length() == 32 {
cidr::Ipv4Inet::new(c.address(), 24).unwrap() cidr::Ipv4Inet::new(c.address(), 24).unwrap()
@@ -505,28 +532,16 @@ impl ConfigLoader for TomlConfigLoader {
} }
fn set_ipv4(&self, addr: Option<cidr::Ipv4Inet>) { fn set_ipv4(&self, addr: Option<cidr::Ipv4Inet>) {
self.config.lock().unwrap().ipv4 = if let Some(addr) = addr { self.config.lock().unwrap().ipv4 = addr.map(|addr| addr.to_string());
Some(addr.to_string())
} else {
None
};
} }
fn get_ipv6(&self) -> Option<cidr::Ipv6Inet> { fn get_ipv6(&self) -> Option<cidr::Ipv6Inet> {
let locked_config = self.config.lock().unwrap(); let locked_config = self.config.lock().unwrap();
locked_config locked_config.ipv6.as_ref().and_then(|s| s.parse().ok())
.ipv6
.as_ref()
.map(|s| s.parse().ok())
.flatten()
} }
fn set_ipv6(&self, addr: Option<cidr::Ipv6Inet>) { fn set_ipv6(&self, addr: Option<cidr::Ipv6Inet>) {
self.config.lock().unwrap().ipv6 = if let Some(addr) = addr { self.config.lock().unwrap().ipv6 = addr.map(|addr| addr.to_string());
Some(addr.to_string())
} else {
None
};
} }
fn get_dhcp(&self) -> bool { fn get_dhcp(&self) -> bool {
@@ -600,7 +615,7 @@ impl ConfigLoader for TomlConfigLoader {
locked_config.instance_id = Some(id); locked_config.instance_id = Some(id);
id id
} else { } else {
locked_config.instance_id.as_ref().unwrap().clone() *locked_config.instance_id.as_ref().unwrap()
} }
} }
@@ -614,7 +629,7 @@ impl ConfigLoader for TomlConfigLoader {
.unwrap() .unwrap()
.network_identity .network_identity
.clone() .clone()
.unwrap_or_else(NetworkIdentity::default) .unwrap_or_default()
} }
fn set_network_identity(&self, identity: NetworkIdentity) { fn set_network_identity(&self, identity: NetworkIdentity) {

View File

@@ -8,14 +8,14 @@ macro_rules! define_global_var {
#[macro_export] #[macro_export]
macro_rules! use_global_var { macro_rules! use_global_var {
($name:ident) => { ($name:ident) => {
crate::common::constants::$name.lock().unwrap().to_owned() $crate::common::constants::$name.lock().unwrap().to_owned()
}; };
} }
#[macro_export] #[macro_export]
macro_rules! set_global_var { macro_rules! set_global_var {
($name:ident, $val:expr) => { ($name:ident, $val:expr) => {
*crate::common::constants::$name.lock().unwrap() = $val *$crate::common::constants::$name.lock().unwrap() = $val
}; };
} }

View File

@@ -12,7 +12,9 @@ impl<F: FnOnce()> Defer<F> {
impl<F: FnOnce()> Drop for Defer<F> { impl<F: FnOnce()> Drop for Defer<F> {
fn drop(&mut self) { fn drop(&mut self) {
self.func.take().map(|f| f()); if let Some(f) = self.func.take() {
f()
}
} }
} }

View File

@@ -48,19 +48,15 @@ pub static RESOLVER: Lazy<Arc<Resolver<GenericConnector<TokioRuntimeProvider>>>>
pub async fn resolve_txt_record(domain_name: &str) -> Result<String, Error> { pub async fn resolve_txt_record(domain_name: &str) -> Result<String, Error> {
let r = RESOLVER.clone(); let r = RESOLVER.clone();
let response = r.txt_lookup(domain_name).await.with_context(|| { let response = r
format!( .txt_lookup(domain_name)
"txt_lookup failed, domain_name: {}", .await
domain_name.to_string() .with_context(|| format!("txt_lookup failed, domain_name: {}", domain_name))?;
)
})?;
let txt_record = response.iter().next().with_context(|| { let txt_record = response
format!( .iter()
"no txt record found, domain_name: {}", .next()
domain_name.to_string() .with_context(|| format!("no txt record found, domain_name: {}", domain_name))?;
)
})?;
let txt_data = String::from_utf8_lossy(&txt_record.txt_data()[0]); let txt_data = String::from_utf8_lossy(&txt_record.txt_data()[0]);
tracing::info!(?txt_data, ?domain_name, "get txt record"); tracing::info!(?txt_data, ?domain_name, "get txt record");

View File

@@ -104,7 +104,7 @@ impl std::fmt::Debug for GlobalCtx {
pub type ArcGlobalCtx = std::sync::Arc<GlobalCtx>; pub type ArcGlobalCtx = std::sync::Arc<GlobalCtx>;
impl GlobalCtx { impl GlobalCtx {
pub fn new(config_fs: impl ConfigLoader + 'static + Send + Sync) -> Self { pub fn new(config_fs: impl ConfigLoader + 'static) -> Self {
let id = config_fs.get_id(); let id = config_fs.get_id();
let network = config_fs.get_network_identity(); let network = config_fs.get_network_identity();
let net_ns = NetNS::new(config_fs.get_netns()); let net_ns = NetNS::new(config_fs.get_netns());
@@ -118,9 +118,11 @@ impl GlobalCtx {
let proxy_forward_by_system = config_fs.get_flags().proxy_forward_by_system; let proxy_forward_by_system = config_fs.get_flags().proxy_forward_by_system;
let no_tun = config_fs.get_flags().no_tun; let no_tun = config_fs.get_flags().no_tun;
let mut feature_flags = PeerFeatureFlag::default(); let feature_flags = PeerFeatureFlag {
feature_flags.kcp_input = !config_fs.get_flags().disable_kcp_input; kcp_input: !config_fs.get_flags().disable_kcp_input,
feature_flags.no_relay_kcp = config_fs.get_flags().disable_relay_kcp; no_relay_kcp: config_fs.get_flags().disable_relay_kcp,
..Default::default()
};
GlobalCtx { GlobalCtx {
inst_name: config_fs.get_inst_name(), inst_name: config_fs.get_inst_name(),
@@ -185,7 +187,7 @@ impl GlobalCtx {
{ {
Ok(()) Ok(())
} else { } else {
Err(anyhow::anyhow!("network {} not in whitelist", network_name).into()) Err(anyhow::anyhow!("network {} not in whitelist", network_name))
} }
} }
@@ -194,8 +196,8 @@ impl GlobalCtx {
return Some(ret); return Some(ret);
} }
let addr = self.config.get_ipv4(); let addr = self.config.get_ipv4();
self.cached_ipv4.store(addr.clone()); self.cached_ipv4.store(addr);
return addr; addr
} }
pub fn set_ipv4(&self, addr: Option<cidr::Ipv4Inet>) { pub fn set_ipv4(&self, addr: Option<cidr::Ipv4Inet>) {
@@ -208,8 +210,8 @@ impl GlobalCtx {
return Some(ret); return Some(ret);
} }
let addr = self.config.get_ipv6(); let addr = self.config.get_ipv6();
self.cached_ipv6.store(addr.clone()); self.cached_ipv6.store(addr);
return addr; addr
} }
pub fn set_ipv6(&self, addr: Option<cidr::Ipv6Inet>) { pub fn set_ipv6(&self, addr: Option<cidr::Ipv6Inet>) {
@@ -376,18 +378,18 @@ pub mod tests {
let mut subscriber = global_ctx.subscribe(); let mut subscriber = global_ctx.subscribe();
let peer_id = new_peer_id(); let peer_id = new_peer_id();
global_ctx.issue_event(GlobalCtxEvent::PeerAdded(peer_id.clone())); global_ctx.issue_event(GlobalCtxEvent::PeerAdded(peer_id));
global_ctx.issue_event(GlobalCtxEvent::PeerRemoved(peer_id.clone())); global_ctx.issue_event(GlobalCtxEvent::PeerRemoved(peer_id));
global_ctx.issue_event(GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default())); global_ctx.issue_event(GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default()));
global_ctx.issue_event(GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default())); global_ctx.issue_event(GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default()));
assert_eq!( assert_eq!(
subscriber.recv().await.unwrap(), subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerAdded(peer_id.clone()) GlobalCtxEvent::PeerAdded(peer_id)
); );
assert_eq!( assert_eq!(
subscriber.recv().await.unwrap(), subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerRemoved(peer_id.clone()) GlobalCtxEvent::PeerRemoved(peer_id)
); );
assert_eq!( assert_eq!(
subscriber.recv().await.unwrap(), subscriber.recv().await.unwrap(),
@@ -404,7 +406,7 @@ pub mod tests {
) -> ArcGlobalCtx { ) -> ArcGlobalCtx {
let config_fs = TomlConfigLoader::default(); let config_fs = TomlConfigLoader::default();
config_fs.set_inst_name(format!("test_{}", config_fs.get_id())); config_fs.set_inst_name(format!("test_{}", config_fs.get_id()));
config_fs.set_network_identity(network_identy.unwrap_or(NetworkIdentity::default())); config_fs.set_network_identity(network_identy.unwrap_or_default());
let ctx = Arc::new(GlobalCtx::new(config_fs)); let ctx = Arc::new(GlobalCtx::new(config_fs));
ctx.replace_stun_info_collector(Box::new(MockStunInfoCollector { ctx.replace_stun_info_collector(Box::new(MockStunInfoCollector {

View File

@@ -1,6 +1,6 @@
#[cfg(any(target_os = "macos", target_os = "freebsd"))] #[cfg(any(target_os = "macos", target_os = "freebsd"))]
mod darwin; mod darwin;
#[cfg(any(target_os = "linux"))] #[cfg(target_os = "linux")]
mod netlink; mod netlink;
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
mod win; mod win;
@@ -141,7 +141,7 @@ pub struct DummyIfConfiger {}
#[async_trait] #[async_trait]
impl IfConfiguerTrait for DummyIfConfiger {} impl IfConfiguerTrait for DummyIfConfiger {}
#[cfg(any(target_os = "linux"))] #[cfg(target_os = "linux")]
pub type IfConfiger = netlink::NetlinkIfConfiger; pub type IfConfiger = netlink::NetlinkIfConfiger;
#[cfg(any(target_os = "macos", target_os = "freebsd"))] #[cfg(any(target_os = "macos", target_os = "freebsd"))]

View File

@@ -85,14 +85,14 @@ fn send_netlink_req_and_wait_one_resp<T: NetlinkDeserializable + NetlinkSerializ
match ret.payload { match ret.payload {
NetlinkPayload::Error(e) => { NetlinkPayload::Error(e) => {
if e.code == NonZero::new(0) { if e.code == NonZero::new(0) {
return Ok(()); Ok(())
} else { } else {
return Err(e.to_io().into()); Err(e.to_io().into())
} }
} }
p => { p => {
tracing::error!("Unexpected netlink response: {:?}", p); tracing::error!("Unexpected netlink response: {:?}", p);
return Err(anyhow::anyhow!("Unexpected netlink response").into()); Err(anyhow::anyhow!("Unexpected netlink response").into())
} }
} }
} }
@@ -263,8 +263,8 @@ impl NetlinkIfConfiger {
let (address, netmask) = match (address.family(), netmask.family()) { let (address, netmask) = match (address.family(), netmask.family()) {
(Some(Inet), Some(Inet)) => ( (Some(Inet), Some(Inet)) => (
IpAddr::V4(address.as_sockaddr_in().unwrap().ip().into()), IpAddr::V4(address.as_sockaddr_in().unwrap().ip()),
IpAddr::V4(netmask.as_sockaddr_in().unwrap().ip().into()), IpAddr::V4(netmask.as_sockaddr_in().unwrap().ip()),
), ),
(Some(Inet6), Some(Inet6)) => ( (Some(Inet6), Some(Inet6)) => (
IpAddr::V6(address.as_sockaddr_in6().unwrap().ip()), IpAddr::V6(address.as_sockaddr_in6().unwrap().ip()),
@@ -333,7 +333,7 @@ impl NetlinkIfConfiger {
let mut resp = Vec::<u8>::new(); let mut resp = Vec::<u8>::new();
loop { loop {
if resp.len() == 0 { if resp.is_empty() {
let (new_resp, _) = s.recv_from_full()?; let (new_resp, _) = s.recv_from_full()?;
resp = new_resp; resp = new_resp;
} }

View File

@@ -727,7 +727,7 @@ impl InterfaceLuid {
if family == (AF_INET6 as ADDRESS_FAMILY) { if family == (AF_INET6 as ADDRESS_FAMILY) {
// ipv6 mtu must be at least 1280 // ipv6 mtu must be at least 1280
mtu = 1280.max(mtu); mtu = 1280.max(mtu);
} }
// https://stackoverflow.com/questions/54857292/setipinterfaceentry-returns-error-invalid-parameter // https://stackoverflow.com/questions/54857292/setipinterfaceentry-returns-error-invalid-parameter
row.SitePrefixLength = 0; row.SitePrefixLength = 0;

View File

@@ -1,3 +1,3 @@
pub mod luid;
pub mod netsh; pub mod netsh;
pub mod types; pub mod types;
pub mod luid;

View File

@@ -115,4 +115,4 @@ pub fn add_dns_ipv6(if_index: u32, dnses: &[Ipv6Addr]) -> Result<(), String> {
} }
let dnses_str: Vec<String> = dnses.iter().map(|addr| addr.to_string()).collect(); let dnses_str: Vec<String> = dnses.iter().map(|addr| addr.to_string()).collect();
add_dns(AF_INET6 as _, if_index, &dnses_str) add_dns(AF_INET6 as _, if_index, &dnses_str)
} }

View File

@@ -100,4 +100,4 @@ pub fn u16_ptr_to_string(ptr: *const u16) -> String {
let slice = unsafe { std::slice::from_raw_parts(ptr, len) }; let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
String::from_utf16_lossy(slice) String::from_utf16_lossy(slice)
} }

View File

@@ -140,8 +140,8 @@ pub fn get_machine_id() -> uuid::Uuid {
)))] )))]
let gen_mid = None; let gen_mid = None;
if gen_mid.is_some() { if let Some(mid) = gen_mid {
return gen_mid.unwrap(); return mid;
} }
let gen_mid = uuid::Uuid::new_v4(); let gen_mid = uuid::Uuid::new_v4();

View File

@@ -34,13 +34,12 @@ impl NetNSGuard {
return; return;
} }
let ns_path: String;
let name = name.unwrap(); let name = name.unwrap();
if name == ROOT_NETNS_NAME { let ns_path: String = if name == ROOT_NETNS_NAME {
ns_path = "/proc/1/ns/net".to_string(); "/proc/1/ns/net".to_string()
} else { } else {
ns_path = format!("/var/run/netns/{}", name); format!("/var/run/netns/{}", name)
} };
let ns = std::fs::File::open(ns_path).unwrap(); let ns = std::fs::File::open(ns_path).unwrap();
tracing::info!( tracing::info!(

View File

@@ -211,7 +211,7 @@ impl IPCollector {
cached_ip_list.read().await.public_ipv6 cached_ip_list.read().await.public_ipv6
); );
let sleep_sec = if !cached_ip_list.read().await.public_ipv4.is_none() { let sleep_sec = if cached_ip_list.read().await.public_ipv4.is_some() {
CACHED_IP_LIST_TIMEOUT_SEC CACHED_IP_LIST_TIMEOUT_SEC
} else { } else {
3 3
@@ -252,14 +252,11 @@ impl IPCollector {
for iface in ifaces { for iface in ifaces {
for ip in iface.ips { for ip in iface.ips {
let ip: std::net::IpAddr = ip.ip(); let ip: std::net::IpAddr = ip.ip();
match ip { if let std::net::IpAddr::V4(v4) = ip {
std::net::IpAddr::V4(v4) => { if ip.is_loopback() || ip.is_multicast() {
if ip.is_loopback() || ip.is_multicast() { continue;
continue;
}
ret.interface_ipv4s.push(v4.into());
} }
_ => {} ret.interface_ipv4s.push(v4.into());
} }
} }
} }
@@ -269,14 +266,11 @@ impl IPCollector {
for iface in ifaces { for iface in ifaces {
for ip in iface.ips { for ip in iface.ips {
let ip: std::net::IpAddr = ip.ip(); let ip: std::net::IpAddr = ip.ip();
match ip { if let std::net::IpAddr::V6(v6) = ip {
std::net::IpAddr::V6(v6) => { if v6.is_multicast() || v6.is_loopback() || v6.is_unicast_link_local() {
if v6.is_multicast() || v6.is_loopback() || v6.is_unicast_link_local() { continue;
continue;
}
ret.interface_ipv6s.push(v6.into());
} }
_ => {} ret.interface_ipv6s.push(v6.into());
} }
} }
} }

View File

@@ -277,6 +277,12 @@ pub struct UnsafeCounter {
value: UnsafeCell<u64>, value: UnsafeCell<u64>,
} }
impl Default for UnsafeCounter {
fn default() -> Self {
Self::new()
}
}
impl UnsafeCounter { impl UnsafeCounter {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@@ -396,14 +402,15 @@ impl MetricKey {
fn new(name: MetricName, labels: LabelSet) -> Self { fn new(name: MetricName, labels: LabelSet) -> Self {
Self { name, labels } Self { name, labels }
} }
}
/// Generate a string representation for this metric key impl fmt::Display for MetricKey {
fn to_string(&self) -> String { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let label_str = self.labels.to_key(); let label_str = self.labels.to_key();
if label_str.is_empty() { if label_str.is_empty() {
self.name.to_string() f.write_str(self.name.to_string().as_str())
} else { } else {
format!("{}[{}]", self.name, label_str) f.write_str(format!("{}[{}]", self.name, label_str).as_str())
} }
} }
} }

View File

@@ -282,9 +282,7 @@ impl StunClient {
.with_context(|| "encode stun message")?; .with_context(|| "encode stun message")?;
tids.push(tid); tids.push(tid);
tracing::trace!(?message, ?msg, tid, "send stun request"); tracing::trace!(?message, ?msg, tid, "send stun request");
self.socket self.socket.send_to(msg.as_slice(), &stun_host).await?;
.send_to(msg.as_slice().into(), &stun_host)
.await?;
} }
let now = Instant::now(); let now = Instant::now();
@@ -372,7 +370,7 @@ impl StunClientBuilder {
pub async fn stop(&mut self) { pub async fn stop(&mut self) {
self.task_set.abort_all(); self.task_set.abort_all();
while let Some(_) = self.task_set.join_next().await {} while self.task_set.join_next().await.is_some() {}
} }
} }
@@ -417,7 +415,7 @@ impl UdpNatTypeDetectResult {
return true; return true;
} }
} }
return false; false
} }
fn is_pat(&self) -> bool { fn is_pat(&self) -> bool {
@@ -457,16 +455,16 @@ impl UdpNatTypeDetectResult {
if self.is_cone() { if self.is_cone() {
if self.has_ip_changed_resp() { if self.has_ip_changed_resp() {
if self.is_open_internet() { if self.is_open_internet() {
return NatType::OpenInternet; NatType::OpenInternet
} else if self.is_pat() { } else if self.is_pat() {
return NatType::NoPat; NatType::NoPat
} else { } else {
return NatType::FullCone; NatType::FullCone
} }
} else if self.has_port_changed_resp() { } else if self.has_port_changed_resp() {
return NatType::Restricted; NatType::Restricted
} else { } else {
return NatType::PortRestricted; NatType::PortRestricted
} }
} else if !self.stun_resps.is_empty() { } else if !self.stun_resps.is_empty() {
if self.public_ips().len() != 1 if self.public_ips().len() != 1
@@ -480,7 +478,7 @@ impl UdpNatTypeDetectResult {
.mapped_socket_addr .mapped_socket_addr
.is_none() .is_none()
{ {
return NatType::Symmetric; NatType::Symmetric
} else { } else {
let extra_bind_test = self.extra_bind_test.as_ref().unwrap(); let extra_bind_test = self.extra_bind_test.as_ref().unwrap();
let extra_port = extra_bind_test.mapped_socket_addr.unwrap().port(); let extra_port = extra_bind_test.mapped_socket_addr.unwrap().port();
@@ -488,15 +486,15 @@ impl UdpNatTypeDetectResult {
let max_port_diff = extra_port.saturating_sub(self.max_port()); let max_port_diff = extra_port.saturating_sub(self.max_port());
let min_port_diff = self.min_port().saturating_sub(extra_port); let min_port_diff = self.min_port().saturating_sub(extra_port);
if max_port_diff != 0 && max_port_diff < 100 { if max_port_diff != 0 && max_port_diff < 100 {
return NatType::SymmetricEasyInc; NatType::SymmetricEasyInc
} else if min_port_diff != 0 && min_port_diff < 100 { } else if min_port_diff != 0 && min_port_diff < 100 {
return NatType::SymmetricEasyDec; NatType::SymmetricEasyDec
} else { } else {
return NatType::Symmetric; NatType::Symmetric
} }
} }
} else { } else {
return NatType::Unknown; NatType::Unknown
} }
} }
@@ -679,7 +677,7 @@ impl StunInfoCollectorTrait for StunInfoCollector {
.unwrap() .unwrap()
.clone() .clone()
.map(|x| x.collect_available_stun_server()) .map(|x| x.collect_available_stun_server())
.unwrap_or(vec![]); .unwrap_or_default();
if stun_servers.is_empty() { if stun_servers.is_empty() {
let mut host_resolver = let mut host_resolver =
@@ -740,7 +738,7 @@ impl StunInfoCollector {
pub fn get_default_servers() -> Vec<String> { pub fn get_default_servers() -> Vec<String> {
// NOTICE: we may need to choose stun stun server based on geo location // NOTICE: we may need to choose stun stun server based on geo location
// stun server cross nation may return a external ip address with high latency and loss rate // stun server cross nation may return a external ip address with high latency and loss rate
vec![ [
"txt:stun.easytier.cn", "txt:stun.easytier.cn",
"stun.miwifi.com", "stun.miwifi.com",
"stun.chat.bilibili.com", "stun.chat.bilibili.com",
@@ -752,16 +750,16 @@ impl StunInfoCollector {
} }
pub fn get_default_servers_v6() -> Vec<String> { pub fn get_default_servers_v6() -> Vec<String> {
vec!["txt:stun-v6.easytier.cn"] ["txt:stun-v6.easytier.cn"]
.iter() .iter()
.map(|x| x.to_string()) .map(|x| x.to_string())
.collect() .collect()
} }
async fn get_public_ipv6(servers: &Vec<String>) -> Option<Ipv6Addr> { async fn get_public_ipv6(servers: &[String]) -> Option<Ipv6Addr> {
let mut ips = HostResolverIter::new(servers.to_vec(), 10, true); let mut ips = HostResolverIter::new(servers.to_vec(), 10, true);
while let Some(ip) = ips.next().await { while let Some(ip) = ips.next().await {
let Ok(udp_socket) = UdpSocket::bind(format!("[::]:0")).await else { let Ok(udp_socket) = UdpSocket::bind("[::]:0".to_string()).await else {
break; break;
}; };
let udp = Arc::new(udp_socket); let udp = Arc::new(udp_socket);
@@ -770,11 +768,8 @@ impl StunInfoCollector {
.bind_request(false, false) .bind_request(false, false)
.await; .await;
tracing::debug!(?ret, "finish ipv6 udp nat type detect"); tracing::debug!(?ret, "finish ipv6 udp nat type detect");
match ret.map(|x| x.mapped_socket_addr.map(|x| x.ip())) { if let Ok(Some(IpAddr::V6(v6))) = ret.map(|x| x.mapped_socket_addr.map(|x| x.ip())) {
Ok(Some(IpAddr::V6(v6))) => { return Some(v6);
return Some(v6);
}
_ => {}
} }
} }
None None
@@ -854,9 +849,9 @@ impl StunInfoCollector {
self.tasks.lock().unwrap().spawn(async move { self.tasks.lock().unwrap().spawn(async move {
loop { loop {
let servers = stun_servers.read().unwrap().clone(); let servers = stun_servers.read().unwrap().clone();
Self::get_public_ipv6(&servers) if let Some(x) = Self::get_public_ipv6(&servers).await {
.await stored_ipv6.store(Some(x))
.map(|x| stored_ipv6.store(Some(x))); }
let sleep_sec = if stored_ipv6.load().is_none() { let sleep_sec = if stored_ipv6.load().is_none() {
60 60

View File

@@ -34,7 +34,7 @@ impl From<LimiterConfig> for BucketConfig {
.unwrap_or(Duration::from_millis(10)); .unwrap_or(Duration::from_millis(10));
BucketConfig { BucketConfig {
capacity: burst_rate * fill_rate, capacity: burst_rate * fill_rate,
fill_rate: fill_rate, fill_rate,
refill_interval, refill_interval,
} }
} }
@@ -162,6 +162,12 @@ pub struct TokenBucketManager {
retain_task: ScopedTask<()>, retain_task: ScopedTask<()>,
} }
impl Default for TokenBucketManager {
fn default() -> Self {
Self::new()
}
}
impl TokenBucketManager { impl TokenBucketManager {
/// Creates a new TokenBucketManager /// Creates a new TokenBucketManager
pub fn new() -> Self { pub fn new() -> Self {
@@ -318,7 +324,7 @@ mod tests {
// Should have accumulated about 100 tokens (10,000 tokens/s * 0.001s) // Should have accumulated about 100 tokens (10,000 tokens/s * 0.001s)
let tokens = bucket.available_tokens.load(Ordering::Relaxed); let tokens = bucket.available_tokens.load(Ordering::Relaxed);
assert!( assert!(
tokens >= 100 && tokens <= 200, (100..=200).contains(&tokens),
"Unexpected token count: {}", "Unexpected token count: {}",
tokens tokens
); );
@@ -355,8 +361,7 @@ mod tests {
.list_foreign_networks() .list_foreign_networks()
.await .await
.foreign_networks .foreign_networks
.len() .is_empty()
== 0
}, },
Duration::from_secs(5), Duration::from_secs(5),
) )
@@ -370,8 +375,7 @@ mod tests {
.get_global_ctx() .get_global_ctx()
.token_bucket_manager() .token_bucket_manager()
.buckets .buckets
.len() .is_empty()
== 0
}, },
Duration::from_secs(10), Duration::from_secs(10),
) )

View File

@@ -180,16 +180,13 @@ impl DirectConnectorManagerData {
// ask remote to send v6 hole punch packet // ask remote to send v6 hole punch packet
// and no matter what the result is, continue to connect // and no matter what the result is, continue to connect
let _ = self let _ = self
.remote_send_v6_hole_punch_packet(dst_peer_id, &local_socket, &remote_url) .remote_send_v6_hole_punch_packet(dst_peer_id, &local_socket, remote_url)
.await; .await;
let udp_connector = UdpTunnelConnector::new(remote_url.clone()); let udp_connector = UdpTunnelConnector::new(remote_url.clone());
let remote_addr = super::check_scheme_and_get_socket_addr::<SocketAddr>( let remote_addr =
&remote_url, super::check_scheme_and_get_socket_addr::<SocketAddr>(remote_url, "udp", IpVersion::V6)
"udp", .await?;
IpVersion::V6,
)
.await?;
let ret = udp_connector let ret = udp_connector
.try_connect_with_socket(local_socket, remote_addr) .try_connect_with_socket(local_socket, remote_addr)
.await?; .await?;
@@ -233,8 +230,8 @@ impl DirectConnectorManagerData {
dst_peer_id: PeerId, dst_peer_id: PeerId,
addr: String, addr: String,
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut rand_gen = rand::rngs::OsRng::default(); let mut rand_gen = rand::rngs::OsRng;
let backoff_ms = vec![1000, 2000, 4000]; let backoff_ms = [1000, 2000, 4000];
let mut backoff_idx = 0; let mut backoff_idx = 0;
tracing::debug!(?dst_peer_id, ?addr, "try_connect_to_ip start"); tracing::debug!(?dst_peer_id, ?addr, "try_connect_to_ip start");
@@ -243,10 +240,7 @@ impl DirectConnectorManagerData {
if self if self
.dst_listener_blacklist .dst_listener_blacklist
.contains(&DstListenerUrlBlackListItem( .contains(&DstListenerUrlBlackListItem(dst_peer_id, addr.clone()))
dst_peer_id.clone(),
addr.clone(),
))
{ {
return Err(Error::UrlInBlacklist); return Err(Error::UrlInBlacklist);
} }
@@ -281,7 +275,7 @@ impl DirectConnectorManagerData {
continue; continue;
} else { } else {
self.dst_listener_blacklist.insert( self.dst_listener_blacklist.insert(
DstListenerUrlBlackListItem(dst_peer_id.clone(), addr), DstListenerUrlBlackListItem(dst_peer_id, addr),
(), (),
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC), std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
); );
@@ -315,7 +309,7 @@ impl DirectConnectorManagerData {
if addr.set_host(Some(ip.to_string().as_str())).is_ok() { if addr.set_host(Some(ip.to_string().as_str())).is_ok() {
tasks.spawn(Self::try_connect_to_ip( tasks.spawn(Self::try_connect_to_ip(
self.clone(), self.clone(),
dst_peer_id.clone(), dst_peer_id,
addr.to_string(), addr.to_string(),
)); ));
} else { } else {
@@ -330,7 +324,7 @@ impl DirectConnectorManagerData {
} else if !s_addr.ip().is_loopback() || TESTING.load(Ordering::Relaxed) { } else if !s_addr.ip().is_loopback() || TESTING.load(Ordering::Relaxed) {
tasks.spawn(Self::try_connect_to_ip( tasks.spawn(Self::try_connect_to_ip(
self.clone(), self.clone(),
dst_peer_id.clone(), dst_peer_id,
listener.to_string(), listener.to_string(),
)); ));
} }
@@ -355,13 +349,10 @@ impl DirectConnectorManagerData {
.iter() .iter()
.for_each(|ip| { .for_each(|ip| {
let mut addr = (*listener).clone(); let mut addr = (*listener).clone();
if addr if addr.set_host(Some(format!("[{}]", ip).as_str())).is_ok() {
.set_host(Some(format!("[{}]", ip.to_string()).as_str()))
.is_ok()
{
tasks.spawn(Self::try_connect_to_ip( tasks.spawn(Self::try_connect_to_ip(
self.clone(), self.clone(),
dst_peer_id.clone(), dst_peer_id,
addr.to_string(), addr.to_string(),
)); ));
} else { } else {
@@ -376,7 +367,7 @@ impl DirectConnectorManagerData {
} else if !s_addr.ip().is_loopback() || TESTING.load(Ordering::Relaxed) { } else if !s_addr.ip().is_loopback() || TESTING.load(Ordering::Relaxed) {
tasks.spawn(Self::try_connect_to_ip( tasks.spawn(Self::try_connect_to_ip(
self.clone(), self.clone(),
dst_peer_id.clone(), dst_peer_id,
listener.to_string(), listener.to_string(),
)); ));
} }
@@ -436,13 +427,8 @@ impl DirectConnectorManagerData {
} }
tracing::debug!("try direct connect to peer with listener: {}", listener); tracing::debug!("try direct connect to peer with listener: {}", listener);
self.spawn_direct_connect_task( self.spawn_direct_connect_task(dst_peer_id, &ip_list, listener, &mut tasks)
dst_peer_id.clone(), .await;
&ip_list,
&listener,
&mut tasks,
)
.await;
listener_list.push(listener.clone().to_string()); listener_list.push(listener.clone().to_string());
available_listeners.pop(); available_listeners.pop();

View File

@@ -124,11 +124,11 @@ impl DNSTunnelConnector {
let responses = responses.clone(); let responses = responses.clone();
async move { async move {
let response = resolver.srv_lookup(srv_domain).await.with_context(|| { let response = resolver.srv_lookup(srv_domain).await.with_context(|| {
format!("srv_lookup failed, srv_domain: {}", srv_domain.to_string()) format!("srv_lookup failed, srv_domain: {}", srv_domain)
})?; })?;
tracing::info!(?response, ?srv_domain, "srv_lookup response"); tracing::info!(?response, ?srv_domain, "srv_lookup response");
for record in response.iter() { for record in response.iter() {
let parsed_record = Self::handle_one_srv_record(record, &protocol); let parsed_record = Self::handle_one_srv_record(record, protocol);
tracing::info!(?parsed_record, ?srv_domain, "parsed_record"); tracing::info!(?parsed_record, ?srv_domain, "parsed_record");
if parsed_record.is_err() { if parsed_record.is_err() {
eprintln!( eprintln!(
@@ -153,8 +153,7 @@ impl DNSTunnelConnector {
let url = weighted_choice(srv_records.as_slice()).with_context(|| { let url = weighted_choice(srv_records.as_slice()).with_context(|| {
format!( format!(
"failed to choose a srv record, domain_name: {}, srv_records: {:?}", "failed to choose a srv record, domain_name: {}, srv_records: {:?}",
domain_name.to_string(), domain_name, srv_records
srv_records
) )
})?; })?;

View File

@@ -93,7 +93,7 @@ impl HttpTunnelConnector {
tracing::info!("try to create connector by url: {}", query[0]); tracing::info!("try to create connector by url: {}", query[0]);
self.redirect_type = HttpRedirectType::RedirectToQuery; self.redirect_type = HttpRedirectType::RedirectToQuery;
return create_connector_by_url( return create_connector_by_url(
&query[0].to_string(), query[0].as_ref(),
&self.global_ctx, &self.global_ctx,
self.ip_version, self.ip_version,
) )
@@ -193,7 +193,7 @@ impl HttpTunnelConnector {
.ok_or_else(|| Error::InvalidUrl("no redirect address found".to_string()))?; .ok_or_else(|| Error::InvalidUrl("no redirect address found".to_string()))?;
let new_url = url::Url::parse(redirect_url.as_str()) let new_url = url::Url::parse(redirect_url.as_str())
.with_context(|| format!("parsing redirect url failed. url: {}", redirect_url))?; .with_context(|| format!("parsing redirect url failed. url: {}", redirect_url))?;
return self.handle_302_redirect(new_url, &redirect_url).await; return self.handle_302_redirect(new_url, redirect_url).await;
} else if res.status_code().is_success() { } else if res.status_code().is_success() {
return self.handle_200_success(&body).await; return self.handle_200_success(&body).await;
} else { } else {

View File

@@ -131,7 +131,7 @@ impl ManualConnectorManager {
.data .data
.connectors .connectors
.iter() .iter()
.map(|x| x.key().clone().into()) .map(|x| x.key().clone())
.collect(); .collect();
let dead_urls: BTreeSet<String> = Self::collect_dead_conns(self.data.clone()) let dead_urls: BTreeSet<String> = Self::collect_dead_conns(self.data.clone())
@@ -155,12 +155,8 @@ impl ManualConnectorManager {
); );
} }
let reconnecting_urls: BTreeSet<String> = self let reconnecting_urls: BTreeSet<String> =
.data self.data.reconnecting.iter().map(|x| x.clone()).collect();
.reconnecting
.iter()
.map(|x| x.clone().into())
.collect();
for conn_url in reconnecting_urls { for conn_url in reconnecting_urls {
ret.insert( ret.insert(
@@ -282,7 +278,7 @@ impl ManualConnectorManager {
let remove_later = DashSet::new(); let remove_later = DashSet::new();
for it in data.removed_conn_urls.iter() { for it in data.removed_conn_urls.iter() {
let url = it.key(); let url = it.key();
if let Some(_) = data.connectors.remove(url) { if data.connectors.remove(url).is_some() {
tracing::warn!("connector: {}, removed", url); tracing::warn!("connector: {}, removed", url);
continue; continue;
} else if data.reconnecting.contains(url) { } else if data.reconnecting.contains(url) {
@@ -301,11 +297,7 @@ impl ManualConnectorManager {
async fn collect_dead_conns(data: Arc<ConnectorManagerData>) -> BTreeSet<String> { async fn collect_dead_conns(data: Arc<ConnectorManagerData>) -> BTreeSet<String> {
Self::handle_remove_connector(data.clone()); Self::handle_remove_connector(data.clone());
let all_urls: BTreeSet<String> = data let all_urls: BTreeSet<String> = data.connectors.iter().map(|x| x.key().clone()).collect();
.connectors
.iter()
.map(|x| x.key().clone().into())
.collect();
let mut ret = BTreeSet::new(); let mut ret = BTreeSet::new();
for url in all_urls.iter() { for url in all_urls.iter() {
if !data.alive_conn_urls.contains(url) { if !data.alive_conn_urls.contains(url) {
@@ -400,21 +392,28 @@ impl ManualConnectorManager {
.await; .await;
tracing::info!("reconnect: {} done, ret: {:?}", dead_url, ret); tracing::info!("reconnect: {} done, ret: {:?}", dead_url, ret);
if ret.is_ok() && ret.as_ref().unwrap().is_ok() { match ret {
reconn_ret = ret.unwrap(); Ok(Ok(_)) => {
break; // 外层和内层都成功:解包并跳出
} else { reconn_ret = ret.unwrap();
if ret.is_err() { break;
reconn_ret = Err(ret.unwrap_err().into()); }
} else if ret.as_ref().unwrap().is_err() { Ok(Err(e)) => {
reconn_ret = Err(ret.unwrap().unwrap_err()); // 外层成功,内层失败
reconn_ret = Err(e);
}
Err(e) => {
// 外层失败
reconn_ret = Err(e.into());
} }
data.global_ctx.issue_event(GlobalCtxEvent::ConnectError(
dead_url.clone(),
format!("{:?}", ip_version),
format!("{:?}", reconn_ret),
));
} }
// 发送事件(只有在未 break 时才执行)
data.global_ctx.issue_event(GlobalCtxEvent::ConnectError(
dead_url.clone(),
format!("{:?}", ip_version),
format!("{:?}", reconn_ret),
));
} }
reconn_ret reconn_ret

View File

@@ -389,7 +389,7 @@ pub mod tests {
let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40164").await.unwrap()); let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40164").await.unwrap());
// 144 - DST_PORT_OFFSET = 124 // 144 - DST_PORT_OFFSET = 124
let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40124").await.unwrap()); let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40124").await.unwrap());
let udps = vec![udp1, udp2]; let udps = [udp1, udp2];
let counter = Arc::new(AtomicU32::new(0)); let counter = Arc::new(AtomicU32::new(0));

View File

@@ -67,9 +67,9 @@ impl From<NatType> for UdpNatType {
} }
} }
impl Into<NatType> for UdpNatType { impl From<UdpNatType> for NatType {
fn into(self) -> NatType { fn from(val: UdpNatType) -> Self {
match self { match val {
UdpNatType::Unknown => NatType::Unknown, UdpNatType::Unknown => NatType::Unknown,
UdpNatType::Open(nat_type) => nat_type, UdpNatType::Open(nat_type) => nat_type,
UdpNatType::Cone(nat_type) => nat_type, UdpNatType::Cone(nat_type) => nat_type,
@@ -249,7 +249,7 @@ impl UdpSocketArray {
tracing::info!(?addr, ?tid, "got hole punching packet with intreast tid"); tracing::info!(?addr, ?tid, "got hole punching packet with intreast tid");
tid_to_socket tid_to_socket
.entry(tid) .entry(tid)
.or_insert_with(Vec::new) .or_default()
.push(PunchedUdpSocket { .push(PunchedUdpSocket {
socket: socket.clone(), socket: socket.clone(),
tid, tid,
@@ -556,7 +556,7 @@ impl PunchHoleServerCommon {
#[tracing::instrument(err, ret(level=Level::DEBUG), skip(ports))] #[tracing::instrument(err, ret(level=Level::DEBUG), skip(ports))]
pub(crate) async fn send_symmetric_hole_punch_packet( pub(crate) async fn send_symmetric_hole_punch_packet(
ports: &Vec<u16>, ports: &[u16],
udp: Arc<UdpSocket>, udp: Arc<UdpSocket>,
transaction_id: u32, transaction_id: u32,
public_ips: &Vec<Ipv4Addr>, public_ips: &Vec<Ipv4Addr>,
@@ -628,5 +628,5 @@ pub(crate) async fn try_connect_with_socket(
connector connector
.try_connect_with_socket(socket, remote_mapped_addr) .try_connect_with_socket(socket, remote_mapped_addr)
.await .await
.map_err(|e| Error::from(e)) .map_err(Error::from)
} }

View File

@@ -172,7 +172,7 @@ impl PunchConeHoleClient {
udp_array udp_array
.send_with_all( .send_with_all(
&new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(), &new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(),
remote_mapped_addr.clone().into(), remote_mapped_addr.into(),
) )
.await .await
.with_context(|| "failed to send hole punch packet from local") .with_context(|| "failed to send hole punch packet from local")
@@ -188,7 +188,7 @@ impl PunchConeHoleClient {
..Default::default() ..Default::default()
}, },
SendPunchPacketConeRequest { SendPunchPacketConeRequest {
listener_mapped_addr: Some(remote_mapped_addr.into()), listener_mapped_addr: Some(remote_mapped_addr),
dest_addr: Some(local_mapped_addr.into()), dest_addr: Some(local_mapped_addr.into()),
transaction_id: tid, transaction_id: tid,
packet_count_per_batch: 2, packet_count_per_batch: 2,

View File

@@ -39,7 +39,7 @@ pub(crate) mod cone;
pub(crate) mod sym_to_cone; pub(crate) mod sym_to_cone;
// sym punch should be serialized // sym punch should be serialized
static SYM_PUNCH_LOCK: Lazy<DashMap<PeerId, Arc<Mutex<()>>>> = Lazy::new(|| DashMap::new()); static SYM_PUNCH_LOCK: Lazy<DashMap<PeerId, Arc<Mutex<()>>>> = Lazy::new(DashMap::new);
pub static RUN_TESTING: Lazy<AtomicBool> = Lazy::new(|| AtomicBool::new(false)); pub static RUN_TESTING: Lazy<AtomicBool> = Lazy::new(|| AtomicBool::new(false));
// Blacklist timeout in seconds // Blacklist timeout in seconds
@@ -223,7 +223,7 @@ impl UdpHoePunchConnectorData {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
async fn handle_punch_result( async fn handle_punch_result(
self: &Self, &self,
ret: Result<Option<Box<dyn Tunnel>>, Error>, ret: Result<Option<Box<dyn Tunnel>>, Error>,
backoff: Option<&mut BackOff>, backoff: Option<&mut BackOff>,
round: Option<&mut u32>, round: Option<&mut u32>,
@@ -236,10 +236,8 @@ impl UdpHoePunchConnectorData {
if let Some(round) = round { if let Some(round) = round {
*round = round.saturating_sub(1); *round = round.saturating_sub(1);
} }
} else { } else if let Some(round) = round {
if let Some(round) = round { *round += 1;
*round += 1;
}
} }
}; };
@@ -464,7 +462,7 @@ impl PeerTaskLauncher for UdpHolePunchPeerTaskLauncher {
} }
let conns = data.peer_mgr.list_peer_conns(peer_id).await; let conns = data.peer_mgr.list_peer_conns(peer_id).await;
if conns.is_some() && conns.unwrap().len() > 0 { if conns.is_some() && !conns.unwrap().is_empty() {
continue; continue;
} }

View File

@@ -80,9 +80,9 @@ impl PunchSymToConeHoleServer {
let public_ips = request let public_ips = request
.public_ips .public_ips
.into_iter() .into_iter()
.map(|ip| std::net::Ipv4Addr::from(ip)) .map(std::net::Ipv4Addr::from)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if public_ips.len() == 0 { if public_ips.is_empty() {
tracing::warn!("send_punch_packet_easy_sym got zero len public ip"); tracing::warn!("send_punch_packet_easy_sym got zero len public ip");
return Err( return Err(
anyhow::anyhow!("send_punch_packet_easy_sym got zero len public ip").into(), anyhow::anyhow!("send_punch_packet_easy_sym got zero len public ip").into(),
@@ -158,9 +158,9 @@ impl PunchSymToConeHoleServer {
let public_ips = request let public_ips = request
.public_ips .public_ips
.into_iter() .into_iter()
.map(|ip| std::net::Ipv4Addr::from(ip)) .map(std::net::Ipv4Addr::from)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if public_ips.len() == 0 { if public_ips.is_empty() {
tracing::warn!("try_punch_symmetric got zero len public ip"); tracing::warn!("try_punch_symmetric got zero len public ip");
return Err(anyhow::anyhow!("try_punch_symmetric got zero len public ip").into()); return Err(anyhow::anyhow!("try_punch_symmetric got zero len public ip").into());
} }
@@ -281,7 +281,7 @@ impl PunchSymToConeHoleClient {
return; return;
}; };
let req = SendPunchPacketEasySymRequest { let req = SendPunchPacketEasySymRequest {
listener_mapped_addr: remote_mapped_addr.clone().into(), listener_mapped_addr: remote_mapped_addr.into(),
public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(), public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(),
transaction_id: tid, transaction_id: tid,
base_port_num: base_port_for_easy_sym.unwrap() as u32, base_port_num: base_port_for_easy_sym.unwrap() as u32,
@@ -313,7 +313,7 @@ impl PunchSymToConeHoleClient {
port_index: u32, port_index: u32,
) -> Option<u32> { ) -> Option<u32> {
let req = SendPunchPacketHardSymRequest { let req = SendPunchPacketHardSymRequest {
listener_mapped_addr: remote_mapped_addr.clone().into(), listener_mapped_addr: remote_mapped_addr.into(),
public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(), public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(),
transaction_id: tid, transaction_id: tid,
round, round,
@@ -333,9 +333,9 @@ impl PunchSymToConeHoleClient {
{ {
Err(e) => { Err(e) => {
tracing::error!(?e, "failed to send punch packet for hard sym"); tracing::error!(?e, "failed to send punch packet for hard sym");
return None; None
} }
Ok(resp) => return Some(resp.next_port_index), Ok(resp) => Some(resp.next_port_index),
} }
} }
@@ -366,7 +366,7 @@ impl PunchSymToConeHoleClient {
let mut finish_time: Option<Instant> = None; let mut finish_time: Option<Instant> = None;
while finish_time.is_none() || finish_time.as_ref().unwrap().elapsed().as_millis() < 1000 { while finish_time.is_none() || finish_time.as_ref().unwrap().elapsed().as_millis() < 1000 {
udp_array udp_array
.send_with_all(&packet, remote_mapped_addr.into()) .send_with_all(packet, remote_mapped_addr.into())
.await?; .await?;
tokio::time::sleep(Duration::from_millis(200)).await; tokio::time::sleep(Duration::from_millis(200)).await;
@@ -484,7 +484,7 @@ impl PunchSymToConeHoleClient {
rpc_stub, rpc_stub,
base_port_for_easy_sym, base_port_for_easy_sym,
my_nat_info, my_nat_info,
remote_mapped_addr.clone(), remote_mapped_addr,
public_ips.clone(), public_ips.clone(),
tid, tid,
)) ))
@@ -494,7 +494,7 @@ impl PunchSymToConeHoleClient {
&udp_array, &udp_array,
&packet, &packet,
tid, tid,
remote_mapped_addr.clone(), remote_mapped_addr,
&scoped_punch_task, &scoped_punch_task,
) )
.await?; .await?;
@@ -510,7 +510,7 @@ impl PunchSymToConeHoleClient {
let scoped_punch_task: ScopedTask<Option<u32>> = let scoped_punch_task: ScopedTask<Option<u32>> =
tokio::spawn(Self::remote_send_hole_punch_packet_random( tokio::spawn(Self::remote_send_hole_punch_packet_random(
rpc_stub, rpc_stub,
remote_mapped_addr.clone(), remote_mapped_addr,
public_ips.clone(), public_ips.clone(),
tid, tid,
round, round,
@@ -522,7 +522,7 @@ impl PunchSymToConeHoleClient {
&udp_array, &udp_array,
&packet, &packet,
tid, tid,
remote_mapped_addr.clone(), remote_mapped_addr,
&scoped_punch_task, &scoped_punch_task,
) )
.await?; .await?;

View File

@@ -4,7 +4,6 @@ use std::{
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
path::PathBuf, path::PathBuf,
str::FromStr, str::FromStr,
sync::Mutex,
time::Duration, time::Duration,
vec, vec,
}; };
@@ -30,15 +29,16 @@ use easytier::{
cli::{ cli::{
list_peer_route_pair, AclManageRpc, AclManageRpcClientFactory, AddPortForwardRequest, list_peer_route_pair, AclManageRpc, AclManageRpcClientFactory, AddPortForwardRequest,
ConnectorManageRpc, ConnectorManageRpcClientFactory, DumpRouteRequest, ConnectorManageRpc, ConnectorManageRpcClientFactory, DumpRouteRequest,
GetAclStatsRequest, GetPrometheusStatsRequest, GetStatsRequest, GetVpnPortalInfoRequest, GetWhitelistRequest, ListConnectorRequest, GetAclStatsRequest, GetPrometheusStatsRequest, GetStatsRequest,
GetVpnPortalInfoRequest, GetWhitelistRequest, ListConnectorRequest,
ListForeignNetworkRequest, ListGlobalForeignNetworkRequest, ListMappedListenerRequest, ListForeignNetworkRequest, ListGlobalForeignNetworkRequest, ListMappedListenerRequest,
ListPeerRequest, ListPeerResponse, ListPortForwardRequest, ListRouteRequest, ListPeerRequest, ListPeerResponse, ListPortForwardRequest, ListRouteRequest,
ListRouteResponse, ManageMappedListenerRequest, MappedListenerManageAction, ListRouteResponse, ManageMappedListenerRequest, MappedListenerManageAction,
MappedListenerManageRpc, MappedListenerManageRpcClientFactory, NodeInfo, PeerManageRpc, MappedListenerManageRpc, MappedListenerManageRpcClientFactory, NodeInfo, PeerManageRpc,
PeerManageRpcClientFactory, PortForwardManageRpc, PortForwardManageRpcClientFactory, PeerManageRpcClientFactory, PortForwardManageRpc, PortForwardManageRpcClientFactory,
RemovePortForwardRequest, SetWhitelistRequest, ShowNodeInfoRequest, StatsRpc, StatsRpcClientFactory, TcpProxyEntryState, RemovePortForwardRequest, SetWhitelistRequest, ShowNodeInfoRequest, StatsRpc,
TcpProxyEntryTransportType, TcpProxyRpc, TcpProxyRpcClientFactory, VpnPortalRpc, StatsRpcClientFactory, TcpProxyEntryState, TcpProxyEntryTransportType, TcpProxyRpc,
VpnPortalRpcClientFactory, TcpProxyRpcClientFactory, VpnPortalRpc, VpnPortalRpcClientFactory,
}, },
common::{NatType, SocketType}, common::{NatType, SocketType},
peer_rpc::{GetGlobalPeerMapRequest, PeerCenterRpc, PeerCenterRpcClientFactory}, peer_rpc::{GetGlobalPeerMapRequest, PeerCenterRpc, PeerCenterRpcClientFactory},
@@ -325,7 +325,7 @@ struct InstallArgs {
type Error = anyhow::Error; type Error = anyhow::Error;
struct CommandHandler<'a> { struct CommandHandler<'a> {
client: Mutex<RpcClient>, client: tokio::sync::Mutex<RpcClient>,
verbose: bool, verbose: bool,
output_format: &'a OutputFormat, output_format: &'a OutputFormat,
} }
@@ -339,7 +339,7 @@ impl CommandHandler<'_> {
Ok(self Ok(self
.client .client
.lock() .lock()
.unwrap() .await
.scoped_client::<PeerManageRpcClientFactory<BaseController>>("".to_string()) .scoped_client::<PeerManageRpcClientFactory<BaseController>>("".to_string())
.await .await
.with_context(|| "failed to get peer manager client")?) .with_context(|| "failed to get peer manager client")?)
@@ -351,7 +351,7 @@ impl CommandHandler<'_> {
Ok(self Ok(self
.client .client
.lock() .lock()
.unwrap() .await
.scoped_client::<ConnectorManageRpcClientFactory<BaseController>>("".to_string()) .scoped_client::<ConnectorManageRpcClientFactory<BaseController>>("".to_string())
.await .await
.with_context(|| "failed to get connector manager client")?) .with_context(|| "failed to get connector manager client")?)
@@ -363,7 +363,7 @@ impl CommandHandler<'_> {
Ok(self Ok(self
.client .client
.lock() .lock()
.unwrap() .await
.scoped_client::<MappedListenerManageRpcClientFactory<BaseController>>("".to_string()) .scoped_client::<MappedListenerManageRpcClientFactory<BaseController>>("".to_string())
.await .await
.with_context(|| "failed to get mapped listener manager client")?) .with_context(|| "failed to get mapped listener manager client")?)
@@ -375,7 +375,7 @@ impl CommandHandler<'_> {
Ok(self Ok(self
.client .client
.lock() .lock()
.unwrap() .await
.scoped_client::<PeerCenterRpcClientFactory<BaseController>>("".to_string()) .scoped_client::<PeerCenterRpcClientFactory<BaseController>>("".to_string())
.await .await
.with_context(|| "failed to get peer center client")?) .with_context(|| "failed to get peer center client")?)
@@ -387,7 +387,7 @@ impl CommandHandler<'_> {
Ok(self Ok(self
.client .client
.lock() .lock()
.unwrap() .await
.scoped_client::<VpnPortalRpcClientFactory<BaseController>>("".to_string()) .scoped_client::<VpnPortalRpcClientFactory<BaseController>>("".to_string())
.await .await
.with_context(|| "failed to get vpn portal client")?) .with_context(|| "failed to get vpn portal client")?)
@@ -399,7 +399,7 @@ impl CommandHandler<'_> {
Ok(self Ok(self
.client .client
.lock() .lock()
.unwrap() .await
.scoped_client::<AclManageRpcClientFactory<BaseController>>("".to_string()) .scoped_client::<AclManageRpcClientFactory<BaseController>>("".to_string())
.await .await
.with_context(|| "failed to get acl manager client")?) .with_context(|| "failed to get acl manager client")?)
@@ -412,7 +412,7 @@ impl CommandHandler<'_> {
Ok(self Ok(self
.client .client
.lock() .lock()
.unwrap() .await
.scoped_client::<TcpProxyRpcClientFactory<BaseController>>(transport_type.to_string()) .scoped_client::<TcpProxyRpcClientFactory<BaseController>>(transport_type.to_string())
.await .await
.with_context(|| "failed to get vpn portal client")?) .with_context(|| "failed to get vpn portal client")?)
@@ -424,7 +424,7 @@ impl CommandHandler<'_> {
Ok(self Ok(self
.client .client
.lock() .lock()
.unwrap() .await
.scoped_client::<PortForwardManageRpcClientFactory<BaseController>>("".to_string()) .scoped_client::<PortForwardManageRpcClientFactory<BaseController>>("".to_string())
.await .await
.with_context(|| "failed to get port forward manager client")?) .with_context(|| "failed to get port forward manager client")?)
@@ -436,7 +436,7 @@ impl CommandHandler<'_> {
Ok(self Ok(self
.client .client
.lock() .lock()
.unwrap() .await
.scoped_client::<StatsRpcClientFactory<BaseController>>("".to_string()) .scoped_client::<StatsRpcClientFactory<BaseController>>("".to_string())
.await .await
.with_context(|| "failed to get stats client")?) .with_context(|| "failed to get stats client")?)
@@ -865,7 +865,7 @@ impl CommandHandler<'_> {
Ok(()) Ok(())
} }
async fn handle_mapped_listener_add(&self, url: &String) -> Result<(), Error> { async fn handle_mapped_listener_add(&self, url: &str) -> Result<(), Error> {
let url = Self::mapped_listener_validate_url(url)?; let url = Self::mapped_listener_validate_url(url)?;
let client = self.get_mapped_listener_manager_client().await?; let client = self.get_mapped_listener_manager_client().await?;
let request = ManageMappedListenerRequest { let request = ManageMappedListenerRequest {
@@ -878,7 +878,7 @@ impl CommandHandler<'_> {
Ok(()) Ok(())
} }
async fn handle_mapped_listener_remove(&self, url: &String) -> Result<(), Error> { async fn handle_mapped_listener_remove(&self, url: &str) -> Result<(), Error> {
let url = Self::mapped_listener_validate_url(url)?; let url = Self::mapped_listener_validate_url(url)?;
let client = self.get_mapped_listener_manager_client().await?; let client = self.get_mapped_listener_manager_client().await?;
let request = ManageMappedListenerRequest { let request = ManageMappedListenerRequest {
@@ -891,7 +891,7 @@ impl CommandHandler<'_> {
Ok(()) Ok(())
} }
fn mapped_listener_validate_url(url: &String) -> Result<url::Url, Error> { fn mapped_listener_validate_url(url: &str) -> Result<url::Url, Error> {
let url = url::Url::parse(url)?; let url = url::Url::parse(url)?;
if url.scheme() != "tcp" && url.scheme() != "udp" { if url.scheme() != "tcp" && url.scheme() != "udp" {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
@@ -925,8 +925,8 @@ impl CommandHandler<'_> {
cfg: Some( cfg: Some(
PortForwardConfig { PortForwardConfig {
proto: protocol.to_string(), proto: protocol.to_string(),
bind_addr: bind_addr.into(), bind_addr,
dst_addr: dst_addr.into(), dst_addr,
} }
.into(), .into(),
), ),
@@ -961,11 +961,10 @@ impl CommandHandler<'_> {
cfg: Some( cfg: Some(
PortForwardConfig { PortForwardConfig {
proto: protocol.to_string(), proto: protocol.to_string(),
bind_addr: bind_addr.into(), bind_addr,
dst_addr: dst_addr dst_addr: dst_addr
.map(|s| s.parse::<SocketAddr>().unwrap()) .map(|s| s.parse::<SocketAddr>().unwrap())
.map(Into::into) .unwrap_or("0.0.0.0:0".parse::<SocketAddr>().unwrap()),
.unwrap_or("0.0.0.0:0".parse::<SocketAddr>().unwrap().into()),
} }
.into(), .into(),
), ),
@@ -1458,7 +1457,7 @@ async fn main() -> Result<(), Error> {
.unwrap(), .unwrap(),
)); ));
let handler = CommandHandler { let handler = CommandHandler {
client: Mutex::new(client), client: tokio::sync::Mutex::new(client),
verbose: cli.verbose, verbose: cli.verbose,
output_format: &cli.output_format, output_format: &cli.output_format,
}; };
@@ -1716,16 +1715,10 @@ async fn main() -> Result<(), Error> {
format!("{:?}", stun_info.udp_nat_type()).as_str(), format!("{:?}", stun_info.udp_nat_type()).as_str(),
]); ]);
ip_list.interface_ipv4s.iter().for_each(|ip| { ip_list.interface_ipv4s.iter().for_each(|ip| {
builder.push_record(vec![ builder.push_record(vec!["Interface IPv4", ip.to_string().as_str()]);
"Interface IPv4",
format!("{}", ip.to_string()).as_str(),
]);
}); });
ip_list.interface_ipv6s.iter().for_each(|ip| { ip_list.interface_ipv6s.iter().for_each(|ip| {
builder.push_record(vec![ builder.push_record(vec!["Interface IPv6", ip.to_string().as_str()]);
"Interface IPv6",
format!("{}", ip.to_string()).as_str(),
]);
}); });
for (idx, l) in node_info.listeners.iter().enumerate() { for (idx, l) in node_info.listeners.iter().enumerate() {
if l.starts_with("ring") { if l.starts_with("ring") {
@@ -1911,9 +1904,7 @@ async fn main() -> Result<(), Error> {
Some(StatsSubCommand::Show) | None => { Some(StatsSubCommand::Show) | None => {
let client = handler.get_stats_client().await?; let client = handler.get_stats_client().await?;
let request = GetStatsRequest {}; let request = GetStatsRequest {};
let response = client let response = client.get_stats(BaseController::default(), request).await?;
.get_stats(BaseController::default(), request)
.await?;
if cli.output_format == OutputFormat::Json { if cli.output_format == OutputFormat::Json {
println!("{}", serde_json::to_string_pretty(&response.metrics)?); println!("{}", serde_json::to_string_pretty(&response.metrics)?);
@@ -1942,7 +1933,7 @@ async fn main() -> Result<(), Error> {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", ") .join(", ")
}; };
let formatted_value = if metric.name.contains("bytes") { let formatted_value = if metric.name.contains("bytes") {
format_size(metric.value, humansize::BINARY) format_size(metric.value, humansize::BINARY)
} else if metric.name.contains("duration") { } else if metric.name.contains("duration") {

View File

@@ -61,7 +61,7 @@ pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:19\0
fn set_prof_active(_active: bool) { fn set_prof_active(_active: bool) {
#[cfg(feature = "jemalloc-prof")] #[cfg(feature = "jemalloc-prof")]
{ {
const PROF_ACTIVE: &'static [u8] = b"prof.active\0"; const PROF_ACTIVE: &[u8] = b"prof.active\0";
let name = PROF_ACTIVE.name(); let name = PROF_ACTIVE.name();
name.write(_active).expect("Should succeed to set prof"); name.write(_active).expect("Should succeed to set prof");
} }
@@ -70,7 +70,7 @@ fn set_prof_active(_active: bool) {
fn dump_profile(_cur_allocated: usize) { fn dump_profile(_cur_allocated: usize) {
#[cfg(feature = "jemalloc-prof")] #[cfg(feature = "jemalloc-prof")]
{ {
const PROF_DUMP: &'static [u8] = b"prof.dump\0"; const PROF_DUMP: &[u8] = b"prof.dump\0";
static mut PROF_DUMP_FILE_NAME: [u8; 128] = [0; 128]; static mut PROF_DUMP_FILE_NAME: [u8; 128] = [0; 128];
let file_name_str = format!( let file_name_str = format!(
"profile-{}-{}.out", "profile-{}-{}.out",
@@ -701,7 +701,7 @@ impl NetworkOptions {
.map(|s| s.parse().unwrap()) .map(|s| s.parse().unwrap())
.collect(), .collect(),
); );
} else if cfg.get_listeners() == None { } else if cfg.get_listeners().is_none() {
cfg.set_listeners( cfg.set_listeners(
Cli::parse_listeners(false, vec!["11010".to_string()])? Cli::parse_listeners(false, vec!["11010".to_string()])?
.into_iter() .into_iter()
@@ -740,7 +740,7 @@ impl NetworkOptions {
} }
for n in self.proxy_networks.iter() { for n in self.proxy_networks.iter() {
add_proxy_network_to_config(n, &cfg)?; add_proxy_network_to_config(n, cfg)?;
} }
let rpc_portal = if let Some(r) = &self.rpc_portal { let rpc_portal = if let Some(r) = &self.rpc_portal {
@@ -754,9 +754,9 @@ impl NetworkOptions {
cfg.set_rpc_portal(rpc_portal); cfg.set_rpc_portal(rpc_portal);
if let Some(rpc_portal_whitelist) = &self.rpc_portal_whitelist { if let Some(rpc_portal_whitelist) = &self.rpc_portal_whitelist {
let mut whitelist = cfg.get_rpc_portal_whitelist().unwrap_or_else(|| Vec::new()); let mut whitelist = cfg.get_rpc_portal_whitelist().unwrap_or_default();
for cidr in rpc_portal_whitelist { for cidr in rpc_portal_whitelist {
whitelist.push((*cidr).clone()); whitelist.push(*cidr);
} }
cfg.set_rpc_portal_whitelist(Some(whitelist)); cfg.set_rpc_portal_whitelist(Some(whitelist));
} }
@@ -825,18 +825,18 @@ impl NetworkOptions {
port_forward.port().expect("local bind port is missing") port_forward.port().expect("local bind port is missing")
) )
.parse() .parse()
.expect(format!("failed to parse local bind addr {}", example_str).as_str()); .unwrap_or_else(|_| panic!("failed to parse local bind addr {}", example_str));
let dst_addr = format!( let dst_addr = port_forward
"{}", .path_segments()
port_forward .unwrap_or_else(|| panic!("remote destination addr is missing {}", example_str))
.path_segments() .next()
.expect(format!("remote destination addr is missing {}", example_str).as_str()) .unwrap_or_else(|| panic!("remote destination addr is missing {}", example_str))
.next() .to_string()
.expect(format!("remote destination addr is missing {}", example_str).as_str()) .parse()
) .unwrap_or_else(|_| {
.parse() panic!("failed to parse remote destination addr {}", example_str)
.expect(format!("failed to parse remote destination addr {}", example_str).as_str()); });
let port_forward_item = PortForwardConfig { let port_forward_item = PortForwardConfig {
bind_addr, bind_addr,
@@ -1141,7 +1141,7 @@ async fn run_main(cli: Cli) -> anyhow::Result<()> {
let mut cfg = TomlConfigLoader::default(); let mut cfg = TomlConfigLoader::default();
cli.network_options cli.network_options
.merge_into(&mut cfg) .merge_into(&mut cfg)
.with_context(|| format!("failed to create config from cli"))?; .with_context(|| "failed to create config from cli".to_string())?;
println!("Starting easytier from cli with config:"); println!("Starting easytier from cli with config:");
println!("############### TOML ###############\n"); println!("############### TOML ###############\n");
println!("{}", cfg.dump()); println!("{}", cfg.dump());
@@ -1156,7 +1156,7 @@ async fn run_main(cli: Cli) -> anyhow::Result<()> {
.into_values() .into_values()
.filter_map(|info| info.error_msg) .filter_map(|info| info.error_msg)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if errs.len() > 0 { if !errs.is_empty() {
return Err(anyhow::anyhow!("some instances stopped with errors")); return Err(anyhow::anyhow!("some instances stopped with errors"));
} }
} }

View File

@@ -294,7 +294,7 @@ pub fn new_udp_header<T: ToTargetAddr>(target_addr: T) -> Result<Vec<u8>> {
} }
/// Parse data from UDP client on raw buffer, return (frag, target_addr, payload). /// Parse data from UDP client on raw buffer, return (frag, target_addr, payload).
pub async fn parse_udp_request<'a>(mut req: &'a [u8]) -> Result<(u8, TargetAddr, &'a [u8])> { pub async fn parse_udp_request(mut req: &[u8]) -> Result<(u8, TargetAddr, &[u8])> {
let rsv = read_exact!(req, [0u8; 2]).context("Malformed request")?; let rsv = read_exact!(req, [0u8; 2]).context("Malformed request")?;
if !rsv.eq(&[0u8; 2]) { if !rsv.eq(&[0u8; 2]) {

View File

@@ -455,16 +455,16 @@ impl<T: AsyncRead + AsyncWrite + Unpin, A: Authentication, C: AsyncTcpConnector>
info!("User logged successfully."); info!("User logged successfully.");
return Ok(credentials); Ok(credentials)
} else { } else {
self.inner self.inner
.write_all(&[1, consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE]) .write_all(&[1, consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE])
.await .await
.context("Can't reply with auth method not acceptable.")?; .context("Can't reply with auth method not acceptable.")?;
return Err(SocksError::AuthenticationRejected(format!( Err(SocksError::AuthenticationRejected(
"Authentication, rejected." "Authentication, rejected.".to_string(),
))); ))
} }
} }

View File

@@ -72,10 +72,7 @@ impl TargetAddr {
} }
pub fn is_ip(&self) -> bool { pub fn is_ip(&self) -> bool {
match self { matches!(self, TargetAddr::Ip(_))
TargetAddr::Ip(_) => true,
_ => false,
}
} }
pub fn is_domain(&self) -> bool { pub fn is_domain(&self) -> bool {
@@ -104,7 +101,7 @@ impl TargetAddr {
} }
TargetAddr::Domain(ref domain, port) => { TargetAddr::Domain(ref domain, port) => {
debug!("TargetAddr::Domain"); debug!("TargetAddr::Domain");
if domain.len() > u8::max_value() as usize { if domain.len() > u8::MAX as usize {
return Err(SocksError::ExceededMaxDomainLen(domain.len()).into()); return Err(SocksError::ExceededMaxDomainLen(domain.len()).into());
} }
buf.extend_from_slice(&[consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME, domain.len() as u8]); buf.extend_from_slice(&[consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME, domain.len() as u8]);
@@ -125,8 +122,7 @@ impl std::net::ToSocketAddrs for TargetAddr {
fn to_socket_addrs(&self) -> io::Result<IntoIter<SocketAddr>> { fn to_socket_addrs(&self) -> io::Result<IntoIter<SocketAddr>> {
match *self { match *self {
TargetAddr::Ip(addr) => Ok(vec![addr].into_iter()), TargetAddr::Ip(addr) => Ok(vec![addr].into_iter()),
TargetAddr::Domain(_, _) => Err(io::Error::new( TargetAddr::Domain(_, _) => Err(io::Error::other(
io::ErrorKind::Other,
"Domain name has to be explicitly resolved, please use TargetAddr::resolve_dns().", "Domain name has to be explicitly resolved, please use TargetAddr::resolve_dns().",
)), )),
} }
@@ -149,7 +145,7 @@ pub trait ToTargetAddr {
fn to_target_addr(&self) -> io::Result<TargetAddr>; fn to_target_addr(&self) -> io::Result<TargetAddr>;
} }
impl<'a> ToTargetAddr for (&'a str, u16) { impl ToTargetAddr for (&str, u16) {
fn to_target_addr(&self) -> io::Result<TargetAddr> { fn to_target_addr(&self) -> io::Result<TargetAddr> {
// try to parse as an IP first // try to parse as an IP first
if let Ok(addr) = self.0.parse::<Ipv4Addr>() { if let Ok(addr) = self.0.parse::<Ipv4Addr>() {

View File

@@ -23,6 +23,7 @@ use tracing::Instrument;
use crate::{ use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
gateway::ip_reassembler::ComposeIpv4PacketArgs,
peers::{peer_manager::PeerManager, PeerPacketFilter}, peers::{peer_manager::PeerManager, PeerPacketFilter},
tunnel::packet_def::{PacketType, ZCPacket}, tunnel::packet_def::{PacketType, ZCPacket},
}; };
@@ -118,7 +119,7 @@ fn socket_recv_loop(
} }
}; };
if len <= 0 { if len == 0 {
tracing::error!("recv empty packet, len: {}", len); tracing::error!("recv empty packet, len: {}", len);
return; return;
} }
@@ -158,20 +159,18 @@ fn socket_recv_loop(
let payload_len = len - ipv4_packet.get_header_length() as usize * 4; let payload_len = len - ipv4_packet.get_header_length() as usize * 4;
let id = ipv4_packet.get_identification(); let id = ipv4_packet.get_identification();
let _ = compose_ipv4_packet( let _ = compose_ipv4_packet(
&mut buf[..], ComposeIpv4PacketArgs {
&v.mapped_dst_ip, buf: &mut buf[..],
&dest_ip, src_v4: &v.mapped_dst_ip,
IpNextHeaderProtocols::Icmp, dst_v4: &dest_ip,
payload_len, next_protocol: IpNextHeaderProtocols::Icmp,
1200, payload_len,
id, payload_mtu: 1200,
ip_id: id,
},
|buf| { |buf| {
let mut p = ZCPacket::new_with_payload(buf); let mut p = ZCPacket::new_with_payload(buf);
p.fill_peer_manager_hdr( p.fill_peer_manager_hdr(v.my_peer_id, v.src_peer_id, PacketType::Data as u8);
v.my_peer_id.into(),
v.src_peer_id.into(),
PacketType::Data as u8,
);
p.mut_peer_manager_header().unwrap().set_no_proxy(true); p.mut_peer_manager_header().unwrap().set_no_proxy(true);
if let Err(e) = sender.send(p) { if let Err(e) = sender.send(p) {
@@ -186,7 +185,7 @@ fn socket_recv_loop(
#[async_trait::async_trait] #[async_trait::async_trait]
impl PeerPacketFilter for IcmpProxy { impl PeerPacketFilter for IcmpProxy {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> { async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
if let Some(_) = self.try_handle_peer_packet(&packet).await { if self.try_handle_peer_packet(&packet).await.is_some() {
return None; return None;
} else { } else {
return Some(packet); return Some(packet);
@@ -320,10 +319,7 @@ impl IcmpProxy {
.unwrap() .unwrap()
.as_ref() .as_ref()
.with_context(|| "icmp socket not created")? .with_context(|| "icmp socket not created")?
.send_to( .send_to(icmp_packet.packet(), &SocketAddrV4::new(dst_ip, 0).into())?;
icmp_packet.packet(),
&SocketAddrV4::new(dst_ip.into(), 0).into(),
)?;
Ok(()) Ok(())
} }
@@ -349,13 +345,15 @@ impl IcmpProxy {
let len = buf.len() - 20; let len = buf.len() - 20;
let _ = compose_ipv4_packet( let _ = compose_ipv4_packet(
&mut buf[..], ComposeIpv4PacketArgs {
src_ip, buf: &mut buf[..],
dst_ip, src_v4: src_ip,
IpNextHeaderProtocols::Icmp, dst_v4: dst_ip,
len, next_protocol: IpNextHeaderProtocols::Icmp,
1200, payload_len: len,
rand::random(), payload_mtu: 1200,
ip_id: rand::random(),
},
|buf| { |buf| {
let mut packet = ZCPacket::new_with_payload(buf); let mut packet = ZCPacket::new_with_payload(buf);
packet.fill_peer_manager_hdr(src_peer_id, dst_peer_id, PacketType::Data as u8); packet.fill_peer_manager_hdr(src_peer_id, dst_peer_id, PacketType::Data as u8);
@@ -387,7 +385,7 @@ impl IcmpProxy {
return None; return None;
}; };
let ipv4 = Ipv4Packet::new(&packet.payload())?; let ipv4 = Ipv4Packet::new(packet.payload())?;
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp
{ {
@@ -396,17 +394,17 @@ impl IcmpProxy {
let mut real_dst_ip = ipv4.get_destination(); let mut real_dst_ip = ipv4.get_destination();
if !self if !(self
.cidr_set .cidr_set
.contains_v4(ipv4.get_destination(), &mut real_dst_ip) .contains_v4(ipv4.get_destination(), &mut real_dst_ip)
&& !is_exit_node || is_exit_node
&& !(self.global_ctx.no_tun() || (self.global_ctx.no_tun()
&& Some(ipv4.get_destination()) && Some(ipv4.get_destination())
== self == self
.global_ctx .global_ctx
.get_ipv4() .get_ipv4()
.as_ref() .as_ref()
.map(cidr::Ipv4Inet::address)) .map(cidr::Ipv4Inet::address)))
{ {
return None; return None;
} }
@@ -416,12 +414,10 @@ impl IcmpProxy {
resembled_buf = resembled_buf =
self.ip_resemmbler self.ip_resemmbler
.add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4); .add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4);
if resembled_buf.is_none() { resembled_buf.as_ref()?;
return None;
};
icmp::echo_request::EchoRequestPacket::new(resembled_buf.as_ref().unwrap())? icmp::echo_request::EchoRequestPacket::new(resembled_buf.as_ref().unwrap())?
} else { } else {
icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())? icmp::echo_request::EchoRequestPacket::new(ipv4.payload())?
}; };
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest { if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
@@ -484,10 +480,9 @@ impl Drop for IcmpProxy {
"dropping icmp proxy, {:?}", "dropping icmp proxy, {:?}",
self.socket.lock().unwrap().as_ref() self.socket.lock().unwrap().as_ref()
); );
self.socket.lock().unwrap().as_ref().and_then(|s| { if let Some(s) = self.socket.lock().unwrap().as_ref() {
tracing::info!("shutting down icmp socket"); tracing::info!("shutting down icmp socket");
let _ = s.shutdown(std::net::Shutdown::Both); let _ = s.shutdown(std::net::Shutdown::Both);
Some(()) }
});
} }
} }

View File

@@ -190,33 +190,36 @@ impl IpReassembler {
} }
} }
pub struct ComposeIpv4PacketArgs<'a> {
pub buf: &'a mut [u8],
pub src_v4: &'a Ipv4Addr,
pub dst_v4: &'a Ipv4Addr,
pub next_protocol: IpNextHeaderProtocol,
pub payload_len: usize,
pub payload_mtu: usize,
pub ip_id: u16,
}
// ip payload should be in buf[20..] // ip payload should be in buf[20..]
pub fn compose_ipv4_packet<F>( pub fn compose_ipv4_packet<F>(args: ComposeIpv4PacketArgs, cb: F) -> Result<(), Error>
buf: &mut [u8],
src_v4: &Ipv4Addr,
dst_v4: &Ipv4Addr,
next_protocol: IpNextHeaderProtocol,
payload_len: usize,
payload_mtu: usize,
ip_id: u16,
cb: F,
) -> Result<(), Error>
where where
F: Fn(&[u8]) -> Result<(), Error>, F: Fn(&[u8]) -> Result<(), Error>,
{ {
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu; let total_pieces = args.payload_len.div_ceil(args.payload_mtu);
let mut buf_offset = 0; let mut buf_offset = 0;
let mut fragment_offset = 0; let mut fragment_offset = 0;
let mut cur_piece = 0; let mut cur_piece = 0;
while fragment_offset < payload_len { while fragment_offset < args.payload_len {
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len); let next_fragment_offset =
std::cmp::min(fragment_offset + args.payload_mtu, args.payload_len);
let fragment_len = next_fragment_offset - fragment_offset; let fragment_len = next_fragment_offset - fragment_offset;
let mut ipv4_packet = let mut ipv4_packet =
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20]).unwrap(); MutableIpv4Packet::new(&mut args.buf[buf_offset..buf_offset + fragment_len + 20])
.unwrap();
ipv4_packet.set_version(4); ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5); ipv4_packet.set_header_length(5);
ipv4_packet.set_total_length((fragment_len + 20) as u16); ipv4_packet.set_total_length((fragment_len + 20) as u16);
ipv4_packet.set_identification(ip_id); ipv4_packet.set_identification(args.ip_id);
if total_pieces > 1 { if total_pieces > 1 {
if cur_piece != total_pieces - 1 { if cur_piece != total_pieces - 1 {
ipv4_packet.set_flags(Ipv4Flags::MoreFragments); ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
@@ -232,9 +235,9 @@ where
ipv4_packet.set_ecn(0); ipv4_packet.set_ecn(0);
ipv4_packet.set_dscp(0); ipv4_packet.set_dscp(0);
ipv4_packet.set_ttl(32); ipv4_packet.set_ttl(32);
ipv4_packet.set_source(src_v4.clone()); ipv4_packet.set_source(*args.src_v4);
ipv4_packet.set_destination(dst_v4.clone()); ipv4_packet.set_destination(*args.dst_v4);
ipv4_packet.set_next_level_protocol(next_protocol); ipv4_packet.set_next_level_protocol(args.next_protocol);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable())); ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
tracing::trace!(?ipv4_packet, "udp nat packet response send"); tracing::trace!(?ipv4_packet, "udp nat packet response send");
@@ -254,7 +257,7 @@ mod tests {
#[test] #[test]
fn resembler() { fn resembler() {
let raw_packets = vec![ let raw_packets = [
// last packet // last packet
vec![ vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x01, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8, 0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x01, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
@@ -282,7 +285,7 @@ mod tests {
let resembler = IpReassembler::new(Duration::from_secs(1)); let resembler = IpReassembler::new(Duration::from_secs(1));
for (idx, raw_packet) in raw_packets.iter().enumerate() { for (idx, raw_packet) in raw_packets.iter().enumerate() {
if let Some(packet) = Ipv4Packet::new(&raw_packet) { if let Some(packet) = Ipv4Packet::new(raw_packet) {
let ret = resembler.add_fragment(source, destination, &packet); let ret = resembler.add_fragment(source, destination, &packet);
if idx != 2 { if idx != 2 {
assert!(ret.is_none()); assert!(ret.is_none());

View File

@@ -70,7 +70,9 @@ impl PeerPacketFilter for KcpEndpointFilter {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> { async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
let t = packet.peer_manager_header().unwrap().packet_type; let t = packet.peer_manager_header().unwrap().packet_type;
if t == PacketType::KcpSrc as u8 && !self.is_src { if t == PacketType::KcpSrc as u8 && !self.is_src {
// src packet, but we are dst
} else if t == PacketType::KcpDst as u8 && self.is_src { } else if t == PacketType::KcpDst as u8 && self.is_src {
// dst packet, but we are src
} else { } else {
return Some(packet); return Some(packet);
} }
@@ -103,7 +105,7 @@ async fn handle_kcp_output(
PacketType::KcpDst as u8 PacketType::KcpDst as u8
}; };
let mut packet = ZCPacket::new_with_payload(&packet.inner().freeze()); let mut packet = ZCPacket::new_with_payload(&packet.inner().freeze());
packet.fill_peer_manager_hdr(peer_mgr.my_peer_id(), dst_peer_id, packet_type as u8); packet.fill_peer_manager_hdr(peer_mgr.my_peer_id(), dst_peer_id, packet_type);
if let Err(e) = peer_mgr.send_msg(packet, dst_peer_id).await { if let Err(e) = peer_mgr.send_msg(packet, dst_peer_id).await {
tracing::error!("failed to send kcp packet to peer: {:?}", e); tracing::error!("failed to send kcp packet to peer: {:?}", e);
@@ -171,7 +173,7 @@ impl NatDstConnector for NatDstKcpConnector {
let kcp_endpoint = self.kcp_endpoint.clone(); let kcp_endpoint = self.kcp_endpoint.clone();
let my_peer_id = peer_mgr.my_peer_id(); let my_peer_id = peer_mgr.my_peer_id();
let conn_data_clone = conn_data.clone(); let conn_data_clone = conn_data;
connect_tasks.spawn(async move { connect_tasks.spawn(async move {
kcp_endpoint kcp_endpoint
@@ -182,9 +184,7 @@ impl NatDstConnector for NatDstKcpConnector {
Bytes::from(conn_data_clone.encode_to_vec()), Bytes::from(conn_data_clone.encode_to_vec()),
) )
.await .await
.with_context(|| { .with_context(|| format!("failed to connect to nat dst: {}", nat_dst))
format!("failed to connect to nat dst: {}", nat_dst.to_string())
})
}); });
} }
@@ -203,7 +203,7 @@ impl NatDstConnector for NatDstKcpConnector {
_ipv4: &Ipv4Packet, _ipv4: &Ipv4Packet,
_real_dst_ip: &mut Ipv4Addr, _real_dst_ip: &mut Ipv4Addr,
) -> bool { ) -> bool {
return hdr.from_peer_id == hdr.to_peer_id && hdr.is_kcp_src_modified(); hdr.from_peer_id == hdr.to_peer_id && hdr.is_kcp_src_modified()
} }
fn transport_type(&self) -> TcpProxyEntryTransportType { fn transport_type(&self) -> TcpProxyEntryTransportType {
@@ -230,7 +230,10 @@ impl TcpProxyForKcpSrcTrait for TcpProxyForKcpSrc {
} }
async fn check_dst_allow_kcp_input(&self, dst_ip: &Ipv4Addr) -> bool { async fn check_dst_allow_kcp_input(&self, dst_ip: &Ipv4Addr) -> bool {
self.0.get_peer_manager().check_allow_kcp_to_dst(&IpAddr::V4(*dst_ip)).await self.0
.get_peer_manager()
.check_allow_kcp_to_dst(&IpAddr::V4(*dst_ip))
.await
} }
} }
@@ -456,14 +459,11 @@ impl KcpProxyDst {
.into(); .into();
let src_socket: SocketAddr = parsed_conn_data.src.unwrap_or_default().into(); let src_socket: SocketAddr = parsed_conn_data.src.unwrap_or_default().into();
match dst_socket.ip() { if let IpAddr::V4(dst_v4_ip) = dst_socket.ip() {
IpAddr::V4(dst_v4_ip) => { let mut real_ip = dst_v4_ip;
let mut real_ip = dst_v4_ip; if cidr_set.contains_v4(dst_v4_ip, &mut real_ip) {
if cidr_set.contains_v4(dst_v4_ip, &mut real_ip) { dst_socket.set_ip(real_ip.into());
dst_socket.set_ip(real_ip.into());
}
} }
_ => {}
}; };
let conn_id = kcp_stream.conn_id(); let conn_id = kcp_stream.conn_id();
@@ -578,7 +578,7 @@ impl TcpProxyRpc for KcpProxyDstRpcService {
let mut reply = ListTcpProxyEntryResponse::default(); let mut reply = ListTcpProxyEntryResponse::default();
if let Some(tcp_proxy) = self.0.upgrade() { if let Some(tcp_proxy) = self.0.upgrade() {
for item in tcp_proxy.iter() { for item in tcp_proxy.iter() {
reply.entries.push(item.value().clone()); reply.entries.push(*item.value());
} }
} }
Ok(reply) Ok(reply)

View File

@@ -56,11 +56,11 @@ impl CidrSet {
cidr_set.lock().unwrap().clear(); cidr_set.lock().unwrap().clear();
for cidr in cidrs.iter() { for cidr in cidrs.iter() {
let real_cidr = cidr.cidr; let real_cidr = cidr.cidr;
let mapped = cidr.mapped_cidr.unwrap_or(real_cidr.clone()); let mapped = cidr.mapped_cidr.unwrap_or(real_cidr);
cidr_set.lock().unwrap().push(mapped.clone()); cidr_set.lock().unwrap().push(mapped);
if mapped != real_cidr { if mapped != real_cidr {
mapped_to_real.insert(mapped.clone(), real_cidr.clone()); mapped_to_real.insert(mapped, real_cidr);
} }
} }
} }
@@ -70,11 +70,11 @@ impl CidrSet {
} }
pub fn contains_v4(&self, ipv4: std::net::Ipv4Addr, real_ip: &mut std::net::Ipv4Addr) -> bool { pub fn contains_v4(&self, ipv4: std::net::Ipv4Addr, real_ip: &mut std::net::Ipv4Addr) -> bool {
let ip = ipv4.into(); let ip = ipv4;
let s = self.cidr_set.lock().unwrap(); let s = self.cidr_set.lock().unwrap();
for cidr in s.iter() { for cidr in s.iter() {
if cidr.contains(&ip) { if cidr.contains(&ip) {
if let Some(real_cidr) = self.mapped_to_real.get(&cidr).map(|v| v.value().clone()) { if let Some(real_cidr) = self.mapped_to_real.get(cidr).map(|v| *v.value()) {
let origin_network_bits = real_cidr.first().address().to_bits(); let origin_network_bits = real_cidr.first().address().to_bits();
let network_mask = cidr.mask().to_bits(); let network_mask = cidr.mask().to_bits();

View File

@@ -172,7 +172,7 @@ impl NatDstConnector for NatDstQUICConnector {
_ipv4: &Ipv4Packet, _ipv4: &Ipv4Packet,
_real_dst_ip: &mut Ipv4Addr, _real_dst_ip: &mut Ipv4Addr,
) -> bool { ) -> bool {
return hdr.from_peer_id == hdr.to_peer_id && !hdr.is_kcp_src_modified(); hdr.from_peer_id == hdr.to_peer_id && !hdr.is_kcp_src_modified()
} }
fn transport_type(&self) -> TcpProxyEntryTransportType { fn transport_type(&self) -> TcpProxyEntryTransportType {
@@ -457,7 +457,7 @@ impl TcpProxyRpc for QUICProxyDstRpcService {
let mut reply = ListTcpProxyEntryResponse::default(); let mut reply = ListTcpProxyEntryResponse::default();
if let Some(tcp_proxy) = self.0.upgrade() { if let Some(tcp_proxy) = self.0.upgrade() {
for item in tcp_proxy.iter() { for item in tcp_proxy.iter() {
reply.entries.push(item.value().clone()); reply.entries.push(*item.value());
} }
} }
Ok(reply) Ok(reply)

View File

@@ -72,9 +72,9 @@ impl SocksUdpSocket {
} }
enum SocksTcpStream { enum SocksTcpStream {
TcpStream(tokio::net::TcpStream), Tcp(tokio::net::TcpStream),
SmolTcpStream(super::tokio_smoltcp::TcpStream), SmolTcp(super::tokio_smoltcp::TcpStream),
KcpStream(KcpStream), Kcp(KcpStream),
} }
impl AsyncRead for SocksTcpStream { impl AsyncRead for SocksTcpStream {
@@ -84,15 +84,11 @@ impl AsyncRead for SocksTcpStream {
buf: &mut tokio::io::ReadBuf<'_>, buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> { ) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() { match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => { SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
std::pin::Pin::new(stream).poll_read(cx, buf) SocksTcpStream::SmolTcp(ref mut stream) => {
}
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_read(cx, buf)
}
SocksTcpStream::KcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_read(cx, buf) std::pin::Pin::new(stream).poll_read(cx, buf)
} }
SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
} }
} }
} }
@@ -104,15 +100,11 @@ impl AsyncWrite for SocksTcpStream {
buf: &[u8], buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> { ) -> std::task::Poll<Result<usize, std::io::Error>> {
match self.get_mut() { match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => { SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
std::pin::Pin::new(stream).poll_write(cx, buf) SocksTcpStream::SmolTcp(ref mut stream) => {
}
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_write(cx, buf)
}
SocksTcpStream::KcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_write(cx, buf) std::pin::Pin::new(stream).poll_write(cx, buf)
} }
SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
} }
} }
@@ -121,11 +113,9 @@ impl AsyncWrite for SocksTcpStream {
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> { ) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() { match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx), SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx),
SocksTcpStream::SmolTcpStream(ref mut stream) => { SocksTcpStream::SmolTcp(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx),
std::pin::Pin::new(stream).poll_flush(cx) SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx),
}
SocksTcpStream::KcpStream(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx),
} }
} }
@@ -134,15 +124,9 @@ impl AsyncWrite for SocksTcpStream {
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> { ) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() { match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => { SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
std::pin::Pin::new(stream).poll_shutdown(cx) SocksTcpStream::SmolTcp(ref mut stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
} SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_shutdown(cx)
}
SocksTcpStream::KcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_shutdown(cx)
}
} }
} }
} }
@@ -196,7 +180,7 @@ impl AsyncTcpConnector for SmolTcpConnector {
let modified_addr = let modified_addr =
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), addr.port()); SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), addr.port());
Ok(SocksTcpStream::TcpStream( Ok(SocksTcpStream::Tcp(
tcp_connect_with_timeout(modified_addr, timeout_s).await?, tcp_connect_with_timeout(modified_addr, timeout_s).await?,
)) ))
} else { } else {
@@ -207,9 +191,9 @@ impl AsyncTcpConnector for SmolTcpConnector {
.await .await
.with_context(|| "connect to remote timeout")?; .with_context(|| "connect to remote timeout")?;
Ok(SocksTcpStream::SmolTcpStream(remote_socket.map_err( Ok(SocksTcpStream::SmolTcp(remote_socket.map_err(|e| {
|e| super::fast_socks5::SocksError::Other(e.into()), super::fast_socks5::SocksError::Other(e.into())
)?)) })?))
} }
} }
} }
@@ -249,7 +233,7 @@ impl AsyncTcpConnector for Socks5KcpConnector {
.connect(self.src_addr, addr) .connect(self.src_addr, addr)
.await .await
.map_err(|e| super::fast_socks5::SocksError::Other(e.into()))?; .map_err(|e| super::fast_socks5::SocksError::Other(e.into()))?;
Ok(SocksTcpStream::KcpStream(ret)) Ok(SocksTcpStream::Kcp(ret))
} }
} }
@@ -560,16 +544,16 @@ impl Socks5Server {
tcp_forward_task.lock().unwrap().abort_all(); tcp_forward_task.lock().unwrap().abort_all();
udp_client_map.clear(); udp_client_map.clear();
if cur_ipv4.is_none() { if let Some(cur_ipv4) = cur_ipv4 {
let _ = net.lock().await.take();
} else {
net.lock().await.replace(Socks5ServerNet::new( net.lock().await.replace(Socks5ServerNet::new(
cur_ipv4.unwrap(), cur_ipv4,
None, None,
peer_manager.clone(), peer_manager.clone(),
packet_recv.clone(), packet_recv.clone(),
entries.clone(), entries.clone(),
)); ));
} else {
let _ = net.lock().await.take();
} }
} }
@@ -621,7 +605,7 @@ impl Socks5Server {
let cfgs = self.global_ctx.config.get_port_forwards(); let cfgs = self.global_ctx.config.get_port_forwards();
self.reload_port_forwards(&cfgs).await?; self.reload_port_forwards(&cfgs).await?;
need_start = need_start || cfgs.len() > 0; need_start = need_start || !cfgs.is_empty();
if need_start { if need_start {
self.peer_manager self.peer_manager
@@ -756,23 +740,21 @@ impl Socks5Server {
continue; continue;
}; };
let dst_allow_kcp = peer_mgr_arc.check_allow_kcp_to_dst(&dst_addr.ip()).await; let dst_allow_kcp = peer_mgr_arc.check_allow_kcp_to_dst(&dst_addr.ip()).await;
tracing::debug!("dst_allow_kcp: {:?}", dst_allow_kcp); tracing::debug!("dst_allow_kcp: {:?}", dst_allow_kcp);
let connector: Box<dyn AsyncTcpConnector<S = SocksTcpStream> + Send> = let connector: Box<dyn AsyncTcpConnector<S = SocksTcpStream> + Send> =
if kcp_endpoint.is_none() || !dst_allow_kcp { match (&kcp_endpoint, dst_allow_kcp) {
Box::new(SmolTcpConnector { (Some(kcp_endpoint), true) => Box::new(Socks5KcpConnector {
kcp_endpoint: kcp_endpoint.clone(),
peer_mgr: peer_mgr.clone(),
src_addr: addr,
}),
(_, _) => Box::new(SmolTcpConnector {
net: net.smoltcp_net.clone(), net: net.smoltcp_net.clone(),
entries: entries.clone(), entries: entries.clone(),
current_entry: std::sync::Mutex::new(None), current_entry: std::sync::Mutex::new(None),
}) }),
} else {
let kcp_endpoint = kcp_endpoint.as_ref().unwrap().clone();
Box::new(Socks5KcpConnector {
kcp_endpoint,
peer_mgr: peer_mgr.clone(),
src_addr: addr,
})
}; };
forward_tasks forward_tasks
@@ -962,10 +944,10 @@ impl Socks5Server {
udp_client_map.retain(|_, client_info| { udp_client_map.retain(|_, client_info| {
now.duration_since(client_info.last_active.load()).as_secs() < 600 now.duration_since(client_info.last_active.load()).as_secs() < 600
}); });
udp_forward_task.retain(|k, _| udp_client_map.contains_key(&k)); udp_forward_task.retain(|k, _| udp_client_map.contains_key(k));
entries.retain(|_, data| match data { entries.retain(|_, data| match data {
Socks5EntryData::Udp((_, udp_client_key)) => { Socks5EntryData::Udp((_, udp_client_key)) => {
udp_client_map.contains_key(&udp_client_key) udp_client_map.contains_key(udp_client_key)
} }
_ => true, _ => true,
}); });

View File

@@ -109,9 +109,9 @@ impl NatDstConnector for NatDstTcpConnector {
) -> bool { ) -> bool {
let is_exit_node = hdr.is_exit_node(); let is_exit_node = hdr.is_exit_node();
if !cidr_set.contains_v4(ipv4.get_destination(), real_dst_ip) if !(cidr_set.contains_v4(ipv4.get_destination(), real_dst_ip)
&& !is_exit_node || is_exit_node
&& !(global_ctx.no_tun() || global_ctx.no_tun()
&& Some(ipv4.get_destination()) && Some(ipv4.get_destination())
== global_ctx.get_ipv4().as_ref().map(Ipv4Inet::address)) == global_ctx.get_ipv4().as_ref().map(Ipv4Inet::address))
{ {
@@ -154,10 +154,10 @@ impl NatDstEntry {
} }
} }
fn into_pb(&self, transport_type: TcpProxyEntryTransportType) -> TcpProxyEntry { fn parse_as_pb(&self, transport_type: TcpProxyEntryTransportType) -> TcpProxyEntry {
TcpProxyEntry { TcpProxyEntry {
src: Some(self.src.clone().into()), src: Some(self.src.into()),
dst: Some(self.real_dst.clone().into()), dst: Some(self.real_dst.into()),
start_time: self.start_time_local.timestamp() as u64, start_time: self.start_time_local.timestamp() as u64,
state: self.state.load().into(), state: self.state.load().into(),
transport_type: transport_type.into(), transport_type: transport_type.into(),
@@ -332,16 +332,14 @@ pub struct TcpProxy<C: NatDstConnector> {
#[async_trait::async_trait] #[async_trait::async_trait]
impl<C: NatDstConnector> PeerPacketFilter for TcpProxy<C> { impl<C: NatDstConnector> PeerPacketFilter for TcpProxy<C> {
async fn try_process_packet_from_peer(&self, mut packet: ZCPacket) -> Option<ZCPacket> { async fn try_process_packet_from_peer(&self, mut packet: ZCPacket) -> Option<ZCPacket> {
if let Some(_) = self.try_handle_peer_packet(&mut packet).await { if self.try_handle_peer_packet(&mut packet).await.is_some() {
if self.is_smoltcp_enabled() { if self.is_smoltcp_enabled() {
let smoltcp_stack_sender = self.smoltcp_stack_sender.as_ref().unwrap(); let smoltcp_stack_sender = self.smoltcp_stack_sender.as_ref().unwrap();
if let Err(e) = smoltcp_stack_sender.try_send(packet) { if let Err(e) = smoltcp_stack_sender.try_send(packet) {
tracing::error!("send to smoltcp stack failed: {:?}", e); tracing::error!("send to smoltcp stack failed: {:?}", e);
} }
} else { } else if let Err(e) = self.peer_manager.get_nic_channel().send(packet).await {
if let Err(e) = self.peer_manager.get_nic_channel().send(packet).await { tracing::error!("send to nic failed: {:?}", e);
tracing::error!("send to nic failed: {:?}", e);
}
} }
return None; return None;
} else { } else {
@@ -610,7 +608,7 @@ impl<C: NatDstConnector> TcpProxy<C> {
self.enable_smoltcp self.enable_smoltcp
.store(false, std::sync::atomic::Ordering::Relaxed); .store(false, std::sync::atomic::Ordering::Relaxed);
return Ok(ProxyTcpListener::KernelTcpListener(tcp_listener)); Ok(ProxyTcpListener::KernelTcpListener(tcp_listener))
} }
} }
@@ -917,10 +915,10 @@ impl<C: NatDstConnector> TcpProxy<C> {
let mut entries: Vec<TcpProxyEntry> = Vec::new(); let mut entries: Vec<TcpProxyEntry> = Vec::new();
let transport_type = self.connector.transport_type(); let transport_type = self.connector.transport_type();
for entry in self.syn_map.iter() { for entry in self.syn_map.iter() {
entries.push(entry.value().as_ref().into_pb(transport_type)); entries.push(entry.value().as_ref().parse_as_pb(transport_type));
} }
for entry in self.conn_map.iter() { for entry in self.conn_map.iter() {
entries.push(entry.value().as_ref().into_pb(transport_type)); entries.push(entry.value().as_ref().parse_as_pb(transport_type));
} }
entries entries
} }

View File

@@ -17,11 +17,17 @@ pub struct ChannelDevice {
caps: DeviceCapabilities, caps: DeviceCapabilities,
} }
pub type ChannelDeviceNewRet = (
ChannelDevice,
Sender<io::Result<Vec<u8>>>,
Receiver<Vec<u8>>,
);
impl ChannelDevice { impl ChannelDevice {
/// Make a new `ChannelDevice` with the given `recv` and `send` channels. /// Make a new `ChannelDevice` with the given `recv` and `send` channels.
/// ///
/// The `caps` is used to determine the device capabilities. `DeviceCapabilities::max_transmission_unit` must be set. /// The `caps` is used to determine the device capabilities. `DeviceCapabilities::max_transmission_unit` must be set.
pub fn new(caps: DeviceCapabilities) -> (Self, Sender<io::Result<Vec<u8>>>, Receiver<Vec<u8>>) { pub fn new(caps: DeviceCapabilities) -> ChannelDeviceNewRet {
let (tx1, rx1) = channel(1000); let (tx1, rx1) = channel(1000);
let (tx2, rx2) = channel(1000); let (tx2, rx2) = channel(1000);
( (
@@ -45,7 +51,7 @@ impl Stream for ChannelDevice {
} }
fn map_err(e: PollSendError<Vec<u8>>) -> io::Error { fn map_err(e: PollSendError<Vec<u8>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, e) io::Error::other(e)
} }
impl Sink<Vec<u8>> for ChannelDevice { impl Sink<Vec<u8>> for ChannelDevice {

View File

@@ -46,8 +46,8 @@ impl RxToken for BufferRxToken {
F: FnOnce(&[u8]) -> R, F: FnOnce(&[u8]) -> R,
{ {
let p = &mut self.0; let p = &mut self.0;
let result = f(p);
result f(p)
} }
} }
@@ -79,10 +79,9 @@ impl Device for BufferDevice {
Self: 'a; Self: 'a;
fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
match self.recv_queue.pop_front() { self.recv_queue
Some(p) => Some((BufferRxToken(p), BufferTxToken(self))), .pop_front()
None => None, .map(|p| (BufferRxToken(p), BufferTxToken(self)))
}
} }
fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> { fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {

View File

@@ -4,7 +4,7 @@
use std::{ use std::{
io, io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, net::{IpAddr, SocketAddr},
sync::{ sync::{
atomic::{AtomicU16, Ordering}, atomic::{AtomicU16, Ordering},
Arc, Arc,
@@ -34,7 +34,7 @@ mod socket_allocator;
/// Can be used to create a forever timestamp in neighbor. /// Can be used to create a forever timestamp in neighbor.
// The 60_000 is the same as NeighborCache::ENTRY_LIFETIME. // The 60_000 is the same as NeighborCache::ENTRY_LIFETIME.
pub const FOREVER: Instant = pub const FOREVER: Instant =
Instant::from_micros_const(i64::max_value() - Duration::from_millis(60_000).micros() as i64); Instant::from_micros_const(i64::MAX - Duration::from_millis(60_000).micros() as i64);
pub struct Neighbor { pub struct Neighbor {
pub protocol_addr: IpAddress, pub protocol_addr: IpAddress,
@@ -173,8 +173,8 @@ impl Net {
fn set_address(&self, mut addr: SocketAddr) -> SocketAddr { fn set_address(&self, mut addr: SocketAddr) -> SocketAddr {
if addr.ip().is_unspecified() { if addr.ip().is_unspecified() {
addr.set_ip(match self.ip_addr.address() { addr.set_ip(match self.ip_addr.address() {
IpAddress::Ipv4(ip) => Ipv4Addr::from(ip).into(), IpAddress::Ipv4(ip) => ip.into(),
IpAddress::Ipv6(ip) => Ipv6Addr::from(ip).into(), IpAddress::Ipv6(ip) => ip.into(),
#[allow(unreachable_patterns)] #[allow(unreachable_patterns)]
_ => panic!("address must not be unspecified"), _ => panic!("address must not be unspecified"),
}); });

View File

@@ -51,9 +51,7 @@ async fn run(
loop { loop {
let packets = device.take_send_queue(); let packets = device.take_send_queue();
async_iface async_iface.send_all(&mut iter(packets).map(Ok)).await?;
.send_all(&mut iter(packets).map(|p| Ok(p)))
.await?;
if recv_buf.is_empty() && device.need_wait() { if recv_buf.is_empty() && device.need_wait() {
let start = Instant::now(); let start = Instant::now();
@@ -94,14 +92,10 @@ async fn run(
// wake up all closed sockets (smoltcp seems have a bug that it doesn't wake up closed sockets) // wake up all closed sockets (smoltcp seems have a bug that it doesn't wake up closed sockets)
for (_, socket) in socket_allocator.sockets().lock().iter_mut() { for (_, socket) in socket_allocator.sockets().lock().iter_mut() {
match socket { if let Socket::Tcp(tcp) = socket {
Socket::Tcp(tcp) => { if tcp.state() == smoltcp::socket::tcp::State::Closed {
if tcp.state() == smoltcp::socket::tcp::State::Closed { tcp.abort();
tcp.abort();
}
} }
#[allow(unreachable_patterns)]
_ => {}
} }
} }
} }
@@ -164,10 +158,8 @@ impl Reactor {
impl Drop for Reactor { impl Drop for Reactor {
fn drop(&mut self) { fn drop(&mut self) {
for (_, socket) in self.socket_allocator.sockets().lock().iter_mut() { for (_, socket) in self.socket_allocator.sockets().lock().iter_mut() {
match socket { if let Socket::Tcp(tcp) = socket {
Socket::Tcp(tcp) => tcp.close(), tcp.close()
#[allow(unreachable_patterns)]
_ => {}
} }
} }
} }

View File

@@ -5,7 +5,7 @@ pub use smoltcp::socket::tcp;
use smoltcp::socket::udp; use smoltcp::socket::udp;
use smoltcp::wire::{IpAddress, IpEndpoint}; use smoltcp::wire::{IpAddress, IpEndpoint};
use std::mem::replace; use std::mem::replace;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::IpAddr;
use std::{ use std::{
io, io,
net::SocketAddr, net::SocketAddr,
@@ -25,7 +25,7 @@ pub struct TcpListener {
} }
fn map_err<E: std::error::Error>(e: E) -> io::Error { fn map_err<E: std::error::Error>(e: E) -> io::Error {
io::Error::new(io::ErrorKind::Other, e.to_string()) io::Error::other(e.to_string())
} }
impl TcpListener { impl TcpListener {
@@ -95,8 +95,8 @@ impl Stream for Incoming {
fn ep2sa(ep: &IpEndpoint) -> SocketAddr { fn ep2sa(ep: &IpEndpoint) -> SocketAddr {
match ep.addr { match ep.addr {
IpAddress::Ipv4(v4) => SocketAddr::new(IpAddr::V4(Ipv4Addr::from(v4)), ep.port), IpAddress::Ipv4(v4) => SocketAddr::new(IpAddr::V4(v4), ep.port),
IpAddress::Ipv6(v6) => SocketAddr::new(IpAddr::V6(Ipv6Addr::from(v6)), ep.port), IpAddress::Ipv6(v6) => SocketAddr::new(IpAddr::V6(v6), ep.port),
#[allow(unreachable_patterns)] #[allow(unreachable_patterns)]
_ => unreachable!(), _ => unreachable!(),
} }

View File

@@ -26,7 +26,7 @@ use tracing::Level;
use crate::{ use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask, PeerId}, common::{error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask, PeerId},
gateway::ip_reassembler::compose_ipv4_packet, gateway::ip_reassembler::{compose_ipv4_packet, ComposeIpv4PacketArgs},
peers::{peer_manager::PeerManager, PeerPacketFilter}, peers::{peer_manager::PeerManager, PeerPacketFilter},
tunnel::{ tunnel::{
common::{reserve_buf, setup_sokcet2}, common::{reserve_buf, setup_sokcet2},
@@ -110,13 +110,15 @@ impl UdpNatEntry {
)); ));
compose_ipv4_packet( compose_ipv4_packet(
&mut buf[..], ComposeIpv4PacketArgs {
src_v4.ip(), buf: &mut buf[..],
nat_src_v4.ip(), src_v4: src_v4.ip(),
IpNextHeaderProtocols::Udp, dst_v4: nat_src_v4.ip(),
payload_len + 8, // include udp header next_protocol: IpNextHeaderProtocols::Udp,
payload_mtu, payload_len: payload_len + 8, // include udp header
ip_id, payload_mtu,
ip_id,
},
|buf| { |buf| {
let mut p = ZCPacket::new_with_payload(buf); let mut p = ZCPacket::new_with_payload(buf);
p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8); p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8);
@@ -273,11 +275,12 @@ impl UdpProxy {
} }
let mut real_dst_ip = ipv4.get_destination(); let mut real_dst_ip = ipv4.get_destination();
if !self
if !(self
.cidr_set .cidr_set
.contains_v4(ipv4.get_destination(), &mut real_dst_ip) .contains_v4(ipv4.get_destination(), &mut real_dst_ip)
&& !is_exit_node || is_exit_node
&& !(self.global_ctx.no_tun() || self.global_ctx.no_tun()
&& Some(ipv4.get_destination()) && Some(ipv4.get_destination())
== self.global_ctx.get_ipv4().as_ref().map(Ipv4Inet::address)) == self.global_ctx.get_ipv4().as_ref().map(Ipv4Inet::address))
{ {
@@ -289,9 +292,7 @@ impl UdpProxy {
resembled_buf = resembled_buf =
self.ip_resemmbler self.ip_resemmbler
.add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4); .add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4);
if resembled_buf.is_none() { resembled_buf.as_ref()?;
return None;
};
udp::UdpPacket::new(resembled_buf.as_ref().unwrap())? udp::UdpPacket::new(resembled_buf.as_ref().unwrap())?
} else { } else {
udp::UdpPacket::new(ipv4.payload())? udp::UdpPacket::new(ipv4.payload())?
@@ -374,7 +375,7 @@ impl UdpProxy {
#[async_trait::async_trait] #[async_trait::async_trait]
impl PeerPacketFilter for UdpProxy { impl PeerPacketFilter for UdpProxy {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> { async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
if let Some(_) = self.try_handle_packet(&packet).await { if self.try_handle_packet(&packet).await.is_some() {
return None; return None;
} else { } else {
return Some(packet); return Some(packet);

View File

@@ -59,7 +59,7 @@ impl MagicDnsClientInstance {
tokio::time::sleep(Duration::from_millis(500)).await; tokio::time::sleep(Duration::from_millis(500)).await;
continue; continue;
} }
prev_last_update = Some(last_update);
let mut routes = peer_mgr.list_routes().await; let mut routes = peer_mgr.list_routes().await;
// add self as a route // add self as a route
let ctx = peer_mgr.get_global_ctx(); let ctx = peer_mgr.get_global_ctx();
@@ -79,6 +79,11 @@ impl MagicDnsClientInstance {
rpc_stub rpc_stub
.update_dns_record(BaseController::default(), req) .update_dns_record(BaseController::default(), req)
.await?; .await?;
let last_update_after_rpc = peer_mgr.get_route_peer_info_last_update_time().await;
if last_update_after_rpc == last_update {
prev_last_update = Some(last_update);
}
} }
} }

View File

@@ -74,7 +74,7 @@ impl Record {
} }
fn rr_type(&self) -> rr::RecordType { fn rr_type(&self) -> rr::RecordType {
self.rr_type.clone().into() self.rr_type
} }
} }
@@ -167,7 +167,7 @@ ttl = "61s"
let (domain, records) = config let (domain, records) = config
.zones .zones
.get_key_value("et.internal") .get_key_value("et.internal")
.map_or(Err(anyhow!("parse error")), |x| Ok(x))?; .ok_or(anyhow!("et.internal not found"))?;
assert_eq!(domain, "et.internal"); assert_eq!(domain, "et.internal");
assert_eq!(records.len(), 1); assert_eq!(records.len(), 1);
let record = &records[0]; let record = &records[0];
@@ -179,7 +179,7 @@ ttl = "61s"
let (domain, records) = config let (domain, records) = config
.zones .zones
.get_key_value("et.top") .get_key_value("et.top")
.map_or(Err(anyhow!("parse error")), |x| Ok(x))?; .ok_or(anyhow!("et.top not found"))?;
assert_eq!(domain, "et.top"); assert_eq!(domain, "et.top");
assert_eq!(records.len(), 1); assert_eq!(records.len(), 1);
let record = &records[0]; let record = &records[0];

View File

@@ -65,7 +65,7 @@ impl DnsRunner {
self.client = Some(client); self.client = Some(client);
self.client.as_mut().unwrap().run_and_wait().await; self.client.as_mut().unwrap().run_and_wait().await;
return Err(anyhow::anyhow!("Client instance exit")); Err(anyhow::anyhow!("Client instance exit"))
} }
pub async fn run(&mut self, canel_token: CancellationToken) { pub async fn run(&mut self, canel_token: CancellationToken) {

View File

@@ -96,12 +96,12 @@ impl Server {
.0 .0
.name_servers() .name_servers()
.iter() .iter()
.cloned() .filter(|&x| {
.filter(|x| {
!config !config
.excluded_forward_nameservers() .excluded_forward_nameservers()
.contains(&x.socket_addr.ip()) .contains(&x.socket_addr.ip())
}) })
.cloned()
.collect::<Vec<_>>() .collect::<Vec<_>>()
.into(), .into(),
options: Some(system_conf.1), options: Some(system_conf.1),
@@ -148,7 +148,7 @@ impl Server {
.with_context(|| { .with_context(|| {
format!( format!(
"DNS Server failed to create UDP socket for address {}", "DNS Server failed to create UDP socket for address {}",
address.to_string() address
) )
})?; })?;
socket2::SockRef::from(&socket) socket2::SockRef::from(&socket)
@@ -156,7 +156,7 @@ impl Server {
.with_context(|| { .with_context(|| {
format!( format!(
"DNS Server failed to set reuse address on socket {}", "DNS Server failed to set reuse address on socket {}",
address.to_string() address
) )
})?; })?;
socket.bind(&bind_addr.into()).with_context(|| { socket.bind(&bind_addr.into()).with_context(|| {
@@ -164,17 +164,17 @@ impl Server {
})?; })?;
socket socket
.set_nonblocking(true) .set_nonblocking(true)
.with_context(|| format!("DNS Server failed to set socket to non-blocking"))?; .with_context(|| "DNS Server failed to set socket to non-blocking".to_string())?;
let socket = UdpSocket::from_std(socket.into()).with_context(|| { let socket = UdpSocket::from_std(socket.into()).with_context(|| {
format!( format!(
"DNS Server failed to convert socket to UdpSocket for address {}", "DNS Server failed to convert socket to UdpSocket for address {}",
address.to_string() address
) )
})?; })?;
let local_addr = socket let local_addr = socket
.local_addr() .local_addr()
.with_context(|| format!("DNS Server failed to get local address"))?; .with_context(|| "DNS Server failed to get local address".to_string())?;
self.server.register_socket(socket); self.server.register_socket(socket);
Ok(local_addr) Ok(local_addr)

View File

@@ -70,6 +70,20 @@ pub(super) struct MagicDnsServerInstanceData {
} }
impl MagicDnsServerInstanceData { impl MagicDnsServerInstanceData {
fn is_valid_subdomain_label(s: &str) -> bool {
let s = s.trim();
// 长度检查1-63 个字符
if s.is_empty() || s.len() > 63 {
return false;
}
// 检查每个字符是否合法,并确保不以 '-' 开头或结尾
s.chars().all(|c| matches!(c, 'a'..='z' | '0'..='9' | '-'))
&& !s.starts_with('-')
&& !s.ends_with('-')
}
pub async fn update_dns_records<'a, T: Iterator<Item = &'a Route>>( pub async fn update_dns_records<'a, T: Iterator<Item = &'a Route>>(
&self, &self,
routes: T, routes: T,
@@ -81,6 +95,11 @@ impl MagicDnsServerInstanceData {
continue; continue;
} }
// check host name valid for dns
if !Self::is_valid_subdomain_label(&route.hostname) {
continue;
}
let Some(ipv4_addr) = route.ipv4_addr.unwrap_or_default().address else { let Some(ipv4_addr) = route.ipv4_addr.unwrap_or_default().address else {
continue; continue;
}; };
@@ -432,7 +451,7 @@ impl MagicDnsServerInstance {
if !self.tun_inet.contains(&self.data.fake_ip) && self.data.tun_dev.is_some() { if !self.tun_inet.contains(&self.data.fake_ip) && self.data.tun_dev.is_some() {
let ifcfg = IfConfiger {}; let ifcfg = IfConfiger {};
let _ = ifcfg let _ = ifcfg
.remove_ipv4_route(&self.data.tun_dev.as_ref().unwrap(), self.data.fake_ip, 32) .remove_ipv4_route(self.data.tun_dev.as_ref().unwrap(), self.data.fake_ip, 32)
.await; .await;
} }

View File

@@ -26,13 +26,19 @@ struct DNSConfigError {
source: Option<anyhow::Error>, source: Option<anyhow::Error>,
} }
type DbusPingFn = dyn Fn(&str, &str) -> Result<()>;
type DbusReadStringFn = dyn Fn(&str, &str, &str, &str) -> Result<String>;
type NmIsUsingResolvedFn = dyn Fn() -> Result<()>;
type NmVersionBetweenFn = dyn Fn(&str, &str) -> Result<bool>;
type ResolvconfStyleFn = dyn Fn() -> String;
// 配置环境结构体 // 配置环境结构体
struct OSConfigEnv { struct OSConfigEnv {
fs: Box<dyn FileSystem>, fs: Box<dyn FileSystem>,
dbus_ping: Box<dyn Fn(&str, &str) -> Result<()>>, dbus_ping: Box<DbusPingFn>,
dbus_read_string: Box<dyn Fn(&str, &str, &str, &str) -> Result<String>>, dbus_read_string: Box<DbusReadStringFn>,
nm_is_using_resolved: Box<dyn Fn() -> Result<()>>, nm_is_using_resolved: Box<NmIsUsingResolvedFn>,
nm_version_between: Box<dyn Fn(&str, &str) -> Result<bool>>, nm_version_between: Box<NmVersionBetweenFn>,
resolvconf_style: Box<dyn Fn() -> String>, resolvconf_style: Box<dyn Fn() -> String>,
} }
@@ -86,8 +92,7 @@ pub fn nm_is_using_resolved() -> Result<()> {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"NetworkManager is not using systemd-resolved, found: {:?}", "NetworkManager is not using systemd-resolved, found: {:?}",
value value
) ));
.into());
} }
Ok(()) Ok(())

View File

@@ -41,7 +41,7 @@ pub async fn prepare_env(dns_name: &str, tun_ip: Ipv4Inet) -> (Arc<PeerManager>,
pub async fn check_dns_record(fake_ip: &Ipv4Addr, domain: &str, expected_ip: &str) { pub async fn check_dns_record(fake_ip: &Ipv4Addr, domain: &str, expected_ip: &str) {
let stream = UdpClientStream::builder( let stream = UdpClientStream::builder(
SocketAddr::new(fake_ip.clone().into(), 53), SocketAddr::new((*fake_ip).into(), 53),
TokioRuntimeProvider::default(), TokioRuntimeProvider::default(),
) )
.build(); .build();

View File

@@ -258,7 +258,7 @@ pub struct Instance {
} }
impl Instance { impl Instance {
pub fn new(config: impl ConfigLoader + Send + Sync + 'static) -> Self { pub fn new(config: impl ConfigLoader + 'static) -> Self {
let global_ctx = Arc::new(GlobalCtx::new(config)); let global_ctx = Arc::new(GlobalCtx::new(config));
tracing::info!( tracing::info!(
@@ -304,10 +304,10 @@ impl Instance {
#[cfg(feature = "socks5")] #[cfg(feature = "socks5")]
let socks5_server = Socks5Server::new(global_ctx.clone(), peer_manager.clone(), None); let socks5_server = Socks5Server::new(global_ctx.clone(), peer_manager.clone(), None);
let rpc_server = global_ctx.config.get_rpc_portal().and_then(|s| { let rpc_server = global_ctx.config.get_rpc_portal().map(|s| {
Some(StandAloneServer::new(TcpTunnelListener::new( StandAloneServer::new(TcpTunnelListener::new(
format!("tcp://{}", s).parse().unwrap(), format!("tcp://{}", s).parse().unwrap(),
))) ))
}); });
Instance { Instance {
@@ -470,7 +470,7 @@ impl Instance {
continue; continue;
} }
let last_ip = current_dhcp_ip.clone(); let last_ip = current_dhcp_ip;
tracing::debug!( tracing::debug!(
?current_dhcp_ip, ?current_dhcp_ip,
?candidate_ipv4_addr, ?candidate_ipv4_addr,
@@ -509,11 +509,7 @@ impl Instance {
Self::use_new_nic_ctx( Self::use_new_nic_ctx(
nic_ctx.clone(), nic_ctx.clone(),
new_nic_ctx, new_nic_ctx,
Self::create_magic_dns_runner( Self::create_magic_dns_runner(peer_manager_c.clone(), ifname, ip),
peer_manager_c.clone(),
ifname,
ip.clone(),
),
) )
.await; .await;
} }
@@ -890,7 +886,7 @@ impl Instance {
) -> Result<GetStatsResponse, rpc_types::error::Error> { ) -> Result<GetStatsResponse, rpc_types::error::Error> {
let stats_manager = self.global_ctx.stats_manager(); let stats_manager = self.global_ctx.stats_manager();
let snapshots = stats_manager.get_all_metrics(); let snapshots = stats_manager.get_all_metrics();
let metrics = snapshots let metrics = snapshots
.into_iter() .into_iter()
.map(|snapshot| { .map(|snapshot| {
@@ -898,7 +894,7 @@ impl Instance {
for label in snapshot.labels.labels() { for label in snapshot.labels.labels() {
labels.insert(label.key.clone(), label.value.clone()); labels.insert(label.key.clone(), label.value.clone());
} }
MetricSnapshot { MetricSnapshot {
name: snapshot.name_str(), name: snapshot.name_str(),
value: snapshot.value, value: snapshot.value,
@@ -906,7 +902,7 @@ impl Instance {
} }
}) })
.collect(); .collect();
Ok(GetStatsResponse { metrics }) Ok(GetStatsResponse { metrics })
} }
@@ -917,7 +913,7 @@ impl Instance {
) -> Result<GetPrometheusStatsResponse, rpc_types::error::Error> { ) -> Result<GetPrometheusStatsResponse, rpc_types::error::Error> {
let stats_manager = self.global_ctx.stats_manager(); let stats_manager = self.global_ctx.stats_manager();
let prometheus_text = stats_manager.export_prometheus(); let prometheus_text = stats_manager.export_prometheus();
Ok(GetPrometheusStatsResponse { prometheus_text }) Ok(GetPrometheusStatsResponse { prometheus_text })
} }
} }

View File

@@ -56,7 +56,7 @@ pub fn get_listener_by_url(
} }
pub fn is_url_host_ipv6(l: &url::Url) -> bool { pub fn is_url_host_ipv6(l: &url::Url) -> bool {
l.host_str().map_or(false, |h| h.contains(':')) l.host_str().is_some_and(|h| h.contains(':'))
} }
pub fn is_url_host_unspecified(l: &url::Url) -> bool { pub fn is_url_host_unspecified(l: &url::Url) -> bool {

View File

@@ -1,5 +1,7 @@
pub mod dns_server; pub mod dns_server;
#[allow(clippy::module_inception)]
pub mod instance; pub mod instance;
pub mod listeners; pub mod listeners;
#[cfg(feature = "tun")] #[cfg(feature = "tun")]

View File

@@ -68,10 +68,10 @@ impl Stream for TunStream {
type Item = StreamItem; type Item = StreamItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<StreamItem>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<StreamItem>> {
let mut self_mut = self.project(); let self_mut = self.project();
let mut g = ready!(self_mut.l.poll_lock(cx)); let mut g = ready!(self_mut.l.poll_lock(cx));
reserve_buf(&mut self_mut.cur_buf, 2500, 4 * 1024); reserve_buf(self_mut.cur_buf, 2500, 4 * 1024);
if self_mut.cur_buf.len() == 0 { if self_mut.cur_buf.is_empty() {
unsafe { unsafe {
self_mut.cur_buf.set_len(*self_mut.payload_offset); self_mut.cur_buf.set_len(*self_mut.payload_offset);
} }
@@ -117,10 +117,7 @@ impl PacketProtocol {
match self { match self {
PacketProtocol::IPv4 => Ok(libc::ETH_P_IP as u16), PacketProtocol::IPv4 => Ok(libc::ETH_P_IP as u16),
PacketProtocol::IPv6 => Ok(libc::ETH_P_IPV6 as u16), PacketProtocol::IPv6 => Ok(libc::ETH_P_IPV6 as u16),
PacketProtocol::Other(_) => Err(io::Error::new( PacketProtocol::Other(_) => Err(io::Error::other("neither an IPv4 nor IPv6 packet")),
io::ErrorKind::Other,
"neither an IPv4 nor IPv6 packet",
)),
} }
} }
@@ -175,7 +172,7 @@ impl TunZCPacketToBytes {
} }
impl ZCPacketToBytes for TunZCPacketToBytes { impl ZCPacketToBytes for TunZCPacketToBytes {
fn into_bytes(&self, zc_packet: ZCPacket) -> Result<Bytes, TunnelError> { fn zcpacket_into_bytes(&self, zc_packet: ZCPacket) -> Result<Bytes, TunnelError> {
let payload_offset = zc_packet.payload_offset(); let payload_offset = zc_packet.payload_offset();
let mut inner = zc_packet.inner(); let mut inner = zc_packet.inner();
// we have peer manager header, so payload offset must larger than 4 // we have peer manager header, so payload offset must larger than 4
@@ -383,11 +380,11 @@ impl VirtualNic {
let dev_name = self.global_ctx.get_flags().dev_name; let dev_name = self.global_ctx.get_flags().dev_name;
if !dev_name.is_empty() { if !dev_name.is_empty() {
config.tun_name(format!("{}", dev_name)); config.tun_name(&dev_name);
} }
} }
#[cfg(any(target_os = "macos"))] #[cfg(target_os = "macos")]
config.platform_config(|config| { config.platform_config(|config| {
// disable packet information so we can process the header by ourselves, see tun2 impl for more details // disable packet information so we can process the header by ourselves, see tun2 impl for more details
config.packet_information(false); config.packet_information(false);
@@ -515,9 +512,7 @@ impl VirtualNic {
{ {
// set mtu by ourselves, rust-tun does not handle it correctly on windows // set mtu by ourselves, rust-tun does not handle it correctly on windows
let _g = self.global_ctx.net_ns.guard(); let _g = self.global_ctx.net_ns.guard();
self.ifcfg self.ifcfg.set_mtu(ifname.as_str(), mtu_in_config).await?;
.set_mtu(ifname.as_str(), mtu_in_config as u32)
.await?;
} }
let has_packet_info = cfg!(target_os = "macos"); let has_packet_info = cfg!(target_os = "macos");
@@ -643,7 +638,7 @@ impl NicCtx {
) -> Self { ) -> Self {
NicCtx { NicCtx {
global_ctx: global_ctx.clone(), global_ctx: global_ctx.clone(),
peer_mgr: Arc::downgrade(&peer_manager), peer_mgr: Arc::downgrade(peer_manager),
peer_packet_receiver, peer_packet_receiver,
nic: Arc::new(Mutex::new(VirtualNic::new(global_ctx))), nic: Arc::new(Mutex::new(VirtualNic::new(global_ctx))),
tasks: JoinSet::new(), tasks: JoinSet::new(),

View File

@@ -18,6 +18,12 @@ pub struct NetworkInstanceManager {
stop_check_notifier: Arc<tokio::sync::Notify>, stop_check_notifier: Arc<tokio::sync::Notify>,
} }
impl Default for NetworkInstanceManager {
fn default() -> Self {
Self::new()
}
}
impl NetworkInstanceManager { impl NetworkInstanceManager {
pub fn new() -> Self { pub fn new() -> Self {
NetworkInstanceManager { NetworkInstanceManager {
@@ -65,11 +71,9 @@ impl NetworkInstanceManager {
let Some(instance_stop_notifier) = instance_stop_notifier else { let Some(instance_stop_notifier) = instance_stop_notifier else {
return; return;
}; };
let _t = if let Some(event) = instance_event_receiver.flatten() { let _t = instance_event_receiver
Some(ScopedTask::from(handle_event(instance_id, event))) .flatten()
} else { .map(|event| ScopedTask::from(handle_event(instance_id, event)));
None
};
instance_stop_notifier.notified().await; instance_stop_notifier.notified().await;
if let Some(instance) = instance_map.get(&instance_id) { if let Some(instance) = instance_map.get(&instance_id) {
if let Some(e) = instance.get_latest_error_msg() { if let Some(e) = instance.get_latest_error_msg() {
@@ -124,17 +128,14 @@ impl NetworkInstanceManager {
let mut ret = BTreeMap::new(); let mut ret = BTreeMap::new();
for instance in self.instance_map.iter() { for instance in self.instance_map.iter() {
if let Some(info) = instance.get_running_info() { if let Some(info) = instance.get_running_info() {
ret.insert(instance.key().clone(), info); ret.insert(*instance.key(), info);
} }
} }
Ok(ret) Ok(ret)
} }
pub fn list_network_instance_ids(&self) -> Vec<uuid::Uuid> { pub fn list_network_instance_ids(&self) -> Vec<uuid::Uuid> {
self.instance_map self.instance_map.iter().map(|item| *item.key()).collect()
.iter()
.map(|item| item.key().clone())
.collect()
} }
pub fn get_network_instance_name(&self, instance_id: &uuid::Uuid) -> Option<String> { pub fn get_network_instance_name(&self, instance_id: &uuid::Uuid) -> Option<String> {
@@ -299,8 +300,8 @@ fn handle_event(
instance_id, instance_id,
format!( format!(
"port forward added. local: {}, remote: {}, proto: {}", "port forward added. local: {}, remote: {}, proto: {}",
cfg.bind_addr.unwrap().to_string(), cfg.bind_addr.unwrap(),
cfg.dst_addr.unwrap().to_string(), cfg.dst_addr.unwrap(),
cfg.socket_type().as_str_name() cfg.socket_type().as_str_name()
), ),
); );
@@ -347,9 +348,8 @@ mod tests {
let instance_id1 = manager let instance_id1 = manager
.run_network_instance( .run_network_instance(
TomlConfigLoader::new_from_str(cfg_str) TomlConfigLoader::new_from_str(cfg_str)
.map(|c| { .inspect(|c| {
c.set_listeners(vec![format!("tcp://0.0.0.0:{}", port).parse().unwrap()]); c.set_listeners(vec![format!("tcp://0.0.0.0:{}", port).parse().unwrap()]);
c
}) })
.unwrap(), .unwrap(),
ConfigSource::Cli, ConfigSource::Cli,
@@ -426,9 +426,8 @@ mod tests {
assert!(manager assert!(manager
.run_network_instance( .run_network_instance(
TomlConfigLoader::new_from_str(cfg_str) TomlConfigLoader::new_from_str(cfg_str)
.map(|c| { .inspect(|c| {
c.set_listeners(vec![format!("tcp://0.0.0.0:{}", port).parse().unwrap()]); c.set_listeners(vec![format!("tcp://0.0.0.0:{}", port).parse().unwrap()]);
c
}) })
.unwrap(), .unwrap(),
ConfigSource::GUI, ConfigSource::GUI,

View File

@@ -1,8 +1,5 @@
use std::{ use crate::common::config::PortForwardConfig;
collections::VecDeque, use crate::proto::web;
sync::{atomic::AtomicBool, Arc, RwLock},
};
use std::net::SocketAddr;
use crate::{ use crate::{
common::{ common::{
config::{ config::{
@@ -19,9 +16,12 @@ use crate::{
}; };
use anyhow::Context; use anyhow::Context;
use chrono::{DateTime, Local}; use chrono::{DateTime, Local};
use std::net::SocketAddr;
use std::{
collections::VecDeque,
sync::{atomic::AtomicBool, Arc, RwLock},
};
use tokio::{sync::broadcast, task::JoinSet}; use tokio::{sync::broadcast, task::JoinSet};
use crate::common::config::PortForwardConfig;
use crate::proto::web;
pub type MyNodeInfo = crate::proto::web::MyNodeInfo; pub type MyNodeInfo = crate::proto::web::MyNodeInfo;
@@ -89,7 +89,7 @@ impl EasyTierLauncher {
let _ = data.event_subscriber.read().unwrap().send(event.clone()); let _ = data.event_subscriber.read().unwrap().send(event.clone());
events.push_front(Event { events.push_front(Event {
time: chrono::Local::now(), time: chrono::Local::now(),
event: event, event,
}); });
if events.len() > 20 { if events.len() > 20 {
events.pop_back(); events.pop_back();
@@ -380,9 +380,7 @@ impl NetworkInstance {
} }
pub fn get_running_info(&self) -> Option<NetworkInstanceRunningInfo> { pub fn get_running_info(&self) -> Option<NetworkInstanceRunningInfo> {
if self.launcher.is_none() { self.launcher.as_ref()?;
return None;
}
let launcher = self.launcher.as_ref().unwrap(); let launcher = self.launcher.as_ref().unwrap();
@@ -434,19 +432,15 @@ impl NetworkInstance {
} }
pub fn subscribe_event(&self) -> Option<broadcast::Receiver<GlobalCtxEvent>> { pub fn subscribe_event(&self) -> Option<broadcast::Receiver<GlobalCtxEvent>> {
if let Some(launcher) = self.launcher.as_ref() { self.launcher
Some(launcher.data.event_subscriber.read().unwrap().subscribe()) .as_ref()
} else { .map(|launcher| launcher.data.event_subscriber.read().unwrap().subscribe())
None
}
} }
pub fn get_stop_notifier(&self) -> Option<Arc<tokio::sync::Notify>> { pub fn get_stop_notifier(&self) -> Option<Arc<tokio::sync::Notify>> {
if let Some(launcher) = self.launcher.as_ref() { self.launcher
Some(launcher.data.instance_stop_notifier.clone()) .as_ref()
} else { .map(|launcher| launcher.data.instance_stop_notifier.clone())
None
}
} }
pub fn get_latest_error_msg(&self) -> Option<String> { pub fn get_latest_error_msg(&self) -> Option<String> {
@@ -511,7 +505,7 @@ impl NetworkConfig {
if !cfg.get_dhcp() { if !cfg.get_dhcp() {
let virtual_ipv4 = self.virtual_ipv4.clone().unwrap_or_default(); let virtual_ipv4 = self.virtual_ipv4.clone().unwrap_or_default();
if virtual_ipv4.len() > 0 { if !virtual_ipv4.is_empty() {
let ip = format!("{}/{}", virtual_ipv4, self.network_length.unwrap_or(24)) let ip = format!("{}/{}", virtual_ipv4, self.network_length.unwrap_or(24))
.parse() .parse()
.with_context(|| { .with_context(|| {
@@ -596,8 +590,10 @@ impl NetworkConfig {
.iter() .iter()
.filter(|pf| !pf.bind_ip.is_empty() && !pf.dst_ip.is_empty()) .filter(|pf| !pf.bind_ip.is_empty() && !pf.dst_ip.is_empty())
.filter_map(|pf| { .filter_map(|pf| {
let bind_addr = format!("{}:{}", pf.bind_ip, pf.bind_port).parse::<SocketAddr>(); let bind_addr =
let dst_addr = format!("{}:{}", pf.dst_ip, pf.dst_port).parse::<SocketAddr>(); format!("{}:{}", pf.bind_ip, pf.bind_port).parse::<SocketAddr>();
let dst_addr =
format!("{}:{}", pf.dst_ip, pf.dst_port).parse::<SocketAddr>();
match (bind_addr, dst_addr) { match (bind_addr, dst_addr) {
(Ok(bind_addr), Ok(dst_addr)) => Some(PortForwardConfig { (Ok(bind_addr), Ok(dst_addr)) => Some(PortForwardConfig {
@@ -608,7 +604,7 @@ impl NetworkConfig {
_ => None, _ => None,
} }
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>(),
); );
} }
@@ -650,7 +646,7 @@ impl NetworkConfig {
cfg.set_routes(Some(routes)); cfg.set_routes(Some(routes));
} }
if self.exit_nodes.len() > 0 { if !self.exit_nodes.is_empty() {
let mut exit_nodes = Vec::<std::net::IpAddr>::with_capacity(self.exit_nodes.len()); let mut exit_nodes = Vec::<std::net::IpAddr>::with_capacity(self.exit_nodes.len());
for node in self.exit_nodes.iter() { for node in self.exit_nodes.iter() {
exit_nodes.push( exit_nodes.push(
@@ -669,7 +665,7 @@ impl NetworkConfig {
} }
} }
if self.mapped_listeners.len() > 0 { if !self.mapped_listeners.is_empty() {
cfg.set_mapped_listeners(Some( cfg.set_mapped_listeners(Some(
self.mapped_listeners self.mapped_listeners
.iter() .iter()
@@ -754,7 +750,7 @@ impl NetworkConfig {
} }
if self.enable_relay_network_whitelist.unwrap_or_default() { if self.enable_relay_network_whitelist.unwrap_or_default() {
if self.relay_network_whitelist.len() > 0 { if !self.relay_network_whitelist.is_empty() {
flags.relay_network_whitelist = self.relay_network_whitelist.join(" "); flags.relay_network_whitelist = self.relay_network_whitelist.join(" ");
} else { } else {
flags.relay_network_whitelist = "".to_string(); flags.relay_network_whitelist = "".to_string();
@@ -784,7 +780,9 @@ impl NetworkConfig {
pub fn new_from_config(config: &TomlConfigLoader) -> Result<Self, anyhow::Error> { pub fn new_from_config(config: &TomlConfigLoader) -> Result<Self, anyhow::Error> {
let default_config = TomlConfigLoader::default(); let default_config = TomlConfigLoader::default();
let mut result = Self::default(); let mut result = Self {
..Default::default()
};
result.instance_id = Some(config.get_id().to_string()); result.instance_id = Some(config.get_id().to_string());
if config.get_hostname() != default_config.get_hostname() { if config.get_hostname() != default_config.get_hostname() {
@@ -819,7 +817,7 @@ impl NetworkConfig {
result.listener_urls = config result.listener_urls = config
.get_listeners() .get_listeners()
.unwrap_or_else(|| vec![]) .unwrap_or_default()
.iter() .iter()
.map(|l| l.to_string()) .map(|l| l.to_string())
.collect(); .collect();
@@ -846,17 +844,16 @@ impl NetworkConfig {
let port_forwards = config.get_port_forwards(); let port_forwards = config.get_port_forwards();
if !port_forwards.is_empty() { if !port_forwards.is_empty() {
result.port_forwards = port_forwards.iter() result.port_forwards = port_forwards
.map(|f| { .iter()
web::PortForwardConfig { .map(|f| web::PortForwardConfig {
proto: f.proto.clone(), proto: f.proto.clone(),
bind_ip: f.bind_addr.ip().to_string(), bind_ip: f.bind_addr.ip().to_string(),
bind_port: f.bind_addr.port() as u32, bind_port: f.bind_addr.port() as u32,
dst_ip: f.dst_addr.ip().to_string(), dst_ip: f.dst_addr.ip().to_string(),
dst_port: f.dst_addr.port() as u32, dst_port: f.dst_addr.port() as u32,
} })
}). .collect();
collect();
} }
if let Some(vpn_config) = config.get_vpn_portal_config() { if let Some(vpn_config) = config.get_vpn_portal_config() {

View File

@@ -13,8 +13,8 @@ mod vpn_portal;
pub mod common; pub mod common;
pub mod connector; pub mod connector;
pub mod launcher;
pub mod instance_manager; pub mod instance_manager;
pub mod launcher;
pub mod peers; pub mod peers;
pub mod proto; pub mod proto;
pub mod tunnel; pub mod tunnel;
@@ -27,6 +27,6 @@ mod tests;
pub const VERSION: &str = common::constants::EASYTIER_VERSION; pub const VERSION: &str = common::constants::EASYTIER_VERSION;
rust_i18n::i18n!("locales", fallback = "en"); rust_i18n::i18n!("locales", fallback = "en");
pub fn print_completions<G: Generator>(generator: G, cmd: &mut Command, bin_name:&str) { pub fn print_completions<G: Generator>(generator: G, cmd: &mut Command, bin_name: &str) {
clap_complete::generate(generator, cmd, bin_name, &mut io::stdout()); clap_complete::generate(generator, cmd, bin_name, &mut io::stdout());
} }

View File

@@ -97,13 +97,13 @@ impl PeerCenterBase {
&self, &self,
job_ctx: T, job_ctx: T,
job_fn: impl Fn( job_fn: impl Fn(
Box<dyn PeerCenterRpc<Controller = BaseController> + Send>, Box<dyn PeerCenterRpc<Controller = BaseController> + Send>,
Arc<PeridicJobCtx<T>>, Arc<PeridicJobCtx<T>>,
) -> Fut ) -> Fut
+ Send + Send
+ Sync + Sync
+ 'static, + 'static,
) -> () { ) {
let my_peer_id = self.my_peer_id; let my_peer_id = self.my_peer_id;
let peer_mgr = self.peer_mgr.clone(); let peer_mgr = self.peer_mgr.clone();
let lock = self.lock.clone(); let lock = self.lock.clone();
@@ -126,7 +126,7 @@ impl PeerCenterBase {
return; return;
}; };
ctx.center_peer.store(center_peer.clone()); ctx.center_peer.store(center_peer);
tracing::trace!(?center_peer, "run periodic job"); tracing::trace!(?center_peer, "run periodic job");
let _g = lock.lock().await; let _g = lock.lock().await;
let stub = rpc_mgr let stub = rpc_mgr
@@ -310,7 +310,7 @@ impl PeerCenterInstance {
.init_periodic_job(ctx, |client, ctx| async move { .init_periodic_job(ctx, |client, ctx| async move {
let my_node_id = ctx.my_peer_id; let my_node_id = ctx.my_peer_id;
let peers = ctx.job_ctx.peer_mgr.list_peers().await; let peers = ctx.job_ctx.peer_mgr.list_peers().await;
let peer_list = peers.direct_peers.keys().map(|k| *k).collect(); let peer_list = peers.direct_peers.keys().copied().collect();
let job_ctx = &ctx.job_ctx; let job_ctx = &ctx.job_ctx;
// only report when: // only report when:
@@ -370,7 +370,7 @@ impl PeerCenterInstance {
.map .map
.get(&src) .get(&src)
.and_then(|src_peer_info| src_peer_info.direct_peers.get(&dst)) .and_then(|src_peer_info| src_peer_info.direct_peers.get(&dst))
.and_then(|info| Some(info.latency_ms)) .map(|info| info.latency_ms)
} }
} }
@@ -503,7 +503,7 @@ mod tests {
let peer_center_b = PeerCenterInstance::new(peer_mgr_b.clone()); let peer_center_b = PeerCenterInstance::new(peer_mgr_b.clone());
let peer_center_c = PeerCenterInstance::new(peer_mgr_c.clone()); let peer_center_c = PeerCenterInstance::new(peer_mgr_c.clone());
let peer_centers = vec![&peer_center_a, &peer_center_b, &peer_center_c]; let peer_centers = [&peer_center_a, &peer_center_b, &peer_center_c];
for pc in peer_centers.iter() { for pc in peer_centers.iter() {
pc.init().await; pc.init().await;
} }

View File

@@ -162,7 +162,7 @@ impl PeerCenterRpc for PeerCenterServer {
direct_peers: Default::default(), direct_peers: Default::default(),
}) })
.direct_peers .direct_peers
.insert(pair.dst, entry.info.clone()); .insert(pair.dst, entry.info);
} }
Ok(GetGlobalPeerMapResponse { Ok(GetGlobalPeerMapResponse {

View File

@@ -27,6 +27,12 @@ pub struct AclFilter {
acl_enabled: Arc<AtomicBool>, acl_enabled: Arc<AtomicBool>,
} }
impl Default for AclFilter {
fn default() -> Self {
Self::new()
}
}
impl AclFilter { impl AclFilter {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@@ -75,8 +81,8 @@ impl AclFilter {
let rules_stats = processor.get_rules_stats(); let rules_stats = processor.get_rules_stats();
AclStats { AclStats {
global: global_stats.into_iter().map(|(k, v)| (k, v)).collect(), global: global_stats.into_iter().collect(),
conn_track: conn_track.iter().map(|x| x.value().clone()).collect(), conn_track: conn_track.iter().map(|x| *x.value()).collect(),
rules: rules_stats, rules: rules_stats,
} }
} }

View File

@@ -15,21 +15,21 @@ pub struct AesGcmCipher {
#[derive(Clone)] #[derive(Clone)]
pub enum AesGcmEnum { pub enum AesGcmEnum {
AES128GCM(Aes128Gcm), AES128GCM(Box<Aes128Gcm>),
AES256GCM(Aes256Gcm), AES256GCM(Box<Aes256Gcm>),
} }
impl AesGcmCipher { impl AesGcmCipher {
pub fn new_128(key: [u8; 16]) -> Self { pub fn new_128(key: [u8; 16]) -> Self {
let key: &Key<Aes128Gcm> = &key.into(); let key: &Key<Aes128Gcm> = &key.into();
Self { Self {
cipher: AesGcmEnum::AES128GCM(Aes128Gcm::new(key)), cipher: AesGcmEnum::AES128GCM(Box::new(Aes128Gcm::new(key))),
} }
} }
pub fn new_256(key: [u8; 32]) -> Self { pub fn new_256(key: [u8; 32]) -> Self {
let key: &Key<Aes256Gcm> = &key.into(); let key: &Key<Aes256Gcm> = &key.into();
Self { Self {
cipher: AesGcmEnum::AES256GCM(Aes256Gcm::new(key)), cipher: AesGcmEnum::AES256GCM(Box::new(Aes256Gcm::new(key))),
} }
} }
} }
@@ -80,7 +80,7 @@ impl Encryptor for AesGcmCipher {
zc_packet zc_packet
.mut_inner() .mut_inner()
.truncate(old_len - AES_GCM_ENCRYPTION_RESERVED); .truncate(old_len - AES_GCM_ENCRYPTION_RESERVED);
return Ok(()); Ok(())
} }
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> { fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
@@ -104,7 +104,7 @@ impl Encryptor for AesGcmCipher {
} }
}; };
return match rs { match rs {
Ok(tag) => { Ok(tag) => {
tail.tag.copy_from_slice(tag.as_slice()); tail.tag.copy_from_slice(tag.as_slice());
@@ -114,7 +114,7 @@ impl Encryptor for AesGcmCipher {
Ok(()) Ok(())
} }
Err(_) => Err(Error::EncryptionFailed), Err(_) => Err(Error::EncryptionFailed),
}; }
} }
} }
@@ -137,10 +137,10 @@ mod tests {
packet.payload().len(), packet.payload().len(),
text.len() + AES_GCM_ENCRYPTION_RESERVED text.len() + AES_GCM_ENCRYPTION_RESERVED
); );
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), true); assert!(packet.peer_manager_header().unwrap().is_encrypted());
cipher.decrypt(&mut packet).unwrap(); cipher.decrypt(&mut packet).unwrap();
assert_eq!(packet.payload(), text); assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), false); assert!(!packet.peer_manager_header().unwrap().is_encrypted());
} }
} }

View File

@@ -43,7 +43,7 @@ impl Encryptor for NullCipher {
fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> { fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap(); let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() { if pm_header.is_encrypted() {
return Err(Error::DecryptionFailed); Err(Error::DecryptionFailed)
} else { } else {
Ok(()) Ok(())
} }

View File

@@ -212,12 +212,12 @@ mod tests {
// 加密 // 加密
cipher.encrypt(&mut packet).unwrap(); cipher.encrypt(&mut packet).unwrap();
assert!(packet.payload().len() > text.len() + OPENSSL_ENCRYPTION_RESERVED); assert!(packet.payload().len() > text.len() + OPENSSL_ENCRYPTION_RESERVED);
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), true); assert!(packet.peer_manager_header().unwrap().is_encrypted());
// 解密 // 解密
cipher.decrypt(&mut packet).unwrap(); cipher.decrypt(&mut packet).unwrap();
assert_eq!(packet.payload(), text); assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), false); assert!(!packet.peer_manager_header().unwrap().is_encrypted());
} }
#[test] #[test]
@@ -231,11 +231,11 @@ mod tests {
// 加密 // 加密
cipher.encrypt(&mut packet).unwrap(); cipher.encrypt(&mut packet).unwrap();
assert!(packet.payload().len() > text.len()); assert!(packet.payload().len() > text.len());
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), true); assert!(packet.peer_manager_header().unwrap().is_encrypted());
// 解密 // 解密
cipher.decrypt(&mut packet).unwrap(); cipher.decrypt(&mut packet).unwrap();
assert_eq!(packet.payload(), text); assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), false); assert!(!packet.peer_manager_header().unwrap().is_encrypted());
} }
} }

View File

@@ -65,7 +65,7 @@ impl Encryptor for AesGcmCipher {
let text_and_tag_len = payload_len - AES_GCM_ENCRYPTION_RESERVED + 16; let text_and_tag_len = payload_len - AES_GCM_ENCRYPTION_RESERVED + 16;
let aes_tail = AesGcmTail::ref_from_suffix(zc_packet.payload()).unwrap(); let aes_tail = AesGcmTail::ref_from_suffix(zc_packet.payload()).unwrap();
let nonce = aead::Nonce::assume_unique_for_key(aes_tail.nonce.clone()); let nonce = aead::Nonce::assume_unique_for_key(aes_tail.nonce);
let rs = match &self.cipher { let rs = match &self.cipher {
AesGcmEnum::AesGCM128(cipher, _) => cipher.open_in_place( AesGcmEnum::AesGCM128(cipher, _) => cipher.open_in_place(
@@ -79,7 +79,7 @@ impl Encryptor for AesGcmCipher {
&mut zc_packet.mut_payload()[..text_and_tag_len], &mut zc_packet.mut_payload()[..text_and_tag_len],
), ),
}; };
if let Err(_) = rs { if rs.is_err() {
return Err(Error::DecryptionFailed); return Err(Error::DecryptionFailed);
} }
@@ -89,7 +89,7 @@ impl Encryptor for AesGcmCipher {
zc_packet zc_packet
.mut_inner() .mut_inner()
.truncate(old_len - AES_GCM_ENCRYPTION_RESERVED); .truncate(old_len - AES_GCM_ENCRYPTION_RESERVED);
return Ok(()); Ok(())
} }
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> { fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
@@ -101,7 +101,7 @@ impl Encryptor for AesGcmCipher {
let mut tail = AesGcmTail::default(); let mut tail = AesGcmTail::default();
rand::thread_rng().fill_bytes(&mut tail.nonce); rand::thread_rng().fill_bytes(&mut tail.nonce);
let nonce = aead::Nonce::assume_unique_for_key(tail.nonce.clone()); let nonce = aead::Nonce::assume_unique_for_key(tail.nonce);
let rs = match &self.cipher { let rs = match &self.cipher {
AesGcmEnum::AesGCM128(cipher, _) => cipher.seal_in_place_separate_tag( AesGcmEnum::AesGCM128(cipher, _) => cipher.seal_in_place_separate_tag(
@@ -115,7 +115,7 @@ impl Encryptor for AesGcmCipher {
zc_packet.mut_payload(), zc_packet.mut_payload(),
), ),
}; };
return match rs { match rs {
Ok(tag) => { Ok(tag) => {
let tag = tag.as_ref(); let tag = tag.as_ref();
if tag.len() != 16 { if tag.len() != 16 {
@@ -129,7 +129,7 @@ impl Encryptor for AesGcmCipher {
Ok(()) Ok(())
} }
Err(_) => Err(Error::EncryptionFailed), Err(_) => Err(Error::EncryptionFailed),
}; }
} }
} }
@@ -152,10 +152,10 @@ mod tests {
packet.payload().len(), packet.payload().len(),
text.len() + AES_GCM_ENCRYPTION_RESERVED text.len() + AES_GCM_ENCRYPTION_RESERVED
); );
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), true); assert!(packet.peer_manager_header().unwrap().is_encrypted());
cipher.decrypt(&mut packet).unwrap(); cipher.decrypt(&mut packet).unwrap();
assert_eq!(packet.payload(), text); assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), false); assert!(!packet.peer_manager_header().unwrap().is_encrypted());
} }
} }

Some files were not shown because too many files have changed in this diff Show More