use customized rpc implementation, remove Tarpc & Tonic (#348)

This patch removes Tarpc & Tonic GRPC and implements a customized rpc framework, which can be used by peer rpc and cli interface.

web config server can also use this rpc framework.

moreover, rewrite the public server logic, use ospf route to implement public server based networking. this make public server mesh possible.
This commit is contained in:
Sijie.Sun
2024-09-18 21:55:28 +08:00
committed by GitHub
parent 0467b0a3dc
commit 1b03223537
77 changed files with 3844 additions and 2856 deletions

295
Cargo.lock generated
View File

@@ -369,15 +369,6 @@ dependencies = [
"system-deps", "system-deps",
] ]
[[package]]
name = "atomic-polyfill"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4"
dependencies = [
"critical-section",
]
[[package]] [[package]]
name = "atomic-shim" name = "atomic-shim"
version = "0.2.0" version = "0.2.0"
@@ -427,53 +418,6 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "axum"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
dependencies = [
"async-trait",
"axum-core",
"bytes",
"futures-util",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"itoa 1.0.11",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"sync_wrapper 1.0.1",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper 0.1.2",
"tower-layer",
"tower-service",
]
[[package]] [[package]]
name = "backtrace" name = "backtrace"
version = "0.3.73" version = "0.3.73"
@@ -960,12 +904,6 @@ dependencies = [
"error-code", "error-code",
] ]
[[package]]
name = "cobs"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15"
[[package]] [[package]]
name = "cocoa" name = "cocoa"
version = "0.25.0" version = "0.25.0"
@@ -1176,12 +1114,6 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "critical-section"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7059fff8937831a9ae6f0fe4d658ffabf58f2ca96aa9dec1c889f936f705f216"
[[package]] [[package]]
name = "crossbeam" name = "crossbeam"
version = "0.8.4" version = "0.8.4"
@@ -1638,14 +1570,16 @@ dependencies = [
"petgraph", "petgraph",
"pin-project-lite", "pin-project-lite",
"pnet", "pnet",
"postcard",
"prost", "prost",
"prost-build",
"prost-types",
"quinn", "quinn",
"rand 0.8.5", "rand 0.8.5",
"rcgen", "rcgen",
"regex", "regex",
"reqwest 0.11.27", "reqwest 0.11.27",
"ring 0.17.8", "ring 0.17.8",
"rpc_build",
"rstest", "rstest",
"rust-i18n", "rust-i18n",
"rustls", "rustls",
@@ -1657,7 +1591,6 @@ dependencies = [
"sys-locale", "sys-locale",
"tabled", "tabled",
"tachyonix", "tachyonix",
"tarpc",
"thiserror", "thiserror",
"time", "time",
"timedmap", "timedmap",
@@ -1668,7 +1601,6 @@ dependencies = [
"tokio-util", "tokio-util",
"tokio-websockets", "tokio-websockets",
"toml 0.8.19", "toml 0.8.19",
"tonic",
"tonic-build", "tonic-build",
"tracing", "tracing",
"tracing-appender", "tracing-appender",
@@ -1736,12 +1668,6 @@ version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ef6b89e5b37196644d8796de5268852ff179b44e96276cf4290264843743bb7" checksum = "4ef6b89e5b37196644d8796de5268852ff179b44e96276cf4290264843743bb7"
[[package]]
name = "embedded-io"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced"
[[package]] [[package]]
name = "encoding" name = "encoding"
version = "0.2.33" version = "0.2.33"
@@ -2532,25 +2458,6 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "h2"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa82e28a107a8cc405f0839610bdc9b15f1e25ec7d696aa5cf173edbcb1486ab"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
"http 1.1.0",
"indexmap 2.4.0",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]] [[package]]
name = "half" name = "half"
version = "2.4.1" version = "2.4.1"
@@ -2561,15 +2468,6 @@ dependencies = [
"crunchy", "crunchy",
] ]
[[package]]
name = "hash32"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67"
dependencies = [
"byteorder",
]
[[package]] [[package]]
name = "hash32" name = "hash32"
version = "0.3.1" version = "0.3.1"
@@ -2591,27 +2489,13 @@ version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "heapless"
version = "0.7.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f"
dependencies = [
"atomic-polyfill",
"hash32 0.2.1",
"rustc_version",
"serde",
"spin 0.9.8",
"stable_deref_trait",
]
[[package]] [[package]]
name = "heapless" name = "heapless"
version = "0.8.0" version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad"
dependencies = [ dependencies = [
"hash32 0.3.1", "hash32",
"stable_deref_trait", "stable_deref_trait",
] ]
@@ -2754,12 +2638,6 @@ dependencies = [
"libm", "libm",
] ]
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "0.14.30" version = "0.14.30"
@@ -2770,7 +2648,7 @@ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
"futures-util", "futures-util",
"h2 0.3.26", "h2",
"http 0.2.12", "http 0.2.12",
"http-body 0.4.6", "http-body 0.4.6",
"httparse", "httparse",
@@ -2793,11 +2671,9 @@ dependencies = [
"bytes", "bytes",
"futures-channel", "futures-channel",
"futures-util", "futures-util",
"h2 0.4.5",
"http 1.1.0", "http 1.1.0",
"http-body 1.0.1", "http-body 1.0.1",
"httparse", "httparse",
"httpdate",
"itoa 1.0.11", "itoa 1.0.11",
"pin-project-lite", "pin-project-lite",
"smallvec", "smallvec",
@@ -2805,19 +2681,6 @@ dependencies = [
"want", "want",
] ]
[[package]]
name = "hyper-timeout"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793"
dependencies = [
"hyper 1.4.1",
"hyper-util",
"pin-project-lite",
"tokio",
"tower-service",
]
[[package]] [[package]]
name = "hyper-tls" name = "hyper-tls"
version = "0.5.0" version = "0.5.0"
@@ -3380,12 +3243,6 @@ version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5"
[[package]]
name = "matchit"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]] [[package]]
name = "md5" name = "md5"
version = "0.7.0" version = "0.7.0"
@@ -3954,25 +3811,6 @@ dependencies = [
"vcpkg", "vcpkg",
] ]
[[package]]
name = "opentelemetry"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6105e89802af13fdf48c49d7646d3b533a70e536d818aae7e78ba0433d01acb8"
dependencies = [
"async-trait",
"crossbeam-channel",
"futures-channel",
"futures-executor",
"futures-util",
"js-sys",
"lazy_static",
"percent-encoding",
"pin-project",
"rand 0.8.5",
"thiserror",
]
[[package]] [[package]]
name = "option-ext" name = "option-ext"
version = "0.2.0" version = "0.2.0"
@@ -4482,18 +4320,6 @@ dependencies = [
"universal-hash", "universal-hash",
] ]
[[package]]
name = "postcard"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a55c51ee6c0db07e68448e336cf8ea4131a620edefebf9893e759b2d793420f8"
dependencies = [
"cobs",
"embedded-io",
"heapless 0.7.17",
"serde",
]
[[package]] [[package]]
name = "powerfmt" name = "powerfmt"
version = "0.2.0" version = "0.2.0"
@@ -4606,9 +4432,9 @@ dependencies = [
[[package]] [[package]]
name = "prost" name = "prost"
version = "0.13.1" version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13db3d3fde688c61e2446b4d843bc27a7e8af269a69440c0308021dc92333cc" checksum = "3b2ecbe40f08db5c006b5764a2645f7f3f141ce756412ac9e1dd6087e6d32995"
dependencies = [ dependencies = [
"bytes", "bytes",
"prost-derive", "prost-derive",
@@ -4616,9 +4442,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-build" name = "prost-build"
version = "0.13.1" version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bb182580f71dd070f88d01ce3de9f4da5021db7115d2e1c3605a754153b77c1" checksum = "f8650aabb6c35b860610e9cff5dc1af886c9e25073b7b1712a68972af4281302"
dependencies = [ dependencies = [
"bytes", "bytes",
"heck 0.5.0", "heck 0.5.0",
@@ -4637,9 +4463,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-derive" name = "prost-derive"
version = "0.13.1" version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18bec9b0adc4eba778b33684b7ba3e7137789434769ee3ce3930463ef904cfca" checksum = "acf0c195eebb4af52c752bec4f52f645da98b6e92077a04110c7f349477ae5ac"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"itertools 0.13.0", "itertools 0.13.0",
@@ -4650,9 +4476,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-types" name = "prost-types"
version = "0.13.1" version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cee5168b05f49d4b0ca581206eb14a7b22fafd963efe729ac48eb03266e25cc2" checksum = "60caa6738c7369b940c3d49246a8d1749323674c65cb13010134f5c9bad5b519"
dependencies = [ dependencies = [
"prost", "prost",
] ]
@@ -4939,7 +4765,7 @@ dependencies = [
"encoding_rs", "encoding_rs",
"futures-core", "futures-core",
"futures-util", "futures-util",
"h2 0.3.26", "h2",
"http 0.2.12", "http 0.2.12",
"http-body 0.4.6", "http-body 0.4.6",
"hyper 0.14.30", "hyper 0.14.30",
@@ -5035,6 +4861,14 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "rpc_build"
version = "0.1.0"
dependencies = [
"heck 0.5.0",
"prost-build",
]
[[package]] [[package]]
name = "rstest" name = "rstest"
version = "0.18.2" version = "0.18.2"
@@ -5667,7 +5501,7 @@ dependencies = [
"byteorder", "byteorder",
"cfg-if", "cfg-if",
"defmt", "defmt",
"heapless 0.8.0", "heapless",
"managed", "managed",
] ]
@@ -6000,40 +5834,6 @@ version = "0.12.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "tarpc"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f07cb5fb67b0a90ea954b5ffd2fac9944ffef5937c801b987d3f8913f0c37348"
dependencies = [
"anyhow",
"fnv",
"futures",
"humantime",
"opentelemetry",
"pin-project",
"rand 0.8.5",
"serde",
"static_assertions",
"tarpc-plugins",
"thiserror",
"tokio",
"tokio-util",
"tracing",
"tracing-opentelemetry",
]
[[package]]
name = "tarpc-plugins"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ee42b4e559f17bce0385ebf511a7beb67d5cc33c12c96b7f4e9789919d9c10f"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "tauri" name = "tauri"
version = "2.0.0-rc.2" version = "2.0.0-rc.2"
@@ -6590,7 +6390,6 @@ dependencies = [
"futures-core", "futures-core",
"futures-sink", "futures-sink",
"pin-project-lite", "pin-project-lite",
"slab",
"tokio", "tokio",
] ]
@@ -6696,36 +6495,6 @@ dependencies = [
"winnow 0.6.18", "winnow 0.6.18",
] ]
[[package]]
name = "tonic"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38659f4a91aba8598d27821589f5db7dddd94601e7a01b1e485a50e5484c7401"
dependencies = [
"async-stream",
"async-trait",
"axum",
"base64 0.22.1",
"bytes",
"h2 0.4.5",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"hyper 1.4.1",
"hyper-timeout",
"hyper-util",
"percent-encoding",
"pin-project",
"prost",
"socket2",
"tokio",
"tokio-stream",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "tonic-build" name = "tonic-build"
version = "0.12.1" version = "0.12.1"
@@ -6747,16 +6516,11 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-util", "futures-util",
"indexmap 1.9.3",
"pin-project", "pin-project",
"pin-project-lite", "pin-project-lite",
"rand 0.8.5",
"slab",
"tokio", "tokio",
"tokio-util",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing",
] ]
[[package]] [[package]]
@@ -6827,19 +6591,6 @@ dependencies = [
"tracing-core", "tracing-core",
] ]
[[package]]
name = "tracing-opentelemetry"
version = "0.17.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbbe89715c1dbbb790059e2565353978564924ee85017b5fff365c872ff6721f"
dependencies = [
"once_cell",
"opentelemetry",
"tracing",
"tracing-core",
"tracing-subscriber",
]
[[package]] [[package]]
name = "tracing-subscriber" name = "tracing-subscriber"
version = "0.3.18" version = "0.3.18"

View File

@@ -10,4 +10,3 @@ panic = "unwind"
panic = "abort" panic = "abort"
lto = true lto = true
codegen-units = 1 codegen-units = 1
strip = true

View File

@@ -49,7 +49,7 @@ futures = { version = "0.3", features = ["bilock", "unstable"] }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tokio-stream = "0.1" tokio-stream = "0.1"
tokio-util = { version = "0.7.9", features = ["codec", "net"] } tokio-util = { version = "0.7.9", features = ["codec", "net", "io"] }
async-stream = "0.3.5" async-stream = "0.3.5"
async-trait = "0.1.74" async-trait = "0.1.74"
@@ -101,14 +101,10 @@ uuid = { version = "1.5.0", features = [
crossbeam-queue = "0.3" crossbeam-queue = "0.3"
once_cell = "1.18.0" once_cell = "1.18.0"
# for packet
postcard = { "version" = "1.0.8", features = ["alloc"] }
# for rpc # for rpc
tonic = "0.12"
prost = "0.13" prost = "0.13"
prost-types = "0.13"
anyhow = "1.0" anyhow = "1.0"
tarpc = { version = "0.32", features = ["tokio1", "serde1"] }
url = { version = "2.5", features = ["serde"] } url = { version = "2.5", features = ["serde"] }
percent-encoding = "2.3.1" percent-encoding = "2.3.1"
@@ -194,6 +190,8 @@ winreg = "0.52"
tonic-build = "0.12" tonic-build = "0.12"
globwalk = "0.8.1" globwalk = "0.8.1"
regex = "1" regex = "1"
prost-build = "0.13.2"
rpc_build = { path = "src/proto/rpc_build" }
[target.'cfg(windows)'.build-dependencies] [target.'cfg(windows)'.build-dependencies]
reqwest = { version = "0.11", features = ["blocking"] } reqwest = { version = "0.11", features = ["blocking"] }

View File

@@ -129,14 +129,31 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
WindowsBuild::check_for_win(); WindowsBuild::check_for_win();
tonic_build::configure() prost_build::Config::new()
.type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]") .type_attribute(".common", "#[derive(serde::Serialize, serde::Deserialize)]")
.type_attribute("cli.DirectConnectedPeerInfo", "#[derive(Hash)]") .type_attribute(".error", "#[derive(serde::Serialize, serde::Deserialize)]")
.type_attribute("cli.PeerInfoForGlobalMap", "#[derive(Hash)]") .type_attribute(".cli", "#[derive(serde::Serialize, serde::Deserialize)]")
.type_attribute(
"peer_rpc.GetIpListResponse",
"#[derive(serde::Serialize, serde::Deserialize)]",
)
.type_attribute("peer_rpc.DirectConnectedPeerInfo", "#[derive(Hash)]")
.type_attribute("peer_rpc.PeerInfoForGlobalMap", "#[derive(Hash)]")
.type_attribute("common.RpcDescriptor", "#[derive(Hash, Eq)]")
.service_generator(Box::new(rpc_build::ServiceGenerator::new()))
.btree_map(&["."]) .btree_map(&["."])
.compile(&["proto/cli.proto"], &["proto/"]) .compile_protos(
&[
"src/proto/peer_rpc.proto",
"src/proto/common.proto",
"src/proto/error.proto",
"src/proto/tests.proto",
"src/proto/cli.proto",
],
&["src/proto/"],
)
.unwrap(); .unwrap();
// tonic_build::compile_protos("proto/cli.proto")?;
check_locale(); check_locale();
Ok(()) Ok(())
} }

View File

@@ -31,8 +31,6 @@ pub enum Error {
// RpcListenError(String), // RpcListenError(String),
#[error("Rpc connect error: {0}")] #[error("Rpc connect error: {0}")]
RpcConnectError(String), RpcConnectError(String),
#[error("Rpc error: {0}")]
RpcClientError(#[from] tarpc::client::RpcError),
#[error("Timeout error: {0}")] #[error("Timeout error: {0}")]
Timeout(#[from] tokio::time::error::Elapsed), Timeout(#[from] tokio::time::error::Elapsed),
#[error("url in blacklist")] #[error("url in blacklist")]

View File

@@ -4,7 +4,7 @@ use std::{
sync::{Arc, Mutex}, sync::{Arc, Mutex},
}; };
use crate::rpc::PeerConnInfo; use crate::proto::cli::PeerConnInfo;
use crossbeam::atomic::AtomicCell; use crossbeam::atomic::AtomicCell;
use super::{ use super::{
@@ -179,6 +179,10 @@ impl GlobalCtx {
self.config.get_network_identity() self.config.get_network_identity()
} }
pub fn get_network_name(&self) -> String {
self.get_network_identity().network_name
}
pub fn get_ip_collector(&self) -> Arc<IPCollector> { pub fn get_ip_collector(&self) -> Arc<IPCollector> {
self.ip_collector.clone() self.ip_collector.clone()
} }
@@ -191,7 +195,6 @@ impl GlobalCtx {
self.stun_info_collection.as_ref() self.stun_info_collection.as_ref()
} }
#[cfg(test)]
pub fn replace_stun_info_collector(&self, collector: Box<dyn StunInfoCollectorTrait>) { pub fn replace_stun_info_collector(&self, collector: Box<dyn StunInfoCollectorTrait>) {
// force replace the stun_info_collection without mut and drop the old one // force replace the stun_info_collection without mut and drop the old one
let ptr = &self.stun_info_collection as *const Box<dyn StunInfoCollectorTrait>; let ptr = &self.stun_info_collection as *const Box<dyn StunInfoCollectorTrait>;

View File

@@ -1,12 +1,13 @@
use std::{net::IpAddr, ops::Deref, sync::Arc}; use std::{net::IpAddr, ops::Deref, sync::Arc};
use crate::rpc::peer::GetIpListResponse;
use pnet::datalink::NetworkInterface; use pnet::datalink::NetworkInterface;
use tokio::{ use tokio::{
sync::{Mutex, RwLock}, sync::{Mutex, RwLock},
task::JoinSet, task::JoinSet,
}; };
use crate::proto::peer_rpc::GetIpListResponse;
use super::{netns::NetNS, stun::StunInfoCollectorTrait}; use super::{netns::NetNS, stun::StunInfoCollectorTrait};
pub const CACHED_IP_LIST_TIMEOUT_SEC: u64 = 60; pub const CACHED_IP_LIST_TIMEOUT_SEC: u64 = 60;
@@ -163,7 +164,7 @@ pub struct IPCollector {
impl IPCollector { impl IPCollector {
pub fn new<T: StunInfoCollectorTrait + 'static>(net_ns: NetNS, stun_info_collector: T) -> Self { pub fn new<T: StunInfoCollectorTrait + 'static>(net_ns: NetNS, stun_info_collector: T) -> Self {
Self { Self {
cached_ip_list: Arc::new(RwLock::new(GetIpListResponse::new())), cached_ip_list: Arc::new(RwLock::new(GetIpListResponse::default())),
collect_ip_task: Mutex::new(JoinSet::new()), collect_ip_task: Mutex::new(JoinSet::new()),
net_ns, net_ns,
stun_info_collector: Arc::new(Box::new(stun_info_collector)), stun_info_collector: Arc::new(Box::new(stun_info_collector)),
@@ -195,14 +196,18 @@ impl IPCollector {
let Ok(ip_addr) = ip.parse::<IpAddr>() else { let Ok(ip_addr) = ip.parse::<IpAddr>() else {
continue; continue;
}; };
if ip_addr.is_ipv4() {
cached_ip_list.write().await.public_ipv4 = ip.clone(); match ip_addr {
} else { IpAddr::V4(v) => {
cached_ip_list.write().await.public_ipv6 = ip.clone(); cached_ip_list.write().await.public_ipv4 = Some(v.into())
}
IpAddr::V6(v) => {
cached_ip_list.write().await.public_ipv6 = Some(v.into())
}
} }
} }
let sleep_sec = if !cached_ip_list.read().await.public_ipv4.is_empty() { let sleep_sec = if !cached_ip_list.read().await.public_ipv4.is_none() {
CACHED_IP_LIST_TIMEOUT_SEC CACHED_IP_LIST_TIMEOUT_SEC
} else { } else {
3 3
@@ -236,7 +241,7 @@ impl IPCollector {
#[tracing::instrument(skip(net_ns))] #[tracing::instrument(skip(net_ns))]
async fn do_collect_local_ip_addrs(net_ns: NetNS) -> GetIpListResponse { async fn do_collect_local_ip_addrs(net_ns: NetNS) -> GetIpListResponse {
let mut ret = crate::rpc::peer::GetIpListResponse::new(); let mut ret = GetIpListResponse::default();
let ifaces = Self::collect_interfaces(net_ns.clone()).await; let ifaces = Self::collect_interfaces(net_ns.clone()).await;
let _g = net_ns.guard(); let _g = net_ns.guard();
@@ -246,25 +251,28 @@ impl IPCollector {
if ip.is_loopback() || ip.is_multicast() { if ip.is_loopback() || ip.is_multicast() {
continue; continue;
} }
if ip.is_ipv4() { match ip {
ret.interface_ipv4s.push(ip.to_string()); std::net::IpAddr::V4(v4) => {
} else if ip.is_ipv6() { ret.interface_ipv4s.push(v4.into());
ret.interface_ipv6s.push(ip.to_string()); }
std::net::IpAddr::V6(v6) => {
ret.interface_ipv6s.push(v6.into());
}
} }
} }
} }
if let Ok(v4_addr) = local_ipv4().await { if let Ok(v4_addr) = local_ipv4().await {
tracing::trace!("got local ipv4: {}", v4_addr); tracing::trace!("got local ipv4: {}", v4_addr);
if !ret.interface_ipv4s.contains(&v4_addr.to_string()) { if !ret.interface_ipv4s.contains(&v4_addr.into()) {
ret.interface_ipv4s.push(v4_addr.to_string()); ret.interface_ipv4s.push(v4_addr.into());
} }
} }
if let Ok(v6_addr) = local_ipv6().await { if let Ok(v6_addr) = local_ipv6().await {
tracing::trace!("got local ipv6: {}", v6_addr); tracing::trace!("got local ipv6: {}", v6_addr);
if !ret.interface_ipv6s.contains(&v6_addr.to_string()) { if !ret.interface_ipv6s.contains(&v6_addr.into()) {
ret.interface_ipv6s.push(v6_addr.to_string()); ret.interface_ipv6s.push(v6_addr.into());
} }
} }

View File

@@ -1,9 +1,10 @@
use std::collections::BTreeSet; use std::collections::BTreeSet;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use crate::rpc::{NatType, StunInfo}; use crate::proto::common::{NatType, StunInfo};
use anyhow::Context; use anyhow::Context;
use chrono::Local; use chrono::Local;
use crossbeam::atomic::AtomicCell; use crossbeam::atomic::AtomicCell;
@@ -161,7 +162,7 @@ impl StunClient {
continue; continue;
}; };
tracing::debug!(b = ?&udp_buf[..len], ?tids, ?remote_addr, ?stun_host, "recv stun response, msg: {:#?}", msg); tracing::trace!(b = ?&udp_buf[..len], ?tids, ?remote_addr, ?stun_host, "recv stun response, msg: {:#?}", msg);
if msg.class() != MessageClass::SuccessResponse if msg.class() != MessageClass::SuccessResponse
|| msg.method() != BINDING || msg.method() != BINDING
@@ -216,7 +217,7 @@ impl StunClient {
changed_addr changed_addr
} }
#[tracing::instrument(ret, err, level = Level::DEBUG)] #[tracing::instrument(ret, level = Level::TRACE)]
pub async fn bind_request( pub async fn bind_request(
self, self,
change_ip: bool, change_ip: bool,
@@ -243,7 +244,7 @@ impl StunClient {
.encode_into_bytes(message.clone()) .encode_into_bytes(message.clone())
.with_context(|| "encode stun message")?; .with_context(|| "encode stun message")?;
tids.push(tid as u128); tids.push(tid as u128);
tracing::debug!(?message, ?msg, tid, "send stun request"); tracing::trace!(?message, ?msg, tid, "send stun request");
self.socket self.socket
.send_to(msg.as_slice().into(), &stun_host) .send_to(msg.as_slice().into(), &stun_host)
.await?; .await?;
@@ -276,7 +277,7 @@ impl StunClient {
latency_us: now.elapsed().as_micros() as u32, latency_us: now.elapsed().as_micros() as u32,
}; };
tracing::debug!( tracing::trace!(
?stun_host, ?stun_host,
?recv_addr, ?recv_addr,
?changed_socket_addr, ?changed_socket_addr,
@@ -303,14 +304,14 @@ impl StunClientBuilder {
task_set.spawn( task_set.spawn(
async move { async move {
let mut buf = [0; 1620]; let mut buf = [0; 1620];
tracing::info!("start stun packet listener"); tracing::trace!("start stun packet listener");
loop { loop {
let Ok((len, addr)) = udp_clone.recv_from(&mut buf).await else { let Ok((len, addr)) = udp_clone.recv_from(&mut buf).await else {
tracing::error!("udp recv_from error"); tracing::error!("udp recv_from error");
break; break;
}; };
let data = buf[..len].to_vec(); let data = buf[..len].to_vec();
tracing::debug!(?addr, ?data, "recv udp stun packet"); tracing::trace!(?addr, ?data, "recv udp stun packet");
let _ = stun_packet_sender_clone.send(StunPacket { data, addr }); let _ = stun_packet_sender_clone.send(StunPacket { data, addr });
} }
} }
@@ -552,12 +553,15 @@ pub struct StunInfoCollector {
udp_nat_test_result: Arc<RwLock<Option<UdpNatTypeDetectResult>>>, udp_nat_test_result: Arc<RwLock<Option<UdpNatTypeDetectResult>>>,
nat_test_result_time: Arc<AtomicCell<chrono::DateTime<Local>>>, nat_test_result_time: Arc<AtomicCell<chrono::DateTime<Local>>>,
redetect_notify: Arc<tokio::sync::Notify>, redetect_notify: Arc<tokio::sync::Notify>,
tasks: JoinSet<()>, tasks: std::sync::Mutex<JoinSet<()>>,
started: AtomicBool,
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl StunInfoCollectorTrait for StunInfoCollector { impl StunInfoCollectorTrait for StunInfoCollector {
fn get_stun_info(&self) -> StunInfo { fn get_stun_info(&self) -> StunInfo {
self.start_stun_routine();
let Some(result) = self.udp_nat_test_result.read().unwrap().clone() else { let Some(result) = self.udp_nat_test_result.read().unwrap().clone() else {
return Default::default(); return Default::default();
}; };
@@ -572,6 +576,8 @@ impl StunInfoCollectorTrait for StunInfoCollector {
} }
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> { async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> {
self.start_stun_routine();
let stun_servers = self let stun_servers = self
.udp_nat_test_result .udp_nat_test_result
.read() .read()
@@ -605,17 +611,14 @@ impl StunInfoCollectorTrait for StunInfoCollector {
impl StunInfoCollector { impl StunInfoCollector {
pub fn new(stun_servers: Vec<String>) -> Self { pub fn new(stun_servers: Vec<String>) -> Self {
let mut ret = Self { Self {
stun_servers: Arc::new(RwLock::new(stun_servers)), stun_servers: Arc::new(RwLock::new(stun_servers)),
udp_nat_test_result: Arc::new(RwLock::new(None)), udp_nat_test_result: Arc::new(RwLock::new(None)),
nat_test_result_time: Arc::new(AtomicCell::new(Local::now())), nat_test_result_time: Arc::new(AtomicCell::new(Local::now())),
redetect_notify: Arc::new(tokio::sync::Notify::new()), redetect_notify: Arc::new(tokio::sync::Notify::new()),
tasks: JoinSet::new(), tasks: std::sync::Mutex::new(JoinSet::new()),
}; started: AtomicBool::new(false),
}
ret.start_stun_routine();
ret
} }
pub fn new_with_default_servers() -> Self { pub fn new_with_default_servers() -> Self {
@@ -648,12 +651,18 @@ impl StunInfoCollector {
.collect() .collect()
} }
fn start_stun_routine(&mut self) { fn start_stun_routine(&self) {
if self.started.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
self.started
.store(true, std::sync::atomic::Ordering::Relaxed);
let stun_servers = self.stun_servers.clone(); let stun_servers = self.stun_servers.clone();
let udp_nat_test_result = self.udp_nat_test_result.clone(); let udp_nat_test_result = self.udp_nat_test_result.clone();
let udp_test_time = self.nat_test_result_time.clone(); let udp_test_time = self.nat_test_result_time.clone();
let redetect_notify = self.redetect_notify.clone(); let redetect_notify = self.redetect_notify.clone();
self.tasks.spawn(async move { self.tasks.lock().unwrap().spawn(async move {
loop { loop {
let servers = stun_servers.read().unwrap().clone(); let servers = stun_servers.read().unwrap().clone();
// use first three and random choose one from the rest // use first three and random choose one from the rest
@@ -712,6 +721,31 @@ impl StunInfoCollector {
} }
} }
pub struct MockStunInfoCollector {
pub udp_nat_type: NatType,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for MockStunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
StunInfo {
udp_nat_type: self.udp_nat_type as i32,
tcp_nat_type: NatType::Unknown as i32,
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
min_port: 100,
max_port: 200,
..Default::default()
}
}
async fn get_udp_port_mapping(&self, mut port: u16) -> Result<std::net::SocketAddr, Error> {
if port == 0 {
port = 40144;
}
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@@ -5,9 +5,17 @@ use std::{net::SocketAddr, sync::Arc};
use crate::{ use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
peers::{peer_manager::PeerManager, peer_rpc::PeerRpcManager}, peers::{peer_manager::PeerManager, peer_rpc::PeerRpcManager},
proto::{
peer_rpc::{
DirectConnectorRpc, DirectConnectorRpcClientFactory, DirectConnectorRpcServer,
GetIpListRequest, GetIpListResponse,
},
rpc_types::{self, controller::BaseController},
},
}; };
use crate::rpc::{peer::GetIpListResponse, PeerConnInfo}; use crate::proto::cli::PeerConnInfo;
use anyhow::Context;
use tokio::{task::JoinSet, time::timeout}; use tokio::{task::JoinSet, time::timeout};
use tracing::Instrument; use tracing::Instrument;
use url::Host; use url::Host;
@@ -17,11 +25,6 @@ use super::create_connector_by_url;
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1; pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300; pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300;
#[tarpc::service]
pub trait DirectConnectorRpc {
async fn get_ip_list() -> GetIpListResponse;
}
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait PeerManagerForDirectConnector { pub trait PeerManagerForDirectConnector {
async fn list_peers(&self) -> Vec<PeerId>; async fn list_peers(&self) -> Vec<PeerId>;
@@ -57,12 +60,23 @@ struct DirectConnectorManagerRpcServer {
global_ctx: ArcGlobalCtx, global_ctx: ArcGlobalCtx,
} }
#[tarpc::server] #[async_trait::async_trait]
impl DirectConnectorRpc for DirectConnectorManagerRpcServer { impl DirectConnectorRpc for DirectConnectorManagerRpcServer {
async fn get_ip_list(self, _: tarpc::context::Context) -> GetIpListResponse { type Controller = BaseController;
async fn get_ip_list(
&self,
_: BaseController,
_: GetIpListRequest,
) -> rpc_types::error::Result<GetIpListResponse> {
let mut ret = self.global_ctx.get_ip_collector().collect_ip_addrs().await; let mut ret = self.global_ctx.get_ip_collector().collect_ip_addrs().await;
ret.listeners = self.global_ctx.get_running_listeners(); ret.listeners = self
ret .global_ctx
.get_running_listeners()
.into_iter()
.map(Into::into)
.collect();
Ok(ret)
} }
} }
@@ -130,9 +144,16 @@ impl DirectConnectorManager {
} }
pub fn run_as_server(&mut self) { pub fn run_as_server(&mut self) {
self.data.peer_manager.get_peer_rpc_mgr().run_service( self.data
DIRECT_CONNECTOR_SERVICE_ID, .peer_manager
DirectConnectorManagerRpcServer::new(self.global_ctx.clone()).serve(), .get_peer_rpc_mgr()
.rpc_server()
.registry()
.register(
DirectConnectorRpcServer::new(DirectConnectorManagerRpcServer::new(
self.global_ctx.clone(),
)),
&self.data.global_ctx.get_network_name(),
); );
} }
@@ -238,7 +259,8 @@ impl DirectConnectorManager {
let enable_ipv6 = data.global_ctx.get_flags().enable_ipv6; let enable_ipv6 = data.global_ctx.get_flags().enable_ipv6;
let available_listeners = ip_list let available_listeners = ip_list
.listeners .listeners
.iter() .into_iter()
.map(Into::<url::Url>::into)
.filter_map(|l| if l.scheme() != "ring" { Some(l) } else { None }) .filter_map(|l| if l.scheme() != "ring" { Some(l) } else { None })
.filter(|l| l.port().is_some() && l.host().is_some()) .filter(|l| l.port().is_some() && l.host().is_some())
.filter(|l| { .filter(|l| {
@@ -268,28 +290,7 @@ impl DirectConnectorManager {
Some(SocketAddr::V4(_)) => { Some(SocketAddr::V4(_)) => {
ip_list.interface_ipv4s.iter().for_each(|ip| { ip_list.interface_ipv4s.iter().for_each(|ip| {
let mut addr = (*listener).clone(); let mut addr = (*listener).clone();
if addr.set_host(Some(ip.as_str())).is_ok() { if addr.set_host(Some(ip.to_string().as_str())).is_ok() {
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr.to_string(),
));
}
});
let mut addr = (*listener).clone();
if addr.set_host(Some(ip_list.public_ipv4.as_str())).is_ok() {
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr.to_string(),
));
}
}
Some(SocketAddr::V6(_)) => {
ip_list.interface_ipv6s.iter().for_each(|ip| {
let mut addr = (*listener).clone();
if addr.set_host(Some(format!("[{}]", ip).as_str())).is_ok() {
tasks.spawn(Self::try_connect_to_ip( tasks.spawn(Self::try_connect_to_ip(
data.clone(), data.clone(),
dst_peer_id.clone(), dst_peer_id.clone(),
@@ -298,9 +299,10 @@ impl DirectConnectorManager {
} }
}); });
if let Some(public_ipv4) = ip_list.public_ipv4 {
let mut addr = (*listener).clone(); let mut addr = (*listener).clone();
if addr if addr
.set_host(Some(format!("[{}]", ip_list.public_ipv6).as_str())) .set_host(Some(public_ipv4.to_string().as_str()))
.is_ok() .is_ok()
{ {
tasks.spawn(Self::try_connect_to_ip( tasks.spawn(Self::try_connect_to_ip(
@@ -310,6 +312,36 @@ impl DirectConnectorManager {
)); ));
} }
} }
}
Some(SocketAddr::V6(_)) => {
ip_list.interface_ipv6s.iter().for_each(|ip| {
let mut addr = (*listener).clone();
if addr
.set_host(Some(format!("[{}]", ip.to_string()).as_str()))
.is_ok()
{
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr.to_string(),
));
}
});
if let Some(public_ipv6) = ip_list.public_ipv6 {
let mut addr = (*listener).clone();
if addr
.set_host(Some(format!("[{}]", public_ipv6.to_string()).as_str()))
.is_ok()
{
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr.to_string(),
));
}
}
}
p => { p => {
tracing::error!(?p, ?listener, "failed to parse ip version from listener"); tracing::error!(?p, ?listener, "failed to parse ip version from listener");
} }
@@ -351,16 +383,21 @@ impl DirectConnectorManager {
tracing::trace!("try direct connect to peer: {}", dst_peer_id); tracing::trace!("try direct connect to peer: {}", dst_peer_id);
let ip_list = peer_manager let rpc_stub = peer_manager
.get_peer_rpc_mgr() .get_peer_rpc_mgr()
.do_client_rpc_scoped(1, dst_peer_id, |c| async { .rpc_client()
let client = .scoped_client::<DirectConnectorRpcClientFactory<BaseController>>(
DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn(); peer_manager.my_peer_id(),
let ip_list = client.get_ip_list(tarpc::context::current()).await; dst_peer_id,
data.global_ctx.get_network_name(),
);
let ip_list = rpc_stub
.get_ip_list(BaseController {}, GetIpListRequest {})
.await
.with_context(|| format!("get ip list from peer {}", dst_peer_id))?;
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list"); tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
ip_list
})
.await?;
Self::do_try_direct_connect_internal(data, dst_peer_id, ip_list).await Self::do_try_direct_connect_internal(data, dst_peer_id, ip_list).await
} }
@@ -380,7 +417,7 @@ mod tests {
connect_peer_manager, create_mock_peer_manager, wait_route_appear, connect_peer_manager, create_mock_peer_manager, wait_route_appear,
wait_route_appear_with_cost, wait_route_appear_with_cost,
}, },
rpc::peer::GetIpListResponse, proto::peer_rpc::GetIpListResponse,
}; };
#[rstest::rstest] #[rstest::rstest]
@@ -436,12 +473,14 @@ mod tests {
p_a.get_global_ctx(), p_a.get_global_ctx(),
p_a.clone(), p_a.clone(),
)); ));
let mut ip_list = GetIpListResponse::new(); let mut ip_list = GetIpListResponse::default();
ip_list ip_list
.listeners .listeners
.push("tcp://127.0.0.1:10222".parse().unwrap()); .push("tcp://127.0.0.1:10222".parse().unwrap());
ip_list.interface_ipv4s.push("127.0.0.1".to_string()); ip_list
.interface_ipv4s
.push("127.0.0.1".parse::<std::net::Ipv4Addr>().unwrap().into());
DirectConnectorManager::do_try_direct_connect_internal(data.clone(), 1, ip_list.clone()) DirectConnectorManager::do_try_direct_connect_internal(data.clone(), 1, ip_list.clone())
.await .await

View File

@@ -11,7 +11,12 @@ use tokio::{
use crate::{ use crate::{
common::PeerId, common::PeerId,
peers::peer_conn::PeerConnId, peers::peer_conn::PeerConnId,
rpc as easytier_rpc, proto::{
cli::{
ConnectorManageAction, ListConnectorResponse, ManageConnectorResponse, PeerConnInfo,
},
rpc_types::{self, controller::BaseController},
},
tunnel::{IpVersion, TunnelConnector}, tunnel::{IpVersion, TunnelConnector},
}; };
@@ -23,9 +28,9 @@ use crate::{
}, },
connector::set_bind_addr_for_peer_connector, connector::set_bind_addr_for_peer_connector,
peers::peer_manager::PeerManager, peers::peer_manager::PeerManager,
rpc::{ proto::cli::{
connector_manage_rpc_server::ConnectorManageRpc, Connector, ConnectorStatus, Connector, ConnectorManageRpc, ConnectorStatus, ListConnectorRequest,
ListConnectorRequest, ManageConnectorRequest, ManageConnectorRequest,
}, },
use_global_var, use_global_var,
}; };
@@ -105,12 +110,18 @@ impl ManualConnectorManager {
Ok(()) Ok(())
} }
pub async fn remove_connector(&self, url: &str) -> Result<(), Error> { pub async fn remove_connector(&self, url: url::Url) -> Result<(), Error> {
tracing::info!("remove_connector: {}", url); tracing::info!("remove_connector: {}", url);
if !self.list_connectors().await.iter().any(|x| x.url == url) { let url = url.into();
if !self
.list_connectors()
.await
.iter()
.any(|x| x.url.as_ref() == Some(&url))
{
return Err(Error::NotFound); return Err(Error::NotFound);
} }
self.data.removed_conn_urls.insert(url.into()); self.data.removed_conn_urls.insert(url.to_string());
Ok(()) Ok(())
} }
@@ -137,7 +148,7 @@ impl ManualConnectorManager {
ret.insert( ret.insert(
0, 0,
Connector { Connector {
url: conn_url, url: Some(conn_url.parse().unwrap()),
status: status.into(), status: status.into(),
}, },
); );
@@ -154,7 +165,7 @@ impl ManualConnectorManager {
ret.insert( ret.insert(
0, 0,
Connector { Connector {
url: conn_url, url: Some(conn_url.parse().unwrap()),
status: ConnectorStatus::Connecting.into(), status: ConnectorStatus::Connecting.into(),
}, },
); );
@@ -213,14 +224,14 @@ impl ManualConnectorManager {
} }
async fn handle_event(event: &GlobalCtxEvent, data: &ConnectorManagerData) { async fn handle_event(event: &GlobalCtxEvent, data: &ConnectorManagerData) {
let need_add_alive = |conn_info: &easytier_rpc::PeerConnInfo| conn_info.is_client; let need_add_alive = |conn_info: &PeerConnInfo| conn_info.is_client;
match event { match event {
GlobalCtxEvent::PeerConnAdded(conn_info) => { GlobalCtxEvent::PeerConnAdded(conn_info) => {
if !need_add_alive(conn_info) { if !need_add_alive(conn_info) {
return; return;
} }
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone(); let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
data.alive_conn_urls.insert(addr); data.alive_conn_urls.insert(addr.unwrap().to_string());
tracing::warn!("peer conn added: {:?}", conn_info); tracing::warn!("peer conn added: {:?}", conn_info);
} }
@@ -229,7 +240,7 @@ impl ManualConnectorManager {
return; return;
} }
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone(); let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
data.alive_conn_urls.remove(&addr); data.alive_conn_urls.remove(&addr.unwrap().to_string());
tracing::warn!("peer conn removed: {:?}", conn_info); tracing::warn!("peer conn removed: {:?}", conn_info);
} }
@@ -303,7 +314,7 @@ impl ManualConnectorManager {
tracing::info!("reconnect get tunnel succ: {:?}", tunnel); tracing::info!("reconnect get tunnel succ: {:?}", tunnel);
assert_eq!( assert_eq!(
dead_url, dead_url,
tunnel.info().unwrap().remote_addr, tunnel.info().unwrap().remote_addr.unwrap().to_string(),
"info: {:?}", "info: {:?}",
tunnel.info() tunnel.info()
); );
@@ -385,45 +396,43 @@ impl ManualConnectorManager {
} }
} }
#[derive(Clone)]
pub struct ConnectorManagerRpcService(pub Arc<ManualConnectorManager>); pub struct ConnectorManagerRpcService(pub Arc<ManualConnectorManager>);
#[tonic::async_trait] #[async_trait::async_trait]
impl ConnectorManageRpc for ConnectorManagerRpcService { impl ConnectorManageRpc for ConnectorManagerRpcService {
type Controller = BaseController;
async fn list_connector( async fn list_connector(
&self, &self,
_request: tonic::Request<ListConnectorRequest>, _: BaseController,
) -> Result<tonic::Response<easytier_rpc::ListConnectorResponse>, tonic::Status> { _request: ListConnectorRequest,
let mut ret = easytier_rpc::ListConnectorResponse::default(); ) -> Result<ListConnectorResponse, rpc_types::error::Error> {
let mut ret = ListConnectorResponse::default();
let connectors = self.0.list_connectors().await; let connectors = self.0.list_connectors().await;
ret.connectors = connectors; ret.connectors = connectors;
Ok(tonic::Response::new(ret)) Ok(ret)
} }
async fn manage_connector( async fn manage_connector(
&self, &self,
request: tonic::Request<ManageConnectorRequest>, _: BaseController,
) -> Result<tonic::Response<easytier_rpc::ManageConnectorResponse>, tonic::Status> { req: ManageConnectorRequest,
let req = request.into_inner(); ) -> Result<ManageConnectorResponse, rpc_types::error::Error> {
let url = url::Url::parse(&req.url) let url: url::Url = req.url.ok_or(anyhow::anyhow!("url is empty"))?.into();
.map_err(|_| tonic::Status::invalid_argument("invalid url"))?; if req.action == ConnectorManageAction::Remove as i32 {
if req.action == easytier_rpc::ConnectorManageAction::Remove as i32 { self.0
self.0.remove_connector(url.path()).await.map_err(|e| { .remove_connector(url.clone())
tonic::Status::invalid_argument(format!("remove connector failed: {:?}", e)) .await
})?; .with_context(|| format!("remove connector failed: {:?}", url))?;
return Ok(tonic::Response::new( return Ok(ManageConnectorResponse::default());
easytier_rpc::ManageConnectorResponse::default(),
));
} else { } else {
self.0 self.0
.add_connector_by_url(url.as_str()) .add_connector_by_url(url.as_str())
.await .await
.map_err(|e| { .with_context(|| format!("add connector failed: {:?}", url))?;
tonic::Status::invalid_argument(format!("add connector failed: {:?}", e))
})?;
} }
Ok(tonic::Response::new( Ok(ManageConnectorResponse::default())
easytier_rpc::ManageConnectorResponse::default(),
))
} }
} }

View File

@@ -32,14 +32,14 @@ async fn set_bind_addr_for_peer_connector(
if is_ipv4 { if is_ipv4 {
let mut bind_addrs = vec![]; let mut bind_addrs = vec![];
for ipv4 in ips.interface_ipv4s { for ipv4 in ips.interface_ipv4s {
let socket_addr = SocketAddrV4::new(ipv4.parse().unwrap(), 0).into(); let socket_addr = SocketAddrV4::new(ipv4.into(), 0).into();
bind_addrs.push(socket_addr); bind_addrs.push(socket_addr);
} }
connector.set_bind_addrs(bind_addrs); connector.set_bind_addrs(bind_addrs);
} else { } else {
let mut bind_addrs = vec![]; let mut bind_addrs = vec![];
for ipv6 in ips.interface_ipv6s { for ipv6 in ips.interface_ipv6s {
let socket_addr = SocketAddrV6::new(ipv6.parse().unwrap(), 0, 0, 0).into(); let socket_addr = SocketAddrV6::new(ipv6.into(), 0, 0, 0).into();
bind_addrs.push(socket_addr); bind_addrs.push(socket_addr);
} }
connector.set_bind_addrs(bind_addrs); connector.set_bind_addrs(bind_addrs);

View File

@@ -5,6 +5,7 @@ use std::{
Arc, Arc,
}, },
time::Duration, time::Duration,
u16,
}; };
use anyhow::Context; use anyhow::Context;
@@ -21,12 +22,20 @@ use zerocopy::FromBytes;
use crate::{ use crate::{
common::{ common::{
constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS,
scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId, scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId,
}, },
defer, defer,
peers::peer_manager::PeerManager, peers::peer_manager::PeerManager,
rpc::NatType, proto::{
common::NatType,
peer_rpc::{
TryPunchHoleRequest, TryPunchHoleResponse, TryPunchSymmetricRequest,
TryPunchSymmetricResponse, UdpHolePunchRpc, UdpHolePunchRpcClientFactory,
UdpHolePunchRpcServer,
},
rpc_types::{self, controller::BaseController},
},
tunnel::{ tunnel::{
common::setup_sokcet2, common::setup_sokcet2,
packet_def::{UDPTunnelHeader, UdpPacketType, UDP_TUNNEL_HEADER_SIZE}, packet_def::{UDPTunnelHeader, UdpPacketType, UDP_TUNNEL_HEADER_SIZE},
@@ -186,21 +195,6 @@ impl std::fmt::Debug for UdpSocketArray {
} }
} }
#[tarpc::service]
pub trait UdpHolePunchService {
async fn try_punch_hole(local_mapped_addr: SocketAddr) -> Option<SocketAddr>;
async fn try_punch_symmetric(
listener_addr: SocketAddr,
port: u16,
public_ips: Vec<Ipv4Addr>,
min_port: u16,
max_port: u16,
transaction_id: u32,
round: u32,
last_port_index: usize,
) -> Option<usize>;
}
#[derive(Debug)] #[derive(Debug)]
struct UdpHolePunchListener { struct UdpHolePunchListener {
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
@@ -324,23 +318,34 @@ impl UdpHolePunchConnectorData {
} }
#[derive(Clone)] #[derive(Clone)]
struct UdpHolePunchRpcServer { struct UdpHolePunchRpcService {
data: Arc<UdpHolePunchConnectorData>, data: Arc<UdpHolePunchConnectorData>,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>, tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
} }
#[tarpc::server] #[async_trait::async_trait]
impl UdpHolePunchService for UdpHolePunchRpcServer { impl UdpHolePunchRpc for UdpHolePunchRpcService {
type Controller = BaseController;
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
async fn try_punch_hole( async fn try_punch_hole(
self, &self,
_: tarpc::context::Context, _: BaseController,
local_mapped_addr: SocketAddr, request: TryPunchHoleRequest,
) -> Option<SocketAddr> { ) -> Result<TryPunchHoleResponse, rpc_types::error::Error> {
let local_mapped_addr = request.local_mapped_addr.ok_or(anyhow::anyhow!(
"try_punch_hole request missing local_mapped_addr"
))?;
let local_mapped_addr = std::net::SocketAddr::from(local_mapped_addr);
// local mapped addr will be unspecified if peer is symmetric // local mapped addr will be unspecified if peer is symmetric
let peer_is_symmetric = local_mapped_addr.ip().is_unspecified(); let peer_is_symmetric = local_mapped_addr.ip().is_unspecified();
let (socket, mapped_addr) = self.select_listener(peer_is_symmetric).await?; let (socket, mapped_addr) =
self.select_listener(peer_is_symmetric)
.await
.ok_or(anyhow::anyhow!(
"failed to select listener for hole punching"
))?;
tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching"); tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching");
if !peer_is_symmetric { if !peer_is_symmetric {
@@ -380,32 +385,48 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
} }
} }
Some(mapped_addr) Ok(TryPunchHoleResponse {
remote_mapped_addr: Some(mapped_addr.into()),
})
} }
#[instrument(skip(self))] #[instrument(skip(self))]
async fn try_punch_symmetric( async fn try_punch_symmetric(
self, &self,
_: tarpc::context::Context, _: BaseController,
listener_addr: SocketAddr, request: TryPunchSymmetricRequest,
port: u16, ) -> Result<TryPunchSymmetricResponse, rpc_types::error::Error> {
public_ips: Vec<Ipv4Addr>, let listener_addr = request.listener_addr.ok_or(anyhow::anyhow!(
mut min_port: u16, "try_punch_symmetric request missing listener_addr"
mut max_port: u16, ))?;
transaction_id: u32, let listener_addr = std::net::SocketAddr::from(listener_addr);
round: u32, let port = request.port as u16;
last_port_index: usize, let public_ips = request
) -> Option<usize> { .public_ips
.into_iter()
.map(|ip| std::net::Ipv4Addr::from(ip))
.collect::<Vec<_>>();
let mut min_port = request.min_port as u16;
let mut max_port = request.max_port as u16;
let transaction_id = request.transaction_id;
let round = request.round;
let last_port_index = request.last_port_index as usize;
tracing::info!("try_punch_symmetric start"); tracing::info!("try_punch_symmetric start");
let punch_predictablely = self.data.punch_predicablely.load(Ordering::Relaxed); let punch_predictablely = self.data.punch_predicablely.load(Ordering::Relaxed);
let punch_randomly = self.data.punch_randomly.load(Ordering::Relaxed); let punch_randomly = self.data.punch_randomly.load(Ordering::Relaxed);
let total_port_count = self.data.shuffled_port_vec.len(); let total_port_count = self.data.shuffled_port_vec.len();
let listener = self.find_listener(&listener_addr).await?; let listener = self
.find_listener(&listener_addr)
.await
.ok_or(anyhow::anyhow!(
"try_punch_symmetric failed to find listener"
))?;
let ip_count = public_ips.len(); let ip_count = public_ips.len();
if ip_count == 0 { if ip_count == 0 {
tracing::warn!("try_punch_symmetric got zero len public ip"); tracing::warn!("try_punch_symmetric got zero len public ip");
return None; return Err(anyhow::anyhow!("try_punch_symmetric got zero len public ip").into());
} }
min_port = std::cmp::max(1, min_port); min_port = std::cmp::max(1, min_port);
@@ -447,7 +468,7 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
&ports, &ports,
) )
.await .await
.ok()?; .with_context(|| "failed to send symmetric hole punch packet predict")?;
} }
if punch_randomly { if punch_randomly {
@@ -461,20 +482,22 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
&self.data.shuffled_port_vec[start..end], &self.data.shuffled_port_vec[start..end],
) )
.await .await
.ok()?; .with_context(|| "failed to send symmetric hole punch packet randomly")?;
return if end >= self.data.shuffled_port_vec.len() { return if end >= self.data.shuffled_port_vec.len() {
Some(1) Ok(TryPunchSymmetricResponse { last_port_index: 1 })
} else { } else {
Some(end) Ok(TryPunchSymmetricResponse {
last_port_index: end as u32,
})
}; };
} }
return Some(1); return Ok(TryPunchSymmetricResponse { last_port_index: 1 });
} }
} }
impl UdpHolePunchRpcServer { impl UdpHolePunchRpcService {
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self { pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned()); join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned());
@@ -593,9 +616,14 @@ impl UdpHolePunchConnector {
} }
pub async fn run_as_server(&mut self) -> Result<(), Error> { pub async fn run_as_server(&mut self) -> Result<(), Error> {
self.data.peer_mgr.get_peer_rpc_mgr().run_service( self.data
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID, .peer_mgr
UdpHolePunchRpcServer::new(self.data.clone()).serve(), .get_peer_rpc_mgr()
.rpc_server()
.registry()
.register(
UdpHolePunchRpcServer::new(UdpHolePunchRpcService::new(self.data.clone())),
&self.data.global_ctx.get_network_name(),
); );
Ok(()) Ok(())
@@ -736,26 +764,26 @@ impl UdpHolePunchConnector {
.with_context(|| "failed to get udp port mapping")?; .with_context(|| "failed to get udp port mapping")?;
// client -> server: tell server the mapped port, server will return the mapped address of listening port. // client -> server: tell server the mapped port, server will return the mapped address of listening port.
let Some(remote_mapped_addr) = data let rpc_stub = data
.peer_mgr .peer_mgr
.get_peer_rpc_mgr() .get_peer_rpc_mgr()
.do_client_rpc_scoped( .rpc_client()
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID, .scoped_client::<UdpHolePunchRpcClientFactory<BaseController>>(
data.peer_mgr.my_peer_id(),
dst_peer_id, dst_peer_id,
|c| async { data.global_ctx.get_network_name(),
let client = );
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn();
let remote_mapped_addr = client let remote_mapped_addr = rpc_stub
.try_punch_hole(tarpc::context::current(), local_mapped_addr) .try_punch_hole(
.await; BaseController {},
tracing::info!(?remote_mapped_addr, ?dst_peer_id, "got remote mapped addr"); TryPunchHoleRequest {
remote_mapped_addr local_mapped_addr: Some(local_mapped_addr.into()),
}, },
) )
.await? .await?
else { .remote_mapped_addr
return Err(anyhow::anyhow!("failed to get remote mapped addr")); .ok_or(anyhow::anyhow!("failed to get remote mapped addr"))?;
};
// server: will send some punching resps, total 10 packets. // server: will send some punching resps, total 10 packets.
// client: use the socket to create UdpTunnel with UdpTunnelConnector // client: use the socket to create UdpTunnel with UdpTunnelConnector
@@ -769,9 +797,11 @@ impl UdpHolePunchConnector {
setup_sokcet2(&socket2_socket, &local_socket_addr)?; setup_sokcet2(&socket2_socket, &local_socket_addr)?;
let socket = Arc::new(UdpSocket::from_std(socket2_socket.into())?); let socket = Arc::new(UdpSocket::from_std(socket2_socket.into())?);
Ok(Self::try_connect_with_socket(socket, remote_mapped_addr) Ok(
Self::try_connect_with_socket(socket, remote_mapped_addr.into())
.await .await
.with_context(|| "UdpTunnelConnector failed to connect remote")?) .with_context(|| "UdpTunnelConnector failed to connect remote")?,
)
} }
#[tracing::instrument(err(level = Level::ERROR))] #[tracing::instrument(err(level = Level::ERROR))]
@@ -783,30 +813,28 @@ impl UdpHolePunchConnector {
return Err(anyhow::anyhow!("udp array not started")); return Err(anyhow::anyhow!("udp array not started"));
}; };
let Some(remote_mapped_addr) = data let rpc_stub = data
.peer_mgr .peer_mgr
.get_peer_rpc_mgr() .get_peer_rpc_mgr()
.do_client_rpc_scoped( .rpc_client()
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID, .scoped_client::<UdpHolePunchRpcClientFactory<BaseController>>(
data.peer_mgr.my_peer_id(),
dst_peer_id, dst_peer_id,
|c| async { data.global_ctx.get_network_name(),
let client =
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn();
let remote_mapped_addr = client
.try_punch_hole(tarpc::context::current(), "0.0.0.0:0".parse().unwrap())
.await;
tracing::debug!(
?remote_mapped_addr,
?dst_peer_id,
"hole punching symmetric got remote mapped addr"
); );
remote_mapped_addr
let local_mapped_addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
let remote_mapped_addr = rpc_stub
.try_punch_hole(
BaseController {},
TryPunchHoleRequest {
local_mapped_addr: Some(local_mapped_addr.into()),
}, },
) )
.await? .await?
else { .remote_mapped_addr
return Err(anyhow::anyhow!("failed to get remote mapped addr")); .ok_or(anyhow::anyhow!("failed to get remote mapped addr"))?
}; .into();
// try direct connect first // try direct connect first
if data.try_direct_connect.load(Ordering::Relaxed) { if data.try_direct_connect.load(Ordering::Relaxed) {
@@ -852,38 +880,26 @@ impl UdpHolePunchConnector {
let mut last_port_idx = rand::thread_rng().gen_range(0..data.shuffled_port_vec.len()); let mut last_port_idx = rand::thread_rng().gen_range(0..data.shuffled_port_vec.len());
for round in 0..5 { for round in 0..5 {
let ret = data let ret = rpc_stub
.peer_mgr
.get_peer_rpc_mgr()
.do_client_rpc_scoped(
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
dst_peer_id,
|c| async {
let client =
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c)
.spawn();
let last_port_idx = client
.try_punch_symmetric( .try_punch_symmetric(
tarpc::context::current(), BaseController {},
remote_mapped_addr, TryPunchSymmetricRequest {
port, listener_addr: Some(remote_mapped_addr.into()),
public_ips.clone(), port: port as u32,
stun_info.min_port as u16, public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(),
stun_info.max_port as u16, min_port: stun_info.min_port as u32,
tid, max_port: stun_info.max_port as u32,
transaction_id: tid,
round, round,
last_port_idx, last_port_index: last_port_idx as u32,
)
.await;
tracing::info!(?last_port_idx, ?dst_peer_id, "punch symmetric return");
last_port_idx
}, },
) )
.await; .await;
tracing::info!(?ret, "punch symmetric return");
let next_last_port_idx = match ret { let next_last_port_idx = match ret {
Ok(Some(idx)) => idx, Ok(s) => s.last_port_index as usize,
err => { Err(err) => {
tracing::error!(?err, "failed to get remote mapped addr"); tracing::error!(?err, "failed to get remote mapped addr");
rand::thread_rng().gen_range(0..data.shuffled_port_vec.len()) rand::thread_rng().gen_range(0..data.shuffled_port_vec.len())
} }
@@ -1027,11 +1043,11 @@ pub mod tests {
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use crate::rpc::{NatType, StunInfo}; use crate::common::stun::MockStunInfoCollector;
use crate::proto::common::NatType;
use crate::tunnel::common::tests::wait_for_condition; use crate::tunnel::common::tests::wait_for_condition;
use crate::{ use crate::{
common::{error::Error, stun::StunInfoCollectorTrait},
connector::udp_hole_punch::UdpHolePunchConnector, connector::udp_hole_punch::UdpHolePunchConnector,
peers::{ peers::{
peer_manager::PeerManager, peer_manager::PeerManager,
@@ -1042,31 +1058,6 @@ pub mod tests {
}, },
}; };
struct MockStunInfoCollector {
udp_nat_type: NatType,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for MockStunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
StunInfo {
udp_nat_type: self.udp_nat_type as i32,
tcp_nat_type: NatType::Unknown as i32,
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
min_port: 100,
max_port: 200,
..Default::default()
}
}
async fn get_udp_port_mapping(&self, mut port: u16) -> Result<std::net::SocketAddr, Error> {
if port == 0 {
port = 40144;
}
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
}
}
pub fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, udp_nat_type: NatType) { pub fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, udp_nat_type: NatType) {
let collector = Box::new(MockStunInfoCollector { udp_nat_type }); let collector = Box::new(MockStunInfoCollector { udp_nat_type });
peer_mgr peer_mgr

View File

@@ -1,26 +1,29 @@
#![allow(dead_code)] #![allow(dead_code)]
use std::{net::SocketAddr, time::Duration, vec}; use std::{net::SocketAddr, sync::Mutex, time::Duration, vec};
use anyhow::{Context, Ok};
use clap::{command, Args, Parser, Subcommand}; use clap::{command, Args, Parser, Subcommand};
use common::stun::StunInfoCollectorTrait; use common::stun::StunInfoCollectorTrait;
use rpc::vpn_portal_rpc_client::VpnPortalRpcClient; use proto::{
common::NatType,
peer_rpc::{GetGlobalPeerMapRequest, PeerCenterRpc, PeerCenterRpcClientFactory},
rpc_impl::standalone::StandAloneClient,
rpc_types::controller::BaseController,
};
use tokio::time::timeout; use tokio::time::timeout;
use tunnel::tcp::TcpTunnelConnector;
use utils::{list_peer_route_pair, PeerRoutePair}; use utils::{list_peer_route_pair, PeerRoutePair};
mod arch; mod arch;
mod common; mod common;
mod rpc; mod proto;
mod tunnel; mod tunnel;
mod utils; mod utils;
use crate::{ use crate::{
common::stun::StunInfoCollector, common::stun::StunInfoCollector,
rpc::{ proto::cli::*,
connector_manage_rpc_client::ConnectorManageRpcClient,
peer_center_rpc_client::PeerCenterRpcClient, peer_manage_rpc_client::PeerManageRpcClient,
*,
},
utils::{cost_to_str, float_to_str}, utils::{cost_to_str, float_to_str},
}; };
use humansize::format_size; use humansize::format_size;
@@ -114,58 +117,76 @@ struct NodeArgs {
sub_command: Option<NodeSubCommand>, sub_command: Option<NodeSubCommand>,
} }
#[derive(thiserror::Error, Debug)] type Error = anyhow::Error;
enum Error {
#[error("tonic transport error")]
TonicTransportError(#[from] tonic::transport::Error),
#[error("tonic rpc error")]
TonicRpcError(#[from] tonic::Status),
#[error("anyhow error")]
Anyhow(#[from] anyhow::Error),
}
struct CommandHandler { struct CommandHandler {
addr: String, client: Mutex<RpcClient>,
verbose: bool, verbose: bool,
} }
type RpcClient = StandAloneClient<TcpTunnelConnector>;
impl CommandHandler { impl CommandHandler {
async fn get_peer_manager_client( async fn get_peer_manager_client(
&self, &self,
) -> Result<PeerManageRpcClient<tonic::transport::Channel>, Error> { ) -> Result<Box<dyn PeerManageRpc<Controller = BaseController>>, Error> {
Ok(PeerManageRpcClient::connect(self.addr.clone()).await?) Ok(self
.client
.lock()
.unwrap()
.scoped_client::<PeerManageRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get peer manager client")?)
} }
async fn get_connector_manager_client( async fn get_connector_manager_client(
&self, &self,
) -> Result<ConnectorManageRpcClient<tonic::transport::Channel>, Error> { ) -> Result<Box<dyn ConnectorManageRpc<Controller = BaseController>>, Error> {
Ok(ConnectorManageRpcClient::connect(self.addr.clone()).await?) Ok(self
.client
.lock()
.unwrap()
.scoped_client::<ConnectorManageRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get connector manager client")?)
} }
async fn get_peer_center_client( async fn get_peer_center_client(
&self, &self,
) -> Result<PeerCenterRpcClient<tonic::transport::Channel>, Error> { ) -> Result<Box<dyn PeerCenterRpc<Controller = BaseController>>, Error> {
Ok(PeerCenterRpcClient::connect(self.addr.clone()).await?) Ok(self
.client
.lock()
.unwrap()
.scoped_client::<PeerCenterRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get peer center client")?)
} }
async fn get_vpn_portal_client( async fn get_vpn_portal_client(
&self, &self,
) -> Result<VpnPortalRpcClient<tonic::transport::Channel>, Error> { ) -> Result<Box<dyn VpnPortalRpc<Controller = BaseController>>, Error> {
Ok(VpnPortalRpcClient::connect(self.addr.clone()).await?) Ok(self
.client
.lock()
.unwrap()
.scoped_client::<VpnPortalRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get vpn portal client")?)
} }
async fn list_peers(&self) -> Result<ListPeerResponse, Error> { async fn list_peers(&self) -> Result<ListPeerResponse, Error> {
let mut client = self.get_peer_manager_client().await?; let client = self.get_peer_manager_client().await?;
let request = tonic::Request::new(ListPeerRequest::default()); let request = ListPeerRequest::default();
let response = client.list_peer(request).await?; let response = client.list_peer(BaseController {}, request).await?;
Ok(response.into_inner()) Ok(response)
} }
async fn list_routes(&self) -> Result<ListRouteResponse, Error> { async fn list_routes(&self) -> Result<ListRouteResponse, Error> {
let mut client = self.get_peer_manager_client().await?; let client = self.get_peer_manager_client().await?;
let request = tonic::Request::new(ListRouteRequest::default()); let request = ListRouteRequest::default();
let response = client.list_route(request).await?; let response = client.list_route(BaseController {}, request).await?;
Ok(response.into_inner()) Ok(response)
} }
async fn list_peer_route_pair(&self) -> Result<Vec<PeerRoutePair>, Error> { async fn list_peer_route_pair(&self) -> Result<Vec<PeerRoutePair>, Error> {
@@ -251,11 +272,10 @@ impl CommandHandler {
return Ok(()); return Ok(());
} }
let mut client = self.get_peer_manager_client().await?; let client = self.get_peer_manager_client().await?;
let node_info = client let node_info = client
.show_node_info(ShowNodeInfoRequest::default()) .show_node_info(BaseController {}, ShowNodeInfoRequest::default())
.await? .await?
.into_inner()
.node_info .node_info
.ok_or(anyhow::anyhow!("node info not found"))?; .ok_or(anyhow::anyhow!("node info not found"))?;
items.push(node_info.into()); items.push(node_info.into());
@@ -273,18 +293,20 @@ impl CommandHandler {
} }
async fn handle_route_dump(&self) -> Result<(), Error> { async fn handle_route_dump(&self) -> Result<(), Error> {
let mut client = self.get_peer_manager_client().await?; let client = self.get_peer_manager_client().await?;
let request = tonic::Request::new(DumpRouteRequest::default()); let request = DumpRouteRequest::default();
let response = client.dump_route(request).await?; let response = client.dump_route(BaseController {}, request).await?;
println!("response: {}", response.into_inner().result); println!("response: {}", response.result);
Ok(()) Ok(())
} }
async fn handle_foreign_network_list(&self) -> Result<(), Error> { async fn handle_foreign_network_list(&self) -> Result<(), Error> {
let mut client = self.get_peer_manager_client().await?; let client = self.get_peer_manager_client().await?;
let request = tonic::Request::new(ListForeignNetworkRequest::default()); let request = ListForeignNetworkRequest::default();
let response = client.list_foreign_network(request).await?; let response = client
let network_map = response.into_inner(); .list_foreign_network(BaseController {}, request)
.await?;
let network_map = response;
if self.verbose { if self.verbose {
println!("{:#?}", network_map); println!("{:#?}", network_map);
return Ok(()); return Ok(());
@@ -303,7 +325,7 @@ impl CommandHandler {
"remote_addr: {}, rx_bytes: {}, tx_bytes: {}, latency_us: {}", "remote_addr: {}, rx_bytes: {}, tx_bytes: {}, latency_us: {}",
conn.tunnel conn.tunnel
.as_ref() .as_ref()
.map(|t| t.remote_addr.clone()) .map(|t| t.remote_addr.clone().unwrap_or_default())
.unwrap_or_default(), .unwrap_or_default(),
conn.stats.as_ref().map(|s| s.rx_bytes).unwrap_or_default(), conn.stats.as_ref().map(|s| s.rx_bytes).unwrap_or_default(),
conn.stats.as_ref().map(|s| s.tx_bytes).unwrap_or_default(), conn.stats.as_ref().map(|s| s.tx_bytes).unwrap_or_default(),
@@ -334,11 +356,10 @@ impl CommandHandler {
} }
let mut items: Vec<RouteTableItem> = vec![]; let mut items: Vec<RouteTableItem> = vec![];
let mut client = self.get_peer_manager_client().await?; let client = self.get_peer_manager_client().await?;
let node_info = client let node_info = client
.show_node_info(ShowNodeInfoRequest::default()) .show_node_info(BaseController {}, ShowNodeInfoRequest::default())
.await? .await?
.into_inner()
.node_info .node_info
.ok_or(anyhow::anyhow!("node info not found"))?; .ok_or(anyhow::anyhow!("node info not found"))?;
@@ -403,10 +424,10 @@ impl CommandHandler {
} }
async fn handle_connector_list(&self) -> Result<(), Error> { async fn handle_connector_list(&self) -> Result<(), Error> {
let mut client = self.get_connector_manager_client().await?; let client = self.get_connector_manager_client().await?;
let request = tonic::Request::new(ListConnectorRequest::default()); let request = ListConnectorRequest::default();
let response = client.list_connector(request).await?; let response = client.list_connector(BaseController {}, request).await?;
println!("response: {:#?}", response.into_inner()); println!("response: {:#?}", response);
Ok(()) Ok(())
} }
} }
@@ -415,8 +436,13 @@ impl CommandHandler {
#[tracing::instrument] #[tracing::instrument]
async fn main() -> Result<(), Error> { async fn main() -> Result<(), Error> {
let cli = Cli::parse(); let cli = Cli::parse();
let client = RpcClient::new(TcpTunnelConnector::new(
format!("tcp://{}:{}", cli.rpc_portal.ip(), cli.rpc_portal.port())
.parse()
.unwrap(),
));
let handler = CommandHandler { let handler = CommandHandler {
addr: format!("http://{}:{}", cli.rpc_portal.ip(), cli.rpc_portal.port()), client: Mutex::new(client),
verbose: cli.verbose, verbose: cli.verbose,
}; };
@@ -476,11 +502,10 @@ async fn main() -> Result<(), Error> {
.unwrap(); .unwrap();
} }
SubCommand::PeerCenter => { SubCommand::PeerCenter => {
let mut peer_center_client = handler.get_peer_center_client().await?; let peer_center_client = handler.get_peer_center_client().await?;
let resp = peer_center_client let resp = peer_center_client
.get_global_peer_map(GetGlobalPeerMapRequest::default()) .get_global_peer_map(BaseController {}, GetGlobalPeerMapRequest::default())
.await? .await?;
.into_inner();
#[derive(tabled::Tabled)] #[derive(tabled::Tabled)]
struct PeerCenterTableItem { struct PeerCenterTableItem {
@@ -510,11 +535,10 @@ async fn main() -> Result<(), Error> {
); );
} }
SubCommand::VpnPortal => { SubCommand::VpnPortal => {
let mut vpn_portal_client = handler.get_vpn_portal_client().await?; let vpn_portal_client = handler.get_vpn_portal_client().await?;
let resp = vpn_portal_client let resp = vpn_portal_client
.get_vpn_portal_info(GetVpnPortalInfoRequest::default()) .get_vpn_portal_info(BaseController {}, GetVpnPortalInfoRequest::default())
.await? .await?
.into_inner()
.vpn_portal_info .vpn_portal_info
.unwrap_or_default(); .unwrap_or_default();
println!("portal_name: {}", resp.vpn_type); println!("portal_name: {}", resp.vpn_type);
@@ -529,11 +553,10 @@ async fn main() -> Result<(), Error> {
println!("connected_clients:\n{:#?}", resp.connected_clients); println!("connected_clients:\n{:#?}", resp.connected_clients);
} }
SubCommand::Node(sub_cmd) => { SubCommand::Node(sub_cmd) => {
let mut client = handler.get_peer_manager_client().await?; let client = handler.get_peer_manager_client().await?;
let node_info = client let node_info = client
.show_node_info(ShowNodeInfoRequest::default()) .show_node_info(BaseController {}, ShowNodeInfoRequest::default())
.await? .await?
.into_inner()
.node_info .node_info
.ok_or(anyhow::anyhow!("node info not found"))?; .ok_or(anyhow::anyhow!("node info not found"))?;
match sub_cmd.sub_command { match sub_cmd.sub_command {

View File

@@ -21,7 +21,7 @@ mod gateway;
mod instance; mod instance;
mod peer_center; mod peer_center;
mod peers; mod peers;
mod rpc; mod proto;
mod tunnel; mod tunnel;
mod utils; mod utils;
mod vpn_portal; mod vpn_portal;
@@ -548,7 +548,7 @@ fn print_event(msg: String) {
); );
} }
fn peer_conn_info_to_string(p: crate::rpc::PeerConnInfo) -> String { fn peer_conn_info_to_string(p: crate::proto::cli::PeerConnInfo) -> String {
format!( format!(
"my_peer_id: {}, dst_peer_id: {}, tunnel_info: {:?}", "my_peer_id: {}, dst_peer_id: {}, tunnel_info: {:?}",
p.my_peer_id, p.peer_id, p.tunnel p.my_peer_id, p.peer_id, p.tunnel

View File

@@ -8,8 +8,6 @@ use anyhow::Context;
use cidr::Ipv4Inet; use cidr::Ipv4Inet;
use tokio::{sync::Mutex, task::JoinSet}; use tokio::{sync::Mutex, task::JoinSet};
use tonic::transport::server::TcpIncoming;
use tonic::transport::Server;
use crate::common::config::ConfigLoader; use crate::common::config::ConfigLoader;
use crate::common::error::Error; use crate::common::error::Error;
@@ -26,8 +24,13 @@ use crate::peers::peer_conn::PeerConnId;
use crate::peers::peer_manager::{PeerManager, RouteAlgoType}; use crate::peers::peer_manager::{PeerManager, RouteAlgoType};
use crate::peers::rpc_service::PeerManagerRpcService; use crate::peers::rpc_service::PeerManagerRpcService;
use crate::peers::PacketRecvChanReceiver; use crate::peers::PacketRecvChanReceiver;
use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc; use crate::proto::cli::VpnPortalRpc;
use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo}; use crate::proto::cli::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
use crate::proto::peer_rpc::PeerCenterRpcServer;
use crate::proto::rpc_impl::standalone::StandAloneServer;
use crate::proto::rpc_types;
use crate::proto::rpc_types::controller::BaseController;
use crate::tunnel::tcp::TcpTunnelListener;
use crate::vpn_portal::{self, VpnPortal}; use crate::vpn_portal::{self, VpnPortal};
use super::listeners::ListenerManager; use super::listeners::ListenerManager;
@@ -104,8 +107,6 @@ pub struct Instance {
nic_ctx: ArcNicCtx, nic_ctx: ArcNicCtx,
tasks: JoinSet<()>,
peer_packet_receiver: Arc<Mutex<PacketRecvChanReceiver>>, peer_packet_receiver: Arc<Mutex<PacketRecvChanReceiver>>,
peer_manager: Arc<PeerManager>, peer_manager: Arc<PeerManager>,
listener_manager: Arc<Mutex<ListenerManager<PeerManager>>>, listener_manager: Arc<Mutex<ListenerManager<PeerManager>>>,
@@ -122,6 +123,8 @@ pub struct Instance {
#[cfg(feature = "socks5")] #[cfg(feature = "socks5")]
socks5_server: Arc<Socks5Server>, socks5_server: Arc<Socks5Server>,
rpc_server: Option<StandAloneServer<TcpTunnelListener>>,
global_ctx: ArcGlobalCtx, global_ctx: ArcGlobalCtx,
} }
@@ -170,6 +173,12 @@ impl Instance {
#[cfg(feature = "socks5")] #[cfg(feature = "socks5")]
let socks5_server = Socks5Server::new(global_ctx.clone(), peer_manager.clone(), None); let socks5_server = Socks5Server::new(global_ctx.clone(), peer_manager.clone(), None);
let rpc_server = global_ctx.config.get_rpc_portal().and_then(|s| {
Some(StandAloneServer::new(TcpTunnelListener::new(
format!("tcp://{}", s).parse().unwrap(),
)))
});
Instance { Instance {
inst_name: global_ctx.inst_name.clone(), inst_name: global_ctx.inst_name.clone(),
id, id,
@@ -177,7 +186,6 @@ impl Instance {
peer_packet_receiver: Arc::new(Mutex::new(peer_packet_receiver)), peer_packet_receiver: Arc::new(Mutex::new(peer_packet_receiver)),
nic_ctx: Arc::new(Mutex::new(None)), nic_ctx: Arc::new(Mutex::new(None)),
tasks: JoinSet::new(),
peer_manager, peer_manager,
listener_manager, listener_manager,
conn_manager, conn_manager,
@@ -193,6 +201,8 @@ impl Instance {
#[cfg(feature = "socks5")] #[cfg(feature = "socks5")]
socks5_server, socks5_server,
rpc_server,
global_ctx, global_ctx,
} }
} }
@@ -375,7 +385,7 @@ impl Instance {
self.check_dhcp_ip_conflict(); self.check_dhcp_ip_conflict();
} }
self.run_rpc_server()?; self.run_rpc_server().await?;
// run after tun device created, so listener can bind to tun device, which may be required by win 10 // run after tun device created, so listener can bind to tun device, which may be required by win 10
self.ip_proxy = Some(IpProxy::new( self.ip_proxy = Some(IpProxy::new(
@@ -441,11 +451,8 @@ impl Instance {
Ok(()) Ok(())
} }
pub async fn wait(&mut self) { pub async fn wait(&self) {
while let Some(ret) = self.tasks.join_next().await { self.peer_manager.wait().await;
tracing::info!("task finished: {:?}", ret);
ret.unwrap();
}
} }
pub fn id(&self) -> uuid::Uuid { pub fn id(&self) -> uuid::Uuid {
@@ -456,24 +463,28 @@ impl Instance {
self.peer_manager.my_peer_id() self.peer_manager.my_peer_id()
} }
fn get_vpn_portal_rpc_service(&self) -> impl VpnPortalRpc { fn get_vpn_portal_rpc_service(&self) -> impl VpnPortalRpc<Controller = BaseController> + Clone {
#[derive(Clone)]
struct VpnPortalRpcService { struct VpnPortalRpcService {
peer_mgr: Weak<PeerManager>, peer_mgr: Weak<PeerManager>,
vpn_portal: Weak<Mutex<Box<dyn VpnPortal>>>, vpn_portal: Weak<Mutex<Box<dyn VpnPortal>>>,
} }
#[tonic::async_trait] #[async_trait::async_trait]
impl VpnPortalRpc for VpnPortalRpcService { impl VpnPortalRpc for VpnPortalRpcService {
type Controller = BaseController;
async fn get_vpn_portal_info( async fn get_vpn_portal_info(
&self, &self,
_request: tonic::Request<GetVpnPortalInfoRequest>, _: BaseController,
) -> Result<tonic::Response<GetVpnPortalInfoResponse>, tonic::Status> { _request: GetVpnPortalInfoRequest,
) -> Result<GetVpnPortalInfoResponse, rpc_types::error::Error> {
let Some(vpn_portal) = self.vpn_portal.upgrade() else { let Some(vpn_portal) = self.vpn_portal.upgrade() else {
return Err(tonic::Status::unavailable("vpn portal not available")); return Err(anyhow::anyhow!("vpn portal not available").into());
}; };
let Some(peer_mgr) = self.peer_mgr.upgrade() else { let Some(peer_mgr) = self.peer_mgr.upgrade() else {
return Err(tonic::Status::unavailable("peer manager not available")); return Err(anyhow::anyhow!("peer manager not available").into());
}; };
let vpn_portal = vpn_portal.lock().await; let vpn_portal = vpn_portal.lock().await;
@@ -485,7 +496,7 @@ impl Instance {
}), }),
}; };
Ok(tonic::Response::new(ret)) Ok(ret)
} }
} }
@@ -495,46 +506,36 @@ impl Instance {
} }
} }
fn run_rpc_server(&mut self) -> Result<(), Error> { async fn run_rpc_server(&mut self) -> Result<(), Error> {
let Some(addr) = self.global_ctx.config.get_rpc_portal() else { let Some(_) = self.global_ctx.config.get_rpc_portal() else {
tracing::info!("rpc server not enabled, because rpc_portal is not set."); tracing::info!("rpc server not enabled, because rpc_portal is not set.");
return Ok(()); return Ok(());
}; };
use crate::proto::cli::*;
let peer_mgr = self.peer_manager.clone(); let peer_mgr = self.peer_manager.clone();
let conn_manager = self.conn_manager.clone(); let conn_manager = self.conn_manager.clone();
let net_ns = self.global_ctx.net_ns.clone();
let peer_center = self.peer_center.clone(); let peer_center = self.peer_center.clone();
let vpn_portal_rpc = self.get_vpn_portal_rpc_service(); let vpn_portal_rpc = self.get_vpn_portal_rpc_service();
let incoming = TcpIncoming::new(addr, true, None) let s = self.rpc_server.as_mut().unwrap();
.map_err(|e| anyhow::anyhow!("create rpc server failed. addr: {}, err: {}", addr, e))?; s.registry().register(
self.tasks.spawn(async move { PeerManageRpcServer::new(PeerManagerRpcService::new(peer_mgr)),
let _g = net_ns.guard(); "",
Server::builder() );
.add_service( s.registry().register(
crate::rpc::peer_manage_rpc_server::PeerManageRpcServer::new( ConnectorManageRpcServer::new(ConnectorManagerRpcService(conn_manager)),
PeerManagerRpcService::new(peer_mgr), "",
), );
)
.add_service( s.registry()
crate::rpc::connector_manage_rpc_server::ConnectorManageRpcServer::new( .register(PeerCenterRpcServer::new(peer_center.get_rpc_service()), "");
ConnectorManagerRpcService(conn_manager.clone()), s.registry()
), .register(VpnPortalRpcServer::new(vpn_portal_rpc), "");
)
.add_service( let _g = self.global_ctx.net_ns.guard();
crate::rpc::peer_center_rpc_server::PeerCenterRpcServer::new( Ok(s.serve().await.with_context(|| "rpc server start failed")?)
peer_center.get_rpc_service(),
),
)
.add_service(crate::rpc::vpn_portal_rpc_server::VpnPortalRpcServer::new(
vpn_portal_rpc,
))
.serve_with_incoming(incoming)
.await
.with_context(|| format!("rpc server failed. addr: {}", addr))
.unwrap();
});
Ok(())
} }
pub fn get_global_ctx(&self) -> ArcGlobalCtx { pub fn get_global_ctx(&self) -> ArcGlobalCtx {

View File

@@ -159,8 +159,16 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
let tunnel_info = ret.info().unwrap(); let tunnel_info = ret.info().unwrap();
global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted( global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted(
tunnel_info.local_addr.clone(), tunnel_info
tunnel_info.remote_addr.clone(), .local_addr
.clone()
.unwrap_or_default()
.to_string(),
tunnel_info
.remote_addr
.clone()
.unwrap_or_default()
.to_string(),
)); ));
tracing::info!(ret = ?ret, "conn accepted"); tracing::info!(ret = ?ret, "conn accepted");
let peer_manager = peer_manager.clone(); let peer_manager = peer_manager.clone();
@@ -169,8 +177,8 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
let server_ret = peer_manager.handle_tunnel(ret).await; let server_ret = peer_manager.handle_tunnel(ret).await;
if let Err(e) = &server_ret { if let Err(e) = &server_ret {
global_ctx.issue_event(GlobalCtxEvent::ConnectionError( global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
tunnel_info.local_addr, tunnel_info.local_addr.unwrap_or_default().to_string(),
tunnel_info.remote_addr, tunnel_info.remote_addr.unwrap_or_default().to_string(),
e.to_string(), e.to_string(),
)); ));
tracing::error!(error = ?e, "handle conn error"); tracing::error!(error = ?e, "handle conn error");

View File

@@ -11,9 +11,10 @@ use crate::{
}, },
instance::instance::Instance, instance::instance::Instance,
peers::rpc_service::PeerManagerRpcService, peers::rpc_service::PeerManagerRpcService,
rpc::{ proto::{
cli::{PeerInfo, Route, StunInfo}, cli::{PeerInfo, Route},
peer::GetIpListResponse, common::StunInfo,
peer_rpc::GetIpListResponse,
}, },
utils::{list_peer_route_pair, PeerRoutePair}, utils::{list_peer_route_pair, PeerRoutePair},
}; };

View File

@@ -6,11 +6,11 @@ mod gateway;
mod instance; mod instance;
mod peer_center; mod peer_center;
mod peers; mod peers;
mod proto;
mod vpn_portal; mod vpn_portal;
pub mod common; pub mod common;
pub mod launcher; pub mod launcher;
pub mod rpc;
pub mod tunnel; pub mod tunnel;
pub mod utils; pub mod utils;

View File

@@ -1,7 +1,7 @@
use std::{ use std::{
collections::BTreeSet, collections::BTreeSet,
sync::Arc, sync::Arc,
time::{Duration, Instant, SystemTime}, time::{Duration, Instant},
}; };
use crossbeam::atomic::AtomicCell; use crossbeam::atomic::AtomicCell;
@@ -18,14 +18,17 @@ use crate::{
route_trait::{RouteCostCalculator, RouteCostCalculatorInterface}, route_trait::{RouteCostCalculator, RouteCostCalculatorInterface},
rpc_service::PeerManagerRpcService, rpc_service::PeerManagerRpcService,
}, },
rpc::{GetGlobalPeerMapRequest, GetGlobalPeerMapResponse}, proto::{
peer_rpc::{
GetGlobalPeerMapRequest, GetGlobalPeerMapResponse, GlobalPeerMap, PeerCenterRpc,
PeerCenterRpcClientFactory, PeerCenterRpcServer, PeerInfoForGlobalMap,
ReportPeersRequest, ReportPeersResponse,
},
rpc_types::{self, controller::BaseController},
},
}; };
use super::{ use super::{server::PeerCenterServer, Digest, Error};
server::PeerCenterServer,
service::{GlobalPeerMap, PeerCenterService, PeerCenterServiceClient, PeerInfoForGlobalMap},
Digest, Error,
};
struct PeerCenterBase { struct PeerCenterBase {
peer_mgr: Arc<PeerManager>, peer_mgr: Arc<PeerManager>,
@@ -44,11 +47,14 @@ struct PeridicJobCtx<T> {
impl PeerCenterBase { impl PeerCenterBase {
pub async fn init(&self) -> Result<(), Error> { pub async fn init(&self) -> Result<(), Error> {
self.peer_mgr.get_peer_rpc_mgr().run_service( self.peer_mgr
SERVICE_ID, .get_peer_rpc_mgr()
PeerCenterServer::new(self.peer_mgr.my_peer_id()).serve(), .rpc_server()
.registry()
.register(
PeerCenterRpcServer::new(PeerCenterServer::new(self.peer_mgr.my_peer_id())),
&self.peer_mgr.get_global_ctx().get_network_name(),
); );
Ok(()) Ok(())
} }
@@ -70,11 +76,17 @@ impl PeerCenterBase {
async fn init_periodic_job< async fn init_periodic_job<
T: Send + Sync + 'static + Clone, T: Send + Sync + 'static + Clone,
Fut: Future<Output = Result<u32, tarpc::client::RpcError>> + Send + 'static, Fut: Future<Output = Result<u32, rpc_types::error::Error>> + Send + 'static,
>( >(
&self, &self,
job_ctx: T, job_ctx: T,
job_fn: (impl Fn(PeerCenterServiceClient, Arc<PeridicJobCtx<T>>) -> Fut + Send + Sync + 'static), job_fn: (impl Fn(
Box<dyn PeerCenterRpc<Controller = BaseController> + Send>,
Arc<PeridicJobCtx<T>>,
) -> Fut
+ Send
+ Sync
+ 'static),
) -> () { ) -> () {
let my_peer_id = self.peer_mgr.my_peer_id(); let my_peer_id = self.peer_mgr.my_peer_id();
let peer_mgr = self.peer_mgr.clone(); let peer_mgr = self.peer_mgr.clone();
@@ -96,14 +108,14 @@ impl PeerCenterBase {
tracing::trace!(?center_peer, "run periodic job"); tracing::trace!(?center_peer, "run periodic job");
let rpc_mgr = peer_mgr.get_peer_rpc_mgr(); let rpc_mgr = peer_mgr.get_peer_rpc_mgr();
let _g = lock.lock().await; let _g = lock.lock().await;
let ret = rpc_mgr let stub = rpc_mgr
.do_client_rpc_scoped(SERVICE_ID, center_peer, |c| async { .rpc_client()
let client = .scoped_client::<PeerCenterRpcClientFactory<BaseController>>(
PeerCenterServiceClient::new(tarpc::client::Config::default(), c) my_peer_id,
.spawn(); center_peer,
job_fn(client, ctx.clone()).await peer_mgr.get_global_ctx().get_network_name(),
}) );
.await; let ret = job_fn(stub, ctx.clone()).await;
drop(_g); drop(_g);
let Ok(sleep_time_ms) = ret else { let Ok(sleep_time_ms) = ret else {
@@ -130,25 +142,34 @@ impl PeerCenterBase {
} }
} }
#[derive(Clone)]
pub struct PeerCenterInstanceService { pub struct PeerCenterInstanceService {
global_peer_map: Arc<RwLock<GlobalPeerMap>>, global_peer_map: Arc<RwLock<GlobalPeerMap>>,
global_peer_map_digest: Arc<AtomicCell<Digest>>, global_peer_map_digest: Arc<AtomicCell<Digest>>,
} }
#[tonic::async_trait] #[async_trait::async_trait]
impl crate::rpc::cli::peer_center_rpc_server::PeerCenterRpc for PeerCenterInstanceService { impl PeerCenterRpc for PeerCenterInstanceService {
type Controller = BaseController;
async fn get_global_peer_map( async fn get_global_peer_map(
&self, &self,
_request: tonic::Request<GetGlobalPeerMapRequest>, _: BaseController,
) -> Result<tonic::Response<GetGlobalPeerMapResponse>, tonic::Status> { _: GetGlobalPeerMapRequest,
let global_peer_map = self.global_peer_map.read().unwrap().clone(); ) -> Result<GetGlobalPeerMapResponse, rpc_types::error::Error> {
Ok(tonic::Response::new(GetGlobalPeerMapResponse { let global_peer_map = self.global_peer_map.read().unwrap();
global_peer_map: global_peer_map Ok(GetGlobalPeerMapResponse {
.map global_peer_map: global_peer_map.map.clone(),
.into_iter() digest: Some(self.global_peer_map_digest.load()),
.map(|(k, v)| (k, v)) })
.collect(), }
}))
async fn report_peers(
&self,
_: BaseController,
_req: ReportPeersRequest,
) -> Result<ReportPeersResponse, rpc_types::error::Error> {
Err(anyhow::anyhow!("not implemented").into())
} }
} }
@@ -166,7 +187,7 @@ impl PeerCenterInstance {
PeerCenterInstance { PeerCenterInstance {
peer_mgr: peer_mgr.clone(), peer_mgr: peer_mgr.clone(),
client: Arc::new(PeerCenterBase::new(peer_mgr.clone())), client: Arc::new(PeerCenterBase::new(peer_mgr.clone())),
global_peer_map: Arc::new(RwLock::new(GlobalPeerMap::new())), global_peer_map: Arc::new(RwLock::new(GlobalPeerMap::default())),
global_peer_map_digest: Arc::new(AtomicCell::new(Digest::default())), global_peer_map_digest: Arc::new(AtomicCell::new(Digest::default())),
global_peer_map_update_time: Arc::new(AtomicCell::new(Instant::now())), global_peer_map_update_time: Arc::new(AtomicCell::new(Instant::now())),
} }
@@ -193,9 +214,6 @@ impl PeerCenterInstance {
self.client self.client
.init_periodic_job(ctx, |client, ctx| async move { .init_periodic_job(ctx, |client, ctx| async move {
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3);
if ctx if ctx
.job_ctx .job_ctx
.global_peer_map_update_time .global_peer_map_update_time
@@ -208,8 +226,13 @@ impl PeerCenterInstance {
} }
let ret = client let ret = client
.get_global_peer_map(rpc_ctx, ctx.job_ctx.global_peer_map_digest.load()) .get_global_peer_map(
.await?; BaseController {},
GetGlobalPeerMapRequest {
digest: ctx.job_ctx.global_peer_map_digest.load(),
},
)
.await;
let Ok(resp) = ret else { let Ok(resp) = ret else {
tracing::error!( tracing::error!(
@@ -219,9 +242,10 @@ impl PeerCenterInstance {
return Ok(1000); return Ok(1000);
}; };
let Some(resp) = resp else { if resp == GetGlobalPeerMapResponse::default() {
// digest match, no need to update
return Ok(5000); return Ok(5000);
}; }
tracing::info!( tracing::info!(
"get global info from center server: {:?}, digest: {:?}", "get global info from center server: {:?}, digest: {:?}",
@@ -229,8 +253,12 @@ impl PeerCenterInstance {
resp.digest resp.digest
); );
*ctx.job_ctx.global_peer_map.write().unwrap() = resp.global_peer_map; *ctx.job_ctx.global_peer_map.write().unwrap() = GlobalPeerMap {
ctx.job_ctx.global_peer_map_digest.store(resp.digest); map: resp.global_peer_map,
};
ctx.job_ctx
.global_peer_map_digest
.store(resp.digest.unwrap_or_default());
ctx.job_ctx ctx.job_ctx
.global_peer_map_update_time .global_peer_map_update_time
.store(Instant::now()); .store(Instant::now());
@@ -274,12 +302,15 @@ impl PeerCenterInstance {
return Ok(5000); return Ok(5000);
} }
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3);
let ret = client let ret = client
.report_peers(rpc_ctx, my_node_id.clone(), peers) .report_peers(
.await?; BaseController {},
ReportPeersRequest {
my_peer_id: my_node_id,
peer_infos: Some(peers),
},
)
.await;
if ret.is_ok() { if ret.is_ok() {
ctx.job_ctx.last_center_peer.store(ctx.center_peer.load()); ctx.job_ctx.last_center_peer.store(ctx.center_peer.load());
@@ -339,7 +370,7 @@ impl PeerCenterInstance {
Box::new(RouteCostCalculatorImpl { Box::new(RouteCostCalculatorImpl {
global_peer_map: self.global_peer_map.clone(), global_peer_map: self.global_peer_map.clone(),
global_peer_map_clone: GlobalPeerMap::new(), global_peer_map_clone: GlobalPeerMap::default(),
last_update_time: AtomicCell::new( last_update_time: AtomicCell::new(
self.global_peer_map_update_time.load() - Duration::from_secs(1), self.global_peer_map_update_time.load() - Duration::from_secs(1),
), ),

View File

@@ -5,9 +5,13 @@
// peer center is not guaranteed to be stable and can be changed when peer enter or leave. // peer center is not guaranteed to be stable and can be changed when peer enter or leave.
// it's used to reduce the cost to exchange infos between peers. // it's used to reduce the cost to exchange infos between peers.
use std::collections::BTreeMap;
use crate::proto::cli::PeerInfo;
use crate::proto::peer_rpc::{DirectConnectedPeerInfo, PeerInfoForGlobalMap};
pub mod instance; pub mod instance;
mod server; mod server;
mod service;
#[derive(thiserror::Error, Debug, serde::Deserialize, serde::Serialize)] #[derive(thiserror::Error, Debug, serde::Deserialize, serde::Serialize)]
pub enum Error { pub enum Error {
@@ -18,3 +22,29 @@ pub enum Error {
} }
pub type Digest = u64; pub type Digest = u64;
impl From<Vec<PeerInfo>> for PeerInfoForGlobalMap {
fn from(peers: Vec<PeerInfo>) -> Self {
let mut peer_map = BTreeMap::new();
for peer in peers {
let Some(min_lat) = peer
.conns
.iter()
.map(|conn| conn.stats.as_ref().unwrap().latency_us)
.min()
else {
continue;
};
let dp_info = DirectConnectedPeerInfo {
latency_ms: std::cmp::max(1, (min_lat as u32 / 1000) as i32),
};
// sort conn info so hash result is stable
peer_map.insert(peer.peer_id, dp_info);
}
PeerInfoForGlobalMap {
direct_peers: peer_map,
}
}
}

View File

@@ -7,15 +7,22 @@ use std::{
use crossbeam::atomic::AtomicCell; use crossbeam::atomic::AtomicCell;
use dashmap::DashMap; use dashmap::DashMap;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use tokio::{task::JoinSet}; use tokio::task::JoinSet;
use crate::{common::PeerId, rpc::DirectConnectedPeerInfo}; use crate::{
common::PeerId,
use super::{ proto::{
service::{GetGlobalPeerMapResponse, GlobalPeerMap, PeerCenterService, PeerInfoForGlobalMap}, peer_rpc::{
Digest, Error, DirectConnectedPeerInfo, GetGlobalPeerMapRequest, GetGlobalPeerMapResponse,
GlobalPeerMap, PeerCenterRpc, PeerInfoForGlobalMap, ReportPeersRequest,
ReportPeersResponse,
},
rpc_types::{self, controller::BaseController},
},
}; };
use super::Digest;
#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)] #[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
pub(crate) struct SrcDstPeerPair { pub(crate) struct SrcDstPeerPair {
src: PeerId, src: PeerId,
@@ -95,15 +102,19 @@ impl PeerCenterServer {
} }
} }
#[tarpc::server] #[async_trait::async_trait]
impl PeerCenterService for PeerCenterServer { impl PeerCenterRpc for PeerCenterServer {
type Controller = BaseController;
#[tracing::instrument()] #[tracing::instrument()]
async fn report_peers( async fn report_peers(
self, &self,
_: tarpc::context::Context, _: BaseController,
my_peer_id: PeerId, req: ReportPeersRequest,
peers: PeerInfoForGlobalMap, ) -> Result<ReportPeersResponse, rpc_types::error::Error> {
) -> Result<(), Error> { let my_peer_id = req.my_peer_id;
let peers = req.peer_infos.unwrap_or_default();
tracing::debug!("receive report_peers"); tracing::debug!("receive report_peers");
let data = get_global_data(self.my_node_id); let data = get_global_data(self.my_node_id);
@@ -125,20 +136,23 @@ impl PeerCenterService for PeerCenterServer {
data.digest data.digest
.store(PeerCenterServer::calc_global_digest(self.my_node_id)); .store(PeerCenterServer::calc_global_digest(self.my_node_id));
Ok(()) Ok(ReportPeersResponse::default())
} }
#[tracing::instrument()]
async fn get_global_peer_map( async fn get_global_peer_map(
self, &self,
_: tarpc::context::Context, _: BaseController,
digest: Digest, req: GetGlobalPeerMapRequest,
) -> Result<Option<GetGlobalPeerMapResponse>, Error> { ) -> Result<GetGlobalPeerMapResponse, rpc_types::error::Error> {
let digest = req.digest;
let data = get_global_data(self.my_node_id); let data = get_global_data(self.my_node_id);
if digest == data.digest.load() && digest != 0 { if digest == data.digest.load() && digest != 0 {
return Ok(None); return Ok(GetGlobalPeerMapResponse::default());
} }
let mut global_peer_map = GlobalPeerMap::new(); let mut global_peer_map = GlobalPeerMap::default();
for item in data.global_peer_map.iter() { for item in data.global_peer_map.iter() {
let (pair, entry) = item.pair(); let (pair, entry) = item.pair();
global_peer_map global_peer_map
@@ -151,9 +165,9 @@ impl PeerCenterService for PeerCenterServer {
.insert(pair.dst, entry.info.clone()); .insert(pair.dst, entry.info.clone());
} }
Ok(Some(GetGlobalPeerMapResponse { Ok(GetGlobalPeerMapResponse {
global_peer_map, global_peer_map: global_peer_map.map,
digest: data.digest.load(), digest: Some(data.digest.load()),
})) })
} }
} }

View File

@@ -1,64 +0,0 @@
use std::collections::BTreeMap;
use crate::{common::PeerId, rpc::DirectConnectedPeerInfo};
use super::{Digest, Error};
use crate::rpc::PeerInfo;
pub type PeerInfoForGlobalMap = crate::rpc::cli::PeerInfoForGlobalMap;
impl From<Vec<PeerInfo>> for PeerInfoForGlobalMap {
fn from(peers: Vec<PeerInfo>) -> Self {
let mut peer_map = BTreeMap::new();
for peer in peers {
let Some(min_lat) = peer
.conns
.iter()
.map(|conn| conn.stats.as_ref().unwrap().latency_us)
.min()
else {
continue;
};
let dp_info = DirectConnectedPeerInfo {
latency_ms: std::cmp::max(1, (min_lat as u32 / 1000) as i32),
};
// sort conn info so hash result is stable
peer_map.insert(peer.peer_id, dp_info);
}
PeerInfoForGlobalMap {
direct_peers: peer_map,
}
}
}
// a global peer topology map, peers can use it to find optimal path to other peers
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct GlobalPeerMap {
pub map: BTreeMap<PeerId, PeerInfoForGlobalMap>,
}
impl GlobalPeerMap {
pub fn new() -> Self {
GlobalPeerMap {
map: BTreeMap::new(),
}
}
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct GetGlobalPeerMapResponse {
pub global_peer_map: GlobalPeerMap,
pub digest: Digest,
}
#[tarpc::service]
pub trait PeerCenterService {
// report center server which peer is directly connected to me
// digest is a hash of current peer map, if digest not match, we need to transfer the whole map
async fn report_peers(my_peer_id: PeerId, peers: PeerInfoForGlobalMap) -> Result<(), Error>;
async fn get_global_peer_map(digest: Digest)
-> Result<Option<GetGlobalPeerMapResponse>, Error>;
}

View File

@@ -1,27 +1,11 @@
use std::{ use std::sync::{Arc, Mutex};
sync::Arc,
time::{Duration, SystemTime},
};
use dashmap::DashMap;
use tokio::{sync::Mutex, task::JoinSet};
use crate::{ use crate::{
common::{ common::{error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask, PeerId},
error::Error,
global_ctx::{ArcGlobalCtx, NetworkIdentity},
PeerId,
},
tunnel::packet_def::ZCPacket, tunnel::packet_def::ZCPacket,
}; };
use super::{ use super::{peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::PeerRpcManager, PacketRecvChan};
foreign_network_manager::{ForeignNetworkServiceClient, FOREIGN_NETWORK_SERVICE_ID},
peer_conn::PeerConn,
peer_map::PeerMap,
peer_rpc::PeerRpcManager,
PacketRecvChan,
};
pub struct ForeignNetworkClient { pub struct ForeignNetworkClient {
global_ctx: ArcGlobalCtx, global_ctx: ArcGlobalCtx,
@@ -29,9 +13,7 @@ pub struct ForeignNetworkClient {
my_peer_id: PeerId, my_peer_id: PeerId,
peer_map: Arc<PeerMap>, peer_map: Arc<PeerMap>,
task: Mutex<Option<ScopedTask<()>>>,
next_hop: Arc<DashMap<PeerId, PeerId>>,
tasks: Mutex<JoinSet<()>>,
} }
impl ForeignNetworkClient { impl ForeignNetworkClient {
@@ -46,17 +28,13 @@ impl ForeignNetworkClient {
global_ctx.clone(), global_ctx.clone(),
my_peer_id, my_peer_id,
)); ));
let next_hop = Arc::new(DashMap::new());
Self { Self {
global_ctx, global_ctx,
peer_rpc, peer_rpc,
my_peer_id, my_peer_id,
peer_map, peer_map,
task: Mutex::new(None),
next_hop,
tasks: Mutex::new(JoinSet::new()),
} }
} }
@@ -65,91 +43,19 @@ impl ForeignNetworkClient {
self.peer_map.add_new_peer_conn(peer_conn).await self.peer_map.add_new_peer_conn(peer_conn).await
} }
async fn collect_next_hop_in_foreign_network_task(
network_identity: NetworkIdentity,
peer_map: Arc<PeerMap>,
peer_rpc: Arc<PeerRpcManager>,
next_hop: Arc<DashMap<PeerId, PeerId>>,
) {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
peer_map.clean_peer_without_conn().await;
let new_next_hop = Self::collect_next_hop_in_foreign_network(
network_identity.clone(),
peer_map.clone(),
peer_rpc.clone(),
)
.await;
next_hop.clear();
for (k, v) in new_next_hop.into_iter() {
next_hop.insert(k, v);
}
}
}
async fn collect_next_hop_in_foreign_network(
network_identity: NetworkIdentity,
peer_map: Arc<PeerMap>,
peer_rpc: Arc<PeerRpcManager>,
) -> DashMap<PeerId, PeerId> {
let peers = peer_map.list_peers().await;
let mut tasks = JoinSet::new();
if !peers.is_empty() {
tracing::warn!(?peers, my_peer_id = ?peer_rpc.my_peer_id(), "collect next hop in foreign network");
}
for peer in peers {
let peer_rpc = peer_rpc.clone();
let network_identity = network_identity.clone();
tasks.spawn(async move {
let Ok(Some(peers_in_foreign)) = peer_rpc
.do_client_rpc_scoped(FOREIGN_NETWORK_SERVICE_ID, peer, |c| async {
let c =
ForeignNetworkServiceClient::new(tarpc::client::Config::default(), c)
.spawn();
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(2);
let ret = c.list_network_peers(rpc_ctx, network_identity).await;
ret
})
.await
else {
return (peer, vec![]);
};
(peer, peers_in_foreign)
});
}
let new_next_hop = DashMap::new();
while let Some(join_ret) = tasks.join_next().await {
let Ok((gateway, peer_ids)) = join_ret else {
tracing::error!(?join_ret, "collect next hop in foreign network failed");
continue;
};
for ret in peer_ids {
new_next_hop.insert(ret, gateway);
}
}
new_next_hop
}
pub fn has_next_hop(&self, peer_id: PeerId) -> bool { pub fn has_next_hop(&self, peer_id: PeerId) -> bool {
self.get_next_hop(peer_id).is_some() self.get_next_hop(peer_id).is_some()
} }
pub fn is_peer_public_node(&self, peer_id: &PeerId) -> bool { pub async fn list_public_peers(&self) -> Vec<PeerId> {
self.peer_map.has_peer(*peer_id) self.peer_map.list_peers().await
} }
pub fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId> { pub fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId> {
if self.peer_map.has_peer(peer_id) { if self.peer_map.has_peer(peer_id) {
return Some(peer_id.clone()); return Some(peer_id.clone());
} }
self.next_hop.get(&peer_id).map(|v| v.clone()) None
} }
pub async fn send_msg(&self, msg: ZCPacket, peer_id: PeerId) -> Result<(), Error> { pub async fn send_msg(&self, msg: ZCPacket, peer_id: PeerId) -> Result<(), Error> {
@@ -162,40 +68,32 @@ impl ForeignNetworkClient {
?next_hop, ?next_hop,
"foreign network client send msg failed" "foreign network client send msg failed"
); );
} else {
tracing::info!(
?peer_id,
?next_hop,
"foreign network client send msg success"
);
} }
return ret; return ret;
} }
Err(Error::RouteError(Some("no next hop".to_string()))) Err(Error::RouteError(Some("no next hop".to_string())))
} }
pub fn list_foreign_peers(&self) -> Vec<PeerId> {
let mut peers = vec![];
for item in self.next_hop.iter() {
if item.key() != &self.my_peer_id {
peers.push(item.key().clone());
}
}
peers
}
pub async fn run(&self) { pub async fn run(&self) {
self.tasks let peer_map = Arc::downgrade(&self.peer_map);
.lock() *self.task.lock().unwrap() = Some(
.await tokio::spawn(async move {
.spawn(Self::collect_next_hop_in_foreign_network_task( loop {
self.global_ctx.get_network_identity(), tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
self.peer_map.clone(), let Some(peer_map) = peer_map.upgrade() else {
self.peer_rpc.clone(), break;
self.next_hop.clone(), };
)); peer_map.clean_peer_without_conn().await;
} }
})
pub fn get_next_hop_table(&self) -> DashMap<PeerId, PeerId> { .into(),
let next_hop = DashMap::new(); );
for item in self.next_hop.iter() {
next_hop.insert(item.key().clone(), item.value().clone());
}
next_hop
} }
pub fn get_peer_map(&self) -> Arc<PeerMap> { pub fn get_peer_map(&self) -> Arc<PeerMap> {

View File

@@ -5,12 +5,12 @@ only forward packets of peers that directly connected to this node.
in future, with the help wo peer center we can forward packets of peers that in future, with the help wo peer center we can forward packets of peers that
connected to any node in the local network. connected to any node in the local network.
*/ */
use std::sync::Arc; use std::sync::{Arc, Weak};
use dashmap::DashMap; use dashmap::DashMap;
use tokio::{ use tokio::{
sync::{ sync::{
mpsc::{self, unbounded_channel, UnboundedReceiver, UnboundedSender}, mpsc::{self, UnboundedReceiver, UnboundedSender},
Mutex, Mutex,
}, },
task::JoinSet, task::JoinSet,
@@ -18,26 +18,35 @@ use tokio::{
use crate::{ use crate::{
common::{ common::{
config::{ConfigLoader, TomlConfigLoader},
error::Error, error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent, NetworkIdentity}, global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent, NetworkIdentity},
stun::MockStunInfoCollector,
PeerId, PeerId,
}, },
rpc::{ForeignNetworkEntryPb, ListForeignNetworkResponse, PeerInfo}, peers::route_trait::{Route, RouteInterface},
proto::{
cli::{ForeignNetworkEntryPb, ListForeignNetworkResponse, PeerInfo},
common::NatType,
},
tunnel::packet_def::{PacketType, ZCPacket}, tunnel::packet_def::{PacketType, ZCPacket},
}; };
use super::{ use super::{
peer_conn::PeerConn, peer_conn::PeerConn,
peer_map::PeerMap, peer_map::PeerMap,
peer_ospf_route::PeerRoute,
peer_rpc::{PeerRpcManager, PeerRpcManagerTransport}, peer_rpc::{PeerRpcManager, PeerRpcManagerTransport},
route_trait::NextHopPolicy, route_trait::{ArcRoute, NextHopPolicy},
PacketRecvChan, PacketRecvChanReceiver, PacketRecvChan, PacketRecvChanReceiver,
}; };
struct ForeignNetworkEntry { struct ForeignNetworkEntry {
global_ctx: ArcGlobalCtx,
network: NetworkIdentity, network: NetworkIdentity,
peer_map: Arc<PeerMap>, peer_map: Arc<PeerMap>,
relay_data: bool, relay_data: bool,
route: ArcRoute,
} }
impl ForeignNetworkEntry { impl ForeignNetworkEntry {
@@ -47,19 +56,70 @@ impl ForeignNetworkEntry {
global_ctx: ArcGlobalCtx, global_ctx: ArcGlobalCtx,
my_peer_id: PeerId, my_peer_id: PeerId,
relay_data: bool, relay_data: bool,
peer_rpc: Arc<PeerRpcManager>,
) -> Self { ) -> Self {
let peer_map = Arc::new(PeerMap::new(packet_sender, global_ctx, my_peer_id)); let config = TomlConfigLoader::default();
config.set_network_identity(network.clone());
config.set_hostname(Some(format!("PublicServer_{}", global_ctx.get_hostname())));
let foreign_global_ctx = Arc::new(GlobalCtx::new(config));
foreign_global_ctx.replace_stun_info_collector(Box::new(MockStunInfoCollector {
udp_nat_type: NatType::Unknown,
}));
let peer_map = Arc::new(PeerMap::new(
packet_sender,
foreign_global_ctx.clone(),
my_peer_id,
));
let route = PeerRoute::new(my_peer_id, foreign_global_ctx.clone(), peer_rpc);
Self { Self {
global_ctx: foreign_global_ctx,
network, network,
peer_map, peer_map,
relay_data, relay_data,
route: Arc::new(Box::new(route)),
} }
} }
async fn prepare(&self, my_peer_id: PeerId) {
struct Interface {
my_peer_id: PeerId,
peer_map: Weak<PeerMap>,
}
#[async_trait::async_trait]
impl RouteInterface for Interface {
async fn list_peers(&self) -> Vec<PeerId> {
let Some(peer_map) = self.peer_map.upgrade() else {
return vec![];
};
peer_map.list_peers_with_conn().await
}
fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
}
self.route
.open(Box::new(Interface {
my_peer_id,
peer_map: Arc::downgrade(&self.peer_map),
}))
.await
.unwrap();
self.peer_map.add_route(self.route.clone()).await;
}
} }
struct ForeignNetworkManagerData { struct ForeignNetworkManagerData {
network_peer_maps: DashMap<String, Arc<ForeignNetworkEntry>>, network_peer_maps: DashMap<String, Arc<ForeignNetworkEntry>>,
peer_network_map: DashMap<PeerId, String>, peer_network_map: DashMap<PeerId, String>,
lock: std::sync::Mutex<()>,
} }
impl ForeignNetworkManagerData { impl ForeignNetworkManagerData {
@@ -88,18 +148,27 @@ impl ForeignNetworkManagerData {
self.network_peer_maps.get(network_name).map(|v| v.clone()) self.network_peer_maps.get(network_name).map(|v| v.clone())
} }
fn remove_peer(&self, peer_id: PeerId) { fn remove_peer(&self, peer_id: PeerId, network_name: &String) {
let _l = self.lock.lock().unwrap();
self.peer_network_map.remove(&peer_id); self.peer_network_map.remove(&peer_id);
self.network_peer_maps.retain(|_, v| !v.peer_map.is_empty()); self.network_peer_maps
.remove_if(network_name, |_, v| v.peer_map.is_empty());
} }
fn clear_no_conn_peer(&self) { async fn clear_no_conn_peer(&self, network_name: &String) {
for item in self.network_peer_maps.iter() { let peer_map = self
let peer_map = item.value().peer_map.clone(); .network_peer_maps
tokio::spawn(async move { .get(network_name)
.unwrap()
.peer_map
.clone();
peer_map.clean_peer_without_conn().await; peer_map.clean_peer_without_conn().await;
});
} }
fn remove_network(&self, network_name: &String) {
let _l = self.lock.lock().unwrap();
self.peer_network_map.retain(|_, v| v != network_name);
self.network_peer_maps.remove(network_name);
} }
} }
@@ -117,11 +186,16 @@ impl PeerRpcManagerTransport for RpcTransport {
} }
async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
tracing::debug!(
"foreign network manager send rpc to peer: {:?}",
dst_peer_id
);
self.data.send_msg(msg, dst_peer_id).await self.data.send_msg(msg, dst_peer_id).await
} }
async fn recv(&self) -> Result<ZCPacket, Error> { async fn recv(&self) -> Result<ZCPacket, Error> {
if let Some(o) = self.packet_recv.lock().await.recv().await { if let Some(o) = self.packet_recv.lock().await.recv().await {
tracing::info!("recv rpc packet in foreign network manager rpc transport");
Ok(o) Ok(o)
} else { } else {
Err(Error::Unknown) Err(Error::Unknown)
@@ -131,23 +205,6 @@ impl PeerRpcManagerTransport for RpcTransport {
pub const FOREIGN_NETWORK_SERVICE_ID: u32 = 1; pub const FOREIGN_NETWORK_SERVICE_ID: u32 = 1;
#[tarpc::service]
pub trait ForeignNetworkService {
async fn list_network_peers(network_identy: NetworkIdentity) -> Option<Vec<PeerId>>;
}
#[tarpc::server]
impl ForeignNetworkService for Arc<ForeignNetworkManagerData> {
async fn list_network_peers(
self,
_: tarpc::context::Context,
network_identy: NetworkIdentity,
) -> Option<Vec<PeerId>> {
let entry = self.network_peer_maps.get(&network_identy.network_name)?;
Some(entry.peer_map.list_peers().await)
}
}
pub struct ForeignNetworkManager { pub struct ForeignNetworkManager {
my_peer_id: PeerId, my_peer_id: PeerId,
global_ctx: ArcGlobalCtx, global_ctx: ArcGlobalCtx,
@@ -175,6 +232,7 @@ impl ForeignNetworkManager {
let data = Arc::new(ForeignNetworkManagerData { let data = Arc::new(ForeignNetworkManagerData {
network_peer_maps: DashMap::new(), network_peer_maps: DashMap::new(),
peer_network_map: DashMap::new(), peer_network_map: DashMap::new(),
lock: std::sync::Mutex::new(()),
}); });
// handle rpc from foreign networks // handle rpc from foreign networks
@@ -225,17 +283,23 @@ impl ForeignNetworkManager {
return ret; return ret;
} }
let mut new_added = false;
let entry = {
let _l = self.data.lock.lock().unwrap();
let entry = self let entry = self
.data .data
.network_peer_maps .network_peer_maps
.entry(peer_conn.get_network_identity().network_name.clone()) .entry(peer_conn.get_network_identity().network_name.clone())
.or_insert_with(|| { .or_insert_with(|| {
new_added = true;
Arc::new(ForeignNetworkEntry::new( Arc::new(ForeignNetworkEntry::new(
peer_conn.get_network_identity(), peer_conn.get_network_identity(),
self.packet_sender.clone(), self.packet_sender.clone(),
self.global_ctx.clone(), self.global_ctx.clone(),
self.my_peer_id, self.my_peer_id,
!ret.is_err(), !ret.is_err(),
self.rpc_mgr.clone(),
)) ))
}) })
.clone(); .clone();
@@ -245,6 +309,14 @@ impl ForeignNetworkManager {
peer_conn.get_network_identity().network_name.clone(), peer_conn.get_network_identity().network_name.clone(),
); );
entry
};
if new_added {
entry.prepare(self.my_peer_id).await;
self.start_event_handler(&entry).await;
}
if entry.network != peer_conn.get_network_identity() { if entry.network != peer_conn.get_network_identity() {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"network secret not match. exp: {:?} real: {:?}", "network secret not match. exp: {:?} real: {:?}",
@@ -257,28 +329,26 @@ impl ForeignNetworkManager {
Ok(entry.peer_map.add_new_peer_conn(peer_conn).await) Ok(entry.peer_map.add_new_peer_conn(peer_conn).await)
} }
async fn start_global_event_handler(&self) { async fn start_event_handler(&self, entry: &ForeignNetworkEntry) {
let data = self.data.clone(); let data = self.data.clone();
let mut s = self.global_ctx.subscribe(); let network_name = entry.network.network_name.clone();
let (ev_tx, mut ev_rx) = unbounded_channel(); let mut s = entry.global_ctx.subscribe();
self.tasks.lock().await.spawn(async move { self.tasks.lock().await.spawn(async move {
while let Ok(e) = s.recv().await { while let Ok(e) = s.recv().await {
ev_tx.send(e).unwrap();
}
panic!("global event handler at foreign network manager exit");
});
self.tasks.lock().await.spawn(async move {
while let Some(e) = ev_rx.recv().await {
if let GlobalCtxEvent::PeerRemoved(peer_id) = &e { if let GlobalCtxEvent::PeerRemoved(peer_id) = &e {
tracing::info!(?e, "remove peer from foreign network manager"); tracing::info!(?e, "remove peer from foreign network manager");
data.remove_peer(*peer_id); data.remove_peer(*peer_id, &network_name);
} else if let GlobalCtxEvent::PeerConnRemoved(..) = &e { } else if let GlobalCtxEvent::PeerConnRemoved(..) = &e {
tracing::info!(?e, "clear no conn peer from foreign network manager"); tracing::info!(?e, "clear no conn peer from foreign network manager");
data.clear_no_conn_peer(); data.clear_no_conn_peer(&network_name).await;
} }
} }
// if lagged or recv done just remove the network
tracing::error!("global event handler at foreign network manager exit");
data.remove_network(&network_name);
}); });
self.tasks.lock().await.spawn(async move {});
} }
async fn start_packet_recv(&self) { async fn start_packet_recv(&self) {
@@ -294,10 +364,14 @@ impl ForeignNetworkManager {
tracing::warn!("invalid packet, skip"); tracing::warn!("invalid packet, skip");
continue; continue;
}; };
tracing::info!(?hdr, "recv packet in foreign network manager");
let from_peer_id = hdr.from_peer_id.get(); let from_peer_id = hdr.from_peer_id.get();
let to_peer_id = hdr.to_peer_id.get(); let to_peer_id = hdr.to_peer_id.get();
if to_peer_id == my_node_id { if to_peer_id == my_node_id {
if hdr.packet_type == PacketType::TaRpc as u8 { if hdr.packet_type == PacketType::TaRpc as u8
|| hdr.packet_type == PacketType::RpcReq as u8
|| hdr.packet_type == PacketType::RpcResp as u8
{
rpc_sender.send(packet_bytes).unwrap(); rpc_sender.send(packet_bytes).unwrap();
continue; continue;
} }
@@ -335,16 +409,9 @@ impl ForeignNetworkManager {
}); });
} }
async fn register_peer_rpc_service(&self) {
self.rpc_mgr.run();
self.rpc_mgr
.run_service(FOREIGN_NETWORK_SERVICE_ID, self.data.clone().serve())
}
pub async fn run(&self) { pub async fn run(&self) {
self.start_global_event_handler().await;
self.start_packet_recv().await; self.start_packet_recv().await;
self.register_peer_rpc_service().await; self.rpc_mgr.run();
} }
pub async fn list_foreign_networks(&self) -> ListForeignNetworkResponse { pub async fn list_foreign_networks(&self) -> ListForeignNetworkResponse {
@@ -380,8 +447,17 @@ impl ForeignNetworkManager {
} }
} }
impl Drop for ForeignNetworkManager {
fn drop(&mut self) {
self.data.peer_network_map.clear();
self.data.network_peer_maps.clear();
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::time::Duration;
use crate::{ use crate::{
common::global_ctx::tests::get_mock_global_ctx_with_network, common::global_ctx::tests::get_mock_global_ctx_with_network,
connector::udp_hole_punch::tests::{ connector::udp_hole_punch::tests::{
@@ -391,7 +467,8 @@ mod tests {
peer_manager::{PeerManager, RouteAlgoType}, peer_manager::{PeerManager, RouteAlgoType},
tests::{connect_peer_manager, wait_route_appear}, tests::{connect_peer_manager, wait_route_appear},
}, },
rpc::NatType, proto::common::NatType,
tunnel::common::tests::wait_for_condition,
}; };
use super::*; use super::*;
@@ -413,7 +490,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn foreign_network_basic() { async fn foreign_network_basic() {
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await; let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
tracing::debug!("pm_center: {:?}", pm_center.my_peer_id()); tracing::debug!("pm_center: {:?}", pm_center.my_peer_id());
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await; let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
@@ -428,8 +505,10 @@ mod tests {
wait_route_appear(pma_net1.clone(), pmb_net1.clone()) wait_route_appear(pma_net1.clone(), pmb_net1.clone())
.await .await
.unwrap(); .unwrap();
assert_eq!(1, pma_net1.list_routes().await.len()); assert_eq!(2, pma_net1.list_routes().await.len());
assert_eq!(1, pmb_net1.list_routes().await.len()); assert_eq!(2, pmb_net1.list_routes().await.len());
println!("{:?}", pmb_net1.list_routes().await);
let rpc_resp = pm_center let rpc_resp = pm_center
.get_foreign_network_manager() .get_foreign_network_manager()
@@ -440,7 +519,7 @@ mod tests {
} }
async fn foreign_network_whitelist_helper(name: String) { async fn foreign_network_whitelist_helper(name: String) {
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await; let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
tracing::debug!("pm_center: {:?}", pm_center.my_peer_id()); tracing::debug!("pm_center: {:?}", pm_center.my_peer_id());
let mut flag = pm_center.get_global_ctx().get_flags(); let mut flag = pm_center.get_global_ctx().get_flags();
flag.foreign_network_whitelist = vec!["net1".to_string(), "net2*".to_string()].join(" "); flag.foreign_network_whitelist = vec!["net1".to_string(), "net2*".to_string()].join(" ");
@@ -466,7 +545,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn only_relay_peer_rpc() { async fn only_relay_peer_rpc() {
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await; let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let mut flag = pm_center.get_global_ctx().get_flags(); let mut flag = pm_center.get_global_ctx().get_flags();
flag.foreign_network_whitelist = "".to_string(); flag.foreign_network_whitelist = "".to_string();
flag.relay_all_peer_rpc = true; flag.relay_all_peer_rpc = true;
@@ -485,8 +564,8 @@ mod tests {
wait_route_appear(pma_net1.clone(), pmb_net1.clone()) wait_route_appear(pma_net1.clone(), pmb_net1.clone())
.await .await
.unwrap(); .unwrap();
assert_eq!(1, pma_net1.list_routes().await.len()); assert_eq!(2, pma_net1.list_routes().await.len());
assert_eq!(1, pmb_net1.list_routes().await.len()); assert_eq!(2, pmb_net1.list_routes().await.len());
} }
#[tokio::test] #[tokio::test]
@@ -497,9 +576,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_foreign_network_manager() { async fn test_foreign_network_manager() {
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await; let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let pm_center2 = let pm_center2 = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
connect_peer_manager(pm_center.clone(), pm_center2.clone()).await; connect_peer_manager(pm_center.clone(), pm_center2.clone()).await;
tracing::debug!( tracing::debug!(
@@ -519,17 +597,9 @@ mod tests {
pmb_net1.my_peer_id() pmb_net1.my_peer_id()
); );
let now = std::time::Instant::now(); wait_route_appear(pma_net1.clone(), pmb_net1.clone())
let mut succ = false; .await
while now.elapsed().as_secs() < 10 { .unwrap();
let table = pma_net1.get_foreign_network_client().get_next_hop_table();
if table.len() >= 1 {
succ = true;
break;
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
assert!(succ);
assert_eq!( assert_eq!(
vec![pm_center.my_peer_id()], vec![pm_center.my_peer_id()],
@@ -547,11 +617,9 @@ mod tests {
.list_peers() .list_peers()
.await .await
); );
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
.await assert_eq!(2, pma_net1.list_routes().await.len());
.unwrap(); assert_eq!(2, pmb_net1.list_routes().await.len());
assert_eq!(1, pma_net1.list_routes().await.len());
assert_eq!(1, pmb_net1.list_routes().await.len());
let pmc_net1 = create_mock_peer_manager_for_foreign_network("net1").await; let pmc_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
connect_peer_manager(pmc_net1.clone(), pm_center.clone()).await; connect_peer_manager(pmc_net1.clone(), pm_center.clone()).await;
@@ -561,7 +629,7 @@ mod tests {
wait_route_appear(pmb_net1.clone(), pmc_net1.clone()) wait_route_appear(pmb_net1.clone(), pmc_net1.clone())
.await .await
.unwrap(); .unwrap();
assert_eq!(2, pmc_net1.list_routes().await.len()); assert_eq!(3, pmc_net1.list_routes().await.len());
tracing::debug!("pmc_net1: {:?}", pmc_net1.my_peer_id()); tracing::debug!("pmc_net1: {:?}", pmc_net1.my_peer_id());
@@ -577,8 +645,8 @@ mod tests {
wait_route_appear(pma_net2.clone(), pmb_net2.clone()) wait_route_appear(pma_net2.clone(), pmb_net2.clone())
.await .await
.unwrap(); .unwrap();
assert_eq!(1, pma_net2.list_routes().await.len()); assert_eq!(2, pma_net2.list_routes().await.len());
assert_eq!(1, pmb_net2.list_routes().await.len()); assert_eq!(2, pmb_net2.list_routes().await.len());
assert_eq!( assert_eq!(
5, 5,
@@ -635,4 +703,27 @@ mod tests {
.len() .len()
); );
} }
#[tokio::test]
async fn test_disconnect_foreign_network() {
let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
tracing::debug!("pm_center: {:?}", pm_center.my_peer_id());
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
tracing::debug!("pma_net1: {:?}", pma_net1.my_peer_id(),);
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
wait_for_condition(
|| async { pma_net1.list_routes().await.len() == 1 },
Duration::from_secs(5),
)
.await;
drop(pm_center);
wait_for_condition(
|| async { pma_net1.list_routes().await.len() == 0 },
Duration::from_secs(5),
)
.await;
}
} }

View File

@@ -5,7 +5,6 @@ pub mod peer_conn_ping;
pub mod peer_manager; pub mod peer_manager;
pub mod peer_map; pub mod peer_map;
pub mod peer_ospf_route; pub mod peer_ospf_route;
pub mod peer_rip_route;
pub mod peer_rpc; pub mod peer_rpc;
pub mod route_trait; pub mod route_trait;
pub mod rpc_service; pub mod rpc_service;

View File

@@ -11,7 +11,7 @@ use super::{
peer_conn::{PeerConn, PeerConnId}, peer_conn::{PeerConn, PeerConnId},
PacketRecvChan, PacketRecvChan,
}; };
use crate::rpc::PeerConnInfo; use crate::proto::cli::PeerConnInfo;
use crate::{ use crate::{
common::{ common::{
error::Error, error::Error,

View File

@@ -29,8 +29,18 @@ use crate::{
global_ctx::ArcGlobalCtx, global_ctx::ArcGlobalCtx,
PeerId, PeerId,
}, },
rpc::{HandshakeRequest, PeerConnInfo, PeerConnStats, TunnelInfo}, proto::{
tunnel::{filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter}, mpsc::{MpscTunnel, MpscTunnelSender}, packet_def::{PacketType, ZCPacket}, stats::{Throughput, WindowLatency}, Tunnel, TunnelError, ZCPacketStream}, cli::{PeerConnInfo, PeerConnStats},
common::TunnelInfo,
peer_rpc::HandshakeRequest,
},
tunnel::{
filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter},
mpsc::{MpscTunnel, MpscTunnelSender},
packet_def::{PacketType, ZCPacket},
stats::{Throughput, WindowLatency},
Tunnel, TunnelError, ZCPacketStream,
},
}; };
use super::{peer_conn_ping::PeerConnPinger, PacketRecvChan}; use super::{peer_conn_ping::PeerConnPinger, PacketRecvChan};

View File

@@ -17,7 +17,6 @@ use tokio::{
task::JoinSet, task::JoinSet,
}; };
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use tokio_util::bytes::Bytes;
use crate::{ use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId}, common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
@@ -27,6 +26,7 @@ use crate::{
route_trait::{NextHopPolicy, RouteInterface}, route_trait::{NextHopPolicy, RouteInterface},
PeerPacketFilter, PeerPacketFilter,
}, },
proto::cli,
tunnel::{ tunnel::{
self, self,
packet_def::{PacketType, ZCPacket}, packet_def::{PacketType, ZCPacket},
@@ -41,7 +41,6 @@ use super::{
peer_conn::PeerConnId, peer_conn::PeerConnId,
peer_map::PeerMap, peer_map::PeerMap,
peer_ospf_route::PeerRoute, peer_ospf_route::PeerRoute,
peer_rip_route::BasicRoute,
peer_rpc::PeerRpcManager, peer_rpc::PeerRpcManager,
route_trait::{ArcRoute, Route}, route_trait::{ArcRoute, Route},
BoxNicPacketFilter, BoxPeerPacketFilter, PacketRecvChanReceiver, BoxNicPacketFilter, BoxPeerPacketFilter, PacketRecvChanReceiver,
@@ -75,7 +74,15 @@ impl PeerRpcManagerTransport for RpcTransport {
.ok_or(Error::Unknown)?; .ok_or(Error::Unknown)?;
let peers = self.peers.upgrade().ok_or(Error::Unknown)?; let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
if let Some(gateway_id) = peers if foreign_peers.has_next_hop(dst_peer_id) {
// do not encrypt for data sending to public server
tracing::debug!(
?dst_peer_id,
?self.my_peer_id,
"failed to send msg to peer, try foreign network",
);
foreign_peers.send_msg(msg, dst_peer_id).await
} else if let Some(gateway_id) = peers
.get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop) .get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop)
.await .await
{ {
@@ -88,20 +95,11 @@ impl PeerRpcManagerTransport for RpcTransport {
self.encryptor self.encryptor
.encrypt(&mut msg) .encrypt(&mut msg)
.with_context(|| "encrypt failed")?; .with_context(|| "encrypt failed")?;
if peers.has_peer(gateway_id) {
peers.send_msg_directly(msg, gateway_id).await peers.send_msg_directly(msg, gateway_id).await
} else if foreign_peers.has_next_hop(dst_peer_id) { } else {
if !foreign_peers.is_peer_public_node(&dst_peer_id) { foreign_peers.send_msg(msg, gateway_id).await
// do not encrypt for msg sending to public node
self.encryptor
.encrypt(&mut msg)
.with_context(|| "encrypt failed")?;
} }
tracing::debug!(
?dst_peer_id,
?self.my_peer_id,
"failed to send msg to peer, try foreign network",
);
foreign_peers.send_msg(msg, dst_peer_id).await
} else { } else {
Err(Error::RouteError(Some(format!( Err(Error::RouteError(Some(format!(
"peermgr RpcTransport no route for dst_peer_id: {}", "peermgr RpcTransport no route for dst_peer_id: {}",
@@ -120,13 +118,11 @@ impl PeerRpcManagerTransport for RpcTransport {
} }
pub enum RouteAlgoType { pub enum RouteAlgoType {
Rip,
Ospf, Ospf,
None, None,
} }
enum RouteAlgoInst { enum RouteAlgoInst {
Rip(Arc<BasicRoute>),
Ospf(Arc<PeerRoute>), Ospf(Arc<PeerRoute>),
None, None,
} }
@@ -217,9 +213,6 @@ impl PeerManager {
let peer_rpc_mgr = Arc::new(PeerRpcManager::new(rpc_tspt.clone())); let peer_rpc_mgr = Arc::new(PeerRpcManager::new(rpc_tspt.clone()));
let route_algo_inst = match route_algo { let route_algo_inst = match route_algo {
RouteAlgoType::Rip => {
RouteAlgoInst::Rip(Arc::new(BasicRoute::new(my_peer_id, global_ctx.clone())))
}
RouteAlgoType::Ospf => RouteAlgoInst::Ospf(PeerRoute::new( RouteAlgoType::Ospf => RouteAlgoInst::Ospf(PeerRoute::new(
my_peer_id, my_peer_id,
global_ctx.clone(), global_ctx.clone(),
@@ -438,7 +431,10 @@ impl PeerManager {
impl PeerPacketFilter for PeerRpcPacketProcessor { impl PeerPacketFilter for PeerRpcPacketProcessor {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> { async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
let hdr = packet.peer_manager_header().unwrap(); let hdr = packet.peer_manager_header().unwrap();
if hdr.packet_type == PacketType::TaRpc as u8 { if hdr.packet_type == PacketType::TaRpc as u8
|| hdr.packet_type == PacketType::RpcReq as u8
|| hdr.packet_type == PacketType::RpcResp as u8
{
self.peer_rpc_tspt_sender.send(packet).unwrap(); self.peer_rpc_tspt_sender.send(packet).unwrap();
None None
} else { } else {
@@ -477,33 +473,11 @@ impl PeerManager {
return vec![]; return vec![];
}; };
let mut peers = foreign_client.list_foreign_peers(); let mut peers = foreign_client.list_public_peers().await;
peers.extend(peer_map.list_peers_with_conn().await); peers.extend(peer_map.list_peers_with_conn().await);
peers peers
} }
async fn send_route_packet(
&self,
msg: Bytes,
_route_id: u8,
dst_peer_id: PeerId,
) -> Result<(), Error> {
let foreign_client = self
.foreign_network_client
.upgrade()
.ok_or(Error::Unknown)?;
let peer_map = self.peers.upgrade().ok_or(Error::Unknown)?;
let mut zc_packet = ZCPacket::new_with_payload(&msg);
zc_packet.fill_peer_manager_hdr(
self.my_peer_id,
dst_peer_id,
PacketType::Route as u8,
);
if foreign_client.has_next_hop(dst_peer_id) {
foreign_client.send_msg(zc_packet, dst_peer_id).await
} else {
peer_map.send_msg_directly(zc_packet, dst_peer_id).await
}
}
fn my_peer_id(&self) -> PeerId { fn my_peer_id(&self) -> PeerId {
self.my_peer_id self.my_peer_id
} }
@@ -525,13 +499,12 @@ impl PeerManager {
pub fn get_route(&self) -> Box<dyn Route + Send + Sync + 'static> { pub fn get_route(&self) -> Box<dyn Route + Send + Sync + 'static> {
match &self.route_algo_inst { match &self.route_algo_inst {
RouteAlgoInst::Rip(route) => Box::new(route.clone()),
RouteAlgoInst::Ospf(route) => Box::new(route.clone()), RouteAlgoInst::Ospf(route) => Box::new(route.clone()),
RouteAlgoInst::None => panic!("no route"), RouteAlgoInst::None => panic!("no route"),
} }
} }
pub async fn list_routes(&self) -> Vec<crate::rpc::Route> { pub async fn list_routes(&self) -> Vec<cli::Route> {
self.get_route().list_routes().await self.get_route().list_routes().await
} }
@@ -649,13 +622,23 @@ impl PeerManager {
.get_gateway_peer_id(*peer_id, next_hop_policy.clone()) .get_gateway_peer_id(*peer_id, next_hop_policy.clone())
.await .await
{ {
if self.peers.has_peer(gateway) {
if let Err(e) = self.peers.send_msg_directly(msg, gateway).await { if let Err(e) = self.peers.send_msg_directly(msg, gateway).await {
errs.push(e); errs.push(e);
} }
} else if self.foreign_network_client.has_next_hop(*peer_id) { } else if self.foreign_network_client.has_next_hop(gateway) {
if let Err(e) = self.foreign_network_client.send_msg(msg, *peer_id).await { if let Err(e) = self.foreign_network_client.send_msg(msg, gateway).await {
errs.push(e); errs.push(e);
} }
} else {
tracing::warn!(
?gateway,
?peer_id,
"cannot send msg to peer through gateway"
);
}
} else {
tracing::debug!(?peer_id, "no gateway for peer");
} }
} }
@@ -693,7 +676,6 @@ impl PeerManager {
pub async fn run(&self) -> Result<(), Error> { pub async fn run(&self) -> Result<(), Error> {
match &self.route_algo_inst { match &self.route_algo_inst {
RouteAlgoInst::Ospf(route) => self.add_route(route.clone()).await, RouteAlgoInst::Ospf(route) => self.add_route(route.clone()).await,
RouteAlgoInst::Rip(route) => self.add_route(route.clone()).await,
RouteAlgoInst::None => {} RouteAlgoInst::None => {}
}; };
@@ -732,13 +714,6 @@ impl PeerManager {
self.nic_channel.clone() self.nic_channel.clone()
} }
pub fn get_basic_route(&self) -> Arc<BasicRoute> {
match &self.route_algo_inst {
RouteAlgoInst::Rip(route) => route.clone(),
_ => panic!("not rip route"),
}
}
pub fn get_foreign_network_manager(&self) -> Arc<ForeignNetworkManager> { pub fn get_foreign_network_manager(&self) -> Arc<ForeignNetworkManager> {
self.foreign_network_manager.clone() self.foreign_network_manager.clone()
} }
@@ -747,8 +722,8 @@ impl PeerManager {
self.foreign_network_client.clone() self.foreign_network_client.clone()
} }
pub fn get_my_info(&self) -> crate::rpc::NodeInfo { pub fn get_my_info(&self) -> cli::NodeInfo {
crate::rpc::NodeInfo { cli::NodeInfo {
peer_id: self.my_peer_id, peer_id: self.my_peer_id,
ipv4_addr: self ipv4_addr: self
.global_ctx .global_ctx
@@ -774,6 +749,12 @@ impl PeerManager {
version: env!("CARGO_PKG_VERSION").to_string(), version: env!("CARGO_PKG_VERSION").to_string(),
} }
} }
pub async fn wait(&self) {
while !self.tasks.lock().await.is_empty() {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
} }
#[cfg(test)] #[cfg(test)]
@@ -789,12 +770,11 @@ mod tests {
instance::listeners::get_listener_by_url, instance::listeners::get_listener_by_url,
peers::{ peers::{
peer_manager::RouteAlgoType, peer_manager::RouteAlgoType,
peer_rpc::tests::{MockService, TestRpcService, TestRpcServiceClient}, peer_rpc::tests::register_service,
tests::{connect_peer_manager, wait_route_appear}, tests::{connect_peer_manager, wait_route_appear},
}, },
rpc::NatType, proto::common::NatType,
tunnel::common::tests::wait_for_condition, tunnel::{common::tests::wait_for_condition, TunnelConnector, TunnelListener},
tunnel::{TunnelConnector, TunnelListener},
}; };
use super::PeerManager; use super::PeerManager;
@@ -857,25 +837,18 @@ mod tests {
#[values("tcp", "udp", "wg", "quic")] proto1: &str, #[values("tcp", "udp", "wg", "quic")] proto1: &str,
#[values("tcp", "udp", "wg", "quic")] proto2: &str, #[values("tcp", "udp", "wg", "quic")] proto2: &str,
) { ) {
use crate::proto::{
rpc_impl::RpcController,
tests::{GreetingClientFactory, SayHelloRequest},
};
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_a.get_peer_rpc_mgr().run_service( register_service(&peer_mgr_a.peer_rpc_mgr, "", 0, "hello a");
100,
MockService {
prefix: "hello a".to_owned(),
}
.serve(),
);
let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_c.get_peer_rpc_mgr().run_service( register_service(&peer_mgr_c.peer_rpc_mgr, "", 0, "hello c");
100,
MockService {
prefix: "hello c".to_owned(),
}
.serve(),
);
let mut listener1 = get_listener_by_url( let mut listener1 = get_listener_by_url(
&format!("{}://0.0.0.0:31013", proto1).parse().unwrap(), &format!("{}://0.0.0.0:31013", proto1).parse().unwrap(),
@@ -913,16 +886,26 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let ret = peer_mgr_a let stub = peer_mgr_a
.get_peer_rpc_mgr() .peer_rpc_mgr
.do_client_rpc_scoped(100, peer_mgr_c.my_peer_id(), |c| async { .rpc_client()
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); .scoped_client::<GreetingClientFactory<RpcController>>(
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await; peer_mgr_a.my_peer_id,
ret peer_mgr_c.my_peer_id,
}) "".to_string(),
);
let ret = stub
.say_hello(
RpcController {},
SayHelloRequest {
name: "abc".to_string(),
},
)
.await .await
.unwrap(); .unwrap();
assert_eq!(ret, "hello c abc");
assert_eq!(ret.greeting, "hello c abc!");
} }
#[tokio::test] #[tokio::test]

View File

@@ -10,7 +10,7 @@ use crate::{
global_ctx::{ArcGlobalCtx, GlobalCtxEvent}, global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
PeerId, PeerId,
}, },
rpc::PeerConnInfo, proto::cli::PeerConnInfo,
tunnel::packet_def::ZCPacket, tunnel::packet_def::ZCPacket,
tunnel::TunnelError, tunnel::TunnelError,
}; };
@@ -66,7 +66,7 @@ impl PeerMap {
} }
pub fn has_peer(&self, peer_id: PeerId) -> bool { pub fn has_peer(&self, peer_id: PeerId) -> bool {
self.peer_map.contains_key(&peer_id) peer_id == self.my_peer_id || self.peer_map.contains_key(&peer_id)
} }
pub async fn send_msg_directly(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { pub async fn send_msg_directly(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
@@ -113,12 +113,10 @@ impl PeerMap {
.get_next_hop_with_policy(dst_peer_id, policy.clone()) .get_next_hop_with_policy(dst_peer_id, policy.clone())
.await .await
{ {
// for foreign network, gateway_peer_id may not connect to me // NOTIC: for foreign network, gateway_peer_id may not connect to me
if self.has_peer(gateway_peer_id) {
return Some(gateway_peer_id); return Some(gateway_peer_id);
} }
} }
}
None None
} }

View File

@@ -25,7 +25,18 @@ use tokio::{
use crate::{ use crate::{
common::{global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId}, common::{global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
peers::route_trait::{Route, RouteInterfaceBox}, peers::route_trait::{Route, RouteInterfaceBox},
rpc::{NatType, StunInfo}, proto::common::{NatType, StunInfo},
proto::{
peer_rpc::{
OspfRouteRpc, OspfRouteRpcClientFactory, OspfRouteRpcServer, PeerIdVersion,
RoutePeerInfo, RoutePeerInfos, SyncRouteInfoError, SyncRouteInfoRequest,
SyncRouteInfoResponse,
},
rpc_types::{
self,
controller::{BaseController, Controller},
},
},
}; };
use super::{ use super::{
@@ -76,31 +87,17 @@ impl From<Version> for AtomicVersion {
} }
} }
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
struct RoutePeerInfo {
// means next hop in route table.
peer_id: PeerId,
inst_id: uuid::Uuid,
cost: u8,
ipv4_addr: Option<Ipv4Addr>,
proxy_cidrs: Vec<String>,
hostname: Option<String>,
udp_stun_info: i8,
last_update: SystemTime,
version: Version,
}
impl RoutePeerInfo { impl RoutePeerInfo {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
peer_id: 0, peer_id: 0,
inst_id: uuid::Uuid::nil(), inst_id: Some(uuid::Uuid::nil().into()),
cost: 0, cost: 0,
ipv4_addr: None, ipv4_addr: None,
proxy_cidrs: Vec::new(), proxy_cidrs: Vec::new(),
hostname: None, hostname: None,
udp_stun_info: 0, udp_stun_info: 0,
last_update: SystemTime::now(), last_update: Some(SystemTime::now().into()),
version: 0, version: 0,
} }
} }
@@ -108,9 +105,9 @@ impl RoutePeerInfo {
pub fn update_self(&self, my_peer_id: PeerId, global_ctx: &ArcGlobalCtx) -> Self { pub fn update_self(&self, my_peer_id: PeerId, global_ctx: &ArcGlobalCtx) -> Self {
let mut new = Self { let mut new = Self {
peer_id: my_peer_id, peer_id: my_peer_id,
inst_id: global_ctx.get_id(), inst_id: Some(global_ctx.get_id().into()),
cost: 0, cost: 0,
ipv4_addr: global_ctx.get_ipv4(), ipv4_addr: global_ctx.get_ipv4().map(|x| x.into()),
proxy_cidrs: global_ctx proxy_cidrs: global_ctx
.get_proxy_cidrs() .get_proxy_cidrs()
.iter() .iter()
@@ -121,20 +118,22 @@ impl RoutePeerInfo {
udp_stun_info: global_ctx udp_stun_info: global_ctx
.get_stun_info_collector() .get_stun_info_collector()
.get_stun_info() .get_stun_info()
.udp_nat_type as i8, .udp_nat_type,
// following fields do not participate in comparison. // following fields do not participate in comparison.
last_update: self.last_update, last_update: self.last_update,
version: self.version, version: self.version,
}; };
let need_update_periodically = if let Ok(d) = new.last_update.elapsed() { let need_update_periodically = if let Ok(Ok(d)) =
SystemTime::try_from(new.last_update.unwrap()).map(|x| x.elapsed())
{
d > UPDATE_PEER_INFO_PERIOD d > UPDATE_PEER_INFO_PERIOD
} else { } else {
true true
}; };
if new != *self || need_update_periodically { if new != *self || need_update_periodically {
new.last_update = SystemTime::now(); new.last_update = Some(SystemTime::now().into());
new.version += 1; new.version += 1;
} }
@@ -142,9 +141,9 @@ impl RoutePeerInfo {
} }
} }
impl Into<crate::rpc::Route> for RoutePeerInfo { impl Into<crate::proto::cli::Route> for RoutePeerInfo {
fn into(self) -> crate::rpc::Route { fn into(self) -> crate::proto::cli::Route {
crate::rpc::Route { crate::proto::cli::Route {
peer_id: self.peer_id, peer_id: self.peer_id,
ipv4_addr: if let Some(ipv4_addr) = self.ipv4_addr { ipv4_addr: if let Some(ipv4_addr) = self.ipv4_addr {
ipv4_addr.to_string() ipv4_addr.to_string()
@@ -162,7 +161,7 @@ impl Into<crate::rpc::Route> for RoutePeerInfo {
} }
Some(stun_info) Some(stun_info)
}, },
inst_id: self.inst_id.to_string(), inst_id: self.inst_id.map(|x| x.to_string()).unwrap_or_default(),
version: env!("CARGO_PKG_VERSION").to_string(), version: env!("CARGO_PKG_VERSION").to_string(),
} }
} }
@@ -174,6 +173,35 @@ struct RouteConnBitmap {
bitmap: Vec<u8>, bitmap: Vec<u8>,
} }
impl Into<crate::proto::peer_rpc::RouteConnBitmap> for RouteConnBitmap {
fn into(self) -> crate::proto::peer_rpc::RouteConnBitmap {
crate::proto::peer_rpc::RouteConnBitmap {
peer_ids: self
.peer_ids
.into_iter()
.map(|x| PeerIdVersion {
peer_id: x.0,
version: x.1,
})
.collect(),
bitmap: self.bitmap,
}
}
}
impl From<crate::proto::peer_rpc::RouteConnBitmap> for RouteConnBitmap {
fn from(v: crate::proto::peer_rpc::RouteConnBitmap) -> Self {
RouteConnBitmap {
peer_ids: v
.peer_ids
.into_iter()
.map(|x| (x.peer_id, x.version))
.collect(),
bitmap: v.bitmap,
}
}
}
impl RouteConnBitmap { impl RouteConnBitmap {
fn new() -> Self { fn new() -> Self {
RouteConnBitmap { RouteConnBitmap {
@@ -200,28 +228,7 @@ impl RouteConnBitmap {
} }
} }
#[derive(Debug, Serialize, Deserialize, Clone)] type Error = SyncRouteInfoError;
enum Error {
DuplicatePeerId,
Stopped,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct SyncRouteInfoResponse {
is_initiator: bool,
session_id: SessionId,
}
#[tarpc::service]
trait RouteService {
async fn sync_route_info(
my_peer_id: PeerId,
my_session_id: SessionId,
is_initiator: bool,
peer_infos: Option<Vec<RoutePeerInfo>>,
conn_bitmap: Option<RouteConnBitmap>,
) -> Result<SyncRouteInfoResponse, Error>;
}
// constructed with all infos synced from all peers. // constructed with all infos synced from all peers.
#[derive(Debug)] #[derive(Debug)]
@@ -299,7 +306,7 @@ impl SyncedRouteInfo {
for mut route_info in peer_infos.iter().map(Clone::clone) { for mut route_info in peer_infos.iter().map(Clone::clone) {
// time between peers may not be synchronized, so update last_update to local now. // time between peers may not be synchronized, so update last_update to local now.
// note only last_update with larger version will be updated to local saved peer info. // note only last_update with larger version will be updated to local saved peer info.
route_info.last_update = SystemTime::now(); route_info.last_update = Some(SystemTime::now().into());
self.peer_infos self.peer_infos
.entry(route_info.peer_id) .entry(route_info.peer_id)
@@ -581,7 +588,7 @@ impl RouteTable {
let info = item.value(); let info = item.value();
if let Some(ipv4_addr) = info.ipv4_addr { if let Some(ipv4_addr) = info.ipv4_addr {
self.ipv4_peer_id_map.insert(ipv4_addr, *peer_id); self.ipv4_peer_id_map.insert(ipv4_addr.into(), *peer_id);
} }
for cidr in info.proxy_cidrs.iter() { for cidr in info.proxy_cidrs.iter() {
@@ -996,7 +1003,8 @@ impl PeerRouteServiceImpl {
let now = SystemTime::now(); let now = SystemTime::now();
let mut to_remove = Vec::new(); let mut to_remove = Vec::new();
for item in self.synced_route_info.peer_infos.iter() { for item in self.synced_route_info.peer_infos.iter() {
if let Ok(d) = now.duration_since(item.value().last_update) { if let Ok(d) = now.duration_since(item.value().last_update.unwrap().try_into().unwrap())
{
if d > REMOVE_DEAD_PEER_INFO_AFTER { if d > REMOVE_DEAD_PEER_INFO_AFTER {
to_remove.push(*item.key()); to_remove.push(*item.key());
} }
@@ -1021,7 +1029,7 @@ impl PeerRouteServiceImpl {
let my_peer_id = self.my_peer_id; let my_peer_id = self.my_peer_id;
let (peer_infos, conn_bitmap) = self.build_sync_request(&session); let (peer_infos, conn_bitmap) = self.build_sync_request(&session);
tracing::info!("my_id {:?}, pper_id: {:?}, peer_infos: {:?}, conn_bitmap: {:?}, synced_route_info: {:?} session: {:?}", tracing::info!("building sync_route request. my_id {:?}, pper_id: {:?}, peer_infos: {:?}, conn_bitmap: {:?}, synced_route_info: {:?} session: {:?}",
my_peer_id, dst_peer_id, peer_infos, conn_bitmap, self.synced_route_info, session); my_peer_id, dst_peer_id, peer_infos, conn_bitmap, self.synced_route_info, session);
if peer_infos.is_none() if peer_infos.is_none()
@@ -1035,33 +1043,60 @@ impl PeerRouteServiceImpl {
.need_sync_initiator_info .need_sync_initiator_info
.store(false, Ordering::Relaxed); .store(false, Ordering::Relaxed);
let ret = peer_rpc let rpc_stub = peer_rpc
.do_client_rpc_scoped(SERVICE_ID, dst_peer_id, |c| async { .rpc_client()
let client = RouteServiceClient::new(tarpc::client::Config::default(), c).spawn(); .scoped_client::<OspfRouteRpcClientFactory<BaseController>>(
let mut rpc_ctx = tarpc::context::current(); self.my_peer_id,
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3); dst_peer_id,
client self.global_ctx.get_network_name(),
);
let mut ctrl = BaseController {};
ctrl.set_timeout_ms(3000);
let ret = rpc_stub
.sync_route_info( .sync_route_info(
rpc_ctx, ctrl,
SyncRouteInfoRequest {
my_peer_id, my_peer_id,
session.my_session_id.load(Ordering::Relaxed), my_session_id: session.my_session_id.load(Ordering::Relaxed),
session.we_are_initiator.load(Ordering::Relaxed), is_initiator: session.we_are_initiator.load(Ordering::Relaxed),
peer_infos.clone(), peer_infos: peer_infos.clone().map(|x| RoutePeerInfos { items: x }),
conn_bitmap.clone(), conn_bitmap: conn_bitmap.clone().map(Into::into),
},
) )
.await
})
.await; .await;
match ret { if let Err(e) = &ret {
Ok(Ok(ret)) => { tracing::error!(
?ret,
?my_peer_id,
?dst_peer_id,
?e,
"sync_route_info failed"
);
session
.need_sync_initiator_info
.store(true, Ordering::Relaxed);
} else {
let resp = ret.as_ref().unwrap();
if resp.error.is_some() {
let err = resp.error.unwrap();
if err == Error::DuplicatePeerId as i32 {
panic!("duplicate peer id");
} else {
tracing::error!(?ret, ?my_peer_id, ?dst_peer_id, "sync_route_info failed");
session
.need_sync_initiator_info
.store(true, Ordering::Relaxed);
}
} else {
session.rpc_tx_count.fetch_add(1, Ordering::Relaxed); session.rpc_tx_count.fetch_add(1, Ordering::Relaxed);
session session
.dst_is_initiator .dst_is_initiator
.store(ret.is_initiator, Ordering::Relaxed); .store(resp.is_initiator, Ordering::Relaxed);
session.update_dst_session_id(ret.session_id); session.update_dst_session_id(resp.session_id);
if let Some(peer_infos) = &peer_infos { if let Some(peer_infos) = &peer_infos {
session.update_dst_saved_peer_info_version(&peer_infos); session.update_dst_saved_peer_info_version(&peer_infos);
@@ -1071,17 +1106,6 @@ impl PeerRouteServiceImpl {
session.update_dst_saved_conn_bitmap_version(&conn_bitmap); session.update_dst_saved_conn_bitmap_version(&conn_bitmap);
} }
} }
Ok(Err(Error::DuplicatePeerId)) => {
panic!("duplicate peer id");
}
_ => {
tracing::error!(?ret, ?my_peer_id, ?dst_peer_id, "sync_route_info failed");
session
.need_sync_initiator_info
.store(true, Ordering::Relaxed);
}
} }
return false; return false;
} }
@@ -1103,59 +1127,37 @@ impl Debug for RouteSessionManager {
} }
} }
#[tarpc::server] #[async_trait::async_trait]
impl RouteService for RouteSessionManager { impl OspfRouteRpc for RouteSessionManager {
type Controller = BaseController;
async fn sync_route_info( async fn sync_route_info(
self, &self,
_: tarpc::context::Context, _ctrl: BaseController,
from_peer_id: PeerId, request: SyncRouteInfoRequest,
from_session_id: SessionId, ) -> Result<SyncRouteInfoResponse, rpc_types::error::Error> {
is_initiator: bool, let from_peer_id = request.my_peer_id;
peer_infos: Option<Vec<RoutePeerInfo>>, let from_session_id = request.my_session_id;
conn_bitmap: Option<RouteConnBitmap>, let is_initiator = request.is_initiator;
) -> Result<SyncRouteInfoResponse, Error> { let peer_infos = request.peer_infos.map(|x| x.items);
let Some(service_impl) = self.service_impl.upgrade() else { let conn_bitmap = request.conn_bitmap.map(Into::into);
return Err(Error::Stopped);
};
let my_peer_id = service_impl.my_peer_id; let ret = self
let session = self.get_or_start_session(from_peer_id)?; .do_sync_route_info(
session.rpc_rx_count.fetch_add(1, Ordering::Relaxed);
session.update_dst_session_id(from_session_id);
if let Some(peer_infos) = &peer_infos {
service_impl.synced_route_info.update_peer_infos(
my_peer_id,
from_peer_id, from_peer_id,
peer_infos, from_session_id,
)?;
session.update_dst_saved_peer_info_version(peer_infos);
}
if let Some(conn_bitmap) = &conn_bitmap {
service_impl.synced_route_info.update_conn_map(&conn_bitmap);
session.update_dst_saved_conn_bitmap_version(conn_bitmap);
}
service_impl.update_route_table_and_cached_local_conn_bitmap();
tracing::info!(
"sync_route_info: from_peer_id: {:?}, is_initiator: {:?}, peer_infos: {:?}, conn_bitmap: {:?}, synced_route_info: {:?} session: {:?}, new_route_table: {:?}",
from_peer_id, is_initiator, peer_infos, conn_bitmap, service_impl.synced_route_info, session, service_impl.route_table);
session
.dst_is_initiator
.store(is_initiator, Ordering::Relaxed);
let is_initiator = session.we_are_initiator.load(Ordering::Relaxed);
let session_id = session.my_session_id.load(Ordering::Relaxed);
self.sync_now("sync_route_info");
Ok(SyncRouteInfoResponse {
is_initiator, is_initiator,
session_id, peer_infos,
conn_bitmap,
)
.await;
Ok(match ret {
Ok(v) => v,
Err(e) => {
let mut resp = SyncRouteInfoResponse::default();
resp.error = Some(e as i32);
resp
}
}) })
} }
} }
@@ -1366,6 +1368,60 @@ impl RouteSessionManager {
let ret = self.sync_now_broadcast.send(()); let ret = self.sync_now_broadcast.send(());
tracing::debug!(?ret, ?reason, "sync_now_broadcast.send"); tracing::debug!(?ret, ?reason, "sync_now_broadcast.send");
} }
async fn do_sync_route_info(
&self,
from_peer_id: PeerId,
from_session_id: SessionId,
is_initiator: bool,
peer_infos: Option<Vec<RoutePeerInfo>>,
conn_bitmap: Option<RouteConnBitmap>,
) -> Result<SyncRouteInfoResponse, Error> {
let Some(service_impl) = self.service_impl.upgrade() else {
return Err(Error::Stopped);
};
let my_peer_id = service_impl.my_peer_id;
let session = self.get_or_start_session(from_peer_id)?;
session.rpc_rx_count.fetch_add(1, Ordering::Relaxed);
session.update_dst_session_id(from_session_id);
if let Some(peer_infos) = &peer_infos {
service_impl.synced_route_info.update_peer_infos(
my_peer_id,
from_peer_id,
peer_infos,
)?;
session.update_dst_saved_peer_info_version(peer_infos);
}
if let Some(conn_bitmap) = &conn_bitmap {
service_impl.synced_route_info.update_conn_map(&conn_bitmap);
session.update_dst_saved_conn_bitmap_version(conn_bitmap);
}
service_impl.update_route_table_and_cached_local_conn_bitmap();
tracing::info!(
"handling sync_route_info rpc: from_peer_id: {:?}, is_initiator: {:?}, peer_infos: {:?}, conn_bitmap: {:?}, synced_route_info: {:?} session: {:?}, new_route_table: {:?}",
from_peer_id, is_initiator, peer_infos, conn_bitmap, service_impl.synced_route_info, session, service_impl.route_table);
session
.dst_is_initiator
.store(is_initiator, Ordering::Relaxed);
let is_initiator = session.we_are_initiator.load(Ordering::Relaxed);
let session_id = session.my_session_id.load(Ordering::Relaxed);
self.sync_now("sync_route_info");
Ok(SyncRouteInfoResponse {
is_initiator,
session_id,
error: None,
})
}
} }
pub struct PeerRoute { pub struct PeerRoute {
@@ -1415,7 +1471,7 @@ impl PeerRoute {
tokio::time::sleep(Duration::from_secs(60)).await; tokio::time::sleep(Duration::from_secs(60)).await;
service_impl.clear_expired_peer(); service_impl.clear_expired_peer();
// TODO: use debug log level for this. // TODO: use debug log level for this.
tracing::info!(?service_impl, "clear_expired_peer"); tracing::debug!(?service_impl, "clear_expired_peer");
} }
} }
@@ -1453,8 +1509,10 @@ impl PeerRoute {
} }
async fn start(&self) { async fn start(&self) {
self.peer_rpc self.peer_rpc.rpc_server().registry().register(
.run_service(SERVICE_ID, RouteService::serve(self.session_mgr.clone())); OspfRouteRpcServer::new(self.session_mgr.clone()),
&self.global_ctx.get_network_name(),
);
self.tasks self.tasks
.lock() .lock()
@@ -1479,6 +1537,15 @@ impl PeerRoute {
} }
} }
impl Drop for PeerRoute {
fn drop(&mut self) {
self.peer_rpc.rpc_server().registry().unregister(
OspfRouteRpcServer::new(self.session_mgr.clone()),
&self.global_ctx.get_network_name(),
);
}
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl Route for PeerRoute { impl Route for PeerRoute {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> { async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> {
@@ -1507,7 +1574,7 @@ impl Route for PeerRoute {
route_table.get_next_hop(dst_peer_id).map(|x| x.0) route_table.get_next_hop(dst_peer_id).map(|x| x.0)
} }
async fn list_routes(&self) -> Vec<crate::rpc::Route> { async fn list_routes(&self) -> Vec<crate::proto::cli::Route> {
let route_table = &self.service_impl.route_table; let route_table = &self.service_impl.route_table;
let mut routes = Vec::new(); let mut routes = Vec::new();
for item in route_table.peer_infos.iter() { for item in route_table.peer_infos.iter() {
@@ -1517,7 +1584,7 @@ impl Route for PeerRoute {
let Some(next_hop_peer) = route_table.get_next_hop(*item.key()) else { let Some(next_hop_peer) = route_table.get_next_hop(*item.key()) else {
continue; continue;
}; };
let mut route: crate::rpc::Route = item.value().clone().into(); let mut route: crate::proto::cli::Route = item.value().clone().into();
route.next_hop_peer_id = next_hop_peer.0; route.next_hop_peer_id = next_hop_peer.0;
route.cost = next_hop_peer.1; route.cost = next_hop_peer.1;
routes.push(route); routes.push(route);
@@ -1567,7 +1634,7 @@ mod tests {
route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface}, route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface},
tests::connect_peer_manager, tests::connect_peer_manager,
}, },
rpc::NatType, proto::common::NatType,
tunnel::common::tests::wait_for_condition, tunnel::common::tests::wait_for_condition,
}; };

View File

@@ -1,753 +0,0 @@
use std::{
net::Ipv4Addr,
sync::{atomic::AtomicU32, Arc},
time::{Duration, Instant},
};
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::{
sync::{Mutex, RwLock},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
peers::route_trait::{Route, RouteInterfaceBox},
rpc::{NatType, StunInfo},
tunnel::packet_def::{PacketType, ZCPacket},
};
use super::PeerPacketFilter;
const SEND_ROUTE_PERIOD_SEC: u64 = 60;
const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5;
const ROUTE_EXPIRED_SEC: u64 = 70;
type Version = u32;
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, PartialEq)]
// Derives can be passed through to the generated type:
pub struct SyncPeerInfo {
// means next hop in route table.
pub peer_id: PeerId,
pub cost: u32,
pub ipv4_addr: Option<Ipv4Addr>,
pub proxy_cidrs: Vec<String>,
pub hostname: Option<String>,
pub udp_stun_info: i8,
}
impl SyncPeerInfo {
pub fn new_self(from_peer: PeerId, global_ctx: &ArcGlobalCtx) -> Self {
SyncPeerInfo {
peer_id: from_peer,
cost: 0,
ipv4_addr: global_ctx.get_ipv4(),
proxy_cidrs: global_ctx
.get_proxy_cidrs()
.iter()
.map(|x| x.to_string())
.chain(global_ctx.get_vpn_portal_cidr().map(|x| x.to_string()))
.collect(),
hostname: Some(global_ctx.get_hostname()),
udp_stun_info: global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type as i8,
}
}
pub fn clone_for_route_table(&self, next_hop: PeerId, cost: u32, from: &Self) -> Self {
SyncPeerInfo {
peer_id: next_hop,
cost,
ipv4_addr: from.ipv4_addr.clone(),
proxy_cidrs: from.proxy_cidrs.clone(),
hostname: from.hostname.clone(),
udp_stun_info: from.udp_stun_info,
}
}
}
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
pub struct SyncPeer {
pub myself: SyncPeerInfo,
pub neighbors: Vec<SyncPeerInfo>,
// the route table version of myself
pub version: Version,
// the route table version of peer that we have received last time
pub peer_version: Option<Version>,
// if we do not have latest peer version, need_reply is true
pub need_reply: bool,
}
impl SyncPeer {
pub fn new(
from_peer: PeerId,
_to_peer: PeerId,
neighbors: Vec<SyncPeerInfo>,
global_ctx: ArcGlobalCtx,
version: Version,
peer_version: Option<Version>,
need_reply: bool,
) -> Self {
SyncPeer {
myself: SyncPeerInfo::new_self(from_peer, &global_ctx),
neighbors,
version,
peer_version,
need_reply,
}
}
}
#[derive(Debug)]
struct SyncPeerFromRemote {
packet: SyncPeer,
last_update: std::time::Instant,
}
type SyncPeerFromRemoteMap = Arc<DashMap<PeerId, SyncPeerFromRemote>>;
#[derive(Debug)]
struct RouteTable {
route_info: DashMap<PeerId, SyncPeerInfo>,
ipv4_peer_id_map: DashMap<Ipv4Addr, PeerId>,
cidr_peer_id_map: DashMap<cidr::IpCidr, PeerId>,
}
impl RouteTable {
fn new() -> Self {
RouteTable {
route_info: DashMap::new(),
ipv4_peer_id_map: DashMap::new(),
cidr_peer_id_map: DashMap::new(),
}
}
fn copy_from(&self, other: &Self) {
self.route_info.clear();
for item in other.route_info.iter() {
let (k, v) = item.pair();
self.route_info.insert(*k, v.clone());
}
self.ipv4_peer_id_map.clear();
for item in other.ipv4_peer_id_map.iter() {
let (k, v) = item.pair();
self.ipv4_peer_id_map.insert(*k, *v);
}
self.cidr_peer_id_map.clear();
for item in other.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
self.cidr_peer_id_map.insert(*k, *v);
}
}
}
#[derive(Debug, Clone)]
struct RouteVersion(Arc<AtomicU32>);
impl RouteVersion {
fn new() -> Self {
// RouteVersion(Arc::new(AtomicU32::new(rand::random())))
RouteVersion(Arc::new(AtomicU32::new(0)))
}
fn get(&self) -> Version {
self.0.load(std::sync::atomic::Ordering::Relaxed)
}
fn inc(&self) {
self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
pub struct BasicRoute {
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
route_table: Arc<RouteTable>,
sync_peer_from_remote: SyncPeerFromRemoteMap,
tasks: Mutex<JoinSet<()>>,
need_sync_notifier: Arc<tokio::sync::Notify>,
version: RouteVersion,
myself: Arc<RwLock<SyncPeerInfo>>,
last_send_time_map: Arc<DashMap<PeerId, (Version, Option<Version>, Instant)>>,
}
impl BasicRoute {
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx) -> Self {
BasicRoute {
my_peer_id,
global_ctx: global_ctx.clone(),
interface: Arc::new(Mutex::new(None)),
route_table: Arc::new(RouteTable::new()),
sync_peer_from_remote: Arc::new(DashMap::new()),
tasks: Mutex::new(JoinSet::new()),
need_sync_notifier: Arc::new(tokio::sync::Notify::new()),
version: RouteVersion::new(),
myself: Arc::new(RwLock::new(SyncPeerInfo::new_self(
my_peer_id.into(),
&global_ctx,
))),
last_send_time_map: Arc::new(DashMap::new()),
}
}
fn update_route_table(
my_id: PeerId,
sync_peer_reqs: SyncPeerFromRemoteMap,
route_table: Arc<RouteTable>,
) {
tracing::trace!(my_id = ?my_id, route_table = ?route_table, "update route table");
let new_route_table = Arc::new(RouteTable::new());
for item in sync_peer_reqs.iter() {
Self::update_route_table_with_req(my_id, &item.value().packet, new_route_table.clone());
}
route_table.copy_from(&new_route_table);
}
async fn update_myself(
my_peer_id: PeerId,
myself: &Arc<RwLock<SyncPeerInfo>>,
global_ctx: &ArcGlobalCtx,
) -> bool {
let new_myself = SyncPeerInfo::new_self(my_peer_id, &global_ctx);
if *myself.read().await != new_myself {
*myself.write().await = new_myself;
true
} else {
false
}
}
fn update_route_table_with_req(my_id: PeerId, packet: &SyncPeer, route_table: Arc<RouteTable>) {
let peer_id = packet.myself.peer_id.clone();
let update = |cost: u32, peer_info: &SyncPeerInfo| {
let node_id: PeerId = peer_info.peer_id.into();
let ret = route_table
.route_info
.entry(node_id.clone().into())
.and_modify(|info| {
if info.cost > cost {
*info = info.clone_for_route_table(peer_id, cost, &peer_info);
}
})
.or_insert(
peer_info
.clone()
.clone_for_route_table(peer_id, cost, &peer_info),
)
.value()
.clone();
if ret.cost > 6 {
tracing::error!(
"cost too large: {}, may lost connection, remove it",
ret.cost
);
route_table.route_info.remove(&node_id);
}
tracing::trace!(
"update route info, to: {:?}, gateway: {:?}, cost: {}, peer: {:?}",
node_id,
peer_id,
cost,
&peer_info
);
if let Some(ipv4) = peer_info.ipv4_addr {
route_table
.ipv4_peer_id_map
.insert(ipv4.clone(), node_id.clone().into());
}
for cidr in peer_info.proxy_cidrs.iter() {
let cidr: cidr::IpCidr = cidr.parse().unwrap();
route_table
.cidr_peer_id_map
.insert(cidr, node_id.clone().into());
}
};
for neighbor in packet.neighbors.iter() {
if neighbor.peer_id == my_id {
continue;
}
update(neighbor.cost + 1, &neighbor);
tracing::trace!("route info: {:?}", neighbor);
}
// add the sender peer to route info
update(1, &packet.myself);
tracing::trace!("my_id: {:?}, current route table: {:?}", my_id, route_table);
}
async fn send_sync_peer_request(
interface: &RouteInterfaceBox,
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
peer_id: PeerId,
route_table: Arc<RouteTable>,
my_version: Version,
peer_version: Option<Version>,
need_reply: bool,
) -> Result<(), Error> {
let mut route_info_copy: Vec<SyncPeerInfo> = Vec::new();
// copy the route info
for item in route_table.route_info.iter() {
let (k, v) = item.pair();
route_info_copy.push(v.clone().clone_for_route_table(*k, v.cost, &v));
}
let msg = SyncPeer::new(
my_peer_id,
peer_id,
route_info_copy,
global_ctx,
my_version,
peer_version,
need_reply,
);
// TODO: this may exceed the MTU of the tunnel
interface
.send_route_packet(postcard::to_allocvec(&msg).unwrap().into(), 1, peer_id)
.await
}
async fn sync_peer_periodically(&self) {
let route_table = self.route_table.clone();
let global_ctx = self.global_ctx.clone();
let my_peer_id = self.my_peer_id.clone();
let interface = self.interface.clone();
let notifier = self.need_sync_notifier.clone();
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
let myself = self.myself.clone();
let version = self.version.clone();
let last_send_time_map = self.last_send_time_map.clone();
self.tasks.lock().await.spawn(
async move {
loop {
if Self::update_myself(my_peer_id,&myself, &global_ctx).await {
version.inc();
tracing::info!(
my_id = ?my_peer_id,
version = version.get(),
"update route table version when myself changed"
);
}
let lockd_interface = interface.lock().await;
let interface = lockd_interface.as_ref().unwrap();
let last_send_time_map_new = DashMap::new();
let peers = interface.list_peers().await;
for peer in peers.iter() {
let last_send_time = last_send_time_map.get(peer).map(|v| *v).unwrap_or((0, None, Instant::now() - Duration::from_secs(3600)));
let my_version_peer_saved = sync_peer_from_remote.get(peer).and_then(|v| v.packet.peer_version);
let peer_have_latest_version = my_version_peer_saved == Some(version.get());
if peer_have_latest_version && last_send_time.2.elapsed().as_secs() < SEND_ROUTE_PERIOD_SEC {
last_send_time_map_new.insert(*peer, last_send_time);
continue;
}
tracing::trace!(
my_id = ?my_peer_id,
dst_peer_id = ?peer,
version = version.get(),
?my_version_peer_saved,
last_send_version = ?last_send_time.0,
last_send_peer_version = ?last_send_time.1,
last_send_elapse = ?last_send_time.2.elapsed().as_secs(),
"need send route info"
);
let peer_version_we_saved = sync_peer_from_remote.get(&peer).and_then(|v| Some(v.packet.version));
last_send_time_map_new.insert(*peer, (version.get(), peer_version_we_saved, Instant::now()));
let ret = Self::send_sync_peer_request(
interface,
my_peer_id.clone(),
global_ctx.clone(),
*peer,
route_table.clone(),
version.get(),
peer_version_we_saved,
!peer_have_latest_version,
)
.await;
match &ret {
Ok(_) => {
tracing::trace!("send sync peer request to peer: {}", peer);
}
Err(Error::PeerNoConnectionError(_)) => {
tracing::trace!("peer {} no connection", peer);
}
Err(e) => {
tracing::error!(
"send sync peer request to peer: {} error: {:?}",
peer,
e
);
}
};
}
last_send_time_map.clear();
for item in last_send_time_map_new.iter() {
let (k, v) = item.pair();
last_send_time_map.insert(*k, *v);
}
tokio::select! {
_ = notifier.notified() => {
tracing::trace!("sync peer request triggered by notifier");
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
tracing::trace!("sync peer request triggered by timeout");
}
}
}
}
.instrument(
tracing::info_span!("sync_peer_periodically", my_id = ?self.my_peer_id, global_ctx = ?self.global_ctx),
),
);
}
async fn check_expired_sync_peer_from_remote(&self) {
let route_table = self.route_table.clone();
let my_peer_id = self.my_peer_id.clone();
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
let notifier = self.need_sync_notifier.clone();
let interface = self.interface.clone();
let version = self.version.clone();
self.tasks.lock().await.spawn(async move {
loop {
let mut need_update_route = false;
let now = std::time::Instant::now();
let mut need_remove = Vec::new();
let connected_peers = interface.lock().await.as_ref().unwrap().list_peers().await;
for item in sync_peer_from_remote.iter() {
let (k, v) = item.pair();
if now.duration_since(v.last_update).as_secs() > ROUTE_EXPIRED_SEC
|| !connected_peers.contains(k)
{
need_update_route = true;
need_remove.insert(0, k.clone());
}
}
for k in need_remove.iter() {
tracing::warn!("remove expired sync peer: {:?}", k);
sync_peer_from_remote.remove(k);
}
if need_update_route {
Self::update_route_table(
my_peer_id,
sync_peer_from_remote.clone(),
route_table.clone(),
);
version.inc();
tracing::info!(
my_id = ?my_peer_id,
version = version.get(),
"update route table when check expired peer"
);
notifier.notify_one();
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
}
fn get_peer_id_for_proxy(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
let ipv4 = std::net::IpAddr::V4(*ipv4);
for item in self.route_table.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
if k.contains(&ipv4) {
return Some(*v);
}
}
None
}
#[tracing::instrument(skip(self, packet), fields(my_id = ?self.my_peer_id, ctx = ?self.global_ctx))]
async fn handle_route_packet(&self, src_peer_id: PeerId, packet: Bytes) {
let packet = postcard::from_bytes::<SyncPeer>(&packet).unwrap();
let p = &packet;
let mut updated = true;
assert_eq!(packet.myself.peer_id, src_peer_id);
self.sync_peer_from_remote
.entry(packet.myself.peer_id.into())
.and_modify(|v| {
if v.packet.myself == p.myself && v.packet.neighbors == p.neighbors {
updated = false;
} else {
v.packet = p.clone();
}
v.packet.version = p.version;
v.packet.peer_version = p.peer_version;
v.last_update = std::time::Instant::now();
})
.or_insert(SyncPeerFromRemote {
packet: p.clone(),
last_update: std::time::Instant::now(),
});
if updated {
Self::update_route_table(
self.my_peer_id.clone(),
self.sync_peer_from_remote.clone(),
self.route_table.clone(),
);
self.version.inc();
tracing::info!(
my_id = ?self.my_peer_id,
?p,
version = self.version.get(),
"update route table when receive route packet"
);
}
if packet.need_reply {
self.last_send_time_map
.entry(packet.myself.peer_id.into())
.and_modify(|v| {
const FAST_REPLY_DURATION: u64 =
SEND_ROUTE_PERIOD_SEC - SEND_ROUTE_FAST_REPLY_SEC;
if v.0 != self.version.get() || v.1 != Some(p.version) {
v.2 = Instant::now() - Duration::from_secs(3600);
} else if v.2.elapsed().as_secs() < FAST_REPLY_DURATION {
// do not send same version route info too frequently
v.2 = Instant::now() - Duration::from_secs(FAST_REPLY_DURATION);
}
});
}
if updated || packet.need_reply {
self.need_sync_notifier.notify_one();
}
}
}
#[async_trait]
impl Route for BasicRoute {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> {
*self.interface.lock().await = Some(interface);
self.sync_peer_periodically().await;
self.check_expired_sync_peer_from_remote().await;
Ok(1)
}
async fn close(&self) {}
async fn get_next_hop(&self, dst_peer_id: PeerId) -> Option<PeerId> {
match self.route_table.route_info.get(&dst_peer_id) {
Some(info) => {
return Some(info.peer_id.clone().into());
}
None => {
tracing::error!("no route info for dst_peer_id: {}", dst_peer_id);
return None;
}
}
}
async fn list_routes(&self) -> Vec<crate::rpc::Route> {
let mut routes = Vec::new();
let parse_route_info = |real_peer_id: PeerId, route_info: &SyncPeerInfo| {
let mut route = crate::rpc::Route::default();
route.ipv4_addr = if let Some(ipv4_addr) = route_info.ipv4_addr {
ipv4_addr.to_string()
} else {
"".to_string()
};
route.peer_id = real_peer_id;
route.next_hop_peer_id = route_info.peer_id;
route.cost = route_info.cost as i32;
route.proxy_cidrs = route_info.proxy_cidrs.clone();
route.hostname = route_info.hostname.clone().unwrap_or_default();
let mut stun_info = StunInfo::default();
if let Ok(udp_nat_type) = NatType::try_from(route_info.udp_stun_info as i32) {
stun_info.set_udp_nat_type(udp_nat_type);
}
route.stun_info = Some(stun_info);
route
};
self.route_table.route_info.iter().for_each(|item| {
routes.push(parse_route_info(*item.key(), item.value()));
});
routes
}
async fn get_peer_id_by_ipv4(&self, ipv4_addr: &Ipv4Addr) -> Option<PeerId> {
if let Some(peer_id) = self.route_table.ipv4_peer_id_map.get(ipv4_addr) {
return Some(*peer_id);
}
if let Some(peer_id) = self.get_peer_id_for_proxy(ipv4_addr) {
return Some(peer_id);
}
tracing::info!("no peer id for ipv4: {}", ipv4_addr);
return None;
}
}
#[async_trait::async_trait]
impl PeerPacketFilter for BasicRoute {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
let hdr = packet.peer_manager_header().unwrap();
if hdr.packet_type == PacketType::Route as u8 {
let b = packet.payload().to_vec();
self.handle_route_packet(hdr.from_peer_id.get(), b.into())
.await;
None
} else {
Some(packet)
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
common::{global_ctx::tests::get_mock_global_ctx, PeerId},
connector::udp_hole_punch::tests::replace_stun_info_collector,
peers::{
peer_manager::{PeerManager, RouteAlgoType},
peer_rip_route::Version,
tests::{connect_peer_manager, wait_route_appear},
},
rpc::NatType,
};
async fn create_mock_pmgr() -> Arc<PeerManager> {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Rip,
get_mock_global_ctx(),
s,
));
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
peer_mgr.run().await.unwrap();
peer_mgr
}
#[tokio::test]
async fn test_rip_route() {
let peer_mgr_a = create_mock_pmgr().await;
let peer_mgr_b = create_mock_pmgr().await;
let peer_mgr_c = create_mock_pmgr().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
let mgrs = vec![peer_mgr_a.clone(), peer_mgr_b.clone(), peer_mgr_c.clone()];
tokio::time::sleep(tokio::time::Duration::from_secs(4)).await;
let check_version = |version: Version, peer_id: PeerId, mgrs: &Vec<Arc<PeerManager>>| {
for mgr in mgrs.iter() {
tracing::warn!(
"check version: {:?}, {:?}, {:?}, {:?}",
version,
peer_id,
mgr,
mgr.get_basic_route().sync_peer_from_remote
);
assert_eq!(
version,
mgr.get_basic_route()
.sync_peer_from_remote
.get(&peer_id)
.unwrap()
.packet
.version,
);
assert_eq!(
mgr.get_basic_route()
.sync_peer_from_remote
.get(&peer_id)
.unwrap()
.packet
.peer_version
.unwrap(),
mgr.get_basic_route().version.get()
);
}
};
let check_sanity = || {
// check peer version in other peer mgr are correct.
check_version(
peer_mgr_b.get_basic_route().version.get(),
peer_mgr_b.my_peer_id(),
&vec![peer_mgr_a.clone(), peer_mgr_c.clone()],
);
check_version(
peer_mgr_a.get_basic_route().version.get(),
peer_mgr_a.my_peer_id(),
&vec![peer_mgr_b.clone()],
);
check_version(
peer_mgr_c.get_basic_route().version.get(),
peer_mgr_c.my_peer_id(),
&vec![peer_mgr_b.clone()],
);
};
check_sanity();
let versions = mgrs
.iter()
.map(|x| x.get_basic_route().version.get())
.collect::<Vec<_>>();
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
let versions2 = mgrs
.iter()
.map(|x| x.get_basic_route().version.get())
.collect::<Vec<_>>();
assert_eq!(versions, versions2);
check_sanity();
assert!(peer_mgr_a.get_basic_route().version.get() <= 3);
assert!(peer_mgr_b.get_basic_route().version.get() <= 6);
assert!(peer_mgr_c.get_basic_route().version.get() <= 3);
}
}

View File

@@ -1,27 +1,10 @@
use std::{ use std::sync::Arc;
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
time::Instant,
};
use crossbeam::atomic::AtomicCell; use futures::StreamExt;
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use prost::Message;
use tarpc::{server::Channel, transport::channel::UnboundedChannel};
use tokio::{
sync::mpsc::{self, UnboundedSender},
task::JoinSet,
};
use tracing::Instrument;
use crate::{ use crate::{
common::{error::Error, PeerId}, common::{error::Error, PeerId},
rpc::TaRpcPacket, proto::rpc_impl,
tunnel::packet_def::{PacketType, ZCPacket}, tunnel::packet_def::{PacketType, ZCPacket},
}; };
@@ -38,33 +21,11 @@ pub trait PeerRpcManagerTransport: Send + Sync + 'static {
async fn recv(&self) -> Result<ZCPacket, Error>; async fn recv(&self) -> Result<ZCPacket, Error>;
} }
type PacketSender = UnboundedSender<ZCPacket>;
struct PeerRpcEndPoint {
peer_id: PeerId,
packet_sender: PacketSender,
create_time: AtomicCell<Instant>,
finished: Arc<AtomicBool>,
tasks: JoinSet<()>,
}
type PeerRpcEndPointCreator =
Box<dyn Fn(PeerId, PeerRpcTransactId) -> PeerRpcEndPoint + Send + Sync + 'static>;
#[derive(Hash, Eq, PartialEq, Clone)]
struct PeerRpcClientCtxKey(PeerId, PeerRpcServiceId, PeerRpcTransactId);
// handle rpc request from one peer // handle rpc request from one peer
pub struct PeerRpcManager { pub struct PeerRpcManager {
service_map: Arc<DashMap<PeerRpcServiceId, PacketSender>>,
tasks: JoinSet<()>,
tspt: Arc<Box<dyn PeerRpcManagerTransport>>, tspt: Arc<Box<dyn PeerRpcManagerTransport>>,
rpc_client: rpc_impl::client::Client,
service_registry: Arc<DashMap<PeerRpcServiceId, PeerRpcEndPointCreator>>, rpc_server: rpc_impl::server::Server,
peer_rpc_endpoints: Arc<DashMap<PeerRpcClientCtxKey, PeerRpcEndPoint>>,
client_resp_receivers: Arc<DashMap<PeerRpcClientCtxKey, PacketSender>>,
transact_id: AtomicU32,
} }
impl std::fmt::Debug for PeerRpcManager { impl std::fmt::Debug for PeerRpcManager {
@@ -75,293 +36,55 @@ impl std::fmt::Debug for PeerRpcManager {
} }
} }
struct PacketMerger {
first_piece: Option<TaRpcPacket>,
pieces: Vec<TaRpcPacket>,
}
impl PacketMerger {
fn new() -> Self {
Self {
first_piece: None,
pieces: Vec::new(),
}
}
fn try_merge_pieces(&self) -> Option<TaRpcPacket> {
if self.first_piece.is_none() || self.pieces.is_empty() {
return None;
}
for p in &self.pieces {
// some piece is missing
if p.total_pieces == 0 {
return None;
}
}
// all pieces are received
let mut content = Vec::new();
for p in &self.pieces {
content.extend_from_slice(&p.content);
}
let mut tmpl_packet = self.first_piece.as_ref().unwrap().clone();
tmpl_packet.total_pieces = 1;
tmpl_packet.piece_idx = 0;
tmpl_packet.content = content;
Some(tmpl_packet)
}
fn feed(
&mut self,
packet: ZCPacket,
expected_tid: Option<PeerRpcTransactId>,
) -> Result<Option<TaRpcPacket>, Error> {
let payload = packet.payload();
let rpc_packet =
TaRpcPacket::decode(payload).map_err(|e| Error::MessageDecodeError(e.to_string()))?;
if expected_tid.is_some() && rpc_packet.transact_id != expected_tid.unwrap() {
return Ok(None);
}
let total_pieces = rpc_packet.total_pieces;
let piece_idx = rpc_packet.piece_idx;
// for compatibility with old version
if total_pieces == 0 && piece_idx == 0 {
return Ok(Some(rpc_packet));
}
if total_pieces > 100 || total_pieces == 0 {
return Err(Error::MessageDecodeError(format!(
"total_pieces is invalid: {}",
total_pieces
)));
}
if piece_idx >= total_pieces {
return Err(Error::MessageDecodeError(
"piece_idx >= total_pieces".to_owned(),
));
}
if self.first_piece.is_none()
|| self.first_piece.as_ref().unwrap().transact_id != rpc_packet.transact_id
|| self.first_piece.as_ref().unwrap().from_peer != rpc_packet.from_peer
{
self.first_piece = Some(rpc_packet.clone());
self.pieces.clear();
}
self.pieces
.resize(total_pieces as usize, Default::default());
self.pieces[piece_idx as usize] = rpc_packet;
Ok(self.try_merge_pieces())
}
}
impl PeerRpcManager { impl PeerRpcManager {
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self { pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
Self { Self {
service_map: Arc::new(DashMap::new()),
tasks: JoinSet::new(),
tspt: Arc::new(Box::new(tspt)), tspt: Arc::new(Box::new(tspt)),
rpc_client: rpc_impl::client::Client::new(),
service_registry: Arc::new(DashMap::new()), rpc_server: rpc_impl::server::Server::new(),
peer_rpc_endpoints: Arc::new(DashMap::new()),
client_resp_receivers: Arc::new(DashMap::new()),
transact_id: AtomicU32::new(0),
} }
} }
pub fn run_service<S, Req>(self: &Self, service_id: PeerRpcServiceId, s: S) -> ()
where
S: tarpc::server::Serve<Req> + Clone + Send + Sync + 'static,
Req: Send + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Resp:
Send + std::fmt::Debug + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Fut: Send + 'static,
{
let tspt = self.tspt.clone();
let creator = Box::new(move |peer_id: PeerId, transact_id: PeerRpcTransactId| {
let mut tasks = JoinSet::new();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel();
let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
let finished = Arc::new(AtomicBool::new(false));
let my_peer_id_clone = tspt.my_peer_id();
let peer_id_clone = peer_id.clone();
let o = server.execute(s.clone());
tasks.spawn(o);
let tspt = tspt.clone();
let finished_clone = finished.clone();
tasks.spawn(async move {
let mut packet_merger = PacketMerger::new();
loop {
tokio::select! {
Some(resp) = client_transport.next() => {
tracing::debug!(resp = ?resp, ?transact_id, ?peer_id, "server recv packet from service provider");
if resp.is_err() {
tracing::warn!(err = ?resp.err(),
"[PEER RPC MGR] client_transport in server side got channel error, ignore it.");
continue;
}
let resp = resp.unwrap();
let serialized_resp = postcard::to_allocvec(&resp);
if serialized_resp.is_err() {
tracing::error!(error = ?serialized_resp.err(), "serialize resp failed");
continue;
}
let msgs = Self::build_rpc_packet(
tspt.my_peer_id(),
peer_id,
service_id,
transact_id,
false,
serialized_resp.as_ref().unwrap(),
);
for msg in msgs {
if let Err(e) = tspt.send(msg, peer_id).await {
tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed");
break;
}
}
finished_clone.store(true, Ordering::Relaxed);
}
Some(packet) = packet_receiver.recv() => {
tracing::trace!("recv packet from peer, packet: {:?}", packet);
let info = match packet_merger.feed(packet, None) {
Err(e) => {
tracing::error!(error = ?e, "feed packet to merger failed");
continue;
},
Ok(None) => {
continue;
},
Ok(Some(info)) => {
info
}
};
assert_eq!(info.service_id, service_id);
assert_eq!(info.from_peer, peer_id);
assert_eq!(info.transact_id, transact_id);
let decoded_ret = postcard::from_bytes(&info.content.as_slice());
if let Err(e) = decoded_ret {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
}
let decoded: tarpc::ClientMessage<Req> = decoded_ret.unwrap();
if let Err(e) = client_transport.send(decoded).await {
tracing::error!(error = ?e, "send to req to client transport failed");
}
}
else => {
tracing::warn!("[PEER RPC MGR] service runner destroy, peer_id: {}, service_id: {}", peer_id, service_id);
}
}
}
}.instrument(tracing::info_span!("service_runner", my_id = ?my_peer_id_clone, peer_id = ?peer_id_clone, service_id = ?service_id)));
tracing::info!(
"[PEER RPC MGR] create new service endpoint for peer {}, service {}",
peer_id,
service_id
);
return PeerRpcEndPoint {
peer_id,
packet_sender,
create_time: AtomicCell::new(Instant::now()),
finished,
tasks,
};
// let resp = client_transport.next().await;
});
if let Some(_) = self.service_registry.insert(service_id, creator) {
panic!(
"[PEER RPC MGR] service {} is already registered",
service_id
);
}
tracing::info!(
"[PEER RPC MGR] register service {} succeed, my_node_id {}",
service_id,
self.tspt.my_peer_id()
)
}
fn parse_rpc_packet(packet: &ZCPacket) -> Result<TaRpcPacket, Error> {
let payload = packet.payload();
TaRpcPacket::decode(payload).map_err(|e| Error::MessageDecodeError(e.to_string()))
}
fn build_rpc_packet(
from_peer: PeerId,
to_peer: PeerId,
service_id: PeerRpcServiceId,
transact_id: PeerRpcTransactId,
is_req: bool,
content: &Vec<u8>,
) -> Vec<ZCPacket> {
let mut ret = Vec::new();
let content_mtu = RPC_PACKET_CONTENT_MTU;
let total_pieces = (content.len() + content_mtu - 1) / content_mtu;
let mut cur_offset = 0;
while cur_offset < content.len() {
let mut cur_len = content_mtu;
if cur_offset + cur_len > content.len() {
cur_len = content.len() - cur_offset;
}
let mut cur_content = Vec::new();
cur_content.extend_from_slice(&content[cur_offset..cur_offset + cur_len]);
let cur_packet = TaRpcPacket {
from_peer,
to_peer,
service_id,
transact_id,
is_req,
total_pieces: total_pieces as u32,
piece_idx: (cur_offset / content_mtu) as u32,
content: cur_content,
};
cur_offset += cur_len;
let mut buf = Vec::new();
cur_packet.encode(&mut buf).unwrap();
let mut zc_packet = ZCPacket::new_with_payload(&buf);
zc_packet.fill_peer_manager_hdr(from_peer, to_peer, PacketType::TaRpc as u8);
ret.push(zc_packet);
}
ret
}
pub fn run(&self) { pub fn run(&self) {
self.rpc_client.run();
self.rpc_server.run();
let (server_tx, mut server_rx) = (
self.rpc_server.get_transport_sink(),
self.rpc_server.get_transport_stream(),
);
let (client_tx, mut client_rx) = (
self.rpc_client.get_transport_sink(),
self.rpc_client.get_transport_stream(),
);
let tspt = self.tspt.clone();
tokio::spawn(async move {
loop {
let packet = tokio::select! {
Some(Ok(packet)) = server_rx.next() => {
tracing::trace!(?packet, "recv rpc packet from server");
packet
}
Some(Ok(packet)) = client_rx.next() => {
tracing::trace!(?packet, "recv rpc packet from client");
packet
}
else => {
tracing::warn!("rpc transport read aborted, exiting");
break;
}
};
let dst_peer_id = packet.peer_manager_header().unwrap().to_peer_id.into();
if let Err(e) = tspt.send(packet, dst_peer_id).await {
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
}
}
});
let tspt = self.tspt.clone(); let tspt = self.tspt.clone();
let service_registry = self.service_registry.clone();
let peer_rpc_endpoints = self.peer_rpc_endpoints.clone();
let client_resp_receivers = self.client_resp_receivers.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
let Ok(o) = tspt.recv().await else { let Ok(o) = tspt.recv().await else {
@@ -369,176 +92,24 @@ impl PeerRpcManager {
break; break;
}; };
let info = Self::parse_rpc_packet(&o).unwrap(); if o.peer_manager_header().unwrap().packet_type == PacketType::RpcReq as u8 {
tracing::debug!(?info, "recv rpc packet from peer"); server_tx.send(o).await.unwrap();
if info.is_req {
if !service_registry.contains_key(&info.service_id) {
tracing::warn!(
"service {} not found, my_node_id: {}",
info.service_id,
tspt.my_peer_id()
);
continue; continue;
} } else if o.peer_manager_header().unwrap().packet_type == PacketType::RpcResp as u8
let endpoint = peer_rpc_endpoints
.entry(PeerRpcClientCtxKey(
info.from_peer,
info.service_id,
info.transact_id,
))
.or_insert_with(|| {
service_registry.get(&info.service_id).unwrap()(
info.from_peer,
info.transact_id,
)
});
endpoint.packet_sender.send(o).unwrap();
} else {
if let Some(a) = client_resp_receivers.get(&PeerRpcClientCtxKey(
info.from_peer,
info.service_id,
info.transact_id,
)) {
tracing::trace!("recv resp: {:?}", info);
if let Err(e) = a.send(o) {
tracing::error!(error = ?e, "send resp to client failed");
}
} else {
tracing::warn!("client resp receiver not found, info: {:?}", info);
}
}
}
});
let peer_rpc_endpoints = self.peer_rpc_endpoints.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
peer_rpc_endpoints.retain(|_, v| {
v.create_time.load().elapsed().as_secs() < 30
&& !v.finished.load(Ordering::Relaxed)
});
}
});
}
#[tracing::instrument(skip(f))]
pub async fn do_client_rpc_scoped<Resp, Req, RpcRet, Fut>(
&self,
service_id: PeerRpcServiceId,
dst_peer_id: PeerId,
f: impl FnOnce(UnboundedChannel<Resp, Req>) -> Fut,
) -> RpcRet
where
Resp: serde::Serialize
+ for<'a> serde::Deserialize<'a>
+ Send
+ Sync
+ std::fmt::Debug
+ 'static,
Req: serde::Serialize
+ for<'a> serde::Deserialize<'a>
+ Send
+ Sync
+ std::fmt::Debug
+ 'static,
Fut: std::future::Future<Output = RpcRet>,
{ {
let mut tasks = JoinSet::new(); client_tx.send(o).await.unwrap();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel();
let (client_transport, server_transport) =
tarpc::transport::channel::unbounded::<Resp, Req>();
let (mut server_s, mut server_r) = server_transport.split();
let transact_id = self
.transact_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let tspt = self.tspt.clone();
tasks.spawn(async move {
while let Some(a) = server_r.next().await {
if a.is_err() {
tracing::error!(error = ?a.err(), "channel error");
continue; continue;
} }
let req = postcard::to_allocvec(&a.unwrap());
if req.is_err() {
tracing::error!(error = ?req.err(), "bincode serialize failed");
continue;
} }
let packets = Self::build_rpc_packet(
tspt.my_peer_id(),
dst_peer_id,
service_id,
transact_id,
true,
req.as_ref().unwrap(),
);
tracing::debug!(?packets, ?req, ?transact_id, "client send rpc packet to peer");
for packet in packets {
if let Err(e) = tspt.send(packet, dst_peer_id).await {
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
break;
}
}
}
tracing::warn!("[PEER RPC MGR] server trasport read aborted");
}); });
tasks.spawn(async move {
let mut packet_merger = PacketMerger::new();
while let Some(packet) = packet_receiver.recv().await {
tracing::trace!("tunnel recv: {:?}", packet);
let info = match packet_merger.feed(packet, Some(transact_id)) {
Err(e) => {
tracing::error!(error = ?e, "feed packet to merger failed");
continue;
}
Ok(None) => {
continue;
}
Ok(Some(info)) => info,
};
let decoded = postcard::from_bytes(&info.content.as_slice());
tracing::debug!(?info, ?decoded, "client recv rpc packet from peer");
assert_eq!(info.transact_id, transact_id);
if let Err(e) = decoded {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
} }
if let Err(e) = server_s.send(decoded.unwrap()).await { pub fn rpc_client(&self) -> &rpc_impl::client::Client {
tracing::error!(error = ?e, "send to rpc server channel failed"); &self.rpc_client
}
} }
tracing::warn!("[PEER RPC MGR] server packet read aborted"); pub fn rpc_server(&self) -> &rpc_impl::server::Server {
}); &self.rpc_server
let key = PeerRpcClientCtxKey(dst_peer_id, service_id, transact_id);
let _insert_ret = self
.client_resp_receivers
.insert(key.clone(), packet_sender);
let ret = f(client_transport).await;
self.client_resp_receivers.remove(&key);
ret
} }
pub fn my_peer_id(&self) -> PeerId { pub fn my_peer_id(&self) -> PeerId {
@@ -548,7 +119,7 @@ impl PeerRpcManager {
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use std::{pin::Pin, sync::Arc, time::Duration}; use std::{pin::Pin, sync::Arc};
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use tokio::sync::Mutex; use tokio::sync::Mutex;
@@ -559,31 +130,18 @@ pub mod tests {
peer_rpc::PeerRpcManager, peer_rpc::PeerRpcManager,
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear}, tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
}, },
proto::{
rpc_impl::RpcController,
tests::{GreetingClientFactory, GreetingServer, GreetingService, SayHelloRequest},
},
tunnel::{ tunnel::{
common::tests::wait_for_condition, packet_def::ZCPacket, ring::create_ring_tunnel_pair, packet_def::ZCPacket, ring::create_ring_tunnel_pair, Tunnel,
Tunnel, ZCPacketSink, ZCPacketStream, ZCPacketSink, ZCPacketStream,
}, },
}; };
use super::PeerRpcManagerTransport; use super::PeerRpcManagerTransport;
#[tarpc::service]
pub trait TestRpcService {
async fn hello(s: String) -> String;
}
#[derive(Clone)]
pub struct MockService {
pub prefix: String,
}
#[tarpc::server]
impl TestRpcService for MockService {
async fn hello(self, _: tarpc::context::Context, s: String) -> String {
format!("{} {}", self.prefix, s)
}
}
fn random_string(len: usize) -> String { fn random_string(len: usize) -> String {
use rand::distributions::Alphanumeric; use rand::distributions::Alphanumeric;
use rand::Rng; use rand::Rng;
@@ -595,6 +153,16 @@ pub mod tests {
String::from_utf8(s).unwrap() String::from_utf8(s).unwrap()
} }
pub fn register_service(rpc_mgr: &PeerRpcManager, domain: &str, delay_ms: u64, prefix: &str) {
rpc_mgr.rpc_server().registry().register(
GreetingServer::new(GreetingService {
delay_ms,
prefix: prefix.to_string(),
}),
domain,
);
}
#[tokio::test] #[tokio::test]
async fn peer_rpc_basic_test() { async fn peer_rpc_basic_test() {
struct MockTransport { struct MockTransport {
@@ -630,10 +198,7 @@ pub mod tests {
my_peer_id: new_peer_id(), my_peer_id: new_peer_id(),
}); });
server_rpc_mgr.run(); server_rpc_mgr.run();
let s = MockService { register_service(&server_rpc_mgr, "test", 0, "Hello");
prefix: "hello".to_owned(),
};
server_rpc_mgr.run_service(1, s.serve());
let client_rpc_mgr = PeerRpcManager::new(MockTransport { let client_rpc_mgr = PeerRpcManager::new(MockTransport {
sink: Arc::new(Mutex::new(stsr)), sink: Arc::new(Mutex::new(stsr)),
@@ -642,35 +207,27 @@ pub mod tests {
}); });
client_rpc_mgr.run(); client_rpc_mgr.run();
let stub = client_rpc_mgr
.rpc_client()
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "test".to_string());
let msg = random_string(8192); let msg = random_string(8192);
let ret = client_rpc_mgr let ret = stub
.do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async { .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); .await
let ret = c.hello(tarpc::context::current(), msg.clone()).await; .unwrap();
ret
})
.await;
println!("ret: {:?}", ret); println!("ret: {:?}", ret);
assert_eq!(ret.unwrap(), format!("hello {}", msg)); assert_eq!(ret.greeting, format!("Hello {}!", msg));
let msg = random_string(10); let msg = random_string(10);
let ret = client_rpc_mgr let ret = stub
.do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async { .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); .await
let ret = c.hello(tarpc::context::current(), msg.clone()).await; .unwrap();
ret
})
.await;
println!("ret: {:?}", ret); println!("ret: {:?}", ret);
assert_eq!(ret.unwrap(), format!("hello {}", msg)); assert_eq!(ret.greeting, format!("Hello {}!", msg));
wait_for_condition(
|| async { server_rpc_mgr.peer_rpc_endpoints.is_empty() },
Duration::from_secs(10),
)
.await;
} }
#[tokio::test] #[tokio::test]
@@ -680,6 +237,7 @@ pub mod tests {
let peer_mgr_c = create_mock_peer_manager().await; let peer_mgr_c = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await; connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await; connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone()) wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await .await
.unwrap(); .unwrap();
@@ -699,51 +257,42 @@ pub mod tests {
peer_mgr_b.my_peer_id() peer_mgr_b.my_peer_id()
); );
let s = MockService { register_service(&peer_mgr_b.get_peer_rpc_mgr(), "test", 0, "Hello");
prefix: "hello".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
let msg = random_string(16 * 1024); let msg = random_string(16 * 1024);
let ip_list = peer_mgr_a let stub = peer_mgr_a
.get_peer_rpc_mgr() .get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async { .rpc_client()
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); .scoped_client::<GreetingClientFactory<RpcController>>(
let ret = c.hello(tarpc::context::current(), msg.clone()).await; peer_mgr_a.my_peer_id(),
ret peer_mgr_b.my_peer_id(),
}) "test".to_string(),
.await; );
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.unwrap(), format!("hello {}", msg)); let ret = stub
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.await
.unwrap();
assert_eq!(ret.greeting, format!("Hello {}!", msg));
// call again // call again
let msg = random_string(16 * 1024); let msg = random_string(16 * 1024);
let ip_list = peer_mgr_a let ret = stub
.get_peer_rpc_mgr() .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async { .await
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); .unwrap();
let ret = c.hello(tarpc::context::current(), msg.clone()).await; assert_eq!(ret.greeting, format!("Hello {}!", msg));
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.unwrap(), format!("hello {}", msg));
let msg = random_string(16 * 1024); let msg = random_string(16 * 1024);
let ip_list = peer_mgr_c let ret = stub
.get_peer_rpc_mgr() .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async { .await
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); .unwrap();
let ret = c.hello(tarpc::context::current(), msg.clone()).await; assert_eq!(ret.greeting, format!("Hello {}!", msg));
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.unwrap(), format!("hello {}", msg));
} }
#[tokio::test] #[tokio::test]
async fn test_multi_service_with_peer_manager() { async fn test_multi_domain_with_peer_manager() {
let peer_mgr_a = create_mock_peer_manager().await; let peer_mgr_a = create_mock_peer_manager().await;
let peer_mgr_b = create_mock_peer_manager().await; let peer_mgr_b = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await; connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
@@ -757,42 +306,37 @@ pub mod tests {
peer_mgr_b.my_peer_id() peer_mgr_b.my_peer_id()
); );
let s = MockService { register_service(&peer_mgr_b.get_peer_rpc_mgr(), "test1", 0, "Hello");
prefix: "hello_a".to_owned(), register_service(&peer_mgr_b.get_peer_rpc_mgr(), "test2", 20000, "Hello2");
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve()); let stub1 = peer_mgr_a
let b = MockService { .get_peer_rpc_mgr()
prefix: "hello_b".to_owned(), .rpc_client()
}; .scoped_client::<GreetingClientFactory<RpcController>>(
peer_mgr_b.get_peer_rpc_mgr().run_service(2, b.serve()); peer_mgr_a.my_peer_id(),
peer_mgr_b.my_peer_id(),
"test1".to_string(),
);
let stub2 = peer_mgr_a
.get_peer_rpc_mgr()
.rpc_client()
.scoped_client::<GreetingClientFactory<RpcController>>(
peer_mgr_a.my_peer_id(),
peer_mgr_b.my_peer_id(),
"test2".to_string(),
);
let msg = random_string(16 * 1024); let msg = random_string(16 * 1024);
let ip_list = peer_mgr_a let ret = stub1
.get_peer_rpc_mgr() .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async { .await
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); .unwrap();
let ret = c.hello(tarpc::context::current(), msg.clone()).await; assert_eq!(ret.greeting, format!("Hello {}!", msg));
ret
})
.await;
assert_eq!(ip_list.unwrap(), format!("hello_a {}", msg));
let msg = random_string(16 * 1024); let ret = stub2
let ip_list = peer_mgr_a .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.get_peer_rpc_mgr()
.do_client_rpc_scoped(2, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), msg.clone()).await;
ret
})
.await;
assert_eq!(ip_list.unwrap(), format!("hello_b {}", msg));
wait_for_condition(
|| async { peer_mgr_b.get_peer_rpc_mgr().peer_rpc_endpoints.is_empty() },
Duration::from_secs(10),
)
.await; .await;
assert!(ret.is_err() && ret.unwrap_err().to_string().contains("Timeout"));
} }
} }

View File

@@ -1,9 +1,6 @@
use std::{net::Ipv4Addr, sync::Arc}; use std::{net::Ipv4Addr, sync::Arc};
use async_trait::async_trait; use crate::common::PeerId;
use tokio_util::bytes::Bytes;
use crate::common::{error::Error, PeerId};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum NextHopPolicy { pub enum NextHopPolicy {
@@ -17,15 +14,9 @@ impl Default for NextHopPolicy {
} }
} }
#[async_trait] #[async_trait::async_trait]
pub trait RouteInterface { pub trait RouteInterface {
async fn list_peers(&self) -> Vec<PeerId>; async fn list_peers(&self) -> Vec<PeerId>;
async fn send_route_packet(
&self,
msg: Bytes,
route_id: u8,
dst_peer_id: PeerId,
) -> Result<(), Error>;
fn my_peer_id(&self) -> PeerId; fn my_peer_id(&self) -> PeerId;
} }
@@ -56,7 +47,7 @@ impl RouteCostCalculatorInterface for DefaultRouteCostCalculator {}
pub type RouteCostCalculator = Box<dyn RouteCostCalculatorInterface>; pub type RouteCostCalculator = Box<dyn RouteCostCalculatorInterface>;
#[async_trait] #[async_trait::async_trait]
#[auto_impl::auto_impl(Box, Arc)] #[auto_impl::auto_impl(Box, Arc)]
pub trait Route { pub trait Route {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()>; async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()>;
@@ -71,7 +62,7 @@ pub trait Route {
self.get_next_hop(peer_id).await self.get_next_hop(peer_id).await
} }
async fn list_routes(&self) -> Vec<crate::rpc::Route>; async fn list_routes(&self) -> Vec<crate::proto::cli::Route>;
async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option<PeerId> { async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option<PeerId> {
None None

View File

@@ -1,14 +1,17 @@
use std::sync::Arc; use std::sync::Arc;
use crate::rpc::{ use crate::proto::{
cli::PeerInfo, peer_manage_rpc_server::PeerManageRpc, DumpRouteRequest, DumpRouteResponse, cli::{
ListForeignNetworkRequest, ListForeignNetworkResponse, ListPeerRequest, ListPeerResponse, DumpRouteRequest, DumpRouteResponse, ListForeignNetworkRequest, ListForeignNetworkResponse,
ListRouteRequest, ListRouteResponse, ShowNodeInfoRequest, ShowNodeInfoResponse, ListPeerRequest, ListPeerResponse, ListRouteRequest, ListRouteResponse, PeerInfo,
PeerManageRpc, ShowNodeInfoRequest, ShowNodeInfoResponse,
},
rpc_types::{self, controller::BaseController},
}; };
use tonic::{Request, Response, Status};
use super::peer_manager::PeerManager; use super::peer_manager::PeerManager;
#[derive(Clone)]
pub struct PeerManagerRpcService { pub struct PeerManagerRpcService {
peer_manager: Arc<PeerManager>, peer_manager: Arc<PeerManager>,
} }
@@ -36,12 +39,14 @@ impl PeerManagerRpcService {
} }
} }
#[tonic::async_trait] #[async_trait::async_trait]
impl PeerManageRpc for PeerManagerRpcService { impl PeerManageRpc for PeerManagerRpcService {
type Controller = BaseController;
async fn list_peer( async fn list_peer(
&self, &self,
_request: Request<ListPeerRequest>, // Accept request of type HelloRequest _: BaseController,
) -> Result<Response<ListPeerResponse>, Status> { _request: ListPeerRequest, // Accept request of type HelloRequest
) -> Result<ListPeerResponse, rpc_types::error::Error> {
let mut reply = ListPeerResponse::default(); let mut reply = ListPeerResponse::default();
let peers = self.list_peers().await; let peers = self.list_peers().await;
@@ -49,45 +54,49 @@ impl PeerManageRpc for PeerManagerRpcService {
reply.peer_infos.push(peer); reply.peer_infos.push(peer);
} }
Ok(Response::new(reply)) Ok(reply)
} }
async fn list_route( async fn list_route(
&self, &self,
_request: Request<ListRouteRequest>, // Accept request of type HelloRequest _: BaseController,
) -> Result<Response<ListRouteResponse>, Status> { _request: ListRouteRequest, // Accept request of type HelloRequest
) -> Result<ListRouteResponse, rpc_types::error::Error> {
let mut reply = ListRouteResponse::default(); let mut reply = ListRouteResponse::default();
reply.routes = self.peer_manager.list_routes().await; reply.routes = self.peer_manager.list_routes().await;
Ok(Response::new(reply)) Ok(reply)
} }
async fn dump_route( async fn dump_route(
&self, &self,
_request: Request<DumpRouteRequest>, // Accept request of type HelloRequest _: BaseController,
) -> Result<Response<DumpRouteResponse>, Status> { _request: DumpRouteRequest, // Accept request of type HelloRequest
) -> Result<DumpRouteResponse, rpc_types::error::Error> {
let mut reply = DumpRouteResponse::default(); let mut reply = DumpRouteResponse::default();
reply.result = self.peer_manager.dump_route().await; reply.result = self.peer_manager.dump_route().await;
Ok(Response::new(reply)) Ok(reply)
} }
async fn list_foreign_network( async fn list_foreign_network(
&self, &self,
_request: Request<ListForeignNetworkRequest>, // Accept request of type HelloRequest _: BaseController,
) -> Result<Response<ListForeignNetworkResponse>, Status> { _request: ListForeignNetworkRequest, // Accept request of type HelloRequest
) -> Result<ListForeignNetworkResponse, rpc_types::error::Error> {
let reply = self let reply = self
.peer_manager .peer_manager
.get_foreign_network_manager() .get_foreign_network_manager()
.list_foreign_networks() .list_foreign_networks()
.await; .await;
Ok(Response::new(reply)) Ok(reply)
} }
async fn show_node_info( async fn show_node_info(
&self, &self,
_request: Request<ShowNodeInfoRequest>, // Accept request of type HelloRequest _: BaseController,
) -> Result<Response<ShowNodeInfoResponse>, Status> { _request: ShowNodeInfoRequest, // Accept request of type HelloRequest
Ok(Response::new(ShowNodeInfoResponse { ) -> Result<ShowNodeInfoResponse, rpc_types::error::Error> {
Ok(ShowNodeInfoResponse {
node_info: Some(self.peer_manager.get_my_info()), node_info: Some(self.peer_manager.get_my_info()),
})) })
} }
} }

View File

@@ -1,4 +1,7 @@
syntax = "proto3"; syntax = "proto3";
import "common.proto";
package cli; package cli;
message Status { message Status {
@@ -16,18 +19,12 @@ message PeerConnStats {
uint64 latency_us = 5; uint64 latency_us = 5;
} }
message TunnelInfo {
string tunnel_type = 1;
string local_addr = 2;
string remote_addr = 3;
}
message PeerConnInfo { message PeerConnInfo {
string conn_id = 1; string conn_id = 1;
uint32 my_peer_id = 2; uint32 my_peer_id = 2;
uint32 peer_id = 3; uint32 peer_id = 3;
repeated string features = 4; repeated string features = 4;
TunnelInfo tunnel = 5; common.TunnelInfo tunnel = 5;
PeerConnStats stats = 6; PeerConnStats stats = 6;
float loss_rate = 7; float loss_rate = 7;
bool is_client = 8; bool is_client = 8;
@@ -46,27 +43,6 @@ message ListPeerResponse {
NodeInfo my_info = 2; NodeInfo my_info = 2;
} }
enum NatType {
// has NAT; but own a single public IP, port is not changed
Unknown = 0;
OpenInternet = 1;
NoPAT = 2;
FullCone = 3;
Restricted = 4;
PortRestricted = 5;
Symmetric = 6;
SymUdpFirewall = 7;
}
message StunInfo {
NatType udp_nat_type = 1;
NatType tcp_nat_type = 2;
int64 last_update_time = 3;
repeated string public_ip = 4;
uint32 min_port = 5;
uint32 max_port = 6;
}
message Route { message Route {
uint32 peer_id = 1; uint32 peer_id = 1;
string ipv4_addr = 2; string ipv4_addr = 2;
@@ -74,7 +50,7 @@ message Route {
int32 cost = 4; int32 cost = 4;
repeated string proxy_cidrs = 5; repeated string proxy_cidrs = 5;
string hostname = 6; string hostname = 6;
StunInfo stun_info = 7; common.StunInfo stun_info = 7;
string inst_id = 8; string inst_id = 8;
string version = 9; string version = 9;
} }
@@ -84,7 +60,7 @@ message NodeInfo {
string ipv4_addr = 2; string ipv4_addr = 2;
repeated string proxy_cidrs = 3; repeated string proxy_cidrs = 3;
string hostname = 4; string hostname = 4;
StunInfo stun_info = 5; common.StunInfo stun_info = 5;
string inst_id = 6; string inst_id = 6;
repeated string listeners = 7; repeated string listeners = 7;
string config = 8; string config = 8;
@@ -127,7 +103,7 @@ enum ConnectorStatus {
} }
message Connector { message Connector {
string url = 1; common.Url url = 1;
ConnectorStatus status = 2; ConnectorStatus status = 2;
} }
@@ -142,7 +118,7 @@ enum ConnectorManageAction {
message ManageConnectorRequest { message ManageConnectorRequest {
ConnectorManageAction action = 1; ConnectorManageAction action = 1;
string url = 2; common.Url url = 2;
} }
message ManageConnectorResponse {} message ManageConnectorResponse {}
@@ -152,23 +128,6 @@ service ConnectorManageRpc {
rpc ManageConnector(ManageConnectorRequest) returns (ManageConnectorResponse); rpc ManageConnector(ManageConnectorRequest) returns (ManageConnectorResponse);
} }
message DirectConnectedPeerInfo { int32 latency_ms = 1; }
message PeerInfoForGlobalMap {
map<uint32, DirectConnectedPeerInfo> direct_peers = 1;
}
message GetGlobalPeerMapRequest {}
message GetGlobalPeerMapResponse {
map<uint32, PeerInfoForGlobalMap> global_peer_map = 1;
}
service PeerCenterRpc {
rpc GetGlobalPeerMap(GetGlobalPeerMapRequest)
returns (GetGlobalPeerMapResponse);
}
message VpnPortalInfo { message VpnPortalInfo {
string vpn_type = 1; string vpn_type = 1;
string client_config = 2; string client_config = 2;
@@ -182,24 +141,3 @@ service VpnPortalRpc {
rpc GetVpnPortalInfo(GetVpnPortalInfoRequest) rpc GetVpnPortalInfo(GetVpnPortalInfoRequest)
returns (GetVpnPortalInfoResponse); returns (GetVpnPortalInfoResponse);
} }
message HandshakeRequest {
uint32 magic = 1;
uint32 my_peer_id = 2;
uint32 version = 3;
repeated string features = 4;
string network_name = 5;
bytes network_secret_digrest = 6;
}
message TaRpcPacket {
uint32 from_peer = 1;
uint32 to_peer = 2;
uint32 service_id = 3;
uint32 transact_id = 4;
bool is_req = 5;
bytes content = 6;
uint32 total_pieces = 7;
uint32 piece_idx = 8;
}

View File

@@ -0,0 +1 @@
include!(concat!(env!("OUT_DIR"), "/cli.rs"));

View File

@@ -0,0 +1,92 @@
syntax = "proto3";
import "error.proto";
package common;
message RpcDescriptor {
// allow same service registered multiple times in different domain
string domain_name = 1;
string proto_name = 2;
string service_name = 3;
uint32 method_index = 4;
}
message RpcRequest {
RpcDescriptor descriptor = 1;
bytes request = 2;
int32 timeout_ms = 3;
}
message RpcResponse {
bytes response = 1;
error.Error error = 2;
uint64 runtime_us = 3;
}
message RpcPacket {
uint32 from_peer = 1;
uint32 to_peer = 2;
int64 transaction_id = 3;
RpcDescriptor descriptor = 4;
bytes body = 5;
bool is_request = 6;
uint32 total_pieces = 7;
uint32 piece_idx = 8;
int32 trace_id = 9;
}
message UUID {
uint64 high = 1;
uint64 low = 2;
}
enum NatType {
// has NAT; but own a single public IP, port is not changed
Unknown = 0;
OpenInternet = 1;
NoPAT = 2;
FullCone = 3;
Restricted = 4;
PortRestricted = 5;
Symmetric = 6;
SymUdpFirewall = 7;
}
message Ipv4Addr { uint32 addr = 1; }
message Ipv6Addr {
uint64 high = 1;
uint64 low = 2;
}
message Url { string url = 1; }
message SocketAddr {
oneof ip {
Ipv4Addr ipv4 = 1;
Ipv6Addr ipv6 = 2;
};
uint32 port = 3;
}
message TunnelInfo {
string tunnel_type = 1;
common.Url local_addr = 2;
common.Url remote_addr = 3;
}
message StunInfo {
NatType udp_nat_type = 1;
NatType tcp_nat_type = 2;
int64 last_update_time = 3;
repeated string public_ip = 4;
uint32 min_port = 5;
uint32 max_port = 6;
}

View File

@@ -0,0 +1,131 @@
use std::{fmt::Display, str::FromStr};
include!(concat!(env!("OUT_DIR"), "/common.rs"));
impl From<uuid::Uuid> for Uuid {
fn from(uuid: uuid::Uuid) -> Self {
let (high, low) = uuid.as_u64_pair();
Uuid { low, high }
}
}
impl From<Uuid> for uuid::Uuid {
fn from(uuid: Uuid) -> Self {
uuid::Uuid::from_u64_pair(uuid.high, uuid.low)
}
}
impl Display for Uuid {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", uuid::Uuid::from(self.clone()))
}
}
impl From<std::net::Ipv4Addr> for Ipv4Addr {
fn from(value: std::net::Ipv4Addr) -> Self {
Self {
addr: u32::from_be_bytes(value.octets()),
}
}
}
impl From<Ipv4Addr> for std::net::Ipv4Addr {
fn from(value: Ipv4Addr) -> Self {
std::net::Ipv4Addr::from(value.addr)
}
}
impl ToString for Ipv4Addr {
fn to_string(&self) -> String {
std::net::Ipv4Addr::from(self.addr).to_string()
}
}
impl From<std::net::Ipv6Addr> for Ipv6Addr {
fn from(value: std::net::Ipv6Addr) -> Self {
let b = value.octets();
Self {
low: u64::from_be_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]),
high: u64::from_be_bytes([b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]]),
}
}
}
impl From<Ipv6Addr> for std::net::Ipv6Addr {
fn from(value: Ipv6Addr) -> Self {
let low = value.low.to_be_bytes();
let high = value.high.to_be_bytes();
std::net::Ipv6Addr::from([
low[0], low[1], low[2], low[3], low[4], low[5], low[6], low[7], high[0], high[1],
high[2], high[3], high[4], high[5], high[6], high[7],
])
}
}
impl ToString for Ipv6Addr {
fn to_string(&self) -> String {
std::net::Ipv6Addr::from(self.clone()).to_string()
}
}
impl From<url::Url> for Url {
fn from(value: url::Url) -> Self {
Url {
url: value.to_string(),
}
}
}
impl From<Url> for url::Url {
fn from(value: Url) -> Self {
url::Url::parse(&value.url).unwrap()
}
}
impl FromStr for Url {
type Err = url::ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Url {
url: s.parse::<url::Url>()?.to_string(),
})
}
}
impl Display for Url {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.url)
}
}
impl From<std::net::SocketAddr> for SocketAddr {
fn from(value: std::net::SocketAddr) -> Self {
match value {
std::net::SocketAddr::V4(v4) => SocketAddr {
ip: Some(socket_addr::Ip::Ipv4(v4.ip().clone().into())),
port: v4.port() as u32,
},
std::net::SocketAddr::V6(v6) => SocketAddr {
ip: Some(socket_addr::Ip::Ipv6(v6.ip().clone().into())),
port: v6.port() as u32,
},
}
}
}
impl From<SocketAddr> for std::net::SocketAddr {
fn from(value: SocketAddr) -> Self {
match value.ip.unwrap() {
socket_addr::Ip::Ipv4(ip) => std::net::SocketAddr::V4(std::net::SocketAddrV4::new(
std::net::Ipv4Addr::from(ip),
value.port as u16,
)),
socket_addr::Ip::Ipv6(ip) => std::net::SocketAddr::V6(std::net::SocketAddrV6::new(
std::net::Ipv6Addr::from(ip),
value.port as u16,
0,
0,
)),
}
}
}

View File

@@ -0,0 +1,34 @@
syntax = "proto3";
package error;
message OtherError { string error_message = 1; }
message InvalidMethodIndex {
string service_name = 1;
uint32 method_index = 2;
}
message InvalidService { string service_name = 1; }
message ProstDecodeError {}
message ProstEncodeError {}
message ExecuteError { string error_message = 1; }
message MalformatRpcPacket { string error_message = 1; }
message Timeout { string error_message = 1; }
message Error {
oneof error {
OtherError other_error = 1;
InvalidMethodIndex invalid_method_index = 2;
InvalidService invalid_service = 3;
ProstDecodeError prost_decode_error = 4;
ProstEncodeError prost_encode_error = 5;
ExecuteError execute_error = 6;
MalformatRpcPacket malformat_rpc_packet = 7;
Timeout timeout = 8;
}
}

View File

@@ -0,0 +1,84 @@
use prost::DecodeError;
use super::rpc_types;
include!(concat!(env!("OUT_DIR"), "/error.rs"));
impl From<&rpc_types::error::Error> for Error {
fn from(e: &rpc_types::error::Error) -> Self {
use super::error::error::Error as ProtoError;
match e {
rpc_types::error::Error::ExecutionError(e) => Self {
error: Some(ProtoError::ExecuteError(ExecuteError {
error_message: e.to_string(),
})),
},
rpc_types::error::Error::DecodeError(_) => Self {
error: Some(ProtoError::ProstDecodeError(ProstDecodeError {})),
},
rpc_types::error::Error::EncodeError(_) => Self {
error: Some(ProtoError::ProstEncodeError(ProstEncodeError {})),
},
rpc_types::error::Error::InvalidMethodIndex(m, s) => Self {
error: Some(ProtoError::InvalidMethodIndex(InvalidMethodIndex {
method_index: *m as u32,
service_name: s.to_string(),
})),
},
rpc_types::error::Error::InvalidServiceKey(s, _) => Self {
error: Some(ProtoError::InvalidService(InvalidService {
service_name: s.to_string(),
})),
},
rpc_types::error::Error::MalformatRpcPacket(e) => Self {
error: Some(ProtoError::MalformatRpcPacket(MalformatRpcPacket {
error_message: e.to_string(),
})),
},
rpc_types::error::Error::Timeout(e) => Self {
error: Some(ProtoError::Timeout(Timeout {
error_message: e.to_string(),
})),
},
#[allow(unreachable_patterns)]
e => Self {
error: Some(ProtoError::OtherError(OtherError {
error_message: e.to_string(),
})),
},
}
}
}
impl From<&Error> for rpc_types::error::Error {
fn from(e: &Error) -> Self {
use super::error::error::Error as ProtoError;
match &e.error {
Some(ProtoError::ExecuteError(e)) => {
Self::ExecutionError(anyhow::anyhow!(e.error_message.clone()))
}
Some(ProtoError::ProstDecodeError(_)) => {
Self::DecodeError(DecodeError::new("decode error"))
}
Some(ProtoError::ProstEncodeError(_)) => {
Self::DecodeError(DecodeError::new("encode error"))
}
Some(ProtoError::InvalidMethodIndex(e)) => {
Self::InvalidMethodIndex(e.method_index as u8, e.service_name.clone())
}
Some(ProtoError::InvalidService(e)) => {
Self::InvalidServiceKey(e.service_name.clone(), "".to_string())
}
Some(ProtoError::MalformatRpcPacket(e)) => {
Self::MalformatRpcPacket(e.error_message.clone())
}
Some(ProtoError::Timeout(e)) => {
Self::ExecutionError(anyhow::anyhow!(e.error_message.clone()))
}
Some(ProtoError::OtherError(e)) => {
Self::ExecutionError(anyhow::anyhow!(e.error_message.clone()))
}
None => Self::ExecutionError(anyhow::anyhow!("unknown error {:?}", e)),
}
}
}

View File

@@ -0,0 +1,9 @@
pub mod rpc_impl;
pub mod rpc_types;
pub mod cli;
pub mod common;
pub mod error;
pub mod peer_rpc;
pub mod tests;

View File

@@ -0,0 +1,129 @@
syntax = "proto3";
import "google/protobuf/timestamp.proto";
import "common.proto";
package peer_rpc;
message RoutePeerInfo {
// means next hop in route table.
uint32 peer_id = 1;
common.UUID inst_id = 2;
uint32 cost = 3;
optional common.Ipv4Addr ipv4_addr = 4;
repeated string proxy_cidrs = 5;
optional string hostname = 6;
common.NatType udp_stun_info = 7;
google.protobuf.Timestamp last_update = 8;
uint32 version = 9;
}
message PeerIdVersion {
uint32 peer_id = 1;
uint32 version = 2;
}
message RouteConnBitmap {
repeated PeerIdVersion peer_ids = 1;
bytes bitmap = 2;
}
message RoutePeerInfos { repeated RoutePeerInfo items = 1; }
message SyncRouteInfoRequest {
uint32 my_peer_id = 1;
uint64 my_session_id = 2;
bool is_initiator = 3;
RoutePeerInfos peer_infos = 4;
RouteConnBitmap conn_bitmap = 5;
}
enum SyncRouteInfoError {
DuplicatePeerId = 0;
Stopped = 1;
}
message SyncRouteInfoResponse {
bool is_initiator = 1;
uint64 session_id = 2;
optional SyncRouteInfoError error = 3;
}
service OspfRouteRpc {
// Generates a "hello" greeting based on the supplied info.
rpc SyncRouteInfo(SyncRouteInfoRequest) returns (SyncRouteInfoResponse);
}
message GetIpListRequest {}
message GetIpListResponse {
common.Ipv4Addr public_ipv4 = 1;
repeated common.Ipv4Addr interface_ipv4s = 2;
common.Ipv6Addr public_ipv6 = 3;
repeated common.Ipv6Addr interface_ipv6s = 4;
repeated common.Url listeners = 5;
}
service DirectConnectorRpc {
rpc GetIpList(GetIpListRequest) returns (GetIpListResponse);
}
message TryPunchHoleRequest { common.SocketAddr local_mapped_addr = 1; }
message TryPunchHoleResponse { common.SocketAddr remote_mapped_addr = 1; }
message TryPunchSymmetricRequest {
common.SocketAddr listener_addr = 1;
uint32 port = 2;
repeated common.Ipv4Addr public_ips = 3;
uint32 min_port = 4;
uint32 max_port = 5;
uint32 transaction_id = 6;
uint32 round = 7;
uint32 last_port_index = 8;
}
message TryPunchSymmetricResponse { uint32 last_port_index = 1; }
service UdpHolePunchRpc {
rpc TryPunchHole(TryPunchHoleRequest) returns (TryPunchHoleResponse);
rpc TryPunchSymmetric(TryPunchSymmetricRequest)
returns (TryPunchSymmetricResponse);
}
message DirectConnectedPeerInfo { int32 latency_ms = 1; }
message PeerInfoForGlobalMap {
map<uint32, DirectConnectedPeerInfo> direct_peers = 1;
}
message ReportPeersRequest {
uint32 my_peer_id = 1;
PeerInfoForGlobalMap peer_infos = 2;
}
message ReportPeersResponse {}
message GlobalPeerMap { map<uint32, PeerInfoForGlobalMap> map = 1; }
message GetGlobalPeerMapRequest { uint64 digest = 1; }
message GetGlobalPeerMapResponse {
map<uint32, PeerInfoForGlobalMap> global_peer_map = 1;
optional uint64 digest = 2;
}
service PeerCenterRpc {
rpc ReportPeers(ReportPeersRequest) returns (ReportPeersResponse);
rpc GetGlobalPeerMap(GetGlobalPeerMapRequest)
returns (GetGlobalPeerMapResponse);
}
message HandshakeRequest {
uint32 magic = 1;
uint32 my_peer_id = 2;
uint32 version = 3;
repeated string features = 4;
string network_name = 5;
bytes network_secret_digrest = 6;
}

View File

@@ -0,0 +1 @@
include!(concat!(env!("OUT_DIR"), "/peer_rpc.rs"));

View File

@@ -0,0 +1,8 @@
[package]
name = "rpc_build"
version = "0.1.0"
edition = "2021"
[dependencies]
heck = "0.5"
prost-build = "0.13"

View File

@@ -0,0 +1,383 @@
extern crate heck;
extern crate prost_build;
use std::fmt;
const NAMESPACE: &str = "crate::proto::rpc_types";
/// The service generator to be used with `prost-build` to generate RPC implementations for
/// `prost-simple-rpc`.
///
/// See the crate-level documentation for more info.
#[allow(missing_copy_implementations)]
#[derive(Clone, Debug)]
pub struct ServiceGenerator {
_private: (),
}
impl ServiceGenerator {
/// Create a new `ServiceGenerator` instance with the default options set.
pub fn new() -> ServiceGenerator {
ServiceGenerator { _private: () }
}
}
impl prost_build::ServiceGenerator for ServiceGenerator {
fn generate(&mut self, service: prost_build::Service, mut buf: &mut String) {
use std::fmt::Write;
let descriptor_name = format!("{}Descriptor", service.name);
let server_name = format!("{}Server", service.name);
let client_name = format!("{}Client", service.name);
let method_descriptor_name = format!("{}MethodDescriptor", service.name);
let mut trait_methods = String::new();
let mut enum_methods = String::new();
let mut list_enum_methods = String::new();
let mut client_methods = String::new();
let mut client_own_methods = String::new();
let mut match_name_methods = String::new();
let mut match_proto_name_methods = String::new();
let mut match_input_type_methods = String::new();
let mut match_input_proto_type_methods = String::new();
let mut match_output_type_methods = String::new();
let mut match_output_proto_type_methods = String::new();
let mut match_handle_methods = String::new();
let mut match_method_try_from = String::new();
for (idx, method) in service.methods.iter().enumerate() {
assert!(
!method.client_streaming,
"Client streaming not yet supported for method {}",
method.proto_name
);
assert!(
!method.server_streaming,
"Server streaming not yet supported for method {}",
method.proto_name
);
ServiceGenerator::write_comments(&mut trait_methods, 4, &method.comments).unwrap();
writeln!(
trait_methods,
r#" async fn {name}(&self, ctrl: Self::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}>;"#,
name = method.name,
input_type = method.input_type,
output_type = method.output_type,
namespace = NAMESPACE,
)
.unwrap();
ServiceGenerator::write_comments(&mut enum_methods, 4, &method.comments).unwrap();
writeln!(
enum_methods,
" {name} = {index},",
name = method.proto_name,
index = format!("{}", idx + 1)
)
.unwrap();
writeln!(
match_method_try_from,
" {index} => Ok({service_name}MethodDescriptor::{name}),",
service_name = service.name,
name = method.proto_name,
index = format!("{}", idx + 1),
)
.unwrap();
writeln!(
list_enum_methods,
" {service_name}MethodDescriptor::{name},",
service_name = service.name,
name = method.proto_name
)
.unwrap();
writeln!(
client_methods,
r#" async fn {name}(&self, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{
{client_name}::{name}_inner(self.0.clone(), ctrl, input).await
}}"#,
name = method.name,
input_type = method.input_type,
output_type = method.output_type,
client_name = format!("{}Client", service.name),
namespace = NAMESPACE,
)
.unwrap();
writeln!(
client_own_methods,
r#" async fn {name}_inner(handler: H, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{
{namespace}::__rt::call_method(handler, ctrl, {method_descriptor_name}::{proto_name}, input).await
}}"#,
name = method.name,
method_descriptor_name = method_descriptor_name,
proto_name = method.proto_name,
input_type = method.input_type,
output_type = method.output_type,
namespace = NAMESPACE,
).unwrap();
let case = format!(
" {service_name}MethodDescriptor::{proto_name} => ",
service_name = service.name,
proto_name = method.proto_name
);
writeln!(match_name_methods, "{}{:?},", case, method.name).unwrap();
writeln!(match_proto_name_methods, "{}{:?},", case, method.proto_name).unwrap();
writeln!(
match_input_type_methods,
"{}::std::any::TypeId::of::<{}>(),",
case, method.input_type
)
.unwrap();
writeln!(
match_input_proto_type_methods,
"{}{:?},",
case, method.input_proto_type
)
.unwrap();
writeln!(
match_output_type_methods,
"{}::std::any::TypeId::of::<{}>(),",
case, method.output_type
)
.unwrap();
writeln!(
match_output_proto_type_methods,
"{}{:?},",
case, method.output_proto_type
)
.unwrap();
write!(
match_handle_methods,
r#"{} {{
let decoded: {input_type} = {namespace}::__rt::decode(input)?;
let ret = service.{name}(ctrl, decoded).await?;
{namespace}::__rt::encode(ret)
}}
"#,
case,
input_type = method.input_type,
name = method.name,
namespace = NAMESPACE,
)
.unwrap();
}
ServiceGenerator::write_comments(&mut buf, 0, &service.comments).unwrap();
write!(
buf,
r#"
#[async_trait::async_trait]
#[auto_impl::auto_impl(&, Arc, Box)]
pub trait {name} {{
type Controller: {namespace}::controller::Controller;
{trait_methods}
}}
/// A service descriptor for a `{name}`.
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd, Default)]
pub struct {descriptor_name};
/// Methods available on a `{name}`.
///
/// This can be used as a key when routing requests for servers/clients of a `{name}`.
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
#[repr(u8)]
pub enum {method_descriptor_name} {{
{enum_methods}
}}
impl std::convert::TryFrom<u8> for {method_descriptor_name} {{
type Error = {namespace}::error::Error;
fn try_from(value: u8) -> {namespace}::error::Result<Self> {{
match value {{
{match_method_try_from}
_ => Err({namespace}::error::Error::InvalidMethodIndex(value, "{name}".to_string())),
}}
}}
}}
/// A client for a `{name}`.
///
/// This implements the `{name}` trait by dispatching all method calls to the supplied `Handler`.
#[derive(Clone, Debug)]
pub struct {client_name}<H>(H) where H: {namespace}::handler::Handler;
impl<H> {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
/// Creates a new client instance that delegates all method calls to the supplied handler.
pub fn new(handler: H) -> {client_name}<H> {{
{client_name}(handler)
}}
}}
impl<H> {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
{client_own_methods}
}}
#[async_trait::async_trait]
impl<H> {name} for {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
type Controller = H::Controller;
{client_methods}
}}
pub struct {client_name}Factory<C: {namespace}::controller::Controller>(std::marker::PhantomData<C>);
impl<C: {namespace}::controller::Controller> Clone for {client_name}Factory<C> {{
fn clone(&self) -> Self {{
Self(std::marker::PhantomData)
}}
}}
impl<C> {namespace}::__rt::RpcClientFactory for {client_name}Factory<C> where C: {namespace}::controller::Controller {{
type Descriptor = {descriptor_name};
type ClientImpl = Box<dyn {name}<Controller = C> + Send + 'static>;
type Controller = C;
fn new(handler: impl {namespace}::handler::Handler<Descriptor = Self::Descriptor, Controller = Self::Controller>) -> Self::ClientImpl {{
Box::new({client_name}::new(handler))
}}
}}
/// A server for a `{name}`.
///
/// This implements the `Server` trait by handling requests and dispatch them to methods on the
/// supplied `{name}`.
#[derive(Clone, Debug)]
pub struct {server_name}<A>(A) where A: {name} + Clone + Send + 'static;
impl<A> {server_name}<A> where A: {name} + Clone + Send + 'static {{
/// Creates a new server instance that dispatches all calls to the supplied service.
pub fn new(service: A) -> {server_name}<A> {{
{server_name}(service)
}}
async fn call_inner(
service: A,
method: {method_descriptor_name},
ctrl: A::Controller,
input: ::bytes::Bytes)
-> {namespace}::error::Result<::bytes::Bytes> {{
match method {{
{match_handle_methods}
}}
}}
}}
impl {namespace}::descriptor::ServiceDescriptor for {descriptor_name} {{
type Method = {method_descriptor_name};
fn name(&self) -> &'static str {{ {name:?} }}
fn proto_name(&self) -> &'static str {{ {proto_name:?} }}
fn package(&self) -> &'static str {{ {package:?} }}
fn methods(&self) -> &'static [Self::Method] {{
&[ {list_enum_methods} ]
}}
}}
#[async_trait::async_trait]
impl<A> {namespace}::handler::Handler for {server_name}<A>
where
A: {name} + Clone + Send + Sync + 'static {{
type Descriptor = {descriptor_name};
type Controller = A::Controller;
async fn call(
&self,
ctrl: A::Controller,
method: {method_descriptor_name},
input: ::bytes::Bytes)
-> {namespace}::error::Result<::bytes::Bytes> {{
{server_name}::call_inner(self.0.clone(), method, ctrl, input).await
}}
}}
impl {namespace}::descriptor::MethodDescriptor for {method_descriptor_name} {{
fn name(&self) -> &'static str {{
match *self {{
{match_name_methods}
}}
}}
fn proto_name(&self) -> &'static str {{
match *self {{
{match_proto_name_methods}
}}
}}
fn input_type(&self) -> ::std::any::TypeId {{
match *self {{
{match_input_type_methods}
}}
}}
fn input_proto_type(&self) -> &'static str {{
match *self {{
{match_input_proto_type_methods}
}}
}}
fn output_type(&self) -> ::std::any::TypeId {{
match *self {{
{match_output_type_methods}
}}
}}
fn output_proto_type(&self) -> &'static str {{
match *self {{
{match_output_proto_type_methods}
}}
}}
fn index(&self) -> u8 {{
*self as u8
}}
}}
"#,
name = service.name,
descriptor_name = descriptor_name,
server_name = server_name,
client_name = client_name,
method_descriptor_name = method_descriptor_name,
proto_name = service.proto_name,
package = service.package,
trait_methods = trait_methods,
enum_methods = enum_methods,
list_enum_methods = list_enum_methods,
client_own_methods = client_own_methods,
client_methods = client_methods,
match_name_methods = match_name_methods,
match_proto_name_methods = match_proto_name_methods,
match_input_type_methods = match_input_type_methods,
match_input_proto_type_methods = match_input_proto_type_methods,
match_output_type_methods = match_output_type_methods,
match_output_proto_type_methods = match_output_proto_type_methods,
match_handle_methods = match_handle_methods,
namespace = NAMESPACE,
).unwrap();
}
}
impl ServiceGenerator {
fn write_comments<W>(
mut write: W,
indent: usize,
comments: &prost_build::Comments,
) -> fmt::Result
where
W: fmt::Write,
{
for comment in &comments.leading {
for line in comment.lines().filter(|s| !s.is_empty()) {
writeln!(write, "{}///{}", " ".repeat(indent), line)?;
}
}
Ok(())
}
}

View File

@@ -0,0 +1,240 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use bytes::Bytes;
use dashmap::DashMap;
use prost::Message;
use tokio::sync::mpsc;
use tokio::task::JoinSet;
use tokio::time::timeout;
use tokio_stream::StreamExt;
use crate::common::PeerId;
use crate::defer;
use crate::proto::common::{RpcDescriptor, RpcPacket, RpcRequest, RpcResponse};
use crate::proto::rpc_impl::packet::build_rpc_packet;
use crate::proto::rpc_types::controller::Controller;
use crate::proto::rpc_types::descriptor::MethodDescriptor;
use crate::proto::rpc_types::{
__rt::RpcClientFactory, descriptor::ServiceDescriptor, handler::Handler,
};
use crate::proto::rpc_types::error::Result;
use crate::tunnel::mpsc::{MpscTunnel, MpscTunnelSender};
use crate::tunnel::packet_def::ZCPacket;
use crate::tunnel::ring::create_ring_tunnel_pair;
use crate::tunnel::{Tunnel, TunnelError, ZCPacketStream};
use super::packet::PacketMerger;
use super::{RpcTransactId, Transport};
static CUR_TID: once_cell::sync::Lazy<atomic_shim::AtomicI64> =
once_cell::sync::Lazy::new(|| atomic_shim::AtomicI64::new(rand::random()));
type RpcPacketSender = mpsc::UnboundedSender<RpcPacket>;
type RpcPacketReceiver = mpsc::UnboundedReceiver<RpcPacket>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct InflightRequestKey {
from_peer_id: PeerId,
to_peer_id: PeerId,
transaction_id: RpcTransactId,
}
struct InflightRequest {
sender: RpcPacketSender,
merger: PacketMerger,
start_time: std::time::Instant,
}
type InflightRequestTable = Arc<DashMap<InflightRequestKey, InflightRequest>>;
pub struct Client {
mpsc: Mutex<MpscTunnel<Box<dyn Tunnel>>>,
transport: Mutex<Transport>,
inflight_requests: InflightRequestTable,
tasks: Arc<Mutex<JoinSet<()>>>,
}
impl Client {
pub fn new() -> Self {
let (ring_a, ring_b) = create_ring_tunnel_pair();
Self {
mpsc: Mutex::new(MpscTunnel::new(ring_a)),
transport: Mutex::new(MpscTunnel::new(ring_b)),
inflight_requests: Arc::new(DashMap::new()),
tasks: Arc::new(Mutex::new(JoinSet::new())),
}
}
pub fn get_transport_sink(&self) -> MpscTunnelSender {
self.transport.lock().unwrap().get_sink()
}
pub fn get_transport_stream(&self) -> Pin<Box<dyn ZCPacketStream>> {
self.transport.lock().unwrap().get_stream()
}
pub fn run(&self) {
let mut tasks = self.tasks.lock().unwrap();
let mut rx = self.mpsc.lock().unwrap().get_stream();
let inflight_requests = self.inflight_requests.clone();
tasks.spawn(async move {
while let Some(packet) = rx.next().await {
if let Err(err) = packet {
tracing::error!(?err, "Failed to receive packet");
continue;
}
let packet = match RpcPacket::decode(packet.unwrap().payload()) {
Err(err) => {
tracing::error!(?err, "Failed to decode packet");
continue;
}
Ok(packet) => packet,
};
if packet.is_request {
tracing::warn!(?packet, "Received non-response packet");
continue;
}
let key = InflightRequestKey {
from_peer_id: packet.to_peer,
to_peer_id: packet.from_peer,
transaction_id: packet.transaction_id,
};
let Some(mut inflight_request) = inflight_requests.get_mut(&key) else {
tracing::warn!(?key, "No inflight request found for key");
continue;
};
let ret = inflight_request.merger.feed(packet);
match ret {
Ok(Some(rpc_packet)) => {
inflight_request.sender.send(rpc_packet).unwrap();
}
Ok(None) => {}
Err(err) => {
tracing::error!(?err, "Failed to feed packet to merger");
}
}
}
});
}
pub fn scoped_client<F: RpcClientFactory>(
&self,
from_peer_id: PeerId,
to_peer_id: PeerId,
domain_name: String,
) -> F::ClientImpl {
#[derive(Clone)]
struct HandlerImpl<F> {
domain_name: String,
from_peer_id: PeerId,
to_peer_id: PeerId,
zc_packet_sender: MpscTunnelSender,
inflight_requests: InflightRequestTable,
_phan: PhantomData<F>,
}
impl<F: RpcClientFactory> HandlerImpl<F> {
async fn do_rpc(
&self,
packets: Vec<ZCPacket>,
rx: &mut RpcPacketReceiver,
) -> Result<RpcPacket> {
for packet in packets {
self.zc_packet_sender.send(packet).await?;
}
Ok(rx.recv().await.ok_or(TunnelError::Shutdown)?)
}
}
#[async_trait::async_trait]
impl<F: RpcClientFactory> Handler for HandlerImpl<F> {
type Descriptor = F::Descriptor;
type Controller = F::Controller;
async fn call(
&self,
ctrl: Self::Controller,
method: <Self::Descriptor as ServiceDescriptor>::Method,
input: bytes::Bytes,
) -> Result<bytes::Bytes> {
let transaction_id = CUR_TID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let (tx, mut rx) = mpsc::unbounded_channel();
let key = InflightRequestKey {
from_peer_id: self.from_peer_id,
to_peer_id: self.to_peer_id,
transaction_id,
};
defer!(self.inflight_requests.remove(&key););
self.inflight_requests.insert(
key.clone(),
InflightRequest {
sender: tx,
merger: PacketMerger::new(),
start_time: std::time::Instant::now(),
},
);
let desc = self.service_descriptor();
let rpc_desc = RpcDescriptor {
domain_name: self.domain_name.clone(),
proto_name: desc.proto_name().to_string(),
service_name: desc.name().to_string(),
method_index: method.index() as u32,
};
let rpc_req = RpcRequest {
descriptor: Some(rpc_desc.clone()),
request: input.into(),
timeout_ms: ctrl.timeout_ms(),
};
let packets = build_rpc_packet(
self.from_peer_id,
self.to_peer_id,
rpc_desc,
transaction_id,
true,
&rpc_req.encode_to_vec(),
ctrl.trace_id(),
);
let timeout_dur = std::time::Duration::from_millis(ctrl.timeout_ms() as u64);
let rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??;
assert_eq!(rpc_packet.transaction_id, transaction_id);
let rpc_resp = RpcResponse::decode(Bytes::from(rpc_packet.body))?;
if let Some(err) = &rpc_resp.error {
return Err(err.into());
}
Ok(bytes::Bytes::from(rpc_resp.response))
}
}
F::new(HandlerImpl::<F> {
domain_name: domain_name.to_string(),
from_peer_id,
to_peer_id,
zc_packet_sender: self.mpsc.lock().unwrap().get_sink(),
inflight_requests: self.inflight_requests.clone(),
_phan: PhantomData,
})
}
pub fn inflight_count(&self) -> usize {
self.inflight_requests.len()
}
}

View File

@@ -0,0 +1,12 @@
use crate::tunnel::{mpsc::MpscTunnel, Tunnel};
pub type RpcController = super::rpc_types::controller::BaseController;
pub mod client;
pub mod packet;
pub mod server;
pub mod service_registry;
pub mod standalone;
pub type Transport = MpscTunnel<Box<dyn Tunnel>>;
pub type RpcTransactId = i64;

View File

@@ -0,0 +1,161 @@
use prost::Message as _;
use crate::{
common::PeerId,
proto::{
common::{RpcDescriptor, RpcPacket},
rpc_types::error::Error,
},
tunnel::packet_def::{PacketType, ZCPacket},
};
use super::RpcTransactId;
const RPC_PACKET_CONTENT_MTU: usize = 1300;
pub struct PacketMerger {
first_piece: Option<RpcPacket>,
pieces: Vec<RpcPacket>,
last_updated: std::time::Instant,
}
impl PacketMerger {
pub fn new() -> Self {
Self {
first_piece: None,
pieces: Vec::new(),
last_updated: std::time::Instant::now(),
}
}
fn try_merge_pieces(&self) -> Option<RpcPacket> {
if self.first_piece.is_none() || self.pieces.is_empty() {
return None;
}
for p in &self.pieces {
// some piece is missing
if p.total_pieces == 0 {
return None;
}
}
// all pieces are received
let mut body = Vec::new();
for p in &self.pieces {
body.extend_from_slice(&p.body);
}
let mut tmpl_packet = self.first_piece.as_ref().unwrap().clone();
tmpl_packet.total_pieces = 1;
tmpl_packet.piece_idx = 0;
tmpl_packet.body = body;
Some(tmpl_packet)
}
pub fn feed(&mut self, rpc_packet: RpcPacket) -> Result<Option<RpcPacket>, Error> {
let total_pieces = rpc_packet.total_pieces;
let piece_idx = rpc_packet.piece_idx;
if rpc_packet.descriptor.is_none() {
return Err(Error::MalformatRpcPacket(
"descriptor is missing".to_owned(),
));
}
// for compatibility with old version
if total_pieces == 0 && piece_idx == 0 {
return Ok(Some(rpc_packet));
}
// about 32MB max size
if total_pieces > 32 * 1024 || total_pieces == 0 {
return Err(Error::MalformatRpcPacket(format!(
"total_pieces is invalid: {}",
total_pieces
)));
}
if piece_idx >= total_pieces {
return Err(Error::MalformatRpcPacket(
"piece_idx >= total_pieces".to_owned(),
));
}
if self.first_piece.is_none()
|| self.first_piece.as_ref().unwrap().transaction_id != rpc_packet.transaction_id
|| self.first_piece.as_ref().unwrap().from_peer != rpc_packet.from_peer
{
self.first_piece = Some(rpc_packet.clone());
self.pieces.clear();
}
self.pieces
.resize(total_pieces as usize, Default::default());
self.pieces[piece_idx as usize] = rpc_packet;
self.last_updated = std::time::Instant::now();
Ok(self.try_merge_pieces())
}
pub fn last_updated(&self) -> std::time::Instant {
self.last_updated
}
}
pub fn build_rpc_packet(
from_peer: PeerId,
to_peer: PeerId,
rpc_desc: RpcDescriptor,
transaction_id: RpcTransactId,
is_req: bool,
content: &Vec<u8>,
trace_id: i32,
) -> Vec<ZCPacket> {
let mut ret = Vec::new();
let content_mtu = RPC_PACKET_CONTENT_MTU;
let total_pieces = (content.len() + content_mtu - 1) / content_mtu;
let mut cur_offset = 0;
while cur_offset < content.len() || content.len() == 0 {
let mut cur_len = content_mtu;
if cur_offset + cur_len > content.len() {
cur_len = content.len() - cur_offset;
}
let mut cur_content = Vec::new();
cur_content.extend_from_slice(&content[cur_offset..cur_offset + cur_len]);
let cur_packet = RpcPacket {
from_peer,
to_peer,
descriptor: Some(rpc_desc.clone()),
is_request: is_req,
total_pieces: total_pieces as u32,
piece_idx: (cur_offset / content_mtu) as u32,
transaction_id,
body: cur_content,
trace_id,
};
cur_offset += cur_len;
let packet_type = if is_req {
PacketType::RpcReq
} else {
PacketType::RpcResp
};
let mut buf = Vec::new();
cur_packet.encode(&mut buf).unwrap();
let mut zc_packet = ZCPacket::new_with_payload(&buf);
zc_packet.fill_peer_manager_hdr(from_peer, to_peer, packet_type as u8);
ret.push(zc_packet);
if content.len() == 0 {
break;
}
}
ret
}

View File

@@ -0,0 +1,207 @@
use std::{
pin::Pin,
sync::{Arc, Mutex},
};
use bytes::Bytes;
use dashmap::DashMap;
use prost::Message;
use tokio::{task::JoinSet, time::timeout};
use tokio_stream::StreamExt;
use crate::{
common::{join_joinset_background, PeerId},
proto::{
common::{self, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse},
rpc_types::error::Result,
},
tunnel::{
mpsc::{MpscTunnel, MpscTunnelSender},
ring::create_ring_tunnel_pair,
Tunnel, ZCPacketStream,
},
};
use super::{
packet::{build_rpc_packet, PacketMerger},
service_registry::ServiceRegistry,
RpcController, Transport,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct PacketMergerKey {
from_peer_id: PeerId,
rpc_desc: RpcDescriptor,
transaction_id: i64,
}
pub struct Server {
registry: Arc<ServiceRegistry>,
mpsc: Mutex<Option<MpscTunnel<Box<dyn Tunnel>>>>,
transport: Mutex<Transport>,
tasks: Arc<Mutex<JoinSet<()>>>,
packet_mergers: Arc<DashMap<PacketMergerKey, PacketMerger>>,
}
impl Server {
pub fn new() -> Self {
Server::new_with_registry(Arc::new(ServiceRegistry::new()))
}
pub fn new_with_registry(registry: Arc<ServiceRegistry>) -> Self {
let (ring_a, ring_b) = create_ring_tunnel_pair();
Self {
registry,
mpsc: Mutex::new(Some(MpscTunnel::new(ring_a))),
transport: Mutex::new(MpscTunnel::new(ring_b)),
tasks: Arc::new(Mutex::new(JoinSet::new())),
packet_mergers: Arc::new(DashMap::new()),
}
}
pub fn registry(&self) -> &ServiceRegistry {
&self.registry
}
pub fn get_transport_sink(&self) -> MpscTunnelSender {
self.transport.lock().unwrap().get_sink()
}
pub fn get_transport_stream(&self) -> Pin<Box<dyn ZCPacketStream>> {
self.transport.lock().unwrap().get_stream()
}
pub fn run(&self) {
let tasks = self.tasks.clone();
join_joinset_background(tasks.clone(), "rpc server".to_string());
let mpsc = self.mpsc.lock().unwrap().take().unwrap();
let packet_merges = self.packet_mergers.clone();
let reg = self.registry.clone();
let t = tasks.clone();
tasks.lock().unwrap().spawn(async move {
let mut mpsc = mpsc;
let mut rx = mpsc.get_stream();
while let Some(packet) = rx.next().await {
if let Err(err) = packet {
tracing::error!(?err, "Failed to receive packet");
continue;
}
let packet = match common::RpcPacket::decode(packet.unwrap().payload()) {
Err(err) => {
tracing::error!(?err, "Failed to decode packet");
continue;
}
Ok(packet) => packet,
};
if !packet.is_request {
tracing::warn!(?packet, "Received non-request packet");
continue;
}
let key = PacketMergerKey {
from_peer_id: packet.from_peer,
rpc_desc: packet.descriptor.clone().unwrap_or_default(),
transaction_id: packet.transaction_id,
};
let ret = packet_merges
.entry(key.clone())
.or_insert_with(PacketMerger::new)
.feed(packet);
match ret {
Ok(Some(packet)) => {
packet_merges.remove(&key);
t.lock().unwrap().spawn(Self::handle_rpc(
mpsc.get_sink(),
packet,
reg.clone(),
));
}
Ok(None) => {}
Err(err) => {
tracing::error!("Failed to feed packet to merger, {}", err.to_string());
}
}
}
});
let packet_mergers = self.packet_mergers.clone();
tasks.lock().unwrap().spawn(async move {
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
packet_mergers.retain(|_, v| v.last_updated().elapsed().as_secs() < 10);
}
});
}
async fn handle_rpc_request(packet: RpcPacket, reg: Arc<ServiceRegistry>) -> Result<Bytes> {
let rpc_request = RpcRequest::decode(Bytes::from(packet.body))?;
let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64);
let ctrl = RpcController {};
Ok(timeout(
timeout_duration,
reg.call_method(
packet.descriptor.unwrap(),
ctrl,
Bytes::from(rpc_request.request),
),
)
.await??)
}
async fn handle_rpc(sender: MpscTunnelSender, packet: RpcPacket, reg: Arc<ServiceRegistry>) {
let from_peer = packet.from_peer;
let to_peer = packet.to_peer;
let transaction_id = packet.transaction_id;
let trace_id = packet.trace_id;
let desc = packet.descriptor.clone().unwrap();
let mut resp_msg = RpcResponse::default();
let now = std::time::Instant::now();
let resp_bytes = Self::handle_rpc_request(packet, reg).await;
match &resp_bytes {
Ok(r) => {
resp_msg.response = r.clone().into();
}
Err(err) => {
resp_msg.error = Some(err.into());
}
};
resp_msg.runtime_us = now.elapsed().as_micros() as u64;
let packets = build_rpc_packet(
to_peer,
from_peer,
desc,
transaction_id,
false,
&resp_msg.encode_to_vec(),
trace_id,
);
for packet in packets {
if let Err(err) = sender.send(packet).await {
tracing::error!(?err, "Failed to send response packet");
}
}
}
pub fn inflight_count(&self) -> usize {
self.packet_mergers.len()
}
pub fn close(&self) {
self.transport.lock().unwrap().close();
}
}

View File

@@ -0,0 +1,105 @@
use std::sync::Arc;
use dashmap::DashMap;
use crate::proto::common::RpcDescriptor;
use crate::proto::rpc_types;
use crate::proto::rpc_types::descriptor::ServiceDescriptor;
use crate::proto::rpc_types::handler::{Handler, HandlerExt};
use super::RpcController;
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub struct ServiceKey {
pub domain_name: String,
pub service_name: String,
pub proto_name: String,
}
impl From<&RpcDescriptor> for ServiceKey {
fn from(desc: &RpcDescriptor) -> Self {
Self {
domain_name: desc.domain_name.to_string(),
service_name: desc.service_name.to_string(),
proto_name: desc.proto_name.to_string(),
}
}
}
#[derive(Clone)]
struct ServiceEntry {
service: Arc<Box<dyn HandlerExt<Controller = RpcController>>>,
}
impl ServiceEntry {
fn new<H: Handler<Controller = RpcController>>(h: H) -> Self {
Self {
service: Arc::new(Box::new(h)),
}
}
async fn call_method(
&self,
ctrl: RpcController,
method_index: u8,
input: bytes::Bytes,
) -> rpc_types::error::Result<bytes::Bytes> {
self.service.call_method(ctrl, method_index, input).await
}
}
pub struct ServiceRegistry {
table: DashMap<ServiceKey, ServiceEntry>,
}
impl ServiceRegistry {
pub fn new() -> Self {
Self {
table: DashMap::new(),
}
}
pub fn register<H: Handler<Controller = RpcController>>(&self, h: H, domain_name: &str) {
let desc = h.service_descriptor();
let key = ServiceKey {
domain_name: domain_name.to_string(),
service_name: desc.name().to_string(),
proto_name: desc.proto_name().to_string(),
};
let entry = ServiceEntry::new(h);
self.table.insert(key, entry);
}
pub fn unregister<H: Handler<Controller = RpcController>>(
&self,
h: H,
domain_name: &str,
) -> Option<()> {
let desc = h.service_descriptor();
let key = ServiceKey {
domain_name: domain_name.to_string(),
service_name: desc.name().to_string(),
proto_name: desc.proto_name().to_string(),
};
self.table.remove(&key).map(|_| ())
}
pub async fn call_method(
&self,
rpc_desc: RpcDescriptor,
ctrl: RpcController,
input: bytes::Bytes,
) -> rpc_types::error::Result<bytes::Bytes> {
let service_key = ServiceKey::from(&rpc_desc);
let method_index = rpc_desc.method_index as u8;
let entry = self
.table
.get(&service_key)
.ok_or(rpc_types::error::Error::InvalidServiceKey(
service_key.service_name.clone(),
service_key.proto_name.clone(),
))?
.clone();
entry.call_method(ctrl, method_index, input).await
}
}

View File

@@ -0,0 +1,245 @@
use std::{
sync::{atomic::AtomicU32, Arc, Mutex},
time::Duration,
};
use anyhow::Context as _;
use futures::{SinkExt as _, StreamExt};
use tokio::task::JoinSet;
use crate::{
common::join_joinset_background,
proto::rpc_types::{__rt::RpcClientFactory, error::Error},
tunnel::{Tunnel, TunnelConnector, TunnelListener},
};
use super::{client::Client, server::Server, service_registry::ServiceRegistry};
struct StandAloneServerOneTunnel {
tunnel: Box<dyn Tunnel>,
rpc_server: Server,
}
impl StandAloneServerOneTunnel {
pub fn new(tunnel: Box<dyn Tunnel>, registry: Arc<ServiceRegistry>) -> Self {
let rpc_server = Server::new_with_registry(registry);
StandAloneServerOneTunnel { tunnel, rpc_server }
}
pub async fn run(self) {
use tokio_stream::StreamExt as _;
let (tunnel_rx, tunnel_tx) = self.tunnel.split();
let (rpc_rx, rpc_tx) = (
self.rpc_server.get_transport_stream(),
self.rpc_server.get_transport_sink(),
);
let mut tasks = JoinSet::new();
tasks.spawn(async move {
let ret = tunnel_rx.timeout(Duration::from_secs(60));
tokio::pin!(ret);
while let Ok(Some(Ok(p))) = ret.try_next().await {
if let Err(e) = rpc_tx.send(p).await {
tracing::error!("tunnel_rx send to rpc_tx error: {:?}", e);
break;
}
}
tracing::info!("forward tunnel_rx to rpc_tx done");
});
tasks.spawn(async move {
let ret = rpc_rx.forward(tunnel_tx).await;
tracing::info!("rpc_rx forward tunnel_tx done: {:?}", ret);
});
self.rpc_server.run();
while let Some(ret) = tasks.join_next().await {
self.rpc_server.close();
tracing::info!("task done: {:?}", ret);
}
tracing::info!("all tasks done");
}
}
pub struct StandAloneServer<L> {
registry: Arc<ServiceRegistry>,
listener: Option<L>,
inflight_server: Arc<AtomicU32>,
tasks: Arc<Mutex<JoinSet<()>>>,
}
impl<L: TunnelListener + 'static> StandAloneServer<L> {
pub fn new(listener: L) -> Self {
StandAloneServer {
registry: Arc::new(ServiceRegistry::new()),
listener: Some(listener),
inflight_server: Arc::new(AtomicU32::new(0)),
tasks: Arc::new(Mutex::new(JoinSet::new())),
}
}
pub fn registry(&self) -> &ServiceRegistry {
&self.registry
}
pub async fn serve(&mut self) -> Result<(), Error> {
let tasks = self.tasks.clone();
let mut listener = self.listener.take().unwrap();
let registry = self.registry.clone();
join_joinset_background(tasks.clone(), "standalone server tasks".to_string());
listener
.listen()
.await
.with_context(|| "failed to listen")?;
let inflight_server = self.inflight_server.clone();
self.tasks.lock().unwrap().spawn(async move {
while let Ok(tunnel) = listener.accept().await {
let server = StandAloneServerOneTunnel::new(tunnel, registry.clone());
let inflight_server = inflight_server.clone();
inflight_server.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tasks.lock().unwrap().spawn(async move {
server.run().await;
inflight_server.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
});
}
panic!("standalone server listener exit");
});
Ok(())
}
pub fn inflight_server(&self) -> u32 {
self.inflight_server
.load(std::sync::atomic::Ordering::Relaxed)
}
}
struct StandAloneClientOneTunnel {
rpc_client: Client,
tasks: Arc<Mutex<JoinSet<()>>>,
error: Arc<Mutex<Option<Error>>>,
}
impl StandAloneClientOneTunnel {
pub fn new(tunnel: Box<dyn Tunnel>) -> Self {
let rpc_client = Client::new();
let (mut rpc_rx, rpc_tx) = (
rpc_client.get_transport_stream(),
rpc_client.get_transport_sink(),
);
let tasks = Arc::new(Mutex::new(JoinSet::new()));
let (mut tunnel_rx, mut tunnel_tx) = tunnel.split();
let error_store = Arc::new(Mutex::new(None));
let error = error_store.clone();
tasks.lock().unwrap().spawn(async move {
while let Some(p) = rpc_rx.next().await {
match p {
Ok(p) => {
if let Err(e) = tunnel_tx
.send(p)
.await
.with_context(|| "failed to send packet")
{
*error.lock().unwrap() = Some(e.into());
}
}
Err(e) => {
*error.lock().unwrap() = Some(anyhow::Error::from(e).into());
}
}
}
*error.lock().unwrap() = Some(anyhow::anyhow!("rpc_rx next exit").into());
});
let error = error_store.clone();
tasks.lock().unwrap().spawn(async move {
while let Some(p) = tunnel_rx.next().await {
match p {
Ok(p) => {
if let Err(e) = rpc_tx
.send(p)
.await
.with_context(|| "failed to send packet")
{
*error.lock().unwrap() = Some(e.into());
}
}
Err(e) => {
*error.lock().unwrap() = Some(anyhow::Error::from(e).into());
}
}
}
*error.lock().unwrap() = Some(anyhow::anyhow!("tunnel_rx next exit").into());
});
rpc_client.run();
StandAloneClientOneTunnel {
rpc_client,
tasks,
error: error_store,
}
}
pub fn take_error(&self) -> Option<Error> {
self.error.lock().unwrap().take()
}
}
pub struct StandAloneClient<C: TunnelConnector> {
connector: C,
client: Option<StandAloneClientOneTunnel>,
}
impl<C: TunnelConnector> StandAloneClient<C> {
pub fn new(connector: C) -> Self {
StandAloneClient {
connector,
client: None,
}
}
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, Error> {
Ok(self.connector.connect().await.with_context(|| {
format!(
"failed to connect to server: {:?}",
self.connector.remote_url()
)
})?)
}
pub async fn scoped_client<F: RpcClientFactory>(
&mut self,
domain_name: String,
) -> Result<F::ClientImpl, Error> {
let mut c = self.client.take();
let error = c.as_ref().and_then(|c| c.take_error());
if c.is_none() || error.is_some() {
tracing::info!("reconnect due to error: {:?}", error);
let tunnel = self.connect().await?;
c = Some(StandAloneClientOneTunnel::new(tunnel));
}
self.client = c;
Ok(self
.client
.as_ref()
.unwrap()
.rpc_client
.scoped_client::<F>(1, 1, domain_name))
}
}

View File

@@ -0,0 +1,57 @@
//! Utility functions used by generated code; this is *not* part of the crate's public API!
use bytes;
use prost;
use super::controller;
use super::descriptor;
use super::descriptor::ServiceDescriptor;
use super::error;
use super::handler;
use super::handler::Handler;
/// Efficiently decode a particular message type from a byte buffer.
pub fn decode<M>(buf: bytes::Bytes) -> error::Result<M>
where
M: prost::Message + Default,
{
let message = prost::Message::decode(buf)?;
Ok(message)
}
/// Efficiently encode a particular message into a byte buffer.
pub fn encode<M>(message: M) -> error::Result<bytes::Bytes>
where
M: prost::Message,
{
let len = prost::Message::encoded_len(&message);
let mut buf = ::bytes::BytesMut::with_capacity(len);
prost::Message::encode(&message, &mut buf)?;
Ok(buf.freeze())
}
pub async fn call_method<H, I, O>(
handler: H,
ctrl: H::Controller,
method: <H::Descriptor as descriptor::ServiceDescriptor>::Method,
input: I,
) -> super::error::Result<O>
where
H: handler::Handler,
I: prost::Message,
O: prost::Message + Default,
{
type Error = super::error::Error;
let input_bytes = encode(input)?;
let ret_msg = handler.call(ctrl, method, input_bytes).await?;
decode(ret_msg)
}
pub trait RpcClientFactory: Clone + Send + Sync + 'static {
type Descriptor: ServiceDescriptor + Default;
type ClientImpl;
type Controller: controller::Controller;
fn new(
handler: impl Handler<Descriptor = Self::Descriptor, Controller = Self::Controller>,
) -> Self::ClientImpl;
}

View File

@@ -0,0 +1,18 @@
pub trait Controller: Send + Sync + 'static {
fn timeout_ms(&self) -> i32 {
5000
}
fn set_timeout_ms(&mut self, _timeout_ms: i32) {}
fn set_trace_id(&mut self, _trace_id: i32) {}
fn trace_id(&self) -> i32 {
0
}
}
#[derive(Debug)]
pub struct BaseController {}
impl Controller for BaseController {}

View File

@@ -0,0 +1,50 @@
//! Traits for defining generic service descriptor definitions.
//!
//! These traits are built on the assumption that some form of code generation is being used (e.g.
//! using only `&'static str`s) but it's of course possible to implement these traits manually.
use std::any;
use std::fmt;
/// A descriptor for an available RPC service.
pub trait ServiceDescriptor: Clone + fmt::Debug + Send + Sync {
/// The associated type of method descriptors.
type Method: MethodDescriptor + fmt::Debug + TryFrom<u8>;
/// The name of the service, used in Rust code and perhaps for human readability.
fn name(&self) -> &'static str;
/// The raw protobuf name of the service.
fn proto_name(&self) -> &'static str;
/// The package name of the service.
fn package(&self) -> &'static str {
""
}
/// All of the available methods on the service.
fn methods(&self) -> &'static [Self::Method];
}
/// A descriptor for a method available on an RPC service.
pub trait MethodDescriptor: Clone + Copy + fmt::Debug + Send + Sync {
/// The name of the service, used in Rust code and perhaps for human readability.
fn name(&self) -> &'static str;
/// The raw protobuf name of the service.
fn proto_name(&self) -> &'static str;
/// The Rust `TypeId` for the input that this method accepts.
fn input_type(&self) -> any::TypeId;
/// The raw protobuf name for the input type that this method accepts.
fn input_proto_type(&self) -> &'static str;
/// The Rust `TypeId` for the output that this method produces.
fn output_type(&self) -> any::TypeId;
/// The raw protobuf name for the output type that this method produces.
fn output_proto_type(&self) -> &'static str;
/// The index of the method in the service descriptor.
fn index(&self) -> u8;
}

View File

@@ -0,0 +1,34 @@
//! Error type definitions for errors that can occur during RPC interactions.
use std::result;
use prost;
use thiserror;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("rust tun error {0}")]
ExecutionError(#[from] anyhow::Error),
#[error("Decode error: {0}")]
DecodeError(#[from] prost::DecodeError),
#[error("Encode error: {0}")]
EncodeError(#[from] prost::EncodeError),
#[error("Invalid method index: {0}, service: {1}")]
InvalidMethodIndex(u8, String),
#[error("Invalid service name: {0}, proto name: {1}")]
InvalidServiceKey(String, String),
#[error("Invalid packet: {0}")]
MalformatRpcPacket(String),
#[error("Timeout: {0}")]
Timeout(#[from] tokio::time::error::Elapsed),
#[error("Tunnel error: {0}")]
TunnelError(#[from] crate::tunnel::TunnelError),
}
pub type Result<T> = result::Result<T, Error>;

View File

@@ -0,0 +1,67 @@
//! Traits for defining generic RPC handlers.
use super::{
controller::Controller,
descriptor::{self, ServiceDescriptor},
};
use bytes;
/// An implementation of a specific RPC handler.
///
/// This can be an actual implementation of a service, or something that will send a request over
/// a network to fulfill a request.
#[async_trait::async_trait]
pub trait Handler: Clone + Send + Sync + 'static {
/// The service descriptor for the service whose requests this handler can handle.
type Descriptor: descriptor::ServiceDescriptor + Default;
type Controller: super::controller::Controller;
///
/// Perform a raw call to the specified service and method.
async fn call(
&self,
ctrl: Self::Controller,
method: <Self::Descriptor as descriptor::ServiceDescriptor>::Method,
input: bytes::Bytes,
) -> super::error::Result<bytes::Bytes>;
fn service_descriptor(&self) -> Self::Descriptor {
Self::Descriptor::default()
}
fn get_method_from_index(
&self,
index: u8,
) -> super::error::Result<<Self::Descriptor as descriptor::ServiceDescriptor>::Method> {
let desc = self.service_descriptor();
<Self::Descriptor as descriptor::ServiceDescriptor>::Method::try_from(index)
.map_err(|_| super::error::Error::InvalidMethodIndex(index, desc.name().to_string()))
}
}
#[async_trait::async_trait]
pub trait HandlerExt: Send + Sync + 'static {
type Controller;
async fn call_method(
&self,
ctrl: Self::Controller,
method_index: u8,
input: bytes::Bytes,
) -> super::error::Result<bytes::Bytes>;
}
#[async_trait::async_trait]
impl<C: Controller, T: Handler<Controller = C>> HandlerExt for T {
type Controller = C;
async fn call_method(
&self,
ctrl: Self::Controller,
method_index: u8,
input: bytes::Bytes,
) -> super::error::Result<bytes::Bytes> {
let method = self.get_method_from_index(method_index)?;
self.call(ctrl, method, input).await
}
}

View File

@@ -0,0 +1,5 @@
pub mod __rt;
pub mod controller;
pub mod descriptor;
pub mod error;
pub mod handler;

View File

@@ -0,0 +1,24 @@
syntax = "proto3";
package tests;
/// The Greeting service. This service is used to generate greetings for various
/// use-cases.
service Greeting {
// Generates a "hello" greeting based on the supplied info.
rpc SayHello(SayHelloRequest) returns (SayHelloResponse);
// Generates a "goodbye" greeting based on the supplied info.
rpc SayGoodbye(SayGoodbyeRequest) returns (SayGoodbyeResponse);
}
// The request for an `Greeting.SayHello` call.
message SayHelloRequest { string name = 1; }
// The response for an `Greeting.SayHello` call.
message SayHelloResponse { string greeting = 1; }
// The request for an `Greeting.SayGoodbye` call.
message SayGoodbyeRequest { string name = 1; }
// The response for an `Greeting.SayGoodbye` call.
message SayGoodbyeResponse { string greeting = 1; }

225
easytier/src/proto/tests.rs Normal file
View File

@@ -0,0 +1,225 @@
include!(concat!(env!("OUT_DIR"), "/tests.rs"));
use std::sync::{Arc, Mutex};
use futures::StreamExt as _;
use tokio::task::JoinSet;
use super::rpc_impl::RpcController;
#[derive(Clone)]
pub struct GreetingService {
pub delay_ms: u64,
pub prefix: String,
}
#[async_trait::async_trait]
impl Greeting for GreetingService {
type Controller = RpcController;
async fn say_hello(
&self,
_ctrl: Self::Controller,
input: SayHelloRequest,
) -> crate::proto::rpc_types::error::Result<SayHelloResponse> {
let resp = SayHelloResponse {
greeting: format!("{} {}!", self.prefix, input.name),
};
tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
Ok(resp)
}
/// Generates a "goodbye" greeting based on the supplied info.
async fn say_goodbye(
&self,
_ctrl: Self::Controller,
input: SayGoodbyeRequest,
) -> crate::proto::rpc_types::error::Result<SayGoodbyeResponse> {
let resp = SayGoodbyeResponse {
greeting: format!("Goodbye, {}!", input.name),
};
tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
Ok(resp)
}
}
use crate::proto::rpc_impl::client::Client;
use crate::proto::rpc_impl::server::Server;
struct TestContext {
client: Client,
server: Server,
tasks: Arc<Mutex<JoinSet<()>>>,
}
impl TestContext {
fn new() -> Self {
let rpc_server = Server::new();
rpc_server.run();
let client = Client::new();
client.run();
let tasks = Arc::new(Mutex::new(JoinSet::new()));
let (mut rx, tx) = (
rpc_server.get_transport_stream(),
client.get_transport_sink(),
);
tasks.lock().unwrap().spawn(async move {
while let Some(Ok(packet)) = rx.next().await {
if let Err(err) = tx.send(packet).await {
println!("{:?}", err);
break;
}
}
});
let (mut rx, tx) = (
client.get_transport_stream(),
rpc_server.get_transport_sink(),
);
tasks.lock().unwrap().spawn(async move {
while let Some(Ok(packet)) = rx.next().await {
if let Err(err) = tx.send(packet).await {
println!("{:?}", err);
break;
}
}
});
Self {
client,
server: rpc_server,
tasks,
}
}
}
fn random_string(len: usize) -> String {
use rand::distributions::Alphanumeric;
use rand::Rng;
let mut rng = rand::thread_rng();
let s: Vec<u8> = std::iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.take(len)
.collect();
String::from_utf8(s).unwrap()
}
#[tokio::test]
async fn rpc_basic_test() {
let ctx = TestContext::new();
let server = GreetingServer::new(GreetingService {
delay_ms: 0,
prefix: "Hello".to_string(),
});
ctx.server.registry().register(server, "");
let out = ctx
.client
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "".to_string());
// small size req and resp
let ctrl = RpcController {};
let input = SayHelloRequest {
name: "world".to_string(),
};
let ret = out.say_hello(ctrl, input).await;
assert_eq!(ret.unwrap().greeting, "Hello world!");
let ctrl = RpcController {};
let input = SayGoodbyeRequest {
name: "world".to_string(),
};
let ret = out.say_goodbye(ctrl, input).await;
assert_eq!(ret.unwrap().greeting, "Goodbye, world!");
// large size req and resp
let ctrl = RpcController {};
let name = random_string(20 * 1024 * 1024);
let input = SayGoodbyeRequest { name: name.clone() };
let ret = out.say_goodbye(ctrl, input).await;
assert_eq!(ret.unwrap().greeting, format!("Goodbye, {}!", name));
assert_eq!(0, ctx.client.inflight_count());
assert_eq!(0, ctx.server.inflight_count());
}
#[tokio::test]
async fn rpc_timeout_test() {
let ctx = TestContext::new();
let server = GreetingServer::new(GreetingService {
delay_ms: 10000,
prefix: "Hello".to_string(),
});
ctx.server.registry().register(server, "test");
let out = ctx
.client
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "test".to_string());
let ctrl = RpcController {};
let input = SayHelloRequest {
name: "world".to_string(),
};
let ret = out.say_hello(ctrl, input).await;
assert!(ret.is_err());
assert!(matches!(
ret.unwrap_err(),
crate::proto::rpc_types::error::Error::Timeout(_)
));
assert_eq!(0, ctx.client.inflight_count());
assert_eq!(0, ctx.server.inflight_count());
}
#[tokio::test]
async fn standalone_rpc_test() {
use crate::proto::rpc_impl::standalone::{StandAloneClient, StandAloneServer};
use crate::tunnel::tcp::{TcpTunnelConnector, TcpTunnelListener};
let mut server = StandAloneServer::new(TcpTunnelListener::new(
"tcp://0.0.0.0:33455".parse().unwrap(),
));
let service = GreetingServer::new(GreetingService {
delay_ms: 0,
prefix: "Hello".to_string(),
});
server.registry().register(service, "test");
server.serve().await.unwrap();
let mut client = StandAloneClient::new(TcpTunnelConnector::new(
"tcp://127.0.0.1:33455".parse().unwrap(),
));
let out = client
.scoped_client::<GreetingClientFactory<RpcController>>("test".to_string())
.await
.unwrap();
let ctrl = RpcController {};
let input = SayHelloRequest {
name: "world".to_string(),
};
let ret = out.say_hello(ctrl, input).await;
assert_eq!(ret.unwrap().greeting, "Hello world!");
let out = client
.scoped_client::<GreetingClientFactory<RpcController>>("test".to_string())
.await
.unwrap();
let ctrl = RpcController {};
let input = SayGoodbyeRequest {
name: "world".to_string(),
};
let ret = out.say_goodbye(ctrl, input).await;
assert_eq!(ret.unwrap().greeting, "Goodbye, world!");
drop(client);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
assert_eq!(0, server.inflight_server());
}

View File

@@ -1 +0,0 @@
tonic::include_proto!("cli"); // The string specified here must match the proto package name

View File

@@ -1,4 +0,0 @@
pub mod cli;
pub use cli::*;
pub mod peer;

View File

@@ -1,22 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Default)]
pub struct GetIpListResponse {
pub public_ipv4: String,
pub interface_ipv4s: Vec<String>,
pub public_ipv6: String,
pub interface_ipv6s: Vec<String>,
pub listeners: Vec<url::Url>,
}
impl GetIpListResponse {
pub fn new() -> Self {
GetIpListResponse {
public_ipv4: "".to_string(),
interface_ipv4s: vec![],
public_ipv6: "".to_string(),
interface_ipv6s: vec![],
listeners: vec![],
}
}
}

View File

@@ -127,7 +127,7 @@ pub fn enable_log() {
.init(); .init();
} }
fn check_route(ipv4: &str, dst_peer_id: PeerId, routes: Vec<crate::rpc::Route>) { fn check_route(ipv4: &str, dst_peer_id: PeerId, routes: Vec<crate::proto::cli::Route>) {
let mut found = false; let mut found = false;
for r in routes.iter() { for r in routes.iter() {
if r.ipv4_addr == ipv4.to_string() { if r.ipv4_addr == ipv4.to_string() {

View File

@@ -518,8 +518,8 @@ pub async fn foreign_network_forward_nic_data() {
wait_for_condition( wait_for_condition(
|| async { || async {
inst1.get_peer_manager().list_routes().await.len() == 1 inst1.get_peer_manager().list_routes().await.len() == 2
&& inst2.get_peer_manager().list_routes().await.len() == 1 && inst2.get_peer_manager().list_routes().await.len() == 2
}, },
Duration::from_secs(5), Duration::from_secs(5),
) )

View File

@@ -16,10 +16,9 @@ use tokio_stream::StreamExt;
use tokio_util::io::poll_write_buf; use tokio_util::io::poll_write_buf;
use zerocopy::FromBytes as _; use zerocopy::FromBytes as _;
use crate::{ use super::TunnelInfo;
rpc::TunnelInfo,
tunnel::packet_def::{ZCPacket, PEER_MANAGER_HEADER_SIZE}, use crate::tunnel::packet_def::{ZCPacket, PEER_MANAGER_HEADER_SIZE};
};
use super::{ use super::{
buf::BufList, buf::BufList,
@@ -505,8 +504,8 @@ pub mod tests {
let ret = listener.accept().await.unwrap(); let ret = listener.accept().await.unwrap();
println!("accept: {:?}", ret.info()); println!("accept: {:?}", ret.info());
assert_eq!( assert_eq!(
ret.info().unwrap().local_addr, url::Url::from(ret.info().unwrap().local_addr.unwrap()),
listener.local_url().to_string() listener.local_url()
); );
_tunnel_echo_server(ret, false).await _tunnel_echo_server(ret, false).await
}); });
@@ -515,8 +514,8 @@ pub mod tests {
println!("connect: {:?}", tunnel.info()); println!("connect: {:?}", tunnel.info());
assert_eq!( assert_eq!(
tunnel.info().unwrap().remote_addr, url::Url::from(tunnel.info().unwrap().remote_addr.unwrap()),
connector.remote_url().to_string() connector.remote_url(),
); );
let (mut recv, mut send) = tunnel.split(); let (mut recv, mut send) = tunnel.split();

View File

@@ -3,10 +3,11 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use crate::rpc::TunnelInfo;
use auto_impl::auto_impl; use auto_impl::auto_impl;
use futures::{Sink, SinkExt, Stream, StreamExt}; use futures::{Sink, SinkExt, Stream, StreamExt};
use crate::proto::common::TunnelInfo;
use self::stats::Throughput; use self::stats::Throughput;
use super::*; use super::*;

View File

@@ -8,7 +8,7 @@ use std::fmt::Debug;
use tokio::time::error::Elapsed; use tokio::time::error::Elapsed;
use crate::rpc::TunnelInfo; use crate::proto::common::TunnelInfo;
use self::packet_def::ZCPacket; use self::packet_def::ZCPacket;

View File

@@ -3,9 +3,13 @@
use std::{pin::Pin, time::Duration}; use std::{pin::Pin, time::Duration};
use anyhow::Context; use anyhow::Context;
use tokio::{task::JoinHandle, time::timeout}; use tokio::time::timeout;
use super::{packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream}; use crate::common::scoped_task::ScopedTask;
use super::{
packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
};
use tachyonix::{channel, Receiver, Sender}; use tachyonix::{channel, Receiver, Sender};
@@ -29,12 +33,12 @@ impl MpscTunnelSender {
} }
pub struct MpscTunnel<T> { pub struct MpscTunnel<T> {
tx: Sender<ZCPacket>, tx: Option<Sender<ZCPacket>>,
tunnel: T, tunnel: T,
stream: Option<Pin<Box<dyn ZCPacketStream>>>, stream: Option<Pin<Box<dyn ZCPacketStream>>>,
task: Option<JoinHandle<()>>, task: ScopedTask<()>,
} }
impl<T: Tunnel> MpscTunnel<T> { impl<T: Tunnel> MpscTunnel<T> {
@@ -54,10 +58,10 @@ impl<T: Tunnel> MpscTunnel<T> {
}); });
Self { Self {
tx, tx: Some(tx),
tunnel, tunnel,
stream: Some(stream), stream: Some(stream),
task: Some(task), task: task.into(),
} }
} }
@@ -81,7 +85,12 @@ impl<T: Tunnel> MpscTunnel<T> {
} }
pub fn get_sink(&self) -> MpscTunnelSender { pub fn get_sink(&self) -> MpscTunnelSender {
MpscTunnelSender(self.tx.clone()) MpscTunnelSender(self.tx.as_ref().unwrap().clone())
}
pub fn close(&mut self) {
self.tx.take();
self.task.abort();
} }
} }

View File

@@ -54,6 +54,8 @@ pub enum PacketType {
Pong = 5, Pong = 5,
TaRpc = 6, TaRpc = 6,
Route = 7, Route = 7,
RpcReq = 8,
RpcResp = 9,
} }
bitflags::bitflags! { bitflags::bitflags! {

View File

@@ -4,12 +4,10 @@
use std::{error::Error, net::SocketAddr, sync::Arc}; use std::{error::Error, net::SocketAddr, sync::Arc};
use crate::{ use crate::tunnel::{
rpc::TunnelInfo,
tunnel::{
check_scheme_and_get_socket_addr_ext, check_scheme_and_get_socket_addr_ext,
common::{FramedReader, FramedWriter, TunnelWrapper}, common::{FramedReader, FramedWriter, TunnelWrapper},
}, TunnelInfo,
}; };
use anyhow::Context; use anyhow::Context;
use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Connection, Endpoint, ServerConfig}; use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Connection, Endpoint, ServerConfig};
@@ -113,8 +111,10 @@ impl TunnelListener for QUICTunnelListener {
let info = TunnelInfo { let info = TunnelInfo {
tunnel_type: "quic".to_owned(), tunnel_type: "quic".to_owned(),
local_addr: self.local_url().into(), local_addr: Some(self.local_url().into()),
remote_addr: super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(), remote_addr: Some(
super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(),
),
}; };
Ok(Box::new(TunnelWrapper::new( Ok(Box::new(TunnelWrapper::new(
@@ -177,8 +177,10 @@ impl TunnelConnector for QUICTunnelConnector {
let info = TunnelInfo { let info = TunnelInfo {
tunnel_type: "quic".to_owned(), tunnel_type: "quic".to_owned(),
local_addr: super::build_url_from_socket_addr(&local_addr.to_string(), "quic").into(), local_addr: Some(
remote_addr: self.addr.to_string(), super::build_url_from_socket_addr(&local_addr.to_string(), "quic").into(),
),
remote_addr: Some(self.addr.clone().into()),
}; };
let arc_conn = Arc::new(ConnWrapper { conn: connection }); let arc_conn = Arc::new(ConnWrapper { conn: connection });

View File

@@ -261,8 +261,8 @@ fn get_tunnel_for_client(conn: Arc<Connection>) -> impl Tunnel {
RingSink::new(conn.server.clone()), RingSink::new(conn.server.clone()),
Some(TunnelInfo { Some(TunnelInfo {
tunnel_type: "ring".to_owned(), tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(), local_addr: Some(build_url_from_socket_addr(&conn.client.id.into(), "ring").into()),
remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(), remote_addr: Some(build_url_from_socket_addr(&conn.server.id.into(), "ring").into()),
}), }),
) )
} }
@@ -273,8 +273,8 @@ fn get_tunnel_for_server(conn: Arc<Connection>) -> impl Tunnel {
RingSink::new(conn.client.clone()), RingSink::new(conn.client.clone()),
Some(TunnelInfo { Some(TunnelInfo {
tunnel_type: "ring".to_owned(), tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(), local_addr: Some(build_url_from_socket_addr(&conn.server.id.into(), "ring").into()),
remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(), remote_addr: Some(build_url_from_socket_addr(&conn.client.id.into(), "ring").into()),
}), }),
) )
} }

View File

@@ -4,7 +4,8 @@ use async_trait::async_trait;
use futures::stream::FuturesUnordered; use futures::stream::FuturesUnordered;
use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::net::{TcpListener, TcpSocket, TcpStream};
use crate::{rpc::TunnelInfo, tunnel::common::setup_sokcet2}; use super::TunnelInfo;
use crate::tunnel::common::setup_sokcet2;
use super::{ use super::{
check_scheme_and_get_socket_addr, check_scheme_and_get_socket_addr_ext, check_scheme_and_get_socket_addr, check_scheme_and_get_socket_addr_ext,
@@ -56,9 +57,10 @@ impl TunnelListener for TcpTunnelListener {
stream.set_nodelay(true).unwrap(); stream.set_nodelay(true).unwrap();
let info = TunnelInfo { let info = TunnelInfo {
tunnel_type: "tcp".to_owned(), tunnel_type: "tcp".to_owned(),
local_addr: self.local_url().into(), local_addr: Some(self.local_url().into()),
remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp") remote_addr: Some(
.into(), super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(),
),
}; };
let (r, w) = stream.into_split(); let (r, w) = stream.into_split();
@@ -82,9 +84,10 @@ fn get_tunnel_with_tcp_stream(
let info = TunnelInfo { let info = TunnelInfo {
tunnel_type: "tcp".to_owned(), tunnel_type: "tcp".to_owned(),
local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp") local_addr: Some(
.into(), super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp").into(),
remote_addr: remote_url.into(), ),
remote_addr: Some(remote_url.into()),
}; };
let (r, w) = stream.into_split(); let (r, w) = stream.into_split();

View File

@@ -15,9 +15,9 @@ use tokio::{
use tracing::{instrument, Instrument}; use tracing::{instrument, Instrument};
use super::TunnelInfo;
use crate::{ use crate::{
common::join_joinset_background, common::join_joinset_background,
rpc::TunnelInfo,
tunnel::{ tunnel::{
build_url_from_socket_addr, build_url_from_socket_addr,
common::{reserve_buf, TunnelWrapper}, common::{reserve_buf, TunnelWrapper},
@@ -317,8 +317,10 @@ impl UdpTunnelListenerData {
Box::new(RingSink::new(ring_for_send_udp)), Box::new(RingSink::new(ring_for_send_udp)),
Some(TunnelInfo { Some(TunnelInfo {
tunnel_type: "udp".to_owned(), tunnel_type: "udp".to_owned(),
local_addr: self.local_url.clone().into(), local_addr: Some(self.local_url.clone().into()),
remote_addr: build_url_from_socket_addr(&remote_addr.to_string(), "udp").into(), remote_addr: Some(
build_url_from_socket_addr(&remote_addr.to_string(), "udp").into(),
),
}), }),
)); ));
@@ -607,9 +609,10 @@ impl UdpTunnelConnector {
Box::new(RingSink::new(ring_for_send_udp)), Box::new(RingSink::new(ring_for_send_udp)),
Some(TunnelInfo { Some(TunnelInfo {
tunnel_type: "udp".to_owned(), tunnel_type: "udp".to_owned(),
local_addr: build_url_from_socket_addr(&socket.local_addr()?.to_string(), "udp") local_addr: Some(
.into(), build_url_from_socket_addr(&socket.local_addr()?.to_string(), "udp").into(),
remote_addr: self.addr.clone().into(), ),
remote_addr: Some(self.addr.clone().into()),
}), }),
))) )))
} }
@@ -708,7 +711,7 @@ impl super::TunnelConnector for UdpTunnelConnector {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::time::Duration; use std::{net::IpAddr, time::Duration};
use futures::SinkExt; use futures::SinkExt;
use tokio::time::timeout; use tokio::time::timeout;
@@ -786,7 +789,11 @@ mod tests {
loop { loop {
let ret = listener.accept().await.unwrap(); let ret = listener.accept().await.unwrap();
assert_eq!( assert_eq!(
ret.info().unwrap().local_addr, ret.info()
.unwrap()
.local_addr
.unwrap_or_default()
.to_string(),
listener.local_url().to_string() listener.local_url().to_string()
); );
tokio::spawn(async move { _tunnel_echo_server(ret, false).await }); tokio::spawn(async move { _tunnel_echo_server(ret, false).await });
@@ -801,15 +808,15 @@ mod tests {
tokio::spawn(timeout( tokio::spawn(timeout(
Duration::from_secs(2), Duration::from_secs(2),
send_random_data_to_socket(t1.info().unwrap().local_addr.parse().unwrap()), send_random_data_to_socket(t1.info().unwrap().local_addr.unwrap().into()),
)); ));
tokio::spawn(timeout( tokio::spawn(timeout(
Duration::from_secs(2), Duration::from_secs(2),
send_random_data_to_socket(t1.info().unwrap().remote_addr.parse().unwrap()), send_random_data_to_socket(t1.info().unwrap().remote_addr.unwrap().into()),
)); ));
tokio::spawn(timeout( tokio::spawn(timeout(
Duration::from_secs(2), Duration::from_secs(2),
send_random_data_to_socket(t2.info().unwrap().remote_addr.parse().unwrap()), send_random_data_to_socket(t2.info().unwrap().remote_addr.unwrap().into()),
)); ));
let sender1 = tokio::spawn(async move { let sender1 = tokio::spawn(async move {
@@ -854,12 +861,12 @@ mod tests {
if ips.is_empty() { if ips.is_empty() {
return; return;
} }
let bind_dev = get_interface_name_by_ip(&ips[0].parse().unwrap()); let bind_dev = get_interface_name_by_ip(&IpAddr::V4(ips[0].into()));
for ip in ips { for ip in ips {
println!("bind to ip: {:?}, {:?}", ip, bind_dev); println!("bind to ip: {:?}, {:?}", ip, bind_dev);
let addr = check_scheme_and_get_socket_addr::<SocketAddr>( let addr = check_scheme_and_get_socket_addr::<SocketAddr>(
&format!("udp://{}:11111", ip).parse().unwrap(), &format!("udp://{}:11111", ip.to_string()).parse().unwrap(),
"udp", "udp",
) )
.unwrap(); .unwrap();

View File

@@ -8,7 +8,8 @@ use tokio_rustls::TlsAcceptor;
use tokio_websockets::{ClientBuilder, Limits, MaybeTlsStream, Message}; use tokio_websockets::{ClientBuilder, Limits, MaybeTlsStream, Message};
use zerocopy::AsBytes; use zerocopy::AsBytes;
use crate::{rpc::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config}; use super::TunnelInfo;
use crate::tunnel::insecure_tls::get_insecure_tls_client_config;
use super::{ use super::{
common::{setup_sokcet2, wait_for_connect_futures, TunnelWrapper}, common::{setup_sokcet2, wait_for_connect_futures, TunnelWrapper},
@@ -72,12 +73,14 @@ impl WSTunnelListener {
async fn try_accept(&mut self, stream: TcpStream) -> Result<Box<dyn Tunnel>, TunnelError> { async fn try_accept(&mut self, stream: TcpStream) -> Result<Box<dyn Tunnel>, TunnelError> {
let info = TunnelInfo { let info = TunnelInfo {
tunnel_type: self.addr.scheme().to_owned(), tunnel_type: self.addr.scheme().to_owned(),
local_addr: self.local_url().into(), local_addr: Some(self.local_url().into()),
remote_addr: super::build_url_from_socket_addr( remote_addr: Some(
super::build_url_from_socket_addr(
&stream.peer_addr()?.to_string(), &stream.peer_addr()?.to_string(),
self.addr.scheme().to_string().as_str(), self.addr.scheme().to_string().as_str(),
) )
.into(), .into(),
),
}; };
let server_bulder = tokio_websockets::ServerBuilder::new().limits(Limits::unlimited()); let server_bulder = tokio_websockets::ServerBuilder::new().limits(Limits::unlimited());
@@ -182,12 +185,14 @@ impl WSTunnelConnector {
let info = TunnelInfo { let info = TunnelInfo {
tunnel_type: addr.scheme().to_owned(), tunnel_type: addr.scheme().to_owned(),
local_addr: super::build_url_from_socket_addr( local_addr: Some(
super::build_url_from_socket_addr(
&stream.local_addr()?.to_string(), &stream.local_addr()?.to_string(),
addr.scheme().to_string().as_str(), addr.scheme().to_string().as_str(),
) )
.into(), .into(),
remote_addr: addr.to_string(), ),
remote_addr: Some(addr.clone().into()),
}; };
let c = ClientBuilder::from_uri(http::Uri::try_from(addr.to_string()).unwrap()); let c = ClientBuilder::from_uri(http::Uri::try_from(addr.to_string()).unwrap());

View File

@@ -20,13 +20,11 @@ use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use rand::RngCore; use rand::RngCore;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use crate::{ use super::TunnelInfo;
rpc::TunnelInfo, use crate::tunnel::{
tunnel::{
build_url_from_socket_addr, build_url_from_socket_addr,
common::TunnelWrapper, common::TunnelWrapper,
packet_def::{ZCPacket, WG_TUNNEL_HEADER_SIZE}, packet_def::{ZCPacket, WG_TUNNEL_HEADER_SIZE},
},
}; };
use super::{ use super::{
@@ -522,12 +520,16 @@ impl WgTunnelListener {
sink, sink,
Some(TunnelInfo { Some(TunnelInfo {
tunnel_type: "wg".to_owned(), tunnel_type: "wg".to_owned(),
local_addr: build_url_from_socket_addr( local_addr: Some(
build_url_from_socket_addr(
&socket.local_addr().unwrap().to_string(), &socket.local_addr().unwrap().to_string(),
"wg", "wg",
) )
.into(), .into(),
remote_addr: build_url_from_socket_addr(&addr.to_string(), "wg").into(), ),
remote_addr: Some(
build_url_from_socket_addr(&addr.to_string(), "wg").into(),
),
}), }),
)); ));
if let Err(e) = conn_sender.send(tunnel) { if let Err(e) = conn_sender.send(tunnel) {
@@ -670,8 +672,8 @@ impl WgTunnelConnector {
sink, sink,
Some(TunnelInfo { Some(TunnelInfo {
tunnel_type: "wg".to_owned(), tunnel_type: "wg".to_owned(),
local_addr: super::build_url_from_socket_addr(&local_addr, "wg").into(), local_addr: Some(super::build_url_from_socket_addr(&local_addr, "wg").into()),
remote_addr: addr_url.to_string(), remote_addr: Some(addr_url.into()),
}), }),
Some(Box::new(wg_peer)), Some(Box::new(wg_peer)),
)); ));

View File

@@ -5,7 +5,10 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilte
use crate::{ use crate::{
common::{config::ConfigLoader, get_logger_timer_rfc3339}, common::{config::ConfigLoader, get_logger_timer_rfc3339},
rpc::cli::{NatType, PeerInfo, Route}, proto::{
cli::{PeerInfo, Route},
common::NatType,
},
}; };
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -114,17 +117,11 @@ pub fn list_peer_route_pair(peers: Vec<PeerInfo>, routes: Vec<Route>) -> Vec<Pee
for route in routes.iter() { for route in routes.iter() {
let peer = peers.iter().find(|peer| peer.peer_id == route.peer_id); let peer = peers.iter().find(|peer| peer.peer_id == route.peer_id);
let has_tunnel = peer.map(|p| !p.conns.is_empty()).unwrap_or(false); let pair = PeerRoutePair {
let mut pair = PeerRoutePair {
route: route.clone(), route: route.clone(),
peer: peer.cloned(), peer: peer.cloned(),
}; };
// it is relayed by public server, adjust the cost
if !has_tunnel && pair.route.cost == 1 {
pair.route.cost = 2;
}
pairs.push(pair); pairs.push(pair);
} }

View File

@@ -89,8 +89,8 @@ impl WireGuardImpl {
peer_mgr peer_mgr
.get_global_ctx() .get_global_ctx()
.issue_event(GlobalCtxEvent::VpnPortalClientConnected( .issue_event(GlobalCtxEvent::VpnPortalClientConnected(
info.local_addr.clone(), info.local_addr.clone().unwrap_or_default().to_string(),
info.remote_addr.clone(), info.remote_addr.clone().unwrap_or_default().to_string(),
)); ));
let mut map_key = None; let mut map_key = None;
@@ -120,7 +120,7 @@ impl WireGuardImpl {
}; };
if !ip_registered { if !ip_registered {
let client_entry = Arc::new(ClientEntry { let client_entry = Arc::new(ClientEntry {
endpoint_addr: remote_addr.parse().ok(), endpoint_addr: remote_addr.clone().map(Into::into),
sink: mpsc_tunnel.get_sink(), sink: mpsc_tunnel.get_sink(),
}); });
map_key = Some(i.get_source()); map_key = Some(i.get_source());
@@ -142,8 +142,8 @@ impl WireGuardImpl {
peer_mgr peer_mgr
.get_global_ctx() .get_global_ctx()
.issue_event(GlobalCtxEvent::VpnPortalClientDisconnected( .issue_event(GlobalCtxEvent::VpnPortalClientDisconnected(
info.local_addr, info.local_addr.unwrap_or_default().to_string(),
info.remote_addr, info.remote_addr.unwrap_or_default().to_string(),
)); ));
} }