mirror of
https://github.com/lbl8603/vnts.git
synced 2025-12-24 12:47:51 +08:00
支持wg
This commit is contained in:
254
Cargo.lock
generated
254
Cargo.lock
generated
@@ -53,7 +53,7 @@ dependencies = [
|
||||
"actix-service",
|
||||
"actix-utils",
|
||||
"ahash",
|
||||
"base64",
|
||||
"base64 0.21.7",
|
||||
"bitflags 2.5.0",
|
||||
"brotli",
|
||||
"bytes",
|
||||
@@ -368,12 +368,24 @@ dependencies = [
|
||||
"rustc-demangle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||
|
||||
[[package]]
|
||||
name = "base64ct"
|
||||
version = "1.6.0"
|
||||
@@ -392,6 +404,15 @@ version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1"
|
||||
|
||||
[[package]]
|
||||
name = "blake2"
|
||||
version = "0.10.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
|
||||
dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
@@ -401,6 +422,28 @@ dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "boringtun"
|
||||
version = "0.6.0"
|
||||
dependencies = [
|
||||
"aead",
|
||||
"base64 0.13.1",
|
||||
"blake2",
|
||||
"chacha20poly1305",
|
||||
"hex",
|
||||
"hmac",
|
||||
"ip_network",
|
||||
"ip_network_table",
|
||||
"libc",
|
||||
"nix",
|
||||
"parking_lot",
|
||||
"rand_core",
|
||||
"ring",
|
||||
"tracing",
|
||||
"untrusted",
|
||||
"x25519-dalek",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "brotli"
|
||||
version = "3.5.0"
|
||||
@@ -466,6 +509,30 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "chacha20"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cipher",
|
||||
"cpufeatures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "chacha20poly1305"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35"
|
||||
dependencies = [
|
||||
"aead",
|
||||
"chacha20",
|
||||
"cipher",
|
||||
"poly1305",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "change-detection"
|
||||
version = "1.2.0"
|
||||
@@ -498,6 +565,7 @@ checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
|
||||
dependencies = [
|
||||
"crypto-common",
|
||||
"inout",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -671,12 +739,39 @@ dependencies = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.5.3"
|
||||
name = "curve25519-dalek"
|
||||
version = "4.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
|
||||
checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"curve25519-dalek-derive",
|
||||
"fiat-crypto",
|
||||
"rustc_version",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "curve25519-dalek-derive"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "6.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "804c8821570c3f8b70230c2ba75ffa5c0f9a4189b9a432b6656c536712acae28"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"hashbrown 0.14.3",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
@@ -691,9 +786,9 @@ checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
|
||||
|
||||
[[package]]
|
||||
name = "der"
|
||||
version = "0.6.1"
|
||||
version = "0.7.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1a467a65c5e759bce6e65eaf91cc29f466cdc57cb65777bd646872a8a1fd4de"
|
||||
checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0"
|
||||
dependencies = [
|
||||
"const-oid",
|
||||
"pem-rfc7468",
|
||||
@@ -748,6 +843,7 @@ dependencies = [
|
||||
"block-buffer",
|
||||
"const-oid",
|
||||
"crypto-common",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -808,6 +904,12 @@ version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
|
||||
|
||||
[[package]]
|
||||
name = "fiat-crypto"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
|
||||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.0.29"
|
||||
@@ -973,6 +1075,21 @@ version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
|
||||
|
||||
[[package]]
|
||||
name = "hex"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "hmac"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
|
||||
dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "home"
|
||||
version = "0.5.9"
|
||||
@@ -1090,6 +1207,37 @@ dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ip_network"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aa2f047c0a98b2f299aa5d6d7088443570faae494e9ae1305e48be000c9e0eb1"
|
||||
|
||||
[[package]]
|
||||
name = "ip_network_table"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4099b7cfc5c5e2fe8c5edf3f6f7adf7a714c9cc697534f63a5a5da30397cb2c0"
|
||||
dependencies = [
|
||||
"ip_network",
|
||||
"ip_network_table-deps-treebitmap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ip_network_table-deps-treebitmap"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e537132deb99c0eb4b752f0346b6a836200eaaa3516dd7e5514b63930a09e5d"
|
||||
|
||||
[[package]]
|
||||
name = "ipnetwork"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is-terminal"
|
||||
version = "0.4.12"
|
||||
@@ -1300,6 +1448,18 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nix"
|
||||
version = "0.25.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f346ff70e7dbfd675fe90590b92d59ef2de15a8779ae305ebcbfd3f0caf59be4"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"bitflags 1.3.2",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint-dig"
|
||||
version = "0.8.4"
|
||||
@@ -1458,9 +1618,9 @@ checksum = "498a099351efa4becc6a19c72aa9270598e8fd274ca47052e37455241c88b696"
|
||||
|
||||
[[package]]
|
||||
name = "pem-rfc7468"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24d159833a9105500e0398934e205e0773f0b27529557134ecfc51c27646adac"
|
||||
checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412"
|
||||
dependencies = [
|
||||
"base64ct",
|
||||
]
|
||||
@@ -1485,21 +1645,20 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "pkcs1"
|
||||
version = "0.4.1"
|
||||
version = "0.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eff33bdbdfc54cc98a2eca766ebdec3e1b8fb7387523d5c9c9a2891da856f719"
|
||||
checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f"
|
||||
dependencies = [
|
||||
"der",
|
||||
"pkcs8",
|
||||
"spki",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pkcs8"
|
||||
version = "0.9.0"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9eca2c590a5f85da82668fa685c09ce2888b9430e83299debf1f34b65fd4a4ba"
|
||||
checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7"
|
||||
dependencies = [
|
||||
"der",
|
||||
"spki",
|
||||
@@ -1511,6 +1670,17 @@ version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec"
|
||||
|
||||
[[package]]
|
||||
name = "poly1305"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf"
|
||||
dependencies = [
|
||||
"cpufeatures",
|
||||
"opaque-debug",
|
||||
"universal-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "polyval"
|
||||
version = "0.6.2"
|
||||
@@ -1774,21 +1944,20 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rsa"
|
||||
version = "0.7.2"
|
||||
version = "0.9.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "094052d5470cbcef561cb848a7209968c9f12dfa6d668f4bca048ac5de51099c"
|
||||
checksum = "5d0e5124fcb30e76a7e79bfee683a2746db83784b86289f6251b54b7950a0dfc"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"const-oid",
|
||||
"digest",
|
||||
"num-bigint-dig",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-traits",
|
||||
"pkcs1",
|
||||
"pkcs8",
|
||||
"rand_core",
|
||||
"signature",
|
||||
"smallvec",
|
||||
"spki",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
@@ -1938,9 +2107,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "signature"
|
||||
version = "1.6.4"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "74233d3b3b2f6d4b006dc19dee745e73e2a6bfb6f93607cd3b02bd5b00797d7c"
|
||||
checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de"
|
||||
dependencies = [
|
||||
"digest",
|
||||
"rand_core",
|
||||
@@ -1985,9 +2154,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||
|
||||
[[package]]
|
||||
name = "spki"
|
||||
version = "0.6.0"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67cf02bbac7a337dc36e4f5a693db6c21e7863f45070f7064577eb4367a3212b"
|
||||
checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d"
|
||||
dependencies = [
|
||||
"base64ct",
|
||||
"der",
|
||||
@@ -2210,9 +2379,21 @@ checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
||||
dependencies = [
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-attributes"
|
||||
version = "0.1.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-core"
|
||||
version = "0.1.32"
|
||||
@@ -2370,6 +2551,8 @@ dependencies = [
|
||||
"aes-gcm",
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"boringtun",
|
||||
"chrono",
|
||||
"clap",
|
||||
"colored",
|
||||
@@ -2378,6 +2561,7 @@ dependencies = [
|
||||
"dashmap",
|
||||
"dirs",
|
||||
"futures-util",
|
||||
"ipnetwork",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"log4rs",
|
||||
@@ -2653,6 +2837,18 @@ version = "0.52.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
|
||||
|
||||
[[package]]
|
||||
name = "x25519-dalek"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277"
|
||||
dependencies = [
|
||||
"curve25519-dalek",
|
||||
"rand_core",
|
||||
"serde",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.7.32"
|
||||
@@ -2678,6 +2874,20 @@ name = "zeroize"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
|
||||
dependencies = [
|
||||
"zeroize_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeroize_derive"
|
||||
version = "1.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd"
|
||||
|
||||
10
Cargo.toml
10
Cargo.toml
@@ -13,10 +13,10 @@ log4rs = "1.3"
|
||||
dirs = "5"
|
||||
crossbeam = "0.8"
|
||||
parking_lot = "0.12"
|
||||
dashmap = "5.5"
|
||||
dashmap = "6.0.1"
|
||||
|
||||
rsa = { version = "0.7.2", features = [] }
|
||||
spki = { version = "0.6.0", features = ["fingerprint", "alloc"] }
|
||||
rsa = { version = "0.9.6", features = [] }
|
||||
spki = { version = "0.7.3", features = ["fingerprint", "alloc", "base64"] }
|
||||
aes-gcm = { version = "0.10.2", optional = true }
|
||||
ring = { version = "0.17", optional = true }
|
||||
rand = "0.8"
|
||||
@@ -39,6 +39,10 @@ actix-files = { version = "0.6", optional = true }
|
||||
actix-web-static-files = { version = "4.0.1", optional = true }
|
||||
tokio-tungstenite = "0.23.1"
|
||||
|
||||
boringtun = { path = "lib/boringtun", features = [] }
|
||||
ipnetwork = "0.20.0"
|
||||
base64 = "0.22.1"
|
||||
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
crossbeam-utils = "0.8"
|
||||
futures-util = "0.3"
|
||||
|
||||
64
lib/boringtun/Cargo.toml
Normal file
64
lib/boringtun/Cargo.toml
Normal file
@@ -0,0 +1,64 @@
|
||||
[package]
|
||||
name = "boringtun"
|
||||
description = "an implementation of the WireGuard® protocol designed for portability and speed"
|
||||
version = "0.6.0"
|
||||
authors = [
|
||||
"Noah Kennedy <nkennedy@cloudflare.com>",
|
||||
"Andy Grover <agrover@cloudflare.com>",
|
||||
"Jeff Hiner <jhiner@cloudflare.com>",
|
||||
]
|
||||
license = "BSD-3-Clause"
|
||||
repository = "https://github.com/cloudflare/boringtun"
|
||||
documentation = "https://docs.rs/boringtun/0.5.2/boringtun/"
|
||||
edition = "2018"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
device = ["socket2", "thiserror"]
|
||||
jni-bindings = ["ffi-bindings", "jni"]
|
||||
ffi-bindings = ["tracing-subscriber"]
|
||||
# mocks std::time::Instant with mock_instant
|
||||
mock-instant = ["mock_instant"]
|
||||
|
||||
[dependencies]
|
||||
base64 = "0.13"
|
||||
hex = "0.4"
|
||||
untrusted = "0.9.0"
|
||||
libc = "0.2"
|
||||
parking_lot = "0.12"
|
||||
tracing = "0.1.40"
|
||||
tracing-subscriber = { version = "0.3", features = ["fmt"], optional = true }
|
||||
ip_network = "0.4.1"
|
||||
ip_network_table = "0.2.0"
|
||||
ring = "0.17"
|
||||
x25519-dalek = { version = "2.0.0", features = [
|
||||
"reusable_secrets",
|
||||
"static_secrets",
|
||||
] }
|
||||
rand_core = { version = "0.6.4", features = ["getrandom"] }
|
||||
chacha20poly1305 = "0.10.0-pre.1"
|
||||
aead = "0.5.0-pre.2"
|
||||
blake2 = "0.10"
|
||||
hmac = "0.12"
|
||||
jni = { version = "0.19.0", optional = true }
|
||||
mock_instant = { version = "0.3", optional = true }
|
||||
socket2 = { version = "0.4.7", features = ["all"], optional = true }
|
||||
thiserror = { version = "1", optional = true }
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
nix = { version = "0.25", default-features = false, features = [
|
||||
"time",
|
||||
"user",
|
||||
] }
|
||||
|
||||
[dev-dependencies]
|
||||
etherparse = "0.13"
|
||||
tracing-subscriber = "0.3"
|
||||
criterion = { version = "0.3.5", features = ["html_reports"] }
|
||||
|
||||
[lib]
|
||||
crate-type = ["staticlib", "cdylib", "rlib"]
|
||||
|
||||
[[bench]]
|
||||
name = "crypto_benches"
|
||||
harness = false
|
||||
90
lib/boringtun/benches/crypto_benches/blake2s_benching.rs
Normal file
90
lib/boringtun/benches/crypto_benches/blake2s_benching.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
use blake2::digest::{FixedOutput, KeyInit};
|
||||
use blake2::{Blake2s256, Blake2sMac, Digest};
|
||||
use criterion::{BenchmarkId, Criterion, Throughput};
|
||||
use ring::rand::{SecureRandom, SystemRandom};
|
||||
|
||||
pub fn bench_blake2s_hash(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("blake2s_hash");
|
||||
|
||||
group.sample_size(1000);
|
||||
|
||||
for size in [32, 64, 128] {
|
||||
group.throughput(Throughput::Bytes(size as u64));
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("blake2s_crate", size), &size, |b, _| {
|
||||
let buf_in = vec![0u8; size];
|
||||
|
||||
b.iter(|| {
|
||||
let mut hasher = Blake2s256::new();
|
||||
hasher.update(&buf_in);
|
||||
hasher.finalize();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_blake2s_hmac(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("blake2s_hmac");
|
||||
|
||||
group.sample_size(1000);
|
||||
|
||||
for size in [16, 32] {
|
||||
group.throughput(Throughput::Bytes(size as u64));
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("blake2s_crate", size), &size, |b, _| {
|
||||
let buf_in = vec![0u8; size];
|
||||
let rng = SystemRandom::new();
|
||||
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let mut key = [0u8; 32];
|
||||
rng.fill(&mut key).unwrap();
|
||||
key
|
||||
},
|
||||
|key| {
|
||||
use blake2::digest::Update;
|
||||
type HmacBlake2s = hmac::SimpleHmac<blake2::Blake2s256>;
|
||||
let mut hmac = HmacBlake2s::new_from_slice(&key).unwrap();
|
||||
hmac.update(&buf_in);
|
||||
hmac.finalize_fixed();
|
||||
},
|
||||
criterion::BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_blake2s_keyed(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("blake2s_keyed_mac");
|
||||
|
||||
group.sample_size(1000);
|
||||
|
||||
for size in [128, 1024] {
|
||||
group.throughput(Throughput::Bytes(size as u64));
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("blake2s_crate", size), &size, |b, _| {
|
||||
let buf_in = vec![0u8; size];
|
||||
let rng = SystemRandom::new();
|
||||
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let mut key = [0u8; 16];
|
||||
rng.fill(&mut key).unwrap();
|
||||
key
|
||||
},
|
||||
|key| -> [u8; 16] {
|
||||
let mut hmac = Blake2sMac::new_from_slice(&key).unwrap();
|
||||
blake2::digest::Update::update(&mut hmac, &buf_in);
|
||||
hmac.finalize_fixed().into()
|
||||
},
|
||||
criterion::BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
use aead::{AeadInPlace, KeyInit};
|
||||
use criterion::{BenchmarkId, Criterion, Throughput};
|
||||
use rand_core::{OsRng, RngCore};
|
||||
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
||||
|
||||
fn chacha20poly1305_ring(key_bytes: &[u8], buf: &mut [u8]) {
|
||||
let len = buf.len();
|
||||
let n = len - 16;
|
||||
|
||||
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key_bytes).unwrap());
|
||||
|
||||
let tag = key
|
||||
.seal_in_place_separate_tag(
|
||||
Nonce::assume_unique_for_key([0u8; 12]),
|
||||
Aad::from(&[]),
|
||||
&mut buf[..n],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
buf[n..].copy_from_slice(tag.as_ref())
|
||||
}
|
||||
|
||||
fn chacha20poly1305_non_ring(key_bytes: &[u8], buf: &mut [u8]) {
|
||||
let len = buf.len();
|
||||
let n = len - 16;
|
||||
|
||||
let aead = chacha20poly1305::ChaCha20Poly1305::new_from_slice(key_bytes).unwrap();
|
||||
let nonce = chacha20poly1305::Nonce::default();
|
||||
|
||||
let tag = aead
|
||||
.encrypt_in_place_detached(&nonce, &[], &mut buf[..n])
|
||||
.unwrap();
|
||||
|
||||
buf[n..].copy_from_slice(tag.as_ref());
|
||||
}
|
||||
|
||||
pub fn bench_chacha20poly1305(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("chacha20poly1305");
|
||||
|
||||
group.sample_size(1000);
|
||||
|
||||
for size in [128, 192, 1400, 8192] {
|
||||
group.throughput(Throughput::Bytes(size as u64));
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("chacha20poly1305_ring", size),
|
||||
&size,
|
||||
|b, i| {
|
||||
let mut key = [0; 32];
|
||||
let mut buf = vec![0; i + 16];
|
||||
|
||||
let mut rng = OsRng::default();
|
||||
|
||||
rng.fill_bytes(&mut key);
|
||||
rng.fill_bytes(&mut buf);
|
||||
|
||||
b.iter(|| chacha20poly1305_ring(&key, &mut buf));
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("chacha20poly1305_non_ring", size),
|
||||
&size,
|
||||
|b, i| {
|
||||
let mut key = [0; 32];
|
||||
let mut buf = vec![0; i + 16];
|
||||
|
||||
let mut rng = OsRng::default();
|
||||
|
||||
rng.fill_bytes(&mut key);
|
||||
rng.fill_bytes(&mut buf);
|
||||
|
||||
b.iter(|| chacha20poly1305_non_ring(&key, &mut buf));
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
20
lib/boringtun/benches/crypto_benches/main.rs
Normal file
20
lib/boringtun/benches/crypto_benches/main.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
use blake2s_benching::{bench_blake2s_hash, bench_blake2s_hmac, bench_blake2s_keyed};
|
||||
use chacha20poly1305_benching::bench_chacha20poly1305;
|
||||
use x25519_public_key_benching::bench_x25519_public_key;
|
||||
use x25519_shared_key_benching::bench_x25519_shared_key;
|
||||
|
||||
mod blake2s_benching;
|
||||
mod chacha20poly1305_benching;
|
||||
mod x25519_public_key_benching;
|
||||
mod x25519_shared_key_benching;
|
||||
|
||||
criterion::criterion_group!(
|
||||
crypto_benches,
|
||||
bench_chacha20poly1305,
|
||||
bench_blake2s_hash,
|
||||
bench_blake2s_hmac,
|
||||
bench_blake2s_keyed,
|
||||
bench_x25519_shared_key,
|
||||
bench_x25519_public_key
|
||||
);
|
||||
criterion::criterion_main!(crypto_benches);
|
||||
@@ -0,0 +1,30 @@
|
||||
use criterion::Criterion;
|
||||
use rand_core::OsRng;
|
||||
|
||||
pub fn bench_x25519_public_key(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("x25519_public_key");
|
||||
|
||||
group.sample_size(1000);
|
||||
|
||||
group.bench_function("x25519_public_key_dalek", |b| {
|
||||
b.iter(|| {
|
||||
let secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
|
||||
let public_key = x25519_dalek::PublicKey::from(&secret_key);
|
||||
|
||||
(secret_key, public_key)
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("x25519_public_key_ring", |b| {
|
||||
let rng = ring::rand::SystemRandom::new();
|
||||
|
||||
b.iter(|| {
|
||||
let my_private_key =
|
||||
ring::agreement::EphemeralPrivateKey::generate(&ring::agreement::X25519, &rng)
|
||||
.unwrap();
|
||||
my_private_key.compute_public_key().unwrap()
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
use criterion::{BatchSize, Criterion};
|
||||
use rand_core::OsRng;
|
||||
|
||||
pub fn bench_x25519_shared_key(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("x25519_shared_key");
|
||||
|
||||
group.sample_size(1000);
|
||||
|
||||
group.bench_function("x25519_shared_key_dalek", |b| {
|
||||
let public_key =
|
||||
x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random_from_rng(OsRng));
|
||||
|
||||
b.iter_batched(
|
||||
|| x25519_dalek::StaticSecret::random_from_rng(OsRng),
|
||||
|secret_key| secret_key.diffie_hellman(&public_key),
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
|
||||
group.bench_function("x25519_shared_key_ring", |b| {
|
||||
let rng = ring::rand::SystemRandom::new();
|
||||
|
||||
let peer_public_key = {
|
||||
let peer_private_key =
|
||||
ring::agreement::EphemeralPrivateKey::generate(&ring::agreement::X25519, &rng)
|
||||
.unwrap();
|
||||
peer_private_key.compute_public_key().unwrap()
|
||||
};
|
||||
let peer_public_key_alg = &ring::agreement::X25519;
|
||||
|
||||
let my_public_key =
|
||||
ring::agreement::UnparsedPublicKey::new(peer_public_key_alg, &peer_public_key);
|
||||
|
||||
b.iter_batched(
|
||||
|| {
|
||||
ring::agreement::EphemeralPrivateKey::generate(&ring::agreement::X25519, &rng)
|
||||
.unwrap()
|
||||
},
|
||||
|my_private_key| {
|
||||
ring::agreement::agree_ephemeral(my_private_key, &my_public_key, |_key_material| ())
|
||||
.unwrap()
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
389
lib/boringtun/src/device/allowed_ips.rs
Normal file
389
lib/boringtun/src/device/allowed_ips.rs
Normal file
@@ -0,0 +1,389 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use crate::device::peer::AllowedIP;
|
||||
|
||||
use ip_network::IpNetwork;
|
||||
use ip_network_table::IpNetworkTable;
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::iter::FromIterator;
|
||||
use std::net::IpAddr;
|
||||
|
||||
/// A trie of IP/cidr addresses
|
||||
#[derive(Default)]
|
||||
pub struct AllowedIps<D> {
|
||||
ips: IpNetworkTable<D>,
|
||||
}
|
||||
|
||||
impl<'a, D> FromIterator<(&'a AllowedIP, D)> for AllowedIps<D> {
|
||||
fn from_iter<I: IntoIterator<Item = (&'a AllowedIP, D)>>(iter: I) -> Self {
|
||||
let mut allowed_ips = AllowedIps::new();
|
||||
|
||||
for (ip, data) in iter {
|
||||
allowed_ips.insert(ip.addr, ip.cidr as u32, data);
|
||||
}
|
||||
|
||||
allowed_ips
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> AllowedIps<D> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
ips: IpNetworkTable::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.ips = IpNetworkTable::new();
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, key: IpAddr, cidr: u32, data: D) -> Option<D> {
|
||||
// These are networks, it doesn't make sense for host bits to be set, so
|
||||
// use new_truncate().
|
||||
self.ips.insert(
|
||||
IpNetwork::new_truncate(key, cidr as u8).expect("cidr is valid length"),
|
||||
data,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn find(&self, key: IpAddr) -> Option<&D> {
|
||||
self.ips.longest_match(key).map(|(_net, data)| data)
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, predicate: &dyn Fn(&D) -> bool) {
|
||||
self.ips.retain(|_, v| !predicate(v));
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> Iter<D> {
|
||||
Iter(
|
||||
self.ips
|
||||
.iter()
|
||||
.map(|(ipa, d)| (d, ipa.network_address(), ipa.netmask()))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Iter<'a, D: 'a>(VecDeque<(&'a D, IpAddr, u8)>);
|
||||
|
||||
impl<'a, D> Iterator for Iter<'a, D> {
|
||||
type Item = (&'a D, IpAddr, u8);
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.0.pop_front()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn build_allowed_ips() -> AllowedIps<char> {
|
||||
let mut map: AllowedIps<char> = Default::default();
|
||||
map.insert(IpAddr::from([127, 0, 0, 1]), 32, '1');
|
||||
map.insert(IpAddr::from([45, 25, 15, 1]), 30, '6');
|
||||
map.insert(IpAddr::from([127, 0, 15, 1]), 16, '2');
|
||||
map.insert(IpAddr::from([127, 1, 15, 1]), 24, '3');
|
||||
map.insert(IpAddr::from([255, 1, 15, 1]), 24, '4');
|
||||
map.insert(IpAddr::from([60, 25, 15, 1]), 32, '5');
|
||||
map.insert(IpAddr::from([553, 0, 0, 1, 0, 0, 0, 0]), 128, '7');
|
||||
map
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allowed_ips_insert_find() {
|
||||
let map = build_allowed_ips();
|
||||
assert_eq!(map.find(IpAddr::from([127, 0, 0, 1])), Some(&'1'));
|
||||
assert_eq!(map.find(IpAddr::from([127, 0, 255, 255])), Some(&'2'));
|
||||
assert_eq!(map.find(IpAddr::from([127, 1, 255, 255])), None);
|
||||
assert_eq!(map.find(IpAddr::from([127, 0, 255, 255])), Some(&'2'));
|
||||
assert_eq!(map.find(IpAddr::from([127, 1, 15, 255])), Some(&'3'));
|
||||
assert_eq!(map.find(IpAddr::from([127, 0, 255, 255])), Some(&'2'));
|
||||
assert_eq!(map.find(IpAddr::from([127, 1, 15, 255])), Some(&'3'));
|
||||
assert_eq!(map.find(IpAddr::from([255, 1, 15, 2])), Some(&'4'));
|
||||
assert_eq!(map.find(IpAddr::from([60, 25, 15, 1])), Some(&'5'));
|
||||
assert_eq!(map.find(IpAddr::from([20, 0, 0, 100])), None);
|
||||
assert_eq!(
|
||||
map.find(IpAddr::from([553, 0, 0, 1, 0, 0, 0, 0])),
|
||||
Some(&'7')
|
||||
);
|
||||
assert_eq!(map.find(IpAddr::from([553, 0, 0, 1, 0, 0, 0, 1])), None);
|
||||
assert_eq!(map.find(IpAddr::from([45, 25, 15, 1])), Some(&'6'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allowed_ips_remove() {
|
||||
let mut map = build_allowed_ips();
|
||||
map.remove(&|c| *c == '5' || *c == '1' || *c == '7');
|
||||
|
||||
let mut map_iter = map.iter();
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'6', IpAddr::from([45, 25, 15, 0]), 30))
|
||||
);
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'2', IpAddr::from([127, 0, 0, 0]), 16))
|
||||
);
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'3', IpAddr::from([127, 1, 15, 0]), 24))
|
||||
);
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'4', IpAddr::from([255, 1, 15, 0]), 24))
|
||||
);
|
||||
assert_eq!(map_iter.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allowed_ips_iter() {
|
||||
let map = build_allowed_ips();
|
||||
let mut map_iter = map.iter();
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'6', IpAddr::from([45, 25, 15, 0]), 30))
|
||||
);
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'5', IpAddr::from([60, 25, 15, 1]), 32))
|
||||
);
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'2', IpAddr::from([127, 0, 0, 0]), 16))
|
||||
);
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'1', IpAddr::from([127, 0, 0, 1]), 32))
|
||||
);
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'3', IpAddr::from([127, 1, 15, 0]), 24))
|
||||
);
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'4', IpAddr::from([255, 1, 15, 0]), 24))
|
||||
);
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'7', IpAddr::from([553, 0, 0, 1, 0, 0, 0, 0]), 128))
|
||||
);
|
||||
assert_eq!(map_iter.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allowed_ips_v4_kernel_compatibility() {
|
||||
// Test case from wireguard-go
|
||||
let mut map: AllowedIps<char> = Default::default();
|
||||
|
||||
map.insert(IpAddr::from([192, 168, 4, 0]), 24, 'a');
|
||||
map.insert(IpAddr::from([192, 168, 4, 4]), 32, 'b');
|
||||
map.insert(IpAddr::from([192, 168, 0, 0]), 16, 'c');
|
||||
map.insert(IpAddr::from([192, 95, 5, 64]), 27, 'd');
|
||||
map.insert(IpAddr::from([192, 95, 5, 65]), 27, 'c');
|
||||
map.insert(IpAddr::from([0, 0, 0, 0]), 0, 'e');
|
||||
map.insert(IpAddr::from([64, 15, 112, 0]), 20, 'g');
|
||||
map.insert(IpAddr::from([64, 15, 123, 211]), 25, 'h');
|
||||
map.insert(IpAddr::from([10, 0, 0, 0]), 25, 'a');
|
||||
map.insert(IpAddr::from([10, 0, 0, 128]), 25, 'b');
|
||||
map.insert(IpAddr::from([10, 1, 0, 0]), 30, 'a');
|
||||
map.insert(IpAddr::from([10, 1, 0, 4]), 30, 'b');
|
||||
map.insert(IpAddr::from([10, 1, 0, 8]), 29, 'c');
|
||||
map.insert(IpAddr::from([10, 1, 0, 16]), 29, 'd');
|
||||
|
||||
assert_eq!(Some(&'a'), map.find(IpAddr::from([192, 168, 4, 20])));
|
||||
assert_eq!(Some(&'a'), map.find(IpAddr::from([192, 168, 4, 0])));
|
||||
assert_eq!(Some(&'b'), map.find(IpAddr::from([192, 168, 4, 4])));
|
||||
assert_eq!(Some(&'c'), map.find(IpAddr::from([192, 168, 200, 182])));
|
||||
assert_eq!(Some(&'c'), map.find(IpAddr::from([192, 95, 5, 68])));
|
||||
assert_eq!(Some(&'e'), map.find(IpAddr::from([192, 95, 5, 96])));
|
||||
assert_eq!(Some(&'g'), map.find(IpAddr::from([64, 15, 116, 26])));
|
||||
assert_eq!(Some(&'g'), map.find(IpAddr::from([64, 15, 127, 3])));
|
||||
|
||||
map.insert(IpAddr::from([1, 0, 0, 0]), 32, 'a');
|
||||
map.insert(IpAddr::from([64, 0, 0, 0]), 32, 'a');
|
||||
map.insert(IpAddr::from([128, 0, 0, 0]), 32, 'a');
|
||||
map.insert(IpAddr::from([192, 0, 0, 0]), 32, 'a');
|
||||
map.insert(IpAddr::from([255, 0, 0, 0]), 32, 'a');
|
||||
|
||||
assert_eq!(Some(&'a'), map.find(IpAddr::from([1, 0, 0, 0])));
|
||||
assert_eq!(Some(&'a'), map.find(IpAddr::from([64, 0, 0, 0])));
|
||||
assert_eq!(Some(&'a'), map.find(IpAddr::from([128, 0, 0, 0])));
|
||||
assert_eq!(Some(&'a'), map.find(IpAddr::from([192, 0, 0, 0])));
|
||||
assert_eq!(Some(&'a'), map.find(IpAddr::from([255, 0, 0, 0])));
|
||||
|
||||
map.remove(&|c| *c == 'a');
|
||||
|
||||
assert_ne!(Some(&'a'), map.find(IpAddr::from([1, 0, 0, 0])));
|
||||
assert_ne!(Some(&'a'), map.find(IpAddr::from([64, 0, 0, 0])));
|
||||
assert_ne!(Some(&'a'), map.find(IpAddr::from([128, 0, 0, 0])));
|
||||
assert_ne!(Some(&'a'), map.find(IpAddr::from([192, 0, 0, 0])));
|
||||
assert_ne!(Some(&'a'), map.find(IpAddr::from([255, 0, 0, 0])));
|
||||
|
||||
map.clear();
|
||||
|
||||
map.insert(IpAddr::from([192, 168, 0, 0]), 16, 'a');
|
||||
map.insert(IpAddr::from([192, 168, 0, 0]), 24, 'a');
|
||||
|
||||
map.remove(&|c| *c == 'a');
|
||||
|
||||
assert_ne!(Some(&'a'), map.find(IpAddr::from([192, 168, 0, 1])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allowed_ips_v6_kernel_compatibility() {
|
||||
// Test case from wireguard-go
|
||||
let mut map: AllowedIps<char> = Default::default();
|
||||
|
||||
map.insert(
|
||||
IpAddr::from([
|
||||
0x2607, 0x5300, 0x6000, 0x6b00, 0x0000, 0x0000, 0xc05f, 0x0543,
|
||||
]),
|
||||
128,
|
||||
'd',
|
||||
);
|
||||
map.insert(
|
||||
IpAddr::from([
|
||||
0x2607, 0x5300, 0x6000, 0x6b00, 0x0000, 0x0000, 0x0000, 0x0000,
|
||||
]),
|
||||
64,
|
||||
'c',
|
||||
);
|
||||
map.insert(
|
||||
IpAddr::from([
|
||||
0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
|
||||
]),
|
||||
0,
|
||||
'e',
|
||||
);
|
||||
map.insert(
|
||||
IpAddr::from([
|
||||
0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
|
||||
]),
|
||||
0,
|
||||
'f',
|
||||
);
|
||||
map.insert(
|
||||
IpAddr::from([
|
||||
0x2404, 0x6800, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
|
||||
]),
|
||||
32,
|
||||
'g',
|
||||
);
|
||||
map.insert(
|
||||
IpAddr::from([
|
||||
0x2404, 0x6800, 0x4004, 0x0800, 0xdead, 0xbeef, 0xdead, 0xbeef,
|
||||
]),
|
||||
64,
|
||||
'h',
|
||||
);
|
||||
map.insert(
|
||||
IpAddr::from([
|
||||
0x2404, 0x6800, 0x4004, 0x0800, 0xdead, 0xbeef, 0xdead, 0xbeef,
|
||||
]),
|
||||
128,
|
||||
'a',
|
||||
);
|
||||
map.insert(
|
||||
IpAddr::from([
|
||||
0x2444, 0x6800, 0x40e4, 0x0800, 0xdeae, 0xbeef, 0x0def, 0xbeef,
|
||||
]),
|
||||
128,
|
||||
'c',
|
||||
);
|
||||
map.insert(
|
||||
IpAddr::from([
|
||||
0x2444, 0x6800, 0xf0e4, 0x0800, 0xeeae, 0xbeef, 0x0000, 0x0000,
|
||||
]),
|
||||
98,
|
||||
'b',
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Some(&'d'),
|
||||
map.find(IpAddr::from([
|
||||
0x2607, 0x5300, 0x6000, 0x6b00, 0x0000, 0x0000, 0xc05f, 0x0543
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
Some(&'c'),
|
||||
map.find(IpAddr::from([
|
||||
0x2607, 0x5300, 0x6000, 0x6b00, 0, 0, 0xc02e, 0x01ee
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
Some(&'f'),
|
||||
map.find(IpAddr::from([0x2607, 0x5300, 0x6000, 0x6b01, 0, 0, 0, 0]))
|
||||
);
|
||||
assert_eq!(
|
||||
Some(&'g'),
|
||||
map.find(IpAddr::from([
|
||||
0x2404, 0x6800, 0x4004, 0x0806, 0, 0, 0, 0x1006
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
Some(&'g'),
|
||||
map.find(IpAddr::from([
|
||||
0x2404, 0x6800, 0x4004, 0x0806, 0, 0x1234, 0, 0x5678
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
Some(&'f'),
|
||||
map.find(IpAddr::from([
|
||||
0x2404, 0x67ff, 0x4004, 0x0806, 0, 0x1234, 0, 0x5678
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
Some(&'f'),
|
||||
map.find(IpAddr::from([
|
||||
0x2404, 0x6801, 0x4004, 0x0806, 0, 0x1234, 0, 0x5678
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
Some(&'h'),
|
||||
map.find(IpAddr::from([
|
||||
0x2404, 0x6800, 0x4004, 0x0800, 0, 0x1234, 0, 0x5678
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
Some(&'h'),
|
||||
map.find(IpAddr::from([0x2404, 0x6800, 0x4004, 0x0800, 0, 0, 0, 0]))
|
||||
);
|
||||
assert_eq!(
|
||||
Some(&'h'),
|
||||
map.find(IpAddr::from([
|
||||
0x2404, 0x6800, 0x4004, 0x0800, 0x1010, 0x1010, 0x1010, 0x1010
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
Some(&'a'),
|
||||
map.find(IpAddr::from([
|
||||
0x2404, 0x6800, 0x4004, 0x0800, 0xdead, 0xbeef, 0xdead, 0xbeef
|
||||
]))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allowed_ips_iter_zero_leaf_bits() {
|
||||
let mut map: AllowedIps<char> = Default::default();
|
||||
map.insert(IpAddr::from([10, 111, 0, 1]), 32, '1');
|
||||
map.insert(IpAddr::from([10, 111, 0, 2]), 32, '2');
|
||||
map.insert(IpAddr::from([10, 111, 0, 3]), 32, '3');
|
||||
|
||||
let mut map_iter = map.iter();
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'1', IpAddr::from([10, 111, 0, 1]), 32))
|
||||
);
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'2', IpAddr::from([10, 111, 0, 2]), 32))
|
||||
);
|
||||
assert_eq!(
|
||||
map_iter.next(),
|
||||
Some((&'3', IpAddr::from([10, 111, 0, 3]), 32))
|
||||
);
|
||||
assert_eq!(map_iter.next(), None);
|
||||
}
|
||||
}
|
||||
368
lib/boringtun/src/device/api.rs
Normal file
368
lib/boringtun/src/device/api.rs
Normal file
@@ -0,0 +1,368 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use super::dev_lock::LockReadGuard;
|
||||
use super::drop_privileges::get_saved_ids;
|
||||
use super::{AllowedIP, Device, Error, SocketAddr};
|
||||
use crate::device::Action;
|
||||
use crate::serialization::KeyBytes;
|
||||
use crate::x25519;
|
||||
use hex::encode as encode_hex;
|
||||
use libc::*;
|
||||
use std::fs::{create_dir, remove_file};
|
||||
use std::io::{BufRead, BufReader, BufWriter, Write};
|
||||
use std::os::unix::io::{AsRawFd, FromRawFd};
|
||||
use std::os::unix::net::{UnixListener, UnixStream};
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
const SOCK_DIR: &str = "/var/run/wireguard/";
|
||||
|
||||
fn create_sock_dir() {
|
||||
let _ = create_dir(SOCK_DIR); // Create the directory if it does not exist
|
||||
|
||||
if let Ok((saved_uid, saved_gid)) = get_saved_ids() {
|
||||
unsafe {
|
||||
let c_path = std::ffi::CString::new(SOCK_DIR).unwrap();
|
||||
// The directory is under the root user, but we want to be able to
|
||||
// delete the files there when we exit, so we need to change the owner
|
||||
chown(
|
||||
c_path.as_bytes_with_nul().as_ptr() as _,
|
||||
saved_uid,
|
||||
saved_gid,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
/// Register the api handler for this Device. The api handler receives stream connections on a Unix socket
|
||||
/// with a known path: /var/run/wireguard/{tun_name}.sock.
|
||||
pub fn register_api_handler(&mut self) -> Result<(), Error> {
|
||||
let path = format!("{}/{}.sock", SOCK_DIR, self.iface.name()?);
|
||||
|
||||
create_sock_dir();
|
||||
|
||||
let _ = remove_file(&path); // Attempt to remove the socket if already exists
|
||||
|
||||
let api_listener = UnixListener::bind(&path).map_err(Error::ApiSocket)?; // Bind a new socket to the path
|
||||
|
||||
self.cleanup_paths.push(path.clone());
|
||||
|
||||
self.queue.new_event(
|
||||
api_listener.as_raw_fd(),
|
||||
Box::new(move |d, _| {
|
||||
// This is the closure that listens on the api unix socket
|
||||
let (api_conn, _) = match api_listener.accept() {
|
||||
Ok(conn) => conn,
|
||||
_ => return Action::Continue,
|
||||
};
|
||||
|
||||
let mut reader = BufReader::new(&api_conn);
|
||||
let mut writer = BufWriter::new(&api_conn);
|
||||
let mut cmd = String::new();
|
||||
if reader.read_line(&mut cmd).is_ok() {
|
||||
cmd.pop(); // pop the new line character
|
||||
let status = match cmd.as_ref() {
|
||||
// Only two commands are legal according to the protocol, get=1 and set=1.
|
||||
"get=1" => api_get(&mut writer, d),
|
||||
"set=1" => api_set(&mut reader, d),
|
||||
_ => EIO,
|
||||
};
|
||||
// The protocol requires to return an error code as the response, or zero on success
|
||||
writeln!(writer, "errno={}\n", status).ok();
|
||||
}
|
||||
Action::Continue // Indicates the worker thread should continue as normal
|
||||
}),
|
||||
)?;
|
||||
|
||||
self.register_monitor(path)?;
|
||||
self.register_api_signal_handlers()
|
||||
}
|
||||
|
||||
pub fn register_api_fd(&mut self, fd: i32) -> Result<(), Error> {
|
||||
let io_file = unsafe { UnixStream::from_raw_fd(fd) };
|
||||
|
||||
self.queue.new_event(
|
||||
io_file.as_raw_fd(),
|
||||
Box::new(move |d, _| {
|
||||
// This is the closure that listens on the api file descriptor
|
||||
|
||||
let mut reader = BufReader::new(&io_file);
|
||||
let mut writer = BufWriter::new(&io_file);
|
||||
let mut cmd = String::new();
|
||||
if reader.read_line(&mut cmd).is_ok() {
|
||||
cmd.pop(); // pop the new line character
|
||||
let status = match cmd.as_ref() {
|
||||
// Only two commands are legal according to the protocol, get=1 and set=1.
|
||||
"get=1" => api_get(&mut writer, d),
|
||||
"set=1" => api_set(&mut reader, d),
|
||||
_ => EIO,
|
||||
};
|
||||
// The protocol requires to return an error code as the response, or zero on success
|
||||
writeln!(writer, "errno={}\n", status).ok();
|
||||
} else {
|
||||
// The remote side is likely closed; we should trigger an exit.
|
||||
d.trigger_exit();
|
||||
return Action::Exit;
|
||||
}
|
||||
|
||||
Action::Continue // Indicates the worker thread should continue as normal
|
||||
}),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn register_monitor(&self, path: String) -> Result<(), Error> {
|
||||
self.queue.new_periodic_event(
|
||||
Box::new(move |d, _| {
|
||||
// This is not a very nice hack to detect if the control socket was removed
|
||||
// and exiting nicely as a result. We check every 3 seconds in a loop if the
|
||||
// file was deleted by stating it.
|
||||
// The problem is that on linux inotify can be used quite beautifully to detect
|
||||
// deletion, and kqueue EVFILT_VNODE can be used for the same purpose, but that
|
||||
// will require introducing new events, for no measurable benefit.
|
||||
// TODO: Could this be an issue if we restart the service too quickly?
|
||||
let path = std::path::Path::new(&path);
|
||||
if !path.exists() {
|
||||
d.trigger_exit();
|
||||
return Action::Exit;
|
||||
}
|
||||
|
||||
// Periodically read the mtu of the interface in case it changes
|
||||
if let Ok(mtu) = d.iface.mtu() {
|
||||
d.mtu.store(mtu, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
Action::Continue
|
||||
}),
|
||||
std::time::Duration::from_millis(1000),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn register_api_signal_handlers(&self) -> Result<(), Error> {
|
||||
self.queue
|
||||
.new_signal_event(SIGINT, Box::new(move |_, _| Action::Exit))?;
|
||||
|
||||
self.queue
|
||||
.new_signal_event(SIGTERM, Box::new(move |_, _| Action::Exit))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused_must_use)]
|
||||
fn api_get(writer: &mut BufWriter<&UnixStream>, d: &Device) -> i32 {
|
||||
// get command requires an empty line, but there is no reason to be religious about it
|
||||
if let Some(ref k) = d.key_pair {
|
||||
writeln!(writer, "own_public_key={}", encode_hex(k.1.as_bytes()));
|
||||
}
|
||||
|
||||
if d.listen_port != 0 {
|
||||
writeln!(writer, "listen_port={}", d.listen_port);
|
||||
}
|
||||
|
||||
if let Some(fwmark) = d.fwmark {
|
||||
writeln!(writer, "fwmark={}", fwmark);
|
||||
}
|
||||
|
||||
for (k, p) in d.peers.iter() {
|
||||
let p = p.lock();
|
||||
writeln!(writer, "public_key={}", encode_hex(k.as_bytes()));
|
||||
|
||||
if let Some(ref key) = p.preshared_key() {
|
||||
writeln!(writer, "preshared_key={}", encode_hex(key));
|
||||
}
|
||||
|
||||
if let Some(keepalive) = p.persistent_keepalive() {
|
||||
writeln!(writer, "persistent_keepalive_interval={}", keepalive);
|
||||
}
|
||||
|
||||
if let Some(ref addr) = p.endpoint().addr {
|
||||
writeln!(writer, "endpoint={}", addr);
|
||||
}
|
||||
|
||||
for (ip, cidr) in p.allowed_ips() {
|
||||
writeln!(writer, "allowed_ip={}/{}", ip, cidr);
|
||||
}
|
||||
|
||||
if let Some(time) = p.time_since_last_handshake() {
|
||||
writeln!(writer, "last_handshake_time_sec={}", time.as_secs());
|
||||
writeln!(writer, "last_handshake_time_nsec={}", time.subsec_nanos());
|
||||
}
|
||||
|
||||
let (_, tx_bytes, rx_bytes, ..) = p.tunnel.stats();
|
||||
|
||||
writeln!(writer, "rx_bytes={}", rx_bytes);
|
||||
writeln!(writer, "tx_bytes={}", tx_bytes);
|
||||
}
|
||||
0
|
||||
}
|
||||
|
||||
fn api_set(reader: &mut BufReader<&UnixStream>, d: &mut LockReadGuard<Device>) -> i32 {
|
||||
d.try_writeable(
|
||||
|device| device.trigger_yield(),
|
||||
|device| {
|
||||
device.cancel_yield();
|
||||
|
||||
let mut cmd = String::new();
|
||||
|
||||
while reader.read_line(&mut cmd).is_ok() {
|
||||
cmd.pop(); // remove newline if any
|
||||
if cmd.is_empty() {
|
||||
return 0; // Done
|
||||
}
|
||||
{
|
||||
let parsed_cmd: Vec<&str> = cmd.split('=').collect();
|
||||
if parsed_cmd.len() != 2 {
|
||||
return EPROTO;
|
||||
}
|
||||
|
||||
let (key, val) = (parsed_cmd[0], parsed_cmd[1]);
|
||||
|
||||
match key {
|
||||
"private_key" => match val.parse::<KeyBytes>() {
|
||||
Ok(key_bytes) => {
|
||||
device.set_key(x25519::StaticSecret::from(key_bytes.0))
|
||||
}
|
||||
Err(_) => return EINVAL,
|
||||
},
|
||||
"listen_port" => match val.parse::<u16>() {
|
||||
Ok(port) => match device.open_listen_socket(port) {
|
||||
Ok(()) => {}
|
||||
Err(_) => return EADDRINUSE,
|
||||
},
|
||||
Err(_) => return EINVAL,
|
||||
},
|
||||
#[cfg(any(
|
||||
target_os = "android",
|
||||
target_os = "fuchsia",
|
||||
target_os = "linux"
|
||||
))]
|
||||
"fwmark" => match val.parse::<u32>() {
|
||||
Ok(mark) => match device.set_fwmark(mark) {
|
||||
Ok(()) => {}
|
||||
Err(_) => return EADDRINUSE,
|
||||
},
|
||||
Err(_) => return EINVAL,
|
||||
},
|
||||
"replace_peers" => match val.parse::<bool>() {
|
||||
Ok(true) => device.clear_peers(),
|
||||
Ok(false) => {}
|
||||
Err(_) => return EINVAL,
|
||||
},
|
||||
"public_key" => match val.parse::<KeyBytes>() {
|
||||
// Indicates a new peer section
|
||||
Ok(key_bytes) => {
|
||||
return api_set_peer(
|
||||
reader,
|
||||
device,
|
||||
x25519::PublicKey::from(key_bytes.0),
|
||||
)
|
||||
}
|
||||
Err(_) => return EINVAL,
|
||||
},
|
||||
_ => return EINVAL,
|
||||
}
|
||||
}
|
||||
cmd.clear();
|
||||
}
|
||||
|
||||
0
|
||||
},
|
||||
)
|
||||
.unwrap_or(EIO)
|
||||
}
|
||||
|
||||
fn api_set_peer(
|
||||
reader: &mut BufReader<&UnixStream>,
|
||||
d: &mut Device,
|
||||
pub_key: x25519::PublicKey,
|
||||
) -> i32 {
|
||||
let mut cmd = String::new();
|
||||
|
||||
let mut remove = false;
|
||||
let mut replace_ips = false;
|
||||
let mut endpoint = None;
|
||||
let mut keepalive = None;
|
||||
let mut public_key = pub_key;
|
||||
let mut preshared_key = None;
|
||||
let mut allowed_ips: Vec<AllowedIP> = vec![];
|
||||
while reader.read_line(&mut cmd).is_ok() {
|
||||
cmd.pop(); // remove newline if any
|
||||
if cmd.is_empty() {
|
||||
d.update_peer(
|
||||
public_key,
|
||||
remove,
|
||||
replace_ips,
|
||||
endpoint,
|
||||
allowed_ips.as_slice(),
|
||||
keepalive,
|
||||
preshared_key,
|
||||
);
|
||||
allowed_ips.clear(); //clear the vector content after update
|
||||
return 0; // Done
|
||||
}
|
||||
{
|
||||
let parsed_cmd: Vec<&str> = cmd.splitn(2, '=').collect();
|
||||
if parsed_cmd.len() != 2 {
|
||||
return EPROTO;
|
||||
}
|
||||
let (key, val) = (parsed_cmd[0], parsed_cmd[1]);
|
||||
match key {
|
||||
"remove" => match val.parse::<bool>() {
|
||||
Ok(true) => remove = true,
|
||||
Ok(false) => remove = false,
|
||||
Err(_) => return EINVAL,
|
||||
},
|
||||
"preshared_key" => match val.parse::<KeyBytes>() {
|
||||
Ok(key_bytes) => preshared_key = Some(key_bytes.0),
|
||||
Err(_) => return EINVAL,
|
||||
},
|
||||
"endpoint" => match val.parse::<SocketAddr>() {
|
||||
Ok(addr) => endpoint = Some(addr),
|
||||
Err(_) => return EINVAL,
|
||||
},
|
||||
"persistent_keepalive_interval" => match val.parse::<u16>() {
|
||||
Ok(interval) => keepalive = Some(interval),
|
||||
Err(_) => return EINVAL,
|
||||
},
|
||||
"replace_allowed_ips" => match val.parse::<bool>() {
|
||||
Ok(true) => replace_ips = true,
|
||||
Ok(false) => replace_ips = false,
|
||||
Err(_) => return EINVAL,
|
||||
},
|
||||
"allowed_ip" => match val.parse::<AllowedIP>() {
|
||||
Ok(ip) => allowed_ips.push(ip),
|
||||
Err(_) => return EINVAL,
|
||||
},
|
||||
"public_key" => {
|
||||
// Indicates a new peer section. Commit changes for current peer, and continue to next peer
|
||||
d.update_peer(
|
||||
public_key,
|
||||
remove,
|
||||
replace_ips,
|
||||
endpoint,
|
||||
allowed_ips.as_slice(),
|
||||
keepalive,
|
||||
preshared_key,
|
||||
);
|
||||
allowed_ips.clear(); //clear the vector content after update
|
||||
match val.parse::<KeyBytes>() {
|
||||
Ok(key_bytes) => public_key = key_bytes.0.into(),
|
||||
Err(_) => return EINVAL,
|
||||
}
|
||||
}
|
||||
"protocol_version" => match val.parse::<u32>() {
|
||||
Ok(1) => {} // Only version 1 is legal
|
||||
_ => return EINVAL,
|
||||
},
|
||||
_ => return EINVAL,
|
||||
}
|
||||
}
|
||||
cmd.clear();
|
||||
}
|
||||
0
|
||||
}
|
||||
108
lib/boringtun/src/device/dev_lock.rs
Normal file
108
lib/boringtun/src/device/dev_lock.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use parking_lot::{Condvar, Mutex, RwLock, RwLockReadGuard};
|
||||
use std::ops::Deref;
|
||||
|
||||
/// A special type of read/write lock, that makes the following assumptions:
|
||||
/// a) Read access is frequent, and has to be very fast, so we want to hold it indefinitely
|
||||
/// b) Write access is very rare (think less than once per second) and can be a bit slower
|
||||
/// c) A thread that holds a read lock, can ask for an upgrade to a write lock, cooperatively asking other threads to yield their locks
|
||||
pub struct Lock<T: ?Sized> {
|
||||
wants_write: (Mutex<bool>, Condvar),
|
||||
inner: RwLock<T>, // Although parking lot lock is upgradable, it does not allow a two staged mark + lock upgrade
|
||||
}
|
||||
|
||||
impl<T> Lock<T> {
|
||||
/// New lock
|
||||
pub fn new(user_data: T) -> Lock<T> {
|
||||
Lock {
|
||||
wants_write: (Mutex::new(false), Condvar::new()),
|
||||
inner: RwLock::new(user_data),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized> Lock<T> {
|
||||
/// Acquire a read lock
|
||||
pub fn read(&self) -> LockReadGuard<T> {
|
||||
let (ref lock, ref cvar) = &self.wants_write;
|
||||
let mut wants_write = lock.lock();
|
||||
while *wants_write {
|
||||
// We have a writer and we want to wait for it to go away
|
||||
cvar.wait(&mut wants_write);
|
||||
}
|
||||
|
||||
LockReadGuard {
|
||||
wants_write: &self.wants_write,
|
||||
inner: self.inner.read(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LockReadGuard<'a, T: 'a + ?Sized> {
|
||||
wants_write: &'a (Mutex<bool>, Condvar),
|
||||
inner: RwLockReadGuard<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T: ?Sized> LockReadGuard<'a, T> {
|
||||
/// Perform a closure on a mutable reference of the inner locked value.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// `prep_func` - A closure that will run once, after the lock marks its intention to write,
|
||||
/// this can be used to tell other threads to yield their read locks temporarily. It will be passed
|
||||
/// an immutable reference to the inner value.
|
||||
///
|
||||
/// `mut_func` - A closure that will run once write access is gained. It iwll be passed a mutable reference
|
||||
/// to the inner value.
|
||||
///
|
||||
pub fn try_writeable<U, P: FnOnce(&T), F: FnOnce(&mut T) -> U>(
|
||||
&mut self,
|
||||
prep_func: P,
|
||||
mut_func: F,
|
||||
) -> Option<U> {
|
||||
// First tell everyone that we want to write now, this will prevent any new reader from starting until we are done.
|
||||
{
|
||||
let &(ref lock, cvar) = &self.wants_write;
|
||||
let mut wants_write = lock.lock();
|
||||
|
||||
RwLockReadGuard::unlocked(&mut self.inner, move || {
|
||||
while *wants_write {
|
||||
// We have a writer and we want to wait for it to go away
|
||||
cvar.wait(&mut wants_write);
|
||||
}
|
||||
|
||||
*wants_write = true;
|
||||
});
|
||||
}
|
||||
|
||||
// Second stage is to run the prep function
|
||||
prep_func(&*self.inner);
|
||||
|
||||
let lock = RwLockReadGuard::rwlock(&self.inner);
|
||||
|
||||
// Third stage is to perform our op under a write lock
|
||||
let ret = Some(RwLockReadGuard::unlocked(&mut self.inner, move || {
|
||||
// There is no race here because wants_write blocks other threads
|
||||
let mut write_access = lock.write();
|
||||
mut_func(&mut *write_access)
|
||||
}));
|
||||
|
||||
// Finally signal other threads
|
||||
let (ref lock, ref cvar) = &self.wants_write;
|
||||
let mut wants_write = lock.lock();
|
||||
*wants_write = false;
|
||||
cvar.notify_all();
|
||||
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: ?Sized> Deref for LockReadGuard<'a, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &T {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
75
lib/boringtun/src/device/drop_privileges.rs
Normal file
75
lib/boringtun/src/device/drop_privileges.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use crate::device::Error;
|
||||
use libc::{gid_t, setgid, setuid, uid_t};
|
||||
use std::io;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
use nix::unistd::User;
|
||||
|
||||
pub fn get_saved_ids() -> Result<(uid_t, gid_t), Error> {
|
||||
// Get the user name of the sudoer
|
||||
#[cfg(target_os = "macos")]
|
||||
match std::env::var("USER") {
|
||||
Ok(uname) => match User::from_name(&uname) {
|
||||
Ok(Some(user)) => Ok((uid_t::from(user.uid), gid_t::from(user.gid))),
|
||||
Err(e) => Err(Error::DropPrivileges(format!(
|
||||
"Failed parse user; err: {:?}",
|
||||
e
|
||||
))),
|
||||
Ok(None) => Err(Error::DropPrivileges("Failed to find user".to_owned())),
|
||||
},
|
||||
Err(e) => Err(Error::DropPrivileges(format!(
|
||||
"Could not get environment variable for user; err: {:?}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
use libc::{getlogin, getpwnam};
|
||||
|
||||
let uname = unsafe { getlogin() };
|
||||
if uname.is_null() {
|
||||
return Err(Error::DropPrivileges("NULL from getlogin".to_owned()));
|
||||
}
|
||||
let userinfo = unsafe { getpwnam(uname) };
|
||||
if userinfo.is_null() {
|
||||
return Err(Error::DropPrivileges("NULL from getpwnam".to_owned()));
|
||||
}
|
||||
|
||||
// Saved group ID
|
||||
let saved_gid = unsafe { (*userinfo).pw_gid };
|
||||
// Saved user ID
|
||||
let saved_uid = unsafe { (*userinfo).pw_uid };
|
||||
|
||||
Ok((saved_uid, saved_gid))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn drop_privileges() -> Result<(), Error> {
|
||||
let (saved_uid, saved_gid) = get_saved_ids()?;
|
||||
|
||||
if -1 == unsafe { setgid(saved_gid) } {
|
||||
// Set real and effective group ID
|
||||
return Err(Error::DropPrivileges(
|
||||
io::Error::last_os_error().to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if -1 == unsafe { setuid(saved_uid) } {
|
||||
// Set real and effective user ID
|
||||
return Err(Error::DropPrivileges(
|
||||
io::Error::last_os_error().to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Validated we can't get sudo back again
|
||||
if unsafe { (setgid(0) != -1) || (setuid(0) != -1) } {
|
||||
Err(Error::DropPrivileges(
|
||||
"Failed to permanently drop privileges".to_owned(),
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
416
lib/boringtun/src/device/epoll.rs
Normal file
416
lib/boringtun/src/device/epoll.rs
Normal file
@@ -0,0 +1,416 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use super::Error;
|
||||
use libc::*;
|
||||
use parking_lot::Mutex;
|
||||
use std::io;
|
||||
use std::ops::Deref;
|
||||
use std::os::unix::io::RawFd;
|
||||
use std::ptr::null_mut;
|
||||
use std::time::Duration;
|
||||
|
||||
/// A return type for the EventPoll::wait() function
|
||||
pub enum WaitResult<'a, H> {
|
||||
/// Event triggered normally
|
||||
Ok(EventGuard<'a, H>),
|
||||
/// Event triggered due to End of File conditions
|
||||
EoF(EventGuard<'a, H>),
|
||||
/// There was an error
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// Implements a registry of pollable events
|
||||
pub struct EventPoll<H: Sized> {
|
||||
events: Mutex<Vec<Option<Box<Event<H>>>>>,
|
||||
epoll: RawFd, // The OS epoll
|
||||
}
|
||||
|
||||
/// A type that hold a reference to a triggered Event
|
||||
/// While an EventGuard exists for a given Event, it will not be triggered by any other thread
|
||||
/// Once the EventGuard goes out of scope, the underlying Event will be re-enabled
|
||||
pub struct EventGuard<'a, H> {
|
||||
epoll: RawFd,
|
||||
event: &'a mut Event<H>,
|
||||
poll: &'a EventPoll<H>,
|
||||
}
|
||||
|
||||
/// A reference to a single event in an EventPoll
|
||||
pub struct EventRef {
|
||||
trigger: RawFd,
|
||||
}
|
||||
|
||||
struct Event<H> {
|
||||
event: epoll_event, // The epoll event description
|
||||
fd: RawFd, // The associated fd
|
||||
handler: H, // The associated data
|
||||
notifier: bool, // Is a notification event
|
||||
needs_read: bool, // This event needs to be read to be cleared
|
||||
}
|
||||
|
||||
impl<H> Drop for EventPoll<H> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { close(self.epoll) };
|
||||
}
|
||||
}
|
||||
|
||||
impl<H: Sync + Send> EventPoll<H> {
|
||||
/// Create a new event registry
|
||||
pub fn new() -> Result<EventPoll<H>, Error> {
|
||||
let epoll = match unsafe { epoll_create(1) } {
|
||||
-1 => return Err(Error::EventQueue(io::Error::last_os_error())),
|
||||
epoll => epoll,
|
||||
};
|
||||
|
||||
Ok(EventPoll {
|
||||
events: Mutex::new(vec![]),
|
||||
epoll,
|
||||
})
|
||||
}
|
||||
|
||||
/// Add and enable a new event with the factory.
|
||||
/// The event is triggered when a Read operation on the provided trigger becomes available
|
||||
/// If the trigger fd is closed, the event won't be triggered anymore, but it's data won't be
|
||||
/// automatically released.
|
||||
/// The safe way to delete an event, is using the cancel method of an EventGuard.
|
||||
/// If the same trigger is used with multiple events in the same EventPoll, the last added
|
||||
/// event overrides all previous events. In case the same trigger is used with multiple polls,
|
||||
/// each event will be triggered independently.
|
||||
/// The event will keep triggering until a Read operation is no longer possible on the trigger.
|
||||
/// When triggered, one of the threads waiting on the poll will receive the handler via an
|
||||
/// appropriate EventGuard. It is guaranteed that only a single thread can have a reference to
|
||||
/// the handler at any given time.
|
||||
pub fn new_event(&self, trigger: RawFd, handler: H) -> Result<EventRef, Error> {
|
||||
// Create an event descriptor
|
||||
let flags = EPOLLIN | EPOLLONESHOT;
|
||||
let ev = Event {
|
||||
event: epoll_event {
|
||||
events: flags as _,
|
||||
u64: 0,
|
||||
},
|
||||
fd: trigger,
|
||||
handler,
|
||||
notifier: false,
|
||||
needs_read: false,
|
||||
};
|
||||
|
||||
self.register_event(ev)
|
||||
}
|
||||
|
||||
/// Add and enable a new write event with the factory.
|
||||
/// The event is triggered when a Write operation on the provided trigger becomes possible
|
||||
/// For TCP sockets it means that the socket was succesfully connected
|
||||
#[allow(dead_code)]
|
||||
pub fn new_write_event(&self, trigger: RawFd, handler: H) -> Result<EventRef, Error> {
|
||||
// Create an event descriptor
|
||||
let flags = EPOLLOUT | EPOLLET | EPOLLONESHOT;
|
||||
let ev = Event {
|
||||
event: epoll_event {
|
||||
events: flags as _,
|
||||
u64: 0,
|
||||
},
|
||||
fd: trigger,
|
||||
handler,
|
||||
notifier: false,
|
||||
needs_read: false,
|
||||
};
|
||||
|
||||
self.register_event(ev)
|
||||
}
|
||||
|
||||
/// Add and enable a new timed event with the factory.
|
||||
/// The even will be triggered for the first time after period time, and henceforth triggered
|
||||
/// every period time. Period is counted from the moment the appropriate EventGuard is released.
|
||||
pub fn new_periodic_event(&self, handler: H, period: Duration) -> Result<EventRef, Error> {
|
||||
// The periodic event on Linux uses the timerfd
|
||||
let tfd = match unsafe { timerfd_create(CLOCK_BOOTTIME, TFD_NONBLOCK) } {
|
||||
-1 => match unsafe { timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK) } {
|
||||
// A fallback for kernels < 3.15
|
||||
-1 => return Err(Error::Timer(io::Error::last_os_error())),
|
||||
efd => efd,
|
||||
},
|
||||
efd => efd,
|
||||
};
|
||||
|
||||
let ts = timespec {
|
||||
tv_sec: period.as_secs() as _,
|
||||
tv_nsec: i64::from(period.subsec_nanos()) as _,
|
||||
};
|
||||
|
||||
let spec = itimerspec {
|
||||
it_value: ts,
|
||||
it_interval: ts,
|
||||
};
|
||||
|
||||
if unsafe { timerfd_settime(tfd, 0, &spec, std::ptr::null_mut()) } == -1 {
|
||||
unsafe { close(tfd) };
|
||||
return Err(Error::Timer(io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
let ev = Event {
|
||||
event: epoll_event {
|
||||
events: (EPOLLIN | EPOLLONESHOT) as _,
|
||||
u64: 0,
|
||||
},
|
||||
fd: tfd,
|
||||
handler,
|
||||
notifier: false,
|
||||
needs_read: true,
|
||||
};
|
||||
|
||||
self.register_event(ev)
|
||||
}
|
||||
|
||||
/// Add and enable a new notification event with the factory.
|
||||
/// The event can only be triggered manually, using the trigger_notification method.
|
||||
/// The event will remain in a triggered state until the stop_notification method is
|
||||
/// called. Both methods should only be called with the producing EventPoll.
|
||||
pub fn new_notifier(&self, handler: H) -> Result<EventRef, Error> {
|
||||
// The notifier on Linux uses the eventfd for notifications.
|
||||
// The way it works is when a non zero value is written into the eventfd it will trigger
|
||||
// the EPOLLIN event. Since we don't enable ONESHOT it will keep triggering until
|
||||
// canceled.
|
||||
// When we want to stop the event, we read something once from the file descriptor.
|
||||
let efd = match unsafe { eventfd(0, EFD_NONBLOCK) } {
|
||||
-1 => return Err(Error::EventQueue(io::Error::last_os_error())),
|
||||
efd => efd,
|
||||
};
|
||||
|
||||
let ev = Event {
|
||||
event: epoll_event {
|
||||
events: (EPOLLIN) as _,
|
||||
u64: 0,
|
||||
},
|
||||
fd: efd,
|
||||
handler,
|
||||
notifier: true,
|
||||
needs_read: false,
|
||||
};
|
||||
|
||||
self.register_event(ev)
|
||||
}
|
||||
|
||||
/// Add and enable a new signal handler
|
||||
pub fn new_signal_event(&self, signal: c_int, handler: H) -> Result<EventRef, Error> {
|
||||
let sfd = match unsafe {
|
||||
let mut sigset = std::mem::zeroed();
|
||||
sigemptyset(&mut sigset);
|
||||
sigaddset(&mut sigset, signal);
|
||||
sigprocmask(SIG_BLOCK, &sigset, null_mut());
|
||||
signalfd(-1, &sigset, SFD_NONBLOCK)
|
||||
} {
|
||||
-1 => return Err(Error::EventQueue(io::Error::last_os_error())),
|
||||
sfd => sfd,
|
||||
};
|
||||
|
||||
let ev = Event {
|
||||
event: epoll_event {
|
||||
events: (EPOLLIN | EPOLLONESHOT) as _,
|
||||
u64: 0,
|
||||
},
|
||||
fd: sfd,
|
||||
handler,
|
||||
notifier: false,
|
||||
needs_read: true,
|
||||
};
|
||||
|
||||
self.register_event(ev)
|
||||
}
|
||||
|
||||
/// Wait until one of the registered events becomes triggered. Once an event
|
||||
/// is triggered, a single caller thread gets the handler for that event.
|
||||
/// In case a notifier is triggered, all waiting threads will receive the same
|
||||
/// handler.
|
||||
pub fn wait(&self) -> WaitResult<'_, H> {
|
||||
let mut event = epoll_event { events: 0, u64: 0 };
|
||||
match unsafe { epoll_wait(self.epoll, &mut event, 1, -1) } {
|
||||
-1 => return WaitResult::Error(io::Error::last_os_error().to_string()),
|
||||
1 => {}
|
||||
_ => return WaitResult::Error("unexpected number of events returned".to_string()),
|
||||
}
|
||||
|
||||
let event_data = unsafe { (event.u64 as *mut Event<H>).as_mut().unwrap() };
|
||||
|
||||
let guard = EventGuard {
|
||||
epoll: self.epoll,
|
||||
event: event_data,
|
||||
poll: self,
|
||||
};
|
||||
|
||||
if event.events & EPOLLHUP as u32 != 0 {
|
||||
// End of file flag
|
||||
WaitResult::EoF(guard)
|
||||
} else {
|
||||
WaitResult::Ok(guard)
|
||||
}
|
||||
}
|
||||
|
||||
// Register an event with this poll.
|
||||
fn register_event(&self, ev: Event<H>) -> Result<EventRef, Error> {
|
||||
// To register an event we
|
||||
// * Create a reference to self in the inner event
|
||||
// * Store the Event in the events vector
|
||||
// * Dispose of a previous Event under same fd if any
|
||||
// * Add the Event to epoll
|
||||
let trigger = ev.fd;
|
||||
let mut ev = Box::new(ev);
|
||||
// The inner event points back to the wrapper
|
||||
ev.event.u64 = ev.as_mut() as *mut Event<H> as _;
|
||||
let mut event_desc = ev.event;
|
||||
// Now add the pointer to the events vector, this is a place from which we can drop the event
|
||||
self.insert_at(trigger as _, ev);
|
||||
// Add the event to epoll
|
||||
if unsafe { epoll_ctl(self.epoll, EPOLL_CTL_ADD, trigger, &mut event_desc) } == -1 {
|
||||
return Err(Error::EventQueue(io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
Ok(EventRef { trigger })
|
||||
}
|
||||
|
||||
// Insert an event into the events vector
|
||||
fn insert_at(&self, index: usize, data: Box<Event<H>>) {
|
||||
let mut events = self.events.lock();
|
||||
while events.len() <= index {
|
||||
// Resize the vector to be able to fit the new index
|
||||
// We trust the OS to allocate file descriptors in a sane order
|
||||
events.push(None); // resize doesn't work because Clone is not satisfied
|
||||
}
|
||||
|
||||
if events[index].take().is_some() {
|
||||
// Properly remove the previous event first
|
||||
unsafe {
|
||||
epoll_ctl(self.epoll, EPOLL_CTL_DEL, index as _, null_mut());
|
||||
};
|
||||
}
|
||||
|
||||
events[index] = Some(data);
|
||||
}
|
||||
|
||||
/// Trigger a notification
|
||||
pub fn trigger_notification(&self, notification_event: &EventRef) {
|
||||
let events = self.events.lock();
|
||||
|
||||
let event_ref = &(*events)[notification_event.trigger as usize];
|
||||
let event_data = event_ref.as_ref().expect("Expected an event");
|
||||
|
||||
if !event_data.notifier {
|
||||
panic!("Can only trigger a notification event");
|
||||
}
|
||||
|
||||
// Write some data to the eventfd to trigger an EPOLLIN event
|
||||
unsafe {
|
||||
write(
|
||||
notification_event.trigger,
|
||||
&(std::u64::MAX - 1).to_ne_bytes()[0] as *const u8 as _,
|
||||
8,
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
/// Stop a notification
|
||||
pub fn stop_notification(&self, notification_event: &EventRef) {
|
||||
let events = self.events.lock();
|
||||
|
||||
let event_ref = &(*events)[notification_event.trigger as usize];
|
||||
let event_data = event_ref.as_ref().expect("Expected an event");
|
||||
|
||||
if !event_data.notifier {
|
||||
panic!("Can only trigger a notification event");
|
||||
}
|
||||
|
||||
let mut buf = [0u8; 8];
|
||||
unsafe {
|
||||
read(
|
||||
notification_event.trigger,
|
||||
buf.as_mut_ptr() as _,
|
||||
buf.len() as _,
|
||||
)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl<H> EventPoll<H> {
|
||||
/// Disable and remove the event and associated handler, using the fd that
|
||||
/// was used to register it.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This function is only safe to call when the event loop is not running,
|
||||
/// otherwise the memory of the handler may get freed while in use.
|
||||
pub unsafe fn clear_event_by_fd(&self, index: RawFd) {
|
||||
let mut events = self.events.lock();
|
||||
assert!(index >= 0);
|
||||
if events[index as usize].take().is_some() {
|
||||
epoll_ctl(self.epoll, EPOLL_CTL_DEL, index, null_mut());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, H> Deref for EventGuard<'a, H> {
|
||||
type Target = H;
|
||||
fn deref(&self) -> &H {
|
||||
&self.event.handler
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, H> Drop for EventGuard<'a, H> {
|
||||
fn drop(&mut self) {
|
||||
if self.event.needs_read {
|
||||
// Must read from the event to reset it before we enable it
|
||||
let mut buf: [std::mem::MaybeUninit<u8>; 256] =
|
||||
unsafe { std::mem::MaybeUninit::uninit().assume_init() };
|
||||
while unsafe { read(self.event.fd, buf.as_mut_ptr() as _, buf.len() as _) } != -1 {}
|
||||
}
|
||||
|
||||
unsafe {
|
||||
epoll_ctl(
|
||||
self.epoll,
|
||||
EPOLL_CTL_MOD,
|
||||
self.event.fd,
|
||||
&mut self.event.event,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, H> EventGuard<'a, H> {
|
||||
/// Get a mutable reference to the stored value
|
||||
#[allow(dead_code)]
|
||||
pub fn get_mut(&mut self) -> &mut H {
|
||||
&mut self.event.handler
|
||||
}
|
||||
|
||||
/// Cancel and remove the event referenced by this guard
|
||||
pub fn cancel(self) {
|
||||
unsafe { self.poll.clear_event_by_fd(self.event.fd) };
|
||||
std::mem::forget(self); // Don't call the regular drop that would enable the event
|
||||
}
|
||||
|
||||
pub fn fd(&self) -> i32 {
|
||||
self.event.fd
|
||||
}
|
||||
|
||||
/// Change the event flags to enable or disable notifying when the fd is writable
|
||||
pub fn notify_writable(&mut self, enabled: bool) {
|
||||
let flags = if enabled {
|
||||
EPOLLOUT | EPOLLIN | EPOLLET | EPOLLONESHOT
|
||||
} else {
|
||||
EPOLLIN | EPOLLONESHOT
|
||||
};
|
||||
self.event.event.events = flags as _;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn block_signal(signal: c_int) -> Result<sigset_t, String> {
|
||||
unsafe {
|
||||
let mut sigset = std::mem::zeroed();
|
||||
sigemptyset(&mut sigset);
|
||||
if sigaddset(&mut sigset, signal) == -1 {
|
||||
return Err(io::Error::last_os_error().to_string());
|
||||
}
|
||||
if sigprocmask(SIG_BLOCK, &sigset, null_mut()) == -1 {
|
||||
return Err(io::Error::last_os_error().to_string());
|
||||
}
|
||||
Ok(sigset)
|
||||
}
|
||||
}
|
||||
849
lib/boringtun/src/device/integration_tests/mod.rs
Normal file
849
lib/boringtun/src/device/integration_tests/mod.rs
Normal file
@@ -0,0 +1,849 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// This module contains some integration tests for boringtun
|
||||
// Those tests require docker and sudo privileges to run
|
||||
#[cfg(all(test, not(target_os = "macos")))]
|
||||
mod tests {
|
||||
use crate::device::{DeviceConfig, DeviceHandle};
|
||||
use crate::x25519::{PublicKey, StaticSecret};
|
||||
use base64::encode as base64encode;
|
||||
use hex::encode;
|
||||
use rand_core::OsRng;
|
||||
use ring::rand::{SecureRandom, SystemRandom};
|
||||
use std::fmt::Write as _;
|
||||
use std::io::{BufRead, BufReader, Read, Write};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::os::unix::net::UnixStream;
|
||||
use std::process::Command;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
static NEXT_IFACE_IDX: AtomicUsize = AtomicUsize::new(100); // utun 100+ should be vacant during testing on CI
|
||||
static NEXT_PORT: AtomicUsize = AtomicUsize::new(61111); // Use ports starting with 61111, hoping we don't run into a taken port 🤷
|
||||
static NEXT_IP: AtomicUsize = AtomicUsize::new(0xc0000200); // Use 192.0.2.0/24 for those tests, we might use more than 256 addresses though, usize must be >=32 bits on all supported platforms
|
||||
static NEXT_IP_V6: AtomicUsize = AtomicUsize::new(0); // Use the 2001:db8:: address space, append this atomic counter for bottom 32 bits
|
||||
|
||||
fn next_ip() -> IpAddr {
|
||||
IpAddr::V4(Ipv4Addr::from(
|
||||
NEXT_IP.fetch_add(1, Ordering::Relaxed) as u32
|
||||
))
|
||||
}
|
||||
|
||||
fn next_ip_v6() -> IpAddr {
|
||||
let addr = 0x2001_0db8_0000_0000_0000_0000_0000_0000_u128
|
||||
+ u128::from(NEXT_IP_V6.fetch_add(1, Ordering::Relaxed) as u32);
|
||||
|
||||
IpAddr::V6(Ipv6Addr::from(addr))
|
||||
}
|
||||
|
||||
fn next_port() -> u16 {
|
||||
NEXT_PORT.fetch_add(1, Ordering::Relaxed) as u16
|
||||
}
|
||||
|
||||
/// Represents an allowed IP and cidr for a peer
|
||||
struct AllowedIp {
|
||||
ip: IpAddr,
|
||||
cidr: u8,
|
||||
}
|
||||
|
||||
/// Represents a single peer running in a container
|
||||
struct Peer {
|
||||
key: StaticSecret,
|
||||
endpoint: SocketAddr,
|
||||
allowed_ips: Vec<AllowedIp>,
|
||||
container_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Represents a single WireGuard interface on local machine
|
||||
struct WGHandle {
|
||||
_device: DeviceHandle,
|
||||
name: String,
|
||||
addr_v4: IpAddr,
|
||||
addr_v6: IpAddr,
|
||||
started: bool,
|
||||
peers: Vec<Arc<Peer>>,
|
||||
}
|
||||
|
||||
impl Drop for Peer {
|
||||
fn drop(&mut self) {
|
||||
if let Some(name) = &self.container_name {
|
||||
Command::new("docker")
|
||||
.args([
|
||||
"stop", // Run docker
|
||||
&name[5..],
|
||||
])
|
||||
.status()
|
||||
.ok();
|
||||
|
||||
std::fs::remove_file(name).ok();
|
||||
std::fs::remove_file(format!("{}.ngx", name)).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Peer {
|
||||
/// Create a new peer with a given endpoint and a list of allowed IPs
|
||||
fn new(endpoint: SocketAddr, allowed_ips: Vec<AllowedIp>) -> Peer {
|
||||
Peer {
|
||||
key: StaticSecret::random_from_rng(OsRng),
|
||||
endpoint,
|
||||
allowed_ips,
|
||||
container_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new configuration file that can be used by wg-quick
|
||||
fn gen_wg_conf(
|
||||
&self,
|
||||
local_key: &PublicKey,
|
||||
local_addr: &IpAddr,
|
||||
local_port: u16,
|
||||
) -> String {
|
||||
let mut conf = String::from("[Interface]\n");
|
||||
// Each allowed ip, becomes a possible address in the config
|
||||
for ip in &self.allowed_ips {
|
||||
let _ = writeln!(conf, "Address = {}/{}", ip.ip, ip.cidr);
|
||||
}
|
||||
|
||||
// The local endpoint port is the remote listen port
|
||||
let _ = writeln!(conf, "ListenPort = {}", self.endpoint.port());
|
||||
// HACK: this should consume the key so it can't be reused instead of cloning and serializing
|
||||
let _ = writeln!(conf, "PrivateKey = {}", base64encode(self.key.to_bytes()));
|
||||
|
||||
// We are the peer
|
||||
let _ = writeln!(conf, "[Peer]");
|
||||
let _ = writeln!(conf, "PublicKey = {}", base64encode(local_key.as_bytes()));
|
||||
let _ = writeln!(conf, "AllowedIPs = {}", local_addr);
|
||||
let _ = write!(conf, "Endpoint = 127.0.0.1:{}", local_port);
|
||||
|
||||
conf
|
||||
}
|
||||
|
||||
/// Create a simple nginx config, that will respond with the peer public key
|
||||
fn gen_nginx_conf(&self) -> String {
|
||||
format!(
|
||||
"server {{\n\
|
||||
listen 80;\n\
|
||||
listen [::]:80;\n\
|
||||
location / {{\n\
|
||||
return 200 '{}';\n\
|
||||
}}\n\
|
||||
}}",
|
||||
encode(PublicKey::from(&self.key).as_bytes())
|
||||
)
|
||||
}
|
||||
|
||||
fn start_in_container(
|
||||
&mut self,
|
||||
local_key: &PublicKey,
|
||||
local_addr: &IpAddr,
|
||||
local_port: u16,
|
||||
) {
|
||||
let peer_config = self.gen_wg_conf(local_key, local_addr, local_port);
|
||||
let peer_config_file = temp_path();
|
||||
std::fs::write(&peer_config_file, peer_config).unwrap();
|
||||
let nginx_config = self.gen_nginx_conf();
|
||||
let nginx_config_file = format!("{}.ngx", peer_config_file);
|
||||
std::fs::write(&nginx_config_file, nginx_config).unwrap();
|
||||
|
||||
Command::new("docker")
|
||||
.args([
|
||||
"run", // Run docker
|
||||
"-d", // In detached mode
|
||||
"--cap-add=NET_ADMIN", // Grant permissions to open a tunnel
|
||||
"--device=/dev/net/tun",
|
||||
"--sysctl", // Enable ipv6
|
||||
"net.ipv6.conf.all.disable_ipv6=0",
|
||||
"--sysctl",
|
||||
"net.ipv6.conf.default.disable_ipv6=0",
|
||||
"-p", // Open port for the endpoint
|
||||
&format!("{0}:{0}/udp", self.endpoint.port()),
|
||||
"-v", // Map the generated WireGuard config file
|
||||
&format!("{}:/wireguard/wg.conf", peer_config_file),
|
||||
"-v", // Map the nginx config file
|
||||
&format!("{}:/etc/nginx/conf.d/default.conf", nginx_config_file),
|
||||
"--rm", // Cleanup
|
||||
"--name",
|
||||
&peer_config_file[5..],
|
||||
"vkrasnov/wireguard-test",
|
||||
])
|
||||
.status()
|
||||
.expect("Failed to run docker");
|
||||
|
||||
self.container_name = Some(peer_config_file);
|
||||
}
|
||||
|
||||
fn connect(&self) -> std::net::TcpStream {
|
||||
let http_addr = SocketAddr::new(self.allowed_ips[0].ip, 80);
|
||||
for _i in 0..5 {
|
||||
let res = std::net::TcpStream::connect(http_addr);
|
||||
if let Err(err) = res {
|
||||
println!("failed to connect: {:?}", err);
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
continue;
|
||||
}
|
||||
|
||||
return res.unwrap();
|
||||
}
|
||||
|
||||
panic!("failed to connect");
|
||||
}
|
||||
|
||||
fn get_request(&self) -> String {
|
||||
let mut tcp_conn = self.connect();
|
||||
|
||||
write!(
|
||||
tcp_conn,
|
||||
"GET / HTTP/1.1\nHost: localhost\nAccept: */*\nConnection: close\n\n"
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
tcp_conn
|
||||
.set_read_timeout(Some(std::time::Duration::from_secs(60)))
|
||||
.ok();
|
||||
|
||||
let mut reader = BufReader::new(tcp_conn);
|
||||
let mut line = String::new();
|
||||
let mut response = String::new();
|
||||
let mut len = 0usize;
|
||||
|
||||
// Read response code
|
||||
if reader.read_line(&mut line).is_ok() && !line.starts_with("HTTP/1.1 200") {
|
||||
return response;
|
||||
}
|
||||
line.clear();
|
||||
|
||||
// Read headers
|
||||
while reader.read_line(&mut line).is_ok() {
|
||||
if line.trim() == "" {
|
||||
break;
|
||||
}
|
||||
|
||||
{
|
||||
let parsed_line: Vec<&str> = line.split(':').collect();
|
||||
if parsed_line.len() < 2 {
|
||||
return response;
|
||||
}
|
||||
|
||||
let (key, val) = (parsed_line[0], parsed_line[1]);
|
||||
if key.to_lowercase() == "content-length" {
|
||||
len = match val.trim().parse() {
|
||||
Err(_) => return response,
|
||||
Ok(len) => len,
|
||||
};
|
||||
}
|
||||
}
|
||||
line.clear();
|
||||
}
|
||||
|
||||
// Read body
|
||||
let mut buf = [0u8; 256];
|
||||
while len > 0 {
|
||||
let to_read = len.min(buf.len());
|
||||
if reader.read_exact(&mut buf[..to_read]).is_err() {
|
||||
return response;
|
||||
}
|
||||
response.push_str(&String::from_utf8_lossy(&buf[..to_read]));
|
||||
len -= to_read;
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
impl WGHandle {
|
||||
/// Create a new interface for the tunnel with the given address
|
||||
fn init(addr_v4: IpAddr, addr_v6: IpAddr) -> WGHandle {
|
||||
WGHandle::init_with_config(
|
||||
addr_v4,
|
||||
addr_v6,
|
||||
DeviceConfig {
|
||||
n_threads: 2,
|
||||
use_connected_socket: true,
|
||||
#[cfg(target_os = "linux")]
|
||||
use_multi_queue: true,
|
||||
#[cfg(target_os = "linux")]
|
||||
uapi_fd: -1,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a new interface for the tunnel with the given address
|
||||
fn init_with_config(addr_v4: IpAddr, addr_v6: IpAddr, config: DeviceConfig) -> WGHandle {
|
||||
// Generate a new name, utun100+ should work on macOS and Linux
|
||||
let name = format!("utun{}", NEXT_IFACE_IDX.fetch_add(1, Ordering::Relaxed));
|
||||
let _device = DeviceHandle::new(&name, config).unwrap();
|
||||
WGHandle {
|
||||
_device,
|
||||
name,
|
||||
addr_v4,
|
||||
addr_v6,
|
||||
started: false,
|
||||
peers: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
/// Starts the tunnel
|
||||
fn start(&mut self) {
|
||||
// Assign the ipv4 address to the interface
|
||||
Command::new("ifconfig")
|
||||
.args(&[
|
||||
&self.name,
|
||||
&self.addr_v4.to_string(),
|
||||
&self.addr_v4.to_string(),
|
||||
"alias",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to assign ip to tunnel");
|
||||
|
||||
// Assign the ipv6 address to the interface
|
||||
Command::new("ifconfig")
|
||||
.args(&[
|
||||
&self.name,
|
||||
"inet6",
|
||||
&self.addr_v6.to_string(),
|
||||
"prefixlen",
|
||||
"128",
|
||||
"alias",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to assign ipv6 to tunnel");
|
||||
|
||||
// Start the tunnel
|
||||
Command::new("ifconfig")
|
||||
.args(&[&self.name, "up"])
|
||||
.status()
|
||||
.expect("failed to start the tunnel");
|
||||
|
||||
self.started = true;
|
||||
|
||||
// Add each peer to the routing table
|
||||
for p in &self.peers {
|
||||
for r in &p.allowed_ips {
|
||||
let inet_flag = match r.ip {
|
||||
IpAddr::V4(_) => "-inet",
|
||||
IpAddr::V6(_) => "-inet6",
|
||||
};
|
||||
|
||||
Command::new("route")
|
||||
.args(&[
|
||||
"-q",
|
||||
"-n",
|
||||
"add",
|
||||
inet_flag,
|
||||
&format!("{}/{}", r.ip, r.cidr),
|
||||
"-interface",
|
||||
&self.name,
|
||||
])
|
||||
.status()
|
||||
.expect("failed to add route");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
/// Starts the tunnel
|
||||
fn start(&mut self) {
|
||||
Command::new("ip")
|
||||
.args([
|
||||
"address",
|
||||
"add",
|
||||
&self.addr_v4.to_string(),
|
||||
"dev",
|
||||
&self.name,
|
||||
])
|
||||
.status()
|
||||
.expect("failed to assign ip to tunnel");
|
||||
|
||||
Command::new("ip")
|
||||
.args([
|
||||
"address",
|
||||
"add",
|
||||
&self.addr_v6.to_string(),
|
||||
"dev",
|
||||
&self.name,
|
||||
])
|
||||
.status()
|
||||
.expect("failed to assign ipv6 to tunnel");
|
||||
|
||||
// Start the tunnel
|
||||
Command::new("ip")
|
||||
.args(["link", "set", "mtu", "1400", "up", "dev", &self.name])
|
||||
.status()
|
||||
.expect("failed to start the tunnel");
|
||||
|
||||
self.started = true;
|
||||
|
||||
// Add each peer to the routing table
|
||||
for p in &self.peers {
|
||||
for r in &p.allowed_ips {
|
||||
Command::new("ip")
|
||||
.args([
|
||||
"route",
|
||||
"add",
|
||||
&format!("{}/{}", r.ip, r.cidr),
|
||||
"dev",
|
||||
&self.name,
|
||||
])
|
||||
.status()
|
||||
.expect("failed to add route");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Issue a get command on the interface
|
||||
fn wg_get(&self) -> String {
|
||||
let path = format!("/var/run/wireguard/{}.sock", self.name);
|
||||
|
||||
let mut socket = UnixStream::connect(path).unwrap();
|
||||
write!(socket, "get=1\n\n").unwrap();
|
||||
|
||||
let mut ret = String::new();
|
||||
socket.read_to_string(&mut ret).unwrap();
|
||||
ret
|
||||
}
|
||||
|
||||
/// Issue a set command on the interface
|
||||
fn wg_set(&self, setting: &str) -> String {
|
||||
let path = format!("/var/run/wireguard/{}.sock", self.name);
|
||||
let mut socket = UnixStream::connect(path).unwrap();
|
||||
write!(socket, "set=1\n{}\n\n", setting).unwrap();
|
||||
|
||||
let mut ret = String::new();
|
||||
socket.read_to_string(&mut ret).unwrap();
|
||||
ret
|
||||
}
|
||||
|
||||
/// Assign a listen_port to the interface
|
||||
fn wg_set_port(&self, port: u16) -> String {
|
||||
self.wg_set(&format!("listen_port={}", port))
|
||||
}
|
||||
|
||||
/// Assign a private_key to the interface
|
||||
fn wg_set_key(&self, key: StaticSecret) -> String {
|
||||
self.wg_set(&format!("private_key={}", encode(key.to_bytes())))
|
||||
}
|
||||
|
||||
/// Assign a peer to the interface (with public_key, endpoint and a series of nallowed_ip)
|
||||
fn wg_set_peer(
|
||||
&self,
|
||||
key: &PublicKey,
|
||||
ep: &SocketAddr,
|
||||
allowed_ips: &[AllowedIp],
|
||||
) -> String {
|
||||
let mut req = format!("public_key={}\nendpoint={}", encode(key.as_bytes()), ep);
|
||||
for AllowedIp { ip, cidr } in allowed_ips {
|
||||
let _ = write!(req, "\nallowed_ip={}/{}", ip, cidr);
|
||||
}
|
||||
|
||||
self.wg_set(&req)
|
||||
}
|
||||
|
||||
/// Add a new known peer
|
||||
fn add_peer(&mut self, peer: Arc<Peer>) {
|
||||
self.wg_set_peer(
|
||||
&PublicKey::from(&peer.key),
|
||||
&peer.endpoint,
|
||||
&peer.allowed_ips,
|
||||
);
|
||||
self.peers.push(peer);
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new filename in the /tmp dir
|
||||
fn temp_path() -> String {
|
||||
let mut path = String::from("/tmp/");
|
||||
let mut buf = [0u8; 32];
|
||||
SystemRandom::new().fill(&mut buf[..]).unwrap();
|
||||
path.push_str(&encode(buf));
|
||||
path
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
/// Test if wireguard starts and creates a unix socket that we can read from
|
||||
fn test_wireguard_get() {
|
||||
let wg = WGHandle::init("192.0.2.0".parse().unwrap(), "::2".parse().unwrap());
|
||||
let response = wg.wg_get();
|
||||
assert!(response.ends_with("errno=0\n\n"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
/// Test if wireguard starts and creates a unix socket that we can use to set settings
|
||||
fn test_wireguard_set() {
|
||||
let port = next_port();
|
||||
let private_key = StaticSecret::random_from_rng(OsRng);
|
||||
let own_public_key = PublicKey::from(&private_key);
|
||||
|
||||
let wg = WGHandle::init("192.0.2.0".parse().unwrap(), "::2".parse().unwrap());
|
||||
assert!(wg.wg_get().ends_with("errno=0\n\n"));
|
||||
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
|
||||
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
|
||||
|
||||
// Check that the response matches what we expect
|
||||
assert_eq!(
|
||||
wg.wg_get(),
|
||||
format!(
|
||||
"own_public_key={}\nlisten_port={}\nerrno=0\n\n",
|
||||
encode(own_public_key.as_bytes()),
|
||||
port
|
||||
)
|
||||
);
|
||||
|
||||
let peer_key = StaticSecret::random_from_rng(OsRng);
|
||||
let peer_pub_key = PublicKey::from(&peer_key);
|
||||
let endpoint = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 0, 0, 1)), 50001);
|
||||
let allowed_ips = [
|
||||
AllowedIp {
|
||||
ip: IpAddr::V4(Ipv4Addr::new(172, 0, 0, 2)),
|
||||
cidr: 32,
|
||||
},
|
||||
AllowedIp {
|
||||
ip: IpAddr::V6(Ipv6Addr::new(0xf120, 0, 0, 2, 2, 2, 0, 0)),
|
||||
cidr: 100,
|
||||
},
|
||||
];
|
||||
|
||||
assert_eq!(
|
||||
wg.wg_set_peer(&peer_pub_key, &endpoint, &allowed_ips),
|
||||
"errno=0\n\n"
|
||||
);
|
||||
|
||||
// Check that the response matches what we expect
|
||||
assert_eq!(
|
||||
wg.wg_get(),
|
||||
format!(
|
||||
"own_public_key={}\n\
|
||||
listen_port={}\n\
|
||||
public_key={}\n\
|
||||
endpoint={}\n\
|
||||
allowed_ip={}/{}\n\
|
||||
allowed_ip={}/{}\n\
|
||||
rx_bytes=0\n\
|
||||
tx_bytes=0\n\
|
||||
errno=0\n\n",
|
||||
encode(own_public_key.as_bytes()),
|
||||
port,
|
||||
encode(peer_pub_key.as_bytes()),
|
||||
endpoint,
|
||||
allowed_ips[0].ip,
|
||||
allowed_ips[0].cidr,
|
||||
allowed_ips[1].ip,
|
||||
allowed_ips[1].cidr
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Test if wireguard can handle simple ipv4 connections, don't use a connected socket
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_wg_start_ipv4_non_connected() {
|
||||
let port = next_port();
|
||||
let private_key = StaticSecret::random_from_rng(OsRng);
|
||||
let public_key = PublicKey::from(&private_key);
|
||||
let addr_v4 = next_ip();
|
||||
let addr_v6 = next_ip_v6();
|
||||
|
||||
let mut wg = WGHandle::init_with_config(
|
||||
addr_v4,
|
||||
addr_v6,
|
||||
DeviceConfig {
|
||||
n_threads: 2,
|
||||
use_connected_socket: false,
|
||||
#[cfg(target_os = "linux")]
|
||||
use_multi_queue: true,
|
||||
#[cfg(target_os = "linux")]
|
||||
uapi_fd: -1,
|
||||
},
|
||||
);
|
||||
|
||||
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
|
||||
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
|
||||
|
||||
// Create a new peer whose endpoint is on this machine
|
||||
let mut peer = Peer::new(
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()),
|
||||
vec![AllowedIp {
|
||||
ip: next_ip(),
|
||||
cidr: 32,
|
||||
}],
|
||||
);
|
||||
|
||||
peer.start_in_container(&public_key, &addr_v4, port);
|
||||
|
||||
let peer = Arc::new(peer);
|
||||
|
||||
wg.add_peer(Arc::clone(&peer));
|
||||
wg.start();
|
||||
|
||||
let response = peer.get_request();
|
||||
|
||||
assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes()));
|
||||
}
|
||||
|
||||
/// Test if wireguard can handle simple ipv4 connections
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_wg_start_ipv4() {
|
||||
let port = next_port();
|
||||
let private_key = StaticSecret::random_from_rng(OsRng);
|
||||
let public_key = PublicKey::from(&private_key);
|
||||
let addr_v4 = next_ip();
|
||||
let addr_v6 = next_ip_v6();
|
||||
|
||||
let mut wg = WGHandle::init(addr_v4, addr_v6);
|
||||
|
||||
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
|
||||
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
|
||||
|
||||
// Create a new peer whose endpoint is on this machine
|
||||
let mut peer = Peer::new(
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()),
|
||||
vec![AllowedIp {
|
||||
ip: next_ip(),
|
||||
cidr: 32,
|
||||
}],
|
||||
);
|
||||
|
||||
peer.start_in_container(&public_key, &addr_v4, port);
|
||||
|
||||
let peer = Arc::new(peer);
|
||||
|
||||
wg.add_peer(Arc::clone(&peer));
|
||||
wg.start();
|
||||
|
||||
let response = peer.get_request();
|
||||
|
||||
assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
/// Test if wireguard can handle simple ipv6 connections
|
||||
fn test_wg_start_ipv6() {
|
||||
let port = next_port();
|
||||
let private_key = StaticSecret::random_from_rng(OsRng);
|
||||
let public_key = PublicKey::from(&private_key);
|
||||
let addr_v4 = next_ip();
|
||||
let addr_v6 = next_ip_v6();
|
||||
|
||||
let mut wg = WGHandle::init(addr_v4, addr_v6);
|
||||
|
||||
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
|
||||
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
|
||||
|
||||
let mut peer = Peer::new(
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()),
|
||||
vec![AllowedIp {
|
||||
ip: next_ip_v6(),
|
||||
cidr: 128,
|
||||
}],
|
||||
);
|
||||
|
||||
peer.start_in_container(&public_key, &addr_v6, port);
|
||||
|
||||
let peer = Arc::new(peer);
|
||||
|
||||
wg.add_peer(Arc::clone(&peer));
|
||||
wg.start();
|
||||
|
||||
let response = peer.get_request();
|
||||
|
||||
assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes()));
|
||||
}
|
||||
|
||||
/// Test if wireguard can handle connection with an ipv6 endpoint
|
||||
#[test]
|
||||
#[ignore]
|
||||
#[cfg(target_os = "linux")] // Can't make docker work with ipv6 on macOS ATM
|
||||
fn test_wg_start_ipv6_endpoint() {
|
||||
let port = next_port();
|
||||
let private_key = StaticSecret::random_from_rng(OsRng);
|
||||
let public_key = PublicKey::from(&private_key);
|
||||
let addr_v4 = next_ip();
|
||||
let addr_v6 = next_ip_v6();
|
||||
|
||||
let mut wg = WGHandle::init(addr_v4, addr_v6);
|
||||
|
||||
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
|
||||
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
|
||||
|
||||
let mut peer = Peer::new(
|
||||
SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
|
||||
next_port(),
|
||||
),
|
||||
vec![AllowedIp {
|
||||
ip: next_ip_v6(),
|
||||
cidr: 128,
|
||||
}],
|
||||
);
|
||||
|
||||
peer.start_in_container(&public_key, &addr_v6, port);
|
||||
|
||||
let peer = Arc::new(peer);
|
||||
|
||||
wg.add_peer(Arc::clone(&peer));
|
||||
wg.start();
|
||||
|
||||
let response = peer.get_request();
|
||||
|
||||
assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes()));
|
||||
}
|
||||
|
||||
/// Test if wireguard can handle connection with an ipv6 endpoint
|
||||
#[test]
|
||||
#[ignore]
|
||||
#[cfg(target_os = "linux")] // Can't make docker work with ipv6 on macOS ATM
|
||||
fn test_wg_start_ipv6_endpoint_not_connected() {
|
||||
let port = next_port();
|
||||
let private_key = StaticSecret::random_from_rng(OsRng);
|
||||
let public_key = PublicKey::from(&private_key);
|
||||
let addr_v4 = next_ip();
|
||||
let addr_v6 = next_ip_v6();
|
||||
|
||||
let mut wg = WGHandle::init_with_config(
|
||||
addr_v4,
|
||||
addr_v6,
|
||||
DeviceConfig {
|
||||
n_threads: 2,
|
||||
use_connected_socket: false,
|
||||
#[cfg(target_os = "linux")]
|
||||
use_multi_queue: true,
|
||||
#[cfg(target_os = "linux")]
|
||||
uapi_fd: -1,
|
||||
},
|
||||
);
|
||||
|
||||
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
|
||||
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
|
||||
|
||||
let mut peer = Peer::new(
|
||||
SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
|
||||
next_port(),
|
||||
),
|
||||
vec![AllowedIp {
|
||||
ip: next_ip_v6(),
|
||||
cidr: 128,
|
||||
}],
|
||||
);
|
||||
|
||||
peer.start_in_container(&public_key, &addr_v6, port);
|
||||
|
||||
let peer = Arc::new(peer);
|
||||
|
||||
wg.add_peer(Arc::clone(&peer));
|
||||
wg.start();
|
||||
|
||||
let response = peer.get_request();
|
||||
|
||||
assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes()));
|
||||
}
|
||||
|
||||
/// Test many concurrent connections
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_wg_concurrent() {
|
||||
let port = next_port();
|
||||
let private_key = StaticSecret::random_from_rng(OsRng);
|
||||
let public_key = PublicKey::from(&private_key);
|
||||
let addr_v4 = next_ip();
|
||||
let addr_v6 = next_ip_v6();
|
||||
|
||||
let mut wg = WGHandle::init(addr_v4, addr_v6);
|
||||
|
||||
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
|
||||
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
|
||||
|
||||
for _ in 0..5 {
|
||||
// Create a new peer whose endpoint is on this machine
|
||||
let mut peer = Peer::new(
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()),
|
||||
vec![AllowedIp {
|
||||
ip: next_ip(),
|
||||
cidr: 32,
|
||||
}],
|
||||
);
|
||||
|
||||
peer.start_in_container(&public_key, &addr_v4, port);
|
||||
|
||||
let peer = Arc::new(peer);
|
||||
|
||||
wg.add_peer(Arc::clone(&peer));
|
||||
}
|
||||
|
||||
wg.start();
|
||||
|
||||
let mut threads = vec![];
|
||||
|
||||
for p in wg.peers {
|
||||
let pub_key = PublicKey::from(&p.key);
|
||||
threads.push(thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
let response = p.get_request();
|
||||
assert_eq!(response, encode(pub_key.as_bytes()));
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for t in threads {
|
||||
t.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
/// Test many concurrent connections
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_wg_concurrent_v6() {
|
||||
let port = next_port();
|
||||
let private_key = StaticSecret::random_from_rng(OsRng);
|
||||
let public_key = PublicKey::from(&private_key);
|
||||
let addr_v4 = next_ip();
|
||||
let addr_v6 = next_ip_v6();
|
||||
|
||||
let mut wg = WGHandle::init(addr_v4, addr_v6);
|
||||
|
||||
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
|
||||
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
|
||||
|
||||
for _ in 0..5 {
|
||||
// Create a new peer whose endpoint is on this machine
|
||||
let mut peer = Peer::new(
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()),
|
||||
vec![AllowedIp {
|
||||
ip: next_ip_v6(),
|
||||
cidr: 128,
|
||||
}],
|
||||
);
|
||||
|
||||
peer.start_in_container(&public_key, &addr_v6, port);
|
||||
|
||||
let peer = Arc::new(peer);
|
||||
|
||||
wg.add_peer(Arc::clone(&peer));
|
||||
}
|
||||
|
||||
wg.start();
|
||||
|
||||
let mut threads = vec![];
|
||||
|
||||
for p in wg.peers {
|
||||
let pub_key = PublicKey::from(&p.key);
|
||||
threads.push(thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
let response = p.get_request();
|
||||
assert_eq!(response, encode(pub_key.as_bytes()));
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for t in threads {
|
||||
t.join().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
337
lib/boringtun/src/device/kqueue.rs
Normal file
337
lib/boringtun/src/device/kqueue.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use super::Error;
|
||||
use libc::*;
|
||||
use parking_lot::Mutex;
|
||||
use std::io;
|
||||
use std::ops::Deref;
|
||||
use std::os::unix::io::RawFd;
|
||||
use std::ptr::{null, null_mut};
|
||||
use std::time::Duration;
|
||||
|
||||
/// A return type for the EventPoll::wait() function
|
||||
pub enum WaitResult<'a, H> {
|
||||
/// Event triggered normally
|
||||
Ok(EventGuard<'a, H>),
|
||||
/// Event triggered due to End of File conditions
|
||||
EoF(EventGuard<'a, H>),
|
||||
/// There was an error
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// Implements a registry of pollable events
|
||||
pub struct EventPoll<H: Sized> {
|
||||
events: Mutex<Vec<Option<Box<Event<H>>>>>, // Events with a file descriptor
|
||||
custom: Mutex<Vec<Option<Box<Event<H>>>>>, // Other events (i.e. timers & notifiers)
|
||||
signals: Mutex<Vec<Option<Box<Event<H>>>>>, // Signal handlers
|
||||
kqueue: RawFd, // The OS kqueue
|
||||
}
|
||||
|
||||
/// A type that hold a reference to a triggered Event
|
||||
/// While an EventGuard exists for a given Event, it will not be triggered by any other thread
|
||||
/// Once the EventGuard goes out of scope, the underlying Event will be re-enabled
|
||||
pub struct EventGuard<'a, H> {
|
||||
kqueue: RawFd,
|
||||
event: &'a Event<H>,
|
||||
poll: &'a EventPoll<H>,
|
||||
}
|
||||
|
||||
/// A reference to a single event in an EventPoll
|
||||
pub struct EventRef {
|
||||
trigger: RawFd,
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum EventKind {
|
||||
FD,
|
||||
Notifier,
|
||||
Signal,
|
||||
Timer,
|
||||
}
|
||||
|
||||
// A single event
|
||||
struct Event<H> {
|
||||
event: kevent, // The kqueue event description
|
||||
handler: H, // The associated data
|
||||
kind: EventKind,
|
||||
}
|
||||
|
||||
impl<H> Drop for EventPoll<H> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { close(self.kqueue) };
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<H> Send for EventPoll<H> {}
|
||||
unsafe impl<H> Sync for EventPoll<H> {}
|
||||
|
||||
impl<H: Send + Sync> EventPoll<H> {
|
||||
/// Create a new event registry
|
||||
pub fn new() -> Result<EventPoll<H>, Error> {
|
||||
let kqueue = match unsafe { kqueue() } {
|
||||
-1 => return Err(Error::EventQueue(io::Error::last_os_error())),
|
||||
kqueue => kqueue,
|
||||
};
|
||||
|
||||
Ok(EventPoll {
|
||||
events: Mutex::new(vec![]),
|
||||
custom: Mutex::new(vec![]),
|
||||
signals: Mutex::new(vec![]),
|
||||
kqueue,
|
||||
})
|
||||
}
|
||||
|
||||
/// Add and enable a new event with the factory.
|
||||
/// The event is triggered when a Read operation on the provided trigger becomes available
|
||||
/// If the trigger fd is closed, the event won't be triggered anymore, but it's data won't be
|
||||
/// automatically released.
|
||||
/// The safe way to delete an event, is using the cancel method of an EventGuard.
|
||||
/// If the same trigger is used with multiple events in the same EventPoll, the last added
|
||||
/// event overrides all previous events. In case the same trigger is used with multiple polls,
|
||||
/// each event will be triggered independently.
|
||||
/// The event will keep triggering until a Read operation is no longer possible on the trigger.
|
||||
/// When triggered, one of the threads waiting on the poll will receive the handler via an
|
||||
/// appropriate EventGuard. It is guaranteed that only a single thread can have a reference to
|
||||
/// the handler at any given time.
|
||||
pub fn new_event(&self, trigger: RawFd, handler: H) -> Result<EventRef, Error> {
|
||||
// Create an event descriptor
|
||||
let flags = EV_ENABLE | EV_DISPATCH;
|
||||
|
||||
let ev = Event {
|
||||
event: kevent {
|
||||
ident: trigger as _,
|
||||
filter: EVFILT_READ,
|
||||
flags,
|
||||
fflags: 0,
|
||||
data: 0,
|
||||
udata: null_mut(),
|
||||
},
|
||||
handler,
|
||||
kind: EventKind::FD,
|
||||
};
|
||||
|
||||
self.register_event(ev)
|
||||
}
|
||||
|
||||
pub fn new_periodic_event(&self, handler: H, period: Duration) -> Result<EventRef, Error> {
|
||||
// The periodic event in BSD uses EVFILT_TIMER
|
||||
let ev = Event {
|
||||
event: kevent {
|
||||
ident: 0,
|
||||
filter: EVFILT_TIMER,
|
||||
flags: EV_ENABLE | EV_DISPATCH,
|
||||
fflags: NOTE_NSECONDS,
|
||||
data: period
|
||||
.as_secs()
|
||||
.checked_mul(1_000_000_000)
|
||||
.unwrap()
|
||||
.checked_add(u64::from(period.subsec_nanos()))
|
||||
.unwrap() as _,
|
||||
udata: null_mut(),
|
||||
},
|
||||
handler,
|
||||
kind: EventKind::Timer,
|
||||
};
|
||||
|
||||
self.register_event(ev)
|
||||
}
|
||||
|
||||
pub fn new_notifier(&self, handler: H) -> Result<EventRef, Error> {
|
||||
// The notifier in BSD uses EVFILT_USER for notifications.
|
||||
let ev = Event {
|
||||
event: kevent {
|
||||
ident: 0,
|
||||
filter: EVFILT_USER,
|
||||
flags: EV_ENABLE,
|
||||
fflags: 0,
|
||||
data: 0,
|
||||
udata: null_mut(),
|
||||
},
|
||||
handler,
|
||||
kind: EventKind::Notifier,
|
||||
};
|
||||
|
||||
self.register_event(ev)
|
||||
}
|
||||
|
||||
/// Add and enable a new signal handler
|
||||
pub fn new_signal_event(&self, signal: c_int, handler: H) -> Result<EventRef, Error> {
|
||||
let ev = Event {
|
||||
event: kevent {
|
||||
ident: signal as _,
|
||||
filter: EVFILT_SIGNAL,
|
||||
flags: EV_ENABLE | EV_DISPATCH,
|
||||
fflags: 0,
|
||||
data: 0,
|
||||
udata: null_mut(),
|
||||
},
|
||||
handler,
|
||||
kind: EventKind::Signal,
|
||||
};
|
||||
|
||||
self.register_event(ev)
|
||||
}
|
||||
|
||||
/// Wait until one of the registered events becomes triggered. Once an event
|
||||
/// is triggered, a single caller thread gets the handler for that event.
|
||||
/// In case a notifier is triggered, all waiting threads will receive the same
|
||||
/// handler.
|
||||
pub fn wait(&'_ self) -> WaitResult<'_, H> {
|
||||
let mut event = kevent {
|
||||
ident: 0,
|
||||
filter: 0,
|
||||
flags: 0,
|
||||
fflags: 0,
|
||||
data: 0,
|
||||
udata: null_mut(),
|
||||
};
|
||||
|
||||
if unsafe { kevent(self.kqueue, null(), 0, &mut event, 1, null()) } == -1 {
|
||||
return WaitResult::Error(io::Error::last_os_error().to_string());
|
||||
}
|
||||
|
||||
let event_data = unsafe { (event.udata as *mut Event<H>).as_ref().unwrap() };
|
||||
|
||||
let guard = EventGuard {
|
||||
kqueue: self.kqueue,
|
||||
event: event_data,
|
||||
poll: self,
|
||||
};
|
||||
|
||||
if event.flags & EV_EOF != 0 {
|
||||
WaitResult::EoF(guard)
|
||||
} else {
|
||||
WaitResult::Ok(guard)
|
||||
}
|
||||
}
|
||||
|
||||
// Register an event with this poll.
|
||||
fn register_event(&self, ev: Event<H>) -> Result<EventRef, Error> {
|
||||
let mut events = match ev.kind {
|
||||
EventKind::FD => self.events.lock(),
|
||||
EventKind::Timer | EventKind::Notifier => self.custom.lock(),
|
||||
EventKind::Signal => self.signals.lock(),
|
||||
};
|
||||
|
||||
let (trigger, index) = match ev.kind {
|
||||
EventKind::FD | EventKind::Signal => (ev.event.ident as RawFd, ev.event.ident as usize),
|
||||
EventKind::Timer | EventKind::Notifier => (-(events.len() as RawFd) - 1, events.len()), // Custom events get negative identifiers, hopefully we will never have more than 2^31 events of each type
|
||||
};
|
||||
|
||||
// Expand events vector if needed
|
||||
while events.len() <= index {
|
||||
// Resize the vector to be able to fit the new index
|
||||
// We trust the OS to allocate file descriptors in a sane order
|
||||
events.push(None); // resize doesn't work because Clone is not satisfied
|
||||
}
|
||||
|
||||
let mut ev = Box::new(ev);
|
||||
// The inner event points back to the wrapper
|
||||
ev.event.ident = trigger as _;
|
||||
ev.event.udata = ev.as_mut() as *mut Event<H> as _;
|
||||
|
||||
let mut kev = ev.event;
|
||||
kev.flags |= EV_ADD;
|
||||
|
||||
if unsafe { kevent(self.kqueue, &kev, 1, null_mut(), 0, null()) } == -1 {
|
||||
return Err(Error::EventQueue(io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
if let Some(mut event) = events[index].take() {
|
||||
// Properly remove any previous event first
|
||||
event.event.flags = EV_DELETE;
|
||||
unsafe { kevent(self.kqueue, &event.event, 1, null_mut(), 0, null()) };
|
||||
}
|
||||
|
||||
if ev.kind == EventKind::Signal {
|
||||
// Mask the signal if successfully added to kqueue
|
||||
unsafe { signal(trigger, SIG_IGN) };
|
||||
}
|
||||
|
||||
events[index] = Some(ev);
|
||||
|
||||
Ok(EventRef { trigger })
|
||||
}
|
||||
|
||||
pub fn trigger_notification(&self, notification_event: &EventRef) {
|
||||
let events = self.custom.lock();
|
||||
let ev_index = -notification_event.trigger - 1; // Custom events have negative index from -1
|
||||
|
||||
let event_ref = &(*events)[ev_index as usize];
|
||||
let event_data = event_ref.as_ref().expect("Expected an event");
|
||||
|
||||
if event_data.kind != EventKind::Notifier {
|
||||
panic!("Can only trigger a notification event");
|
||||
}
|
||||
|
||||
let mut kev = event_data.event;
|
||||
kev.fflags = NOTE_TRIGGER;
|
||||
|
||||
unsafe { kevent(self.kqueue, &kev, 1, null_mut(), 0, null()) };
|
||||
}
|
||||
|
||||
pub fn stop_notification(&self, notification_event: &EventRef) {
|
||||
let events = self.custom.lock();
|
||||
let ev_index = -notification_event.trigger - 1; // Custom events have negative index from -1
|
||||
|
||||
let event_ref = &(*events)[ev_index as usize];
|
||||
let event_data = event_ref.as_ref().expect("Expected an event");
|
||||
|
||||
if event_data.kind != EventKind::Notifier {
|
||||
panic!("Can only stop a notification event");
|
||||
}
|
||||
|
||||
let mut kev = event_data.event;
|
||||
kev.flags = EV_DISABLE;
|
||||
kev.fflags = 0;
|
||||
|
||||
unsafe { kevent(self.kqueue, &kev, 1, null_mut(), 0, null()) };
|
||||
}
|
||||
}
|
||||
|
||||
impl<H> EventPoll<H> {
|
||||
// This function is only safe to call when the event loop is not running
|
||||
pub unsafe fn clear_event_by_fd(&self, index: RawFd) {
|
||||
let (mut events, index) = if index >= 0 {
|
||||
(self.events.lock(), index as usize)
|
||||
} else {
|
||||
(self.custom.lock(), (-index - 1) as usize)
|
||||
};
|
||||
|
||||
if let Some(mut event) = events[index].take() {
|
||||
// Properly remove any previous event first
|
||||
event.event.flags = EV_DELETE;
|
||||
kevent(self.kqueue, &event.event, 1, null_mut(), 0, null());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, H> Deref for EventGuard<'a, H> {
|
||||
type Target = H;
|
||||
fn deref(&self) -> &H {
|
||||
&self.event.handler
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, H> Drop for EventGuard<'a, H> {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
// Re-enable the event once EventGuard goes out of scope
|
||||
kevent(self.kqueue, &self.event.event, 1, null_mut(), 0, null());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, H> EventGuard<'a, H> {
|
||||
/// Cancel and remove the event represented by this guard
|
||||
pub fn cancel(self) {
|
||||
unsafe { self.poll.clear_event_by_fd(self.event.event.ident as RawFd) };
|
||||
std::mem::forget(self); // Don't call the regular drop that would enable the event
|
||||
}
|
||||
|
||||
/// Stub: only used for Linux-specific features.
|
||||
pub fn fd(&self) -> i32 {
|
||||
-1
|
||||
}
|
||||
}
|
||||
884
lib/boringtun/src/device/mod.rs
Normal file
884
lib/boringtun/src/device/mod.rs
Normal file
@@ -0,0 +1,884 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
pub mod allowed_ips;
|
||||
pub mod api;
|
||||
mod dev_lock;
|
||||
pub mod drop_privileges;
|
||||
#[cfg(test)]
|
||||
mod integration_tests;
|
||||
pub mod peer;
|
||||
|
||||
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))]
|
||||
#[path = "kqueue.rs"]
|
||||
pub mod poll;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
#[path = "epoll.rs"]
|
||||
pub mod poll;
|
||||
|
||||
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))]
|
||||
#[path = "tun_darwin.rs"]
|
||||
pub mod tun;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
#[path = "tun_linux.rs"]
|
||||
pub mod tun;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io::{self, Write as _};
|
||||
use std::mem::MaybeUninit;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::os::unix::io::AsRawFd;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::thread::JoinHandle;
|
||||
|
||||
use crate::noise::errors::WireGuardError;
|
||||
use crate::noise::handshake::parse_handshake_anon;
|
||||
use crate::noise::rate_limiter::RateLimiter;
|
||||
use crate::noise::{Packet, Tunn, TunnResult};
|
||||
use crate::x25519;
|
||||
use allowed_ips::AllowedIps;
|
||||
use parking_lot::Mutex;
|
||||
use peer::{AllowedIP, Peer};
|
||||
use poll::{EventPoll, EventRef, WaitResult};
|
||||
use rand_core::{OsRng, RngCore};
|
||||
use socket2::{Domain, Protocol, Type};
|
||||
use tun::TunSocket;
|
||||
|
||||
use dev_lock::{Lock, LockReadGuard};
|
||||
|
||||
const HANDSHAKE_RATE_LIMIT: u64 = 100; // The number of handshakes per second we can tolerate before using cookies
|
||||
|
||||
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
|
||||
const MAX_ITR: usize = 100; // Number of packets to handle per handler call
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("i/o error: {0}")]
|
||||
IoError(#[from] io::Error),
|
||||
#[error("{0}")]
|
||||
Socket(io::Error),
|
||||
#[error("{0}")]
|
||||
Bind(String),
|
||||
#[error("{0}")]
|
||||
FCntl(io::Error),
|
||||
#[error("{0}")]
|
||||
EventQueue(io::Error),
|
||||
#[error("{0}")]
|
||||
IOCtl(io::Error),
|
||||
#[error("{0}")]
|
||||
Connect(String),
|
||||
#[error("{0}")]
|
||||
SetSockOpt(String),
|
||||
#[error("Invalid tunnel name")]
|
||||
InvalidTunnelName,
|
||||
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))]
|
||||
#[error("{0}")]
|
||||
GetSockOpt(io::Error),
|
||||
#[error("{0}")]
|
||||
GetSockName(String),
|
||||
#[cfg(target_os = "linux")]
|
||||
#[error("{0}")]
|
||||
Timer(io::Error),
|
||||
#[error("iface read: {0}")]
|
||||
IfaceRead(io::Error),
|
||||
#[error("{0}")]
|
||||
DropPrivileges(String),
|
||||
#[error("API socket error: {0}")]
|
||||
ApiSocket(io::Error),
|
||||
}
|
||||
|
||||
// What the event loop should do after a handler returns
|
||||
enum Action {
|
||||
Continue, // Continue the loop
|
||||
Yield, // Yield the read lock and acquire it again
|
||||
Exit, // Stop the loop
|
||||
}
|
||||
|
||||
// Event handler function
|
||||
type Handler = Box<dyn Fn(&mut LockReadGuard<Device>, &mut ThreadData) -> Action + Send + Sync>;
|
||||
|
||||
pub struct DeviceHandle {
|
||||
device: Arc<Lock<Device>>, // The interface this handle owns
|
||||
threads: Vec<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct DeviceConfig {
|
||||
pub n_threads: usize,
|
||||
pub use_connected_socket: bool,
|
||||
#[cfg(target_os = "linux")]
|
||||
pub use_multi_queue: bool,
|
||||
#[cfg(target_os = "linux")]
|
||||
pub uapi_fd: i32,
|
||||
}
|
||||
|
||||
impl Default for DeviceConfig {
|
||||
fn default() -> Self {
|
||||
DeviceConfig {
|
||||
n_threads: 4,
|
||||
use_connected_socket: true,
|
||||
#[cfg(target_os = "linux")]
|
||||
use_multi_queue: true,
|
||||
#[cfg(target_os = "linux")]
|
||||
uapi_fd: -1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Device {
|
||||
key_pair: Option<(x25519::StaticSecret, x25519::PublicKey)>,
|
||||
queue: Arc<EventPoll<Handler>>,
|
||||
|
||||
listen_port: u16,
|
||||
fwmark: Option<u32>,
|
||||
|
||||
iface: Arc<TunSocket>,
|
||||
udp4: Option<socket2::Socket>,
|
||||
udp6: Option<socket2::Socket>,
|
||||
|
||||
yield_notice: Option<EventRef>,
|
||||
exit_notice: Option<EventRef>,
|
||||
|
||||
peers: HashMap<x25519::PublicKey, Arc<Mutex<Peer>>>,
|
||||
peers_by_ip: AllowedIps<Arc<Mutex<Peer>>>,
|
||||
peers_by_idx: HashMap<u32, Arc<Mutex<Peer>>>,
|
||||
next_index: IndexLfsr,
|
||||
|
||||
config: DeviceConfig,
|
||||
|
||||
cleanup_paths: Vec<String>,
|
||||
|
||||
mtu: AtomicUsize,
|
||||
|
||||
rate_limiter: Option<Arc<RateLimiter>>,
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
uapi_fd: i32,
|
||||
}
|
||||
|
||||
struct ThreadData {
|
||||
iface: Arc<TunSocket>,
|
||||
src_buf: [u8; MAX_UDP_SIZE],
|
||||
dst_buf: [u8; MAX_UDP_SIZE],
|
||||
}
|
||||
|
||||
impl DeviceHandle {
|
||||
pub fn new(name: &str, config: DeviceConfig) -> Result<DeviceHandle, Error> {
|
||||
let n_threads = config.n_threads;
|
||||
let mut wg_interface = Device::new(name, config)?;
|
||||
wg_interface.open_listen_socket(0)?; // Start listening on a random port
|
||||
|
||||
let interface_lock = Arc::new(Lock::new(wg_interface));
|
||||
|
||||
let mut threads = vec![];
|
||||
|
||||
for i in 0..n_threads {
|
||||
threads.push({
|
||||
let dev = Arc::clone(&interface_lock);
|
||||
thread::spawn(move || DeviceHandle::event_loop(i, &dev))
|
||||
});
|
||||
}
|
||||
|
||||
Ok(DeviceHandle {
|
||||
device: interface_lock,
|
||||
threads,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn wait(&mut self) {
|
||||
while let Some(thread) = self.threads.pop() {
|
||||
thread.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clean(&mut self) {
|
||||
for path in &self.device.read().cleanup_paths {
|
||||
// attempt to remove any file we created in the work dir
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
|
||||
fn event_loop(_i: usize, device: &Lock<Device>) {
|
||||
#[cfg(target_os = "linux")]
|
||||
let mut thread_local = ThreadData {
|
||||
src_buf: [0u8; MAX_UDP_SIZE],
|
||||
dst_buf: [0u8; MAX_UDP_SIZE],
|
||||
iface: if _i == 0 || !device.read().config.use_multi_queue {
|
||||
// For the first thread use the original iface
|
||||
Arc::clone(&device.read().iface)
|
||||
} else {
|
||||
// For for the rest create a new iface queue
|
||||
let iface_local = Arc::new(
|
||||
TunSocket::new(&device.read().iface.name().unwrap())
|
||||
.unwrap()
|
||||
.set_non_blocking()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
device
|
||||
.read()
|
||||
.register_iface_handler(Arc::clone(&iface_local))
|
||||
.ok();
|
||||
|
||||
iface_local
|
||||
},
|
||||
};
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
let mut thread_local = ThreadData {
|
||||
src_buf: [0u8; MAX_UDP_SIZE],
|
||||
dst_buf: [0u8; MAX_UDP_SIZE],
|
||||
iface: Arc::clone(&device.read().iface),
|
||||
};
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
let uapi_fd = -1;
|
||||
#[cfg(target_os = "linux")]
|
||||
let uapi_fd = device.read().uapi_fd;
|
||||
|
||||
loop {
|
||||
// The event loop keeps a read lock on the device, because we assume write access is rarely needed
|
||||
let mut device_lock = device.read();
|
||||
let queue = Arc::clone(&device_lock.queue);
|
||||
|
||||
loop {
|
||||
match queue.wait() {
|
||||
WaitResult::Ok(handler) => {
|
||||
let action = (*handler)(&mut device_lock, &mut thread_local);
|
||||
match action {
|
||||
Action::Continue => {}
|
||||
Action::Yield => break,
|
||||
Action::Exit => {
|
||||
device_lock.trigger_exit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
WaitResult::EoF(handler) => {
|
||||
if uapi_fd >= 0 && uapi_fd == handler.fd() {
|
||||
device_lock.trigger_exit();
|
||||
return;
|
||||
}
|
||||
handler.cancel();
|
||||
}
|
||||
WaitResult::Error(e) => tracing::error!(message = "Poll error", error = ?e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DeviceHandle {
|
||||
fn drop(&mut self) {
|
||||
self.device.read().trigger_exit();
|
||||
self.clean();
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
fn next_index(&mut self) -> u32 {
|
||||
self.next_index.next()
|
||||
}
|
||||
|
||||
fn remove_peer(&mut self, pub_key: &x25519::PublicKey) {
|
||||
if let Some(peer) = self.peers.remove(pub_key) {
|
||||
// Found a peer to remove, now purge all references to it:
|
||||
{
|
||||
let p = peer.lock();
|
||||
p.shutdown_endpoint(); // close open udp socket and free the closure
|
||||
self.peers_by_idx.remove(&p.index());
|
||||
}
|
||||
self.peers_by_ip
|
||||
.remove(&|p: &Arc<Mutex<Peer>>| Arc::ptr_eq(&peer, p));
|
||||
|
||||
tracing::info!("Peer removed");
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn update_peer(
|
||||
&mut self,
|
||||
pub_key: x25519::PublicKey,
|
||||
remove: bool,
|
||||
_replace_ips: bool,
|
||||
endpoint: Option<SocketAddr>,
|
||||
allowed_ips: &[AllowedIP],
|
||||
keepalive: Option<u16>,
|
||||
preshared_key: Option<[u8; 32]>,
|
||||
) {
|
||||
if remove {
|
||||
// Completely remove a peer
|
||||
return self.remove_peer(&pub_key);
|
||||
}
|
||||
|
||||
// Update an existing peer
|
||||
if self.peers.get(&pub_key).is_some() {
|
||||
// We already have a peer, we need to merge the existing config into the newly created one
|
||||
panic!("Modifying existing peers is not yet supported. Remove and add again instead.");
|
||||
}
|
||||
|
||||
let next_index = self.next_index();
|
||||
let device_key_pair = self
|
||||
.key_pair
|
||||
.as_ref()
|
||||
.expect("Private key must be set first");
|
||||
|
||||
let tunn = Tunn::new(
|
||||
device_key_pair.0.clone(),
|
||||
pub_key,
|
||||
preshared_key,
|
||||
keepalive,
|
||||
next_index,
|
||||
None,
|
||||
);
|
||||
|
||||
let peer = Peer::new(tunn, next_index, endpoint, allowed_ips, preshared_key);
|
||||
|
||||
let peer = Arc::new(Mutex::new(peer));
|
||||
self.peers.insert(pub_key, Arc::clone(&peer));
|
||||
self.peers_by_idx.insert(next_index, Arc::clone(&peer));
|
||||
|
||||
for AllowedIP { addr, cidr } in allowed_ips {
|
||||
self.peers_by_ip
|
||||
.insert(*addr, *cidr as _, Arc::clone(&peer));
|
||||
}
|
||||
|
||||
tracing::info!("Peer added");
|
||||
}
|
||||
|
||||
pub fn new(name: &str, config: DeviceConfig) -> Result<Device, Error> {
|
||||
let poll = EventPoll::<Handler>::new()?;
|
||||
|
||||
// Create a tunnel device
|
||||
let iface = Arc::new(TunSocket::new(name)?.set_non_blocking()?);
|
||||
let mtu = iface.mtu()?;
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
let uapi_fd = -1;
|
||||
#[cfg(target_os = "linux")]
|
||||
let uapi_fd = config.uapi_fd;
|
||||
|
||||
let mut device = Device {
|
||||
queue: Arc::new(poll),
|
||||
iface,
|
||||
config,
|
||||
exit_notice: Default::default(),
|
||||
yield_notice: Default::default(),
|
||||
fwmark: Default::default(),
|
||||
key_pair: Default::default(),
|
||||
listen_port: Default::default(),
|
||||
next_index: Default::default(),
|
||||
peers: Default::default(),
|
||||
peers_by_idx: Default::default(),
|
||||
peers_by_ip: AllowedIps::new(),
|
||||
udp4: Default::default(),
|
||||
udp6: Default::default(),
|
||||
cleanup_paths: Default::default(),
|
||||
mtu: AtomicUsize::new(mtu),
|
||||
rate_limiter: None,
|
||||
#[cfg(target_os = "linux")]
|
||||
uapi_fd,
|
||||
};
|
||||
|
||||
if uapi_fd >= 0 {
|
||||
device.register_api_fd(uapi_fd)?;
|
||||
} else {
|
||||
device.register_api_handler()?;
|
||||
}
|
||||
device.register_iface_handler(Arc::clone(&device.iface))?;
|
||||
device.register_notifiers()?;
|
||||
device.register_timers()?;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
// Only for macOS write the actual socket name into WG_TUN_NAME_FILE
|
||||
if let Ok(name_file) = std::env::var("WG_TUN_NAME_FILE") {
|
||||
if name == "utun" {
|
||||
std::fs::write(&name_file, device.iface.name().unwrap().as_bytes()).unwrap();
|
||||
device.cleanup_paths.push(name_file);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(device)
|
||||
}
|
||||
|
||||
fn open_listen_socket(&mut self, mut port: u16) -> Result<(), Error> {
|
||||
// Binds the network facing interfaces
|
||||
// First close any existing open socket, and remove them from the event loop
|
||||
if let Some(s) = self.udp4.take() {
|
||||
unsafe {
|
||||
// This is safe because the event loop is not running yet
|
||||
self.queue.clear_event_by_fd(s.as_raw_fd())
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(s) = self.udp6.take() {
|
||||
unsafe { self.queue.clear_event_by_fd(s.as_raw_fd()) };
|
||||
}
|
||||
|
||||
for peer in self.peers.values() {
|
||||
peer.lock().shutdown_endpoint();
|
||||
}
|
||||
|
||||
// Then open new sockets and bind to the port
|
||||
let udp_sock4 = socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
|
||||
udp_sock4.set_reuse_address(true)?;
|
||||
udp_sock4.bind(&SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into())?;
|
||||
udp_sock4.set_nonblocking(true)?;
|
||||
|
||||
if port == 0 {
|
||||
// Random port was assigned
|
||||
port = udp_sock4.local_addr()?.as_socket().unwrap().port();
|
||||
}
|
||||
|
||||
let udp_sock6 = socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
|
||||
udp_sock6.set_reuse_address(true)?;
|
||||
udp_sock6.bind(&SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).into())?;
|
||||
udp_sock6.set_nonblocking(true)?;
|
||||
|
||||
self.register_udp_handler(udp_sock4.try_clone().unwrap())?;
|
||||
self.register_udp_handler(udp_sock6.try_clone().unwrap())?;
|
||||
self.udp4 = Some(udp_sock4);
|
||||
self.udp6 = Some(udp_sock6);
|
||||
|
||||
self.listen_port = port;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_key(&mut self, private_key: x25519::StaticSecret) {
|
||||
let public_key = x25519::PublicKey::from(&private_key);
|
||||
let key_pair = Some((private_key.clone(), public_key));
|
||||
|
||||
// x25519 (rightly) doesn't let us expose secret keys for comparison.
|
||||
// If the public keys are the same, then the private keys are the same.
|
||||
if Some(&public_key) == self.key_pair.as_ref().map(|p| &p.1) {
|
||||
return;
|
||||
}
|
||||
|
||||
let rate_limiter = Arc::new(RateLimiter::new(&public_key, HANDSHAKE_RATE_LIMIT));
|
||||
|
||||
for peer in self.peers.values_mut() {
|
||||
peer.lock().tunnel.set_static_private(
|
||||
private_key.clone(),
|
||||
public_key,
|
||||
Some(Arc::clone(&rate_limiter)),
|
||||
)
|
||||
}
|
||||
|
||||
self.key_pair = key_pair;
|
||||
self.rate_limiter = Some(rate_limiter);
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||
fn set_fwmark(&mut self, mark: u32) -> Result<(), Error> {
|
||||
self.fwmark = Some(mark);
|
||||
|
||||
// First set fwmark on listeners
|
||||
if let Some(ref sock) = self.udp4 {
|
||||
sock.set_mark(mark)?;
|
||||
}
|
||||
|
||||
if let Some(ref sock) = self.udp6 {
|
||||
sock.set_mark(mark)?;
|
||||
}
|
||||
|
||||
// Then on all currently connected sockets
|
||||
for peer in self.peers.values() {
|
||||
if let Some(ref sock) = peer.lock().endpoint().conn {
|
||||
sock.set_mark(mark)?
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clear_peers(&mut self) {
|
||||
self.peers.clear();
|
||||
self.peers_by_idx.clear();
|
||||
self.peers_by_ip.clear();
|
||||
}
|
||||
|
||||
fn register_notifiers(&mut self) -> Result<(), Error> {
|
||||
let yield_ev = self
|
||||
.queue
|
||||
// The notification event handler simply returns Action::Yield
|
||||
.new_notifier(Box::new(|_, _| Action::Yield))?;
|
||||
self.yield_notice = Some(yield_ev);
|
||||
|
||||
let exit_ev = self
|
||||
.queue
|
||||
// The exit event handler simply returns Action::Exit
|
||||
.new_notifier(Box::new(|_, _| Action::Exit))?;
|
||||
self.exit_notice = Some(exit_ev);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn register_timers(&self) -> Result<(), Error> {
|
||||
self.queue.new_periodic_event(
|
||||
// Reset the rate limiter every second give or take
|
||||
Box::new(|d, _| {
|
||||
if let Some(r) = d.rate_limiter.as_ref() {
|
||||
r.reset_count()
|
||||
}
|
||||
Action::Continue
|
||||
}),
|
||||
std::time::Duration::from_secs(1),
|
||||
)?;
|
||||
|
||||
self.queue.new_periodic_event(
|
||||
// Execute the timed function of every peer in the list
|
||||
Box::new(|d, t| {
|
||||
let peer_map = &d.peers;
|
||||
|
||||
let (udp4, udp6) = match (d.udp4.as_ref(), d.udp6.as_ref()) {
|
||||
(Some(udp4), Some(udp6)) => (udp4, udp6),
|
||||
_ => return Action::Continue,
|
||||
};
|
||||
|
||||
// Go over each peer and invoke the timer function
|
||||
for peer in peer_map.values() {
|
||||
let mut p = peer.lock();
|
||||
let endpoint_addr = match p.endpoint().addr {
|
||||
Some(addr) => addr,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
match p.update_timers(&mut t.dst_buf[..]) {
|
||||
TunnResult::Done => {}
|
||||
TunnResult::Err(WireGuardError::ConnectionExpired) => {
|
||||
p.shutdown_endpoint(); // close open udp socket
|
||||
}
|
||||
TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e),
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
match endpoint_addr {
|
||||
SocketAddr::V4(_) => {
|
||||
udp4.send_to(packet, &endpoint_addr.into()).ok()
|
||||
}
|
||||
SocketAddr::V6(_) => {
|
||||
udp6.send_to(packet, &endpoint_addr.into()).ok()
|
||||
}
|
||||
};
|
||||
}
|
||||
_ => panic!("Unexpected result from update_timers"),
|
||||
};
|
||||
}
|
||||
Action::Continue
|
||||
}),
|
||||
std::time::Duration::from_millis(250),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn trigger_yield(&self) {
|
||||
self.queue
|
||||
.trigger_notification(self.yield_notice.as_ref().unwrap())
|
||||
}
|
||||
|
||||
pub(crate) fn trigger_exit(&self) {
|
||||
self.queue
|
||||
.trigger_notification(self.exit_notice.as_ref().unwrap())
|
||||
}
|
||||
|
||||
pub(crate) fn cancel_yield(&self) {
|
||||
self.queue
|
||||
.stop_notification(self.yield_notice.as_ref().unwrap())
|
||||
}
|
||||
|
||||
fn register_udp_handler(&self, udp: socket2::Socket) -> Result<(), Error> {
|
||||
self.queue.new_event(
|
||||
udp.as_raw_fd(),
|
||||
Box::new(move |d, t| {
|
||||
// Handler that handles anonymous packets over UDP
|
||||
let mut iter = MAX_ITR;
|
||||
let (private_key, public_key) = d.key_pair.as_ref().expect("Key not set");
|
||||
|
||||
let rate_limiter = d.rate_limiter.as_ref().unwrap();
|
||||
|
||||
// Loop while we have packets on the anonymous connection
|
||||
|
||||
// Safety: the `recv_from` implementation promises not to write uninitialised
|
||||
// bytes to the buffer, so this casting is safe.
|
||||
let src_buf =
|
||||
unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit<u8>]) };
|
||||
while let Ok((packet_len, addr)) = udp.recv_from(src_buf) {
|
||||
let packet = &t.src_buf[..packet_len];
|
||||
// The rate limiter initially checks mac1 and mac2, and optionally asks to send a cookie
|
||||
let parsed_packet = match rate_limiter.verify_packet(
|
||||
Some(addr.as_socket().unwrap().ip()),
|
||||
packet,
|
||||
&mut t.dst_buf,
|
||||
) {
|
||||
Ok(packet) => packet,
|
||||
Err(TunnResult::WriteToNetwork(cookie)) => {
|
||||
let _: Result<_, _> = udp.send_to(cookie, &addr);
|
||||
continue;
|
||||
}
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let peer = match &parsed_packet {
|
||||
Packet::HandshakeInit(p) => {
|
||||
parse_handshake_anon(private_key, public_key, p)
|
||||
.ok()
|
||||
.and_then(|hh| {
|
||||
d.peers.get(&x25519::PublicKey::from(hh.peer_static_public))
|
||||
})
|
||||
}
|
||||
Packet::HandshakeResponse(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)),
|
||||
Packet::PacketCookieReply(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)),
|
||||
Packet::PacketData(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)),
|
||||
};
|
||||
|
||||
let peer = match peer {
|
||||
None => continue,
|
||||
Some(peer) => peer,
|
||||
};
|
||||
|
||||
let mut p = peer.lock();
|
||||
|
||||
// We found a peer, use it to decapsulate the message+
|
||||
let mut flush = false; // Are there packets to send from the queue?
|
||||
match p
|
||||
.tunnel
|
||||
.handle_verified_packet(parsed_packet, &mut t.dst_buf[..])
|
||||
{
|
||||
TunnResult::Done => {}
|
||||
TunnResult::Err(_) => continue,
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
flush = true;
|
||||
let _: Result<_, _> = udp.send_to(packet, &addr);
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(packet, addr) => {
|
||||
if p.is_allowed_ip(addr) {
|
||||
t.iface.write4(packet);
|
||||
}
|
||||
}
|
||||
TunnResult::WriteToTunnelV6(packet, addr) => {
|
||||
if p.is_allowed_ip(addr) {
|
||||
t.iface.write6(packet);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if flush {
|
||||
// Flush pending queue
|
||||
while let TunnResult::WriteToNetwork(packet) =
|
||||
p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..])
|
||||
{
|
||||
let _: Result<_, _> = udp.send_to(packet, &addr);
|
||||
}
|
||||
}
|
||||
|
||||
// This packet was OK, that means we want to create a connected socket for this peer
|
||||
let addr = addr.as_socket().unwrap();
|
||||
let ip_addr = addr.ip();
|
||||
p.set_endpoint(addr);
|
||||
if d.config.use_connected_socket {
|
||||
if let Ok(sock) = p.connect_endpoint(d.listen_port, d.fwmark) {
|
||||
d.register_conn_handler(Arc::clone(peer), sock, ip_addr)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
iter -= 1;
|
||||
if iter == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Action::Continue
|
||||
}),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn register_conn_handler(
|
||||
&self,
|
||||
peer: Arc<Mutex<Peer>>,
|
||||
udp: socket2::Socket,
|
||||
peer_addr: IpAddr,
|
||||
) -> Result<(), Error> {
|
||||
self.queue.new_event(
|
||||
udp.as_raw_fd(),
|
||||
Box::new(move |_, t| {
|
||||
// The conn_handler handles packet received from a connected UDP socket, associated
|
||||
// with a known peer, this saves us the hustle of finding the right peer. If another
|
||||
// peer gets the same ip, it will be ignored until the socket does not expire.
|
||||
let iface = &t.iface;
|
||||
let mut iter = MAX_ITR;
|
||||
|
||||
// Safety: the `recv_from` implementation promises not to write uninitialised
|
||||
// bytes to the buffer, so this casting is safe.
|
||||
let src_buf =
|
||||
unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit<u8>]) };
|
||||
|
||||
while let Ok(read_bytes) = udp.recv(src_buf) {
|
||||
let mut flush = false;
|
||||
let mut p = peer.lock();
|
||||
match p.tunnel.decapsulate(
|
||||
Some(peer_addr),
|
||||
&t.src_buf[..read_bytes],
|
||||
&mut t.dst_buf[..],
|
||||
) {
|
||||
TunnResult::Done => {}
|
||||
TunnResult::Err(e) => eprintln!("Decapsulate error {:?}", e),
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
flush = true;
|
||||
let _: Result<_, _> = udp.send(packet);
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(packet, addr) => {
|
||||
if p.is_allowed_ip(addr) {
|
||||
iface.write4(packet);
|
||||
}
|
||||
}
|
||||
TunnResult::WriteToTunnelV6(packet, addr) => {
|
||||
if p.is_allowed_ip(addr) {
|
||||
iface.write6(packet);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if flush {
|
||||
// Flush pending queue
|
||||
while let TunnResult::WriteToNetwork(packet) =
|
||||
p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..])
|
||||
{
|
||||
let _: Result<_, _> = udp.send(packet);
|
||||
}
|
||||
}
|
||||
|
||||
iter -= 1;
|
||||
if iter == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Action::Continue
|
||||
}),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn register_iface_handler(&self, iface: Arc<TunSocket>) -> Result<(), Error> {
|
||||
self.queue.new_event(
|
||||
iface.as_raw_fd(),
|
||||
Box::new(move |d, t| {
|
||||
// The iface_handler handles packets received from the WireGuard virtual network
|
||||
// interface. The flow is as follows:
|
||||
// * Read a packet
|
||||
// * Determine peer based on packet destination ip
|
||||
// * Encapsulate the packet for the given peer
|
||||
// * Send encapsulated packet to the peer's endpoint
|
||||
let mtu = d.mtu.load(Ordering::Relaxed);
|
||||
|
||||
let udp4 = d.udp4.as_ref().expect("Not connected");
|
||||
let udp6 = d.udp6.as_ref().expect("Not connected");
|
||||
|
||||
let peers = &d.peers_by_ip;
|
||||
for _ in 0..MAX_ITR {
|
||||
let src = match iface.read(&mut t.src_buf[..mtu]) {
|
||||
Ok(src) => src,
|
||||
Err(Error::IfaceRead(e)) => {
|
||||
let ek = e.kind();
|
||||
if ek == io::ErrorKind::Interrupted || ek == io::ErrorKind::WouldBlock {
|
||||
break;
|
||||
}
|
||||
eprintln!("Fatal read error on tun interface: {:?}", e);
|
||||
return Action::Exit;
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Unexpected error on tun interface: {:?}", e);
|
||||
return Action::Exit;
|
||||
}
|
||||
};
|
||||
|
||||
let dst_addr = match Tunn::dst_address(src) {
|
||||
Some(addr) => addr,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let mut peer = match peers.find(dst_addr) {
|
||||
Some(peer) => peer.lock(),
|
||||
None => continue,
|
||||
};
|
||||
|
||||
match peer.tunnel.encapsulate(src, &mut t.dst_buf[..]) {
|
||||
TunnResult::Done => {}
|
||||
TunnResult::Err(e) => {
|
||||
tracing::error!(message = "Encapsulate error", error = ?e)
|
||||
}
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
let mut endpoint = peer.endpoint_mut();
|
||||
if let Some(conn) = endpoint.conn.as_mut() {
|
||||
// Prefer to send using the connected socket
|
||||
let _: Result<_, _> = conn.write(packet);
|
||||
} else if let Some(addr @ SocketAddr::V4(_)) = endpoint.addr {
|
||||
let _: Result<_, _> = udp4.send_to(packet, &addr.into());
|
||||
} else if let Some(addr @ SocketAddr::V6(_)) = endpoint.addr {
|
||||
let _: Result<_, _> = udp6.send_to(packet, &addr.into());
|
||||
} else {
|
||||
tracing::error!("No endpoint");
|
||||
}
|
||||
}
|
||||
_ => panic!("Unexpected result from encapsulate"),
|
||||
};
|
||||
}
|
||||
Action::Continue
|
||||
}),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A basic linear-feedback shift register implemented as xorshift, used to
|
||||
/// distribute peer indexes across the 24-bit address space reserved for peer
|
||||
/// identification.
|
||||
/// The purpose is to obscure the total number of peers using the system and to
|
||||
/// ensure it requires a non-trivial amount of processing power and/or samples
|
||||
/// to guess other peers' indices. Anything more ambitious than this is wasted
|
||||
/// with only 24 bits of space.
|
||||
struct IndexLfsr {
|
||||
initial: u32,
|
||||
lfsr: u32,
|
||||
mask: u32,
|
||||
}
|
||||
|
||||
impl IndexLfsr {
|
||||
/// Generate a random 24-bit nonzero integer
|
||||
fn random_index() -> u32 {
|
||||
const LFSR_MAX: u32 = 0xffffff; // 24-bit seed
|
||||
loop {
|
||||
let i = OsRng.next_u32() & LFSR_MAX;
|
||||
if i > 0 {
|
||||
// LFSR seed must be non-zero
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate the next value in the pseudorandom sequence
|
||||
fn next(&mut self) -> u32 {
|
||||
// 24-bit polynomial for randomness. This is arbitrarily chosen to
|
||||
// inject bitflips into the value.
|
||||
const LFSR_POLY: u32 = 0xd80000; // 24-bit polynomial
|
||||
let value = self.lfsr - 1; // lfsr will never have value of 0
|
||||
self.lfsr = (self.lfsr >> 1) ^ ((0u32.wrapping_sub(self.lfsr & 1u32)) & LFSR_POLY);
|
||||
assert!(self.lfsr != self.initial, "Too many peers created");
|
||||
value ^ self.mask
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for IndexLfsr {
|
||||
fn default() -> Self {
|
||||
let seed = Self::random_index();
|
||||
IndexLfsr {
|
||||
initial: seed,
|
||||
lfsr: seed,
|
||||
mask: Self::random_index(),
|
||||
}
|
||||
}
|
||||
}
|
||||
170
lib/boringtun/src/device/peer.rs
Normal file
170
lib/boringtun/src/device/peer.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use socket2::{Domain, Protocol, Type};
|
||||
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::device::{AllowedIps, Error};
|
||||
use crate::noise::{Tunn, TunnResult};
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct Endpoint {
|
||||
pub addr: Option<SocketAddr>,
|
||||
pub conn: Option<socket2::Socket>,
|
||||
}
|
||||
|
||||
pub struct Peer {
|
||||
/// The associated tunnel struct
|
||||
pub(crate) tunnel: Tunn,
|
||||
/// The index the tunnel uses
|
||||
index: u32,
|
||||
endpoint: RwLock<Endpoint>,
|
||||
allowed_ips: AllowedIps<()>,
|
||||
preshared_key: Option<[u8; 32]>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
|
||||
pub struct AllowedIP {
|
||||
pub addr: IpAddr,
|
||||
pub cidr: u8,
|
||||
}
|
||||
|
||||
impl FromStr for AllowedIP {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let ip: Vec<&str> = s.split('/').collect();
|
||||
if ip.len() != 2 {
|
||||
return Err("Invalid IP format".to_owned());
|
||||
}
|
||||
|
||||
let (addr, cidr) = (ip[0].parse::<IpAddr>(), ip[1].parse::<u8>());
|
||||
match (addr, cidr) {
|
||||
(Ok(addr @ IpAddr::V4(_)), Ok(cidr)) if cidr <= 32 => Ok(AllowedIP { addr, cidr }),
|
||||
(Ok(addr @ IpAddr::V6(_)), Ok(cidr)) if cidr <= 128 => Ok(AllowedIP { addr, cidr }),
|
||||
_ => Err("Invalid IP format".to_owned()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Peer {
|
||||
pub fn new(
|
||||
tunnel: Tunn,
|
||||
index: u32,
|
||||
endpoint: Option<SocketAddr>,
|
||||
allowed_ips: &[AllowedIP],
|
||||
preshared_key: Option<[u8; 32]>,
|
||||
) -> Peer {
|
||||
Peer {
|
||||
tunnel,
|
||||
index,
|
||||
endpoint: RwLock::new(Endpoint {
|
||||
addr: endpoint,
|
||||
conn: None,
|
||||
}),
|
||||
allowed_ips: allowed_ips.iter().map(|ip| (ip, ())).collect(),
|
||||
preshared_key,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
|
||||
self.tunnel.update_timers(dst)
|
||||
}
|
||||
|
||||
pub fn endpoint(&self) -> parking_lot::RwLockReadGuard<'_, Endpoint> {
|
||||
self.endpoint.read()
|
||||
}
|
||||
|
||||
pub(crate) fn endpoint_mut(&self) -> parking_lot::RwLockWriteGuard<'_, Endpoint> {
|
||||
self.endpoint.write()
|
||||
}
|
||||
|
||||
pub fn shutdown_endpoint(&self) {
|
||||
if let Some(conn) = self.endpoint.write().conn.take() {
|
||||
tracing::info!("Disconnecting from endpoint");
|
||||
conn.shutdown(Shutdown::Both).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_endpoint(&self, addr: SocketAddr) {
|
||||
let mut endpoint = self.endpoint.write();
|
||||
if endpoint.addr != Some(addr) {
|
||||
// We only need to update the endpoint if it differs from the current one
|
||||
if let Some(conn) = endpoint.conn.take() {
|
||||
conn.shutdown(Shutdown::Both).unwrap();
|
||||
}
|
||||
|
||||
endpoint.addr = Some(addr);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connect_endpoint(
|
||||
&self,
|
||||
port: u16,
|
||||
fwmark: Option<u32>,
|
||||
) -> Result<socket2::Socket, Error> {
|
||||
let mut endpoint = self.endpoint.write();
|
||||
|
||||
if endpoint.conn.is_some() {
|
||||
return Err(Error::Connect("Connected".to_owned()));
|
||||
}
|
||||
|
||||
let addr = endpoint
|
||||
.addr
|
||||
.expect("Attempt to connect to undefined endpoint");
|
||||
|
||||
let udp_conn =
|
||||
socket2::Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::UDP))?;
|
||||
udp_conn.set_reuse_address(true)?;
|
||||
let bind_addr = if addr.is_ipv4() {
|
||||
SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into()
|
||||
} else {
|
||||
SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).into()
|
||||
};
|
||||
udp_conn.bind(&bind_addr)?;
|
||||
udp_conn.connect(&addr.into())?;
|
||||
udp_conn.set_nonblocking(true)?;
|
||||
|
||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||
if let Some(fwmark) = fwmark {
|
||||
udp_conn.set_mark(fwmark)?;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
message="Connected endpoint",
|
||||
port=port,
|
||||
endpoint=?endpoint.addr.unwrap()
|
||||
);
|
||||
|
||||
endpoint.conn = Some(udp_conn.try_clone().unwrap());
|
||||
|
||||
Ok(udp_conn)
|
||||
}
|
||||
|
||||
pub fn is_allowed_ip<I: Into<IpAddr>>(&self, addr: I) -> bool {
|
||||
self.allowed_ips.find(addr.into()).is_some()
|
||||
}
|
||||
|
||||
pub fn allowed_ips(&self) -> impl Iterator<Item = (IpAddr, u8)> + '_ {
|
||||
self.allowed_ips.iter().map(|(_, ip, cidr)| (ip, cidr))
|
||||
}
|
||||
|
||||
pub fn time_since_last_handshake(&self) -> Option<std::time::Duration> {
|
||||
self.tunnel.time_since_last_handshake()
|
||||
}
|
||||
|
||||
pub fn persistent_keepalive(&self) -> Option<u16> {
|
||||
self.tunnel.persistent_keepalive()
|
||||
}
|
||||
|
||||
pub fn preshared_key(&self) -> Option<&[u8; 32]> {
|
||||
self.preshared_key.as_ref()
|
||||
}
|
||||
|
||||
pub fn index(&self) -> u32 {
|
||||
self.index
|
||||
}
|
||||
}
|
||||
256
lib/boringtun/src/device/tun_darwin.rs
Normal file
256
lib/boringtun/src/device/tun_darwin.rs
Normal file
@@ -0,0 +1,256 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use super::Error;
|
||||
use libc::*;
|
||||
use std::io;
|
||||
use std::mem::size_of;
|
||||
use std::mem::size_of_val;
|
||||
use std::os::unix::io::{AsRawFd, RawFd};
|
||||
use std::ptr::null_mut;
|
||||
|
||||
const CTRL_NAME: &[u8] = b"com.apple.net.utun_control";
|
||||
|
||||
#[repr(C)]
|
||||
pub struct ctl_info {
|
||||
pub ctl_id: u32,
|
||||
pub ctl_name: [c_uchar; 96],
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
union IfrIfru {
|
||||
ifru_addr: sockaddr,
|
||||
ifru_addr_v4: sockaddr_in,
|
||||
ifru_addr_v6: sockaddr_in,
|
||||
ifru_dstaddr: sockaddr,
|
||||
ifru_broadaddr: sockaddr,
|
||||
ifru_flags: c_short,
|
||||
ifru_metric: c_int,
|
||||
ifru_mtu: c_int,
|
||||
ifru_phys: c_int,
|
||||
ifru_media: c_int,
|
||||
ifru_intval: c_int,
|
||||
//ifru_data: caddr_t,
|
||||
//ifru_devmtu: ifdevmtu,
|
||||
//ifru_kpi: ifkpi,
|
||||
ifru_wake_flags: u32,
|
||||
ifru_route_refcnt: u32,
|
||||
ifru_cap: [c_int; 2],
|
||||
ifru_functional_type: u32,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct ifreq {
|
||||
ifr_name: [c_uchar; IF_NAMESIZE],
|
||||
ifr_ifru: IfrIfru,
|
||||
}
|
||||
|
||||
const CTLIOCGINFO: u64 = 0x0000_0000_c064_4e03;
|
||||
const SIOCGIFMTU: u64 = 0x0000_0000_c020_6933;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct TunSocket {
|
||||
fd: RawFd,
|
||||
}
|
||||
|
||||
impl Drop for TunSocket {
|
||||
fn drop(&mut self) {
|
||||
unsafe { close(self.fd) };
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRawFd for TunSocket {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.fd
|
||||
}
|
||||
}
|
||||
|
||||
// On Darwin tunnel can only be named utunXXX
|
||||
pub fn parse_utun_name(name: &str) -> Result<u32, Error> {
|
||||
if !name.starts_with("utun") {
|
||||
return Err(Error::InvalidTunnelName);
|
||||
}
|
||||
|
||||
match name.get(4..) {
|
||||
None | Some("") => {
|
||||
// The name is simply "utun"
|
||||
Ok(0)
|
||||
}
|
||||
Some(idx) => {
|
||||
// Everything past utun should represent an integer index
|
||||
idx.parse::<u32>()
|
||||
.map_err(|_| Error::InvalidTunnelName)
|
||||
.map(|x| x + 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TunSocket {
|
||||
fn write(&self, src: &[u8], af: u8) -> usize {
|
||||
let mut hdr = [0u8, 0u8, 0u8, af as u8];
|
||||
let mut iov = [
|
||||
iovec {
|
||||
iov_base: hdr.as_mut_ptr() as _,
|
||||
iov_len: hdr.len(),
|
||||
},
|
||||
iovec {
|
||||
iov_base: src.as_ptr() as _,
|
||||
iov_len: src.len(),
|
||||
},
|
||||
];
|
||||
|
||||
let msg_hdr = msghdr {
|
||||
msg_name: null_mut(),
|
||||
msg_namelen: 0,
|
||||
msg_iov: &mut iov[0],
|
||||
msg_iovlen: iov.len() as _,
|
||||
msg_control: null_mut(),
|
||||
msg_controllen: 0,
|
||||
msg_flags: 0,
|
||||
};
|
||||
|
||||
match unsafe { sendmsg(self.fd, &msg_hdr, 0) } {
|
||||
-1 => 0,
|
||||
n => n as usize,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(name: &str) -> Result<TunSocket, Error> {
|
||||
let idx = parse_utun_name(name)?;
|
||||
|
||||
let fd = match unsafe { socket(PF_SYSTEM, SOCK_DGRAM, SYSPROTO_CONTROL) } {
|
||||
-1 => return Err(Error::Socket(io::Error::last_os_error())),
|
||||
fd => fd,
|
||||
};
|
||||
|
||||
let mut info = ctl_info {
|
||||
ctl_id: 0,
|
||||
ctl_name: [0u8; 96],
|
||||
};
|
||||
info.ctl_name[..CTRL_NAME.len()].copy_from_slice(CTRL_NAME);
|
||||
|
||||
if unsafe { ioctl(fd, CTLIOCGINFO, &mut info as *mut ctl_info) } < 0 {
|
||||
unsafe { close(fd) };
|
||||
return Err(Error::IOCtl(io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
let addr = sockaddr_ctl {
|
||||
sc_len: size_of::<sockaddr_ctl>() as u8,
|
||||
sc_family: AF_SYSTEM as u8,
|
||||
ss_sysaddr: AF_SYS_CONTROL as u16,
|
||||
sc_id: info.ctl_id,
|
||||
sc_unit: idx,
|
||||
sc_reserved: Default::default(),
|
||||
};
|
||||
|
||||
if unsafe {
|
||||
connect(
|
||||
fd,
|
||||
&addr as *const sockaddr_ctl as _,
|
||||
size_of_val(&addr) as _,
|
||||
)
|
||||
} < 0
|
||||
{
|
||||
unsafe { close(fd) };
|
||||
let mut err_string = io::Error::last_os_error().to_string();
|
||||
err_string.push_str("(did you run with sudo?)");
|
||||
return Err(Error::Connect(err_string));
|
||||
}
|
||||
|
||||
Ok(TunSocket { fd })
|
||||
}
|
||||
|
||||
pub fn set_non_blocking(self) -> Result<TunSocket, Error> {
|
||||
match unsafe { fcntl(self.fd, F_GETFL) } {
|
||||
-1 => Err(Error::FCntl(io::Error::last_os_error())),
|
||||
flags => match unsafe { fcntl(self.fd, F_SETFL, flags | O_NONBLOCK) } {
|
||||
-1 => Err(Error::FCntl(io::Error::last_os_error())),
|
||||
_ => Ok(self),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> Result<String, Error> {
|
||||
let mut tunnel_name = [0u8; 256];
|
||||
let mut tunnel_name_len: socklen_t = tunnel_name.len() as u32;
|
||||
if unsafe {
|
||||
getsockopt(
|
||||
self.fd,
|
||||
SYSPROTO_CONTROL,
|
||||
UTUN_OPT_IFNAME,
|
||||
tunnel_name.as_mut_ptr() as _,
|
||||
&mut tunnel_name_len,
|
||||
)
|
||||
} < 0
|
||||
|| tunnel_name_len == 0
|
||||
{
|
||||
return Err(Error::GetSockOpt(io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8_lossy(&tunnel_name[..(tunnel_name_len - 1) as usize]).to_string())
|
||||
}
|
||||
|
||||
/// Get the current MTU value
|
||||
pub fn mtu(&self) -> Result<usize, Error> {
|
||||
let fd = match unsafe { socket(AF_INET, SOCK_STREAM, IPPROTO_IP) } {
|
||||
-1 => return Err(Error::Socket(io::Error::last_os_error())),
|
||||
fd => fd,
|
||||
};
|
||||
|
||||
let name = self.name()?;
|
||||
let iface_name: &[u8] = name.as_ref();
|
||||
let mut ifr = ifreq {
|
||||
ifr_name: [0; IF_NAMESIZE],
|
||||
ifr_ifru: IfrIfru { ifru_mtu: 0 },
|
||||
};
|
||||
|
||||
ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name);
|
||||
|
||||
if unsafe { ioctl(fd, SIOCGIFMTU, &ifr) } < 0 {
|
||||
return Err(Error::IOCtl(io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
unsafe { close(fd) };
|
||||
|
||||
Ok(unsafe { ifr.ifr_ifru.ifru_mtu } as _)
|
||||
}
|
||||
|
||||
pub fn write4(&self, src: &[u8]) -> usize {
|
||||
self.write(src, AF_INET as u8)
|
||||
}
|
||||
|
||||
pub fn write6(&self, src: &[u8]) -> usize {
|
||||
self.write(src, AF_INET6 as u8)
|
||||
}
|
||||
|
||||
pub fn read<'a>(&self, dst: &'a mut [u8]) -> Result<&'a mut [u8], Error> {
|
||||
let mut hdr = [0u8; 4];
|
||||
|
||||
let mut iov = [
|
||||
iovec {
|
||||
iov_base: hdr.as_mut_ptr() as _,
|
||||
iov_len: hdr.len(),
|
||||
},
|
||||
iovec {
|
||||
iov_base: dst.as_mut_ptr() as _,
|
||||
iov_len: dst.len(),
|
||||
},
|
||||
];
|
||||
|
||||
let mut msg_hdr = msghdr {
|
||||
msg_name: null_mut(),
|
||||
msg_namelen: 0,
|
||||
msg_iov: &mut iov[0],
|
||||
msg_iovlen: iov.len() as _,
|
||||
msg_control: null_mut(),
|
||||
msg_controllen: 0,
|
||||
msg_flags: 0,
|
||||
};
|
||||
|
||||
match unsafe { recvmsg(self.fd, &mut msg_hdr, 0) } {
|
||||
-1 => Err(Error::IfaceRead(io::Error::last_os_error())),
|
||||
0..=4 => Ok(&mut dst[..0]),
|
||||
n => Ok(&mut dst[..(n - 4) as usize]),
|
||||
}
|
||||
}
|
||||
}
|
||||
159
lib/boringtun/src/device/tun_linux.rs
Normal file
159
lib/boringtun/src/device/tun_linux.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use super::Error;
|
||||
use libc::*;
|
||||
use std::io;
|
||||
use std::os::unix::io::{AsRawFd, RawFd};
|
||||
|
||||
const TUNSETIFF: u64 = 0x4004_54ca;
|
||||
|
||||
#[repr(C)]
|
||||
union IfrIfru {
|
||||
ifru_addr: sockaddr,
|
||||
ifru_addr_v4: sockaddr_in,
|
||||
ifru_addr_v6: sockaddr_in,
|
||||
ifru_dstaddr: sockaddr,
|
||||
ifru_broadaddr: sockaddr,
|
||||
ifru_flags: c_short,
|
||||
ifru_metric: c_int,
|
||||
ifru_mtu: c_int,
|
||||
ifru_phys: c_int,
|
||||
ifru_media: c_int,
|
||||
ifru_intval: c_int,
|
||||
//ifru_data: caddr_t,
|
||||
//ifru_devmtu: ifdevmtu,
|
||||
//ifru_kpi: ifkpi,
|
||||
ifru_wake_flags: u32,
|
||||
ifru_route_refcnt: u32,
|
||||
ifru_cap: [c_int; 2],
|
||||
ifru_functional_type: u32,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct ifreq {
|
||||
ifr_name: [c_uchar; IFNAMSIZ],
|
||||
ifr_ifru: IfrIfru,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct TunSocket {
|
||||
fd: RawFd,
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl Drop for TunSocket {
|
||||
fn drop(&mut self) {
|
||||
unsafe { close(self.fd) };
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRawFd for TunSocket {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.fd
|
||||
}
|
||||
}
|
||||
|
||||
impl TunSocket {
|
||||
fn write(&self, buf: &[u8]) -> usize {
|
||||
match unsafe { write(self.fd, buf.as_ptr() as _, buf.len() as _) } {
|
||||
-1 => 0,
|
||||
n => n as usize,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(name: &str) -> Result<TunSocket, Error> {
|
||||
// If the provided name appears to be a FD, use that.
|
||||
let provided_fd = name.parse::<i32>();
|
||||
if let Ok(fd) = provided_fd {
|
||||
return Ok(TunSocket {
|
||||
fd,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let fd = match unsafe { open(b"/dev/net/tun\0".as_ptr() as _, O_RDWR) } {
|
||||
-1 => return Err(Error::Socket(io::Error::last_os_error())),
|
||||
fd => fd,
|
||||
};
|
||||
let iface_name = name.as_bytes();
|
||||
let mut ifr = ifreq {
|
||||
ifr_name: [0; IFNAMSIZ],
|
||||
ifr_ifru: IfrIfru {
|
||||
ifru_flags: (IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE) as _,
|
||||
},
|
||||
};
|
||||
|
||||
if iface_name.len() >= ifr.ifr_name.len() {
|
||||
return Err(Error::InvalidTunnelName);
|
||||
}
|
||||
|
||||
ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name);
|
||||
|
||||
if unsafe { ioctl(fd, TUNSETIFF as _, &ifr) } < 0 {
|
||||
return Err(Error::IOCtl(io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
let name = name.to_string();
|
||||
Ok(TunSocket { fd, name })
|
||||
}
|
||||
|
||||
pub fn set_non_blocking(self) -> Result<TunSocket, Error> {
|
||||
match unsafe { fcntl(self.fd, F_GETFL) } {
|
||||
-1 => Err(Error::FCntl(io::Error::last_os_error())),
|
||||
flags => match unsafe { fcntl(self.fd, F_SETFL, flags | O_NONBLOCK) } {
|
||||
-1 => Err(Error::FCntl(io::Error::last_os_error())),
|
||||
_ => Ok(self),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> Result<String, Error> {
|
||||
Ok(self.name.clone())
|
||||
}
|
||||
|
||||
/// Get the current MTU value
|
||||
pub fn mtu(&self) -> Result<usize, Error> {
|
||||
let provided_fd = self.name.parse::<i32>();
|
||||
if provided_fd.is_ok() {
|
||||
return Ok(1500);
|
||||
}
|
||||
|
||||
let fd = match unsafe { socket(AF_INET, SOCK_STREAM, IPPROTO_IP) } {
|
||||
-1 => return Err(Error::Socket(io::Error::last_os_error())),
|
||||
fd => fd,
|
||||
};
|
||||
|
||||
let name = self.name()?;
|
||||
let iface_name: &[u8] = name.as_ref();
|
||||
let mut ifr = ifreq {
|
||||
ifr_name: [0; IF_NAMESIZE],
|
||||
ifr_ifru: IfrIfru { ifru_mtu: 0 },
|
||||
};
|
||||
|
||||
ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name);
|
||||
|
||||
if unsafe { ioctl(fd, SIOCGIFMTU as _, &ifr) } < 0 {
|
||||
return Err(Error::IOCtl(io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
unsafe { close(fd) };
|
||||
|
||||
Ok(unsafe { ifr.ifr_ifru.ifru_mtu } as _)
|
||||
}
|
||||
|
||||
pub fn write4(&self, src: &[u8]) -> usize {
|
||||
self.write(src)
|
||||
}
|
||||
|
||||
pub fn write6(&self, src: &[u8]) -> usize {
|
||||
self.write(src)
|
||||
}
|
||||
|
||||
pub fn read<'a>(&self, dst: &'a mut [u8]) -> Result<&'a mut [u8], Error> {
|
||||
match unsafe { read(self.fd, dst.as_mut_ptr() as _, dst.len()) } {
|
||||
-1 => Err(Error::IfaceRead(io::Error::last_os_error())),
|
||||
n => Ok(&mut dst[..n as usize]),
|
||||
}
|
||||
}
|
||||
}
|
||||
397
lib/boringtun/src/ffi/mod.rs
Normal file
397
lib/boringtun/src/ffi/mod.rs
Normal file
@@ -0,0 +1,397 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Requiring explicit per-fn "Safety" docs not worth it. Just pass in valid
|
||||
// pointers and buffers/lengths to these, ok?
|
||||
#![allow(clippy::missing_safety_doc)]
|
||||
|
||||
//! C bindings for the BoringTun library
|
||||
use super::noise::{Tunn, TunnResult};
|
||||
use crate::x25519::{PublicKey, StaticSecret};
|
||||
use base64::{decode, encode};
|
||||
use hex::encode as encode_hex;
|
||||
use libc::{raise, SIGSEGV};
|
||||
use parking_lot::Mutex;
|
||||
use rand_core::OsRng;
|
||||
use tracing;
|
||||
use tracing_subscriber::fmt;
|
||||
|
||||
use crate::serialization::KeyBytes;
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::io::{Error, ErrorKind, Write};
|
||||
use std::os::raw::c_char;
|
||||
use std::panic;
|
||||
use std::ptr;
|
||||
use std::ptr::null_mut;
|
||||
use std::slice;
|
||||
use std::sync::Once;
|
||||
|
||||
static PANIC_HOOK: Once = Once::new();
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[repr(C)]
|
||||
/// Indicates the operation required from the caller
|
||||
pub enum result_type {
|
||||
/// No operation is required.
|
||||
WIREGUARD_DONE = 0,
|
||||
/// Write dst buffer to network. Size indicates the number of bytes to write.
|
||||
WRITE_TO_NETWORK = 1,
|
||||
/// Some error occurred, no operation is required. Size indicates error code.
|
||||
WIREGUARD_ERROR = 2,
|
||||
/// Write dst buffer to the interface as an ipv4 packet. Size indicates the number of bytes to write.
|
||||
WRITE_TO_TUNNEL_IPV4 = 4,
|
||||
/// Write dst buffer to the interface as an ipv6 packet. Size indicates the number of bytes to write.
|
||||
WRITE_TO_TUNNEL_IPV6 = 6,
|
||||
}
|
||||
|
||||
/// The return type of WireGuard functions
|
||||
#[repr(C)]
|
||||
pub struct wireguard_result {
|
||||
/// The operation to be performed by the caller
|
||||
pub op: result_type,
|
||||
/// Additional information, required to perform the operation
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct stats {
|
||||
pub time_since_last_handshake: i64,
|
||||
pub tx_bytes: usize,
|
||||
pub rx_bytes: usize,
|
||||
pub estimated_loss: f32,
|
||||
pub estimated_rtt: i32,
|
||||
reserved: [u8; 56], // Make sure to add new fields in this space, keeping total size constant
|
||||
}
|
||||
|
||||
impl<'a> From<TunnResult<'a>> for wireguard_result {
|
||||
fn from(res: TunnResult<'a>) -> wireguard_result {
|
||||
match res {
|
||||
TunnResult::Done => wireguard_result {
|
||||
op: result_type::WIREGUARD_DONE,
|
||||
size: 0,
|
||||
},
|
||||
TunnResult::Err(e) => wireguard_result {
|
||||
op: result_type::WIREGUARD_ERROR,
|
||||
size: e as _,
|
||||
},
|
||||
TunnResult::WriteToNetwork(b) => wireguard_result {
|
||||
op: result_type::WRITE_TO_NETWORK,
|
||||
size: b.len(),
|
||||
},
|
||||
TunnResult::WriteToTunnelV4(b, _) => wireguard_result {
|
||||
op: result_type::WRITE_TO_TUNNEL_IPV4,
|
||||
size: b.len(),
|
||||
},
|
||||
TunnResult::WriteToTunnelV6(b, _) => wireguard_result {
|
||||
op: result_type::WRITE_TO_TUNNEL_IPV6,
|
||||
size: b.len(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct x25519_key {
|
||||
pub key: [u8; 32],
|
||||
}
|
||||
|
||||
/// Generates a new x25519 secret key.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn x25519_secret_key() -> x25519_key {
|
||||
x25519_key {
|
||||
key: StaticSecret::random_from_rng(OsRng).to_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes a public x25519 key from a secret key.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn x25519_public_key(private_key: x25519_key) -> x25519_key {
|
||||
let private = StaticSecret::from(private_key.key);
|
||||
let public = PublicKey::from(&private);
|
||||
x25519_key {
|
||||
key: public.to_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the base64 encoding of a key as a UTF8 C-string.
|
||||
///
|
||||
/// The memory has to be freed by calling `x25519_key_to_str_free`
|
||||
#[no_mangle]
|
||||
pub extern "C" fn x25519_key_to_base64(key: x25519_key) -> *const c_char {
|
||||
let encoded_key = encode(key.key);
|
||||
CString::into_raw(CString::new(encoded_key).unwrap())
|
||||
}
|
||||
|
||||
/// Returns the hex encoding of a key as a UTF8 C-string.
|
||||
///
|
||||
/// The memory has to be freed by calling `x25519_key_to_str_free`
|
||||
#[no_mangle]
|
||||
pub extern "C" fn x25519_key_to_hex(key: x25519_key) -> *const c_char {
|
||||
let encoded_key = encode_hex(key.key);
|
||||
CString::into_raw(CString::new(encoded_key).unwrap())
|
||||
}
|
||||
|
||||
/// Frees memory of the string given by `x25519_key_to_hex` or `x25519_key_to_base64`
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn x25519_key_to_str_free(stringified_key: *mut c_char) {
|
||||
drop(CString::from_raw(stringified_key));
|
||||
}
|
||||
|
||||
/// Check if the input C-string represents a valid base64 encoded x25519 key.
|
||||
/// Return 1 if valid 0 otherwise.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn check_base64_encoded_x25519_key(key: *const c_char) -> i32 {
|
||||
let c_str = CStr::from_ptr(key);
|
||||
let utf8_key = match c_str.to_str() {
|
||||
Err(_) => return 0,
|
||||
Ok(string) => string,
|
||||
};
|
||||
|
||||
if let Ok(key) = decode(utf8_key) {
|
||||
let len = key.len();
|
||||
let mut zero = 0u8;
|
||||
for b in key {
|
||||
zero |= b
|
||||
}
|
||||
if len == 32 && zero != 0 {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Custom tracing_subscriber writer to an external function pointer
|
||||
struct FFIFunctionPointerWriter {
|
||||
log_func: unsafe extern "C" fn(*const c_char),
|
||||
}
|
||||
|
||||
/// Implements Write trait for use with tracing_subscriber
|
||||
impl Write for FFIFunctionPointerWriter {
|
||||
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
|
||||
let out_str = String::from_utf8_lossy(buf).to_string();
|
||||
if let Ok(c_string) = CString::new(out_str) {
|
||||
unsafe { (self.log_func)(c_string.as_ptr()) }
|
||||
Ok(buf.len())
|
||||
} else {
|
||||
Err(Error::new(
|
||||
ErrorKind::Other,
|
||||
"Failed to create CString from buffer.",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), std::io::Error> {
|
||||
// no-op
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the default tracing_subscriber to write to `log_func`.
|
||||
///
|
||||
/// Uses Compact format without level, target, thread ids, thread names, or ansi control characters.
|
||||
/// Subscribes to TRACE level events.
|
||||
///
|
||||
/// This function should only be called once as setting the default tracing_subscriber
|
||||
/// more than once will result in an error.
|
||||
///
|
||||
/// Returns false on failure.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `c_char` will be freed by the library after calling `log_func`. If the value needs
|
||||
/// to be stored then `log_func` needs to create a copy, e.g. `strcpy`.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn set_logging_function(
|
||||
log_func: unsafe extern "C" fn(*const c_char),
|
||||
) -> bool {
|
||||
let result = std::panic::catch_unwind(|| -> bool {
|
||||
let writer = FFIFunctionPointerWriter { log_func };
|
||||
let format = fmt::format()
|
||||
// don't include levels in formatted output
|
||||
.with_level(false)
|
||||
// don't include targets
|
||||
.with_target(false)
|
||||
// don't 'include the thread ID of the current thread
|
||||
.with_thread_ids(false)
|
||||
// don't 'include the name of the current thread
|
||||
.with_thread_names(false)
|
||||
// use the `Compact` formatting style.
|
||||
.compact()
|
||||
// disable terminal escape codes
|
||||
.with_ansi(false);
|
||||
|
||||
fmt()
|
||||
.event_format(format)
|
||||
.with_writer(std::sync::Mutex::new(writer))
|
||||
.with_max_level(tracing::Level::TRACE)
|
||||
.with_ansi(false)
|
||||
.try_init()
|
||||
.is_ok()
|
||||
});
|
||||
if let Ok(value) = result {
|
||||
value
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a new tunnel, return NULL on failure.
|
||||
/// Keys must be valid base64 encoded 32-byte keys.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn new_tunnel(
|
||||
static_private: *const c_char,
|
||||
server_static_public: *const c_char,
|
||||
preshared_key: *const c_char,
|
||||
keep_alive: u16,
|
||||
index: u32,
|
||||
) -> *mut Mutex<Tunn> {
|
||||
let c_str = CStr::from_ptr(static_private);
|
||||
let static_private = match c_str.to_str() {
|
||||
Err(_) => return ptr::null_mut(),
|
||||
Ok(string) => string,
|
||||
};
|
||||
|
||||
let c_str = CStr::from_ptr(server_static_public);
|
||||
let server_static_public = match c_str.to_str() {
|
||||
Err(_) => return ptr::null_mut(),
|
||||
Ok(string) => string,
|
||||
};
|
||||
|
||||
let preshared_key = if preshared_key.is_null() {
|
||||
None
|
||||
} else {
|
||||
let c_str = CStr::from_ptr(preshared_key);
|
||||
|
||||
if let Ok(string) = c_str.to_str() {
|
||||
if let Ok(key) = string.parse::<KeyBytes>() {
|
||||
Some(key.0)
|
||||
} else {
|
||||
return null_mut();
|
||||
}
|
||||
} else {
|
||||
return null_mut();
|
||||
}
|
||||
};
|
||||
|
||||
let private_key = match static_private.parse::<KeyBytes>() {
|
||||
Err(_) => return ptr::null_mut(),
|
||||
Ok(key) => StaticSecret::from(key.0),
|
||||
};
|
||||
|
||||
let public_key = match server_static_public.parse::<KeyBytes>() {
|
||||
Err(_) => return ptr::null_mut(),
|
||||
Ok(key) => PublicKey::from(key.0),
|
||||
};
|
||||
|
||||
let keep_alive = if keep_alive == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(keep_alive)
|
||||
};
|
||||
|
||||
let tunnel = Box::new(Mutex::new(Tunn::new(
|
||||
private_key,
|
||||
public_key,
|
||||
preshared_key,
|
||||
keep_alive,
|
||||
index,
|
||||
None,
|
||||
)));
|
||||
|
||||
PANIC_HOOK.call_once(|| {
|
||||
// FFI won't properly unwind on panic, but it will if we cause a segmentation fault
|
||||
panic::set_hook(Box::new(move |_| {
|
||||
raise(SIGSEGV);
|
||||
}));
|
||||
});
|
||||
|
||||
Box::into_raw(tunnel)
|
||||
}
|
||||
|
||||
/// Drops the Tunn object
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn tunnel_free(tunnel: *mut Mutex<Tunn>) {
|
||||
drop(Box::from_raw(tunnel));
|
||||
}
|
||||
|
||||
/// Write an IP packet from the tunnel interface.
|
||||
/// For more details check noise::tunnel_to_network functions.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn wireguard_write(
|
||||
tunnel: *const Mutex<Tunn>,
|
||||
src: *const u8,
|
||||
src_size: u32,
|
||||
dst: *mut u8,
|
||||
dst_size: u32,
|
||||
) -> wireguard_result {
|
||||
let mut tunnel = tunnel.as_ref().unwrap().lock();
|
||||
// Slices are not owned, and therefore will not be freed by Rust
|
||||
let src = slice::from_raw_parts(src, src_size as usize);
|
||||
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
|
||||
wireguard_result::from(tunnel.encapsulate(src, dst))
|
||||
}
|
||||
|
||||
/// Read a UDP packet from the server.
|
||||
/// For more details check noise::network_to_tunnel functions.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn wireguard_read(
|
||||
tunnel: *const Mutex<Tunn>,
|
||||
src: *const u8,
|
||||
src_size: u32,
|
||||
dst: *mut u8,
|
||||
dst_size: u32,
|
||||
) -> wireguard_result {
|
||||
let mut tunnel = tunnel.as_ref().unwrap().lock();
|
||||
// Slices are not owned, and therefore will not be freed by Rust
|
||||
let src = slice::from_raw_parts(src, src_size as usize);
|
||||
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
|
||||
wireguard_result::from(tunnel.decapsulate(None, src, dst))
|
||||
}
|
||||
|
||||
/// This is a state keeping function, that need to be called periodically.
|
||||
/// Recommended interval: 100ms.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn wireguard_tick(
|
||||
tunnel: *const Mutex<Tunn>,
|
||||
dst: *mut u8,
|
||||
dst_size: u32,
|
||||
) -> wireguard_result {
|
||||
let mut tunnel = tunnel.as_ref().unwrap().lock();
|
||||
// Slices are not owned, and therefore will not be freed by Rust
|
||||
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
|
||||
wireguard_result::from(tunnel.update_timers(dst))
|
||||
}
|
||||
|
||||
/// Force the tunnel to initiate a new handshake, dst buffer must be at least 148 byte long.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn wireguard_force_handshake(
|
||||
tunnel: *const Mutex<Tunn>,
|
||||
dst: *mut u8,
|
||||
dst_size: u32,
|
||||
) -> wireguard_result {
|
||||
let mut tunnel = tunnel.as_ref().unwrap().lock();
|
||||
// Slices are not owned, and therefore will not be freed by Rust
|
||||
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
|
||||
wireguard_result::from(tunnel.format_handshake_initiation(dst, true))
|
||||
}
|
||||
|
||||
/// Returns stats from the tunnel:
|
||||
/// Time of last handshake in seconds (or -1 if no handshake occurred)
|
||||
/// Number of data bytes encapsulated
|
||||
/// Number of data bytes decapsulated
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn wireguard_stats(tunnel: *const Mutex<Tunn>) -> stats {
|
||||
let tunnel = tunnel.as_ref().unwrap().lock();
|
||||
let (time, tx_bytes, rx_bytes, estimated_loss, estimated_rtt) = tunnel.stats();
|
||||
stats {
|
||||
time_since_last_handshake: time.map(|t| t.as_secs() as i64).unwrap_or(-1),
|
||||
tx_bytes,
|
||||
rx_bytes,
|
||||
estimated_loss,
|
||||
estimated_rtt: estimated_rtt.map(|r| r as i32).unwrap_or(-1),
|
||||
reserved: [0u8; 56],
|
||||
}
|
||||
}
|
||||
271
lib/boringtun/src/jni.rs
Normal file
271
lib/boringtun/src/jni.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// temporary, we need to do some verification around these bindings later
|
||||
#![allow(clippy::missing_safety_doc)]
|
||||
|
||||
/// JNI bindings for BoringTun library
|
||||
use std::os::raw::c_char;
|
||||
use std::ptr;
|
||||
|
||||
use jni::objects::{JByteBuffer, JClass, JString};
|
||||
use jni::strings::JNIStr;
|
||||
use jni::sys::{jbyteArray, jint, jlong, jshort, jstring};
|
||||
use jni::JNIEnv;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::ffi::new_tunnel;
|
||||
use crate::ffi::wireguard_read;
|
||||
use crate::ffi::wireguard_result;
|
||||
use crate::ffi::wireguard_tick;
|
||||
use crate::ffi::wireguard_write;
|
||||
use crate::ffi::x25519_key;
|
||||
use crate::ffi::x25519_key_to_base64;
|
||||
use crate::ffi::x25519_key_to_hex;
|
||||
use crate::ffi::x25519_public_key;
|
||||
use crate::ffi::x25519_secret_key;
|
||||
|
||||
use crate::noise::Tunn;
|
||||
|
||||
pub extern "C" fn log_print(_log_string: *const c_char) {
|
||||
/*
|
||||
XXX:
|
||||
Define callback function in app.
|
||||
*/
|
||||
}
|
||||
|
||||
/// Generates new x25519 secret key and converts into java byte array.
|
||||
#[no_mangle]
|
||||
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1secret_1key"]
|
||||
pub extern "C" fn generate_secret_key(env: JNIEnv, _class: JClass) -> jbyteArray {
|
||||
match env.byte_array_from_slice(&x25519_secret_key().key) {
|
||||
Ok(v) => v,
|
||||
Err(_) => ptr::null_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes public x25519 key from secret key and converts into java byte array.
|
||||
#[no_mangle]
|
||||
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1public_1key"]
|
||||
pub unsafe extern "C" fn generate_public_key1(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
arg_secret_key: jbyteArray,
|
||||
) -> jbyteArray {
|
||||
let mut key_inner = [0; 32];
|
||||
|
||||
if env
|
||||
.get_byte_array_region(arg_secret_key, 0, &mut key_inner)
|
||||
.is_err()
|
||||
{
|
||||
return ptr::null_mut();
|
||||
}
|
||||
|
||||
let secret_key = x25519_key {
|
||||
key: std::mem::transmute::<[i8; 32], [u8; 32]>(key_inner),
|
||||
};
|
||||
|
||||
match env.byte_array_from_slice(&x25519_public_key(secret_key).key) {
|
||||
Ok(v) => v,
|
||||
Err(_) => ptr::null_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts x25519 key to hex string.
|
||||
#[no_mangle]
|
||||
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1key_1to_1hex"]
|
||||
pub unsafe extern "C" fn convert_x25519_key_to_hex(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
arg_key: jbyteArray,
|
||||
) -> jstring {
|
||||
let mut key = [0; 32];
|
||||
|
||||
if env.get_byte_array_region(arg_key, 0, &mut key).is_err() {
|
||||
return ptr::null_mut();
|
||||
}
|
||||
|
||||
let x25519_key = x25519_key {
|
||||
key: std::mem::transmute::<[i8; 32], [u8; 32]>(key),
|
||||
};
|
||||
|
||||
let output = match env.new_string(JNIStr::from_ptr(x25519_key_to_hex(x25519_key)).to_owned()) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return ptr::null_mut(),
|
||||
};
|
||||
|
||||
output.into_inner()
|
||||
}
|
||||
|
||||
/// Converts x25519 key to base64 string.
|
||||
#[no_mangle]
|
||||
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1key_1to_1base64"]
|
||||
pub unsafe extern "C" fn convert_x25519_key_to_base64(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
arg_key: jbyteArray,
|
||||
) -> jstring {
|
||||
let mut key = [0; 32];
|
||||
|
||||
if env.get_byte_array_region(arg_key, 0, &mut key).is_err() {
|
||||
return ptr::null_mut();
|
||||
}
|
||||
|
||||
let x25519_key = x25519_key {
|
||||
key: std::mem::transmute::<[i8; 32], [u8; 32]>(key),
|
||||
};
|
||||
|
||||
let output = match env.new_string(JNIStr::from_ptr(x25519_key_to_base64(x25519_key)).to_owned())
|
||||
{
|
||||
Ok(v) => v,
|
||||
Err(_) => return ptr::null_mut(),
|
||||
};
|
||||
|
||||
output.into_inner()
|
||||
}
|
||||
|
||||
/// Creates new tunnel
|
||||
#[no_mangle]
|
||||
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_new_1tunnel"]
|
||||
pub unsafe extern "C" fn create_new_tunnel(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
arg_secret_key: JString,
|
||||
arg_public_key: JString,
|
||||
arg_preshared_key: JString,
|
||||
keep_alive: jshort,
|
||||
index: jint,
|
||||
) -> jlong {
|
||||
let secret_key = match env.get_string_utf_chars(arg_secret_key) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return 0,
|
||||
};
|
||||
|
||||
let public_key = match env.get_string_utf_chars(arg_public_key) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return 0,
|
||||
};
|
||||
|
||||
let preshared_key = if arg_preshared_key.is_null() {
|
||||
ptr::null_mut()
|
||||
} else {
|
||||
match env.get_string_utf_chars(arg_preshared_key) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return 0,
|
||||
}
|
||||
};
|
||||
|
||||
let tunnel = new_tunnel(
|
||||
secret_key,
|
||||
public_key,
|
||||
preshared_key,
|
||||
keep_alive as u16,
|
||||
index as u32,
|
||||
);
|
||||
|
||||
if tunnel.is_null() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
tunnel as jlong
|
||||
}
|
||||
|
||||
/// Encrypts raw IP packets into WG formatted packets.
|
||||
#[no_mangle]
|
||||
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_wireguard_1write"]
|
||||
pub unsafe extern "C" fn encrypt_raw_packet(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
tunnel: jlong,
|
||||
src: jbyteArray,
|
||||
src_size: jint,
|
||||
dst: JByteBuffer,
|
||||
dst_size: jint,
|
||||
op: JByteBuffer,
|
||||
) -> jint {
|
||||
let dst_ptr: *mut u8 = match env.get_direct_buffer_address(dst) {
|
||||
Ok(v) => v.as_mut_ptr(),
|
||||
Err(_) => return 0,
|
||||
};
|
||||
|
||||
let op_ptr: *mut u8 = match env.get_direct_buffer_address(op) {
|
||||
Ok(v) => v.as_mut_ptr(),
|
||||
Err(_) => return 0,
|
||||
};
|
||||
|
||||
let output: wireguard_result = wireguard_write(
|
||||
tunnel as *const Mutex<Tunn>,
|
||||
env.convert_byte_array(src).unwrap().as_mut_ptr(),
|
||||
src_size as u32,
|
||||
dst_ptr,
|
||||
dst_size as u32,
|
||||
);
|
||||
*op_ptr = output.op as u8;
|
||||
|
||||
output.size as i32
|
||||
}
|
||||
|
||||
/// Decrypts WG formatted packets into raw IP packets.
|
||||
#[no_mangle]
|
||||
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_wireguard_1read"]
|
||||
pub unsafe extern "C" fn decrypt_to_raw_packet(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
tunnel: jlong,
|
||||
src: jbyteArray,
|
||||
src_size: jint,
|
||||
dst: JByteBuffer,
|
||||
dst_size: jint,
|
||||
op: JByteBuffer,
|
||||
) -> jint {
|
||||
let dst_ptr: *mut u8 = match env.get_direct_buffer_address(dst) {
|
||||
Ok(v) => v.as_mut_ptr(),
|
||||
Err(_) => return 0,
|
||||
};
|
||||
|
||||
let op_ptr: *mut u8 = match env.get_direct_buffer_address(op) {
|
||||
Ok(v) => v.as_mut_ptr(),
|
||||
Err(_) => return 0,
|
||||
};
|
||||
|
||||
let output: wireguard_result = wireguard_read(
|
||||
tunnel as *const Mutex<Tunn>,
|
||||
env.convert_byte_array(src).unwrap().as_mut_ptr(),
|
||||
src_size as u32,
|
||||
dst_ptr,
|
||||
dst_size as u32,
|
||||
);
|
||||
|
||||
*op_ptr = output.op as u8;
|
||||
|
||||
output.size as i32
|
||||
}
|
||||
|
||||
/// Periodic function that writes WG formatted packets into destination buffer
|
||||
#[no_mangle]
|
||||
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_wireguard_1tick"]
|
||||
pub unsafe extern "C" fn run_periodic_task(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
tunnel: jlong,
|
||||
dst: JByteBuffer,
|
||||
dst_size: jint,
|
||||
op: JByteBuffer,
|
||||
) -> jint {
|
||||
let dst_ptr: *mut u8 = match env.get_direct_buffer_address(dst) {
|
||||
Ok(v) => v.as_mut_ptr(),
|
||||
Err(_) => return 0,
|
||||
};
|
||||
|
||||
let op_ptr: *mut u8 = match env.get_direct_buffer_address(op) {
|
||||
Ok(v) => v.as_mut_ptr(),
|
||||
Err(_) => return 0,
|
||||
};
|
||||
|
||||
let output: wireguard_result =
|
||||
wireguard_tick(tunnel as *const Mutex<Tunn>, dst_ptr, dst_size as u32);
|
||||
|
||||
*op_ptr = output.op as u8;
|
||||
|
||||
output.size as i32
|
||||
}
|
||||
27
lib/boringtun/src/lib.rs
Normal file
27
lib/boringtun/src/lib.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//! Simple implementation of the client-side of the WireGuard protocol.
|
||||
//!
|
||||
//! <code>git clone https://github.com/cloudflare/boringtun.git</code>
|
||||
|
||||
#[cfg(feature = "device")]
|
||||
pub mod device;
|
||||
|
||||
#[cfg(feature = "ffi-bindings")]
|
||||
pub mod ffi;
|
||||
#[cfg(feature = "jni-bindings")]
|
||||
pub mod jni;
|
||||
pub mod noise;
|
||||
|
||||
#[cfg(not(feature = "mock-instant"))]
|
||||
pub(crate) mod sleepyinstant;
|
||||
|
||||
pub(crate) mod serialization;
|
||||
|
||||
/// Re-export of the x25519 types
|
||||
pub mod x25519 {
|
||||
pub use x25519_dalek::{
|
||||
EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret,
|
||||
};
|
||||
}
|
||||
23
lib/boringtun/src/noise/errors.rs
Normal file
23
lib/boringtun/src/noise/errors.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum WireGuardError {
|
||||
DestinationBufferTooSmall,
|
||||
IncorrectPacketLength,
|
||||
UnexpectedPacket,
|
||||
WrongPacketType,
|
||||
WrongIndex,
|
||||
WrongKey,
|
||||
InvalidTai64nTimestamp,
|
||||
WrongTai64nTimestamp,
|
||||
InvalidMac,
|
||||
InvalidAeadTag,
|
||||
InvalidCounter,
|
||||
DuplicateCounter,
|
||||
InvalidPacket,
|
||||
NoCurrentSession,
|
||||
LockFailed,
|
||||
ConnectionExpired,
|
||||
UnderLoad,
|
||||
}
|
||||
940
lib/boringtun/src/noise/handshake.rs
Normal file
940
lib/boringtun/src/noise/handshake.rs
Normal file
@@ -0,0 +1,940 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use super::{HandshakeInit, HandshakeResponse, PacketCookieReply};
|
||||
use crate::noise::errors::WireGuardError;
|
||||
use crate::noise::session::Session;
|
||||
#[cfg(not(feature = "mock-instant"))]
|
||||
use crate::sleepyinstant::Instant;
|
||||
use crate::x25519;
|
||||
use aead::{Aead, Payload};
|
||||
use blake2::digest::{FixedOutput, KeyInit};
|
||||
use blake2::{Blake2s256, Blake2sMac, Digest};
|
||||
use chacha20poly1305::XChaCha20Poly1305;
|
||||
use rand_core::OsRng;
|
||||
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
||||
use std::convert::TryInto;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
#[cfg(feature = "mock-instant")]
|
||||
use mock_instant::Instant;
|
||||
|
||||
pub(crate) const LABEL_MAC1: &[u8; 8] = b"mac1----";
|
||||
pub(crate) const LABEL_COOKIE: &[u8; 8] = b"cookie--";
|
||||
const KEY_LEN: usize = 32;
|
||||
const TIMESTAMP_LEN: usize = 12;
|
||||
|
||||
// initiator.chaining_key = HASH(CONSTRUCTION)
|
||||
const INITIAL_CHAIN_KEY: [u8; KEY_LEN] = [
|
||||
96, 226, 109, 174, 243, 39, 239, 192, 46, 195, 53, 226, 160, 37, 210, 208, 22, 235, 66, 6, 248,
|
||||
114, 119, 245, 45, 56, 209, 152, 139, 120, 205, 54,
|
||||
];
|
||||
|
||||
// initiator.chaining_hash = HASH(initiator.chaining_key || IDENTIFIER)
|
||||
const INITIAL_CHAIN_HASH: [u8; KEY_LEN] = [
|
||||
34, 17, 179, 97, 8, 26, 197, 102, 105, 18, 67, 219, 69, 138, 213, 50, 45, 156, 108, 102, 34,
|
||||
147, 232, 183, 14, 225, 156, 101, 186, 7, 158, 243,
|
||||
];
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn b2s_hash(data1: &[u8], data2: &[u8]) -> [u8; 32] {
|
||||
let mut hash = Blake2s256::new();
|
||||
hash.update(data1);
|
||||
hash.update(data2);
|
||||
hash.finalize().into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// RFC 2401 HMAC+Blake2s, not to be confused with *keyed* Blake2s
|
||||
pub(crate) fn b2s_hmac(key: &[u8], data1: &[u8]) -> [u8; 32] {
|
||||
use blake2::digest::Update;
|
||||
type HmacBlake2s = hmac::SimpleHmac<Blake2s256>;
|
||||
let mut hmac = HmacBlake2s::new_from_slice(key).unwrap();
|
||||
hmac.update(data1);
|
||||
hmac.finalize_fixed().into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Like b2s_hmac, but chain data1 and data2 together
|
||||
pub(crate) fn b2s_hmac2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 32] {
|
||||
use blake2::digest::Update;
|
||||
type HmacBlake2s = hmac::SimpleHmac<Blake2s256>;
|
||||
let mut hmac = HmacBlake2s::new_from_slice(key).unwrap();
|
||||
hmac.update(data1);
|
||||
hmac.update(data2);
|
||||
hmac.finalize_fixed().into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn b2s_keyed_mac_16(key: &[u8], data1: &[u8]) -> [u8; 16] {
|
||||
let mut hmac = Blake2sMac::new_from_slice(key).unwrap();
|
||||
blake2::digest::Update::update(&mut hmac, data1);
|
||||
hmac.finalize_fixed().into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn b2s_keyed_mac_16_2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 16] {
|
||||
let mut hmac = Blake2sMac::new_from_slice(key).unwrap();
|
||||
blake2::digest::Update::update(&mut hmac, data1);
|
||||
blake2::digest::Update::update(&mut hmac, data2);
|
||||
hmac.finalize_fixed().into()
|
||||
}
|
||||
|
||||
pub(crate) fn b2s_mac_24(key: &[u8], data1: &[u8]) -> [u8; 24] {
|
||||
let mut hmac = Blake2sMac::new_from_slice(key).unwrap();
|
||||
blake2::digest::Update::update(&mut hmac, data1);
|
||||
hmac.finalize_fixed().into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// This wrapper involves an extra copy and MAY BE SLOWER
|
||||
fn aead_chacha20_seal(ciphertext: &mut [u8], key: &[u8], counter: u64, data: &[u8], aad: &[u8]) {
|
||||
let mut nonce: [u8; 12] = [0; 12];
|
||||
nonce[4..12].copy_from_slice(&counter.to_le_bytes());
|
||||
|
||||
aead_chacha20_seal_inner(ciphertext, key, nonce, data, aad)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn aead_chacha20_seal_inner(
|
||||
ciphertext: &mut [u8],
|
||||
key: &[u8],
|
||||
nonce: [u8; 12],
|
||||
data: &[u8],
|
||||
aad: &[u8],
|
||||
) {
|
||||
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap());
|
||||
|
||||
ciphertext[..data.len()].copy_from_slice(data);
|
||||
|
||||
let tag = key
|
||||
.seal_in_place_separate_tag(
|
||||
Nonce::assume_unique_for_key(nonce),
|
||||
Aad::from(aad),
|
||||
&mut ciphertext[..data.len()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
ciphertext[data.len()..].copy_from_slice(tag.as_ref());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// This wrapper involves an extra copy and MAY BE SLOWER
|
||||
fn aead_chacha20_open(
|
||||
buffer: &mut [u8],
|
||||
key: &[u8],
|
||||
counter: u64,
|
||||
data: &[u8],
|
||||
aad: &[u8],
|
||||
) -> Result<(), WireGuardError> {
|
||||
let mut nonce: [u8; 12] = [0; 12];
|
||||
nonce[4..].copy_from_slice(&counter.to_le_bytes());
|
||||
|
||||
aead_chacha20_open_inner(buffer, key, nonce, data, aad)
|
||||
.map_err(|_| WireGuardError::InvalidAeadTag)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn aead_chacha20_open_inner(
|
||||
buffer: &mut [u8],
|
||||
key: &[u8],
|
||||
nonce: [u8; 12],
|
||||
data: &[u8],
|
||||
aad: &[u8],
|
||||
) -> Result<(), ring::error::Unspecified> {
|
||||
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap());
|
||||
|
||||
let mut inner_buffer = data.to_owned();
|
||||
|
||||
let plaintext = key.open_in_place(
|
||||
Nonce::assume_unique_for_key(nonce),
|
||||
Aad::from(aad),
|
||||
&mut inner_buffer,
|
||||
)?;
|
||||
|
||||
buffer.copy_from_slice(plaintext);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// This struct represents a 12 byte [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp
|
||||
struct Tai64N {
|
||||
secs: u64,
|
||||
nano: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// This struct computes a [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp from current system time
|
||||
struct TimeStamper {
|
||||
duration_at_start: Duration,
|
||||
instant_at_start: Instant,
|
||||
}
|
||||
|
||||
impl TimeStamper {
|
||||
/// Create a new TimeStamper
|
||||
pub fn new() -> TimeStamper {
|
||||
TimeStamper {
|
||||
duration_at_start: SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap(),
|
||||
instant_at_start: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Take time reading and generate a 12 byte timestamp
|
||||
pub fn stamp(&self) -> [u8; 12] {
|
||||
const TAI64_BASE: u64 = (1u64 << 62) + 37;
|
||||
let mut ext_stamp = [0u8; 12];
|
||||
let stamp = Instant::now().duration_since(self.instant_at_start) + self.duration_at_start;
|
||||
ext_stamp[0..8].copy_from_slice(&(stamp.as_secs() + TAI64_BASE).to_be_bytes());
|
||||
ext_stamp[8..12].copy_from_slice(&stamp.subsec_nanos().to_be_bytes());
|
||||
ext_stamp
|
||||
}
|
||||
}
|
||||
|
||||
impl Tai64N {
|
||||
/// A zeroed out timestamp
|
||||
fn zero() -> Tai64N {
|
||||
Tai64N { secs: 0, nano: 0 }
|
||||
}
|
||||
|
||||
/// Parse a timestamp from a 12 byte u8 slice
|
||||
fn parse(buf: &[u8; 12]) -> Result<Tai64N, WireGuardError> {
|
||||
if buf.len() < 12 {
|
||||
return Err(WireGuardError::InvalidTai64nTimestamp);
|
||||
}
|
||||
|
||||
let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::<u64>());
|
||||
let secs = u64::from_be_bytes(sec_bytes.try_into().unwrap());
|
||||
let nano = u32::from_be_bytes(nano_bytes.try_into().unwrap());
|
||||
|
||||
// WireGuard does not actually expect tai64n timestamp, just monotonically increasing one
|
||||
//if secs < (1u64 << 62) || secs >= (1u64 << 63) {
|
||||
// return Err(WireGuardError::InvalidTai64nTimestamp);
|
||||
//};
|
||||
//if nano >= 1_000_000_000 {
|
||||
// return Err(WireGuardError::InvalidTai64nTimestamp);
|
||||
//}
|
||||
|
||||
Ok(Tai64N { secs, nano })
|
||||
}
|
||||
|
||||
/// Check if this timestamp represents a time that is chronologically after the time represented
|
||||
/// by the other timestamp
|
||||
pub fn after(&self, other: &Tai64N) -> bool {
|
||||
(self.secs > other.secs) || ((self.secs == other.secs) && (self.nano > other.nano))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters used by the noise protocol
|
||||
struct NoiseParams {
|
||||
/// Our static public key
|
||||
static_public: x25519::PublicKey,
|
||||
/// Our static private key
|
||||
static_private: x25519::StaticSecret,
|
||||
/// Static public key of the other party
|
||||
peer_static_public: x25519::PublicKey,
|
||||
/// A shared key = DH(static_private, peer_static_public)
|
||||
static_shared: x25519::SharedSecret,
|
||||
/// A pre-computation of HASH("mac1----", peer_static_public) for this peer
|
||||
sending_mac1_key: [u8; KEY_LEN],
|
||||
/// An optional preshared key
|
||||
preshared_key: Option<[u8; KEY_LEN]>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for NoiseParams {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("NoiseParams")
|
||||
.field("static_public", &self.static_public)
|
||||
.field("static_private", &"<redacted>")
|
||||
.field("peer_static_public", &self.peer_static_public)
|
||||
.field("static_shared", &"<redacted>")
|
||||
.field("sending_mac1_key", &self.sending_mac1_key)
|
||||
.field("preshared_key", &self.preshared_key)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
struct HandshakeInitSentState {
|
||||
local_index: u32,
|
||||
hash: [u8; KEY_LEN],
|
||||
chaining_key: [u8; KEY_LEN],
|
||||
ephemeral_private: x25519::ReusableSecret,
|
||||
time_sent: Instant,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for HandshakeInitSentState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("HandshakeInitSentState")
|
||||
.field("local_index", &self.local_index)
|
||||
.field("hash", &self.hash)
|
||||
.field("chaining_key", &self.chaining_key)
|
||||
.field("ephemeral_private", &"<redacted>")
|
||||
.field("time_sent", &self.time_sent)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum HandshakeState {
|
||||
/// No handshake in process
|
||||
None,
|
||||
/// We initiated the handshake
|
||||
InitSent(HandshakeInitSentState),
|
||||
/// Handshake initiated by peer
|
||||
InitReceived {
|
||||
hash: [u8; KEY_LEN],
|
||||
chaining_key: [u8; KEY_LEN],
|
||||
peer_ephemeral_public: x25519::PublicKey,
|
||||
peer_index: u32,
|
||||
},
|
||||
/// Handshake was established too long ago (implies no handshake is in progress)
|
||||
Expired,
|
||||
}
|
||||
|
||||
pub struct Handshake {
|
||||
params: NoiseParams,
|
||||
/// Index of the next session
|
||||
next_index: u32,
|
||||
/// Allow to have two outgoing handshakes in flight, because sometimes we may receive a delayed response to a handshake with bad networks
|
||||
previous: HandshakeState,
|
||||
/// Current handshake state
|
||||
state: HandshakeState,
|
||||
cookies: Cookies,
|
||||
/// The timestamp of the last handshake we received
|
||||
last_handshake_timestamp: Tai64N,
|
||||
// TODO: make TimeStamper a singleton
|
||||
stamper: TimeStamper,
|
||||
pub(super) last_rtt: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Cookies {
|
||||
last_mac1: Option<[u8; 16]>,
|
||||
index: u32,
|
||||
write_cookie: Option<[u8; 16]>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HalfHandshake {
|
||||
pub peer_index: u32,
|
||||
pub peer_static_public: [u8; 32],
|
||||
}
|
||||
|
||||
pub fn parse_handshake_anon(
|
||||
static_private: &x25519::StaticSecret,
|
||||
static_public: &x25519::PublicKey,
|
||||
packet: &HandshakeInit,
|
||||
) -> Result<HalfHandshake, WireGuardError> {
|
||||
let peer_index = packet.sender_idx;
|
||||
// initiator.chaining_key = HASH(CONSTRUCTION)
|
||||
let mut chaining_key = INITIAL_CHAIN_KEY;
|
||||
// initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public)
|
||||
let mut hash = INITIAL_CHAIN_HASH;
|
||||
hash = b2s_hash(&hash, static_public.as_bytes());
|
||||
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
|
||||
let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral);
|
||||
// initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral)
|
||||
hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes());
|
||||
// temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral)
|
||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(
|
||||
&b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()),
|
||||
&[0x01],
|
||||
);
|
||||
// temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public))
|
||||
let ephemeral_shared = static_private.diffie_hellman(&peer_ephemeral_public);
|
||||
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
|
||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// key = HMAC(temp, initiator.chaining_key || 0x2)
|
||||
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
||||
|
||||
let mut peer_static_public = [0u8; KEY_LEN];
|
||||
// msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash)
|
||||
aead_chacha20_open(
|
||||
&mut peer_static_public,
|
||||
&key,
|
||||
0,
|
||||
packet.encrypted_static,
|
||||
&hash,
|
||||
)?;
|
||||
|
||||
Ok(HalfHandshake {
|
||||
peer_index,
|
||||
peer_static_public,
|
||||
})
|
||||
}
|
||||
|
||||
impl NoiseParams {
|
||||
/// New noise params struct from our secret key, peers public key, and optional preshared key
|
||||
fn new(
|
||||
static_private: x25519::StaticSecret,
|
||||
static_public: x25519::PublicKey,
|
||||
peer_static_public: x25519::PublicKey,
|
||||
preshared_key: Option<[u8; 32]>,
|
||||
) -> NoiseParams {
|
||||
let static_shared = static_private.diffie_hellman(&peer_static_public);
|
||||
|
||||
let initial_sending_mac_key = b2s_hash(LABEL_MAC1, peer_static_public.as_bytes());
|
||||
|
||||
NoiseParams {
|
||||
static_public,
|
||||
static_private,
|
||||
peer_static_public,
|
||||
static_shared,
|
||||
sending_mac1_key: initial_sending_mac_key,
|
||||
preshared_key,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a new private key
|
||||
fn set_static_private(
|
||||
&mut self,
|
||||
static_private: x25519::StaticSecret,
|
||||
static_public: x25519::PublicKey,
|
||||
) {
|
||||
// Check that the public key indeed matches the private key
|
||||
let check_key = x25519::PublicKey::from(&static_private);
|
||||
assert_eq!(check_key.as_bytes(), static_public.as_bytes());
|
||||
|
||||
self.static_private = static_private;
|
||||
self.static_public = static_public;
|
||||
|
||||
self.static_shared = self.static_private.diffie_hellman(&self.peer_static_public);
|
||||
}
|
||||
}
|
||||
|
||||
impl Handshake {
|
||||
pub(crate) fn new(
|
||||
static_private: x25519::StaticSecret,
|
||||
static_public: x25519::PublicKey,
|
||||
peer_static_public: x25519::PublicKey,
|
||||
global_idx: u32,
|
||||
preshared_key: Option<[u8; 32]>,
|
||||
) -> Handshake {
|
||||
let params = NoiseParams::new(
|
||||
static_private,
|
||||
static_public,
|
||||
peer_static_public,
|
||||
preshared_key,
|
||||
);
|
||||
|
||||
Handshake {
|
||||
params,
|
||||
next_index: global_idx,
|
||||
previous: HandshakeState::None,
|
||||
state: HandshakeState::None,
|
||||
last_handshake_timestamp: Tai64N::zero(),
|
||||
stamper: TimeStamper::new(),
|
||||
cookies: Default::default(),
|
||||
last_rtt: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_in_progress(&self) -> bool {
|
||||
!matches!(self.state, HandshakeState::None | HandshakeState::Expired)
|
||||
}
|
||||
|
||||
pub(crate) fn timer(&self) -> Option<Instant> {
|
||||
match self.state {
|
||||
HandshakeState::InitSent(HandshakeInitSentState { time_sent, .. }) => Some(time_sent),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_expired(&mut self) {
|
||||
self.previous = HandshakeState::Expired;
|
||||
self.state = HandshakeState::Expired;
|
||||
}
|
||||
|
||||
pub(crate) fn is_expired(&self) -> bool {
|
||||
matches!(self.state, HandshakeState::Expired)
|
||||
}
|
||||
|
||||
pub(crate) fn has_cookie(&self) -> bool {
|
||||
self.cookies.write_cookie.is_some()
|
||||
}
|
||||
|
||||
pub(crate) fn clear_cookie(&mut self) {
|
||||
self.cookies.write_cookie = None;
|
||||
}
|
||||
|
||||
// The index used is 24 bits for peer index, allowing for 16M active peers per server and 8 bits for cyclic session index
|
||||
fn inc_index(&mut self) -> u32 {
|
||||
let index = self.next_index;
|
||||
let idx8 = index as u8;
|
||||
self.next_index = (index & !0xff) | u32::from(idx8.wrapping_add(1));
|
||||
self.next_index
|
||||
}
|
||||
|
||||
pub(crate) fn set_static_private(
|
||||
&mut self,
|
||||
private_key: x25519::StaticSecret,
|
||||
public_key: x25519::PublicKey,
|
||||
) {
|
||||
self.params.set_static_private(private_key, public_key)
|
||||
}
|
||||
|
||||
pub(super) fn receive_handshake_initialization<'a>(
|
||||
&mut self,
|
||||
packet: HandshakeInit,
|
||||
dst: &'a mut [u8],
|
||||
) -> Result<(&'a mut [u8], Session), WireGuardError> {
|
||||
// initiator.chaining_key = HASH(CONSTRUCTION)
|
||||
let mut chaining_key = INITIAL_CHAIN_KEY;
|
||||
// initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public)
|
||||
let mut hash = INITIAL_CHAIN_HASH;
|
||||
hash = b2s_hash(&hash, self.params.static_public.as_bytes());
|
||||
// msg.sender_index = little_endian(initiator.sender_index)
|
||||
let peer_index = packet.sender_idx;
|
||||
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
|
||||
let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral);
|
||||
// initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral)
|
||||
hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes());
|
||||
// temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral)
|
||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(
|
||||
&b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()),
|
||||
&[0x01],
|
||||
);
|
||||
// temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public))
|
||||
let ephemeral_shared = self
|
||||
.params
|
||||
.static_private
|
||||
.diffie_hellman(&peer_ephemeral_public);
|
||||
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
|
||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// key = HMAC(temp, initiator.chaining_key || 0x2)
|
||||
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
||||
|
||||
let mut peer_static_public_decrypted = [0u8; KEY_LEN];
|
||||
// msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash)
|
||||
aead_chacha20_open(
|
||||
&mut peer_static_public_decrypted,
|
||||
&key,
|
||||
0,
|
||||
packet.encrypted_static,
|
||||
&hash,
|
||||
)?;
|
||||
|
||||
ring::constant_time::verify_slices_are_equal(
|
||||
self.params.peer_static_public.as_bytes(),
|
||||
&peer_static_public_decrypted,
|
||||
)
|
||||
.map_err(|_| WireGuardError::WrongKey)?;
|
||||
|
||||
// initiator.hash = HASH(initiator.hash || msg.encrypted_static)
|
||||
hash = b2s_hash(&hash, packet.encrypted_static);
|
||||
// temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public))
|
||||
let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes());
|
||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// key = HMAC(temp, initiator.chaining_key || 0x2)
|
||||
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
||||
// msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash)
|
||||
let mut timestamp = [0u8; TIMESTAMP_LEN];
|
||||
aead_chacha20_open(&mut timestamp, &key, 0, packet.encrypted_timestamp, &hash)?;
|
||||
|
||||
let timestamp = Tai64N::parse(×tamp)?;
|
||||
if !timestamp.after(&self.last_handshake_timestamp) {
|
||||
// Possibly a replay
|
||||
return Err(WireGuardError::WrongTai64nTimestamp);
|
||||
}
|
||||
self.last_handshake_timestamp = timestamp;
|
||||
|
||||
// initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp)
|
||||
hash = b2s_hash(&hash, packet.encrypted_timestamp);
|
||||
|
||||
self.previous = std::mem::replace(
|
||||
&mut self.state,
|
||||
HandshakeState::InitReceived {
|
||||
chaining_key,
|
||||
hash,
|
||||
peer_ephemeral_public,
|
||||
peer_index,
|
||||
},
|
||||
);
|
||||
|
||||
self.format_handshake_response(dst)
|
||||
}
|
||||
|
||||
pub(super) fn receive_handshake_response(
|
||||
&mut self,
|
||||
packet: HandshakeResponse,
|
||||
) -> Result<Session, WireGuardError> {
|
||||
// Check if there is a handshake awaiting a response and return the correct one
|
||||
let (state, is_previous) = match (&self.state, &self.previous) {
|
||||
(HandshakeState::InitSent(s), _) if s.local_index == packet.receiver_idx => (s, false),
|
||||
(_, HandshakeState::InitSent(s)) if s.local_index == packet.receiver_idx => (s, true),
|
||||
_ => return Err(WireGuardError::UnexpectedPacket),
|
||||
};
|
||||
|
||||
let peer_index = packet.sender_idx;
|
||||
let local_index = state.local_index;
|
||||
|
||||
let unencrypted_ephemeral = x25519::PublicKey::from(*packet.unencrypted_ephemeral);
|
||||
// msg.unencrypted_ephemeral = DH_PUBKEY(responder.ephemeral_private)
|
||||
// responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral)
|
||||
let mut hash = b2s_hash(&state.hash, unencrypted_ephemeral.as_bytes());
|
||||
// temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral)
|
||||
let temp = b2s_hmac(&state.chaining_key, unencrypted_ephemeral.as_bytes());
|
||||
// responder.chaining_key = HMAC(temp, 0x1)
|
||||
let mut chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.ephemeral_public))
|
||||
let ephemeral_shared = state
|
||||
.ephemeral_private
|
||||
.diffie_hellman(&unencrypted_ephemeral);
|
||||
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
|
||||
// responder.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.static_public))
|
||||
let temp = b2s_hmac(
|
||||
&chaining_key,
|
||||
&self
|
||||
.params
|
||||
.static_private
|
||||
.diffie_hellman(&unencrypted_ephemeral)
|
||||
.to_bytes(),
|
||||
);
|
||||
// responder.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// temp = HMAC(responder.chaining_key, preshared_key)
|
||||
let temp = b2s_hmac(
|
||||
&chaining_key,
|
||||
&self.params.preshared_key.unwrap_or([0u8; 32])[..],
|
||||
);
|
||||
// responder.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// temp2 = HMAC(temp, responder.chaining_key || 0x2)
|
||||
let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
||||
// key = HMAC(temp, temp2 || 0x3)
|
||||
let key = b2s_hmac2(&temp, &temp2, &[0x03]);
|
||||
// responder.hash = HASH(responder.hash || temp2)
|
||||
hash = b2s_hash(&hash, &temp2);
|
||||
// msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash)
|
||||
aead_chacha20_open(&mut [], &key, 0, packet.encrypted_nothing, &hash)?;
|
||||
|
||||
// responder.hash = HASH(responder.hash || msg.encrypted_nothing)
|
||||
// hash = b2s_hash(hash, buf[ENC_NOTHING_OFF..ENC_NOTHING_OFF + ENC_NOTHING_SZ]);
|
||||
|
||||
// Derive keys
|
||||
// temp1 = HMAC(initiator.chaining_key, [empty])
|
||||
// temp2 = HMAC(temp1, 0x1)
|
||||
// temp3 = HMAC(temp1, temp2 || 0x2)
|
||||
// initiator.sending_key = temp2
|
||||
// initiator.receiving_key = temp3
|
||||
// initiator.sending_key_counter = 0
|
||||
// initiator.receiving_key_counter = 0
|
||||
let temp1 = b2s_hmac(&chaining_key, &[]);
|
||||
let temp2 = b2s_hmac(&temp1, &[0x01]);
|
||||
let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]);
|
||||
|
||||
let rtt_time = Instant::now().duration_since(state.time_sent);
|
||||
self.last_rtt = Some(rtt_time.as_millis() as u32);
|
||||
|
||||
if is_previous {
|
||||
self.previous = HandshakeState::None;
|
||||
} else {
|
||||
self.state = HandshakeState::None;
|
||||
}
|
||||
Ok(Session::new(local_index, peer_index, temp3, temp2))
|
||||
}
|
||||
|
||||
pub(super) fn receive_cookie_reply(
|
||||
&mut self,
|
||||
packet: PacketCookieReply,
|
||||
) -> Result<(), WireGuardError> {
|
||||
let mac1 = match self.cookies.last_mac1 {
|
||||
Some(mac) => mac,
|
||||
None => {
|
||||
return Err(WireGuardError::UnexpectedPacket);
|
||||
}
|
||||
};
|
||||
|
||||
let local_index = self.cookies.index;
|
||||
if packet.receiver_idx != local_index {
|
||||
return Err(WireGuardError::WrongIndex);
|
||||
}
|
||||
// msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public), msg.nonce, cookie, last_received_msg.mac1)
|
||||
let key = b2s_hash(LABEL_COOKIE, self.params.peer_static_public.as_bytes()); // TODO: pre-compute
|
||||
|
||||
let payload = Payload {
|
||||
aad: &mac1[0..16],
|
||||
msg: packet.encrypted_cookie,
|
||||
};
|
||||
let plaintext = XChaCha20Poly1305::new_from_slice(&key)
|
||||
.unwrap()
|
||||
.decrypt(packet.nonce.into(), payload)
|
||||
.map_err(|_| WireGuardError::InvalidAeadTag)?;
|
||||
|
||||
let cookie = plaintext
|
||||
.try_into()
|
||||
.map_err(|_| WireGuardError::InvalidPacket)?;
|
||||
self.cookies.write_cookie = Some(cookie);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Compute and append mac1 and mac2 to a handshake message
|
||||
fn append_mac1_and_mac2<'a>(
|
||||
&mut self,
|
||||
local_index: u32,
|
||||
dst: &'a mut [u8],
|
||||
) -> Result<&'a mut [u8], WireGuardError> {
|
||||
let mac1_off = dst.len() - 32;
|
||||
let mac2_off = dst.len() - 16;
|
||||
|
||||
// msg.mac1 = MAC(HASH(LABEL_MAC1 || responder.static_public), msg[0:offsetof(msg.mac1)])
|
||||
let msg_mac1 = b2s_keyed_mac_16(&self.params.sending_mac1_key, &dst[..mac1_off]);
|
||||
|
||||
dst[mac1_off..mac2_off].copy_from_slice(&msg_mac1[..]);
|
||||
|
||||
//msg.mac2 = MAC(initiator.last_received_cookie, msg[0:offsetof(msg.mac2)])
|
||||
let msg_mac2: [u8; 16] = if let Some(cookie) = self.cookies.write_cookie {
|
||||
b2s_keyed_mac_16(&cookie, &dst[..mac2_off])
|
||||
} else {
|
||||
[0u8; 16]
|
||||
};
|
||||
|
||||
dst[mac2_off..].copy_from_slice(&msg_mac2[..]);
|
||||
|
||||
self.cookies.index = local_index;
|
||||
self.cookies.last_mac1 = Some(msg_mac1);
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
pub(super) fn format_handshake_initiation<'a>(
|
||||
&mut self,
|
||||
dst: &'a mut [u8],
|
||||
) -> Result<&'a mut [u8], WireGuardError> {
|
||||
if dst.len() < super::HANDSHAKE_INIT_SZ {
|
||||
return Err(WireGuardError::DestinationBufferTooSmall);
|
||||
}
|
||||
|
||||
let (message_type, rest) = dst.split_at_mut(4);
|
||||
let (sender_index, rest) = rest.split_at_mut(4);
|
||||
let (unencrypted_ephemeral, rest) = rest.split_at_mut(32);
|
||||
let (encrypted_static, rest) = rest.split_at_mut(32 + 16);
|
||||
let (encrypted_timestamp, _) = rest.split_at_mut(12 + 16);
|
||||
|
||||
let local_index = self.inc_index();
|
||||
|
||||
// initiator.chaining_key = HASH(CONSTRUCTION)
|
||||
let mut chaining_key = INITIAL_CHAIN_KEY;
|
||||
// initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public)
|
||||
let mut hash = INITIAL_CHAIN_HASH;
|
||||
hash = b2s_hash(&hash, self.params.peer_static_public.as_bytes());
|
||||
// initiator.ephemeral_private = DH_GENERATE()
|
||||
let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng);
|
||||
// msg.message_type = 1
|
||||
// msg.reserved_zero = { 0, 0, 0 }
|
||||
message_type.copy_from_slice(&super::HANDSHAKE_INIT.to_le_bytes());
|
||||
// msg.sender_index = little_endian(initiator.sender_index)
|
||||
sender_index.copy_from_slice(&local_index.to_le_bytes());
|
||||
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
|
||||
unencrypted_ephemeral
|
||||
.copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes());
|
||||
// initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral)
|
||||
hash = b2s_hash(&hash, unencrypted_ephemeral);
|
||||
// temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral)
|
||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&b2s_hmac(&chaining_key, unencrypted_ephemeral), &[0x01]);
|
||||
// temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public))
|
||||
let ephemeral_shared = ephemeral_private.diffie_hellman(&self.params.peer_static_public);
|
||||
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
|
||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// key = HMAC(temp, initiator.chaining_key || 0x2)
|
||||
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
||||
// msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash)
|
||||
aead_chacha20_seal(
|
||||
encrypted_static,
|
||||
&key,
|
||||
0,
|
||||
self.params.static_public.as_bytes(),
|
||||
&hash,
|
||||
);
|
||||
// initiator.hash = HASH(initiator.hash || msg.encrypted_static)
|
||||
hash = b2s_hash(&hash, encrypted_static);
|
||||
// temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public))
|
||||
let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes());
|
||||
// initiator.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// key = HMAC(temp, initiator.chaining_key || 0x2)
|
||||
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
||||
// msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash)
|
||||
let timestamp = self.stamper.stamp();
|
||||
aead_chacha20_seal(encrypted_timestamp, &key, 0, ×tamp, &hash);
|
||||
// initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp)
|
||||
hash = b2s_hash(&hash, encrypted_timestamp);
|
||||
|
||||
let time_now = Instant::now();
|
||||
self.previous = std::mem::replace(
|
||||
&mut self.state,
|
||||
HandshakeState::InitSent(HandshakeInitSentState {
|
||||
local_index,
|
||||
chaining_key,
|
||||
hash,
|
||||
ephemeral_private,
|
||||
time_sent: time_now,
|
||||
}),
|
||||
);
|
||||
|
||||
self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_INIT_SZ])
|
||||
}
|
||||
|
||||
fn format_handshake_response<'a>(
|
||||
&mut self,
|
||||
dst: &'a mut [u8],
|
||||
) -> Result<(&'a mut [u8], Session), WireGuardError> {
|
||||
if dst.len() < super::HANDSHAKE_RESP_SZ {
|
||||
return Err(WireGuardError::DestinationBufferTooSmall);
|
||||
}
|
||||
|
||||
let state = std::mem::replace(&mut self.state, HandshakeState::None);
|
||||
let (mut chaining_key, mut hash, peer_ephemeral_public, peer_index) = match state {
|
||||
HandshakeState::InitReceived {
|
||||
chaining_key,
|
||||
hash,
|
||||
peer_ephemeral_public,
|
||||
peer_index,
|
||||
} => (chaining_key, hash, peer_ephemeral_public, peer_index),
|
||||
_ => {
|
||||
panic!("Unexpected attempt to call send_handshake_response");
|
||||
}
|
||||
};
|
||||
|
||||
let (message_type, rest) = dst.split_at_mut(4);
|
||||
let (sender_index, rest) = rest.split_at_mut(4);
|
||||
let (receiver_index, rest) = rest.split_at_mut(4);
|
||||
let (unencrypted_ephemeral, rest) = rest.split_at_mut(32);
|
||||
let (encrypted_nothing, _) = rest.split_at_mut(16);
|
||||
|
||||
// responder.ephemeral_private = DH_GENERATE()
|
||||
let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng);
|
||||
let local_index = self.inc_index();
|
||||
// msg.message_type = 2
|
||||
// msg.reserved_zero = { 0, 0, 0 }
|
||||
message_type.copy_from_slice(&super::HANDSHAKE_RESP.to_le_bytes());
|
||||
// msg.sender_index = little_endian(responder.sender_index)
|
||||
sender_index.copy_from_slice(&local_index.to_le_bytes());
|
||||
// msg.receiver_index = little_endian(initiator.sender_index)
|
||||
receiver_index.copy_from_slice(&peer_index.to_le_bytes());
|
||||
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
|
||||
unencrypted_ephemeral
|
||||
.copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes());
|
||||
// responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral)
|
||||
hash = b2s_hash(&hash, unencrypted_ephemeral);
|
||||
// temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral)
|
||||
let temp = b2s_hmac(&chaining_key, unencrypted_ephemeral);
|
||||
// responder.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.ephemeral_public))
|
||||
let ephemeral_shared = ephemeral_private.diffie_hellman(&peer_ephemeral_public);
|
||||
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
|
||||
// responder.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.static_public))
|
||||
let temp = b2s_hmac(
|
||||
&chaining_key,
|
||||
&ephemeral_private
|
||||
.diffie_hellman(&self.params.peer_static_public)
|
||||
.to_bytes(),
|
||||
);
|
||||
// responder.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// temp = HMAC(responder.chaining_key, preshared_key)
|
||||
let temp = b2s_hmac(
|
||||
&chaining_key,
|
||||
&self.params.preshared_key.unwrap_or([0u8; 32])[..],
|
||||
);
|
||||
// responder.chaining_key = HMAC(temp, 0x1)
|
||||
chaining_key = b2s_hmac(&temp, &[0x01]);
|
||||
// temp2 = HMAC(temp, responder.chaining_key || 0x2)
|
||||
let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]);
|
||||
// key = HMAC(temp, temp2 || 0x3)
|
||||
let key = b2s_hmac2(&temp, &temp2, &[0x03]);
|
||||
// responder.hash = HASH(responder.hash || temp2)
|
||||
hash = b2s_hash(&hash, &temp2);
|
||||
// msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash)
|
||||
aead_chacha20_seal(encrypted_nothing, &key, 0, &[], &hash);
|
||||
|
||||
// Derive keys
|
||||
// temp1 = HMAC(initiator.chaining_key, [empty])
|
||||
// temp2 = HMAC(temp1, 0x1)
|
||||
// temp3 = HMAC(temp1, temp2 || 0x2)
|
||||
// initiator.sending_key = temp2
|
||||
// initiator.receiving_key = temp3
|
||||
// initiator.sending_key_counter = 0
|
||||
// initiator.receiving_key_counter = 0
|
||||
let temp1 = b2s_hmac(&chaining_key, &[]);
|
||||
let temp2 = b2s_hmac(&temp1, &[0x01]);
|
||||
let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]);
|
||||
|
||||
let dst = self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_RESP_SZ])?;
|
||||
|
||||
Ok((dst, Session::new(local_index, peer_index, temp2, temp3)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn chacha20_seal_rfc7530_test_vector() {
|
||||
let plaintext = b"Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it.";
|
||||
let aad: [u8; 12] = [
|
||||
0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
|
||||
];
|
||||
let key: [u8; 32] = [
|
||||
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d,
|
||||
0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b,
|
||||
0x9c, 0x9d, 0x9e, 0x9f,
|
||||
];
|
||||
let nonce: [u8; 12] = [
|
||||
0x07, 0x00, 0x00, 0x00, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47,
|
||||
];
|
||||
let mut buffer = vec![0; plaintext.len() + 16];
|
||||
|
||||
aead_chacha20_seal_inner(&mut buffer, &key, nonce, plaintext, &aad);
|
||||
|
||||
const EXPECTED_CIPHERTEXT: [u8; 114] = [
|
||||
0xd3, 0x1a, 0x8d, 0x34, 0x64, 0x8e, 0x60, 0xdb, 0x7b, 0x86, 0xaf, 0xbc, 0x53, 0xef,
|
||||
0x7e, 0xc2, 0xa4, 0xad, 0xed, 0x51, 0x29, 0x6e, 0x08, 0xfe, 0xa9, 0xe2, 0xb5, 0xa7,
|
||||
0x36, 0xee, 0x62, 0xd6, 0x3d, 0xbe, 0xa4, 0x5e, 0x8c, 0xa9, 0x67, 0x12, 0x82, 0xfa,
|
||||
0xfb, 0x69, 0xda, 0x92, 0x72, 0x8b, 0x1a, 0x71, 0xde, 0x0a, 0x9e, 0x06, 0x0b, 0x29,
|
||||
0x05, 0xd6, 0xa5, 0xb6, 0x7e, 0xcd, 0x3b, 0x36, 0x92, 0xdd, 0xbd, 0x7f, 0x2d, 0x77,
|
||||
0x8b, 0x8c, 0x98, 0x03, 0xae, 0xe3, 0x28, 0x09, 0x1b, 0x58, 0xfa, 0xb3, 0x24, 0xe4,
|
||||
0xfa, 0xd6, 0x75, 0x94, 0x55, 0x85, 0x80, 0x8b, 0x48, 0x31, 0xd7, 0xbc, 0x3f, 0xf4,
|
||||
0xde, 0xf0, 0x8e, 0x4b, 0x7a, 0x9d, 0xe5, 0x76, 0xd2, 0x65, 0x86, 0xce, 0xc6, 0x4b,
|
||||
0x61, 0x16,
|
||||
];
|
||||
const EXPECTED_TAG: [u8; 16] = [
|
||||
0x1a, 0xe1, 0x0b, 0x59, 0x4f, 0x09, 0xe2, 0x6a, 0x7e, 0x90, 0x2e, 0xcb, 0xd0, 0x60,
|
||||
0x06, 0x91,
|
||||
];
|
||||
|
||||
assert_eq!(buffer[..plaintext.len()], EXPECTED_CIPHERTEXT);
|
||||
assert_eq!(buffer[plaintext.len()..], EXPECTED_TAG);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn symmetric_chacha20_seal_open() {
|
||||
let aad: [u8; 32] = Default::default();
|
||||
let key: [u8; 32] = Default::default();
|
||||
let counter = 0;
|
||||
|
||||
let mut encrypted_nothing: [u8; 16] = Default::default();
|
||||
|
||||
aead_chacha20_seal(&mut encrypted_nothing, &key, counter, &[], &aad);
|
||||
|
||||
eprintln!("encrypted_nothing: {:?}", encrypted_nothing);
|
||||
|
||||
aead_chacha20_open(&mut [], &key, counter, &encrypted_nothing, &aad)
|
||||
.expect("Should open what we just sealed");
|
||||
}
|
||||
}
|
||||
794
lib/boringtun/src/noise/mod.rs
Normal file
794
lib/boringtun/src/noise/mod.rs
Normal file
@@ -0,0 +1,794 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
pub mod errors;
|
||||
pub mod handshake;
|
||||
pub mod rate_limiter;
|
||||
|
||||
mod session;
|
||||
mod timers;
|
||||
|
||||
use crate::noise::errors::WireGuardError;
|
||||
use crate::noise::handshake::Handshake;
|
||||
use crate::noise::rate_limiter::RateLimiter;
|
||||
use crate::noise::timers::{TimerName, Timers};
|
||||
use crate::x25519;
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// The default value to use for rate limiting, when no other rate limiter is defined
|
||||
const PEER_HANDSHAKE_RATE_LIMIT: u64 = 10;
|
||||
|
||||
const IPV4_MIN_HEADER_SIZE: usize = 20;
|
||||
const IPV4_LEN_OFF: usize = 2;
|
||||
const IPV4_SRC_IP_OFF: usize = 12;
|
||||
const IPV4_DST_IP_OFF: usize = 16;
|
||||
const IPV4_IP_SZ: usize = 4;
|
||||
|
||||
const IPV6_MIN_HEADER_SIZE: usize = 40;
|
||||
const IPV6_LEN_OFF: usize = 4;
|
||||
const IPV6_SRC_IP_OFF: usize = 8;
|
||||
const IPV6_DST_IP_OFF: usize = 24;
|
||||
const IPV6_IP_SZ: usize = 16;
|
||||
|
||||
const IP_LEN_SZ: usize = 2;
|
||||
|
||||
const MAX_QUEUE_DEPTH: usize = 256;
|
||||
/// number of sessions in the ring, better keep a PoT
|
||||
const N_SESSIONS: usize = 8;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TunnResult<'a> {
|
||||
Done,
|
||||
Err(WireGuardError),
|
||||
WriteToNetwork(&'a mut [u8]),
|
||||
WriteToTunnelV4(&'a mut [u8], Ipv4Addr),
|
||||
WriteToTunnelV6(&'a mut [u8], Ipv6Addr),
|
||||
}
|
||||
|
||||
impl<'a> From<WireGuardError> for TunnResult<'a> {
|
||||
fn from(err: WireGuardError) -> TunnResult<'a> {
|
||||
TunnResult::Err(err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Tunnel represents a point-to-point WireGuard connection
|
||||
pub struct Tunn {
|
||||
/// The handshake currently in progress
|
||||
handshake: handshake::Handshake,
|
||||
/// The N_SESSIONS most recent sessions, index is session id modulo N_SESSIONS
|
||||
sessions: [Option<session::Session>; N_SESSIONS],
|
||||
/// Index of most recently used session
|
||||
current: usize,
|
||||
/// Queue to store blocked packets
|
||||
packet_queue: VecDeque<Vec<u8>>,
|
||||
/// Keeps tabs on the expiring timers
|
||||
timers: timers::Timers,
|
||||
tx_bytes: usize,
|
||||
rx_bytes: usize,
|
||||
rate_limiter: Arc<RateLimiter>,
|
||||
}
|
||||
|
||||
type MessageType = u32;
|
||||
const HANDSHAKE_INIT: MessageType = 1;
|
||||
const HANDSHAKE_RESP: MessageType = 2;
|
||||
const COOKIE_REPLY: MessageType = 3;
|
||||
const DATA: MessageType = 4;
|
||||
|
||||
const HANDSHAKE_INIT_SZ: usize = 148;
|
||||
const HANDSHAKE_RESP_SZ: usize = 92;
|
||||
const COOKIE_REPLY_SZ: usize = 64;
|
||||
const DATA_OVERHEAD_SZ: usize = 32;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HandshakeInit<'a> {
|
||||
sender_idx: u32,
|
||||
pub unencrypted_ephemeral: &'a [u8; 32],
|
||||
encrypted_static: &'a [u8],
|
||||
encrypted_timestamp: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HandshakeResponse<'a> {
|
||||
sender_idx: u32,
|
||||
pub receiver_idx: u32,
|
||||
pub unencrypted_ephemeral: &'a [u8; 32],
|
||||
encrypted_nothing: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PacketCookieReply<'a> {
|
||||
pub receiver_idx: u32,
|
||||
nonce: &'a [u8],
|
||||
encrypted_cookie: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PacketData<'a> {
|
||||
pub receiver_idx: u32,
|
||||
counter: u64,
|
||||
encrypted_encapsulated_packet: &'a [u8],
|
||||
}
|
||||
|
||||
/// Describes a packet from network
|
||||
#[derive(Debug)]
|
||||
pub enum Packet<'a> {
|
||||
HandshakeInit(HandshakeInit<'a>),
|
||||
HandshakeResponse(HandshakeResponse<'a>),
|
||||
PacketCookieReply(PacketCookieReply<'a>),
|
||||
PacketData(PacketData<'a>),
|
||||
}
|
||||
|
||||
impl Tunn {
|
||||
#[inline(always)]
|
||||
pub fn parse_incoming_packet(src: &[u8]) -> Result<Packet, WireGuardError> {
|
||||
if src.len() < 4 {
|
||||
return Err(WireGuardError::InvalidPacket);
|
||||
}
|
||||
|
||||
// Checks the type, as well as the reserved zero fields
|
||||
let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap());
|
||||
|
||||
Ok(match (packet_type, src.len()) {
|
||||
(HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit {
|
||||
sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
|
||||
unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[8..40])
|
||||
.expect("length already checked above"),
|
||||
encrypted_static: &src[40..88],
|
||||
encrypted_timestamp: &src[88..116],
|
||||
}),
|
||||
(HANDSHAKE_RESP, HANDSHAKE_RESP_SZ) => Packet::HandshakeResponse(HandshakeResponse {
|
||||
sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
|
||||
receiver_idx: u32::from_le_bytes(src[8..12].try_into().unwrap()),
|
||||
unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[12..44])
|
||||
.expect("length already checked above"),
|
||||
encrypted_nothing: &src[44..60],
|
||||
}),
|
||||
(COOKIE_REPLY, COOKIE_REPLY_SZ) => Packet::PacketCookieReply(PacketCookieReply {
|
||||
receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
|
||||
nonce: &src[8..32],
|
||||
encrypted_cookie: &src[32..64],
|
||||
}),
|
||||
(DATA, DATA_OVERHEAD_SZ..=std::usize::MAX) => Packet::PacketData(PacketData {
|
||||
receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
|
||||
counter: u64::from_le_bytes(src[8..16].try_into().unwrap()),
|
||||
encrypted_encapsulated_packet: &src[16..],
|
||||
}),
|
||||
_ => return Err(WireGuardError::InvalidPacket),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_expired(&self) -> bool {
|
||||
self.handshake.is_expired()
|
||||
}
|
||||
|
||||
pub fn dst_address(packet: &[u8]) -> Option<IpAddr> {
|
||||
if packet.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
match packet[0] >> 4 {
|
||||
4 if packet.len() >= IPV4_MIN_HEADER_SIZE => {
|
||||
let addr_bytes: [u8; IPV4_IP_SZ] = packet
|
||||
[IPV4_DST_IP_OFF..IPV4_DST_IP_OFF + IPV4_IP_SZ]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
Some(IpAddr::from(addr_bytes))
|
||||
}
|
||||
6 if packet.len() >= IPV6_MIN_HEADER_SIZE => {
|
||||
let addr_bytes: [u8; IPV6_IP_SZ] = packet
|
||||
[IPV6_DST_IP_OFF..IPV6_DST_IP_OFF + IPV6_IP_SZ]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
Some(IpAddr::from(addr_bytes))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new tunnel using own private key and the peer public key
|
||||
pub fn new(
|
||||
static_private: x25519::StaticSecret,
|
||||
peer_static_public: x25519::PublicKey,
|
||||
preshared_key: Option<[u8; 32]>,
|
||||
persistent_keepalive: Option<u16>,
|
||||
index: u32,
|
||||
rate_limiter: Option<Arc<RateLimiter>>,
|
||||
) -> Self {
|
||||
let static_public = x25519::PublicKey::from(&static_private);
|
||||
|
||||
Tunn {
|
||||
handshake: Handshake::new(
|
||||
static_private,
|
||||
static_public,
|
||||
peer_static_public,
|
||||
index << 8,
|
||||
preshared_key,
|
||||
),
|
||||
sessions: Default::default(),
|
||||
current: Default::default(),
|
||||
tx_bytes: Default::default(),
|
||||
rx_bytes: Default::default(),
|
||||
|
||||
packet_queue: VecDeque::new(),
|
||||
timers: Timers::new(persistent_keepalive, rate_limiter.is_none()),
|
||||
|
||||
rate_limiter: rate_limiter.unwrap_or_else(|| {
|
||||
Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT))
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the private key and clear existing sessions
|
||||
pub fn set_static_private(
|
||||
&mut self,
|
||||
static_private: x25519::StaticSecret,
|
||||
static_public: x25519::PublicKey,
|
||||
rate_limiter: Option<Arc<RateLimiter>>,
|
||||
) {
|
||||
self.timers.should_reset_rr = rate_limiter.is_none();
|
||||
self.rate_limiter = rate_limiter.unwrap_or_else(|| {
|
||||
Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT))
|
||||
});
|
||||
self.handshake
|
||||
.set_static_private(static_private, static_public);
|
||||
for s in &mut self.sessions {
|
||||
*s = None;
|
||||
}
|
||||
}
|
||||
|
||||
/// Encapsulate a single packet from the tunnel interface.
|
||||
/// Returns TunnResult.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if dst buffer is too small.
|
||||
/// Size of dst should be at least src.len() + 32, and no less than 148 bytes.
|
||||
pub fn encapsulate<'a>(&mut self, src: &[u8], dst: &'a mut [u8]) -> TunnResult<'a> {
|
||||
let current = self.current;
|
||||
if let Some(ref session) = self.sessions[current % N_SESSIONS] {
|
||||
// Send the packet using an established session
|
||||
let packet = session.format_packet_data(src, dst);
|
||||
self.timer_tick(TimerName::TimeLastPacketSent);
|
||||
// Exclude Keepalive packets from timer update.
|
||||
if !src.is_empty() {
|
||||
self.timer_tick(TimerName::TimeLastDataPacketSent);
|
||||
}
|
||||
self.tx_bytes += src.len();
|
||||
return TunnResult::WriteToNetwork(packet);
|
||||
}
|
||||
|
||||
// If there is no session, queue the packet for future retry
|
||||
self.queue_packet(src);
|
||||
// Initiate a new handshake if none is in progress
|
||||
self.format_handshake_initiation(dst, false)
|
||||
}
|
||||
|
||||
/// Receives a UDP datagram from the network and parses it.
|
||||
/// Returns TunnResult.
|
||||
///
|
||||
/// If the result is of type TunnResult::WriteToNetwork, should repeat the call with empty datagram,
|
||||
/// until TunnResult::Done is returned. If batch processing packets, it is OK to defer until last
|
||||
/// packet is processed.
|
||||
pub fn decapsulate<'a>(
|
||||
&mut self,
|
||||
src_addr: Option<IpAddr>,
|
||||
datagram: &[u8],
|
||||
dst: &'a mut [u8],
|
||||
) -> TunnResult<'a> {
|
||||
if datagram.is_empty() {
|
||||
// Indicates a repeated call
|
||||
return self.send_queued_packet(dst);
|
||||
}
|
||||
|
||||
let mut cookie = [0u8; COOKIE_REPLY_SZ];
|
||||
let packet = match self
|
||||
.rate_limiter
|
||||
.verify_packet(src_addr, datagram, &mut cookie)
|
||||
{
|
||||
Ok(packet) => packet,
|
||||
Err(TunnResult::WriteToNetwork(cookie)) => {
|
||||
dst[..cookie.len()].copy_from_slice(cookie);
|
||||
return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]);
|
||||
}
|
||||
Err(TunnResult::Err(e)) => return TunnResult::Err(e),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
self.handle_verified_packet(packet, dst)
|
||||
}
|
||||
|
||||
pub(crate) fn handle_verified_packet<'a>(
|
||||
&mut self,
|
||||
packet: Packet,
|
||||
dst: &'a mut [u8],
|
||||
) -> TunnResult<'a> {
|
||||
match packet {
|
||||
Packet::HandshakeInit(p) => self.handle_handshake_init(p, dst),
|
||||
Packet::HandshakeResponse(p) => self.handle_handshake_response(p, dst),
|
||||
Packet::PacketCookieReply(p) => self.handle_cookie_reply(p),
|
||||
Packet::PacketData(p) => self.handle_data(p, dst),
|
||||
}
|
||||
.unwrap_or_else(TunnResult::from)
|
||||
}
|
||||
|
||||
fn handle_handshake_init<'a>(
|
||||
&mut self,
|
||||
p: HandshakeInit,
|
||||
dst: &'a mut [u8],
|
||||
) -> Result<TunnResult<'a>, WireGuardError> {
|
||||
tracing::debug!(
|
||||
message = "Received handshake_initiation",
|
||||
remote_idx = p.sender_idx
|
||||
);
|
||||
|
||||
let (packet, session) = self.handshake.receive_handshake_initialization(p, dst)?;
|
||||
|
||||
// Store new session in ring buffer
|
||||
let index = session.local_index();
|
||||
self.sessions[index % N_SESSIONS] = Some(session);
|
||||
|
||||
self.timer_tick(TimerName::TimeLastPacketReceived);
|
||||
self.timer_tick(TimerName::TimeLastPacketSent);
|
||||
self.timer_tick_session_established(false, index); // New session established, we are not the initiator
|
||||
|
||||
tracing::debug!(message = "Sending handshake_response", local_idx = index);
|
||||
|
||||
Ok(TunnResult::WriteToNetwork(packet))
|
||||
}
|
||||
|
||||
fn handle_handshake_response<'a>(
|
||||
&mut self,
|
||||
p: HandshakeResponse,
|
||||
dst: &'a mut [u8],
|
||||
) -> Result<TunnResult<'a>, WireGuardError> {
|
||||
tracing::debug!(
|
||||
message = "Received handshake_response",
|
||||
local_idx = p.receiver_idx,
|
||||
remote_idx = p.sender_idx
|
||||
);
|
||||
|
||||
let session = self.handshake.receive_handshake_response(p)?;
|
||||
|
||||
let keepalive_packet = session.format_packet_data(&[], dst);
|
||||
// Store new session in ring buffer
|
||||
let l_idx = session.local_index();
|
||||
let index = l_idx % N_SESSIONS;
|
||||
self.sessions[index] = Some(session);
|
||||
|
||||
self.timer_tick(TimerName::TimeLastPacketReceived);
|
||||
self.timer_tick_session_established(true, index); // New session established, we are the initiator
|
||||
self.set_current_session(l_idx);
|
||||
|
||||
tracing::debug!("Sending keepalive");
|
||||
|
||||
Ok(TunnResult::WriteToNetwork(keepalive_packet)) // Send a keepalive as a response
|
||||
}
|
||||
|
||||
fn handle_cookie_reply<'a>(
|
||||
&mut self,
|
||||
p: PacketCookieReply,
|
||||
) -> Result<TunnResult<'a>, WireGuardError> {
|
||||
tracing::debug!(
|
||||
message = "Received cookie_reply",
|
||||
local_idx = p.receiver_idx
|
||||
);
|
||||
|
||||
self.handshake.receive_cookie_reply(p)?;
|
||||
self.timer_tick(TimerName::TimeLastPacketReceived);
|
||||
self.timer_tick(TimerName::TimeCookieReceived);
|
||||
|
||||
tracing::debug!("Did set cookie");
|
||||
|
||||
Ok(TunnResult::Done)
|
||||
}
|
||||
|
||||
/// Update the index of the currently used session, if needed
|
||||
fn set_current_session(&mut self, new_idx: usize) {
|
||||
let cur_idx = self.current;
|
||||
if cur_idx == new_idx {
|
||||
// There is nothing to do, already using this session, this is the common case
|
||||
return;
|
||||
}
|
||||
if self.sessions[cur_idx % N_SESSIONS].is_none()
|
||||
|| self.timers.session_timers[new_idx % N_SESSIONS]
|
||||
>= self.timers.session_timers[cur_idx % N_SESSIONS]
|
||||
{
|
||||
self.current = new_idx;
|
||||
tracing::debug!(message = "New session", session = new_idx);
|
||||
}
|
||||
}
|
||||
|
||||
/// Decrypts a data packet, and stores the decapsulated packet in dst.
|
||||
fn handle_data<'a>(
|
||||
&mut self,
|
||||
packet: PacketData,
|
||||
dst: &'a mut [u8],
|
||||
) -> Result<TunnResult<'a>, WireGuardError> {
|
||||
let r_idx = packet.receiver_idx as usize;
|
||||
let idx = r_idx % N_SESSIONS;
|
||||
|
||||
// Get the (probably) right session
|
||||
let decapsulated_packet = {
|
||||
let session = self.sessions[idx].as_ref();
|
||||
let session = session.ok_or_else(|| {
|
||||
tracing::trace!(message = "No current session available", remote_idx = r_idx);
|
||||
WireGuardError::NoCurrentSession
|
||||
})?;
|
||||
session.receive_packet_data(packet, dst)?
|
||||
};
|
||||
|
||||
self.set_current_session(r_idx);
|
||||
|
||||
self.timer_tick(TimerName::TimeLastPacketReceived);
|
||||
|
||||
Ok(self.validate_decapsulated_packet(decapsulated_packet))
|
||||
}
|
||||
|
||||
/// Formats a new handshake initiation message and store it in dst. If force_resend is true will send
|
||||
/// a new handshake, even if a handshake is already in progress (for example when a handshake times out)
|
||||
pub fn format_handshake_initiation<'a>(
|
||||
&mut self,
|
||||
dst: &'a mut [u8],
|
||||
force_resend: bool,
|
||||
) -> TunnResult<'a> {
|
||||
if self.handshake.is_in_progress() && !force_resend {
|
||||
return TunnResult::Done;
|
||||
}
|
||||
|
||||
if self.handshake.is_expired() {
|
||||
self.timers.clear();
|
||||
}
|
||||
|
||||
let starting_new_handshake = !self.handshake.is_in_progress();
|
||||
|
||||
match self.handshake.format_handshake_initiation(dst) {
|
||||
Ok(packet) => {
|
||||
tracing::debug!("Sending handshake_initiation");
|
||||
|
||||
if starting_new_handshake {
|
||||
self.timer_tick(TimerName::TimeLastHandshakeStarted);
|
||||
}
|
||||
self.timer_tick(TimerName::TimeLastPacketSent);
|
||||
TunnResult::WriteToNetwork(packet)
|
||||
}
|
||||
Err(e) => TunnResult::Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an IP packet is v4 or v6, truncate to the length indicated by the length field
|
||||
/// Returns the truncated packet and the source IP as TunnResult
|
||||
fn validate_decapsulated_packet<'a>(&mut self, packet: &'a mut [u8]) -> TunnResult<'a> {
|
||||
let (computed_len, src_ip_address) = match packet.len() {
|
||||
0 => return TunnResult::Done, // This is keepalive, and not an error
|
||||
_ if packet[0] >> 4 == 4 && packet.len() >= IPV4_MIN_HEADER_SIZE => {
|
||||
let len_bytes: [u8; IP_LEN_SZ] = packet[IPV4_LEN_OFF..IPV4_LEN_OFF + IP_LEN_SZ]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let addr_bytes: [u8; IPV4_IP_SZ] = packet
|
||||
[IPV4_SRC_IP_OFF..IPV4_SRC_IP_OFF + IPV4_IP_SZ]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
(
|
||||
u16::from_be_bytes(len_bytes) as usize,
|
||||
IpAddr::from(addr_bytes),
|
||||
)
|
||||
}
|
||||
_ if packet[0] >> 4 == 6 && packet.len() >= IPV6_MIN_HEADER_SIZE => {
|
||||
let len_bytes: [u8; IP_LEN_SZ] = packet[IPV6_LEN_OFF..IPV6_LEN_OFF + IP_LEN_SZ]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let addr_bytes: [u8; IPV6_IP_SZ] = packet
|
||||
[IPV6_SRC_IP_OFF..IPV6_SRC_IP_OFF + IPV6_IP_SZ]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
(
|
||||
u16::from_be_bytes(len_bytes) as usize + IPV6_MIN_HEADER_SIZE,
|
||||
IpAddr::from(addr_bytes),
|
||||
)
|
||||
}
|
||||
_ => return TunnResult::Err(WireGuardError::InvalidPacket),
|
||||
};
|
||||
|
||||
if computed_len > packet.len() {
|
||||
return TunnResult::Err(WireGuardError::InvalidPacket);
|
||||
}
|
||||
|
||||
self.timer_tick(TimerName::TimeLastDataPacketReceived);
|
||||
self.rx_bytes += computed_len;
|
||||
|
||||
match src_ip_address {
|
||||
IpAddr::V4(addr) => TunnResult::WriteToTunnelV4(&mut packet[..computed_len], addr),
|
||||
IpAddr::V6(addr) => TunnResult::WriteToTunnelV6(&mut packet[..computed_len], addr),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a packet from the queue, and try to encapsulate it
|
||||
fn send_queued_packet<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
|
||||
if let Some(packet) = self.dequeue_packet() {
|
||||
match self.encapsulate(&packet, dst) {
|
||||
TunnResult::Err(_) => {
|
||||
// On error, return packet to the queue
|
||||
self.requeue_packet(packet);
|
||||
}
|
||||
r => return r,
|
||||
}
|
||||
}
|
||||
TunnResult::Done
|
||||
}
|
||||
|
||||
/// Push packet to the back of the queue
|
||||
fn queue_packet(&mut self, packet: &[u8]) {
|
||||
if self.packet_queue.len() < MAX_QUEUE_DEPTH {
|
||||
// Drop if too many are already in queue
|
||||
self.packet_queue.push_back(packet.to_vec());
|
||||
}
|
||||
}
|
||||
|
||||
/// Push packet to the front of the queue
|
||||
fn requeue_packet(&mut self, packet: Vec<u8>) {
|
||||
if self.packet_queue.len() < MAX_QUEUE_DEPTH {
|
||||
// Drop if too many are already in queue
|
||||
self.packet_queue.push_front(packet);
|
||||
}
|
||||
}
|
||||
|
||||
fn dequeue_packet(&mut self) -> Option<Vec<u8>> {
|
||||
self.packet_queue.pop_front()
|
||||
}
|
||||
|
||||
fn estimate_loss(&self) -> f32 {
|
||||
let session_idx = self.current;
|
||||
|
||||
let mut weight = 9.0;
|
||||
let mut cur_avg = 0.0;
|
||||
let mut total_weight = 0.0;
|
||||
|
||||
for i in 0..N_SESSIONS {
|
||||
if let Some(ref session) = self.sessions[(session_idx.wrapping_sub(i)) % N_SESSIONS] {
|
||||
let (expected, received) = session.current_packet_cnt();
|
||||
|
||||
let loss = if expected == 0 {
|
||||
0.0
|
||||
} else {
|
||||
1.0 - received as f32 / expected as f32
|
||||
};
|
||||
|
||||
cur_avg += loss * weight;
|
||||
total_weight += weight;
|
||||
weight /= 3.0;
|
||||
}
|
||||
}
|
||||
|
||||
if total_weight == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
cur_avg / total_weight
|
||||
}
|
||||
}
|
||||
|
||||
/// Return stats from the tunnel:
|
||||
/// * Time since last handshake in seconds
|
||||
/// * Data bytes sent
|
||||
/// * Data bytes received
|
||||
pub fn stats(&self) -> (Option<Duration>, usize, usize, f32, Option<u32>) {
|
||||
let time = self.time_since_last_handshake();
|
||||
let tx_bytes = self.tx_bytes;
|
||||
let rx_bytes = self.rx_bytes;
|
||||
let loss = self.estimate_loss();
|
||||
let rtt = self.handshake.last_rtt;
|
||||
|
||||
(time, tx_bytes, rx_bytes, loss, rtt)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[cfg(feature = "mock-instant")]
|
||||
use crate::noise::timers::{REKEY_AFTER_TIME, REKEY_TIMEOUT};
|
||||
|
||||
use super::*;
|
||||
use rand_core::{OsRng, RngCore};
|
||||
|
||||
fn create_two_tuns() -> (Tunn, Tunn) {
|
||||
let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
|
||||
let my_public_key = x25519_dalek::PublicKey::from(&my_secret_key);
|
||||
let my_idx = OsRng.next_u32();
|
||||
|
||||
let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
|
||||
let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key);
|
||||
let their_idx = OsRng.next_u32();
|
||||
|
||||
let my_tun = Tunn::new(my_secret_key, their_public_key, None, None, my_idx, None);
|
||||
|
||||
let their_tun = Tunn::new(their_secret_key, my_public_key, None, None, their_idx, None);
|
||||
|
||||
(my_tun, their_tun)
|
||||
}
|
||||
|
||||
fn create_handshake_init(tun: &mut Tunn) -> Vec<u8> {
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let handshake_init = tun.format_handshake_initiation(&mut dst, false);
|
||||
assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
|
||||
let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
|
||||
sent
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
handshake_init.into()
|
||||
}
|
||||
|
||||
fn create_handshake_response(tun: &mut Tunn, handshake_init: &[u8]) -> Vec<u8> {
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let handshake_resp = tun.decapsulate(None, handshake_init, &mut dst);
|
||||
assert!(matches!(handshake_resp, TunnResult::WriteToNetwork(_)));
|
||||
|
||||
let handshake_resp = if let TunnResult::WriteToNetwork(sent) = handshake_resp {
|
||||
sent
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
handshake_resp.into()
|
||||
}
|
||||
|
||||
fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec<u8> {
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let keepalive = tun.decapsulate(None, handshake_resp, &mut dst);
|
||||
assert!(matches!(keepalive, TunnResult::WriteToNetwork(_)));
|
||||
|
||||
let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive {
|
||||
sent
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
keepalive.into()
|
||||
}
|
||||
|
||||
fn parse_keepalive(tun: &mut Tunn, keepalive: &[u8]) {
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let keepalive = tun.decapsulate(None, keepalive, &mut dst);
|
||||
assert!(matches!(keepalive, TunnResult::Done));
|
||||
}
|
||||
|
||||
fn create_two_tuns_and_handshake() -> (Tunn, Tunn) {
|
||||
let (mut my_tun, mut their_tun) = create_two_tuns();
|
||||
let init = create_handshake_init(&mut my_tun);
|
||||
let resp = create_handshake_response(&mut their_tun, &init);
|
||||
let keepalive = parse_handshake_resp(&mut my_tun, &resp);
|
||||
parse_keepalive(&mut their_tun, &keepalive);
|
||||
|
||||
(my_tun, their_tun)
|
||||
}
|
||||
|
||||
fn create_ipv4_udp_packet() -> Vec<u8> {
|
||||
let header =
|
||||
etherparse::PacketBuilder::ipv4([192, 168, 1, 2], [192, 168, 1, 3], 5).udp(5678, 23);
|
||||
let payload = [0, 1, 2, 3];
|
||||
let mut packet = Vec::<u8>::with_capacity(header.size(payload.len()));
|
||||
header.write(&mut packet, &payload).unwrap();
|
||||
packet
|
||||
}
|
||||
|
||||
#[cfg(feature = "mock-instant")]
|
||||
fn update_timer_results_in_handshake(tun: &mut Tunn) {
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let result = tun.update_timers(&mut dst);
|
||||
assert!(matches!(result, TunnResult::WriteToNetwork(_)));
|
||||
let packet_data = if let TunnResult::WriteToNetwork(data) = result {
|
||||
data
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
let packet = Tunn::parse_incoming_packet(packet_data).unwrap();
|
||||
assert!(matches!(packet, Packet::HandshakeInit(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_two_tunnels_linked_to_eachother() {
|
||||
let (_my_tun, _their_tun) = create_two_tuns();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handshake_init() {
|
||||
let (mut my_tun, _their_tun) = create_two_tuns();
|
||||
let init = create_handshake_init(&mut my_tun);
|
||||
let packet = Tunn::parse_incoming_packet(&init).unwrap();
|
||||
assert!(matches!(packet, Packet::HandshakeInit(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handshake_init_and_response() {
|
||||
let (mut my_tun, mut their_tun) = create_two_tuns();
|
||||
let init = create_handshake_init(&mut my_tun);
|
||||
let resp = create_handshake_response(&mut their_tun, &init);
|
||||
let packet = Tunn::parse_incoming_packet(&resp).unwrap();
|
||||
assert!(matches!(packet, Packet::HandshakeResponse(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_handshake() {
|
||||
let (mut my_tun, mut their_tun) = create_two_tuns();
|
||||
let init = create_handshake_init(&mut my_tun);
|
||||
let resp = create_handshake_response(&mut their_tun, &init);
|
||||
let keepalive = parse_handshake_resp(&mut my_tun, &resp);
|
||||
let packet = Tunn::parse_incoming_packet(&keepalive).unwrap();
|
||||
assert!(matches!(packet, Packet::PacketData(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_handshake_plus_timers() {
|
||||
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
|
||||
// Time has not yet advanced so their is nothing to do
|
||||
assert!(matches!(my_tun.update_timers(&mut []), TunnResult::Done));
|
||||
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mock-instant")]
|
||||
fn new_handshake_after_two_mins() {
|
||||
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
|
||||
let mut my_dst = [0u8; 1024];
|
||||
|
||||
// Advance time 1 second and "send" 1 packet so that we send a handshake
|
||||
// after the timeout
|
||||
mock_instant::MockClock::advance(Duration::from_secs(1));
|
||||
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
|
||||
assert!(matches!(
|
||||
my_tun.update_timers(&mut my_dst),
|
||||
TunnResult::Done
|
||||
));
|
||||
let sent_packet_buf = create_ipv4_udp_packet();
|
||||
let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst);
|
||||
assert!(matches!(data, TunnResult::WriteToNetwork(_)));
|
||||
|
||||
//Advance to timeout
|
||||
mock_instant::MockClock::advance(REKEY_AFTER_TIME);
|
||||
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
|
||||
update_timer_results_in_handshake(&mut my_tun);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mock-instant")]
|
||||
fn handshake_no_resp_rekey_timeout() {
|
||||
let (mut my_tun, _their_tun) = create_two_tuns();
|
||||
|
||||
let init = create_handshake_init(&mut my_tun);
|
||||
let packet = Tunn::parse_incoming_packet(&init).unwrap();
|
||||
assert!(matches!(packet, Packet::HandshakeInit(_)));
|
||||
|
||||
mock_instant::MockClock::advance(REKEY_TIMEOUT);
|
||||
update_timer_results_in_handshake(&mut my_tun)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_ip_packet() {
|
||||
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
|
||||
let mut my_dst = [0u8; 1024];
|
||||
let mut their_dst = [0u8; 1024];
|
||||
|
||||
let sent_packet_buf = create_ipv4_udp_packet();
|
||||
|
||||
let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst);
|
||||
assert!(matches!(data, TunnResult::WriteToNetwork(_)));
|
||||
let data = if let TunnResult::WriteToNetwork(sent) = data {
|
||||
sent
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
let data = their_tun.decapsulate(None, data, &mut their_dst);
|
||||
assert!(matches!(data, TunnResult::WriteToTunnelV4(..)));
|
||||
let recv_packet_buf = if let TunnResult::WriteToTunnelV4(recv, _addr) = data {
|
||||
recv
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
assert_eq!(sent_packet_buf, recv_packet_buf);
|
||||
}
|
||||
}
|
||||
193
lib/boringtun/src/noise/rate_limiter.rs
Normal file
193
lib/boringtun/src/noise/rate_limiter.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use super::handshake::{b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24};
|
||||
use crate::noise::handshake::{LABEL_COOKIE, LABEL_MAC1};
|
||||
use crate::noise::{HandshakeInit, HandshakeResponse, Packet, Tunn, TunnResult, WireGuardError};
|
||||
|
||||
#[cfg(feature = "mock-instant")]
|
||||
use mock_instant::Instant;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
#[cfg(not(feature = "mock-instant"))]
|
||||
use crate::sleepyinstant::Instant;
|
||||
|
||||
use aead::generic_array::GenericArray;
|
||||
use aead::{AeadInPlace, KeyInit};
|
||||
use chacha20poly1305::{Key, XChaCha20Poly1305};
|
||||
use parking_lot::Mutex;
|
||||
use rand_core::{OsRng, RngCore};
|
||||
use ring::constant_time::verify_slices_are_equal;
|
||||
|
||||
const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division
|
||||
const COOKIE_SIZE: usize = 16;
|
||||
const COOKIE_NONCE_SIZE: usize = 24;
|
||||
|
||||
/// How often should reset count in seconds
|
||||
const RESET_PERIOD: u64 = 1;
|
||||
|
||||
type Cookie = [u8; COOKIE_SIZE];
|
||||
|
||||
/// There are two places where WireGuard requires "randomness" for cookies
|
||||
/// * The 24 byte nonce in the cookie massage - here the only goal is to avoid nonce reuse
|
||||
/// * A secret value that changes every two minutes
|
||||
/// Because the main goal of the cookie is simply for a party to prove ownership of an IP address
|
||||
/// we can relax the randomness definition a bit, in order to avoid locking, because using less
|
||||
/// resources is the main goal of any DoS prevention mechanism.
|
||||
/// In order to avoid locking and calls to rand we derive pseudo random values using the AEAD and
|
||||
/// some counters.
|
||||
pub struct RateLimiter {
|
||||
/// The key we use to derive the nonce
|
||||
nonce_key: [u8; 32],
|
||||
/// The key we use to derive the cookie
|
||||
secret_key: [u8; 16],
|
||||
start_time: Instant,
|
||||
/// A single 64 bit counter (should suffice for many years)
|
||||
nonce_ctr: AtomicUsize,
|
||||
mac1_key: [u8; 32],
|
||||
cookie_key: Key,
|
||||
limit: usize,
|
||||
/// The counter since last reset
|
||||
count: AtomicUsize,
|
||||
/// The time last reset was performed on this rate limiter
|
||||
last_reset: Mutex<Instant>,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(public_key: &crate::x25519::PublicKey, limit: u64) -> Self {
|
||||
let mut secret_key = [0u8; 16];
|
||||
OsRng.fill_bytes(&mut secret_key);
|
||||
RateLimiter {
|
||||
nonce_key: Self::rand_bytes(),
|
||||
secret_key,
|
||||
start_time: Instant::now(),
|
||||
nonce_ctr: AtomicUsize::new(0),
|
||||
mac1_key: b2s_hash(LABEL_MAC1, public_key.as_bytes()),
|
||||
cookie_key: b2s_hash(LABEL_COOKIE, public_key.as_bytes()).into(),
|
||||
limit: limit as _,
|
||||
count: AtomicUsize::new(0),
|
||||
last_reset: Mutex::new(Instant::now()),
|
||||
}
|
||||
}
|
||||
|
||||
fn rand_bytes() -> [u8; 32] {
|
||||
let mut key = [0u8; 32];
|
||||
OsRng.fill_bytes(&mut key);
|
||||
key
|
||||
}
|
||||
|
||||
/// Reset packet count (ideally should be called with a period of 1 second)
|
||||
pub fn reset_count(&self) {
|
||||
// The rate limiter is not very accurate, but at the scale we care about it doesn't matter much
|
||||
let current_time = Instant::now();
|
||||
let mut last_reset_time = self.last_reset.lock();
|
||||
if current_time.duration_since(*last_reset_time).as_secs() >= RESET_PERIOD {
|
||||
self.count.store(0, Ordering::SeqCst);
|
||||
*last_reset_time = current_time;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the correct cookie value based on the current secret value and the source IP
|
||||
fn current_cookie(&self, addr: IpAddr) -> Cookie {
|
||||
let mut addr_bytes = [0u8; 16];
|
||||
|
||||
match addr {
|
||||
IpAddr::V4(a) => addr_bytes[..4].copy_from_slice(&a.octets()[..]),
|
||||
IpAddr::V6(a) => addr_bytes[..].copy_from_slice(&a.octets()[..]),
|
||||
}
|
||||
|
||||
// The current cookie for a given IP is the MAC(responder.changing_secret_every_two_minutes, initiator.ip_address)
|
||||
// First we derive the secret from the current time, the value of cur_counter would change with time.
|
||||
let cur_counter = Instant::now().duration_since(self.start_time).as_secs() / COOKIE_REFRESH;
|
||||
|
||||
// Next we derive the cookie
|
||||
b2s_keyed_mac_16_2(&self.secret_key, &cur_counter.to_le_bytes(), &addr_bytes)
|
||||
}
|
||||
|
||||
fn nonce(&self) -> [u8; COOKIE_NONCE_SIZE] {
|
||||
let ctr = self.nonce_ctr.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
b2s_mac_24(&self.nonce_key, &ctr.to_le_bytes())
|
||||
}
|
||||
|
||||
fn is_under_load(&self) -> bool {
|
||||
self.count.fetch_add(1, Ordering::SeqCst) >= self.limit
|
||||
}
|
||||
|
||||
pub(crate) fn format_cookie_reply<'a>(
|
||||
&self,
|
||||
idx: u32,
|
||||
cookie: Cookie,
|
||||
mac1: &[u8],
|
||||
dst: &'a mut [u8],
|
||||
) -> Result<&'a mut [u8], WireGuardError> {
|
||||
if dst.len() < super::COOKIE_REPLY_SZ {
|
||||
return Err(WireGuardError::DestinationBufferTooSmall);
|
||||
}
|
||||
|
||||
let (message_type, rest) = dst.split_at_mut(4);
|
||||
let (receiver_index, rest) = rest.split_at_mut(4);
|
||||
let (nonce, rest) = rest.split_at_mut(24);
|
||||
let (encrypted_cookie, _) = rest.split_at_mut(16 + 16);
|
||||
|
||||
// msg.message_type = 3
|
||||
// msg.reserved_zero = { 0, 0, 0 }
|
||||
message_type.copy_from_slice(&super::COOKIE_REPLY.to_le_bytes());
|
||||
// msg.receiver_index = little_endian(initiator.sender_index)
|
||||
receiver_index.copy_from_slice(&idx.to_le_bytes());
|
||||
nonce.copy_from_slice(&self.nonce()[..]);
|
||||
|
||||
let cipher = XChaCha20Poly1305::new(&self.cookie_key);
|
||||
|
||||
let iv = GenericArray::from_slice(nonce);
|
||||
|
||||
encrypted_cookie[..16].copy_from_slice(&cookie);
|
||||
let tag = cipher
|
||||
.encrypt_in_place_detached(iv, mac1, &mut encrypted_cookie[..16])
|
||||
.map_err(|_| WireGuardError::DestinationBufferTooSmall)?;
|
||||
|
||||
encrypted_cookie[16..].copy_from_slice(&tag);
|
||||
|
||||
Ok(&mut dst[..super::COOKIE_REPLY_SZ])
|
||||
}
|
||||
|
||||
/// Verify the MAC fields on the datagram, and apply rate limiting if needed
|
||||
pub fn verify_packet<'a, 'b>(
|
||||
&self,
|
||||
src_addr: Option<IpAddr>,
|
||||
src: &'a [u8],
|
||||
dst: &'b mut [u8],
|
||||
) -> Result<Packet<'a>, TunnResult<'b>> {
|
||||
let packet = Tunn::parse_incoming_packet(src)?;
|
||||
|
||||
// Verify and rate limit handshake messages only
|
||||
if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. })
|
||||
| Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet
|
||||
{
|
||||
let (msg, macs) = src.split_at(src.len() - 32);
|
||||
let (mac1, mac2) = macs.split_at(16);
|
||||
|
||||
let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, msg);
|
||||
verify_slices_are_equal(&computed_mac1[..16], mac1)
|
||||
.map_err(|_| TunnResult::Err(WireGuardError::InvalidMac))?;
|
||||
|
||||
if self.is_under_load() {
|
||||
let addr = match src_addr {
|
||||
None => return Err(TunnResult::Err(WireGuardError::UnderLoad)),
|
||||
Some(addr) => addr,
|
||||
};
|
||||
|
||||
// Only given an address can we validate mac2
|
||||
let cookie = self.current_cookie(addr);
|
||||
let computed_mac2 = b2s_keyed_mac_16_2(&cookie, msg, mac1);
|
||||
|
||||
if verify_slices_are_equal(&computed_mac2[..16], mac2).is_err() {
|
||||
let cookie_packet = self
|
||||
.format_cookie_reply(sender_idx, cookie, mac1, dst)
|
||||
.map_err(TunnResult::Err)?;
|
||||
return Err(TunnResult::WriteToNetwork(cookie_packet));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
}
|
||||
329
lib/boringtun/src/noise/session.rs
Normal file
329
lib/boringtun/src/noise/session.rs
Normal file
@@ -0,0 +1,329 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use super::PacketData;
|
||||
use crate::noise::errors::WireGuardError;
|
||||
use parking_lot::Mutex;
|
||||
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
pub struct Session {
|
||||
pub(crate) receiving_index: u32,
|
||||
sending_index: u32,
|
||||
receiver: LessSafeKey,
|
||||
sender: LessSafeKey,
|
||||
sending_key_counter: AtomicUsize,
|
||||
receiving_key_counter: Mutex<ReceivingKeyCounterValidator>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Session {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Session: {}<- ->{}",
|
||||
self.receiving_index, self.sending_index
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Where encrypted data resides in a data packet
|
||||
const DATA_OFFSET: usize = 16;
|
||||
/// The overhead of the AEAD
|
||||
const AEAD_SIZE: usize = 16;
|
||||
|
||||
// Receiving buffer constants
|
||||
const WORD_SIZE: u64 = 64;
|
||||
const N_WORDS: u64 = 16; // Suffice to reorder 64*16 = 1024 packets; can be increased at will
|
||||
const N_BITS: u64 = WORD_SIZE * N_WORDS;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct ReceivingKeyCounterValidator {
|
||||
/// In order to avoid replays while allowing for some reordering of the packets, we keep a
|
||||
/// bitmap of received packets, and the value of the highest counter
|
||||
next: u64,
|
||||
/// Used to estimate packet loss
|
||||
receive_cnt: u64,
|
||||
bitmap: [u64; N_WORDS as usize],
|
||||
}
|
||||
|
||||
impl ReceivingKeyCounterValidator {
|
||||
#[inline(always)]
|
||||
fn set_bit(&mut self, idx: u64) {
|
||||
let bit_idx = idx % N_BITS;
|
||||
let word = (bit_idx / WORD_SIZE) as usize;
|
||||
let bit = (bit_idx % WORD_SIZE) as usize;
|
||||
self.bitmap[word] |= 1 << bit;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn clear_bit(&mut self, idx: u64) {
|
||||
let bit_idx = idx % N_BITS;
|
||||
let word = (bit_idx / WORD_SIZE) as usize;
|
||||
let bit = (bit_idx % WORD_SIZE) as usize;
|
||||
self.bitmap[word] &= !(1u64 << bit);
|
||||
}
|
||||
|
||||
/// Clear the word that contains idx
|
||||
#[inline(always)]
|
||||
fn clear_word(&mut self, idx: u64) {
|
||||
let bit_idx = idx % N_BITS;
|
||||
let word = (bit_idx / WORD_SIZE) as usize;
|
||||
self.bitmap[word] = 0;
|
||||
}
|
||||
|
||||
/// Returns true if bit is set, false otherwise
|
||||
#[inline(always)]
|
||||
fn check_bit(&self, idx: u64) -> bool {
|
||||
let bit_idx = idx % N_BITS;
|
||||
let word = (bit_idx / WORD_SIZE) as usize;
|
||||
let bit = (bit_idx % WORD_SIZE) as usize;
|
||||
((self.bitmap[word] >> bit) & 1) == 1
|
||||
}
|
||||
|
||||
/// Returns true if the counter was not yet received, and is not too far back
|
||||
#[inline(always)]
|
||||
fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> {
|
||||
if counter >= self.next {
|
||||
// As long as the counter is growing no replay took place for sure
|
||||
return Ok(());
|
||||
}
|
||||
if counter + N_BITS < self.next {
|
||||
// Drop if too far back
|
||||
return Err(WireGuardError::InvalidCounter);
|
||||
}
|
||||
if !self.check_bit(counter) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(WireGuardError::DuplicateCounter)
|
||||
}
|
||||
}
|
||||
|
||||
/// Marks the counter as received, and returns true if it is still good (in case during
|
||||
/// decryption something changed)
|
||||
#[inline(always)]
|
||||
fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> {
|
||||
if counter + N_BITS < self.next {
|
||||
// Drop if too far back
|
||||
return Err(WireGuardError::InvalidCounter);
|
||||
}
|
||||
if counter == self.next {
|
||||
// Usually the packets arrive in order, in that case we simply mark the bit and
|
||||
// increment the counter
|
||||
self.set_bit(counter);
|
||||
self.next += 1;
|
||||
return Ok(());
|
||||
}
|
||||
if counter < self.next {
|
||||
// A packet arrived out of order, check if it is valid, and mark
|
||||
if self.check_bit(counter) {
|
||||
return Err(WireGuardError::InvalidCounter);
|
||||
}
|
||||
self.set_bit(counter);
|
||||
return Ok(());
|
||||
}
|
||||
// Packets where dropped, or maybe reordered, skip them and mark unused
|
||||
if counter - self.next >= N_BITS {
|
||||
// Too far ahead, clear all the bits
|
||||
for c in self.bitmap.iter_mut() {
|
||||
*c = 0;
|
||||
}
|
||||
} else {
|
||||
let mut i = self.next;
|
||||
while i % WORD_SIZE != 0 && i < counter {
|
||||
// Clear until i aligned to word size
|
||||
self.clear_bit(i);
|
||||
i += 1;
|
||||
}
|
||||
while i + WORD_SIZE < counter {
|
||||
// Clear whole word at a time
|
||||
self.clear_word(i);
|
||||
i = (i + WORD_SIZE) & 0u64.wrapping_sub(WORD_SIZE);
|
||||
}
|
||||
while i < counter {
|
||||
// Clear any remaining bits
|
||||
self.clear_bit(i);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
self.set_bit(counter);
|
||||
self.next = counter + 1;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub(super) fn new(
|
||||
local_index: u32,
|
||||
peer_index: u32,
|
||||
receiving_key: [u8; 32],
|
||||
sending_key: [u8; 32],
|
||||
) -> Session {
|
||||
Session {
|
||||
receiving_index: local_index,
|
||||
sending_index: peer_index,
|
||||
receiver: LessSafeKey::new(
|
||||
UnboundKey::new(&CHACHA20_POLY1305, &receiving_key).unwrap(),
|
||||
),
|
||||
sender: LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &sending_key).unwrap()),
|
||||
sending_key_counter: AtomicUsize::new(0),
|
||||
receiving_key_counter: Mutex::new(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn local_index(&self) -> usize {
|
||||
self.receiving_index as usize
|
||||
}
|
||||
|
||||
/// Returns true if receiving counter is good to use
|
||||
fn receiving_counter_quick_check(&self, counter: u64) -> Result<(), WireGuardError> {
|
||||
let counter_validator = self.receiving_key_counter.lock();
|
||||
counter_validator.will_accept(counter)
|
||||
}
|
||||
|
||||
/// Returns true if receiving counter is good to use, and marks it as used {
|
||||
fn receiving_counter_mark(&self, counter: u64) -> Result<(), WireGuardError> {
|
||||
let mut counter_validator = self.receiving_key_counter.lock();
|
||||
let ret = counter_validator.mark_did_receive(counter);
|
||||
if ret.is_ok() {
|
||||
counter_validator.receive_cnt += 1;
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
/// src - an IP packet from the interface
|
||||
/// dst - pre-allocated space to hold the encapsulating UDP packet to send over the network
|
||||
/// returns the size of the formatted packet
|
||||
pub(super) fn format_packet_data<'a>(&self, src: &[u8], dst: &'a mut [u8]) -> &'a mut [u8] {
|
||||
if dst.len() < src.len() + super::DATA_OVERHEAD_SZ {
|
||||
panic!("The destination buffer is too small");
|
||||
}
|
||||
|
||||
let sending_key_counter = self.sending_key_counter.fetch_add(1, Ordering::Relaxed) as u64;
|
||||
|
||||
let (message_type, rest) = dst.split_at_mut(4);
|
||||
let (receiver_index, rest) = rest.split_at_mut(4);
|
||||
let (counter, data) = rest.split_at_mut(8);
|
||||
|
||||
message_type.copy_from_slice(&super::DATA.to_le_bytes());
|
||||
receiver_index.copy_from_slice(&self.sending_index.to_le_bytes());
|
||||
counter.copy_from_slice(&sending_key_counter.to_le_bytes());
|
||||
|
||||
// TODO: spec requires padding to 16 bytes, but actually works fine without it
|
||||
let n = {
|
||||
let mut nonce = [0u8; 12];
|
||||
nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes());
|
||||
data[..src.len()].copy_from_slice(src);
|
||||
self.sender
|
||||
.seal_in_place_separate_tag(
|
||||
Nonce::assume_unique_for_key(nonce),
|
||||
Aad::from(&[]),
|
||||
&mut data[..src.len()],
|
||||
)
|
||||
.map(|tag| {
|
||||
data[src.len()..src.len() + AEAD_SIZE].copy_from_slice(tag.as_ref());
|
||||
src.len() + AEAD_SIZE
|
||||
})
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
&mut dst[..DATA_OFFSET + n]
|
||||
}
|
||||
|
||||
/// packet - a data packet we received from the network
|
||||
/// dst - pre-allocated space to hold the encapsulated IP packet, to send to the interface
|
||||
/// dst will always take less space than src
|
||||
/// return the size of the encapsulated packet on success
|
||||
pub(super) fn receive_packet_data<'a>(
|
||||
&self,
|
||||
packet: PacketData,
|
||||
dst: &'a mut [u8],
|
||||
) -> Result<&'a mut [u8], WireGuardError> {
|
||||
let ct_len = packet.encrypted_encapsulated_packet.len();
|
||||
if dst.len() < ct_len {
|
||||
// This is a very incorrect use of the library, therefore panic and not error
|
||||
panic!("The destination buffer is too small");
|
||||
}
|
||||
if packet.receiver_idx != self.receiving_index {
|
||||
return Err(WireGuardError::WrongIndex);
|
||||
}
|
||||
// Don't reuse counters, in case this is a replay attack we want to quickly check the counter without running expensive decryption
|
||||
self.receiving_counter_quick_check(packet.counter)?;
|
||||
|
||||
let ret = {
|
||||
let mut nonce = [0u8; 12];
|
||||
nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes());
|
||||
dst[..ct_len].copy_from_slice(packet.encrypted_encapsulated_packet);
|
||||
self.receiver
|
||||
.open_in_place(
|
||||
Nonce::assume_unique_for_key(nonce),
|
||||
Aad::from(&[]),
|
||||
&mut dst[..ct_len],
|
||||
)
|
||||
.map_err(|_| WireGuardError::InvalidAeadTag)?
|
||||
};
|
||||
|
||||
// After decryption is done, check counter again, and mark as received
|
||||
self.receiving_counter_mark(packet.counter)?;
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
/// Returns the estimated downstream packet loss for this session
|
||||
pub(super) fn current_packet_cnt(&self) -> (u64, u64) {
|
||||
let counter_validator = self.receiving_key_counter.lock();
|
||||
(counter_validator.next, counter_validator.receive_cnt)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_replay_counter() {
|
||||
let mut c: ReceivingKeyCounterValidator = Default::default();
|
||||
|
||||
assert!(c.mark_did_receive(0).is_ok());
|
||||
assert!(c.mark_did_receive(0).is_err());
|
||||
assert!(c.mark_did_receive(1).is_ok());
|
||||
assert!(c.mark_did_receive(1).is_err());
|
||||
assert!(c.mark_did_receive(63).is_ok());
|
||||
assert!(c.mark_did_receive(63).is_err());
|
||||
assert!(c.mark_did_receive(15).is_ok());
|
||||
assert!(c.mark_did_receive(15).is_err());
|
||||
|
||||
for i in 64..N_BITS + 128 {
|
||||
assert!(c.mark_did_receive(i).is_ok());
|
||||
assert!(c.mark_did_receive(i).is_err());
|
||||
}
|
||||
|
||||
assert!(c.mark_did_receive(N_BITS * 3).is_ok());
|
||||
for i in 0..=N_BITS * 2 {
|
||||
assert!(matches!(
|
||||
c.will_accept(i),
|
||||
Err(WireGuardError::InvalidCounter)
|
||||
));
|
||||
assert!(c.mark_did_receive(i).is_err());
|
||||
}
|
||||
for i in N_BITS * 2 + 1..N_BITS * 3 {
|
||||
assert!(c.will_accept(i).is_ok());
|
||||
}
|
||||
assert!(matches!(
|
||||
c.will_accept(N_BITS * 3),
|
||||
Err(WireGuardError::DuplicateCounter)
|
||||
));
|
||||
|
||||
for i in (N_BITS * 2 + 1..N_BITS * 3).rev() {
|
||||
assert!(c.mark_did_receive(i).is_ok());
|
||||
assert!(c.mark_did_receive(i).is_err());
|
||||
}
|
||||
|
||||
assert!(c.mark_did_receive(N_BITS * 3 + 70).is_ok());
|
||||
assert!(c.mark_did_receive(N_BITS * 3 + 71).is_ok());
|
||||
assert!(c.mark_did_receive(N_BITS * 3 + 72).is_ok());
|
||||
assert!(c.mark_did_receive(N_BITS * 3 + 72 + 125).is_ok());
|
||||
assert!(c.mark_did_receive(N_BITS * 3 + 63).is_ok());
|
||||
|
||||
assert!(c.mark_did_receive(N_BITS * 3 + 70).is_err());
|
||||
assert!(c.mark_did_receive(N_BITS * 3 + 71).is_err());
|
||||
assert!(c.mark_did_receive(N_BITS * 3 + 72).is_err());
|
||||
}
|
||||
}
|
||||
335
lib/boringtun/src/noise/timers.rs
Normal file
335
lib/boringtun/src/noise/timers.rs
Normal file
@@ -0,0 +1,335 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
use super::errors::WireGuardError;
|
||||
use crate::noise::{Tunn, TunnResult};
|
||||
use std::mem;
|
||||
use std::ops::{Index, IndexMut};
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(feature = "mock-instant")]
|
||||
use mock_instant::Instant;
|
||||
|
||||
#[cfg(not(feature = "mock-instant"))]
|
||||
use crate::sleepyinstant::Instant;
|
||||
|
||||
// Some constants, represent time in seconds
|
||||
// https://www.wireguard.com/papers/wireguard.pdf#page=14
|
||||
pub(crate) const REKEY_AFTER_TIME: Duration = Duration::from_secs(120);
|
||||
const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
|
||||
const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90);
|
||||
pub(crate) const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const COOKIE_EXPIRATION_TIME: Duration = Duration::from_secs(120);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TimerName {
|
||||
/// Current time, updated each call to `update_timers`
|
||||
TimeCurrent,
|
||||
/// Time when last handshake was completed
|
||||
TimeSessionEstablished,
|
||||
/// Time the last attempt for a new handshake began
|
||||
TimeLastHandshakeStarted,
|
||||
/// Time we last received and authenticated a packet
|
||||
TimeLastPacketReceived,
|
||||
/// Time we last send a packet
|
||||
TimeLastPacketSent,
|
||||
/// Time we last received and authenticated a DATA packet
|
||||
TimeLastDataPacketReceived,
|
||||
/// Time we last send a DATA packet
|
||||
TimeLastDataPacketSent,
|
||||
/// Time we last received a cookie
|
||||
TimeCookieReceived,
|
||||
/// Time we last sent persistent keepalive
|
||||
TimePersistentKeepalive,
|
||||
Top,
|
||||
}
|
||||
|
||||
use self::TimerName::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Timers {
|
||||
/// Is the owner of the timer the initiator or the responder for the last handshake?
|
||||
is_initiator: bool,
|
||||
/// Start time of the tunnel
|
||||
time_started: Instant,
|
||||
timers: [Duration; TimerName::Top as usize],
|
||||
pub(super) session_timers: [Duration; super::N_SESSIONS],
|
||||
/// Did we receive data without sending anything back?
|
||||
want_keepalive: bool,
|
||||
/// Did we send data without hearing back?
|
||||
want_handshake: bool,
|
||||
persistent_keepalive: usize,
|
||||
/// Should this timer call reset rr function (if not a shared rr instance)
|
||||
pub(super) should_reset_rr: bool,
|
||||
}
|
||||
|
||||
impl Timers {
|
||||
pub(super) fn new(persistent_keepalive: Option<u16>, reset_rr: bool) -> Timers {
|
||||
Timers {
|
||||
is_initiator: false,
|
||||
time_started: Instant::now(),
|
||||
timers: Default::default(),
|
||||
session_timers: Default::default(),
|
||||
want_keepalive: Default::default(),
|
||||
want_handshake: Default::default(),
|
||||
persistent_keepalive: usize::from(persistent_keepalive.unwrap_or(0)),
|
||||
should_reset_rr: reset_rr,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_initiator(&self) -> bool {
|
||||
self.is_initiator
|
||||
}
|
||||
|
||||
// We don't really clear the timers, but we set them to the current time to
|
||||
// so the reference time frame is the same
|
||||
pub(super) fn clear(&mut self) {
|
||||
let now = Instant::now().duration_since(self.time_started);
|
||||
for t in &mut self.timers[..] {
|
||||
*t = now;
|
||||
}
|
||||
self.want_handshake = false;
|
||||
self.want_keepalive = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl Index<TimerName> for Timers {
|
||||
type Output = Duration;
|
||||
fn index(&self, index: TimerName) -> &Duration {
|
||||
&self.timers[index as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl IndexMut<TimerName> for Timers {
|
||||
fn index_mut(&mut self, index: TimerName) -> &mut Duration {
|
||||
&mut self.timers[index as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl Tunn {
|
||||
pub(super) fn timer_tick(&mut self, timer_name: TimerName) {
|
||||
match timer_name {
|
||||
TimeLastPacketReceived => {
|
||||
self.timers.want_keepalive = true;
|
||||
self.timers.want_handshake = false;
|
||||
}
|
||||
TimeLastPacketSent => {
|
||||
self.timers.want_handshake = true;
|
||||
self.timers.want_keepalive = false;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let time = self.timers[TimeCurrent];
|
||||
self.timers[timer_name] = time;
|
||||
}
|
||||
|
||||
pub(super) fn timer_tick_session_established(
|
||||
&mut self,
|
||||
is_initiator: bool,
|
||||
session_idx: usize,
|
||||
) {
|
||||
self.timer_tick(TimeSessionEstablished);
|
||||
self.timers.session_timers[session_idx % crate::noise::N_SESSIONS] =
|
||||
self.timers[TimeCurrent];
|
||||
self.timers.is_initiator = is_initiator;
|
||||
}
|
||||
|
||||
// We don't really clear the timers, but we set them to the current time to
|
||||
// so the reference time frame is the same
|
||||
fn clear_all(&mut self) {
|
||||
for session in &mut self.sessions {
|
||||
*session = None;
|
||||
}
|
||||
|
||||
self.packet_queue.clear();
|
||||
|
||||
self.timers.clear();
|
||||
}
|
||||
|
||||
fn update_session_timers(&mut self, time_now: Duration) {
|
||||
let timers = &mut self.timers;
|
||||
|
||||
for (i, t) in timers.session_timers.iter_mut().enumerate() {
|
||||
if time_now - *t > REJECT_AFTER_TIME {
|
||||
if let Some(session) = self.sessions[i].take() {
|
||||
tracing::debug!(
|
||||
message = "SESSION_EXPIRED(REJECT_AFTER_TIME)",
|
||||
session = session.receiving_index
|
||||
);
|
||||
}
|
||||
*t = time_now;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
|
||||
let mut handshake_initiation_required = false;
|
||||
let mut keepalive_required = false;
|
||||
|
||||
let time = Instant::now();
|
||||
|
||||
if self.timers.should_reset_rr {
|
||||
self.rate_limiter.reset_count();
|
||||
}
|
||||
|
||||
// All the times are counted from tunnel initiation, for efficiency our timers are rounded
|
||||
// to a second, as there is no real benefit to having highly accurate timers.
|
||||
let now = time.duration_since(self.timers.time_started);
|
||||
self.timers[TimeCurrent] = now;
|
||||
|
||||
self.update_session_timers(now);
|
||||
|
||||
// Load timers only once:
|
||||
let session_established = self.timers[TimeSessionEstablished];
|
||||
let handshake_started = self.timers[TimeLastHandshakeStarted];
|
||||
let aut_packet_received = self.timers[TimeLastPacketReceived];
|
||||
let aut_packet_sent = self.timers[TimeLastPacketSent];
|
||||
let data_packet_received = self.timers[TimeLastDataPacketReceived];
|
||||
let data_packet_sent = self.timers[TimeLastDataPacketSent];
|
||||
let persistent_keepalive = self.timers.persistent_keepalive;
|
||||
|
||||
{
|
||||
if self.handshake.is_expired() {
|
||||
return TunnResult::Err(WireGuardError::ConnectionExpired);
|
||||
}
|
||||
|
||||
// Clear cookie after COOKIE_EXPIRATION_TIME
|
||||
if self.handshake.has_cookie()
|
||||
&& now - self.timers[TimeCookieReceived] >= COOKIE_EXPIRATION_TIME
|
||||
{
|
||||
self.handshake.clear_cookie();
|
||||
}
|
||||
|
||||
// All ephemeral private keys and symmetric session keys are zeroed out after
|
||||
// (REJECT_AFTER_TIME * 3) ms if no new keys have been exchanged.
|
||||
if now - session_established >= REJECT_AFTER_TIME * 3 {
|
||||
tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)");
|
||||
self.handshake.set_expired();
|
||||
self.clear_all();
|
||||
return TunnResult::Err(WireGuardError::ConnectionExpired);
|
||||
}
|
||||
|
||||
if let Some(time_init_sent) = self.handshake.timer() {
|
||||
// Handshake Initiation Retransmission
|
||||
if now - handshake_started >= REKEY_ATTEMPT_TIME {
|
||||
// After REKEY_ATTEMPT_TIME ms of trying to initiate a new handshake,
|
||||
// the retries give up and cease, and clear all existing packets queued
|
||||
// up to be sent. If a packet is explicitly queued up to be sent, then
|
||||
// this timer is reset.
|
||||
tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)");
|
||||
self.handshake.set_expired();
|
||||
self.clear_all();
|
||||
return TunnResult::Err(WireGuardError::ConnectionExpired);
|
||||
}
|
||||
|
||||
if time_init_sent.elapsed() >= REKEY_TIMEOUT {
|
||||
// We avoid using `time` here, because it can be earlier than `time_init_sent`.
|
||||
// Once `checked_duration_since` is stable we can use that.
|
||||
// A handshake initiation is retried after REKEY_TIMEOUT + jitter ms,
|
||||
// if a response has not been received, where jitter is some random
|
||||
// value between 0 and 333 ms.
|
||||
tracing::warn!("HANDSHAKE(REKEY_TIMEOUT)");
|
||||
handshake_initiation_required = true;
|
||||
}
|
||||
} else {
|
||||
if self.timers.is_initiator() {
|
||||
// After sending a packet, if the sender was the original initiator
|
||||
// of the handshake and if the current session key is REKEY_AFTER_TIME
|
||||
// ms old, we initiate a new handshake. If the sender was the original
|
||||
// responder of the handshake, it does not re-initiate a new handshake
|
||||
// after REKEY_AFTER_TIME ms like the original initiator does.
|
||||
if session_established < data_packet_sent
|
||||
&& now - session_established >= REKEY_AFTER_TIME
|
||||
{
|
||||
tracing::debug!("HANDSHAKE(REKEY_AFTER_TIME (on send))");
|
||||
handshake_initiation_required = true;
|
||||
}
|
||||
|
||||
// After receiving a packet, if the receiver was the original initiator
|
||||
// of the handshake and if the current session key is REJECT_AFTER_TIME
|
||||
// - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT ms old, we initiate a new
|
||||
// handshake.
|
||||
if session_established < data_packet_received
|
||||
&& now - session_established
|
||||
>= REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT
|
||||
{
|
||||
tracing::warn!(
|
||||
"HANDSHAKE(REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - \
|
||||
REKEY_TIMEOUT \
|
||||
(on receive))"
|
||||
);
|
||||
handshake_initiation_required = true;
|
||||
}
|
||||
}
|
||||
|
||||
// If we have sent a packet to a given peer but have not received a
|
||||
// packet after from that peer for (KEEPALIVE + REKEY_TIMEOUT) ms,
|
||||
// we initiate a new handshake.
|
||||
if data_packet_sent > aut_packet_received
|
||||
&& now - aut_packet_received >= KEEPALIVE_TIMEOUT + REKEY_TIMEOUT
|
||||
&& mem::replace(&mut self.timers.want_handshake, false)
|
||||
{
|
||||
tracing::warn!("HANDSHAKE(KEEPALIVE + REKEY_TIMEOUT)");
|
||||
handshake_initiation_required = true;
|
||||
}
|
||||
|
||||
if !handshake_initiation_required {
|
||||
// If a packet has been received from a given peer, but we have not sent one back
|
||||
// to the given peer in KEEPALIVE ms, we send an empty packet.
|
||||
if data_packet_received > aut_packet_sent
|
||||
&& now - aut_packet_sent >= KEEPALIVE_TIMEOUT
|
||||
&& mem::replace(&mut self.timers.want_keepalive, false)
|
||||
{
|
||||
tracing::debug!("KEEPALIVE(KEEPALIVE_TIMEOUT)");
|
||||
keepalive_required = true;
|
||||
}
|
||||
|
||||
// Persistent KEEPALIVE
|
||||
if persistent_keepalive > 0
|
||||
&& (now - self.timers[TimePersistentKeepalive]
|
||||
>= Duration::from_secs(persistent_keepalive as _))
|
||||
{
|
||||
tracing::debug!("KEEPALIVE(PERSISTENT_KEEPALIVE)");
|
||||
self.timer_tick(TimePersistentKeepalive);
|
||||
keepalive_required = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if handshake_initiation_required {
|
||||
return self.format_handshake_initiation(dst, true);
|
||||
}
|
||||
|
||||
if keepalive_required {
|
||||
return self.encapsulate(&[], dst);
|
||||
}
|
||||
|
||||
TunnResult::Done
|
||||
}
|
||||
|
||||
pub fn time_since_last_handshake(&self) -> Option<Duration> {
|
||||
let current_session = self.current;
|
||||
if self.sessions[current_session % super::N_SESSIONS].is_some() {
|
||||
let duration_since_tun_start = Instant::now().duration_since(self.timers.time_started);
|
||||
let duration_since_session_established = self.timers[TimeSessionEstablished];
|
||||
|
||||
Some(duration_since_tun_start - duration_since_session_established)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn persistent_keepalive(&self) -> Option<u16> {
|
||||
let keepalive = self.timers.persistent_keepalive;
|
||||
|
||||
if keepalive > 0 {
|
||||
Some(keepalive as u16)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
33
lib/boringtun/src/serialization.rs
Normal file
33
lib/boringtun/src/serialization.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
pub(crate) struct KeyBytes(pub [u8; 32]);
|
||||
|
||||
impl std::str::FromStr for KeyBytes {
|
||||
type Err = &'static str;
|
||||
|
||||
/// Can parse a secret key from a hex or base64 encoded string.
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut internal = [0u8; 32];
|
||||
|
||||
match s.len() {
|
||||
64 => {
|
||||
// Try to parse as hex
|
||||
for i in 0..32 {
|
||||
internal[i] = u8::from_str_radix(&s[i * 2..=i * 2 + 1], 16)
|
||||
.map_err(|_| "Illegal character in key")?;
|
||||
}
|
||||
}
|
||||
43 | 44 => {
|
||||
// Try to parse as base64
|
||||
if let Ok(decoded_key) = base64::decode(s) {
|
||||
if decoded_key.len() == internal.len() {
|
||||
internal[..].copy_from_slice(&decoded_key);
|
||||
} else {
|
||||
return Err("Illegal character in key");
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => return Err("Illegal key size"),
|
||||
}
|
||||
|
||||
Ok(KeyBytes(internal))
|
||||
}
|
||||
}
|
||||
77
lib/boringtun/src/sleepyinstant/mod.rs
Normal file
77
lib/boringtun/src/sleepyinstant/mod.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
#![forbid(unsafe_code)]
|
||||
//! Attempts to provide the same functionality as std::time::Instant, except it
|
||||
//! uses a timer which accounts for time when the system is asleep
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
mod windows;
|
||||
#[cfg(target_os = "windows")]
|
||||
use windows as inner;
|
||||
|
||||
#[cfg(unix)]
|
||||
mod unix;
|
||||
#[cfg(unix)]
|
||||
use unix as inner;
|
||||
|
||||
/// A measurement of a monotonically nondecreasing clock.
|
||||
/// Opaque and useful only with [`Duration`].
|
||||
///
|
||||
/// Instants are always guaranteed, barring [platform bugs], to be no less than any previously
|
||||
/// measured instant when created, and are often useful for tasks such as measuring
|
||||
/// benchmarks or timing how long an operation takes.
|
||||
///
|
||||
/// Note, however, that instants are **not** guaranteed to be **steady**. In other
|
||||
/// words, each tick of the underlying clock might not be the same length (e.g.
|
||||
/// some seconds may be longer than others). An instant may jump forwards or
|
||||
/// experience time dilation (slow down or speed up), but it will never go
|
||||
/// backwards.
|
||||
///
|
||||
/// Instants are opaque types that can only be compared to one another. There is
|
||||
/// no method to get "the number of seconds" from an instant. Instead, it only
|
||||
/// allows measuring the duration between two instants (or comparing two
|
||||
/// instants).
|
||||
///
|
||||
/// The size of an `Instant` struct may vary depending on the target operating
|
||||
/// system.
|
||||
///
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Instant {
|
||||
t: inner::Instant,
|
||||
}
|
||||
|
||||
impl Instant {
|
||||
/// Returns an instant corresponding to "now".
|
||||
pub fn now() -> Self {
|
||||
Self {
|
||||
t: inner::Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the amount of time elapsed from another instant to this one,
|
||||
/// or zero duration if that instant is later than this one.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// panics when `earlier` was later than `self`.
|
||||
pub fn duration_since(&self, earlier: Instant) -> Duration {
|
||||
self.t.duration_since(earlier.t)
|
||||
}
|
||||
|
||||
/// Returns the amount of time elapsed since this instant was created.
|
||||
pub fn elapsed(&self) -> Duration {
|
||||
Self::now().duration_since(*self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn time_increments_after_sleep() {
|
||||
let sleep_time = Duration::from_millis(10);
|
||||
let start = Instant::now();
|
||||
std::thread::sleep(sleep_time);
|
||||
assert!(start.elapsed() >= sleep_time);
|
||||
}
|
||||
}
|
||||
48
lib/boringtun/src/sleepyinstant/unix.rs
Normal file
48
lib/boringtun/src/sleepyinstant/unix.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use nix::sys::time::TimeSpec;
|
||||
use nix::time::{clock_gettime, ClockId};
|
||||
|
||||
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))]
|
||||
const CLOCK_ID: ClockId = ClockId::CLOCK_MONOTONIC;
|
||||
#[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "tvos")))]
|
||||
const CLOCK_ID: ClockId = ClockId::CLOCK_BOOTTIME;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(crate) struct Instant {
|
||||
t: TimeSpec,
|
||||
}
|
||||
|
||||
impl Instant {
|
||||
pub(crate) fn now() -> Self {
|
||||
// std::time::Instant unwraps as well, so feel safe doing so here
|
||||
let t = clock_gettime(CLOCK_ID).unwrap();
|
||||
Self { t }
|
||||
}
|
||||
|
||||
fn checked_duration_since(&self, earlier: Instant) -> Option<Duration> {
|
||||
const NANOSECOND: nix::libc::c_long = 1_000_000_000;
|
||||
let (tv_sec, tv_nsec) = if self.t.tv_nsec() < earlier.t.tv_nsec() {
|
||||
(
|
||||
self.t.tv_sec() - earlier.t.tv_sec() - 1,
|
||||
self.t.tv_nsec() - earlier.t.tv_nsec() + NANOSECOND,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
self.t.tv_sec() - earlier.t.tv_sec(),
|
||||
self.t.tv_nsec() - earlier.t.tv_nsec(),
|
||||
)
|
||||
};
|
||||
|
||||
if tv_sec < 0 {
|
||||
None
|
||||
} else {
|
||||
Some(Duration::new(tv_sec as _, tv_nsec as _))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn duration_since(&self, earlier: Instant) -> Duration {
|
||||
self.checked_duration_since(earlier)
|
||||
.unwrap_or(Duration::ZERO)
|
||||
}
|
||||
}
|
||||
1
lib/boringtun/src/sleepyinstant/windows.rs
Normal file
1
lib/boringtun/src/sleepyinstant/windows.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub(crate) use std::time::Instant;
|
||||
106
lib/boringtun/src/wireguard_ffi.h
Normal file
106
lib/boringtun/src/wireguard_ffi.h
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
struct wireguard_tunnel; // This corresponds to the Rust type
|
||||
|
||||
enum
|
||||
{
|
||||
MAX_WIREGUARD_PACKET_SIZE = 65536 + 64,
|
||||
};
|
||||
|
||||
enum result_type
|
||||
{
|
||||
WIREGUARD_DONE = 0,
|
||||
WRITE_TO_NETWORK = 1,
|
||||
WIREGUARD_ERROR = 2,
|
||||
WRITE_TO_TUNNEL_IPV4 = 4,
|
||||
WRITE_TO_TUNNEL_IPV6 = 6,
|
||||
};
|
||||
|
||||
struct wireguard_result
|
||||
{
|
||||
enum result_type op;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct stats
|
||||
{
|
||||
int64_t time_since_last_handshake;
|
||||
size_t tx_bytes;
|
||||
size_t rx_bytes;
|
||||
float estimated_loss;
|
||||
int32_t estimated_rtt; // rtt estimated on time it took to complete latest initiated handshake in ms
|
||||
uint8_t reserved[56]; // decrement appropriately when adding new fields
|
||||
};
|
||||
|
||||
struct x25519_key
|
||||
{
|
||||
uint8_t key[32];
|
||||
};
|
||||
|
||||
// Generates a fresh x25519 secret key
|
||||
struct x25519_key x25519_secret_key();
|
||||
// Computes an x25519 public key from a secret key
|
||||
struct x25519_key x25519_public_key(struct x25519_key private_key);
|
||||
// Encodes a public or private x25519 key to base64. Must be freed with x25519_key_to_str_free.
|
||||
const char *x25519_key_to_base64(struct x25519_key key);
|
||||
// Encodes a public or private x25519 key to hex. Must be freed with x25519_key_to_str_free.
|
||||
const char *x25519_key_to_hex(struct x25519_key key);
|
||||
// Free string pointer obtained from either x25519_key_to_base64 or x25519_key_to_hex
|
||||
void x25519_key_to_str_free(const char *key_str);
|
||||
// Check if a null terminated string represents a valid x25519 key
|
||||
// Returns 0 if not
|
||||
int check_base64_encoded_x25519_key(const char *key);
|
||||
|
||||
/// Sets the default tracing_subscriber to write to `log_func`.
|
||||
///
|
||||
/// Uses Compact format without level, target, thread ids, thread names, or ansi control characters.
|
||||
/// Subscribes to TRACE level events.
|
||||
///
|
||||
/// This function should only be called once as setting the default tracing_subscriber
|
||||
/// more than once will result in an error.
|
||||
///
|
||||
/// Returns false on failure.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `c_char` will be freed by the library after calling `log_func`. If the value needs
|
||||
/// to be stored then `log_func` needs to create a copy, e.g. `strcpy`.
|
||||
bool set_logging_function(void (*log_func)(const char *));
|
||||
|
||||
// Allocate a new tunnel
|
||||
struct wireguard_tunnel *new_tunnel(const char *static_private,
|
||||
const char *server_static_public,
|
||||
const char *preshared_key,
|
||||
uint16_t keep_alive, // Keep alive interval in seconds
|
||||
uint32_t index); // The 24bit index prefix to be used for session indexes
|
||||
|
||||
// Deallocate the tunnel
|
||||
void tunnel_free(struct wireguard_tunnel *);
|
||||
|
||||
struct wireguard_result wireguard_write(const struct wireguard_tunnel *tunnel,
|
||||
const uint8_t *src,
|
||||
uint32_t src_size,
|
||||
uint8_t *dst,
|
||||
uint32_t dst_size);
|
||||
|
||||
struct wireguard_result wireguard_read(const struct wireguard_tunnel *tunnel,
|
||||
const uint8_t *src,
|
||||
uint32_t src_size,
|
||||
uint8_t *dst,
|
||||
uint32_t dst_size);
|
||||
|
||||
struct wireguard_result wireguard_tick(const struct wireguard_tunnel *tunnel,
|
||||
uint8_t *dst,
|
||||
uint32_t dst_size);
|
||||
|
||||
struct wireguard_result wireguard_force_handshake(const struct wireguard_tunnel *tunnel,
|
||||
uint8_t *dst,
|
||||
uint32_t dst_size);
|
||||
|
||||
struct stats wireguard_stats(const struct wireguard_tunnel *tunnel);
|
||||
@@ -24,6 +24,7 @@ message RegistrationRequest {
|
||||
fixed32 virtual_ip = 6;
|
||||
bool allow_ip_change = 7;
|
||||
bool client_secret = 8;
|
||||
bytes client_secret_hash = 9;
|
||||
}
|
||||
|
||||
message RegistrationResponse {
|
||||
@@ -41,6 +42,8 @@ message DeviceInfo {
|
||||
fixed32 virtual_ip = 2;
|
||||
uint32 device_status = 3;
|
||||
bool client_secret = 4;
|
||||
bytes client_secret_hash = 5;
|
||||
bool wireguard = 6;
|
||||
}
|
||||
|
||||
message DeviceList {
|
||||
|
||||
@@ -89,7 +89,7 @@ impl RsaCipher {
|
||||
})
|
||||
}
|
||||
pub fn finger_(public_key_der: &[u8]) -> io::Result<String> {
|
||||
match rsa::pkcs8::SubjectPublicKeyInfo::from_der(public_key_der) {
|
||||
match spki::SubjectPublicKeyInfoOwned::from_der(public_key_der) {
|
||||
Ok(spki) => match spki.fingerprint_base64() {
|
||||
Ok(finger) => Ok(finger),
|
||||
Err(e) => Err(io::Error::new(
|
||||
@@ -120,7 +120,7 @@ impl RsaCipher {
|
||||
match self
|
||||
.inner
|
||||
.private_key
|
||||
.decrypt(rsa::PaddingScheme::PKCS1v15Encrypt, net_packet.payload())
|
||||
.decrypt(rsa::Pkcs1v15Encrypt, net_packet.payload())
|
||||
{
|
||||
Ok(rs) => {
|
||||
let mut nonce_raw = [0; 12];
|
||||
|
||||
@@ -1,10 +1,23 @@
|
||||
use chrono::{DateTime, Local};
|
||||
use std::collections::HashMap;
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
|
||||
use chrono::{DateTime, Local};
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WireGuardConfig {
|
||||
pub vnts_endpoint: String,
|
||||
pub vnts_allowed_ips: String,
|
||||
pub group_id: String,
|
||||
pub device_id: String,
|
||||
pub ip: Ipv4Addr,
|
||||
pub prefix: u8,
|
||||
pub persistent_keepalive: u16,
|
||||
pub secret_key: [u8; 32],
|
||||
pub public_key: [u8; 32],
|
||||
}
|
||||
/// 网段信息
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Debug)]
|
||||
pub struct NetworkInfo {
|
||||
// 组网编号
|
||||
// pub group: String,
|
||||
@@ -33,6 +46,7 @@ impl NetworkInfo {
|
||||
}
|
||||
|
||||
/// 客户端信息
|
||||
#[derive(Debug)]
|
||||
pub struct ClientInfo {
|
||||
// 设备ID
|
||||
pub device_id: String,
|
||||
@@ -42,6 +56,8 @@ pub struct ClientInfo {
|
||||
pub name: String,
|
||||
// 客户端间是否加密
|
||||
pub client_secret: bool,
|
||||
// 加密hash
|
||||
pub client_secret_hash: Vec<u8>,
|
||||
// 和服务端是否加密
|
||||
pub server_secret: bool,
|
||||
// 链接服务器的来源地址
|
||||
@@ -52,11 +68,51 @@ pub struct ClientInfo {
|
||||
pub virtual_ip: u32,
|
||||
// 建立的tcp连接发送端
|
||||
pub tcp_sender: Option<Sender<Vec<u8>>>,
|
||||
// wireguard客户端公钥
|
||||
pub wireguard: Option<[u8; 32]>,
|
||||
pub wg_sender: Option<Sender<(Vec<u8>, Ipv4Addr)>>,
|
||||
pub client_status: Option<ClientStatusInfo>,
|
||||
pub last_join_time: DateTime<Local>,
|
||||
pub timestamp: i64,
|
||||
}
|
||||
|
||||
/// 客户端简要信息
|
||||
#[derive(Debug)]
|
||||
pub struct SimpleClientInfo {
|
||||
// 分配的ip
|
||||
pub virtual_ip: u32,
|
||||
// 版本
|
||||
pub version: String,
|
||||
// 名称
|
||||
pub name: String,
|
||||
// 客户端间是否加密
|
||||
pub client_secret: bool,
|
||||
// 加密hash
|
||||
pub client_secret_hash: Vec<u8>,
|
||||
// 和服务端是否加密
|
||||
pub server_secret: bool,
|
||||
// 是否在线
|
||||
pub online: bool,
|
||||
// 是wg客户端
|
||||
pub wireguard: bool,
|
||||
}
|
||||
impl From<&ClientInfo> for SimpleClientInfo {
|
||||
fn from(value: &ClientInfo) -> Self {
|
||||
Self {
|
||||
virtual_ip: value.virtual_ip,
|
||||
version: value.version.clone(),
|
||||
name: value.name.clone(),
|
||||
client_secret: value.client_secret,
|
||||
client_secret_hash: if value.online {
|
||||
value.client_secret_hash.clone()
|
||||
} else {
|
||||
vec![]
|
||||
},
|
||||
server_secret: value.server_secret,
|
||||
online: value.online,
|
||||
wireguard: value.wireguard.is_some(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Default for ClientInfo {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -64,18 +120,21 @@ impl Default for ClientInfo {
|
||||
version: "".to_string(),
|
||||
name: "".to_string(),
|
||||
client_secret: false,
|
||||
client_secret_hash: vec![],
|
||||
server_secret: false,
|
||||
address: "0.0.0.0:0".parse().unwrap(),
|
||||
online: false,
|
||||
virtual_ip: 0,
|
||||
tcp_sender: None,
|
||||
wireguard: None,
|
||||
wg_sender: None,
|
||||
client_status: None,
|
||||
last_join_time: Local::now(),
|
||||
timestamp: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClientStatusInfo {
|
||||
pub p2p_list: Vec<Ipv4Addr>,
|
||||
pub up_stream: u64,
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::sync::Arc;
|
||||
use tokio::net::{TcpListener, UdpSocket};
|
||||
|
||||
use crate::cipher::RsaCipher;
|
||||
use crate::core::server::wire_guard::WireGuardGroup;
|
||||
use crate::core::service::PacketHandler;
|
||||
use crate::core::store::cache::AppCache;
|
||||
use crate::ConfigInfo;
|
||||
@@ -13,6 +14,7 @@ mod udp;
|
||||
#[cfg(feature = "web")]
|
||||
mod web;
|
||||
mod websocket;
|
||||
mod wire_guard;
|
||||
|
||||
pub async fn start(
|
||||
udp: std::net::UdpSocket,
|
||||
@@ -29,8 +31,9 @@ pub async fn start(
|
||||
rsa_cipher.clone(),
|
||||
udp.clone(),
|
||||
);
|
||||
let wg = WireGuardGroup::new(cache.clone(), config.clone(), udp.clone());
|
||||
let tcp_handle = tokio::spawn(tcp::start(TcpListener::from_std(tcp)?, handler.clone()));
|
||||
let udp_handle = tokio::spawn(udp::start(udp, handler.clone()));
|
||||
let udp_handle = tokio::spawn(udp::start(udp, handler.clone(), wg));
|
||||
#[cfg(not(feature = "web"))]
|
||||
let _ = tokio::try_join!(tcp_handle, udp_handle);
|
||||
#[cfg(feature = "web")]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::core::service::PacketHandler;
|
||||
use crate::core::store::cache::VntContext;
|
||||
use crate::protocol::NetPacket;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
@@ -73,21 +74,29 @@ async fn stream_handle(stream: TcpStream, addr: SocketAddr, handler: PacketHandl
|
||||
let _ = w.shutdown().await;
|
||||
});
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = tcp_read(r, addr, sender, handler).await {
|
||||
let mut context = VntContext {
|
||||
link_context: None,
|
||||
server_cipher: None,
|
||||
link_address: addr,
|
||||
};
|
||||
if let Err(e) = tcp_read(&mut context, r, addr, sender, &handler).await {
|
||||
log::warn!("tcp_read {:?}", e)
|
||||
}
|
||||
handler.leave(context).await;
|
||||
});
|
||||
}
|
||||
|
||||
async fn tcp_read(
|
||||
context: &mut VntContext,
|
||||
mut read: OwnedReadHalf,
|
||||
addr: SocketAddr,
|
||||
sender: Sender<Vec<u8>>,
|
||||
handler: PacketHandler,
|
||||
handler: &PacketHandler,
|
||||
) -> io::Result<()> {
|
||||
let mut head = [0; 4];
|
||||
let mut buf = [0; 65536];
|
||||
let sender = Some(sender);
|
||||
|
||||
loop {
|
||||
read.read_exact(&mut head).await?;
|
||||
if head[0] != 0 {
|
||||
@@ -103,7 +112,7 @@ async fn tcp_read(
|
||||
}
|
||||
read.read_exact(&mut buf[..len]).await?;
|
||||
let packet = NetPacket::new0(len, &mut buf)?;
|
||||
if let Some(rs) = handler.handle(packet, addr, &sender).await {
|
||||
if let Some(rs) = handler.handle(context, packet, addr, &sender).await {
|
||||
if sender
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
|
||||
@@ -1,21 +1,91 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
use crate::core::server::wire_guard::WireGuardGroup;
|
||||
use crate::core::service::PacketHandler;
|
||||
use crate::core::store::cache::VntContext;
|
||||
use crate::protocol::NetPacket;
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::mpsc::{channel, Sender};
|
||||
|
||||
pub async fn start(main_udp: Arc<UdpSocket>, handler: PacketHandler, mut wg: WireGuardGroup) {
|
||||
let mut udp_group = UdpGroup::new(main_udp.clone(), handler);
|
||||
let mut buf = [0u8; 65536];
|
||||
|
||||
pub async fn start(main_udp: Arc<UdpSocket>, handler: PacketHandler) {
|
||||
loop {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
match main_udp.recv_from(&mut buf).await {
|
||||
Ok((len, addr)) => {
|
||||
let handler = handler.clone();
|
||||
let udp = main_udp.clone();
|
||||
tokio::spawn(async move {
|
||||
match NetPacket::new(&mut buf[..len]) {
|
||||
if len == 0 {
|
||||
log::warn!("UnexpectedEof {}", addr);
|
||||
continue;
|
||||
}
|
||||
let buf = buf[..len].to_vec();
|
||||
if WireGuardGroup::maybe_wg(&buf) {
|
||||
// 可能是wg协议
|
||||
wg.handle(buf, addr);
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = udp_group.handle(buf, addr) {
|
||||
log::error!("{} {:?}", addr, e);
|
||||
}
|
||||
}
|
||||
#[cfg(windows)]
|
||||
Err(ref e) if e.kind() == std::io::ErrorKind::ConnectionReset => {
|
||||
// 忽略 ConnectionReset 错误
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("{:?}", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpGroup {
|
||||
data_channel_map: Arc<Mutex<HashMap<SocketAddr, Sender<Vec<u8>>>>>,
|
||||
udp: Arc<UdpSocket>,
|
||||
handler: PacketHandler,
|
||||
}
|
||||
|
||||
impl UdpGroup {
|
||||
pub fn new(udp: Arc<UdpSocket>, handler: PacketHandler) -> Self {
|
||||
Self {
|
||||
data_channel_map: Default::default(),
|
||||
udp,
|
||||
handler,
|
||||
}
|
||||
}
|
||||
pub fn handle(&mut self, buf: Vec<u8>, addr: SocketAddr) -> anyhow::Result<()> {
|
||||
if let Some(sender) = self.data_channel_map.lock().get(&addr) {
|
||||
sender.try_send(buf)?;
|
||||
return Ok(());
|
||||
}
|
||||
let (udp_sender, mut udp_receiver) = channel(64);
|
||||
udp_sender.try_send(buf)?;
|
||||
let data_channel_map = self.data_channel_map.clone();
|
||||
data_channel_map.lock().insert(addr, udp_sender);
|
||||
let handler = self.handler.clone();
|
||||
let udp = self.udp.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut context = VntContext {
|
||||
link_context: None,
|
||||
server_cipher: None,
|
||||
link_address: addr,
|
||||
};
|
||||
loop {
|
||||
let data = match tokio::time::timeout(Duration::from_secs(60), udp_receiver.recv())
|
||||
.await
|
||||
{
|
||||
Ok(data) => data,
|
||||
Err(_) => break,
|
||||
};
|
||||
if let Some(data) = data {
|
||||
match NetPacket::new(data) {
|
||||
Ok(net_packet) => {
|
||||
if let Some(rs) = handler.handle(net_packet, addr, &None).await {
|
||||
if let Some(rs) =
|
||||
handler.handle(&mut context, net_packet, addr, &None).await
|
||||
{
|
||||
if let Err(e) = udp.send_to(rs.buffer(), addr).await {
|
||||
log::error!("{:?} {}", e, addr)
|
||||
}
|
||||
@@ -25,11 +95,13 @@ pub async fn start(main_udp: Arc<UdpSocket>, handler: PacketHandler) {
|
||||
log::error!("{:?} {}", e, addr)
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("{:?}", e)
|
||||
}
|
||||
}
|
||||
handler.leave(context).await;
|
||||
data_channel_map.lock().remove(&addr);
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashMap;
|
||||
use std::net;
|
||||
use std::sync::Arc;
|
||||
|
||||
use actix_web::dev::Service;
|
||||
use actix_web::web::Data;
|
||||
@@ -9,7 +8,9 @@ use actix_web::{middleware, post, web, App, HttpRequest, HttpResponse, HttpServe
|
||||
use actix_web_static_files::ResourceFiles;
|
||||
|
||||
use crate::core::server::web::service::VntsWebService;
|
||||
use crate::core::server::web::vo::{LoginData, ResponseMessage};
|
||||
use crate::core::server::web::vo::req::{CreateWGData, LoginData, RemoveClientReq};
|
||||
|
||||
use crate::core::server::web::vo::ResponseMessage;
|
||||
use crate::core::store::cache::AppCache;
|
||||
use crate::ConfigInfo;
|
||||
|
||||
@@ -18,7 +19,7 @@ mod vo;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/generated.rs"));
|
||||
|
||||
#[post("/login")]
|
||||
#[post("/api/login")]
|
||||
async fn login(service: Data<VntsWebService>, data: web::Json<LoginData>) -> HttpResponse {
|
||||
match service.login(data.0).await {
|
||||
Ok(auth) => HttpResponse::Ok().json(ResponseMessage::success(auth)),
|
||||
@@ -26,13 +27,37 @@ async fn login(service: Data<VntsWebService>, data: web::Json<LoginData>) -> Htt
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/group_list")]
|
||||
#[post("/api/group_list")]
|
||||
async fn group_list(_req: HttpRequest, service: Data<VntsWebService>) -> HttpResponse {
|
||||
let info = service.group_list();
|
||||
HttpResponse::Ok().json(ResponseMessage::success(info))
|
||||
}
|
||||
|
||||
#[post("/group_info")]
|
||||
#[post("/api/remove_client")]
|
||||
async fn remove_client(
|
||||
_req: HttpRequest,
|
||||
service: Data<VntsWebService>,
|
||||
data: web::Json<RemoveClientReq>,
|
||||
) -> HttpResponse {
|
||||
service.remove_client(data.0);
|
||||
HttpResponse::Ok().json(ResponseMessage::success("success"))
|
||||
}
|
||||
#[post("/api/private_key")]
|
||||
async fn private_key(_req: HttpRequest, service: Data<VntsWebService>) -> HttpResponse {
|
||||
let private_key = service.gen_wg_private_key();
|
||||
HttpResponse::Ok().json(ResponseMessage::success(private_key))
|
||||
}
|
||||
#[post("/api/create_wg_config")]
|
||||
async fn create_wg_config(
|
||||
_req: HttpRequest,
|
||||
service: Data<VntsWebService>,
|
||||
data: web::Json<CreateWGData>,
|
||||
) -> HttpResponse {
|
||||
match service.create_wg_config(data.0).await {
|
||||
Ok(wg_config) => HttpResponse::Ok().json(ResponseMessage::success(wg_config)),
|
||||
Err(e) => HttpResponse::Ok().json(ResponseMessage::fail(e.to_string())),
|
||||
}
|
||||
}
|
||||
#[post("/api/group_info")]
|
||||
async fn group_info(
|
||||
_req: HttpRequest,
|
||||
service: Data<VntsWebService>,
|
||||
@@ -46,36 +71,19 @@ async fn group_info(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AuthApi {
|
||||
api_set: Arc<HashSet<String>>,
|
||||
}
|
||||
|
||||
fn auth_api_set() -> AuthApi {
|
||||
let mut api_set = HashSet::new();
|
||||
api_set.insert("/group_info".to_string());
|
||||
api_set.insert("/group_list".to_string());
|
||||
AuthApi {
|
||||
api_set: Arc::new(api_set),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start(
|
||||
lst: net::TcpListener,
|
||||
cache: AppCache,
|
||||
config: ConfigInfo,
|
||||
) -> std::io::Result<()> {
|
||||
let web_service = VntsWebService::new(cache, config);
|
||||
let auth_api = auth_api_set();
|
||||
HttpServer::new(move || {
|
||||
let generated = generate();
|
||||
App::new()
|
||||
.app_data(Data::new(web_service.clone()))
|
||||
.app_data(Data::new(auth_api.clone()))
|
||||
.wrap_fn(|request, srv| {
|
||||
let auth_api: &Data<AuthApi> = request.app_data().unwrap();
|
||||
let path = request.path();
|
||||
if path == "/login" || !auth_api.api_set.contains(path) {
|
||||
if path == "/api/login" || !path.contains("/api/") {
|
||||
return srv.call(request);
|
||||
}
|
||||
let service: &Data<VntsWebService> = request.app_data().unwrap();
|
||||
@@ -96,6 +104,9 @@ pub async fn start(
|
||||
})
|
||||
.wrap(middleware::Compress::default())
|
||||
.service(login)
|
||||
.service(remove_client)
|
||||
.service(private_key)
|
||||
.service(create_wg_config)
|
||||
.service(group_list)
|
||||
.service(group_info)
|
||||
.service(ResourceFiles::new("/", generated))
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use std::net::{SocketAddr, SocketAddrV4};
|
||||
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::core::server::web::vo::{
|
||||
ClientInfo, ClientStatusInfo, GroupList, LoginData, NetworkInfo,
|
||||
use anyhow::{anyhow, Context};
|
||||
use base64::engine::general_purpose;
|
||||
use base64::Engine;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use ipnetwork::Ipv4Network;
|
||||
use rsa::rand_core::RngCore;
|
||||
|
||||
use crate::core::entity::WireGuardConfig;
|
||||
|
||||
use crate::core::server::web::vo::req::{CreateWGData, CreateWgConfig, LoginData, RemoveClientReq};
|
||||
use crate::core::server::web::vo::res::{
|
||||
ClientInfo, ClientStatusInfo, GroupList, NetworkInfo, WGData, WgConfig,
|
||||
};
|
||||
use crate::core::service::server::{generate_ip, RegisterClientRequest};
|
||||
use crate::core::store::cache::AppCache;
|
||||
use crate::ConfigInfo;
|
||||
|
||||
@@ -60,6 +71,113 @@ impl VntsWebService {
|
||||
.collect();
|
||||
GroupList { group_list }
|
||||
}
|
||||
pub fn remove_client(&self, req: RemoveClientReq) {
|
||||
if let Some(ip) = req.virtual_ip {
|
||||
if let Some(network_info) = self.cache.virtual_network.get(&req.group_id) {
|
||||
if let Some(client_info) = network_info.write().clients.remove(&ip.into()) {
|
||||
if let Some(key) = client_info.wireguard {
|
||||
self.cache.wg_group_map.remove(&key);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if let Some(network_info) = self.cache.virtual_network.remove(&req.group_id) {
|
||||
for (_, client_info) in network_info.write().clients.drain() {
|
||||
if let Some(key) = client_info.wireguard {
|
||||
self.cache.wg_group_map.remove(&key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn gen_wg_private_key(&self) -> String {
|
||||
let mut bytes = [0u8; 32];
|
||||
rand::thread_rng().fill_bytes(&mut bytes);
|
||||
return general_purpose::STANDARD.encode(bytes);
|
||||
}
|
||||
pub async fn create_wg_config(&self, wg_data: CreateWGData) -> anyhow::Result<WGData> {
|
||||
let device_id = wg_data.device_id.trim().to_string();
|
||||
let group_id = wg_data.group_id.trim().to_string();
|
||||
if group_id.is_empty() {
|
||||
Err(anyhow!("组网id不能为空"))?;
|
||||
}
|
||||
if device_id.is_empty() {
|
||||
Err(anyhow!("设备id不能为空"))?;
|
||||
}
|
||||
let cache = &self.cache;
|
||||
let (secret_key, public_key) = Self::check_wg_config(&wg_data.config)?;
|
||||
let gateway = self.config.gateway;
|
||||
let netmask = self.config.netmask;
|
||||
let network = Ipv4Network::with_netmask(gateway, netmask)?;
|
||||
let network = Ipv4Network::with_netmask(network.network(), netmask)?;
|
||||
let virtual_ip = if wg_data.virtual_ip.trim().is_empty() {
|
||||
Ipv4Addr::UNSPECIFIED
|
||||
} else {
|
||||
Ipv4Addr::from_str(&wg_data.virtual_ip).context("虚拟IP错误")?
|
||||
};
|
||||
let register_client_request = RegisterClientRequest {
|
||||
group_id: group_id.clone(),
|
||||
virtual_ip,
|
||||
gateway,
|
||||
netmask,
|
||||
allow_ip_change: false,
|
||||
device_id: device_id.clone(),
|
||||
version: String::from("wg"),
|
||||
name: wg_data.name.clone(),
|
||||
client_secret: true,
|
||||
client_secret_hash: vec![],
|
||||
server_secret: true,
|
||||
address: SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(),
|
||||
tcp_sender: None,
|
||||
online: false,
|
||||
wireguard: Some(public_key),
|
||||
};
|
||||
let response = generate_ip(cache, register_client_request).await?;
|
||||
let wireguard_config = WireGuardConfig {
|
||||
vnts_endpoint: wg_data.config.vnts_endpoint.clone(),
|
||||
vnts_allowed_ips: network.to_string(),
|
||||
group_id: group_id.clone(),
|
||||
device_id: device_id.clone(),
|
||||
ip: response.virtual_ip,
|
||||
prefix: network.prefix(),
|
||||
persistent_keepalive: wg_data.config.persistent_keepalive,
|
||||
secret_key,
|
||||
public_key,
|
||||
};
|
||||
cache.wg_group_map.insert(public_key, wireguard_config);
|
||||
let config = WgConfig {
|
||||
vnts_endpoint: wg_data.config.vnts_endpoint,
|
||||
vnts_public_key: general_purpose::STANDARD.encode(&self.config.wg_public_key),
|
||||
vnts_allowed_ips: network.to_string(),
|
||||
public_key: general_purpose::STANDARD.encode(public_key),
|
||||
private_key: general_purpose::STANDARD.encode(secret_key),
|
||||
ip: response.virtual_ip,
|
||||
prefix: network.prefix(),
|
||||
persistent_keepalive: wg_data.config.persistent_keepalive,
|
||||
};
|
||||
let wg_data = WGData {
|
||||
group_id,
|
||||
virtual_ip: response.virtual_ip,
|
||||
device_id,
|
||||
name: wg_data.name,
|
||||
config,
|
||||
};
|
||||
Ok(wg_data)
|
||||
}
|
||||
fn check_wg_config(config: &CreateWgConfig) -> anyhow::Result<([u8; 32], [u8; 32])> {
|
||||
let addr = SocketAddr::from_str(&config.vnts_endpoint).context("服务器地址错误")?;
|
||||
if addr.ip().is_unspecified() || addr.port() == 0 {
|
||||
Err(anyhow!("服务端地址错误"))?
|
||||
}
|
||||
let private_key = general_purpose::STANDARD
|
||||
.decode(&config.private_key)
|
||||
.context("私钥错误")?;
|
||||
let private_key: [u8; 32] = private_key.try_into().map_err(|_| anyhow!("私钥错误"))?;
|
||||
let secret_key = boringtun::x25519::StaticSecret::from(private_key);
|
||||
let public_key = *boringtun::x25519::PublicKey::from(&secret_key).as_bytes();
|
||||
|
||||
Ok((private_key, public_key))
|
||||
}
|
||||
pub fn group_info(&self, group: String) -> Option<NetworkInfo> {
|
||||
if let Some(info) = self.cache.virtual_network.get(&group) {
|
||||
let guard = info.read();
|
||||
@@ -67,19 +185,20 @@ impl VntsWebService {
|
||||
guard.network_ip.into(),
|
||||
guard.mask_ip.into(),
|
||||
guard.gateway_ip.into(),
|
||||
general_purpose::STANDARD.encode(&self.config.wg_public_key),
|
||||
);
|
||||
for into in guard.clients.values() {
|
||||
let address = match into.address {
|
||||
SocketAddr::V4(_) => into.address,
|
||||
for info in guard.clients.values() {
|
||||
let address = match info.address {
|
||||
SocketAddr::V4(_) => info.address,
|
||||
SocketAddr::V6(ipv6) => {
|
||||
if let Some(ipv4) = ipv6.ip().to_ipv4_mapped() {
|
||||
SocketAddr::V4(SocketAddrV4::new(ipv4, ipv6.port()))
|
||||
} else {
|
||||
into.address
|
||||
info.address
|
||||
}
|
||||
}
|
||||
};
|
||||
let status_info = if let Some(client_status) = &into.client_status {
|
||||
let status_info = if let Some(client_status) = &info.client_status {
|
||||
Some(ClientStatusInfo {
|
||||
p2p_list: client_status.p2p_list.clone(),
|
||||
up_stream: client_status.up_stream,
|
||||
@@ -93,18 +212,24 @@ impl VntsWebService {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut wg_config = None;
|
||||
if let Some(key) = &info.wireguard {
|
||||
if let Some(v) = self.cache.wg_group_map.get(key) {
|
||||
wg_config.replace(v.clone());
|
||||
}
|
||||
}
|
||||
let client_info = ClientInfo {
|
||||
device_id: into.device_id.clone(),
|
||||
version: into.version.clone(),
|
||||
name: into.name.clone(),
|
||||
client_secret: into.client_secret,
|
||||
server_secret: into.server_secret,
|
||||
device_id: info.device_id.clone(),
|
||||
version: info.version.clone(),
|
||||
name: info.name.clone(),
|
||||
client_secret: info.client_secret,
|
||||
server_secret: info.server_secret,
|
||||
address,
|
||||
online: into.online,
|
||||
virtual_ip: into.virtual_ip.into(),
|
||||
online: info.online,
|
||||
virtual_ip: info.virtual_ip.into(),
|
||||
status_info,
|
||||
last_join_time: into.last_join_time.format("%Y-%m-%d %H:%M:%S").to_string(),
|
||||
last_join_time: info.last_join_time.format("%Y-%m-%d %H:%M:%S").to_string(),
|
||||
wg_config: wg_config.map(|v| v.into()),
|
||||
};
|
||||
network.clients.push(client_info);
|
||||
}
|
||||
@@ -116,32 +241,4 @@ impl VntsWebService {
|
||||
None
|
||||
}
|
||||
}
|
||||
// pub fn groups_info(&self) -> GroupsInfo {
|
||||
// let mut data = GroupsInfo::new();
|
||||
// for (group, info) in self.cache.virtual_network.key_values() {
|
||||
// let guard = info.read();
|
||||
// let mut network = NetworkInfo::new(
|
||||
// guard.network_ip.into(),
|
||||
// guard.mask_ip.into(),
|
||||
// guard.gateway_ip.into(),
|
||||
// );
|
||||
// for (_ip, into) in guard.clients.iter() {
|
||||
// let client_info = ClientInfo {
|
||||
// device_id: into.device_id.clone(),
|
||||
// name: into.name.clone(),
|
||||
// client_secret: into.client_secret,
|
||||
// server_secret: into.server_secret.is_some(),
|
||||
// address: into.address,
|
||||
// online: into.online,
|
||||
// virtual_ip: into.virtual_ip.into(),
|
||||
// };
|
||||
// network.clients.push(client_info);
|
||||
// }
|
||||
// network
|
||||
// .clients
|
||||
// .sort_by(|v1, v2| v1.virtual_ip.cmp(&v2.virtual_ip));
|
||||
// data.data.insert(group.to_string(), network);
|
||||
// }
|
||||
// data
|
||||
// }
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub mod req;
|
||||
pub mod res;
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ResponseMessage<V> {
|
||||
data: V,
|
||||
@@ -38,73 +37,3 @@ impl ResponseMessage<Option<()>> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ClientInfo {
|
||||
// 设备ID
|
||||
pub device_id: String,
|
||||
// 客户端版本
|
||||
pub version: String,
|
||||
// 名称
|
||||
pub name: String,
|
||||
// 客户端间是否加密
|
||||
pub client_secret: bool,
|
||||
// 客户端和服务端是否加密
|
||||
pub server_secret: bool,
|
||||
// 链接服务器的来源地址
|
||||
pub address: SocketAddr,
|
||||
// 是否在线
|
||||
pub online: bool,
|
||||
// 分配的ip
|
||||
pub virtual_ip: Ipv4Addr,
|
||||
pub status_info: Option<ClientStatusInfo>,
|
||||
pub last_join_time: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ClientStatusInfo {
|
||||
pub p2p_list: Vec<Ipv4Addr>,
|
||||
pub up_stream: u64,
|
||||
pub down_stream: u64,
|
||||
pub is_cone: bool,
|
||||
pub update_time: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct NetworkInfo {
|
||||
// 网段
|
||||
pub network_ip: Ipv4Addr,
|
||||
// 掩码
|
||||
pub mask_ip: Ipv4Addr,
|
||||
// 网关
|
||||
pub gateway_ip: Ipv4Addr,
|
||||
// 网段下的客户端列表
|
||||
pub clients: Vec<ClientInfo>,
|
||||
}
|
||||
|
||||
impl NetworkInfo {
|
||||
pub fn new(network_ip: Ipv4Addr, mask_ip: Ipv4Addr, gateway_ip: Ipv4Addr) -> Self {
|
||||
Self {
|
||||
network_ip,
|
||||
mask_ip,
|
||||
gateway_ip,
|
||||
clients: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GroupList {
|
||||
pub group_list: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GroupsInfo {
|
||||
pub data: HashMap<String, NetworkInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct LoginData {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
29
src/core/server/web/vo/req.rs
Normal file
29
src/core/server/web/vo/req.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CreateWGData {
|
||||
pub group_id: String,
|
||||
pub virtual_ip: String,
|
||||
pub device_id: String,
|
||||
pub name: String,
|
||||
pub config: CreateWgConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CreateWgConfig {
|
||||
pub vnts_endpoint: String,
|
||||
pub private_key: String,
|
||||
pub persistent_keepalive: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct LoginData {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RemoveClientReq {
|
||||
pub group_id: String,
|
||||
pub virtual_ip: Option<Ipv4Addr>,
|
||||
}
|
||||
123
src/core/server/web/vo/res.rs
Normal file
123
src/core/server/web/vo/res.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use crate::core::entity::WireGuardConfig;
|
||||
use base64::engine::general_purpose;
|
||||
use base64::Engine;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct WGData {
|
||||
pub group_id: String,
|
||||
pub virtual_ip: Ipv4Addr,
|
||||
pub device_id: String,
|
||||
pub name: String,
|
||||
pub config: WgConfig,
|
||||
}
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct WgConfig {
|
||||
pub vnts_endpoint: String,
|
||||
pub vnts_public_key: String,
|
||||
pub vnts_allowed_ips: String,
|
||||
|
||||
pub public_key: String,
|
||||
pub private_key: String,
|
||||
// 合一起是 Address = ip/prefix
|
||||
pub ip: Ipv4Addr,
|
||||
pub prefix: u8,
|
||||
pub persistent_keepalive: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ClientInfo {
|
||||
// 设备ID
|
||||
pub device_id: String,
|
||||
// 客户端版本
|
||||
pub version: String,
|
||||
// 名称
|
||||
pub name: String,
|
||||
// 客户端间是否加密
|
||||
pub client_secret: bool,
|
||||
// 客户端和服务端是否加密
|
||||
pub server_secret: bool,
|
||||
// 链接服务器的来源地址
|
||||
pub address: SocketAddr,
|
||||
// 是否在线
|
||||
pub online: bool,
|
||||
// 分配的ip
|
||||
pub virtual_ip: Ipv4Addr,
|
||||
pub status_info: Option<ClientStatusInfo>,
|
||||
pub last_join_time: String,
|
||||
// wg配置
|
||||
pub wg_config: Option<WireGuardConfigRes>,
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct WireGuardConfigRes {
|
||||
pub vnts_endpoint: String,
|
||||
pub vnts_allowed_ips: String,
|
||||
pub group_id: String,
|
||||
pub device_id: String,
|
||||
pub ip: Ipv4Addr,
|
||||
pub prefix: u8,
|
||||
pub persistent_keepalive: u16,
|
||||
pub secret_key: String,
|
||||
pub public_key: String,
|
||||
}
|
||||
impl From<WireGuardConfig> for WireGuardConfigRes {
|
||||
fn from(value: WireGuardConfig) -> Self {
|
||||
Self {
|
||||
vnts_endpoint: value.vnts_endpoint,
|
||||
vnts_allowed_ips: value.vnts_allowed_ips,
|
||||
group_id: value.group_id,
|
||||
device_id: value.device_id,
|
||||
ip: value.ip,
|
||||
prefix: value.prefix,
|
||||
persistent_keepalive: value.persistent_keepalive,
|
||||
secret_key: general_purpose::STANDARD.encode(&value.secret_key),
|
||||
public_key: general_purpose::STANDARD.encode(&value.public_key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ClientStatusInfo {
|
||||
pub p2p_list: Vec<Ipv4Addr>,
|
||||
pub up_stream: u64,
|
||||
pub down_stream: u64,
|
||||
pub is_cone: bool,
|
||||
pub update_time: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct NetworkInfo {
|
||||
// 网段
|
||||
pub network_ip: Ipv4Addr,
|
||||
// 掩码
|
||||
pub mask_ip: Ipv4Addr,
|
||||
// 网关
|
||||
pub gateway_ip: Ipv4Addr,
|
||||
// vnts的公钥
|
||||
pub vnts_public_key: String,
|
||||
// 网段下的客户端列表
|
||||
pub clients: Vec<ClientInfo>,
|
||||
}
|
||||
|
||||
impl NetworkInfo {
|
||||
pub fn new(
|
||||
network_ip: Ipv4Addr,
|
||||
mask_ip: Ipv4Addr,
|
||||
gateway_ip: Ipv4Addr,
|
||||
vnts_public_key: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
network_ip,
|
||||
mask_ip,
|
||||
gateway_ip,
|
||||
vnts_public_key,
|
||||
clients: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GroupList {
|
||||
pub group_list: Vec<String>,
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::core::service::PacketHandler;
|
||||
use crate::core::store::cache::VntContext;
|
||||
use crate::protocol::NetPacket;
|
||||
use anyhow::Context;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
@@ -14,16 +15,23 @@ pub async fn handle_websocket_connection(
|
||||
handler: PacketHandler,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_websocket_connection0(stream, addr, handler).await {
|
||||
let mut context = VntContext {
|
||||
link_context: None,
|
||||
server_cipher: None,
|
||||
link_address: addr,
|
||||
};
|
||||
if let Err(e) = handle_websocket_connection0(&mut context, stream, addr, &handler).await {
|
||||
log::warn!("websocket err {:?} {}", e, addr);
|
||||
}
|
||||
handler.leave(context).await;
|
||||
});
|
||||
}
|
||||
|
||||
async fn handle_websocket_connection0(
|
||||
context: &mut VntContext,
|
||||
stream: TcpStream,
|
||||
addr: SocketAddr,
|
||||
handler: PacketHandler,
|
||||
handler: &PacketHandler,
|
||||
) -> anyhow::Result<()> {
|
||||
let ws_stream = accept_async(stream)
|
||||
.await
|
||||
@@ -42,13 +50,14 @@ async fn handle_websocket_connection0(
|
||||
let _ = ws_write.close().await;
|
||||
});
|
||||
let sender = Some(sender);
|
||||
|
||||
while let Some(msg) = ws_read.next().await {
|
||||
let msg = msg.with_context(|| format!("Error during WebSocket read {}", addr))?;
|
||||
match msg {
|
||||
Message::Text(txt) => log::info!("Received text message: {} {}", txt, addr),
|
||||
Message::Binary(mut data) => {
|
||||
let packet = NetPacket::new0(data.len(), &mut data)?;
|
||||
if let Some(rs) = handler.handle(packet, addr, &sender).await {
|
||||
if let Some(rs) = handler.handle(context, packet, addr, &sender).await {
|
||||
if sender
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
|
||||
470
src/core/server/wire_guard/mod.rs
Normal file
470
src/core/server/wire_guard/mod.rs
Normal file
@@ -0,0 +1,470 @@
|
||||
use crate::core::entity::{NetworkInfo, WireGuardConfig};
|
||||
use crate::core::store::cache::AppCache;
|
||||
use crate::protocol::{ip_turn_packet, NetPacket, Protocol, HEAD_LEN, MAX_TTL};
|
||||
use crate::ConfigInfo;
|
||||
use anyhow::{anyhow, Context};
|
||||
use boringtun::noise::errors::WireGuardError;
|
||||
use boringtun::noise::{handshake, Packet, Tunn, TunnResult};
|
||||
use boringtun::x25519::StaticSecret;
|
||||
use chrono::Local;
|
||||
use packet::icmp::{icmp, Kind};
|
||||
use packet::ip::ipv4;
|
||||
use packet::ip::ipv4::packet::IpV4Packet;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use rand::RngCore;
|
||||
use std::collections::HashMap;
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::mpsc::{channel, Receiver, Sender};
|
||||
|
||||
pub struct WireGuardGroup {
|
||||
cache: AppCache,
|
||||
config: ConfigInfo,
|
||||
udp: Arc<UdpSocket>,
|
||||
data_channel_map: Arc<Mutex<HashMap<SocketAddr, Sender<Vec<u8>>>>>,
|
||||
}
|
||||
|
||||
impl WireGuardGroup {
|
||||
pub fn new(cache: AppCache, config: ConfigInfo, udp: Arc<UdpSocket>) -> Self {
|
||||
Self {
|
||||
cache,
|
||||
config,
|
||||
udp,
|
||||
data_channel_map: Default::default(),
|
||||
}
|
||||
}
|
||||
pub fn handle(&mut self, buf: Vec<u8>, addr: SocketAddr) {
|
||||
if let Err(e) = self.handle0(buf, addr) {
|
||||
log::warn!("{},{}", addr, e);
|
||||
}
|
||||
}
|
||||
fn handle0(&mut self, buf: Vec<u8>, addr: SocketAddr) -> anyhow::Result<()> {
|
||||
if let Some(sender) = self.data_channel_map.lock().get(&addr) {
|
||||
sender.try_send(buf)?;
|
||||
return Ok(());
|
||||
}
|
||||
let config = self.handshake(&buf)?;
|
||||
let network_info = self
|
||||
.cache
|
||||
.virtual_network
|
||||
.get(&config.group_id)
|
||||
.context("wg配置已过期")?;
|
||||
let (network_receiver, broadcast_ip, mask_ip, gateway_ip) = {
|
||||
let mut guard = network_info.write();
|
||||
let broadcast_ip = guard.network_ip | (!guard.mask_ip);
|
||||
|
||||
let client_info = guard
|
||||
.clients
|
||||
.get_mut(&config.ip.into())
|
||||
.context("wg配置已过期")?;
|
||||
if client_info.wireguard.is_none() {
|
||||
Err(anyhow!("不是wg配置"))?;
|
||||
}
|
||||
let (network_sender, network_receiver) = channel(64);
|
||||
client_info.wg_sender = Some(network_sender);
|
||||
client_info.last_join_time = Local::now();
|
||||
client_info.timestamp = client_info.last_join_time.timestamp();
|
||||
client_info.address = addr;
|
||||
client_info.online = true;
|
||||
guard.epoch += 1;
|
||||
(
|
||||
network_receiver,
|
||||
broadcast_ip,
|
||||
guard.mask_ip,
|
||||
guard.gateway_ip,
|
||||
)
|
||||
};
|
||||
let wg = WireGuard::new(
|
||||
network_info.clone(),
|
||||
broadcast_ip.into(),
|
||||
mask_ip.into(),
|
||||
gateway_ip.into(),
|
||||
self.cache.clone(),
|
||||
self.config.wg_secret_key.clone(),
|
||||
self.udp.clone(),
|
||||
addr,
|
||||
config,
|
||||
self.data_channel_map.clone(),
|
||||
);
|
||||
let (udp_sender, udp_receiver) = channel(64);
|
||||
udp_sender.try_send(buf)?;
|
||||
self.data_channel_map.lock().insert(addr, udp_sender);
|
||||
tokio::spawn(wg.start(udp_receiver, network_receiver));
|
||||
Ok(())
|
||||
}
|
||||
#[inline]
|
||||
pub fn maybe_wg(buf: &[u8]) -> bool {
|
||||
if buf.len() < 4 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Checks the type, as well as the reserved zero fields
|
||||
let packet_type = u32::from_le_bytes(buf[0..4].try_into().unwrap());
|
||||
(1..=4).contains(&packet_type)
|
||||
}
|
||||
pub fn handshake(&mut self, buf: &[u8]) -> anyhow::Result<WireGuardConfig> {
|
||||
let packet = match Tunn::parse_incoming_packet(buf) {
|
||||
Ok(packet) => packet,
|
||||
Err(e) => Err(anyhow!("{:?}", e))?,
|
||||
};
|
||||
match packet {
|
||||
Packet::HandshakeInit(data) => {
|
||||
let half_handshake = handshake::parse_handshake_anon(
|
||||
&self.config.wg_secret_key,
|
||||
&self.config.wg_public_key,
|
||||
&data,
|
||||
)
|
||||
.map_err(|e| anyhow!("HandshakeInit {:?}", e))?;
|
||||
let config = self
|
||||
.cache
|
||||
.wg_group_map
|
||||
.get(&half_handshake.peer_static_public)
|
||||
.context("需要先在vnts配置wg信息")?
|
||||
.clone();
|
||||
Ok(config)
|
||||
}
|
||||
_ => Err(anyhow!("非握手包")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WireGuard {
|
||||
network_info: Arc<RwLock<NetworkInfo>>,
|
||||
ip: Ipv4Addr,
|
||||
broadcast_ip: Ipv4Addr,
|
||||
mask_ip: Ipv4Addr,
|
||||
gateway_ip: Ipv4Addr,
|
||||
|
||||
group_id: String,
|
||||
tunn: Tunn,
|
||||
cache: AppCache,
|
||||
wg_source_addr: SocketAddr,
|
||||
udp: Arc<UdpSocket>,
|
||||
data_channel_map: Arc<Mutex<HashMap<SocketAddr, Sender<Vec<u8>>>>>,
|
||||
}
|
||||
|
||||
impl WireGuard {
|
||||
pub fn new(
|
||||
network_info: Arc<RwLock<NetworkInfo>>,
|
||||
broadcast_ip: Ipv4Addr,
|
||||
mask_ip: Ipv4Addr,
|
||||
gateway_ip: Ipv4Addr,
|
||||
cache: AppCache,
|
||||
vnts_secret_key: StaticSecret,
|
||||
udp: Arc<UdpSocket>,
|
||||
wg_source_addr: SocketAddr,
|
||||
config: WireGuardConfig,
|
||||
data_channel_map: Arc<Mutex<HashMap<SocketAddr, Sender<Vec<u8>>>>>,
|
||||
) -> Self {
|
||||
let tunn = Tunn::new(
|
||||
vnts_secret_key,
|
||||
config.public_key.into(),
|
||||
None,
|
||||
Some(config.persistent_keepalive),
|
||||
rand::thread_rng().next_u32(),
|
||||
None,
|
||||
);
|
||||
Self {
|
||||
network_info,
|
||||
ip: config.ip,
|
||||
broadcast_ip,
|
||||
mask_ip,
|
||||
gateway_ip,
|
||||
group_id: config.group_id,
|
||||
tunn,
|
||||
cache,
|
||||
wg_source_addr,
|
||||
udp,
|
||||
data_channel_map,
|
||||
}
|
||||
}
|
||||
pub async fn start(
|
||||
mut self,
|
||||
udp_receiver: Receiver<Vec<u8>>,
|
||||
ipv4_receiver: Receiver<(Vec<u8>, Ipv4Addr)>,
|
||||
) {
|
||||
if let Err(e) = self.start0(udp_receiver, ipv4_receiver).await {
|
||||
log::warn!(
|
||||
"wg连接异常断开 {:?},{:?},{:?},{:?}",
|
||||
self.group_id,
|
||||
self.ip,
|
||||
self.wg_source_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
self.offline();
|
||||
}
|
||||
fn offline(&self) {
|
||||
if let Some(v) = self.cache.virtual_network.get(&self.group_id) {
|
||||
if let Some(v) = v.write().clients.get_mut(&self.ip.into()) {
|
||||
if v.address == self.wg_source_addr {
|
||||
v.online = false;
|
||||
v.wg_sender = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
self.data_channel_map.lock().remove(&self.wg_source_addr);
|
||||
}
|
||||
pub async fn start0(
|
||||
&mut self,
|
||||
mut udp_receiver: Receiver<Vec<u8>>,
|
||||
mut ipv4_receiver: Receiver<(Vec<u8>, Ipv4Addr)>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut interval = tokio::time::interval(Duration::from_millis(200));
|
||||
let mut dst_buf = [0; 65535];
|
||||
let mut dst_buf2 = [0; 65535];
|
||||
log::info!(
|
||||
"处理wg链接 {},{}/{},{}",
|
||||
self.group_id,
|
||||
self.ip,
|
||||
self.mask_ip,
|
||||
self.wg_source_addr
|
||||
);
|
||||
loop {
|
||||
tokio::select! {
|
||||
rs = udp_receiver.recv()=>{
|
||||
if let Some(mut data) = rs{
|
||||
self.handle_wg_data(&mut data,&mut dst_buf,&mut dst_buf2).await?;
|
||||
}else{
|
||||
break;
|
||||
}
|
||||
}
|
||||
rs = ipv4_receiver.recv()=>{
|
||||
if let Some((data,ip)) = rs{
|
||||
if let Err(e) = self.handle_ipv4_data(&data,&mut dst_buf).await{
|
||||
log::warn!("来源{},发送到wg失败,{:?}",ip,e)
|
||||
}
|
||||
}else{
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = interval.tick()=>{
|
||||
self.update_timers(&mut dst_buf,&mut dst_buf2).await?
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
pub async fn handle_ipv4_data(&mut self, buf: &[u8], dst_buf: &mut [u8]) -> anyhow::Result<()> {
|
||||
let result = self.tunn.encapsulate(buf, dst_buf);
|
||||
match result {
|
||||
TunnResult::Done => {}
|
||||
TunnResult::WriteToNetwork(data) => {
|
||||
self.udp.send_to(data, self.wg_source_addr).await?;
|
||||
}
|
||||
e => Err(anyhow!("{:?}", e))?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_wg_data(
|
||||
&mut self,
|
||||
mut buf: &mut [u8],
|
||||
dst_buf: &mut [u8],
|
||||
dst_buf2: &mut [u8],
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let mut result = self.tunn.decapsulate(None, buf, dst_buf);
|
||||
if !self.handle_tunn_result(&mut result, dst_buf2).await? {
|
||||
break;
|
||||
}
|
||||
buf = &mut [];
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
async fn handle_tunn_result(
|
||||
&mut self,
|
||||
result: &mut TunnResult<'_>,
|
||||
dst_buf: &mut [u8],
|
||||
) -> anyhow::Result<bool> {
|
||||
match result {
|
||||
TunnResult::Done => {}
|
||||
TunnResult::Err(WireGuardError::ConnectionExpired) => {
|
||||
// 超时了直接断开,vnts不重连,等对端重连
|
||||
return Err(anyhow!("链接超时"));
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
log::warn!("WireGuard数据异常 {:?}", e);
|
||||
}
|
||||
TunnResult::WriteToNetwork(data) => {
|
||||
self.udp.send_to(data, self.wg_source_addr).await?;
|
||||
return Ok(true);
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(data, _source_ip) => {
|
||||
let mut packet = IpV4Packet::new(data)?;
|
||||
let source_ip = packet.source_ip();
|
||||
let destination_ip = packet.destination_ip();
|
||||
if let Err(e) = self
|
||||
.turn_data(source_ip, destination_ip, &mut packet.buffer, dst_buf)
|
||||
.await
|
||||
{
|
||||
log::warn!("wg数据转发失败 {}->{} {:?}", source_ip, destination_ip, e);
|
||||
}
|
||||
}
|
||||
TunnResult::WriteToTunnelV6(_packet, ip) => {
|
||||
return Err(anyhow!("不支持ipv6连接 {:?}", ip))
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
/// from 'wireguard_tick':
|
||||
/// This is a state keeping function, that need to be called periodically.
|
||||
/// Recommended interval: 100ms.
|
||||
pub async fn update_timers(
|
||||
&mut self,
|
||||
dst_buf: &mut [u8],
|
||||
dst_buf2: &mut [u8],
|
||||
) -> anyhow::Result<()> {
|
||||
let mut result = self.tunn.update_timers(dst_buf);
|
||||
self.handle_tunn_result(&mut result, dst_buf2).await?;
|
||||
Ok(())
|
||||
}
|
||||
async fn turn_data(
|
||||
&mut self,
|
||||
src_ip: Ipv4Addr,
|
||||
dest_ip: Ipv4Addr,
|
||||
data: &mut [u8],
|
||||
dst_buf: &mut [u8],
|
||||
) -> anyhow::Result<()> {
|
||||
if dest_ip == self.gateway_ip {
|
||||
if self.ping(data, src_ip, dest_ip).is_ok() {
|
||||
if let Err(e) = self.handle_ipv4_data(&data, dst_buf).await {
|
||||
log::warn!("发送ping回应到wg失败,{:?}", e)
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
if dest_ip.is_broadcast() || dest_ip == self.broadcast_ip {
|
||||
// 广播
|
||||
let x: Vec<_> = self
|
||||
.network_info
|
||||
.read()
|
||||
.clients
|
||||
.values()
|
||||
.filter(|v| v.online && v.virtual_ip != u32::from(self.ip))
|
||||
.map(|v| {
|
||||
(
|
||||
v.address,
|
||||
v.tcp_sender.clone(),
|
||||
v.server_secret,
|
||||
v.wg_sender.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
for (peer_addr, peer_tcp_sender, server_secret, peer_wg_sender) in x {
|
||||
if let Err(e) = self
|
||||
.send_one(
|
||||
peer_addr,
|
||||
peer_tcp_sender,
|
||||
peer_wg_sender,
|
||||
server_secret,
|
||||
src_ip,
|
||||
dest_ip,
|
||||
data,
|
||||
dst_buf,
|
||||
)
|
||||
.await
|
||||
{
|
||||
log::warn!("wg广播失败 {} {} {:?}", src_ip, peer_addr, e);
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (server_secret, peer_addr, peer_tcp_sender, peer_wg_sender) = {
|
||||
let guard = self.network_info.read();
|
||||
if let Some(dest_client_info) = guard.clients.get(&dest_ip.into()) {
|
||||
if !dest_client_info.online {
|
||||
Err(anyhow!("目标不在线"))?
|
||||
}
|
||||
if !dest_client_info.virtual_ip == u32::from(self.ip) {
|
||||
Err(anyhow!("阻止回路"))?
|
||||
}
|
||||
let dest_link_addr = dest_client_info.address;
|
||||
let server_secret = dest_client_info.server_secret;
|
||||
(
|
||||
server_secret,
|
||||
dest_link_addr,
|
||||
dest_client_info.tcp_sender.clone(),
|
||||
dest_client_info.wg_sender.clone(),
|
||||
)
|
||||
} else {
|
||||
Err(anyhow!("目标未注册"))?
|
||||
}
|
||||
};
|
||||
|
||||
self.send_one(
|
||||
peer_addr,
|
||||
peer_tcp_sender,
|
||||
peer_wg_sender,
|
||||
server_secret,
|
||||
src_ip,
|
||||
dest_ip,
|
||||
data,
|
||||
dst_buf,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
async fn send_one(
|
||||
&self,
|
||||
peer_addr: SocketAddr,
|
||||
peer_tcp_sender: Option<Sender<Vec<u8>>>,
|
||||
peer_wg_sender: Option<Sender<(Vec<u8>, Ipv4Addr)>>,
|
||||
server_secret: bool,
|
||||
src_ip: Ipv4Addr,
|
||||
dest_ip: Ipv4Addr,
|
||||
data: &mut [u8],
|
||||
dst_buf: &mut [u8],
|
||||
) -> anyhow::Result<()> {
|
||||
if let Some(peer_wg_sender) = peer_wg_sender {
|
||||
if let Err(e) = peer_wg_sender.send((data.to_vec(), self.ip)).await {
|
||||
Err(anyhow!("发送到对端wg失败 {}", e))?
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
let mut net_packet = NetPacket::new0(HEAD_LEN + data.len(), dst_buf)?;
|
||||
net_packet.set_default_version();
|
||||
// 把wg的转发当成是服务端来源的数据,因为服务端没有客户端密钥对数据进行加密
|
||||
net_packet.set_gateway_flag(true);
|
||||
net_packet.set_protocol(Protocol::IpTurn);
|
||||
net_packet.set_transport_protocol(ip_turn_packet::Protocol::WGIpv4.into());
|
||||
net_packet.first_set_ttl(MAX_TTL);
|
||||
net_packet.set_source(src_ip);
|
||||
net_packet.set_destination(dest_ip);
|
||||
net_packet.set_payload(data)?;
|
||||
if server_secret {
|
||||
let cipher = self
|
||||
.cache
|
||||
.cipher_session
|
||||
.get(&peer_addr)
|
||||
.context("加密信息不存在")?;
|
||||
cipher.encrypt_ipv4(&mut net_packet)?;
|
||||
}
|
||||
if let Some(tcp_sender) = peer_tcp_sender {
|
||||
tcp_sender.send(net_packet.buffer().to_vec()).await?;
|
||||
} else {
|
||||
self.udp.send_to(net_packet.buffer(), peer_addr).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn ping(&self, data: &mut [u8], src_ip: Ipv4Addr, dest_ip: Ipv4Addr) -> anyhow::Result<()> {
|
||||
let mut ipv4 = IpV4Packet::new(data)?;
|
||||
if let ipv4::protocol::Protocol::Icmp = ipv4.protocol() {
|
||||
let mut icmp_packet = icmp::IcmpPacket::new(ipv4.payload_mut())?;
|
||||
if icmp_packet.kind() == Kind::EchoRequest {
|
||||
//开启ping
|
||||
icmp_packet.set_kind(Kind::EchoReply);
|
||||
icmp_packet.update_checksum();
|
||||
ipv4.set_source_ip(dest_ip);
|
||||
ipv4.set_destination_ip(src_ip);
|
||||
ipv4.update_checksum();
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Err(anyhow!("非ping Echo 不处理"))
|
||||
}
|
||||
}
|
||||
@@ -3,14 +3,13 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
use crate::cipher::RsaCipher;
|
||||
use crate::core::entity::ClientInfo;
|
||||
use crate::core::store::cache::{AppCache, Context};
|
||||
use crate::core::store::cache::{AppCache, LinkVntContext, VntContext};
|
||||
use crate::error::*;
|
||||
use crate::protocol::NetPacket;
|
||||
use crate::ConfigInfo;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ClientPacketHandler {
|
||||
@@ -37,25 +36,26 @@ impl ClientPacketHandler {
|
||||
}
|
||||
|
||||
impl ClientPacketHandler {
|
||||
pub fn handle<B: AsRef<[u8]> + AsMut<[u8]>>(
|
||||
pub async fn handle<B: AsRef<[u8]> + AsMut<[u8]>>(
|
||||
&self,
|
||||
context: &VntContext,
|
||||
net_packet: NetPacket<B>,
|
||||
addr: SocketAddr,
|
||||
_addr: SocketAddr,
|
||||
) -> Result<()> {
|
||||
if let Some(context) = self.cache.get_context(&addr) {
|
||||
self.handle0(net_packet, context)
|
||||
if let Some(context) = &context.link_context {
|
||||
self.handle0(context, net_packet).await
|
||||
} else {
|
||||
Err(Error::Disconnect)
|
||||
Err(Error::Disconnect)?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientPacketHandler {
|
||||
/// 转发到目标地址
|
||||
fn handle0<B: AsRef<[u8]> + AsMut<[u8]>>(
|
||||
async fn handle0<B: AsRef<[u8]> + AsMut<[u8]>>(
|
||||
&self,
|
||||
context: &LinkVntContext,
|
||||
mut net_packet: NetPacket<B>,
|
||||
context: Context,
|
||||
) -> Result<()> {
|
||||
if net_packet.incr_ttl() > 1 {
|
||||
if self.config.check_finger {
|
||||
@@ -65,33 +65,65 @@ impl ClientPacketHandler {
|
||||
let destination = net_packet.destination();
|
||||
if destination.is_broadcast() || self.config.broadcast == destination {
|
||||
//处理广播
|
||||
broadcast(&self.udp, context, net_packet);
|
||||
} else if let Some(client_info) =
|
||||
context.network_info.read().clients.get(&destination.into())
|
||||
{
|
||||
send_one(&self.udp, client_info, &net_packet);
|
||||
broadcast(context, &self.udp, net_packet).await;
|
||||
} else {
|
||||
let is_encrypt = net_packet.is_encrypt();
|
||||
let source_ip = u32::from(net_packet.source());
|
||||
let rs = context
|
||||
.network_info
|
||||
.read()
|
||||
.clients
|
||||
.get(&destination.into())
|
||||
.filter(|v| {
|
||||
v.wireguard.is_none()
|
||||
&& v.online
|
||||
&& v.client_secret == is_encrypt
|
||||
&& v.virtual_ip != source_ip
|
||||
})
|
||||
.map(|v| (v.address, v.tcp_sender.clone()));
|
||||
if let Some((peer_addr, peer_tcp_sender)) = rs {
|
||||
send_one(&self.udp, peer_addr, peer_tcp_sender, &net_packet).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn broadcast<B: AsRef<[u8]>>(udp_socket: &UdpSocket, context: Context, net_packet: NetPacket<B>) {
|
||||
for client_info in context.network_info.read().clients.values() {
|
||||
send_one(udp_socket, client_info, &net_packet);
|
||||
async fn broadcast<B: AsRef<[u8]>>(
|
||||
context: &LinkVntContext,
|
||||
udp_socket: &UdpSocket,
|
||||
net_packet: NetPacket<B>,
|
||||
) {
|
||||
let is_encrypt = net_packet.is_encrypt();
|
||||
let source_ip = u32::from(net_packet.source());
|
||||
let x: Vec<_> = context
|
||||
.network_info
|
||||
.read()
|
||||
.clients
|
||||
.values()
|
||||
.filter(|v| {
|
||||
v.wireguard.is_none()
|
||||
&& v.online
|
||||
&& v.client_secret == is_encrypt
|
||||
&& v.virtual_ip != source_ip
|
||||
})
|
||||
.map(|v| (v.address, v.tcp_sender.clone()))
|
||||
.collect();
|
||||
for (peer_addr, peer_tcp_sender) in x {
|
||||
send_one(udp_socket, peer_addr, peer_tcp_sender, &net_packet).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn send_one<B: AsRef<[u8]>>(
|
||||
async fn send_one<B: AsRef<[u8]>>(
|
||||
udp_socket: &UdpSocket,
|
||||
client_info: &ClientInfo,
|
||||
peer_addr: SocketAddr,
|
||||
peer_tcp_sender: Option<Sender<Vec<u8>>>,
|
||||
net_packet: &NetPacket<B>,
|
||||
) {
|
||||
if client_info.online && client_info.client_secret == net_packet.is_encrypt() {
|
||||
if let Some(sender) = &client_info.tcp_sender {
|
||||
let _ = sender.try_send(net_packet.buffer().to_vec());
|
||||
} else {
|
||||
let _ = udp_socket.try_send_to(net_packet.buffer(), client_info.address);
|
||||
}
|
||||
if let Some(sender) = &peer_tcp_sender {
|
||||
let _ = sender.send(net_packet.buffer().to_vec()).await;
|
||||
} else {
|
||||
let _ = udp_socket.send_to(net_packet.buffer(), peer_addr).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use tokio::sync::mpsc::Sender;
|
||||
use crate::cipher::RsaCipher;
|
||||
use crate::core::service::client::ClientPacketHandler;
|
||||
use crate::core::service::server::ServerPacketHandler;
|
||||
use crate::core::store::cache::AppCache;
|
||||
use crate::core::store::cache::{AppCache, VntContext};
|
||||
use crate::error::*;
|
||||
use crate::protocol::NetPacket;
|
||||
use crate::ConfigInfo;
|
||||
@@ -41,13 +41,17 @@ impl PacketHandler {
|
||||
}
|
||||
|
||||
impl PacketHandler {
|
||||
pub async fn leave(&self, context: VntContext) {
|
||||
self.server.leave(context).await;
|
||||
}
|
||||
pub async fn handle<B: AsRef<[u8]> + AsMut<[u8]>>(
|
||||
&self,
|
||||
context: &mut VntContext,
|
||||
net_packet: NetPacket<B>,
|
||||
addr: SocketAddr,
|
||||
tcp_sender: &Option<Sender<Vec<u8>>>,
|
||||
) -> Option<NetPacket<Vec<u8>>> {
|
||||
self.handle0(net_packet, addr, tcp_sender)
|
||||
self.handle0(context, net_packet, addr, tcp_sender)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
log::error!("addr={},{:?}", addr, e);
|
||||
@@ -56,14 +60,17 @@ impl PacketHandler {
|
||||
}
|
||||
async fn handle0<B: AsRef<[u8]> + AsMut<[u8]>>(
|
||||
&self,
|
||||
context: &mut VntContext,
|
||||
net_packet: NetPacket<B>,
|
||||
addr: SocketAddr,
|
||||
tcp_sender: &Option<Sender<Vec<u8>>>,
|
||||
) -> Result<Option<NetPacket<Vec<u8>>>> {
|
||||
if net_packet.is_gateway() {
|
||||
self.server.handle(net_packet, addr, tcp_sender).await
|
||||
self.server
|
||||
.handle(context, net_packet, addr, tcp_sender)
|
||||
.await
|
||||
} else {
|
||||
self.client.handle(net_packet, addr)?;
|
||||
self.client.handle(context, net_packet, addr).await?;
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
use anyhow::{anyhow, Context};
|
||||
use chrono::Local;
|
||||
use packet::icmp::{icmp, Kind};
|
||||
use packet::ip::ipv4;
|
||||
use packet::ip::ipv4::packet::IpV4Packet;
|
||||
use protobuf::Message;
|
||||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{io, result};
|
||||
|
||||
use protobuf::Message;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::cipher::{Aes256GcmCipher, Finger, RsaCipher};
|
||||
use crate::core::entity::{ClientInfo, ClientStatusInfo, NetworkInfo};
|
||||
use crate::core::store::cache::{AppCache, Context};
|
||||
use crate::core::entity::{ClientInfo, ClientStatusInfo, NetworkInfo, SimpleClientInfo};
|
||||
use crate::core::store::cache::{AppCache, LinkVntContext, VntContext};
|
||||
use crate::error::*;
|
||||
use crate::proto::message;
|
||||
use crate::proto::message::{DeviceList, RegistrationRequest, RegistrationResponse};
|
||||
@@ -48,8 +48,12 @@ impl ServerPacketHandler {
|
||||
}
|
||||
|
||||
impl ServerPacketHandler {
|
||||
pub async fn leave(&self, context: VntContext) {
|
||||
context.leave(&self.cache).await;
|
||||
}
|
||||
pub async fn handle<B: AsRef<[u8]> + AsMut<[u8]>>(
|
||||
&self,
|
||||
context: &mut VntContext,
|
||||
mut net_packet: NetPacket<B>,
|
||||
addr: SocketAddr,
|
||||
tcp_sender: &Option<Sender<Vec<u8>>>,
|
||||
@@ -66,26 +70,24 @@ impl ServerPacketHandler {
|
||||
}
|
||||
service_packet::Protocol::SecretHandshakeRequest => {
|
||||
// 加密握手
|
||||
let rs = self.secret_handshake(net_packet, addr).await?;
|
||||
let rs = self.secret_handshake(context, net_packet, addr).await?;
|
||||
return Ok(Some(rs));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// 解密
|
||||
let aes = if net_packet.is_encrypt() {
|
||||
if let Some(aes) = self.cache.cipher_session.get(&addr) {
|
||||
let server_secret = net_packet.is_encrypt();
|
||||
if server_secret {
|
||||
if let Some(aes) = &context.server_cipher {
|
||||
aes.decrypt_ipv4(&mut net_packet)?;
|
||||
Some(aes)
|
||||
} else {
|
||||
log::info!("没有密钥:{},head={:?}", addr, net_packet.head());
|
||||
return Ok(Some(self.handle_err(addr, source, Error::NoKey)?));
|
||||
return Ok(Some(self.handle_err(addr, source, &Error::NoKey)?));
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
}
|
||||
let mut packet = match self
|
||||
.handle0(net_packet, addr, tcp_sender, aes.is_some())
|
||||
.handle0(context, net_packet, addr, tcp_sender, server_secret)
|
||||
.await
|
||||
{
|
||||
Ok(rs) => {
|
||||
@@ -95,11 +97,13 @@ impl ServerPacketHandler {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
Err(e) => self.handle_err(addr, source, e)?,
|
||||
Err(e) => self.handle_anyhow_err(addr, source, e)?,
|
||||
};
|
||||
self.common_param(&mut packet, source);
|
||||
if let Some(aes) = aes {
|
||||
aes.encrypt_ipv4(&mut packet)?;
|
||||
if server_secret {
|
||||
if let Some(aes) = &context.server_cipher {
|
||||
aes.encrypt_ipv4(&mut packet)?;
|
||||
}
|
||||
}
|
||||
Ok(Some(packet))
|
||||
}
|
||||
@@ -115,20 +119,28 @@ impl ServerPacketHandler {
|
||||
net_packet.first_set_ttl(MAX_TTL);
|
||||
net_packet.set_gateway_flag(true);
|
||||
}
|
||||
fn handle_anyhow_err(
|
||||
&self,
|
||||
addr: SocketAddr,
|
||||
source: Ipv4Addr,
|
||||
e: anyhow::Error,
|
||||
) -> Result<NetPacket<Vec<u8>>> {
|
||||
if let Some(e) = e.downcast_ref() {
|
||||
self.handle_err(addr, source, e)
|
||||
} else {
|
||||
self.handle_err(addr, source, &Error::Other(format!("{}", e)))
|
||||
}
|
||||
}
|
||||
fn handle_err(
|
||||
&self,
|
||||
addr: SocketAddr,
|
||||
source: Ipv4Addr,
|
||||
e: Error,
|
||||
e: &Error,
|
||||
) -> Result<NetPacket<Vec<u8>>> {
|
||||
log::warn!("addr={},source={},{:?}", addr, source, e);
|
||||
let rs = vec![0u8; 12 + ENCRYPTION_RESERVED];
|
||||
let mut packet = NetPacket::new_encrypt(rs)?;
|
||||
match e {
|
||||
Error::Io(_) => {}
|
||||
Error::Channel(_) => {}
|
||||
Error::Protobuf(_) => {}
|
||||
|
||||
Error::AddressExhausted => {
|
||||
packet.set_transport_protocol(error_packet::Protocol::AddressExhausted.into());
|
||||
}
|
||||
@@ -161,6 +173,7 @@ impl ServerPacketHandler {
|
||||
}
|
||||
async fn handle0<B: AsRef<[u8]> + AsMut<[u8]>>(
|
||||
&self,
|
||||
context: &mut VntContext,
|
||||
net_packet: NetPacket<B>,
|
||||
addr: SocketAddr,
|
||||
tcp_sender: &Option<Sender<Vec<u8>>>,
|
||||
@@ -168,7 +181,7 @@ impl ServerPacketHandler {
|
||||
) -> Result<Option<NetPacket<Vec<u8>>>> {
|
||||
// 处理不需要连接上下文的请求
|
||||
let mut net_packet = match self
|
||||
.not_context(net_packet, addr, tcp_sender, server_secret)
|
||||
.not_context(context, net_packet, addr, tcp_sender, server_secret)
|
||||
.await
|
||||
{
|
||||
Ok(rs) => {
|
||||
@@ -177,10 +190,10 @@ impl ServerPacketHandler {
|
||||
Err(net_packet) => net_packet,
|
||||
};
|
||||
// 需要连接的上下文
|
||||
let context = if let Some(context) = self.cache.get_context(&addr) {
|
||||
context
|
||||
let link_context = if let Some(link_context) = &context.link_context {
|
||||
link_context
|
||||
} else {
|
||||
return Err(Error::Disconnect);
|
||||
return Err(Error::Disconnect)?;
|
||||
};
|
||||
|
||||
match net_packet.protocol() {
|
||||
@@ -188,13 +201,13 @@ impl ServerPacketHandler {
|
||||
match protocol::service_packet::Protocol::from(net_packet.transport_protocol()) {
|
||||
service_packet::Protocol::PullDeviceList => {
|
||||
//拉取网段设备信息
|
||||
return self.poll_device_list(net_packet, addr, &context);
|
||||
return self.poll_device_list(net_packet, addr, &link_context);
|
||||
}
|
||||
service_packet::Protocol::ClientStatusInfo => {
|
||||
//客户端上报信息
|
||||
let client_status_info =
|
||||
message::ClientStatusInfo::parse_from_bytes(net_packet.payload())?;
|
||||
self.up_client_status_info(client_status_info, &context);
|
||||
self.up_client_status_info(client_status_info, &link_context);
|
||||
return Ok(None);
|
||||
}
|
||||
_ => {}
|
||||
@@ -205,17 +218,22 @@ impl ServerPacketHandler {
|
||||
if let control_packet::Protocol::Ping =
|
||||
protocol::control_packet::Protocol::from(net_packet.transport_protocol())
|
||||
{
|
||||
return self.control_ping(net_packet, &context);
|
||||
return self.control_ping(net_packet, &link_context);
|
||||
}
|
||||
}
|
||||
Protocol::IpTurn => {
|
||||
match protocol::ip_turn_packet::Protocol::from(net_packet.transport_protocol()) {
|
||||
protocol::ip_turn_packet::Protocol::WGIpv4 => {
|
||||
//wg数据转发
|
||||
self.wg_ipv4(&link_context, net_packet).await?;
|
||||
return Ok(None);
|
||||
}
|
||||
protocol::ip_turn_packet::Protocol::Ipv4Broadcast => {
|
||||
//处理选择性广播,进过网关还原成原始广播
|
||||
let broadcast_packet = BroadcastPacket::new(net_packet.payload())?;
|
||||
let exclude = broadcast_packet.addresses();
|
||||
let broadcast_net_packet = NetPacket::new(broadcast_packet.data()?)?;
|
||||
self.broadcast(&context, broadcast_net_packet, &exclude)?;
|
||||
self.broadcast(&link_context, broadcast_net_packet, &exclude)?;
|
||||
return Ok(None);
|
||||
}
|
||||
protocol::ip_turn_packet::Protocol::Ipv4 => {
|
||||
@@ -258,6 +276,7 @@ impl ServerPacketHandler {
|
||||
impl ServerPacketHandler {
|
||||
async fn not_context<B: AsRef<[u8]>>(
|
||||
&self,
|
||||
context: &mut VntContext,
|
||||
net_packet: NetPacket<B>,
|
||||
addr: SocketAddr,
|
||||
tcp_sender: &Option<Sender<Vec<u8>>>,
|
||||
@@ -269,7 +288,7 @@ impl ServerPacketHandler {
|
||||
{
|
||||
//注册
|
||||
return Ok(self
|
||||
.register(net_packet, addr, tcp_sender, server_secret)
|
||||
.register(context, net_packet, addr, tcp_sender, server_secret)
|
||||
.await);
|
||||
}
|
||||
} else if net_packet.protocol() == Protocol::Control {
|
||||
@@ -287,7 +306,7 @@ impl ServerPacketHandler {
|
||||
fn control_ping<B: AsRef<[u8]>>(
|
||||
&self,
|
||||
net_packet: NetPacket<B>,
|
||||
context: &Context,
|
||||
context: &LinkVntContext,
|
||||
) -> Result<Option<NetPacket<Vec<u8>>>> {
|
||||
let vec = vec![0u8; 12 + 4 + ENCRYPTION_RESERVED];
|
||||
let mut packet = NetPacket::new_encrypt(vec)?;
|
||||
@@ -324,13 +343,13 @@ impl ServerPacketHandler {
|
||||
impl ServerPacketHandler {
|
||||
async fn register<B: AsRef<[u8]>>(
|
||||
&self,
|
||||
context: &mut VntContext,
|
||||
net_packet: NetPacket<B>,
|
||||
addr: SocketAddr,
|
||||
tcp_sender: &Option<Sender<Vec<u8>>>,
|
||||
server_secret: bool,
|
||||
) -> Result<Option<NetPacket<Vec<u8>>>> {
|
||||
let config = &self.config;
|
||||
let cache = &self.cache;
|
||||
let request = RegistrationRequest::parse_from_bytes(net_packet.payload())?;
|
||||
check_reg(&request)?;
|
||||
log::info!(
|
||||
@@ -346,6 +365,8 @@ impl ServerPacketHandler {
|
||||
tcp_sender.is_some()
|
||||
);
|
||||
let group_id = request.token.clone();
|
||||
let gateway = config.gateway;
|
||||
let netmask = config.netmask;
|
||||
if let Some(white_token) = &config.white_token {
|
||||
if !white_token.contains(&group_id) {
|
||||
log::info!(
|
||||
@@ -353,7 +374,7 @@ impl ServerPacketHandler {
|
||||
white_token,
|
||||
group_id
|
||||
);
|
||||
return Err(Error::TokenError);
|
||||
Err(Error::TokenError)?
|
||||
}
|
||||
}
|
||||
let mut response = RegistrationResponse::new();
|
||||
@@ -371,119 +392,46 @@ impl ServerPacketHandler {
|
||||
}
|
||||
}
|
||||
}
|
||||
//固定网段
|
||||
let gateway: u32 = config.gateway.into();
|
||||
let netmask: u32 = config.netmask.into();
|
||||
let network: u32 = gateway & netmask;
|
||||
let register_client_request = RegisterClientRequest {
|
||||
group_id: group_id.clone(),
|
||||
virtual_ip: request.virtual_ip.into(),
|
||||
gateway,
|
||||
netmask,
|
||||
allow_ip_change: request.allow_ip_change,
|
||||
device_id: request.device_id,
|
||||
version: request.version,
|
||||
name: request.name,
|
||||
client_secret: request.client_secret,
|
||||
client_secret_hash: request.client_secret_hash,
|
||||
server_secret,
|
||||
address: addr,
|
||||
tcp_sender: tcp_sender.clone(),
|
||||
online: true,
|
||||
wireguard: None,
|
||||
};
|
||||
let register_response = generate_ip(&self.cache, register_client_request).await?;
|
||||
let virtual_ip = register_response.virtual_ip.into();
|
||||
response.virtual_gateway = gateway.into();
|
||||
response.virtual_netmask = netmask.into();
|
||||
response.virtual_ip = virtual_ip;
|
||||
response.epoch = register_response.epoch as u32;
|
||||
response.device_info_list = register_response
|
||||
.client_list
|
||||
.into_iter()
|
||||
.map(|v| v.into())
|
||||
.collect();
|
||||
context.link_context.replace(LinkVntContext {
|
||||
network_info: self
|
||||
.cache
|
||||
.virtual_network
|
||||
.get(&group_id)
|
||||
.context("virtual_network is none")?,
|
||||
group: group_id.clone(),
|
||||
virtual_ip,
|
||||
broadcast: config.broadcast,
|
||||
timestamp: register_response.timestamp,
|
||||
});
|
||||
|
||||
response.virtual_netmask = netmask;
|
||||
response.virtual_gateway = gateway;
|
||||
|
||||
let v = cache
|
||||
.virtual_network
|
||||
.optionally_get_with(group_id.clone(), || {
|
||||
(
|
||||
Duration::from_secs(7 * 24 * 3600),
|
||||
Arc::new(parking_lot::const_rwlock(NetworkInfo::new(
|
||||
network, netmask, gateway,
|
||||
))),
|
||||
)
|
||||
})
|
||||
.await;
|
||||
let mut virtual_ip = request.virtual_ip;
|
||||
// 可分配的ip段
|
||||
let ip_range = network + 1..gateway | (!netmask);
|
||||
let timestamp = Local::now().timestamp();
|
||||
{
|
||||
let mut lock = v.write();
|
||||
let mut insert = true;
|
||||
if virtual_ip != 0 {
|
||||
if u32::from(config.gateway) == virtual_ip
|
||||
|| u32::from(config.broadcast) == virtual_ip
|
||||
|| !ip_range.contains(&virtual_ip)
|
||||
{
|
||||
log::warn!("手动指定的ip无效: {:?}", request);
|
||||
return Err(Error::InvalidIp);
|
||||
}
|
||||
//指定了ip
|
||||
if let Some(info) = lock.clients.get_mut(&request.virtual_ip) {
|
||||
if info.device_id != request.device_id {
|
||||
//ip被占用了,并且不能更改ip
|
||||
if !request.allow_ip_change {
|
||||
log::warn!("手动指定的ip已经存在:{:?}", request);
|
||||
return Err(Error::IpAlreadyExists);
|
||||
}
|
||||
// 重新挑选ip
|
||||
virtual_ip = 0;
|
||||
} else {
|
||||
insert = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut old_ip = 0;
|
||||
if insert {
|
||||
// 找到上一次用的ip
|
||||
for (ip, x) in &lock.clients {
|
||||
if x.device_id == request.device_id {
|
||||
if virtual_ip == 0 {
|
||||
virtual_ip = *ip;
|
||||
} else {
|
||||
old_ip = *ip;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if virtual_ip == 0 {
|
||||
// 从小到大找一个未使用的ip
|
||||
for ip in ip_range {
|
||||
if ip == lock.gateway_ip {
|
||||
continue;
|
||||
}
|
||||
if !lock.clients.contains_key(&ip) {
|
||||
virtual_ip = ip;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if virtual_ip == 0 {
|
||||
log::error!("地址使用完:{:?}", request);
|
||||
return Err(Error::AddressExhausted);
|
||||
}
|
||||
let info = if old_ip == 0 {
|
||||
lock.clients
|
||||
.entry(virtual_ip)
|
||||
.or_insert_with(ClientInfo::default)
|
||||
} else {
|
||||
let client_info = lock.clients.remove(&old_ip).unwrap();
|
||||
lock.clients
|
||||
.entry(virtual_ip)
|
||||
.or_insert_with(|| client_info)
|
||||
};
|
||||
info.name = request.name;
|
||||
info.device_id = request.device_id;
|
||||
info.version = request.version;
|
||||
info.client_secret = request.client_secret;
|
||||
info.server_secret = server_secret;
|
||||
info.address = addr;
|
||||
info.online = true;
|
||||
info.virtual_ip = virtual_ip;
|
||||
info.tcp_sender = tcp_sender.clone();
|
||||
info.last_join_time = Local::now();
|
||||
info.timestamp = timestamp;
|
||||
lock.epoch += 1;
|
||||
response.virtual_ip = virtual_ip;
|
||||
response.epoch = lock.epoch as u32;
|
||||
response.device_info_list = Self::clients_info(&lock.clients, virtual_ip);
|
||||
drop(lock);
|
||||
}
|
||||
cache
|
||||
.insert_ip_session((group_id.clone(), virtual_ip), addr)
|
||||
.await;
|
||||
cache
|
||||
.insert_addr_session(addr, (group_id, virtual_ip, timestamp))
|
||||
.await;
|
||||
let bytes = response.write_to_bytes()?;
|
||||
let rs = vec![0u8; 12 + bytes.len() + ENCRYPTION_RESERVED];
|
||||
let mut packet = NetPacket::new_encrypt(rs)?;
|
||||
@@ -496,13 +444,16 @@ impl ServerPacketHandler {
|
||||
|
||||
fn check_reg(request: &RegistrationRequest) -> Result<()> {
|
||||
if request.token.is_empty() || request.token.len() > 128 {
|
||||
return Err(Error::Other("group length error".into()));
|
||||
Err(anyhow!("group length error"))?
|
||||
}
|
||||
if request.device_id.is_empty() || request.device_id.len() > 128 {
|
||||
return Err(Error::Other("device_id length error".into()));
|
||||
Err(anyhow!("device_id length error"))?
|
||||
}
|
||||
if request.name.is_empty() || request.name.len() > 128 {
|
||||
return Err(Error::Other("name length error".into()));
|
||||
Err(anyhow!("name length error"))?
|
||||
}
|
||||
if request.client_secret_hash.len() > 128 {
|
||||
Err(anyhow!("client_secret_hash length error"))?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -535,6 +486,7 @@ impl ServerPacketHandler {
|
||||
}
|
||||
async fn secret_handshake<B: AsRef<[u8]>>(
|
||||
&self,
|
||||
context: &mut VntContext,
|
||||
net_packet: NetPacket<B>,
|
||||
addr: SocketAddr,
|
||||
) -> Result<NetPacket<Vec<u8>>> {
|
||||
@@ -545,10 +497,7 @@ impl ServerPacketHandler {
|
||||
let sync_secret =
|
||||
message::SecretHandshakeRequest::parse_from_bytes(rsa_secret_body.data())?;
|
||||
let c = Aes256GcmCipher::new(
|
||||
sync_secret
|
||||
.key
|
||||
.try_into()
|
||||
.map_err(|_| Error::Other("key err".into()))?,
|
||||
sync_secret.key.try_into().map_err(|_| anyhow!("key err"))?,
|
||||
Finger::new(&sync_secret.token),
|
||||
);
|
||||
let rs = vec![0u8; 12 + ENCRYPTION_RESERVED];
|
||||
@@ -557,10 +506,11 @@ impl ServerPacketHandler {
|
||||
packet.set_transport_protocol(service_packet::Protocol::SecretHandshakeResponse.into());
|
||||
self.common_param(&mut packet, source);
|
||||
c.encrypt_ipv4(&mut packet)?;
|
||||
context.server_cipher.replace(c.clone());
|
||||
self.cache.insert_cipher_session(addr, c).await;
|
||||
return Ok(packet);
|
||||
}
|
||||
Err(Error::Other("no encryption".into()))
|
||||
Err(anyhow!("no encryption"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -569,15 +519,15 @@ impl ServerPacketHandler {
|
||||
&self,
|
||||
_net_packet: NetPacket<B>,
|
||||
_addr: SocketAddr,
|
||||
context: &Context,
|
||||
context: &LinkVntContext,
|
||||
) -> Result<Option<NetPacket<Vec<u8>>>> {
|
||||
let guard = context.network_info.read();
|
||||
let ips = Self::clients_info(&guard.clients, context.virtual_ip);
|
||||
let ips = clients_info(&guard.clients, context.virtual_ip);
|
||||
let epoch = guard.epoch;
|
||||
drop(guard);
|
||||
let mut device_list = DeviceList::new();
|
||||
device_list.epoch = epoch as u32;
|
||||
device_list.device_info_list = ips;
|
||||
device_list.device_info_list = ips.into_iter().map(|v| v.into()).collect();
|
||||
let bytes = device_list.write_to_bytes()?;
|
||||
let vec = vec![0u8; 12 + bytes.len() + ENCRYPTION_RESERVED];
|
||||
let mut device_list_packet = NetPacket::new_encrypt(vec)?;
|
||||
@@ -589,7 +539,7 @@ impl ServerPacketHandler {
|
||||
fn up_client_status_info(
|
||||
&self,
|
||||
client_status_info: message::ClientStatusInfo,
|
||||
context: &Context,
|
||||
context: &LinkVntContext,
|
||||
) {
|
||||
let mut status_info = ClientStatusInfo::default();
|
||||
let iplist = &mut status_info.p2p_list;
|
||||
@@ -612,26 +562,40 @@ impl ServerPacketHandler {
|
||||
v.client_status = Some(status_info);
|
||||
}
|
||||
}
|
||||
fn clients_info(
|
||||
clients: &HashMap<u32, ClientInfo>,
|
||||
current_ip: u32,
|
||||
) -> Vec<message::DeviceInfo> {
|
||||
clients
|
||||
.iter()
|
||||
.filter(|&(_, dev)| dev.virtual_ip != current_ip)
|
||||
.map(|(_, device_info)| {
|
||||
let mut dev = message::DeviceInfo::new();
|
||||
dev.virtual_ip = device_info.virtual_ip;
|
||||
dev.name = device_info.name.clone();
|
||||
dev.device_status = if device_info.online { 0 } else { 1 };
|
||||
dev.client_secret = device_info.client_secret;
|
||||
dev
|
||||
})
|
||||
.collect()
|
||||
async fn wg_ipv4<B: AsRef<[u8]>>(
|
||||
&self,
|
||||
context: &LinkVntContext,
|
||||
net_packet: NetPacket<B>,
|
||||
) -> anyhow::Result<()> {
|
||||
let source = net_packet.source();
|
||||
let dest = net_packet.destination();
|
||||
if dest.is_broadcast() || dest == context.broadcast {
|
||||
// 广播
|
||||
for peer in context.network_info.read().clients.values() {
|
||||
if !peer.online {
|
||||
continue;
|
||||
}
|
||||
if let Some(sender) = &peer.wg_sender {
|
||||
if let Err(e) = sender.try_send((net_packet.payload().to_vec(), source)) {
|
||||
log::info!("广播到对端wg失败 {}->{},{}", source, dest, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if let Some(peer) = context.network_info.read().clients.get(&dest.into()) {
|
||||
// 点对点
|
||||
if peer.online {
|
||||
if let Some(sender) = &peer.wg_sender {
|
||||
if let Err(e) = sender.try_send((net_packet.payload().to_vec(), source)) {
|
||||
log::info!("发送到对端wg失败 {}->{},{}", source, dest, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn broadcast<B: AsRef<[u8]>>(
|
||||
&self,
|
||||
context: &Context,
|
||||
context: &LinkVntContext,
|
||||
net_packet: NetPacket<B>,
|
||||
exclude: &[Ipv4Addr],
|
||||
) -> io::Result<()> {
|
||||
@@ -640,6 +604,7 @@ impl ServerPacketHandler {
|
||||
if client_info.online
|
||||
&& !exclude.contains(&(*ip).into())
|
||||
&& client_info.client_secret == client_secret
|
||||
&& client_info.wireguard.is_none()
|
||||
{
|
||||
if let Some(sender) = &client_info.tcp_sender {
|
||||
let _ = sender.try_send(net_packet.buffer().to_vec());
|
||||
@@ -653,3 +618,171 @@ impl ServerPacketHandler {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RegisterClientRequest {
|
||||
pub group_id: String,
|
||||
// ip 0表示自动分配
|
||||
pub virtual_ip: Ipv4Addr,
|
||||
pub gateway: Ipv4Addr,
|
||||
pub netmask: Ipv4Addr,
|
||||
|
||||
// 允许分配不一样的ip
|
||||
pub allow_ip_change: bool,
|
||||
// 设备ID
|
||||
pub device_id: String,
|
||||
// 版本
|
||||
pub version: String,
|
||||
// 名称
|
||||
pub name: String,
|
||||
// 客户端间是否加密
|
||||
pub client_secret: bool,
|
||||
// 加密hash
|
||||
pub client_secret_hash: Vec<u8>,
|
||||
// 和服务端是否加密
|
||||
pub server_secret: bool,
|
||||
// 链接服务器的来源地址
|
||||
pub address: SocketAddr,
|
||||
pub tcp_sender: Option<Sender<Vec<u8>>>,
|
||||
// 是否在线
|
||||
pub online: bool,
|
||||
// wireguard客户端公钥
|
||||
pub wireguard: Option<[u8; 32]>,
|
||||
}
|
||||
|
||||
pub struct RegisterClientResponse {
|
||||
timestamp: i64,
|
||||
pub virtual_ip: Ipv4Addr,
|
||||
// 纪元号
|
||||
pub epoch: u64,
|
||||
pub client_list: Vec<SimpleClientInfo>,
|
||||
}
|
||||
|
||||
pub async fn generate_ip(
|
||||
cache: &AppCache,
|
||||
register_request: RegisterClientRequest,
|
||||
) -> anyhow::Result<RegisterClientResponse> {
|
||||
let gateway: u32 = register_request.gateway.into();
|
||||
let netmask: u32 = register_request.netmask.into();
|
||||
let network: u32 = gateway & netmask;
|
||||
let mut virtual_ip: u32 = register_request.virtual_ip.into();
|
||||
let device_id = register_request.device_id;
|
||||
let allow_ip_change = register_request.allow_ip_change;
|
||||
let group_id = register_request.group_id;
|
||||
let v = cache
|
||||
.virtual_network
|
||||
.optionally_get_with(group_id, || {
|
||||
(
|
||||
Duration::from_secs(7 * 24 * 3600),
|
||||
Arc::new(parking_lot::const_rwlock(NetworkInfo::new(
|
||||
network, netmask, gateway,
|
||||
))),
|
||||
)
|
||||
})
|
||||
.await;
|
||||
// 可分配的ip段
|
||||
let ip_range = network + 1..gateway | (!netmask);
|
||||
let timestamp = Local::now().timestamp();
|
||||
let mut lock = v.write();
|
||||
let mut insert = true;
|
||||
if virtual_ip != 0 {
|
||||
if gateway == virtual_ip || !ip_range.contains(&virtual_ip) {
|
||||
Err(Error::InvalidIp)?
|
||||
}
|
||||
//指定了ip
|
||||
if let Some(info) = lock.clients.get_mut(&virtual_ip) {
|
||||
if info.device_id != device_id {
|
||||
//ip被占用了,并且不能更改ip
|
||||
if !allow_ip_change {
|
||||
Err(Error::IpAlreadyExists)?
|
||||
}
|
||||
// 重新挑选ip
|
||||
virtual_ip = 0;
|
||||
} else {
|
||||
insert = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut old_ip = 0;
|
||||
if insert {
|
||||
// 找到上一次用的ip
|
||||
for (ip, x) in &lock.clients {
|
||||
if x.device_id == device_id {
|
||||
if virtual_ip == 0 {
|
||||
virtual_ip = *ip;
|
||||
} else {
|
||||
old_ip = *ip;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if virtual_ip == 0 {
|
||||
// 从小到大找一个未使用的ip
|
||||
for ip in ip_range {
|
||||
if ip == lock.gateway_ip {
|
||||
continue;
|
||||
}
|
||||
if !lock.clients.contains_key(&ip) {
|
||||
virtual_ip = ip;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if virtual_ip == 0 {
|
||||
log::error!("地址使用完:{:?}", lock);
|
||||
Err(Error::AddressExhausted)?
|
||||
}
|
||||
let info = if old_ip == 0 {
|
||||
lock.clients
|
||||
.entry(virtual_ip)
|
||||
.or_insert_with(ClientInfo::default)
|
||||
} else {
|
||||
let client_info = lock.clients.remove(&old_ip).unwrap();
|
||||
lock.clients
|
||||
.entry(virtual_ip)
|
||||
.or_insert_with(|| client_info)
|
||||
};
|
||||
info.name = register_request.name;
|
||||
info.device_id = device_id;
|
||||
info.version = register_request.version;
|
||||
info.client_secret = register_request.client_secret;
|
||||
info.client_secret_hash = register_request.client_secret_hash;
|
||||
info.server_secret = register_request.server_secret;
|
||||
info.address = register_request.address;
|
||||
info.online = register_request.online;
|
||||
info.wireguard = register_request.wireguard;
|
||||
info.virtual_ip = virtual_ip;
|
||||
info.tcp_sender = register_request.tcp_sender;
|
||||
info.last_join_time = Local::now();
|
||||
info.timestamp = timestamp;
|
||||
lock.epoch += 1;
|
||||
let response = RegisterClientResponse {
|
||||
timestamp,
|
||||
virtual_ip: virtual_ip.into(),
|
||||
epoch: lock.epoch,
|
||||
client_list: clients_info(&lock.clients, virtual_ip),
|
||||
};
|
||||
Ok(response)
|
||||
}
|
||||
fn clients_info(clients: &HashMap<u32, ClientInfo>, current_ip: u32) -> Vec<SimpleClientInfo> {
|
||||
clients
|
||||
.iter()
|
||||
.filter(|&(_, dev)| dev.virtual_ip != current_ip)
|
||||
.map(|(_, device_info)| device_info.into())
|
||||
.collect()
|
||||
}
|
||||
impl From<SimpleClientInfo> for message::DeviceInfo {
|
||||
fn from(value: SimpleClientInfo) -> Self {
|
||||
let mut dev = message::DeviceInfo::new();
|
||||
dev.virtual_ip = value.virtual_ip;
|
||||
dev.name = value.name;
|
||||
dev.device_status = if value.online { 0 } else { 1 };
|
||||
dev.client_secret = value.client_secret;
|
||||
if value.online {
|
||||
dev.client_secret_hash = value.client_secret_hash;
|
||||
}
|
||||
dev.wireguard = value.wireguard;
|
||||
dev
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,130 +1,121 @@
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::RwLock;
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use parking_lot::RwLock;
|
||||
|
||||
use crate::cipher::Aes256GcmCipher;
|
||||
use crate::core::entity::NetworkInfo;
|
||||
use crate::core::entity::{NetworkInfo, WireGuardConfig};
|
||||
use crate::core::store::expire_map::ExpireMap;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppCache {
|
||||
// group -> NetworkInfo
|
||||
pub virtual_network: ExpireMap<String, Arc<RwLock<NetworkInfo>>>,
|
||||
// (group,ip) -> addr
|
||||
// (group,ip) -> addr 用于客户端过期,只有客户端离线才设置
|
||||
pub ip_session: ExpireMap<(String, u32), SocketAddr>,
|
||||
// addr -> (group,ip)
|
||||
pub addr_session: ExpireMap<SocketAddr, (String, u32, i64)>,
|
||||
pub cipher_session: ExpireMap<SocketAddr, Arc<Aes256GcmCipher>>,
|
||||
// 加密密钥
|
||||
pub cipher_session: Arc<DashMap<SocketAddr, Arc<Aes256GcmCipher>>>,
|
||||
// web登录状态
|
||||
pub auth_map: ExpireMap<String, ()>,
|
||||
// wg公钥 -> wg配置
|
||||
pub wg_group_map: Arc<DashMap<[u8; 32], WireGuardConfig>>,
|
||||
}
|
||||
|
||||
pub struct Context {
|
||||
pub struct VntContext {
|
||||
pub link_context: Option<LinkVntContext>,
|
||||
pub server_cipher: Option<Aes256GcmCipher>,
|
||||
pub link_address: SocketAddr,
|
||||
}
|
||||
pub struct LinkVntContext {
|
||||
pub network_info: Arc<RwLock<NetworkInfo>>,
|
||||
pub group: String,
|
||||
pub virtual_ip: u32,
|
||||
pub broadcast: Ipv4Addr,
|
||||
pub timestamp: i64,
|
||||
}
|
||||
impl VntContext {
|
||||
pub async fn leave(self, cache: &AppCache) {
|
||||
if self.server_cipher.is_some() {
|
||||
cache.cipher_session.remove(&self.link_address);
|
||||
}
|
||||
if let Some(context) = self.link_context {
|
||||
if let Some(network_info) = cache.virtual_network.get(&context.group) {
|
||||
{
|
||||
let mut guard = network_info.write();
|
||||
if let Some(client_info) = guard.clients.get_mut(&context.virtual_ip) {
|
||||
if client_info.address != self.link_address
|
||||
&& client_info.timestamp != context.timestamp
|
||||
{
|
||||
return;
|
||||
}
|
||||
client_info.online = false;
|
||||
client_info.tcp_sender = None;
|
||||
guard.epoch += 1;
|
||||
}
|
||||
drop(guard);
|
||||
}
|
||||
cache
|
||||
.insert_ip_session((context.group, context.virtual_ip), self.link_address)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AppCache {
|
||||
pub fn new() -> Self {
|
||||
let wg_group_map: Arc<DashMap<[u8; 32], WireGuardConfig>> = Default::default();
|
||||
// 网段7天未使用则回收
|
||||
let virtual_network: ExpireMap<String, Arc<RwLock<NetworkInfo>>> =
|
||||
ExpireMap::new(|_k, _v| {});
|
||||
let virtual_network_ = virtual_network.clone();
|
||||
// ip一天未使用则回收
|
||||
let ip_session: ExpireMap<(String, u32), SocketAddr> =
|
||||
ExpireMap::new(move |(group_id, ip), addr: SocketAddr| {
|
||||
log::info!(
|
||||
"ip_session eviction group_id={},ip={},addr={}",
|
||||
group_id,
|
||||
Ipv4Addr::from(ip),
|
||||
addr
|
||||
);
|
||||
if let Some(v) = virtual_network_.get(&group_id) {
|
||||
let mut lock = v.write();
|
||||
if let Some(dev) = lock.clients.get(&ip) {
|
||||
if dev.address == addr {
|
||||
lock.clients.remove(&ip);
|
||||
lock.epoch += 1;
|
||||
}
|
||||
}
|
||||
ExpireMap::new(|_k, v: &Arc<RwLock<NetworkInfo>>| {
|
||||
let lock = v.read();
|
||||
if !lock.clients.is_empty() {
|
||||
// 存在客户端的不过期
|
||||
return Some(Duration::from_secs(7 * 24 * 3600));
|
||||
}
|
||||
None
|
||||
});
|
||||
let virtual_network_ = virtual_network.clone();
|
||||
// 20秒钟没有收到消息则判定为掉线
|
||||
let addr_session = ExpireMap::new(
|
||||
move |addr: SocketAddr, (group, virtual_ip, timestamp)| {
|
||||
log::info!(
|
||||
"addr_session eviction group={},virtual_ip={},addr={},timestamp={}",
|
||||
group,
|
||||
Ipv4Addr::from(virtual_ip),
|
||||
addr,
|
||||
timestamp
|
||||
);
|
||||
|
||||
if let Some(v) = virtual_network_.get(&group) {
|
||||
let mut lock = v.write();
|
||||
if let Some(item) = lock.clients.get_mut(&virtual_ip) {
|
||||
if item.address != addr || item.timestamp != timestamp {
|
||||
log::info!(
|
||||
"无效信息 addr_session eviction group={},virtual_ip={},addr={},timestamp={}",
|
||||
group,
|
||||
Ipv4Addr::from(virtual_ip),
|
||||
addr,
|
||||
timestamp
|
||||
);
|
||||
return;
|
||||
}
|
||||
item.online = false;
|
||||
// ip一天未使用则回收
|
||||
let ip_session: ExpireMap<(String, u32), SocketAddr> = ExpireMap::new(move |key, addr| {
|
||||
let (group_id, ip) = &key;
|
||||
log::info!(
|
||||
"ip_session eviction group_id={},ip={},addr={}",
|
||||
group_id,
|
||||
Ipv4Addr::from(*ip),
|
||||
addr
|
||||
);
|
||||
if let Some(v) = virtual_network_.get(group_id) {
|
||||
let mut lock = v.write();
|
||||
if let Some(dev) = lock.clients.get(ip) {
|
||||
if !dev.online && &dev.address == addr {
|
||||
lock.clients.remove(ip);
|
||||
lock.epoch += 1;
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
let cipher_session = ExpireMap::new(|_k, _v| {});
|
||||
let auth_map = ExpireMap::new(|_k, _v| {});
|
||||
}
|
||||
None
|
||||
});
|
||||
|
||||
let auth_map = ExpireMap::new(|_k, _v| None);
|
||||
Self {
|
||||
virtual_network,
|
||||
ip_session,
|
||||
addr_session,
|
||||
cipher_session,
|
||||
cipher_session: Default::default(),
|
||||
auth_map,
|
||||
wg_group_map,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AppCache {
|
||||
pub fn get_context(&self, addr: &SocketAddr) -> Option<Context> {
|
||||
if let Some((group, virtual_ip, _)) = self.addr_session.get(addr) {
|
||||
let k = (group, virtual_ip);
|
||||
self.ip_session.get(&k)?;
|
||||
let (group, virtual_ip) = k;
|
||||
return self
|
||||
.virtual_network
|
||||
.get(&group)
|
||||
.map(|network_info| Context {
|
||||
network_info,
|
||||
group,
|
||||
virtual_ip,
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub async fn insert_cipher_session(&self, key: SocketAddr, value: Aes256GcmCipher) {
|
||||
self.cipher_session
|
||||
.insert(key, Arc::new(value), Duration::from_secs(120))
|
||||
.await
|
||||
self.cipher_session.insert(key, Arc::new(value));
|
||||
}
|
||||
pub async fn insert_ip_session(&self, key: (String, u32), value: SocketAddr) {
|
||||
self.ip_session
|
||||
.insert(key, value, Duration::from_secs(24 * 3600))
|
||||
.await
|
||||
}
|
||||
pub async fn insert_addr_session(&self, key: SocketAddr, value: (String, u32, i64)) {
|
||||
self.addr_session
|
||||
.insert(key, value, Duration::from_secs(20))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ struct Value<V> {
|
||||
impl<K, V> ExpireMap<K, V> {
|
||||
pub fn new<F>(call: F) -> ExpireMap<K, V>
|
||||
where
|
||||
F: Fn(K, V) + Send + 'static,
|
||||
F: Fn(&K, &V) -> Option<Duration> + Send + 'static,
|
||||
K: Eq + Hash + Clone + Sync + Send + 'static,
|
||||
V: Clone + Sync + Send + 'static,
|
||||
{
|
||||
@@ -66,6 +66,14 @@ where
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
/// remove出去的不会执行过期回调
|
||||
pub fn remove(&self, k: &K) -> Option<V> {
|
||||
if let Some(v) = self.base.write().remove(k) {
|
||||
Some(v.val)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
pub fn get(&self, k: &K) -> Option<V> {
|
||||
if let Some(v) = self.base.read().get(k) {
|
||||
// 延长过期时间
|
||||
@@ -78,7 +86,10 @@ where
|
||||
pub fn get_val(&self, k: &K) -> Option<V> {
|
||||
self.base.read().get(k).map(|v| v.val.clone())
|
||||
}
|
||||
fn expire_call(&self, k: &K) -> Op<K, V> {
|
||||
fn expire_call<F>(&self, k: &K, f: &F) -> Op<K, V>
|
||||
where
|
||||
F: Fn(&K, &V) -> Option<Duration>,
|
||||
{
|
||||
let mut write_guard = self.base.write();
|
||||
if let Some(v) = write_guard.get(k) {
|
||||
let now = Instant::now();
|
||||
@@ -87,10 +98,14 @@ where
|
||||
// 过期时间更新了
|
||||
return Op::Reset(instant);
|
||||
} else {
|
||||
//删除key
|
||||
if let Some((k, v)) = write_guard.remove_entry(k) {
|
||||
//执行回调
|
||||
return Op::Remove(k, v.val);
|
||||
//执行过期回调
|
||||
if let Some(v) = f(k, &v.val) {
|
||||
return Op::Reset(now.add(v));
|
||||
} else {
|
||||
//删除key
|
||||
if let Some((k, v)) = write_guard.remove_entry(k) {
|
||||
return Op::Remove(k, v.val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -142,7 +157,7 @@ async fn expire_task<K, V, F>(mut receiver: Receiver<DelayedTask<K>>, map: Expir
|
||||
where
|
||||
K: Eq + Hash + Clone,
|
||||
V: Clone,
|
||||
F: Fn(K, V),
|
||||
F: Fn(&K, &V) -> Option<Duration>,
|
||||
{
|
||||
let mut binary_heap = BinaryHeap::<DelayedTask<K>>::with_capacity(32);
|
||||
loop {
|
||||
@@ -165,17 +180,13 @@ where
|
||||
}
|
||||
} else if let Some(mut task) = binary_heap.pop() {
|
||||
//执行过期逻辑
|
||||
match map.expire_call(&task.k) {
|
||||
match map.expire_call(&task.k, &f) {
|
||||
Op::Reset(time) => {
|
||||
//没有过期,重新加入监听
|
||||
|
||||
task.time = time;
|
||||
binary_heap.push(task);
|
||||
}
|
||||
Op::Remove(k, v) => {
|
||||
//执行回调
|
||||
f(k, v)
|
||||
}
|
||||
Op::Remove(_, _) => {}
|
||||
Op::None => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,9 @@
|
||||
#![allow(dead_code, clippy::enum_variant_names)]
|
||||
|
||||
use std::io;
|
||||
|
||||
use crossbeam::channel::RecvError;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Io error")]
|
||||
Io(#[from] io::Error),
|
||||
#[error("Channel error")]
|
||||
Channel(#[from] RecvError),
|
||||
#[error("Protobuf error")]
|
||||
Protobuf(#[from] protobuf::Error),
|
||||
#[error("Disconnect")]
|
||||
Disconnect,
|
||||
#[error("No Key")]
|
||||
@@ -29,4 +20,4 @@ pub enum Error {
|
||||
Other(String),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
pub type Result<T> = anyhow::Result<T>;
|
||||
|
||||
56
src/main.rs
56
src/main.rs
@@ -1,12 +1,16 @@
|
||||
use aes_gcm::aead::rand_core::RngCore;
|
||||
use anyhow::{anyhow, Context};
|
||||
use base64::engine::general_purpose;
|
||||
use base64::Engine;
|
||||
use boringtun::x25519::{PublicKey, StaticSecret};
|
||||
use clap::Parser;
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::Display;
|
||||
use std::fmt::{Debug, Display, Formatter};
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use crate::cipher::RsaCipher;
|
||||
|
||||
mod cipher;
|
||||
@@ -15,6 +19,7 @@ mod error;
|
||||
mod generated_serial_number;
|
||||
mod proto;
|
||||
mod protocol;
|
||||
|
||||
pub const VNT_VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// 默认网关信息
|
||||
@@ -56,9 +61,12 @@ pub struct StartArgs {
|
||||
/// web后台用户密码,默认为admin
|
||||
#[arg(short = 'W', long)]
|
||||
password: Option<String>,
|
||||
/// wg私钥,使用base64编码
|
||||
#[arg(long = "wg")]
|
||||
wg_secret_key: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone)]
|
||||
pub struct ConfigInfo {
|
||||
pub port: u16,
|
||||
pub white_token: Option<HashSet<String>>,
|
||||
@@ -70,6 +78,28 @@ pub struct ConfigInfo {
|
||||
pub username: String,
|
||||
#[cfg(feature = "web")]
|
||||
pub password: String,
|
||||
pub wg_secret_key: StaticSecret,
|
||||
pub wg_public_key: PublicKey,
|
||||
}
|
||||
impl Debug for ConfigInfo {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ConfigInfo")
|
||||
.field("port", &self.port)
|
||||
.field("white_token", &self.white_token)
|
||||
.field("gateway", &self.gateway)
|
||||
.field("broadcast", &self.broadcast)
|
||||
.field("netmask", &self.netmask)
|
||||
.field("check_finger", &self.check_finger)
|
||||
.field(
|
||||
"wg_secret_key",
|
||||
&general_purpose::STANDARD.encode(&self.wg_secret_key),
|
||||
)
|
||||
.field(
|
||||
"wg_public_key",
|
||||
&general_purpose::STANDARD.encode(&self.wg_public_key),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn log_init(root_path: PathBuf, log_path: Option<String>) {
|
||||
@@ -231,6 +261,22 @@ async fn main() {
|
||||
if check_finger {
|
||||
println!("转发校验数据指纹,客户端必须增加--finger参数");
|
||||
}
|
||||
let wg_secret_key: [u8; 32] = if let Some(wg_secret_key) = args.wg_secret_key {
|
||||
let wg_secret_key = general_purpose::STANDARD
|
||||
.decode(wg_secret_key)
|
||||
.context("wg私钥错误")
|
||||
.unwrap();
|
||||
wg_secret_key
|
||||
.try_into()
|
||||
.map_err(|_| anyhow!("wg私钥错误"))
|
||||
.unwrap()
|
||||
} else {
|
||||
let mut wg_secret_key = [0u8; 32];
|
||||
rand::thread_rng().fill_bytes(&mut wg_secret_key);
|
||||
wg_secret_key
|
||||
};
|
||||
let wg_secret_key = boringtun::x25519::StaticSecret::from(wg_secret_key);
|
||||
let wg_public_key = boringtun::x25519::PublicKey::from(&wg_secret_key);
|
||||
let config = ConfigInfo {
|
||||
port,
|
||||
white_token,
|
||||
@@ -242,6 +288,8 @@ async fn main() {
|
||||
username: args.username.unwrap_or_else(|| "admin".into()),
|
||||
#[cfg(feature = "web")]
|
||||
password: args.password.unwrap_or_else(|| "admin".into()),
|
||||
wg_secret_key,
|
||||
wg_public_key,
|
||||
};
|
||||
let rsa = match RsaCipher::new(root_path) {
|
||||
Ok(rsa) => {
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::net::Ipv4Addr;
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
|
||||
pub enum Protocol {
|
||||
Ipv4,
|
||||
WGIpv4,
|
||||
Ipv4Broadcast,
|
||||
Unknown(u8),
|
||||
}
|
||||
@@ -14,6 +15,7 @@ impl From<u8> for Protocol {
|
||||
fn from(value: u8) -> Self {
|
||||
match value {
|
||||
4 => Protocol::Ipv4,
|
||||
5 => Protocol::WGIpv4,
|
||||
201 => Protocol::Ipv4Broadcast,
|
||||
val => Protocol::Unknown(val),
|
||||
}
|
||||
@@ -24,6 +26,7 @@ impl From<Protocol> for u8 {
|
||||
fn from(val: Protocol) -> Self {
|
||||
match val {
|
||||
Protocol::Ipv4 => 4,
|
||||
Protocol::WGIpv4 => 5,
|
||||
Protocol::Ipv4Broadcast => 201,
|
||||
Protocol::Unknown(val) => val,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
.option-cell {
|
||||
display: flex;
|
||||
gap: 10px; /* 间隔 */
|
||||
}
|
||||
|
||||
.option-cell button {
|
||||
padding: 5px 10px;
|
||||
font-size: 14px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.3s ease;
|
||||
}
|
||||
|
||||
.option-cell button.delete-button {
|
||||
background-color: #f44336; /* 红色 */
|
||||
color: white;
|
||||
}
|
||||
|
||||
.option-cell button.delete-button:hover {
|
||||
background-color: #d32f2f; /* 深红色 */
|
||||
}
|
||||
|
||||
.option-cell button.view-button {
|
||||
background-color: #4CAF50; /* 绿色 */
|
||||
color: white;
|
||||
}
|
||||
|
||||
.option-cell button.view-button:hover {
|
||||
background-color: #388E3C; /* 深绿色 */
|
||||
}
|
||||
|
||||
/* wg弹窗样式 */
|
||||
.modal {
|
||||
display: none; /* 默认隐藏 */
|
||||
position: fixed;
|
||||
z-index: 1;
|
||||
left: 0;
|
||||
top: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
overflow: auto;
|
||||
background-color: rgb(0,0,0);
|
||||
background-color: rgba(0,0,0,0.4);
|
||||
padding-top: 60px;
|
||||
}
|
||||
|
||||
.modal-content {
|
||||
position: relative;
|
||||
background-color: #fefefe;
|
||||
width: 380px;
|
||||
height: 380px;
|
||||
margin: 5% auto;
|
||||
padding: 50px;
|
||||
border: 1px solid #888;
|
||||
text-align: center;
|
||||
box-sizing: border-box;
|
||||
border-radius: 5px;
|
||||
}
|
||||
.add-modal-content{
|
||||
position: relative;
|
||||
background-color: #fefefe;
|
||||
margin: 5% auto;
|
||||
padding: 20px;
|
||||
border: 1px solid #888;
|
||||
width: 80%;
|
||||
max-width: 600px;
|
||||
box-sizing: border-box;
|
||||
border-radius: 5px;
|
||||
}
|
||||
.form-group {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
.form-group label {
|
||||
flex: 1;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.form-group input {
|
||||
flex: 2;
|
||||
padding: 10px;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.modal button {
|
||||
padding: 10px 20px;
|
||||
margin: 10px 5px;
|
||||
}
|
||||
|
||||
.button-container {
|
||||
text-align: center;
|
||||
}
|
||||
.button-container button {
|
||||
padding: 10px 20px;
|
||||
margin: 10px 5px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
transition: background-color 0.3s, box-shadow 0.3s;
|
||||
}
|
||||
|
||||
.error {
|
||||
color: red;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
#confirmButton {
|
||||
background-color: #4CAF50; /* 绿色背景 */
|
||||
color: white;
|
||||
}
|
||||
|
||||
#confirmButton:hover {
|
||||
background-color: #45a049; /* 鼠标悬停时变暗 */
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.2); /* 阴影效果 */
|
||||
}
|
||||
|
||||
#cancelButton {
|
||||
background-color: #f44336; /* 红色背景 */
|
||||
color: white;
|
||||
}
|
||||
|
||||
#cancelButton:hover {
|
||||
background-color: #e53935; /* 鼠标悬停时变暗 */
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.2); /* 阴影效果 */
|
||||
}
|
||||
.modal-content .title{
|
||||
position: absolute;
|
||||
left: 20px;
|
||||
top: 10px;
|
||||
}
|
||||
|
||||
.close {
|
||||
position: absolute;
|
||||
right: 20px;
|
||||
top: 10px;
|
||||
color: #aaa;
|
||||
font-size: 28px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.close:hover,
|
||||
.close:focus {
|
||||
color: black;
|
||||
text-decoration: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.visible {
|
||||
display: block;
|
||||
}
|
||||
#qrcode {
|
||||
margin: 0 auto; /* 居中 */
|
||||
width: 260px; /* 固定宽度 */
|
||||
height: 260px; /* 固定高度 */
|
||||
}
|
||||
pre {
|
||||
text-align: left; /* 左对齐 */
|
||||
white-space: pre-wrap; /* 自动换行 */
|
||||
word-break: break-all;
|
||||
user-select: auto;
|
||||
}
|
||||
#toggleButton{
|
||||
position: absolute;
|
||||
bottom: 10px;
|
||||
left: 116px;
|
||||
}
|
||||
@@ -123,7 +123,27 @@ body {
|
||||
padding-top: 20px;
|
||||
margin-left: 300px;
|
||||
}
|
||||
button{
|
||||
padding: 10px 20px;
|
||||
margin: 10px 5px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
transition: background-color 0.3s, box-shadow 0.3s;
|
||||
}
|
||||
#addWireGuard {
|
||||
background-color: #4CAF50; /* 绿色背景 */
|
||||
color: white;
|
||||
position: absolute;
|
||||
right: 50px;
|
||||
top:20px;
|
||||
}
|
||||
|
||||
#addWireGuard:hover {
|
||||
background-color: #45a049; /* 鼠标悬停时变暗 */
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.2); /* 阴影效果 */
|
||||
}
|
||||
/* 下拉菜单按钮 */
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,11 @@
|
||||
<meta charset="UTF-8">
|
||||
<script src="./js/jquery-3.7.1.min.js"></script>
|
||||
<script src="./js/g6.min.js"></script>
|
||||
<script src="./js/qrcode.min.js"></script>
|
||||
<script src="./js/api-post.js"></script>
|
||||
|
||||
<link rel="stylesheet" href="./css/select.css">
|
||||
<link rel="stylesheet" href="./css/index.css">
|
||||
<title>vnts-web</title>
|
||||
<style>
|
||||
|
||||
@@ -91,7 +93,7 @@
|
||||
<div class="topBox">
|
||||
<div class="select-content">
|
||||
<!-- <label>组网标识:</label>-->
|
||||
<input type="hidden" name="newMachineId">
|
||||
<input type="hidden" name="groupId">
|
||||
<input type="text" name="select_input" id="select_input" class="select-input" value="" autocomplete="off"
|
||||
placeholder="Search..."/>
|
||||
<div id="search_select" class="search-select">
|
||||
@@ -101,6 +103,7 @@
|
||||
</div>
|
||||
<span class="group_len"></span>
|
||||
</div>
|
||||
<button id="addWireGuard">接入WireGuard客户端</button>
|
||||
</div>
|
||||
<div id="container_content">
|
||||
<div id="group_info">
|
||||
@@ -125,6 +128,7 @@
|
||||
<th>连接时间</th>
|
||||
<th>链接地址</th>
|
||||
<th>设备 ID</th>
|
||||
<th>操作</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
@@ -134,7 +138,51 @@
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
<!-- wg信息弹窗 -->
|
||||
<div id="wgConfigModal" class="modal">
|
||||
<div class="modal-content">
|
||||
<span class="title">使用WireGuard客户端接入</span>
|
||||
<span class="close">×</span>
|
||||
<div id="qrcode" class="visible"></div>
|
||||
<pre id="textConfig" class="hidden"></pre>
|
||||
<button id="toggleButton">显示文本配置</button>
|
||||
</div>
|
||||
</div>
|
||||
<!-- 添加wg弹窗 -->
|
||||
<div id="addModal" class="modal">
|
||||
<div class="add-modal-content">
|
||||
<span class="close">×</span>
|
||||
<h2>WireGuard配置</h2>
|
||||
<div class="form-group">
|
||||
<label for="groupId">组网编号</label>
|
||||
<input type="text" id="groupId" placeholder="组网编号">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="virtualIP">虚拟IP</label>
|
||||
<input type="text" id="virtualIP" placeholder="为空则自动分配">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="deviceName">设备名称</label>
|
||||
<input type="text" id="deviceName" placeholder="设备名称">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="privateKey">PrivateKey</label>
|
||||
<input type="text" id="privateKey" placeholder="PrivateKey">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="endpoint">Endpoint</label>
|
||||
<input type="text" id="endpoint" placeholder="服务器UDP链接地址">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="persistentKeepalive">PersistentKeepalive</label>
|
||||
<input type="number" id="persistentKeepalive" value="10" placeholder="PersistentKeepalive">
|
||||
</div>
|
||||
<div class="button-container">
|
||||
<button id="confirmButton">确认</button>
|
||||
</div>
|
||||
<div class="error" id="addWGError"></div>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
<script src="js/group-node.js"></script>
|
||||
|
||||
@@ -158,7 +206,7 @@
|
||||
}
|
||||
var options = '';
|
||||
for (var i = 0; i < listArr.length; i++) {
|
||||
opt = '<li class="li-select" data-newMachineId="' + listArr[i] + '">' + listArr[i] + '</li>';
|
||||
opt = '<li class="li-select" data-groupId="' + listArr[i] + '">' + listArr[i] + '</li>';
|
||||
options += opt;
|
||||
}
|
||||
if (options == '') {
|
||||
@@ -196,12 +244,11 @@
|
||||
$('#select_ul').delegate('.li-select', 'click', function () {
|
||||
$('#select_ul .li-select').removeClass('li-hover');
|
||||
var selectText = $(this).html();
|
||||
var newMachineIdVal = $($(this)[0]).attr("data-newMachineId");
|
||||
var groupIdVal = $($(this)[0]).attr("data-groupId");
|
||||
$('#select_input').val(selectText);
|
||||
$('#search_select').hide();
|
||||
$("input[name='newMachineId']").val(newMachineIdVal);
|
||||
console.log(newMachineIdVal);
|
||||
getGroupInfoFunc(newMachineIdVal)
|
||||
$("input[name='groupId']").val(groupIdVal);
|
||||
getGroupInfoFunc(groupIdVal)
|
||||
|
||||
});
|
||||
$('#select_ul').delegate('.li-select', 'mouseover', function () {
|
||||
@@ -215,8 +262,10 @@
|
||||
$('.group_len').html('(' + tempArr.length + ')')
|
||||
}
|
||||
|
||||
function displayDeviceInfo(devices) {
|
||||
function displayDeviceInfo(groupId, data) {
|
||||
let devices = data.clients;
|
||||
const tableBody = document.getElementById('deviceTable').getElementsByTagName('tbody')[0];
|
||||
tableBody.innerHTML = "";
|
||||
devices.forEach(device => {
|
||||
const row = document.createElement('tr');
|
||||
|
||||
@@ -247,7 +296,7 @@
|
||||
const lastJoinTimeCell = document.createElement('td');
|
||||
lastJoinTimeCell.textContent = device.last_join_time;
|
||||
row.appendChild(lastJoinTimeCell);
|
||||
|
||||
|
||||
const addressCell = document.createElement('td');
|
||||
addressCell.textContent = device.address;
|
||||
row.appendChild(addressCell);
|
||||
@@ -255,12 +304,132 @@
|
||||
const deviceIdCell = document.createElement('td');
|
||||
deviceIdCell.textContent = device.device_id;
|
||||
row.appendChild(deviceIdCell);
|
||||
// 操作栏
|
||||
const optionCell = document.createElement('td');
|
||||
optionCell.className = 'option-cell';
|
||||
const deleteButton = document.createElement('button');
|
||||
optionCell.appendChild(deleteButton);
|
||||
deleteButton.className = 'delete-button';
|
||||
deleteButton.textContent = '删除';
|
||||
deleteButton.onclick = function () {
|
||||
const confirmed = window.confirm('你确定要删除这条记录吗?');
|
||||
if (confirmed) {
|
||||
postRemoveClient({'group_id': groupId, 'virtual_ip': device.virtual_ip}, function () {
|
||||
location.reload();
|
||||
});
|
||||
row.remove();
|
||||
}
|
||||
|
||||
};
|
||||
let wg_config = device.wg_config;
|
||||
if (wg_config) {
|
||||
function generateWireguardConfig(privateKey, ip, prefix, publicKey, allowedIPs, endpoint, persistentKeepalive) {
|
||||
return `[Interface]
|
||||
PrivateKey = ${privateKey}
|
||||
Address = ${ip}/${prefix}
|
||||
[Peer]
|
||||
PublicKey = ${publicKey}
|
||||
AllowedIPs = ${allowedIPs}
|
||||
Endpoint = ${endpoint}
|
||||
PersistentKeepalive = ${persistentKeepalive}`;
|
||||
}
|
||||
|
||||
const wireguardConfig = generateWireguardConfig(wg_config.secret_key, wg_config.ip, wg_config.prefix,
|
||||
data.vnts_public_key, wg_config.vnts_allowed_ips, wg_config.vnts_endpoint, wg_config.persistent_keepalive);
|
||||
console.log(wireguardConfig);
|
||||
const viewButton = document.createElement('button');
|
||||
viewButton.textContent = '接入';
|
||||
viewButton.className = 'view-button';
|
||||
viewButton.onclick = function () {
|
||||
// 显示弹窗
|
||||
$('#wgConfigModal').show();
|
||||
// 设置文本配置
|
||||
$('#textConfig').text(wireguardConfig);
|
||||
$('#qrcode').html('');
|
||||
// 生成二维码
|
||||
new QRCode(document.getElementById("qrcode"), {
|
||||
correctLevel: 3,
|
||||
text: wireguardConfig,
|
||||
width: 256,
|
||||
height: 256
|
||||
});
|
||||
|
||||
};
|
||||
optionCell.appendChild(viewButton);
|
||||
}
|
||||
row.appendChild(optionCell);
|
||||
tableBody.appendChild(row);
|
||||
});
|
||||
}
|
||||
</script>
|
||||
<script>
|
||||
let groupId;
|
||||
let deviceId = 'wg' + new Date().getTime();
|
||||
$('#addWireGuard').on('click', function () {
|
||||
$('#groupId').val(groupId);
|
||||
$('#deviceName').val('WireGuard');
|
||||
postWgPrivateKey({}, function (data) {
|
||||
$('#privateKey').val(data.data);
|
||||
$('.error').text('');
|
||||
$('#addModal').show();
|
||||
})
|
||||
});
|
||||
$('#confirmButton').on('click', function () {
|
||||
// 清除以前的错误信息
|
||||
$('.error').text('');
|
||||
// 获取输入值
|
||||
const groupId = $('#groupId').val().trim();
|
||||
const virtualIP = $('#virtualIP').val().trim();
|
||||
const deviceName = $('#deviceName').val().trim();
|
||||
const privateKey = $('#privateKey').val().trim();
|
||||
const endpoint = $('#endpoint').val().trim();
|
||||
const persistentKeepalive = $('#persistentKeepalive').val().trim();
|
||||
// 校验输入值
|
||||
if (!groupId) {
|
||||
$('#addWGError').text('组网编号不能为空');
|
||||
return
|
||||
}
|
||||
|
||||
if (!deviceName) {
|
||||
$('#addWGError').text('设备名称不能为空');
|
||||
return;
|
||||
}
|
||||
|
||||
if (!privateKey) {
|
||||
$('#addWGError').text('PrivateKey不能为空');
|
||||
return;
|
||||
}
|
||||
|
||||
if (!endpoint) {
|
||||
$('#addWGError').text('Endpoint不能为空');
|
||||
return;
|
||||
}
|
||||
|
||||
if (persistentKeepalive === '') {
|
||||
$('#addWGError').text('PersistentKeepalive不能为空');
|
||||
return;
|
||||
} else if (isNaN(persistentKeepalive) || persistentKeepalive < 0 || persistentKeepalive > 65535) {
|
||||
$('#addWGError').text('PersistentKeepalive必须是0~65535');
|
||||
return;
|
||||
}
|
||||
postCreateWG({
|
||||
'group_id': groupId, 'virtual_ip': virtualIP, 'device_id': deviceId, 'name': deviceName,
|
||||
'config': {
|
||||
'vnts_endpoint': endpoint,
|
||||
'private_key': privateKey,
|
||||
'persistent_keepalive': parseInt(persistentKeepalive),
|
||||
}
|
||||
}, function (data) {
|
||||
if (data && data.code === 200) {
|
||||
location.reload();
|
||||
$('#addModal').hide(); // 关闭模态框
|
||||
} else {
|
||||
$('#addWGError').text('添加失败:' + data.message);
|
||||
}
|
||||
|
||||
})
|
||||
});
|
||||
|
||||
function formatBytes(bytes) {
|
||||
if (!bytes) {
|
||||
return ''
|
||||
@@ -361,6 +530,7 @@
|
||||
graph.render();
|
||||
}
|
||||
let getGroupInfoFunc = function (group) {
|
||||
groupId = group;
|
||||
postGroupInfo({'group': group}, function (response) {
|
||||
if (response && response.code === 200) {
|
||||
let data = response.data;
|
||||
@@ -368,7 +538,7 @@
|
||||
$('.mask_ip').html(data.mask_ip);
|
||||
$('.network_ip').html(data.network_ip);
|
||||
nodeInitFunc(data.gateway_ip, data.clients);
|
||||
displayDeviceInfo(data.clients);
|
||||
displayDeviceInfo(group, data);
|
||||
} else {
|
||||
window.alert("调用服务失败")
|
||||
}
|
||||
@@ -380,6 +550,20 @@
|
||||
console.log(group_list);
|
||||
searchInput(group_list);
|
||||
});
|
||||
|
||||
$('.close').on('click', function () {
|
||||
$('#wgConfigModal').hide();
|
||||
$('#addModal').hide();
|
||||
});
|
||||
$('#toggleButton').on('click', function () {
|
||||
if ($('#qrcode').hasClass('visible')) {
|
||||
$('#qrcode').removeClass('visible').addClass('hidden');
|
||||
$('#textConfig').removeClass('hidden').addClass('visible');
|
||||
$('#toggleButton').text('显示二维码');
|
||||
} else {
|
||||
$('#qrcode').removeClass('hidden').addClass('visible');
|
||||
$('#textConfig').removeClass('visible').addClass('hidden');
|
||||
$('#toggleButton').text('显示文本配置');
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</html>
|
||||
@@ -64,13 +64,25 @@ function post(url, data, success, error,) {
|
||||
}
|
||||
|
||||
function postLogin(requestData, success, error) {
|
||||
post("login", requestData, success, error)
|
||||
post("api/login", requestData, success, error)
|
||||
}
|
||||
|
||||
function postGroupList(requestData, success, error) {
|
||||
post("group_list", requestData, success, error)
|
||||
post("api/group_list", requestData, success, error)
|
||||
}
|
||||
|
||||
function postGroupInfo(requestData, success, error) {
|
||||
post("group_info", requestData, success, error)
|
||||
post("api/group_info", requestData, success, error)
|
||||
}
|
||||
|
||||
function postWgPrivateKey(requestData, success, error) {
|
||||
post("api/private_key", requestData, success, error)
|
||||
}
|
||||
|
||||
function postCreateWG(requestData, success, error) {
|
||||
post("api/create_wg_config", requestData, success, error)
|
||||
}
|
||||
|
||||
function postRemoveClient(requestData, success, error) {
|
||||
post("api/remove_client", requestData, success, error)
|
||||
}
|
||||
1
static/js/qrcode.min.js
vendored
Normal file
1
static/js/qrcode.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user