This commit is contained in:
lbl8603
2024-07-20 10:34:55 +08:00
parent 57b3a61596
commit 92347c30c5
59 changed files with 10078 additions and 527 deletions

254
Cargo.lock generated
View File

@@ -53,7 +53,7 @@ dependencies = [
"actix-service",
"actix-utils",
"ahash",
"base64",
"base64 0.21.7",
"bitflags 2.5.0",
"brotli",
"bytes",
@@ -368,12 +368,24 @@ dependencies = [
"rustc-demangle",
]
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "base64"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "base64ct"
version = "1.6.0"
@@ -392,6 +404,15 @@ version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1"
[[package]]
name = "blake2"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
dependencies = [
"digest",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
@@ -401,6 +422,28 @@ dependencies = [
"generic-array",
]
[[package]]
name = "boringtun"
version = "0.6.0"
dependencies = [
"aead",
"base64 0.13.1",
"blake2",
"chacha20poly1305",
"hex",
"hmac",
"ip_network",
"ip_network_table",
"libc",
"nix",
"parking_lot",
"rand_core",
"ring",
"tracing",
"untrusted",
"x25519-dalek",
]
[[package]]
name = "brotli"
version = "3.5.0"
@@ -466,6 +509,30 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chacha20"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "chacha20poly1305"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35"
dependencies = [
"aead",
"chacha20",
"cipher",
"poly1305",
"zeroize",
]
[[package]]
name = "change-detection"
version = "1.2.0"
@@ -498,6 +565,7 @@ checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"inout",
"zeroize",
]
[[package]]
@@ -671,12 +739,39 @@ dependencies = [
]
[[package]]
name = "dashmap"
version = "5.5.3"
name = "curve25519-dalek"
version = "4.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be"
dependencies = [
"cfg-if",
"cpufeatures",
"curve25519-dalek-derive",
"fiat-crypto",
"rustc_version",
"subtle",
"zeroize",
]
[[package]]
name = "curve25519-dalek-derive"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "dashmap"
version = "6.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "804c8821570c3f8b70230c2ba75ffa5c0f9a4189b9a432b6656c536712acae28"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.3",
"lock_api",
"once_cell",
@@ -691,9 +786,9 @@ checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
[[package]]
name = "der"
version = "0.6.1"
version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1a467a65c5e759bce6e65eaf91cc29f466cdc57cb65777bd646872a8a1fd4de"
checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0"
dependencies = [
"const-oid",
"pem-rfc7468",
@@ -748,6 +843,7 @@ dependencies = [
"block-buffer",
"const-oid",
"crypto-common",
"subtle",
]
[[package]]
@@ -808,6 +904,12 @@ version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
[[package]]
name = "fiat-crypto"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
[[package]]
name = "flate2"
version = "1.0.29"
@@ -973,6 +1075,21 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hmac"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
]
[[package]]
name = "home"
version = "0.5.9"
@@ -1090,6 +1207,37 @@ dependencies = [
"generic-array",
]
[[package]]
name = "ip_network"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa2f047c0a98b2f299aa5d6d7088443570faae494e9ae1305e48be000c9e0eb1"
[[package]]
name = "ip_network_table"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4099b7cfc5c5e2fe8c5edf3f6f7adf7a714c9cc697534f63a5a5da30397cb2c0"
dependencies = [
"ip_network",
"ip_network_table-deps-treebitmap",
]
[[package]]
name = "ip_network_table-deps-treebitmap"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e537132deb99c0eb4b752f0346b6a836200eaaa3516dd7e5514b63930a09e5d"
[[package]]
name = "ipnetwork"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e"
dependencies = [
"serde",
]
[[package]]
name = "is-terminal"
version = "0.4.12"
@@ -1300,6 +1448,18 @@ dependencies = [
"uuid",
]
[[package]]
name = "nix"
version = "0.25.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f346ff70e7dbfd675fe90590b92d59ef2de15a8779ae305ebcbfd3f0caf59be4"
dependencies = [
"autocfg",
"bitflags 1.3.2",
"cfg-if",
"libc",
]
[[package]]
name = "num-bigint-dig"
version = "0.8.4"
@@ -1458,9 +1618,9 @@ checksum = "498a099351efa4becc6a19c72aa9270598e8fd274ca47052e37455241c88b696"
[[package]]
name = "pem-rfc7468"
version = "0.6.0"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d159833a9105500e0398934e205e0773f0b27529557134ecfc51c27646adac"
checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412"
dependencies = [
"base64ct",
]
@@ -1485,21 +1645,20 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "pkcs1"
version = "0.4.1"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eff33bdbdfc54cc98a2eca766ebdec3e1b8fb7387523d5c9c9a2891da856f719"
checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f"
dependencies = [
"der",
"pkcs8",
"spki",
"zeroize",
]
[[package]]
name = "pkcs8"
version = "0.9.0"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9eca2c590a5f85da82668fa685c09ce2888b9430e83299debf1f34b65fd4a4ba"
checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7"
dependencies = [
"der",
"spki",
@@ -1511,6 +1670,17 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec"
[[package]]
name = "poly1305"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf"
dependencies = [
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]]
name = "polyval"
version = "0.6.2"
@@ -1774,21 +1944,20 @@ dependencies = [
[[package]]
name = "rsa"
version = "0.7.2"
version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "094052d5470cbcef561cb848a7209968c9f12dfa6d668f4bca048ac5de51099c"
checksum = "5d0e5124fcb30e76a7e79bfee683a2746db83784b86289f6251b54b7950a0dfc"
dependencies = [
"byteorder",
"const-oid",
"digest",
"num-bigint-dig",
"num-integer",
"num-iter",
"num-traits",
"pkcs1",
"pkcs8",
"rand_core",
"signature",
"smallvec",
"spki",
"subtle",
"zeroize",
]
@@ -1938,9 +2107,9 @@ dependencies = [
[[package]]
name = "signature"
version = "1.6.4"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74233d3b3b2f6d4b006dc19dee745e73e2a6bfb6f93607cd3b02bd5b00797d7c"
checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de"
dependencies = [
"digest",
"rand_core",
@@ -1985,9 +2154,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]]
name = "spki"
version = "0.6.0"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67cf02bbac7a337dc36e4f5a693db6c21e7863f45070f7064577eb4367a3212b"
checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d"
dependencies = [
"base64ct",
"der",
@@ -2210,9 +2379,21 @@ checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "tracing-core"
version = "0.1.32"
@@ -2370,6 +2551,8 @@ dependencies = [
"aes-gcm",
"anyhow",
"async-trait",
"base64 0.22.1",
"boringtun",
"chrono",
"clap",
"colored",
@@ -2378,6 +2561,7 @@ dependencies = [
"dashmap",
"dirs",
"futures-util",
"ipnetwork",
"lazy_static",
"log",
"log4rs",
@@ -2653,6 +2837,18 @@ version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
[[package]]
name = "x25519-dalek"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277"
dependencies = [
"curve25519-dalek",
"rand_core",
"serde",
"zeroize",
]
[[package]]
name = "zerocopy"
version = "0.7.32"
@@ -2678,6 +2874,20 @@ name = "zeroize"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
dependencies = [
"zeroize_derive",
]
[[package]]
name = "zeroize_derive"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "zstd"

View File

@@ -13,10 +13,10 @@ log4rs = "1.3"
dirs = "5"
crossbeam = "0.8"
parking_lot = "0.12"
dashmap = "5.5"
dashmap = "6.0.1"
rsa = { version = "0.7.2", features = [] }
spki = { version = "0.6.0", features = ["fingerprint", "alloc"] }
rsa = { version = "0.9.6", features = [] }
spki = { version = "0.7.3", features = ["fingerprint", "alloc", "base64"] }
aes-gcm = { version = "0.10.2", optional = true }
ring = { version = "0.17", optional = true }
rand = "0.8"
@@ -39,6 +39,10 @@ actix-files = { version = "0.6", optional = true }
actix-web-static-files = { version = "4.0.1", optional = true }
tokio-tungstenite = "0.23.1"
boringtun = { path = "lib/boringtun", features = [] }
ipnetwork = "0.20.0"
base64 = "0.22.1"
serde = { version = "1", features = ["derive"] }
crossbeam-utils = "0.8"
futures-util = "0.3"

64
lib/boringtun/Cargo.toml Normal file
View File

@@ -0,0 +1,64 @@
[package]
name = "boringtun"
description = "an implementation of the WireGuard® protocol designed for portability and speed"
version = "0.6.0"
authors = [
"Noah Kennedy <nkennedy@cloudflare.com>",
"Andy Grover <agrover@cloudflare.com>",
"Jeff Hiner <jhiner@cloudflare.com>",
]
license = "BSD-3-Clause"
repository = "https://github.com/cloudflare/boringtun"
documentation = "https://docs.rs/boringtun/0.5.2/boringtun/"
edition = "2018"
[features]
default = []
device = ["socket2", "thiserror"]
jni-bindings = ["ffi-bindings", "jni"]
ffi-bindings = ["tracing-subscriber"]
# mocks std::time::Instant with mock_instant
mock-instant = ["mock_instant"]
[dependencies]
base64 = "0.13"
hex = "0.4"
untrusted = "0.9.0"
libc = "0.2"
parking_lot = "0.12"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3", features = ["fmt"], optional = true }
ip_network = "0.4.1"
ip_network_table = "0.2.0"
ring = "0.17"
x25519-dalek = { version = "2.0.0", features = [
"reusable_secrets",
"static_secrets",
] }
rand_core = { version = "0.6.4", features = ["getrandom"] }
chacha20poly1305 = "0.10.0-pre.1"
aead = "0.5.0-pre.2"
blake2 = "0.10"
hmac = "0.12"
jni = { version = "0.19.0", optional = true }
mock_instant = { version = "0.3", optional = true }
socket2 = { version = "0.4.7", features = ["all"], optional = true }
thiserror = { version = "1", optional = true }
[target.'cfg(unix)'.dependencies]
nix = { version = "0.25", default-features = false, features = [
"time",
"user",
] }
[dev-dependencies]
etherparse = "0.13"
tracing-subscriber = "0.3"
criterion = { version = "0.3.5", features = ["html_reports"] }
[lib]
crate-type = ["staticlib", "cdylib", "rlib"]
[[bench]]
name = "crypto_benches"
harness = false

View File

@@ -0,0 +1,90 @@
use blake2::digest::{FixedOutput, KeyInit};
use blake2::{Blake2s256, Blake2sMac, Digest};
use criterion::{BenchmarkId, Criterion, Throughput};
use ring::rand::{SecureRandom, SystemRandom};
pub fn bench_blake2s_hash(c: &mut Criterion) {
let mut group = c.benchmark_group("blake2s_hash");
group.sample_size(1000);
for size in [32, 64, 128] {
group.throughput(Throughput::Bytes(size as u64));
group.bench_with_input(BenchmarkId::new("blake2s_crate", size), &size, |b, _| {
let buf_in = vec![0u8; size];
b.iter(|| {
let mut hasher = Blake2s256::new();
hasher.update(&buf_in);
hasher.finalize();
});
});
}
group.finish();
}
pub fn bench_blake2s_hmac(c: &mut Criterion) {
let mut group = c.benchmark_group("blake2s_hmac");
group.sample_size(1000);
for size in [16, 32] {
group.throughput(Throughput::Bytes(size as u64));
group.bench_with_input(BenchmarkId::new("blake2s_crate", size), &size, |b, _| {
let buf_in = vec![0u8; size];
let rng = SystemRandom::new();
b.iter_batched(
|| {
let mut key = [0u8; 32];
rng.fill(&mut key).unwrap();
key
},
|key| {
use blake2::digest::Update;
type HmacBlake2s = hmac::SimpleHmac<blake2::Blake2s256>;
let mut hmac = HmacBlake2s::new_from_slice(&key).unwrap();
hmac.update(&buf_in);
hmac.finalize_fixed();
},
criterion::BatchSize::SmallInput,
);
});
}
group.finish();
}
pub fn bench_blake2s_keyed(c: &mut Criterion) {
let mut group = c.benchmark_group("blake2s_keyed_mac");
group.sample_size(1000);
for size in [128, 1024] {
group.throughput(Throughput::Bytes(size as u64));
group.bench_with_input(BenchmarkId::new("blake2s_crate", size), &size, |b, _| {
let buf_in = vec![0u8; size];
let rng = SystemRandom::new();
b.iter_batched(
|| {
let mut key = [0u8; 16];
rng.fill(&mut key).unwrap();
key
},
|key| -> [u8; 16] {
let mut hmac = Blake2sMac::new_from_slice(&key).unwrap();
blake2::digest::Update::update(&mut hmac, &buf_in);
hmac.finalize_fixed().into()
},
criterion::BatchSize::SmallInput,
);
});
}
group.finish();
}

View File

@@ -0,0 +1,79 @@
use aead::{AeadInPlace, KeyInit};
use criterion::{BenchmarkId, Criterion, Throughput};
use rand_core::{OsRng, RngCore};
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
fn chacha20poly1305_ring(key_bytes: &[u8], buf: &mut [u8]) {
let len = buf.len();
let n = len - 16;
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key_bytes).unwrap());
let tag = key
.seal_in_place_separate_tag(
Nonce::assume_unique_for_key([0u8; 12]),
Aad::from(&[]),
&mut buf[..n],
)
.unwrap();
buf[n..].copy_from_slice(tag.as_ref())
}
fn chacha20poly1305_non_ring(key_bytes: &[u8], buf: &mut [u8]) {
let len = buf.len();
let n = len - 16;
let aead = chacha20poly1305::ChaCha20Poly1305::new_from_slice(key_bytes).unwrap();
let nonce = chacha20poly1305::Nonce::default();
let tag = aead
.encrypt_in_place_detached(&nonce, &[], &mut buf[..n])
.unwrap();
buf[n..].copy_from_slice(tag.as_ref());
}
pub fn bench_chacha20poly1305(c: &mut Criterion) {
let mut group = c.benchmark_group("chacha20poly1305");
group.sample_size(1000);
for size in [128, 192, 1400, 8192] {
group.throughput(Throughput::Bytes(size as u64));
group.bench_with_input(
BenchmarkId::new("chacha20poly1305_ring", size),
&size,
|b, i| {
let mut key = [0; 32];
let mut buf = vec![0; i + 16];
let mut rng = OsRng::default();
rng.fill_bytes(&mut key);
rng.fill_bytes(&mut buf);
b.iter(|| chacha20poly1305_ring(&key, &mut buf));
},
);
group.bench_with_input(
BenchmarkId::new("chacha20poly1305_non_ring", size),
&size,
|b, i| {
let mut key = [0; 32];
let mut buf = vec![0; i + 16];
let mut rng = OsRng::default();
rng.fill_bytes(&mut key);
rng.fill_bytes(&mut buf);
b.iter(|| chacha20poly1305_non_ring(&key, &mut buf));
},
);
}
group.finish();
}

View File

@@ -0,0 +1,20 @@
use blake2s_benching::{bench_blake2s_hash, bench_blake2s_hmac, bench_blake2s_keyed};
use chacha20poly1305_benching::bench_chacha20poly1305;
use x25519_public_key_benching::bench_x25519_public_key;
use x25519_shared_key_benching::bench_x25519_shared_key;
mod blake2s_benching;
mod chacha20poly1305_benching;
mod x25519_public_key_benching;
mod x25519_shared_key_benching;
criterion::criterion_group!(
crypto_benches,
bench_chacha20poly1305,
bench_blake2s_hash,
bench_blake2s_hmac,
bench_blake2s_keyed,
bench_x25519_shared_key,
bench_x25519_public_key
);
criterion::criterion_main!(crypto_benches);

View File

@@ -0,0 +1,30 @@
use criterion::Criterion;
use rand_core::OsRng;
pub fn bench_x25519_public_key(c: &mut Criterion) {
let mut group = c.benchmark_group("x25519_public_key");
group.sample_size(1000);
group.bench_function("x25519_public_key_dalek", |b| {
b.iter(|| {
let secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
let public_key = x25519_dalek::PublicKey::from(&secret_key);
(secret_key, public_key)
});
});
group.bench_function("x25519_public_key_ring", |b| {
let rng = ring::rand::SystemRandom::new();
b.iter(|| {
let my_private_key =
ring::agreement::EphemeralPrivateKey::generate(&ring::agreement::X25519, &rng)
.unwrap();
my_private_key.compute_public_key().unwrap()
});
});
group.finish();
}

View File

@@ -0,0 +1,48 @@
use criterion::{BatchSize, Criterion};
use rand_core::OsRng;
pub fn bench_x25519_shared_key(c: &mut Criterion) {
let mut group = c.benchmark_group("x25519_shared_key");
group.sample_size(1000);
group.bench_function("x25519_shared_key_dalek", |b| {
let public_key =
x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random_from_rng(OsRng));
b.iter_batched(
|| x25519_dalek::StaticSecret::random_from_rng(OsRng),
|secret_key| secret_key.diffie_hellman(&public_key),
BatchSize::SmallInput,
);
});
group.bench_function("x25519_shared_key_ring", |b| {
let rng = ring::rand::SystemRandom::new();
let peer_public_key = {
let peer_private_key =
ring::agreement::EphemeralPrivateKey::generate(&ring::agreement::X25519, &rng)
.unwrap();
peer_private_key.compute_public_key().unwrap()
};
let peer_public_key_alg = &ring::agreement::X25519;
let my_public_key =
ring::agreement::UnparsedPublicKey::new(peer_public_key_alg, &peer_public_key);
b.iter_batched(
|| {
ring::agreement::EphemeralPrivateKey::generate(&ring::agreement::X25519, &rng)
.unwrap()
},
|my_private_key| {
ring::agreement::agree_ephemeral(my_private_key, &my_public_key, |_key_material| ())
.unwrap()
},
BatchSize::SmallInput,
);
});
group.finish();
}

View File

@@ -0,0 +1,389 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use crate::device::peer::AllowedIP;
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use std::collections::VecDeque;
use std::iter::FromIterator;
use std::net::IpAddr;
/// A trie of IP/cidr addresses
#[derive(Default)]
pub struct AllowedIps<D> {
ips: IpNetworkTable<D>,
}
impl<'a, D> FromIterator<(&'a AllowedIP, D)> for AllowedIps<D> {
fn from_iter<I: IntoIterator<Item = (&'a AllowedIP, D)>>(iter: I) -> Self {
let mut allowed_ips = AllowedIps::new();
for (ip, data) in iter {
allowed_ips.insert(ip.addr, ip.cidr as u32, data);
}
allowed_ips
}
}
impl<D> AllowedIps<D> {
pub fn new() -> Self {
Self {
ips: IpNetworkTable::new(),
}
}
pub fn clear(&mut self) {
self.ips = IpNetworkTable::new();
}
pub fn insert(&mut self, key: IpAddr, cidr: u32, data: D) -> Option<D> {
// These are networks, it doesn't make sense for host bits to be set, so
// use new_truncate().
self.ips.insert(
IpNetwork::new_truncate(key, cidr as u8).expect("cidr is valid length"),
data,
)
}
pub fn find(&self, key: IpAddr) -> Option<&D> {
self.ips.longest_match(key).map(|(_net, data)| data)
}
pub fn remove(&mut self, predicate: &dyn Fn(&D) -> bool) {
self.ips.retain(|_, v| !predicate(v));
}
pub fn iter(&self) -> Iter<D> {
Iter(
self.ips
.iter()
.map(|(ipa, d)| (d, ipa.network_address(), ipa.netmask()))
.collect(),
)
}
}
pub struct Iter<'a, D: 'a>(VecDeque<(&'a D, IpAddr, u8)>);
impl<'a, D> Iterator for Iter<'a, D> {
type Item = (&'a D, IpAddr, u8);
fn next(&mut self) -> Option<Self::Item> {
self.0.pop_front()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build_allowed_ips() -> AllowedIps<char> {
let mut map: AllowedIps<char> = Default::default();
map.insert(IpAddr::from([127, 0, 0, 1]), 32, '1');
map.insert(IpAddr::from([45, 25, 15, 1]), 30, '6');
map.insert(IpAddr::from([127, 0, 15, 1]), 16, '2');
map.insert(IpAddr::from([127, 1, 15, 1]), 24, '3');
map.insert(IpAddr::from([255, 1, 15, 1]), 24, '4');
map.insert(IpAddr::from([60, 25, 15, 1]), 32, '5');
map.insert(IpAddr::from([553, 0, 0, 1, 0, 0, 0, 0]), 128, '7');
map
}
#[test]
fn test_allowed_ips_insert_find() {
let map = build_allowed_ips();
assert_eq!(map.find(IpAddr::from([127, 0, 0, 1])), Some(&'1'));
assert_eq!(map.find(IpAddr::from([127, 0, 255, 255])), Some(&'2'));
assert_eq!(map.find(IpAddr::from([127, 1, 255, 255])), None);
assert_eq!(map.find(IpAddr::from([127, 0, 255, 255])), Some(&'2'));
assert_eq!(map.find(IpAddr::from([127, 1, 15, 255])), Some(&'3'));
assert_eq!(map.find(IpAddr::from([127, 0, 255, 255])), Some(&'2'));
assert_eq!(map.find(IpAddr::from([127, 1, 15, 255])), Some(&'3'));
assert_eq!(map.find(IpAddr::from([255, 1, 15, 2])), Some(&'4'));
assert_eq!(map.find(IpAddr::from([60, 25, 15, 1])), Some(&'5'));
assert_eq!(map.find(IpAddr::from([20, 0, 0, 100])), None);
assert_eq!(
map.find(IpAddr::from([553, 0, 0, 1, 0, 0, 0, 0])),
Some(&'7')
);
assert_eq!(map.find(IpAddr::from([553, 0, 0, 1, 0, 0, 0, 1])), None);
assert_eq!(map.find(IpAddr::from([45, 25, 15, 1])), Some(&'6'));
}
#[test]
fn test_allowed_ips_remove() {
let mut map = build_allowed_ips();
map.remove(&|c| *c == '5' || *c == '1' || *c == '7');
let mut map_iter = map.iter();
assert_eq!(
map_iter.next(),
Some((&'6', IpAddr::from([45, 25, 15, 0]), 30))
);
assert_eq!(
map_iter.next(),
Some((&'2', IpAddr::from([127, 0, 0, 0]), 16))
);
assert_eq!(
map_iter.next(),
Some((&'3', IpAddr::from([127, 1, 15, 0]), 24))
);
assert_eq!(
map_iter.next(),
Some((&'4', IpAddr::from([255, 1, 15, 0]), 24))
);
assert_eq!(map_iter.next(), None);
}
#[test]
fn test_allowed_ips_iter() {
let map = build_allowed_ips();
let mut map_iter = map.iter();
assert_eq!(
map_iter.next(),
Some((&'6', IpAddr::from([45, 25, 15, 0]), 30))
);
assert_eq!(
map_iter.next(),
Some((&'5', IpAddr::from([60, 25, 15, 1]), 32))
);
assert_eq!(
map_iter.next(),
Some((&'2', IpAddr::from([127, 0, 0, 0]), 16))
);
assert_eq!(
map_iter.next(),
Some((&'1', IpAddr::from([127, 0, 0, 1]), 32))
);
assert_eq!(
map_iter.next(),
Some((&'3', IpAddr::from([127, 1, 15, 0]), 24))
);
assert_eq!(
map_iter.next(),
Some((&'4', IpAddr::from([255, 1, 15, 0]), 24))
);
assert_eq!(
map_iter.next(),
Some((&'7', IpAddr::from([553, 0, 0, 1, 0, 0, 0, 0]), 128))
);
assert_eq!(map_iter.next(), None);
}
#[test]
fn test_allowed_ips_v4_kernel_compatibility() {
// Test case from wireguard-go
let mut map: AllowedIps<char> = Default::default();
map.insert(IpAddr::from([192, 168, 4, 0]), 24, 'a');
map.insert(IpAddr::from([192, 168, 4, 4]), 32, 'b');
map.insert(IpAddr::from([192, 168, 0, 0]), 16, 'c');
map.insert(IpAddr::from([192, 95, 5, 64]), 27, 'd');
map.insert(IpAddr::from([192, 95, 5, 65]), 27, 'c');
map.insert(IpAddr::from([0, 0, 0, 0]), 0, 'e');
map.insert(IpAddr::from([64, 15, 112, 0]), 20, 'g');
map.insert(IpAddr::from([64, 15, 123, 211]), 25, 'h');
map.insert(IpAddr::from([10, 0, 0, 0]), 25, 'a');
map.insert(IpAddr::from([10, 0, 0, 128]), 25, 'b');
map.insert(IpAddr::from([10, 1, 0, 0]), 30, 'a');
map.insert(IpAddr::from([10, 1, 0, 4]), 30, 'b');
map.insert(IpAddr::from([10, 1, 0, 8]), 29, 'c');
map.insert(IpAddr::from([10, 1, 0, 16]), 29, 'd');
assert_eq!(Some(&'a'), map.find(IpAddr::from([192, 168, 4, 20])));
assert_eq!(Some(&'a'), map.find(IpAddr::from([192, 168, 4, 0])));
assert_eq!(Some(&'b'), map.find(IpAddr::from([192, 168, 4, 4])));
assert_eq!(Some(&'c'), map.find(IpAddr::from([192, 168, 200, 182])));
assert_eq!(Some(&'c'), map.find(IpAddr::from([192, 95, 5, 68])));
assert_eq!(Some(&'e'), map.find(IpAddr::from([192, 95, 5, 96])));
assert_eq!(Some(&'g'), map.find(IpAddr::from([64, 15, 116, 26])));
assert_eq!(Some(&'g'), map.find(IpAddr::from([64, 15, 127, 3])));
map.insert(IpAddr::from([1, 0, 0, 0]), 32, 'a');
map.insert(IpAddr::from([64, 0, 0, 0]), 32, 'a');
map.insert(IpAddr::from([128, 0, 0, 0]), 32, 'a');
map.insert(IpAddr::from([192, 0, 0, 0]), 32, 'a');
map.insert(IpAddr::from([255, 0, 0, 0]), 32, 'a');
assert_eq!(Some(&'a'), map.find(IpAddr::from([1, 0, 0, 0])));
assert_eq!(Some(&'a'), map.find(IpAddr::from([64, 0, 0, 0])));
assert_eq!(Some(&'a'), map.find(IpAddr::from([128, 0, 0, 0])));
assert_eq!(Some(&'a'), map.find(IpAddr::from([192, 0, 0, 0])));
assert_eq!(Some(&'a'), map.find(IpAddr::from([255, 0, 0, 0])));
map.remove(&|c| *c == 'a');
assert_ne!(Some(&'a'), map.find(IpAddr::from([1, 0, 0, 0])));
assert_ne!(Some(&'a'), map.find(IpAddr::from([64, 0, 0, 0])));
assert_ne!(Some(&'a'), map.find(IpAddr::from([128, 0, 0, 0])));
assert_ne!(Some(&'a'), map.find(IpAddr::from([192, 0, 0, 0])));
assert_ne!(Some(&'a'), map.find(IpAddr::from([255, 0, 0, 0])));
map.clear();
map.insert(IpAddr::from([192, 168, 0, 0]), 16, 'a');
map.insert(IpAddr::from([192, 168, 0, 0]), 24, 'a');
map.remove(&|c| *c == 'a');
assert_ne!(Some(&'a'), map.find(IpAddr::from([192, 168, 0, 1])));
}
#[test]
fn test_allowed_ips_v6_kernel_compatibility() {
// Test case from wireguard-go
let mut map: AllowedIps<char> = Default::default();
map.insert(
IpAddr::from([
0x2607, 0x5300, 0x6000, 0x6b00, 0x0000, 0x0000, 0xc05f, 0x0543,
]),
128,
'd',
);
map.insert(
IpAddr::from([
0x2607, 0x5300, 0x6000, 0x6b00, 0x0000, 0x0000, 0x0000, 0x0000,
]),
64,
'c',
);
map.insert(
IpAddr::from([
0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
]),
0,
'e',
);
map.insert(
IpAddr::from([
0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
]),
0,
'f',
);
map.insert(
IpAddr::from([
0x2404, 0x6800, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
]),
32,
'g',
);
map.insert(
IpAddr::from([
0x2404, 0x6800, 0x4004, 0x0800, 0xdead, 0xbeef, 0xdead, 0xbeef,
]),
64,
'h',
);
map.insert(
IpAddr::from([
0x2404, 0x6800, 0x4004, 0x0800, 0xdead, 0xbeef, 0xdead, 0xbeef,
]),
128,
'a',
);
map.insert(
IpAddr::from([
0x2444, 0x6800, 0x40e4, 0x0800, 0xdeae, 0xbeef, 0x0def, 0xbeef,
]),
128,
'c',
);
map.insert(
IpAddr::from([
0x2444, 0x6800, 0xf0e4, 0x0800, 0xeeae, 0xbeef, 0x0000, 0x0000,
]),
98,
'b',
);
assert_eq!(
Some(&'d'),
map.find(IpAddr::from([
0x2607, 0x5300, 0x6000, 0x6b00, 0x0000, 0x0000, 0xc05f, 0x0543
]))
);
assert_eq!(
Some(&'c'),
map.find(IpAddr::from([
0x2607, 0x5300, 0x6000, 0x6b00, 0, 0, 0xc02e, 0x01ee
]))
);
assert_eq!(
Some(&'f'),
map.find(IpAddr::from([0x2607, 0x5300, 0x6000, 0x6b01, 0, 0, 0, 0]))
);
assert_eq!(
Some(&'g'),
map.find(IpAddr::from([
0x2404, 0x6800, 0x4004, 0x0806, 0, 0, 0, 0x1006
]))
);
assert_eq!(
Some(&'g'),
map.find(IpAddr::from([
0x2404, 0x6800, 0x4004, 0x0806, 0, 0x1234, 0, 0x5678
]))
);
assert_eq!(
Some(&'f'),
map.find(IpAddr::from([
0x2404, 0x67ff, 0x4004, 0x0806, 0, 0x1234, 0, 0x5678
]))
);
assert_eq!(
Some(&'f'),
map.find(IpAddr::from([
0x2404, 0x6801, 0x4004, 0x0806, 0, 0x1234, 0, 0x5678
]))
);
assert_eq!(
Some(&'h'),
map.find(IpAddr::from([
0x2404, 0x6800, 0x4004, 0x0800, 0, 0x1234, 0, 0x5678
]))
);
assert_eq!(
Some(&'h'),
map.find(IpAddr::from([0x2404, 0x6800, 0x4004, 0x0800, 0, 0, 0, 0]))
);
assert_eq!(
Some(&'h'),
map.find(IpAddr::from([
0x2404, 0x6800, 0x4004, 0x0800, 0x1010, 0x1010, 0x1010, 0x1010
]))
);
assert_eq!(
Some(&'a'),
map.find(IpAddr::from([
0x2404, 0x6800, 0x4004, 0x0800, 0xdead, 0xbeef, 0xdead, 0xbeef
]))
);
}
#[test]
fn test_allowed_ips_iter_zero_leaf_bits() {
let mut map: AllowedIps<char> = Default::default();
map.insert(IpAddr::from([10, 111, 0, 1]), 32, '1');
map.insert(IpAddr::from([10, 111, 0, 2]), 32, '2');
map.insert(IpAddr::from([10, 111, 0, 3]), 32, '3');
let mut map_iter = map.iter();
assert_eq!(
map_iter.next(),
Some((&'1', IpAddr::from([10, 111, 0, 1]), 32))
);
assert_eq!(
map_iter.next(),
Some((&'2', IpAddr::from([10, 111, 0, 2]), 32))
);
assert_eq!(
map_iter.next(),
Some((&'3', IpAddr::from([10, 111, 0, 3]), 32))
);
assert_eq!(map_iter.next(), None);
}
}

View File

@@ -0,0 +1,368 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use super::dev_lock::LockReadGuard;
use super::drop_privileges::get_saved_ids;
use super::{AllowedIP, Device, Error, SocketAddr};
use crate::device::Action;
use crate::serialization::KeyBytes;
use crate::x25519;
use hex::encode as encode_hex;
use libc::*;
use std::fs::{create_dir, remove_file};
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net::{UnixListener, UnixStream};
use std::sync::atomic::Ordering;
const SOCK_DIR: &str = "/var/run/wireguard/";
fn create_sock_dir() {
let _ = create_dir(SOCK_DIR); // Create the directory if it does not exist
if let Ok((saved_uid, saved_gid)) = get_saved_ids() {
unsafe {
let c_path = std::ffi::CString::new(SOCK_DIR).unwrap();
// The directory is under the root user, but we want to be able to
// delete the files there when we exit, so we need to change the owner
chown(
c_path.as_bytes_with_nul().as_ptr() as _,
saved_uid,
saved_gid,
);
}
}
}
impl Device {
/// Register the api handler for this Device. The api handler receives stream connections on a Unix socket
/// with a known path: /var/run/wireguard/{tun_name}.sock.
pub fn register_api_handler(&mut self) -> Result<(), Error> {
let path = format!("{}/{}.sock", SOCK_DIR, self.iface.name()?);
create_sock_dir();
let _ = remove_file(&path); // Attempt to remove the socket if already exists
let api_listener = UnixListener::bind(&path).map_err(Error::ApiSocket)?; // Bind a new socket to the path
self.cleanup_paths.push(path.clone());
self.queue.new_event(
api_listener.as_raw_fd(),
Box::new(move |d, _| {
// This is the closure that listens on the api unix socket
let (api_conn, _) = match api_listener.accept() {
Ok(conn) => conn,
_ => return Action::Continue,
};
let mut reader = BufReader::new(&api_conn);
let mut writer = BufWriter::new(&api_conn);
let mut cmd = String::new();
if reader.read_line(&mut cmd).is_ok() {
cmd.pop(); // pop the new line character
let status = match cmd.as_ref() {
// Only two commands are legal according to the protocol, get=1 and set=1.
"get=1" => api_get(&mut writer, d),
"set=1" => api_set(&mut reader, d),
_ => EIO,
};
// The protocol requires to return an error code as the response, or zero on success
writeln!(writer, "errno={}\n", status).ok();
}
Action::Continue // Indicates the worker thread should continue as normal
}),
)?;
self.register_monitor(path)?;
self.register_api_signal_handlers()
}
pub fn register_api_fd(&mut self, fd: i32) -> Result<(), Error> {
let io_file = unsafe { UnixStream::from_raw_fd(fd) };
self.queue.new_event(
io_file.as_raw_fd(),
Box::new(move |d, _| {
// This is the closure that listens on the api file descriptor
let mut reader = BufReader::new(&io_file);
let mut writer = BufWriter::new(&io_file);
let mut cmd = String::new();
if reader.read_line(&mut cmd).is_ok() {
cmd.pop(); // pop the new line character
let status = match cmd.as_ref() {
// Only two commands are legal according to the protocol, get=1 and set=1.
"get=1" => api_get(&mut writer, d),
"set=1" => api_set(&mut reader, d),
_ => EIO,
};
// The protocol requires to return an error code as the response, or zero on success
writeln!(writer, "errno={}\n", status).ok();
} else {
// The remote side is likely closed; we should trigger an exit.
d.trigger_exit();
return Action::Exit;
}
Action::Continue // Indicates the worker thread should continue as normal
}),
)?;
Ok(())
}
fn register_monitor(&self, path: String) -> Result<(), Error> {
self.queue.new_periodic_event(
Box::new(move |d, _| {
// This is not a very nice hack to detect if the control socket was removed
// and exiting nicely as a result. We check every 3 seconds in a loop if the
// file was deleted by stating it.
// The problem is that on linux inotify can be used quite beautifully to detect
// deletion, and kqueue EVFILT_VNODE can be used for the same purpose, but that
// will require introducing new events, for no measurable benefit.
// TODO: Could this be an issue if we restart the service too quickly?
let path = std::path::Path::new(&path);
if !path.exists() {
d.trigger_exit();
return Action::Exit;
}
// Periodically read the mtu of the interface in case it changes
if let Ok(mtu) = d.iface.mtu() {
d.mtu.store(mtu, Ordering::Relaxed);
}
Action::Continue
}),
std::time::Duration::from_millis(1000),
)?;
Ok(())
}
fn register_api_signal_handlers(&self) -> Result<(), Error> {
self.queue
.new_signal_event(SIGINT, Box::new(move |_, _| Action::Exit))?;
self.queue
.new_signal_event(SIGTERM, Box::new(move |_, _| Action::Exit))?;
Ok(())
}
}
#[allow(unused_must_use)]
fn api_get(writer: &mut BufWriter<&UnixStream>, d: &Device) -> i32 {
// get command requires an empty line, but there is no reason to be religious about it
if let Some(ref k) = d.key_pair {
writeln!(writer, "own_public_key={}", encode_hex(k.1.as_bytes()));
}
if d.listen_port != 0 {
writeln!(writer, "listen_port={}", d.listen_port);
}
if let Some(fwmark) = d.fwmark {
writeln!(writer, "fwmark={}", fwmark);
}
for (k, p) in d.peers.iter() {
let p = p.lock();
writeln!(writer, "public_key={}", encode_hex(k.as_bytes()));
if let Some(ref key) = p.preshared_key() {
writeln!(writer, "preshared_key={}", encode_hex(key));
}
if let Some(keepalive) = p.persistent_keepalive() {
writeln!(writer, "persistent_keepalive_interval={}", keepalive);
}
if let Some(ref addr) = p.endpoint().addr {
writeln!(writer, "endpoint={}", addr);
}
for (ip, cidr) in p.allowed_ips() {
writeln!(writer, "allowed_ip={}/{}", ip, cidr);
}
if let Some(time) = p.time_since_last_handshake() {
writeln!(writer, "last_handshake_time_sec={}", time.as_secs());
writeln!(writer, "last_handshake_time_nsec={}", time.subsec_nanos());
}
let (_, tx_bytes, rx_bytes, ..) = p.tunnel.stats();
writeln!(writer, "rx_bytes={}", rx_bytes);
writeln!(writer, "tx_bytes={}", tx_bytes);
}
0
}
fn api_set(reader: &mut BufReader<&UnixStream>, d: &mut LockReadGuard<Device>) -> i32 {
d.try_writeable(
|device| device.trigger_yield(),
|device| {
device.cancel_yield();
let mut cmd = String::new();
while reader.read_line(&mut cmd).is_ok() {
cmd.pop(); // remove newline if any
if cmd.is_empty() {
return 0; // Done
}
{
let parsed_cmd: Vec<&str> = cmd.split('=').collect();
if parsed_cmd.len() != 2 {
return EPROTO;
}
let (key, val) = (parsed_cmd[0], parsed_cmd[1]);
match key {
"private_key" => match val.parse::<KeyBytes>() {
Ok(key_bytes) => {
device.set_key(x25519::StaticSecret::from(key_bytes.0))
}
Err(_) => return EINVAL,
},
"listen_port" => match val.parse::<u16>() {
Ok(port) => match device.open_listen_socket(port) {
Ok(()) => {}
Err(_) => return EADDRINUSE,
},
Err(_) => return EINVAL,
},
#[cfg(any(
target_os = "android",
target_os = "fuchsia",
target_os = "linux"
))]
"fwmark" => match val.parse::<u32>() {
Ok(mark) => match device.set_fwmark(mark) {
Ok(()) => {}
Err(_) => return EADDRINUSE,
},
Err(_) => return EINVAL,
},
"replace_peers" => match val.parse::<bool>() {
Ok(true) => device.clear_peers(),
Ok(false) => {}
Err(_) => return EINVAL,
},
"public_key" => match val.parse::<KeyBytes>() {
// Indicates a new peer section
Ok(key_bytes) => {
return api_set_peer(
reader,
device,
x25519::PublicKey::from(key_bytes.0),
)
}
Err(_) => return EINVAL,
},
_ => return EINVAL,
}
}
cmd.clear();
}
0
},
)
.unwrap_or(EIO)
}
fn api_set_peer(
reader: &mut BufReader<&UnixStream>,
d: &mut Device,
pub_key: x25519::PublicKey,
) -> i32 {
let mut cmd = String::new();
let mut remove = false;
let mut replace_ips = false;
let mut endpoint = None;
let mut keepalive = None;
let mut public_key = pub_key;
let mut preshared_key = None;
let mut allowed_ips: Vec<AllowedIP> = vec![];
while reader.read_line(&mut cmd).is_ok() {
cmd.pop(); // remove newline if any
if cmd.is_empty() {
d.update_peer(
public_key,
remove,
replace_ips,
endpoint,
allowed_ips.as_slice(),
keepalive,
preshared_key,
);
allowed_ips.clear(); //clear the vector content after update
return 0; // Done
}
{
let parsed_cmd: Vec<&str> = cmd.splitn(2, '=').collect();
if parsed_cmd.len() != 2 {
return EPROTO;
}
let (key, val) = (parsed_cmd[0], parsed_cmd[1]);
match key {
"remove" => match val.parse::<bool>() {
Ok(true) => remove = true,
Ok(false) => remove = false,
Err(_) => return EINVAL,
},
"preshared_key" => match val.parse::<KeyBytes>() {
Ok(key_bytes) => preshared_key = Some(key_bytes.0),
Err(_) => return EINVAL,
},
"endpoint" => match val.parse::<SocketAddr>() {
Ok(addr) => endpoint = Some(addr),
Err(_) => return EINVAL,
},
"persistent_keepalive_interval" => match val.parse::<u16>() {
Ok(interval) => keepalive = Some(interval),
Err(_) => return EINVAL,
},
"replace_allowed_ips" => match val.parse::<bool>() {
Ok(true) => replace_ips = true,
Ok(false) => replace_ips = false,
Err(_) => return EINVAL,
},
"allowed_ip" => match val.parse::<AllowedIP>() {
Ok(ip) => allowed_ips.push(ip),
Err(_) => return EINVAL,
},
"public_key" => {
// Indicates a new peer section. Commit changes for current peer, and continue to next peer
d.update_peer(
public_key,
remove,
replace_ips,
endpoint,
allowed_ips.as_slice(),
keepalive,
preshared_key,
);
allowed_ips.clear(); //clear the vector content after update
match val.parse::<KeyBytes>() {
Ok(key_bytes) => public_key = key_bytes.0.into(),
Err(_) => return EINVAL,
}
}
"protocol_version" => match val.parse::<u32>() {
Ok(1) => {} // Only version 1 is legal
_ => return EINVAL,
},
_ => return EINVAL,
}
}
cmd.clear();
}
0
}

View File

@@ -0,0 +1,108 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use parking_lot::{Condvar, Mutex, RwLock, RwLockReadGuard};
use std::ops::Deref;
/// A special type of read/write lock, that makes the following assumptions:
/// a) Read access is frequent, and has to be very fast, so we want to hold it indefinitely
/// b) Write access is very rare (think less than once per second) and can be a bit slower
/// c) A thread that holds a read lock, can ask for an upgrade to a write lock, cooperatively asking other threads to yield their locks
pub struct Lock<T: ?Sized> {
wants_write: (Mutex<bool>, Condvar),
inner: RwLock<T>, // Although parking lot lock is upgradable, it does not allow a two staged mark + lock upgrade
}
impl<T> Lock<T> {
/// New lock
pub fn new(user_data: T) -> Lock<T> {
Lock {
wants_write: (Mutex::new(false), Condvar::new()),
inner: RwLock::new(user_data),
}
}
}
impl<T: ?Sized> Lock<T> {
/// Acquire a read lock
pub fn read(&self) -> LockReadGuard<T> {
let (ref lock, ref cvar) = &self.wants_write;
let mut wants_write = lock.lock();
while *wants_write {
// We have a writer and we want to wait for it to go away
cvar.wait(&mut wants_write);
}
LockReadGuard {
wants_write: &self.wants_write,
inner: self.inner.read(),
}
}
}
pub struct LockReadGuard<'a, T: 'a + ?Sized> {
wants_write: &'a (Mutex<bool>, Condvar),
inner: RwLockReadGuard<'a, T>,
}
impl<'a, T: ?Sized> LockReadGuard<'a, T> {
/// Perform a closure on a mutable reference of the inner locked value.
///
/// # Parameters
///
/// `prep_func` - A closure that will run once, after the lock marks its intention to write,
/// this can be used to tell other threads to yield their read locks temporarily. It will be passed
/// an immutable reference to the inner value.
///
/// `mut_func` - A closure that will run once write access is gained. It iwll be passed a mutable reference
/// to the inner value.
///
pub fn try_writeable<U, P: FnOnce(&T), F: FnOnce(&mut T) -> U>(
&mut self,
prep_func: P,
mut_func: F,
) -> Option<U> {
// First tell everyone that we want to write now, this will prevent any new reader from starting until we are done.
{
let &(ref lock, cvar) = &self.wants_write;
let mut wants_write = lock.lock();
RwLockReadGuard::unlocked(&mut self.inner, move || {
while *wants_write {
// We have a writer and we want to wait for it to go away
cvar.wait(&mut wants_write);
}
*wants_write = true;
});
}
// Second stage is to run the prep function
prep_func(&*self.inner);
let lock = RwLockReadGuard::rwlock(&self.inner);
// Third stage is to perform our op under a write lock
let ret = Some(RwLockReadGuard::unlocked(&mut self.inner, move || {
// There is no race here because wants_write blocks other threads
let mut write_access = lock.write();
mut_func(&mut *write_access)
}));
// Finally signal other threads
let (ref lock, ref cvar) = &self.wants_write;
let mut wants_write = lock.lock();
*wants_write = false;
cvar.notify_all();
ret
}
}
impl<'a, T: ?Sized> Deref for LockReadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
&self.inner
}
}

View File

@@ -0,0 +1,75 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use crate::device::Error;
use libc::{gid_t, setgid, setuid, uid_t};
use std::io;
#[cfg(target_os = "macos")]
use nix::unistd::User;
pub fn get_saved_ids() -> Result<(uid_t, gid_t), Error> {
// Get the user name of the sudoer
#[cfg(target_os = "macos")]
match std::env::var("USER") {
Ok(uname) => match User::from_name(&uname) {
Ok(Some(user)) => Ok((uid_t::from(user.uid), gid_t::from(user.gid))),
Err(e) => Err(Error::DropPrivileges(format!(
"Failed parse user; err: {:?}",
e
))),
Ok(None) => Err(Error::DropPrivileges("Failed to find user".to_owned())),
},
Err(e) => Err(Error::DropPrivileges(format!(
"Could not get environment variable for user; err: {:?}",
e
))),
}
#[cfg(not(target_os = "macos"))]
{
use libc::{getlogin, getpwnam};
let uname = unsafe { getlogin() };
if uname.is_null() {
return Err(Error::DropPrivileges("NULL from getlogin".to_owned()));
}
let userinfo = unsafe { getpwnam(uname) };
if userinfo.is_null() {
return Err(Error::DropPrivileges("NULL from getpwnam".to_owned()));
}
// Saved group ID
let saved_gid = unsafe { (*userinfo).pw_gid };
// Saved user ID
let saved_uid = unsafe { (*userinfo).pw_uid };
Ok((saved_uid, saved_gid))
}
}
pub fn drop_privileges() -> Result<(), Error> {
let (saved_uid, saved_gid) = get_saved_ids()?;
if -1 == unsafe { setgid(saved_gid) } {
// Set real and effective group ID
return Err(Error::DropPrivileges(
io::Error::last_os_error().to_string(),
));
}
if -1 == unsafe { setuid(saved_uid) } {
// Set real and effective user ID
return Err(Error::DropPrivileges(
io::Error::last_os_error().to_string(),
));
}
// Validated we can't get sudo back again
if unsafe { (setgid(0) != -1) || (setuid(0) != -1) } {
Err(Error::DropPrivileges(
"Failed to permanently drop privileges".to_owned(),
))
} else {
Ok(())
}
}

View File

@@ -0,0 +1,416 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use super::Error;
use libc::*;
use parking_lot::Mutex;
use std::io;
use std::ops::Deref;
use std::os::unix::io::RawFd;
use std::ptr::null_mut;
use std::time::Duration;
/// A return type for the EventPoll::wait() function
pub enum WaitResult<'a, H> {
/// Event triggered normally
Ok(EventGuard<'a, H>),
/// Event triggered due to End of File conditions
EoF(EventGuard<'a, H>),
/// There was an error
Error(String),
}
/// Implements a registry of pollable events
pub struct EventPoll<H: Sized> {
events: Mutex<Vec<Option<Box<Event<H>>>>>,
epoll: RawFd, // The OS epoll
}
/// A type that hold a reference to a triggered Event
/// While an EventGuard exists for a given Event, it will not be triggered by any other thread
/// Once the EventGuard goes out of scope, the underlying Event will be re-enabled
pub struct EventGuard<'a, H> {
epoll: RawFd,
event: &'a mut Event<H>,
poll: &'a EventPoll<H>,
}
/// A reference to a single event in an EventPoll
pub struct EventRef {
trigger: RawFd,
}
struct Event<H> {
event: epoll_event, // The epoll event description
fd: RawFd, // The associated fd
handler: H, // The associated data
notifier: bool, // Is a notification event
needs_read: bool, // This event needs to be read to be cleared
}
impl<H> Drop for EventPoll<H> {
fn drop(&mut self) {
unsafe { close(self.epoll) };
}
}
impl<H: Sync + Send> EventPoll<H> {
/// Create a new event registry
pub fn new() -> Result<EventPoll<H>, Error> {
let epoll = match unsafe { epoll_create(1) } {
-1 => return Err(Error::EventQueue(io::Error::last_os_error())),
epoll => epoll,
};
Ok(EventPoll {
events: Mutex::new(vec![]),
epoll,
})
}
/// Add and enable a new event with the factory.
/// The event is triggered when a Read operation on the provided trigger becomes available
/// If the trigger fd is closed, the event won't be triggered anymore, but it's data won't be
/// automatically released.
/// The safe way to delete an event, is using the cancel method of an EventGuard.
/// If the same trigger is used with multiple events in the same EventPoll, the last added
/// event overrides all previous events. In case the same trigger is used with multiple polls,
/// each event will be triggered independently.
/// The event will keep triggering until a Read operation is no longer possible on the trigger.
/// When triggered, one of the threads waiting on the poll will receive the handler via an
/// appropriate EventGuard. It is guaranteed that only a single thread can have a reference to
/// the handler at any given time.
pub fn new_event(&self, trigger: RawFd, handler: H) -> Result<EventRef, Error> {
// Create an event descriptor
let flags = EPOLLIN | EPOLLONESHOT;
let ev = Event {
event: epoll_event {
events: flags as _,
u64: 0,
},
fd: trigger,
handler,
notifier: false,
needs_read: false,
};
self.register_event(ev)
}
/// Add and enable a new write event with the factory.
/// The event is triggered when a Write operation on the provided trigger becomes possible
/// For TCP sockets it means that the socket was succesfully connected
#[allow(dead_code)]
pub fn new_write_event(&self, trigger: RawFd, handler: H) -> Result<EventRef, Error> {
// Create an event descriptor
let flags = EPOLLOUT | EPOLLET | EPOLLONESHOT;
let ev = Event {
event: epoll_event {
events: flags as _,
u64: 0,
},
fd: trigger,
handler,
notifier: false,
needs_read: false,
};
self.register_event(ev)
}
/// Add and enable a new timed event with the factory.
/// The even will be triggered for the first time after period time, and henceforth triggered
/// every period time. Period is counted from the moment the appropriate EventGuard is released.
pub fn new_periodic_event(&self, handler: H, period: Duration) -> Result<EventRef, Error> {
// The periodic event on Linux uses the timerfd
let tfd = match unsafe { timerfd_create(CLOCK_BOOTTIME, TFD_NONBLOCK) } {
-1 => match unsafe { timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK) } {
// A fallback for kernels < 3.15
-1 => return Err(Error::Timer(io::Error::last_os_error())),
efd => efd,
},
efd => efd,
};
let ts = timespec {
tv_sec: period.as_secs() as _,
tv_nsec: i64::from(period.subsec_nanos()) as _,
};
let spec = itimerspec {
it_value: ts,
it_interval: ts,
};
if unsafe { timerfd_settime(tfd, 0, &spec, std::ptr::null_mut()) } == -1 {
unsafe { close(tfd) };
return Err(Error::Timer(io::Error::last_os_error()));
}
let ev = Event {
event: epoll_event {
events: (EPOLLIN | EPOLLONESHOT) as _,
u64: 0,
},
fd: tfd,
handler,
notifier: false,
needs_read: true,
};
self.register_event(ev)
}
/// Add and enable a new notification event with the factory.
/// The event can only be triggered manually, using the trigger_notification method.
/// The event will remain in a triggered state until the stop_notification method is
/// called. Both methods should only be called with the producing EventPoll.
pub fn new_notifier(&self, handler: H) -> Result<EventRef, Error> {
// The notifier on Linux uses the eventfd for notifications.
// The way it works is when a non zero value is written into the eventfd it will trigger
// the EPOLLIN event. Since we don't enable ONESHOT it will keep triggering until
// canceled.
// When we want to stop the event, we read something once from the file descriptor.
let efd = match unsafe { eventfd(0, EFD_NONBLOCK) } {
-1 => return Err(Error::EventQueue(io::Error::last_os_error())),
efd => efd,
};
let ev = Event {
event: epoll_event {
events: (EPOLLIN) as _,
u64: 0,
},
fd: efd,
handler,
notifier: true,
needs_read: false,
};
self.register_event(ev)
}
/// Add and enable a new signal handler
pub fn new_signal_event(&self, signal: c_int, handler: H) -> Result<EventRef, Error> {
let sfd = match unsafe {
let mut sigset = std::mem::zeroed();
sigemptyset(&mut sigset);
sigaddset(&mut sigset, signal);
sigprocmask(SIG_BLOCK, &sigset, null_mut());
signalfd(-1, &sigset, SFD_NONBLOCK)
} {
-1 => return Err(Error::EventQueue(io::Error::last_os_error())),
sfd => sfd,
};
let ev = Event {
event: epoll_event {
events: (EPOLLIN | EPOLLONESHOT) as _,
u64: 0,
},
fd: sfd,
handler,
notifier: false,
needs_read: true,
};
self.register_event(ev)
}
/// Wait until one of the registered events becomes triggered. Once an event
/// is triggered, a single caller thread gets the handler for that event.
/// In case a notifier is triggered, all waiting threads will receive the same
/// handler.
pub fn wait(&self) -> WaitResult<'_, H> {
let mut event = epoll_event { events: 0, u64: 0 };
match unsafe { epoll_wait(self.epoll, &mut event, 1, -1) } {
-1 => return WaitResult::Error(io::Error::last_os_error().to_string()),
1 => {}
_ => return WaitResult::Error("unexpected number of events returned".to_string()),
}
let event_data = unsafe { (event.u64 as *mut Event<H>).as_mut().unwrap() };
let guard = EventGuard {
epoll: self.epoll,
event: event_data,
poll: self,
};
if event.events & EPOLLHUP as u32 != 0 {
// End of file flag
WaitResult::EoF(guard)
} else {
WaitResult::Ok(guard)
}
}
// Register an event with this poll.
fn register_event(&self, ev: Event<H>) -> Result<EventRef, Error> {
// To register an event we
// * Create a reference to self in the inner event
// * Store the Event in the events vector
// * Dispose of a previous Event under same fd if any
// * Add the Event to epoll
let trigger = ev.fd;
let mut ev = Box::new(ev);
// The inner event points back to the wrapper
ev.event.u64 = ev.as_mut() as *mut Event<H> as _;
let mut event_desc = ev.event;
// Now add the pointer to the events vector, this is a place from which we can drop the event
self.insert_at(trigger as _, ev);
// Add the event to epoll
if unsafe { epoll_ctl(self.epoll, EPOLL_CTL_ADD, trigger, &mut event_desc) } == -1 {
return Err(Error::EventQueue(io::Error::last_os_error()));
}
Ok(EventRef { trigger })
}
// Insert an event into the events vector
fn insert_at(&self, index: usize, data: Box<Event<H>>) {
let mut events = self.events.lock();
while events.len() <= index {
// Resize the vector to be able to fit the new index
// We trust the OS to allocate file descriptors in a sane order
events.push(None); // resize doesn't work because Clone is not satisfied
}
if events[index].take().is_some() {
// Properly remove the previous event first
unsafe {
epoll_ctl(self.epoll, EPOLL_CTL_DEL, index as _, null_mut());
};
}
events[index] = Some(data);
}
/// Trigger a notification
pub fn trigger_notification(&self, notification_event: &EventRef) {
let events = self.events.lock();
let event_ref = &(*events)[notification_event.trigger as usize];
let event_data = event_ref.as_ref().expect("Expected an event");
if !event_data.notifier {
panic!("Can only trigger a notification event");
}
// Write some data to the eventfd to trigger an EPOLLIN event
unsafe {
write(
notification_event.trigger,
&(std::u64::MAX - 1).to_ne_bytes()[0] as *const u8 as _,
8,
)
};
}
/// Stop a notification
pub fn stop_notification(&self, notification_event: &EventRef) {
let events = self.events.lock();
let event_ref = &(*events)[notification_event.trigger as usize];
let event_data = event_ref.as_ref().expect("Expected an event");
if !event_data.notifier {
panic!("Can only trigger a notification event");
}
let mut buf = [0u8; 8];
unsafe {
read(
notification_event.trigger,
buf.as_mut_ptr() as _,
buf.len() as _,
)
};
}
}
impl<H> EventPoll<H> {
/// Disable and remove the event and associated handler, using the fd that
/// was used to register it.
///
/// # Safety
///
/// This function is only safe to call when the event loop is not running,
/// otherwise the memory of the handler may get freed while in use.
pub unsafe fn clear_event_by_fd(&self, index: RawFd) {
let mut events = self.events.lock();
assert!(index >= 0);
if events[index as usize].take().is_some() {
epoll_ctl(self.epoll, EPOLL_CTL_DEL, index, null_mut());
}
}
}
impl<'a, H> Deref for EventGuard<'a, H> {
type Target = H;
fn deref(&self) -> &H {
&self.event.handler
}
}
impl<'a, H> Drop for EventGuard<'a, H> {
fn drop(&mut self) {
if self.event.needs_read {
// Must read from the event to reset it before we enable it
let mut buf: [std::mem::MaybeUninit<u8>; 256] =
unsafe { std::mem::MaybeUninit::uninit().assume_init() };
while unsafe { read(self.event.fd, buf.as_mut_ptr() as _, buf.len() as _) } != -1 {}
}
unsafe {
epoll_ctl(
self.epoll,
EPOLL_CTL_MOD,
self.event.fd,
&mut self.event.event,
);
}
}
}
impl<'a, H> EventGuard<'a, H> {
/// Get a mutable reference to the stored value
#[allow(dead_code)]
pub fn get_mut(&mut self) -> &mut H {
&mut self.event.handler
}
/// Cancel and remove the event referenced by this guard
pub fn cancel(self) {
unsafe { self.poll.clear_event_by_fd(self.event.fd) };
std::mem::forget(self); // Don't call the regular drop that would enable the event
}
pub fn fd(&self) -> i32 {
self.event.fd
}
/// Change the event flags to enable or disable notifying when the fd is writable
pub fn notify_writable(&mut self, enabled: bool) {
let flags = if enabled {
EPOLLOUT | EPOLLIN | EPOLLET | EPOLLONESHOT
} else {
EPOLLIN | EPOLLONESHOT
};
self.event.event.events = flags as _;
}
}
pub fn block_signal(signal: c_int) -> Result<sigset_t, String> {
unsafe {
let mut sigset = std::mem::zeroed();
sigemptyset(&mut sigset);
if sigaddset(&mut sigset, signal) == -1 {
return Err(io::Error::last_os_error().to_string());
}
if sigprocmask(SIG_BLOCK, &sigset, null_mut()) == -1 {
return Err(io::Error::last_os_error().to_string());
}
Ok(sigset)
}
}

View File

@@ -0,0 +1,849 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This module contains some integration tests for boringtun
// Those tests require docker and sudo privileges to run
#[cfg(all(test, not(target_os = "macos")))]
mod tests {
use crate::device::{DeviceConfig, DeviceHandle};
use crate::x25519::{PublicKey, StaticSecret};
use base64::encode as base64encode;
use hex::encode;
use rand_core::OsRng;
use ring::rand::{SecureRandom, SystemRandom};
use std::fmt::Write as _;
use std::io::{BufRead, BufReader, Read, Write};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::os::unix::net::UnixStream;
use std::process::Command;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
static NEXT_IFACE_IDX: AtomicUsize = AtomicUsize::new(100); // utun 100+ should be vacant during testing on CI
static NEXT_PORT: AtomicUsize = AtomicUsize::new(61111); // Use ports starting with 61111, hoping we don't run into a taken port 🤷
static NEXT_IP: AtomicUsize = AtomicUsize::new(0xc0000200); // Use 192.0.2.0/24 for those tests, we might use more than 256 addresses though, usize must be >=32 bits on all supported platforms
static NEXT_IP_V6: AtomicUsize = AtomicUsize::new(0); // Use the 2001:db8:: address space, append this atomic counter for bottom 32 bits
fn next_ip() -> IpAddr {
IpAddr::V4(Ipv4Addr::from(
NEXT_IP.fetch_add(1, Ordering::Relaxed) as u32
))
}
fn next_ip_v6() -> IpAddr {
let addr = 0x2001_0db8_0000_0000_0000_0000_0000_0000_u128
+ u128::from(NEXT_IP_V6.fetch_add(1, Ordering::Relaxed) as u32);
IpAddr::V6(Ipv6Addr::from(addr))
}
fn next_port() -> u16 {
NEXT_PORT.fetch_add(1, Ordering::Relaxed) as u16
}
/// Represents an allowed IP and cidr for a peer
struct AllowedIp {
ip: IpAddr,
cidr: u8,
}
/// Represents a single peer running in a container
struct Peer {
key: StaticSecret,
endpoint: SocketAddr,
allowed_ips: Vec<AllowedIp>,
container_name: Option<String>,
}
/// Represents a single WireGuard interface on local machine
struct WGHandle {
_device: DeviceHandle,
name: String,
addr_v4: IpAddr,
addr_v6: IpAddr,
started: bool,
peers: Vec<Arc<Peer>>,
}
impl Drop for Peer {
fn drop(&mut self) {
if let Some(name) = &self.container_name {
Command::new("docker")
.args([
"stop", // Run docker
&name[5..],
])
.status()
.ok();
std::fs::remove_file(name).ok();
std::fs::remove_file(format!("{}.ngx", name)).ok();
}
}
}
impl Peer {
/// Create a new peer with a given endpoint and a list of allowed IPs
fn new(endpoint: SocketAddr, allowed_ips: Vec<AllowedIp>) -> Peer {
Peer {
key: StaticSecret::random_from_rng(OsRng),
endpoint,
allowed_ips,
container_name: None,
}
}
/// Creates a new configuration file that can be used by wg-quick
fn gen_wg_conf(
&self,
local_key: &PublicKey,
local_addr: &IpAddr,
local_port: u16,
) -> String {
let mut conf = String::from("[Interface]\n");
// Each allowed ip, becomes a possible address in the config
for ip in &self.allowed_ips {
let _ = writeln!(conf, "Address = {}/{}", ip.ip, ip.cidr);
}
// The local endpoint port is the remote listen port
let _ = writeln!(conf, "ListenPort = {}", self.endpoint.port());
// HACK: this should consume the key so it can't be reused instead of cloning and serializing
let _ = writeln!(conf, "PrivateKey = {}", base64encode(self.key.to_bytes()));
// We are the peer
let _ = writeln!(conf, "[Peer]");
let _ = writeln!(conf, "PublicKey = {}", base64encode(local_key.as_bytes()));
let _ = writeln!(conf, "AllowedIPs = {}", local_addr);
let _ = write!(conf, "Endpoint = 127.0.0.1:{}", local_port);
conf
}
/// Create a simple nginx config, that will respond with the peer public key
fn gen_nginx_conf(&self) -> String {
format!(
"server {{\n\
listen 80;\n\
listen [::]:80;\n\
location / {{\n\
return 200 '{}';\n\
}}\n\
}}",
encode(PublicKey::from(&self.key).as_bytes())
)
}
fn start_in_container(
&mut self,
local_key: &PublicKey,
local_addr: &IpAddr,
local_port: u16,
) {
let peer_config = self.gen_wg_conf(local_key, local_addr, local_port);
let peer_config_file = temp_path();
std::fs::write(&peer_config_file, peer_config).unwrap();
let nginx_config = self.gen_nginx_conf();
let nginx_config_file = format!("{}.ngx", peer_config_file);
std::fs::write(&nginx_config_file, nginx_config).unwrap();
Command::new("docker")
.args([
"run", // Run docker
"-d", // In detached mode
"--cap-add=NET_ADMIN", // Grant permissions to open a tunnel
"--device=/dev/net/tun",
"--sysctl", // Enable ipv6
"net.ipv6.conf.all.disable_ipv6=0",
"--sysctl",
"net.ipv6.conf.default.disable_ipv6=0",
"-p", // Open port for the endpoint
&format!("{0}:{0}/udp", self.endpoint.port()),
"-v", // Map the generated WireGuard config file
&format!("{}:/wireguard/wg.conf", peer_config_file),
"-v", // Map the nginx config file
&format!("{}:/etc/nginx/conf.d/default.conf", nginx_config_file),
"--rm", // Cleanup
"--name",
&peer_config_file[5..],
"vkrasnov/wireguard-test",
])
.status()
.expect("Failed to run docker");
self.container_name = Some(peer_config_file);
}
fn connect(&self) -> std::net::TcpStream {
let http_addr = SocketAddr::new(self.allowed_ips[0].ip, 80);
for _i in 0..5 {
let res = std::net::TcpStream::connect(http_addr);
if let Err(err) = res {
println!("failed to connect: {:?}", err);
std::thread::sleep(std::time::Duration::from_millis(100));
continue;
}
return res.unwrap();
}
panic!("failed to connect");
}
fn get_request(&self) -> String {
let mut tcp_conn = self.connect();
write!(
tcp_conn,
"GET / HTTP/1.1\nHost: localhost\nAccept: */*\nConnection: close\n\n"
)
.unwrap();
tcp_conn
.set_read_timeout(Some(std::time::Duration::from_secs(60)))
.ok();
let mut reader = BufReader::new(tcp_conn);
let mut line = String::new();
let mut response = String::new();
let mut len = 0usize;
// Read response code
if reader.read_line(&mut line).is_ok() && !line.starts_with("HTTP/1.1 200") {
return response;
}
line.clear();
// Read headers
while reader.read_line(&mut line).is_ok() {
if line.trim() == "" {
break;
}
{
let parsed_line: Vec<&str> = line.split(':').collect();
if parsed_line.len() < 2 {
return response;
}
let (key, val) = (parsed_line[0], parsed_line[1]);
if key.to_lowercase() == "content-length" {
len = match val.trim().parse() {
Err(_) => return response,
Ok(len) => len,
};
}
}
line.clear();
}
// Read body
let mut buf = [0u8; 256];
while len > 0 {
let to_read = len.min(buf.len());
if reader.read_exact(&mut buf[..to_read]).is_err() {
return response;
}
response.push_str(&String::from_utf8_lossy(&buf[..to_read]));
len -= to_read;
}
response
}
}
impl WGHandle {
/// Create a new interface for the tunnel with the given address
fn init(addr_v4: IpAddr, addr_v6: IpAddr) -> WGHandle {
WGHandle::init_with_config(
addr_v4,
addr_v6,
DeviceConfig {
n_threads: 2,
use_connected_socket: true,
#[cfg(target_os = "linux")]
use_multi_queue: true,
#[cfg(target_os = "linux")]
uapi_fd: -1,
},
)
}
/// Create a new interface for the tunnel with the given address
fn init_with_config(addr_v4: IpAddr, addr_v6: IpAddr, config: DeviceConfig) -> WGHandle {
// Generate a new name, utun100+ should work on macOS and Linux
let name = format!("utun{}", NEXT_IFACE_IDX.fetch_add(1, Ordering::Relaxed));
let _device = DeviceHandle::new(&name, config).unwrap();
WGHandle {
_device,
name,
addr_v4,
addr_v6,
started: false,
peers: vec![],
}
}
#[cfg(target_os = "macos")]
/// Starts the tunnel
fn start(&mut self) {
// Assign the ipv4 address to the interface
Command::new("ifconfig")
.args(&[
&self.name,
&self.addr_v4.to_string(),
&self.addr_v4.to_string(),
"alias",
])
.status()
.expect("failed to assign ip to tunnel");
// Assign the ipv6 address to the interface
Command::new("ifconfig")
.args(&[
&self.name,
"inet6",
&self.addr_v6.to_string(),
"prefixlen",
"128",
"alias",
])
.status()
.expect("failed to assign ipv6 to tunnel");
// Start the tunnel
Command::new("ifconfig")
.args(&[&self.name, "up"])
.status()
.expect("failed to start the tunnel");
self.started = true;
// Add each peer to the routing table
for p in &self.peers {
for r in &p.allowed_ips {
let inet_flag = match r.ip {
IpAddr::V4(_) => "-inet",
IpAddr::V6(_) => "-inet6",
};
Command::new("route")
.args(&[
"-q",
"-n",
"add",
inet_flag,
&format!("{}/{}", r.ip, r.cidr),
"-interface",
&self.name,
])
.status()
.expect("failed to add route");
}
}
}
#[cfg(target_os = "linux")]
/// Starts the tunnel
fn start(&mut self) {
Command::new("ip")
.args([
"address",
"add",
&self.addr_v4.to_string(),
"dev",
&self.name,
])
.status()
.expect("failed to assign ip to tunnel");
Command::new("ip")
.args([
"address",
"add",
&self.addr_v6.to_string(),
"dev",
&self.name,
])
.status()
.expect("failed to assign ipv6 to tunnel");
// Start the tunnel
Command::new("ip")
.args(["link", "set", "mtu", "1400", "up", "dev", &self.name])
.status()
.expect("failed to start the tunnel");
self.started = true;
// Add each peer to the routing table
for p in &self.peers {
for r in &p.allowed_ips {
Command::new("ip")
.args([
"route",
"add",
&format!("{}/{}", r.ip, r.cidr),
"dev",
&self.name,
])
.status()
.expect("failed to add route");
}
}
}
/// Issue a get command on the interface
fn wg_get(&self) -> String {
let path = format!("/var/run/wireguard/{}.sock", self.name);
let mut socket = UnixStream::connect(path).unwrap();
write!(socket, "get=1\n\n").unwrap();
let mut ret = String::new();
socket.read_to_string(&mut ret).unwrap();
ret
}
/// Issue a set command on the interface
fn wg_set(&self, setting: &str) -> String {
let path = format!("/var/run/wireguard/{}.sock", self.name);
let mut socket = UnixStream::connect(path).unwrap();
write!(socket, "set=1\n{}\n\n", setting).unwrap();
let mut ret = String::new();
socket.read_to_string(&mut ret).unwrap();
ret
}
/// Assign a listen_port to the interface
fn wg_set_port(&self, port: u16) -> String {
self.wg_set(&format!("listen_port={}", port))
}
/// Assign a private_key to the interface
fn wg_set_key(&self, key: StaticSecret) -> String {
self.wg_set(&format!("private_key={}", encode(key.to_bytes())))
}
/// Assign a peer to the interface (with public_key, endpoint and a series of nallowed_ip)
fn wg_set_peer(
&self,
key: &PublicKey,
ep: &SocketAddr,
allowed_ips: &[AllowedIp],
) -> String {
let mut req = format!("public_key={}\nendpoint={}", encode(key.as_bytes()), ep);
for AllowedIp { ip, cidr } in allowed_ips {
let _ = write!(req, "\nallowed_ip={}/{}", ip, cidr);
}
self.wg_set(&req)
}
/// Add a new known peer
fn add_peer(&mut self, peer: Arc<Peer>) {
self.wg_set_peer(
&PublicKey::from(&peer.key),
&peer.endpoint,
&peer.allowed_ips,
);
self.peers.push(peer);
}
}
/// Create a new filename in the /tmp dir
fn temp_path() -> String {
let mut path = String::from("/tmp/");
let mut buf = [0u8; 32];
SystemRandom::new().fill(&mut buf[..]).unwrap();
path.push_str(&encode(buf));
path
}
#[test]
#[ignore]
/// Test if wireguard starts and creates a unix socket that we can read from
fn test_wireguard_get() {
let wg = WGHandle::init("192.0.2.0".parse().unwrap(), "::2".parse().unwrap());
let response = wg.wg_get();
assert!(response.ends_with("errno=0\n\n"));
}
#[test]
#[ignore]
/// Test if wireguard starts and creates a unix socket that we can use to set settings
fn test_wireguard_set() {
let port = next_port();
let private_key = StaticSecret::random_from_rng(OsRng);
let own_public_key = PublicKey::from(&private_key);
let wg = WGHandle::init("192.0.2.0".parse().unwrap(), "::2".parse().unwrap());
assert!(wg.wg_get().ends_with("errno=0\n\n"));
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
// Check that the response matches what we expect
assert_eq!(
wg.wg_get(),
format!(
"own_public_key={}\nlisten_port={}\nerrno=0\n\n",
encode(own_public_key.as_bytes()),
port
)
);
let peer_key = StaticSecret::random_from_rng(OsRng);
let peer_pub_key = PublicKey::from(&peer_key);
let endpoint = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 0, 0, 1)), 50001);
let allowed_ips = [
AllowedIp {
ip: IpAddr::V4(Ipv4Addr::new(172, 0, 0, 2)),
cidr: 32,
},
AllowedIp {
ip: IpAddr::V6(Ipv6Addr::new(0xf120, 0, 0, 2, 2, 2, 0, 0)),
cidr: 100,
},
];
assert_eq!(
wg.wg_set_peer(&peer_pub_key, &endpoint, &allowed_ips),
"errno=0\n\n"
);
// Check that the response matches what we expect
assert_eq!(
wg.wg_get(),
format!(
"own_public_key={}\n\
listen_port={}\n\
public_key={}\n\
endpoint={}\n\
allowed_ip={}/{}\n\
allowed_ip={}/{}\n\
rx_bytes=0\n\
tx_bytes=0\n\
errno=0\n\n",
encode(own_public_key.as_bytes()),
port,
encode(peer_pub_key.as_bytes()),
endpoint,
allowed_ips[0].ip,
allowed_ips[0].cidr,
allowed_ips[1].ip,
allowed_ips[1].cidr
)
);
}
/// Test if wireguard can handle simple ipv4 connections, don't use a connected socket
#[test]
#[ignore]
fn test_wg_start_ipv4_non_connected() {
let port = next_port();
let private_key = StaticSecret::random_from_rng(OsRng);
let public_key = PublicKey::from(&private_key);
let addr_v4 = next_ip();
let addr_v6 = next_ip_v6();
let mut wg = WGHandle::init_with_config(
addr_v4,
addr_v6,
DeviceConfig {
n_threads: 2,
use_connected_socket: false,
#[cfg(target_os = "linux")]
use_multi_queue: true,
#[cfg(target_os = "linux")]
uapi_fd: -1,
},
);
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
// Create a new peer whose endpoint is on this machine
let mut peer = Peer::new(
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()),
vec![AllowedIp {
ip: next_ip(),
cidr: 32,
}],
);
peer.start_in_container(&public_key, &addr_v4, port);
let peer = Arc::new(peer);
wg.add_peer(Arc::clone(&peer));
wg.start();
let response = peer.get_request();
assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes()));
}
/// Test if wireguard can handle simple ipv4 connections
#[test]
#[ignore]
fn test_wg_start_ipv4() {
let port = next_port();
let private_key = StaticSecret::random_from_rng(OsRng);
let public_key = PublicKey::from(&private_key);
let addr_v4 = next_ip();
let addr_v6 = next_ip_v6();
let mut wg = WGHandle::init(addr_v4, addr_v6);
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
// Create a new peer whose endpoint is on this machine
let mut peer = Peer::new(
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()),
vec![AllowedIp {
ip: next_ip(),
cidr: 32,
}],
);
peer.start_in_container(&public_key, &addr_v4, port);
let peer = Arc::new(peer);
wg.add_peer(Arc::clone(&peer));
wg.start();
let response = peer.get_request();
assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes()));
}
#[test]
#[ignore]
/// Test if wireguard can handle simple ipv6 connections
fn test_wg_start_ipv6() {
let port = next_port();
let private_key = StaticSecret::random_from_rng(OsRng);
let public_key = PublicKey::from(&private_key);
let addr_v4 = next_ip();
let addr_v6 = next_ip_v6();
let mut wg = WGHandle::init(addr_v4, addr_v6);
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
let mut peer = Peer::new(
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()),
vec![AllowedIp {
ip: next_ip_v6(),
cidr: 128,
}],
);
peer.start_in_container(&public_key, &addr_v6, port);
let peer = Arc::new(peer);
wg.add_peer(Arc::clone(&peer));
wg.start();
let response = peer.get_request();
assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes()));
}
/// Test if wireguard can handle connection with an ipv6 endpoint
#[test]
#[ignore]
#[cfg(target_os = "linux")] // Can't make docker work with ipv6 on macOS ATM
fn test_wg_start_ipv6_endpoint() {
let port = next_port();
let private_key = StaticSecret::random_from_rng(OsRng);
let public_key = PublicKey::from(&private_key);
let addr_v4 = next_ip();
let addr_v6 = next_ip_v6();
let mut wg = WGHandle::init(addr_v4, addr_v6);
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
let mut peer = Peer::new(
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
next_port(),
),
vec![AllowedIp {
ip: next_ip_v6(),
cidr: 128,
}],
);
peer.start_in_container(&public_key, &addr_v6, port);
let peer = Arc::new(peer);
wg.add_peer(Arc::clone(&peer));
wg.start();
let response = peer.get_request();
assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes()));
}
/// Test if wireguard can handle connection with an ipv6 endpoint
#[test]
#[ignore]
#[cfg(target_os = "linux")] // Can't make docker work with ipv6 on macOS ATM
fn test_wg_start_ipv6_endpoint_not_connected() {
let port = next_port();
let private_key = StaticSecret::random_from_rng(OsRng);
let public_key = PublicKey::from(&private_key);
let addr_v4 = next_ip();
let addr_v6 = next_ip_v6();
let mut wg = WGHandle::init_with_config(
addr_v4,
addr_v6,
DeviceConfig {
n_threads: 2,
use_connected_socket: false,
#[cfg(target_os = "linux")]
use_multi_queue: true,
#[cfg(target_os = "linux")]
uapi_fd: -1,
},
);
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
let mut peer = Peer::new(
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
next_port(),
),
vec![AllowedIp {
ip: next_ip_v6(),
cidr: 128,
}],
);
peer.start_in_container(&public_key, &addr_v6, port);
let peer = Arc::new(peer);
wg.add_peer(Arc::clone(&peer));
wg.start();
let response = peer.get_request();
assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes()));
}
/// Test many concurrent connections
#[test]
#[ignore]
fn test_wg_concurrent() {
let port = next_port();
let private_key = StaticSecret::random_from_rng(OsRng);
let public_key = PublicKey::from(&private_key);
let addr_v4 = next_ip();
let addr_v6 = next_ip_v6();
let mut wg = WGHandle::init(addr_v4, addr_v6);
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
for _ in 0..5 {
// Create a new peer whose endpoint is on this machine
let mut peer = Peer::new(
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()),
vec![AllowedIp {
ip: next_ip(),
cidr: 32,
}],
);
peer.start_in_container(&public_key, &addr_v4, port);
let peer = Arc::new(peer);
wg.add_peer(Arc::clone(&peer));
}
wg.start();
let mut threads = vec![];
for p in wg.peers {
let pub_key = PublicKey::from(&p.key);
threads.push(thread::spawn(move || {
for _ in 0..100 {
let response = p.get_request();
assert_eq!(response, encode(pub_key.as_bytes()));
}
}));
}
for t in threads {
t.join().unwrap();
}
}
/// Test many concurrent connections
#[test]
#[ignore]
fn test_wg_concurrent_v6() {
let port = next_port();
let private_key = StaticSecret::random_from_rng(OsRng);
let public_key = PublicKey::from(&private_key);
let addr_v4 = next_ip();
let addr_v6 = next_ip_v6();
let mut wg = WGHandle::init(addr_v4, addr_v6);
assert_eq!(wg.wg_set_port(port), "errno=0\n\n");
assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n");
for _ in 0..5 {
// Create a new peer whose endpoint is on this machine
let mut peer = Peer::new(
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()),
vec![AllowedIp {
ip: next_ip_v6(),
cidr: 128,
}],
);
peer.start_in_container(&public_key, &addr_v6, port);
let peer = Arc::new(peer);
wg.add_peer(Arc::clone(&peer));
}
wg.start();
let mut threads = vec![];
for p in wg.peers {
let pub_key = PublicKey::from(&p.key);
threads.push(thread::spawn(move || {
for _ in 0..100 {
let response = p.get_request();
assert_eq!(response, encode(pub_key.as_bytes()));
}
}));
}
for t in threads {
t.join().unwrap();
}
}
}

View File

@@ -0,0 +1,337 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use super::Error;
use libc::*;
use parking_lot::Mutex;
use std::io;
use std::ops::Deref;
use std::os::unix::io::RawFd;
use std::ptr::{null, null_mut};
use std::time::Duration;
/// A return type for the EventPoll::wait() function
pub enum WaitResult<'a, H> {
/// Event triggered normally
Ok(EventGuard<'a, H>),
/// Event triggered due to End of File conditions
EoF(EventGuard<'a, H>),
/// There was an error
Error(String),
}
/// Implements a registry of pollable events
pub struct EventPoll<H: Sized> {
events: Mutex<Vec<Option<Box<Event<H>>>>>, // Events with a file descriptor
custom: Mutex<Vec<Option<Box<Event<H>>>>>, // Other events (i.e. timers & notifiers)
signals: Mutex<Vec<Option<Box<Event<H>>>>>, // Signal handlers
kqueue: RawFd, // The OS kqueue
}
/// A type that hold a reference to a triggered Event
/// While an EventGuard exists for a given Event, it will not be triggered by any other thread
/// Once the EventGuard goes out of scope, the underlying Event will be re-enabled
pub struct EventGuard<'a, H> {
kqueue: RawFd,
event: &'a Event<H>,
poll: &'a EventPoll<H>,
}
/// A reference to a single event in an EventPoll
pub struct EventRef {
trigger: RawFd,
}
#[derive(PartialEq)]
enum EventKind {
FD,
Notifier,
Signal,
Timer,
}
// A single event
struct Event<H> {
event: kevent, // The kqueue event description
handler: H, // The associated data
kind: EventKind,
}
impl<H> Drop for EventPoll<H> {
fn drop(&mut self) {
unsafe { close(self.kqueue) };
}
}
unsafe impl<H> Send for EventPoll<H> {}
unsafe impl<H> Sync for EventPoll<H> {}
impl<H: Send + Sync> EventPoll<H> {
/// Create a new event registry
pub fn new() -> Result<EventPoll<H>, Error> {
let kqueue = match unsafe { kqueue() } {
-1 => return Err(Error::EventQueue(io::Error::last_os_error())),
kqueue => kqueue,
};
Ok(EventPoll {
events: Mutex::new(vec![]),
custom: Mutex::new(vec![]),
signals: Mutex::new(vec![]),
kqueue,
})
}
/// Add and enable a new event with the factory.
/// The event is triggered when a Read operation on the provided trigger becomes available
/// If the trigger fd is closed, the event won't be triggered anymore, but it's data won't be
/// automatically released.
/// The safe way to delete an event, is using the cancel method of an EventGuard.
/// If the same trigger is used with multiple events in the same EventPoll, the last added
/// event overrides all previous events. In case the same trigger is used with multiple polls,
/// each event will be triggered independently.
/// The event will keep triggering until a Read operation is no longer possible on the trigger.
/// When triggered, one of the threads waiting on the poll will receive the handler via an
/// appropriate EventGuard. It is guaranteed that only a single thread can have a reference to
/// the handler at any given time.
pub fn new_event(&self, trigger: RawFd, handler: H) -> Result<EventRef, Error> {
// Create an event descriptor
let flags = EV_ENABLE | EV_DISPATCH;
let ev = Event {
event: kevent {
ident: trigger as _,
filter: EVFILT_READ,
flags,
fflags: 0,
data: 0,
udata: null_mut(),
},
handler,
kind: EventKind::FD,
};
self.register_event(ev)
}
pub fn new_periodic_event(&self, handler: H, period: Duration) -> Result<EventRef, Error> {
// The periodic event in BSD uses EVFILT_TIMER
let ev = Event {
event: kevent {
ident: 0,
filter: EVFILT_TIMER,
flags: EV_ENABLE | EV_DISPATCH,
fflags: NOTE_NSECONDS,
data: period
.as_secs()
.checked_mul(1_000_000_000)
.unwrap()
.checked_add(u64::from(period.subsec_nanos()))
.unwrap() as _,
udata: null_mut(),
},
handler,
kind: EventKind::Timer,
};
self.register_event(ev)
}
pub fn new_notifier(&self, handler: H) -> Result<EventRef, Error> {
// The notifier in BSD uses EVFILT_USER for notifications.
let ev = Event {
event: kevent {
ident: 0,
filter: EVFILT_USER,
flags: EV_ENABLE,
fflags: 0,
data: 0,
udata: null_mut(),
},
handler,
kind: EventKind::Notifier,
};
self.register_event(ev)
}
/// Add and enable a new signal handler
pub fn new_signal_event(&self, signal: c_int, handler: H) -> Result<EventRef, Error> {
let ev = Event {
event: kevent {
ident: signal as _,
filter: EVFILT_SIGNAL,
flags: EV_ENABLE | EV_DISPATCH,
fflags: 0,
data: 0,
udata: null_mut(),
},
handler,
kind: EventKind::Signal,
};
self.register_event(ev)
}
/// Wait until one of the registered events becomes triggered. Once an event
/// is triggered, a single caller thread gets the handler for that event.
/// In case a notifier is triggered, all waiting threads will receive the same
/// handler.
pub fn wait(&'_ self) -> WaitResult<'_, H> {
let mut event = kevent {
ident: 0,
filter: 0,
flags: 0,
fflags: 0,
data: 0,
udata: null_mut(),
};
if unsafe { kevent(self.kqueue, null(), 0, &mut event, 1, null()) } == -1 {
return WaitResult::Error(io::Error::last_os_error().to_string());
}
let event_data = unsafe { (event.udata as *mut Event<H>).as_ref().unwrap() };
let guard = EventGuard {
kqueue: self.kqueue,
event: event_data,
poll: self,
};
if event.flags & EV_EOF != 0 {
WaitResult::EoF(guard)
} else {
WaitResult::Ok(guard)
}
}
// Register an event with this poll.
fn register_event(&self, ev: Event<H>) -> Result<EventRef, Error> {
let mut events = match ev.kind {
EventKind::FD => self.events.lock(),
EventKind::Timer | EventKind::Notifier => self.custom.lock(),
EventKind::Signal => self.signals.lock(),
};
let (trigger, index) = match ev.kind {
EventKind::FD | EventKind::Signal => (ev.event.ident as RawFd, ev.event.ident as usize),
EventKind::Timer | EventKind::Notifier => (-(events.len() as RawFd) - 1, events.len()), // Custom events get negative identifiers, hopefully we will never have more than 2^31 events of each type
};
// Expand events vector if needed
while events.len() <= index {
// Resize the vector to be able to fit the new index
// We trust the OS to allocate file descriptors in a sane order
events.push(None); // resize doesn't work because Clone is not satisfied
}
let mut ev = Box::new(ev);
// The inner event points back to the wrapper
ev.event.ident = trigger as _;
ev.event.udata = ev.as_mut() as *mut Event<H> as _;
let mut kev = ev.event;
kev.flags |= EV_ADD;
if unsafe { kevent(self.kqueue, &kev, 1, null_mut(), 0, null()) } == -1 {
return Err(Error::EventQueue(io::Error::last_os_error()));
}
if let Some(mut event) = events[index].take() {
// Properly remove any previous event first
event.event.flags = EV_DELETE;
unsafe { kevent(self.kqueue, &event.event, 1, null_mut(), 0, null()) };
}
if ev.kind == EventKind::Signal {
// Mask the signal if successfully added to kqueue
unsafe { signal(trigger, SIG_IGN) };
}
events[index] = Some(ev);
Ok(EventRef { trigger })
}
pub fn trigger_notification(&self, notification_event: &EventRef) {
let events = self.custom.lock();
let ev_index = -notification_event.trigger - 1; // Custom events have negative index from -1
let event_ref = &(*events)[ev_index as usize];
let event_data = event_ref.as_ref().expect("Expected an event");
if event_data.kind != EventKind::Notifier {
panic!("Can only trigger a notification event");
}
let mut kev = event_data.event;
kev.fflags = NOTE_TRIGGER;
unsafe { kevent(self.kqueue, &kev, 1, null_mut(), 0, null()) };
}
pub fn stop_notification(&self, notification_event: &EventRef) {
let events = self.custom.lock();
let ev_index = -notification_event.trigger - 1; // Custom events have negative index from -1
let event_ref = &(*events)[ev_index as usize];
let event_data = event_ref.as_ref().expect("Expected an event");
if event_data.kind != EventKind::Notifier {
panic!("Can only stop a notification event");
}
let mut kev = event_data.event;
kev.flags = EV_DISABLE;
kev.fflags = 0;
unsafe { kevent(self.kqueue, &kev, 1, null_mut(), 0, null()) };
}
}
impl<H> EventPoll<H> {
// This function is only safe to call when the event loop is not running
pub unsafe fn clear_event_by_fd(&self, index: RawFd) {
let (mut events, index) = if index >= 0 {
(self.events.lock(), index as usize)
} else {
(self.custom.lock(), (-index - 1) as usize)
};
if let Some(mut event) = events[index].take() {
// Properly remove any previous event first
event.event.flags = EV_DELETE;
kevent(self.kqueue, &event.event, 1, null_mut(), 0, null());
}
}
}
impl<'a, H> Deref for EventGuard<'a, H> {
type Target = H;
fn deref(&self) -> &H {
&self.event.handler
}
}
impl<'a, H> Drop for EventGuard<'a, H> {
fn drop(&mut self) {
unsafe {
// Re-enable the event once EventGuard goes out of scope
kevent(self.kqueue, &self.event.event, 1, null_mut(), 0, null());
}
}
}
impl<'a, H> EventGuard<'a, H> {
/// Cancel and remove the event represented by this guard
pub fn cancel(self) {
unsafe { self.poll.clear_event_by_fd(self.event.event.ident as RawFd) };
std::mem::forget(self); // Don't call the regular drop that would enable the event
}
/// Stub: only used for Linux-specific features.
pub fn fd(&self) -> i32 {
-1
}
}

View File

@@ -0,0 +1,884 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
pub mod allowed_ips;
pub mod api;
mod dev_lock;
pub mod drop_privileges;
#[cfg(test)]
mod integration_tests;
pub mod peer;
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))]
#[path = "kqueue.rs"]
pub mod poll;
#[cfg(target_os = "linux")]
#[path = "epoll.rs"]
pub mod poll;
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))]
#[path = "tun_darwin.rs"]
pub mod tun;
#[cfg(target_os = "linux")]
#[path = "tun_linux.rs"]
pub mod tun;
use std::collections::HashMap;
use std::io::{self, Write as _};
use std::mem::MaybeUninit;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::unix::io::AsRawFd;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::thread::JoinHandle;
use crate::noise::errors::WireGuardError;
use crate::noise::handshake::parse_handshake_anon;
use crate::noise::rate_limiter::RateLimiter;
use crate::noise::{Packet, Tunn, TunnResult};
use crate::x25519;
use allowed_ips::AllowedIps;
use parking_lot::Mutex;
use peer::{AllowedIP, Peer};
use poll::{EventPoll, EventRef, WaitResult};
use rand_core::{OsRng, RngCore};
use socket2::{Domain, Protocol, Type};
use tun::TunSocket;
use dev_lock::{Lock, LockReadGuard};
const HANDSHAKE_RATE_LIMIT: u64 = 100; // The number of handshakes per second we can tolerate before using cookies
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
const MAX_ITR: usize = 100; // Number of packets to handle per handler call
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("i/o error: {0}")]
IoError(#[from] io::Error),
#[error("{0}")]
Socket(io::Error),
#[error("{0}")]
Bind(String),
#[error("{0}")]
FCntl(io::Error),
#[error("{0}")]
EventQueue(io::Error),
#[error("{0}")]
IOCtl(io::Error),
#[error("{0}")]
Connect(String),
#[error("{0}")]
SetSockOpt(String),
#[error("Invalid tunnel name")]
InvalidTunnelName,
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))]
#[error("{0}")]
GetSockOpt(io::Error),
#[error("{0}")]
GetSockName(String),
#[cfg(target_os = "linux")]
#[error("{0}")]
Timer(io::Error),
#[error("iface read: {0}")]
IfaceRead(io::Error),
#[error("{0}")]
DropPrivileges(String),
#[error("API socket error: {0}")]
ApiSocket(io::Error),
}
// What the event loop should do after a handler returns
enum Action {
Continue, // Continue the loop
Yield, // Yield the read lock and acquire it again
Exit, // Stop the loop
}
// Event handler function
type Handler = Box<dyn Fn(&mut LockReadGuard<Device>, &mut ThreadData) -> Action + Send + Sync>;
pub struct DeviceHandle {
device: Arc<Lock<Device>>, // The interface this handle owns
threads: Vec<JoinHandle<()>>,
}
#[derive(Debug, Clone, Copy)]
pub struct DeviceConfig {
pub n_threads: usize,
pub use_connected_socket: bool,
#[cfg(target_os = "linux")]
pub use_multi_queue: bool,
#[cfg(target_os = "linux")]
pub uapi_fd: i32,
}
impl Default for DeviceConfig {
fn default() -> Self {
DeviceConfig {
n_threads: 4,
use_connected_socket: true,
#[cfg(target_os = "linux")]
use_multi_queue: true,
#[cfg(target_os = "linux")]
uapi_fd: -1,
}
}
}
pub struct Device {
key_pair: Option<(x25519::StaticSecret, x25519::PublicKey)>,
queue: Arc<EventPoll<Handler>>,
listen_port: u16,
fwmark: Option<u32>,
iface: Arc<TunSocket>,
udp4: Option<socket2::Socket>,
udp6: Option<socket2::Socket>,
yield_notice: Option<EventRef>,
exit_notice: Option<EventRef>,
peers: HashMap<x25519::PublicKey, Arc<Mutex<Peer>>>,
peers_by_ip: AllowedIps<Arc<Mutex<Peer>>>,
peers_by_idx: HashMap<u32, Arc<Mutex<Peer>>>,
next_index: IndexLfsr,
config: DeviceConfig,
cleanup_paths: Vec<String>,
mtu: AtomicUsize,
rate_limiter: Option<Arc<RateLimiter>>,
#[cfg(target_os = "linux")]
uapi_fd: i32,
}
struct ThreadData {
iface: Arc<TunSocket>,
src_buf: [u8; MAX_UDP_SIZE],
dst_buf: [u8; MAX_UDP_SIZE],
}
impl DeviceHandle {
pub fn new(name: &str, config: DeviceConfig) -> Result<DeviceHandle, Error> {
let n_threads = config.n_threads;
let mut wg_interface = Device::new(name, config)?;
wg_interface.open_listen_socket(0)?; // Start listening on a random port
let interface_lock = Arc::new(Lock::new(wg_interface));
let mut threads = vec![];
for i in 0..n_threads {
threads.push({
let dev = Arc::clone(&interface_lock);
thread::spawn(move || DeviceHandle::event_loop(i, &dev))
});
}
Ok(DeviceHandle {
device: interface_lock,
threads,
})
}
pub fn wait(&mut self) {
while let Some(thread) = self.threads.pop() {
thread.join().unwrap();
}
}
pub fn clean(&mut self) {
for path in &self.device.read().cleanup_paths {
// attempt to remove any file we created in the work dir
let _ = std::fs::remove_file(path);
}
}
fn event_loop(_i: usize, device: &Lock<Device>) {
#[cfg(target_os = "linux")]
let mut thread_local = ThreadData {
src_buf: [0u8; MAX_UDP_SIZE],
dst_buf: [0u8; MAX_UDP_SIZE],
iface: if _i == 0 || !device.read().config.use_multi_queue {
// For the first thread use the original iface
Arc::clone(&device.read().iface)
} else {
// For for the rest create a new iface queue
let iface_local = Arc::new(
TunSocket::new(&device.read().iface.name().unwrap())
.unwrap()
.set_non_blocking()
.unwrap(),
);
device
.read()
.register_iface_handler(Arc::clone(&iface_local))
.ok();
iface_local
},
};
#[cfg(not(target_os = "linux"))]
let mut thread_local = ThreadData {
src_buf: [0u8; MAX_UDP_SIZE],
dst_buf: [0u8; MAX_UDP_SIZE],
iface: Arc::clone(&device.read().iface),
};
#[cfg(not(target_os = "linux"))]
let uapi_fd = -1;
#[cfg(target_os = "linux")]
let uapi_fd = device.read().uapi_fd;
loop {
// The event loop keeps a read lock on the device, because we assume write access is rarely needed
let mut device_lock = device.read();
let queue = Arc::clone(&device_lock.queue);
loop {
match queue.wait() {
WaitResult::Ok(handler) => {
let action = (*handler)(&mut device_lock, &mut thread_local);
match action {
Action::Continue => {}
Action::Yield => break,
Action::Exit => {
device_lock.trigger_exit();
return;
}
}
}
WaitResult::EoF(handler) => {
if uapi_fd >= 0 && uapi_fd == handler.fd() {
device_lock.trigger_exit();
return;
}
handler.cancel();
}
WaitResult::Error(e) => tracing::error!(message = "Poll error", error = ?e),
}
}
}
}
}
impl Drop for DeviceHandle {
fn drop(&mut self) {
self.device.read().trigger_exit();
self.clean();
}
}
impl Device {
fn next_index(&mut self) -> u32 {
self.next_index.next()
}
fn remove_peer(&mut self, pub_key: &x25519::PublicKey) {
if let Some(peer) = self.peers.remove(pub_key) {
// Found a peer to remove, now purge all references to it:
{
let p = peer.lock();
p.shutdown_endpoint(); // close open udp socket and free the closure
self.peers_by_idx.remove(&p.index());
}
self.peers_by_ip
.remove(&|p: &Arc<Mutex<Peer>>| Arc::ptr_eq(&peer, p));
tracing::info!("Peer removed");
}
}
#[allow(clippy::too_many_arguments)]
fn update_peer(
&mut self,
pub_key: x25519::PublicKey,
remove: bool,
_replace_ips: bool,
endpoint: Option<SocketAddr>,
allowed_ips: &[AllowedIP],
keepalive: Option<u16>,
preshared_key: Option<[u8; 32]>,
) {
if remove {
// Completely remove a peer
return self.remove_peer(&pub_key);
}
// Update an existing peer
if self.peers.get(&pub_key).is_some() {
// We already have a peer, we need to merge the existing config into the newly created one
panic!("Modifying existing peers is not yet supported. Remove and add again instead.");
}
let next_index = self.next_index();
let device_key_pair = self
.key_pair
.as_ref()
.expect("Private key must be set first");
let tunn = Tunn::new(
device_key_pair.0.clone(),
pub_key,
preshared_key,
keepalive,
next_index,
None,
);
let peer = Peer::new(tunn, next_index, endpoint, allowed_ips, preshared_key);
let peer = Arc::new(Mutex::new(peer));
self.peers.insert(pub_key, Arc::clone(&peer));
self.peers_by_idx.insert(next_index, Arc::clone(&peer));
for AllowedIP { addr, cidr } in allowed_ips {
self.peers_by_ip
.insert(*addr, *cidr as _, Arc::clone(&peer));
}
tracing::info!("Peer added");
}
pub fn new(name: &str, config: DeviceConfig) -> Result<Device, Error> {
let poll = EventPoll::<Handler>::new()?;
// Create a tunnel device
let iface = Arc::new(TunSocket::new(name)?.set_non_blocking()?);
let mtu = iface.mtu()?;
#[cfg(not(target_os = "linux"))]
let uapi_fd = -1;
#[cfg(target_os = "linux")]
let uapi_fd = config.uapi_fd;
let mut device = Device {
queue: Arc::new(poll),
iface,
config,
exit_notice: Default::default(),
yield_notice: Default::default(),
fwmark: Default::default(),
key_pair: Default::default(),
listen_port: Default::default(),
next_index: Default::default(),
peers: Default::default(),
peers_by_idx: Default::default(),
peers_by_ip: AllowedIps::new(),
udp4: Default::default(),
udp6: Default::default(),
cleanup_paths: Default::default(),
mtu: AtomicUsize::new(mtu),
rate_limiter: None,
#[cfg(target_os = "linux")]
uapi_fd,
};
if uapi_fd >= 0 {
device.register_api_fd(uapi_fd)?;
} else {
device.register_api_handler()?;
}
device.register_iface_handler(Arc::clone(&device.iface))?;
device.register_notifiers()?;
device.register_timers()?;
#[cfg(target_os = "macos")]
{
// Only for macOS write the actual socket name into WG_TUN_NAME_FILE
if let Ok(name_file) = std::env::var("WG_TUN_NAME_FILE") {
if name == "utun" {
std::fs::write(&name_file, device.iface.name().unwrap().as_bytes()).unwrap();
device.cleanup_paths.push(name_file);
}
}
}
Ok(device)
}
fn open_listen_socket(&mut self, mut port: u16) -> Result<(), Error> {
// Binds the network facing interfaces
// First close any existing open socket, and remove them from the event loop
if let Some(s) = self.udp4.take() {
unsafe {
// This is safe because the event loop is not running yet
self.queue.clear_event_by_fd(s.as_raw_fd())
}
};
if let Some(s) = self.udp6.take() {
unsafe { self.queue.clear_event_by_fd(s.as_raw_fd()) };
}
for peer in self.peers.values() {
peer.lock().shutdown_endpoint();
}
// Then open new sockets and bind to the port
let udp_sock4 = socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
udp_sock4.set_reuse_address(true)?;
udp_sock4.bind(&SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into())?;
udp_sock4.set_nonblocking(true)?;
if port == 0 {
// Random port was assigned
port = udp_sock4.local_addr()?.as_socket().unwrap().port();
}
let udp_sock6 = socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?;
udp_sock6.set_reuse_address(true)?;
udp_sock6.bind(&SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).into())?;
udp_sock6.set_nonblocking(true)?;
self.register_udp_handler(udp_sock4.try_clone().unwrap())?;
self.register_udp_handler(udp_sock6.try_clone().unwrap())?;
self.udp4 = Some(udp_sock4);
self.udp6 = Some(udp_sock6);
self.listen_port = port;
Ok(())
}
fn set_key(&mut self, private_key: x25519::StaticSecret) {
let public_key = x25519::PublicKey::from(&private_key);
let key_pair = Some((private_key.clone(), public_key));
// x25519 (rightly) doesn't let us expose secret keys for comparison.
// If the public keys are the same, then the private keys are the same.
if Some(&public_key) == self.key_pair.as_ref().map(|p| &p.1) {
return;
}
let rate_limiter = Arc::new(RateLimiter::new(&public_key, HANDSHAKE_RATE_LIMIT));
for peer in self.peers.values_mut() {
peer.lock().tunnel.set_static_private(
private_key.clone(),
public_key,
Some(Arc::clone(&rate_limiter)),
)
}
self.key_pair = key_pair;
self.rate_limiter = Some(rate_limiter);
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
fn set_fwmark(&mut self, mark: u32) -> Result<(), Error> {
self.fwmark = Some(mark);
// First set fwmark on listeners
if let Some(ref sock) = self.udp4 {
sock.set_mark(mark)?;
}
if let Some(ref sock) = self.udp6 {
sock.set_mark(mark)?;
}
// Then on all currently connected sockets
for peer in self.peers.values() {
if let Some(ref sock) = peer.lock().endpoint().conn {
sock.set_mark(mark)?
}
}
Ok(())
}
fn clear_peers(&mut self) {
self.peers.clear();
self.peers_by_idx.clear();
self.peers_by_ip.clear();
}
fn register_notifiers(&mut self) -> Result<(), Error> {
let yield_ev = self
.queue
// The notification event handler simply returns Action::Yield
.new_notifier(Box::new(|_, _| Action::Yield))?;
self.yield_notice = Some(yield_ev);
let exit_ev = self
.queue
// The exit event handler simply returns Action::Exit
.new_notifier(Box::new(|_, _| Action::Exit))?;
self.exit_notice = Some(exit_ev);
Ok(())
}
fn register_timers(&self) -> Result<(), Error> {
self.queue.new_periodic_event(
// Reset the rate limiter every second give or take
Box::new(|d, _| {
if let Some(r) = d.rate_limiter.as_ref() {
r.reset_count()
}
Action::Continue
}),
std::time::Duration::from_secs(1),
)?;
self.queue.new_periodic_event(
// Execute the timed function of every peer in the list
Box::new(|d, t| {
let peer_map = &d.peers;
let (udp4, udp6) = match (d.udp4.as_ref(), d.udp6.as_ref()) {
(Some(udp4), Some(udp6)) => (udp4, udp6),
_ => return Action::Continue,
};
// Go over each peer and invoke the timer function
for peer in peer_map.values() {
let mut p = peer.lock();
let endpoint_addr = match p.endpoint().addr {
Some(addr) => addr,
None => continue,
};
match p.update_timers(&mut t.dst_buf[..]) {
TunnResult::Done => {}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
p.shutdown_endpoint(); // close open udp socket
}
TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e),
TunnResult::WriteToNetwork(packet) => {
match endpoint_addr {
SocketAddr::V4(_) => {
udp4.send_to(packet, &endpoint_addr.into()).ok()
}
SocketAddr::V6(_) => {
udp6.send_to(packet, &endpoint_addr.into()).ok()
}
};
}
_ => panic!("Unexpected result from update_timers"),
};
}
Action::Continue
}),
std::time::Duration::from_millis(250),
)?;
Ok(())
}
pub(crate) fn trigger_yield(&self) {
self.queue
.trigger_notification(self.yield_notice.as_ref().unwrap())
}
pub(crate) fn trigger_exit(&self) {
self.queue
.trigger_notification(self.exit_notice.as_ref().unwrap())
}
pub(crate) fn cancel_yield(&self) {
self.queue
.stop_notification(self.yield_notice.as_ref().unwrap())
}
fn register_udp_handler(&self, udp: socket2::Socket) -> Result<(), Error> {
self.queue.new_event(
udp.as_raw_fd(),
Box::new(move |d, t| {
// Handler that handles anonymous packets over UDP
let mut iter = MAX_ITR;
let (private_key, public_key) = d.key_pair.as_ref().expect("Key not set");
let rate_limiter = d.rate_limiter.as_ref().unwrap();
// Loop while we have packets on the anonymous connection
// Safety: the `recv_from` implementation promises not to write uninitialised
// bytes to the buffer, so this casting is safe.
let src_buf =
unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit<u8>]) };
while let Ok((packet_len, addr)) = udp.recv_from(src_buf) {
let packet = &t.src_buf[..packet_len];
// The rate limiter initially checks mac1 and mac2, and optionally asks to send a cookie
let parsed_packet = match rate_limiter.verify_packet(
Some(addr.as_socket().unwrap().ip()),
packet,
&mut t.dst_buf,
) {
Ok(packet) => packet,
Err(TunnResult::WriteToNetwork(cookie)) => {
let _: Result<_, _> = udp.send_to(cookie, &addr);
continue;
}
Err(_) => continue,
};
let peer = match &parsed_packet {
Packet::HandshakeInit(p) => {
parse_handshake_anon(private_key, public_key, p)
.ok()
.and_then(|hh| {
d.peers.get(&x25519::PublicKey::from(hh.peer_static_public))
})
}
Packet::HandshakeResponse(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)),
Packet::PacketCookieReply(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)),
Packet::PacketData(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)),
};
let peer = match peer {
None => continue,
Some(peer) => peer,
};
let mut p = peer.lock();
// We found a peer, use it to decapsulate the message+
let mut flush = false; // Are there packets to send from the queue?
match p
.tunnel
.handle_verified_packet(parsed_packet, &mut t.dst_buf[..])
{
TunnResult::Done => {}
TunnResult::Err(_) => continue,
TunnResult::WriteToNetwork(packet) => {
flush = true;
let _: Result<_, _> = udp.send_to(packet, &addr);
}
TunnResult::WriteToTunnelV4(packet, addr) => {
if p.is_allowed_ip(addr) {
t.iface.write4(packet);
}
}
TunnResult::WriteToTunnelV6(packet, addr) => {
if p.is_allowed_ip(addr) {
t.iface.write6(packet);
}
}
};
if flush {
// Flush pending queue
while let TunnResult::WriteToNetwork(packet) =
p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..])
{
let _: Result<_, _> = udp.send_to(packet, &addr);
}
}
// This packet was OK, that means we want to create a connected socket for this peer
let addr = addr.as_socket().unwrap();
let ip_addr = addr.ip();
p.set_endpoint(addr);
if d.config.use_connected_socket {
if let Ok(sock) = p.connect_endpoint(d.listen_port, d.fwmark) {
d.register_conn_handler(Arc::clone(peer), sock, ip_addr)
.unwrap();
}
}
iter -= 1;
if iter == 0 {
break;
}
}
Action::Continue
}),
)?;
Ok(())
}
fn register_conn_handler(
&self,
peer: Arc<Mutex<Peer>>,
udp: socket2::Socket,
peer_addr: IpAddr,
) -> Result<(), Error> {
self.queue.new_event(
udp.as_raw_fd(),
Box::new(move |_, t| {
// The conn_handler handles packet received from a connected UDP socket, associated
// with a known peer, this saves us the hustle of finding the right peer. If another
// peer gets the same ip, it will be ignored until the socket does not expire.
let iface = &t.iface;
let mut iter = MAX_ITR;
// Safety: the `recv_from` implementation promises not to write uninitialised
// bytes to the buffer, so this casting is safe.
let src_buf =
unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit<u8>]) };
while let Ok(read_bytes) = udp.recv(src_buf) {
let mut flush = false;
let mut p = peer.lock();
match p.tunnel.decapsulate(
Some(peer_addr),
&t.src_buf[..read_bytes],
&mut t.dst_buf[..],
) {
TunnResult::Done => {}
TunnResult::Err(e) => eprintln!("Decapsulate error {:?}", e),
TunnResult::WriteToNetwork(packet) => {
flush = true;
let _: Result<_, _> = udp.send(packet);
}
TunnResult::WriteToTunnelV4(packet, addr) => {
if p.is_allowed_ip(addr) {
iface.write4(packet);
}
}
TunnResult::WriteToTunnelV6(packet, addr) => {
if p.is_allowed_ip(addr) {
iface.write6(packet);
}
}
};
if flush {
// Flush pending queue
while let TunnResult::WriteToNetwork(packet) =
p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..])
{
let _: Result<_, _> = udp.send(packet);
}
}
iter -= 1;
if iter == 0 {
break;
}
}
Action::Continue
}),
)?;
Ok(())
}
fn register_iface_handler(&self, iface: Arc<TunSocket>) -> Result<(), Error> {
self.queue.new_event(
iface.as_raw_fd(),
Box::new(move |d, t| {
// The iface_handler handles packets received from the WireGuard virtual network
// interface. The flow is as follows:
// * Read a packet
// * Determine peer based on packet destination ip
// * Encapsulate the packet for the given peer
// * Send encapsulated packet to the peer's endpoint
let mtu = d.mtu.load(Ordering::Relaxed);
let udp4 = d.udp4.as_ref().expect("Not connected");
let udp6 = d.udp6.as_ref().expect("Not connected");
let peers = &d.peers_by_ip;
for _ in 0..MAX_ITR {
let src = match iface.read(&mut t.src_buf[..mtu]) {
Ok(src) => src,
Err(Error::IfaceRead(e)) => {
let ek = e.kind();
if ek == io::ErrorKind::Interrupted || ek == io::ErrorKind::WouldBlock {
break;
}
eprintln!("Fatal read error on tun interface: {:?}", e);
return Action::Exit;
}
Err(e) => {
eprintln!("Unexpected error on tun interface: {:?}", e);
return Action::Exit;
}
};
let dst_addr = match Tunn::dst_address(src) {
Some(addr) => addr,
None => continue,
};
let mut peer = match peers.find(dst_addr) {
Some(peer) => peer.lock(),
None => continue,
};
match peer.tunnel.encapsulate(src, &mut t.dst_buf[..]) {
TunnResult::Done => {}
TunnResult::Err(e) => {
tracing::error!(message = "Encapsulate error", error = ?e)
}
TunnResult::WriteToNetwork(packet) => {
let mut endpoint = peer.endpoint_mut();
if let Some(conn) = endpoint.conn.as_mut() {
// Prefer to send using the connected socket
let _: Result<_, _> = conn.write(packet);
} else if let Some(addr @ SocketAddr::V4(_)) = endpoint.addr {
let _: Result<_, _> = udp4.send_to(packet, &addr.into());
} else if let Some(addr @ SocketAddr::V6(_)) = endpoint.addr {
let _: Result<_, _> = udp6.send_to(packet, &addr.into());
} else {
tracing::error!("No endpoint");
}
}
_ => panic!("Unexpected result from encapsulate"),
};
}
Action::Continue
}),
)?;
Ok(())
}
}
/// A basic linear-feedback shift register implemented as xorshift, used to
/// distribute peer indexes across the 24-bit address space reserved for peer
/// identification.
/// The purpose is to obscure the total number of peers using the system and to
/// ensure it requires a non-trivial amount of processing power and/or samples
/// to guess other peers' indices. Anything more ambitious than this is wasted
/// with only 24 bits of space.
struct IndexLfsr {
initial: u32,
lfsr: u32,
mask: u32,
}
impl IndexLfsr {
/// Generate a random 24-bit nonzero integer
fn random_index() -> u32 {
const LFSR_MAX: u32 = 0xffffff; // 24-bit seed
loop {
let i = OsRng.next_u32() & LFSR_MAX;
if i > 0 {
// LFSR seed must be non-zero
return i;
}
}
}
/// Generate the next value in the pseudorandom sequence
fn next(&mut self) -> u32 {
// 24-bit polynomial for randomness. This is arbitrarily chosen to
// inject bitflips into the value.
const LFSR_POLY: u32 = 0xd80000; // 24-bit polynomial
let value = self.lfsr - 1; // lfsr will never have value of 0
self.lfsr = (self.lfsr >> 1) ^ ((0u32.wrapping_sub(self.lfsr & 1u32)) & LFSR_POLY);
assert!(self.lfsr != self.initial, "Too many peers created");
value ^ self.mask
}
}
impl Default for IndexLfsr {
fn default() -> Self {
let seed = Self::random_index();
IndexLfsr {
initial: seed,
lfsr: seed,
mask: Self::random_index(),
}
}
}

View File

@@ -0,0 +1,170 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use parking_lot::RwLock;
use socket2::{Domain, Protocol, Type};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::str::FromStr;
use crate::device::{AllowedIps, Error};
use crate::noise::{Tunn, TunnResult};
#[derive(Default, Debug)]
pub struct Endpoint {
pub addr: Option<SocketAddr>,
pub conn: Option<socket2::Socket>,
}
pub struct Peer {
/// The associated tunnel struct
pub(crate) tunnel: Tunn,
/// The index the tunnel uses
index: u32,
endpoint: RwLock<Endpoint>,
allowed_ips: AllowedIps<()>,
preshared_key: Option<[u8; 32]>,
}
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
pub struct AllowedIP {
pub addr: IpAddr,
pub cidr: u8,
}
impl FromStr for AllowedIP {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let ip: Vec<&str> = s.split('/').collect();
if ip.len() != 2 {
return Err("Invalid IP format".to_owned());
}
let (addr, cidr) = (ip[0].parse::<IpAddr>(), ip[1].parse::<u8>());
match (addr, cidr) {
(Ok(addr @ IpAddr::V4(_)), Ok(cidr)) if cidr <= 32 => Ok(AllowedIP { addr, cidr }),
(Ok(addr @ IpAddr::V6(_)), Ok(cidr)) if cidr <= 128 => Ok(AllowedIP { addr, cidr }),
_ => Err("Invalid IP format".to_owned()),
}
}
}
impl Peer {
pub fn new(
tunnel: Tunn,
index: u32,
endpoint: Option<SocketAddr>,
allowed_ips: &[AllowedIP],
preshared_key: Option<[u8; 32]>,
) -> Peer {
Peer {
tunnel,
index,
endpoint: RwLock::new(Endpoint {
addr: endpoint,
conn: None,
}),
allowed_ips: allowed_ips.iter().map(|ip| (ip, ())).collect(),
preshared_key,
}
}
pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
self.tunnel.update_timers(dst)
}
pub fn endpoint(&self) -> parking_lot::RwLockReadGuard<'_, Endpoint> {
self.endpoint.read()
}
pub(crate) fn endpoint_mut(&self) -> parking_lot::RwLockWriteGuard<'_, Endpoint> {
self.endpoint.write()
}
pub fn shutdown_endpoint(&self) {
if let Some(conn) = self.endpoint.write().conn.take() {
tracing::info!("Disconnecting from endpoint");
conn.shutdown(Shutdown::Both).unwrap();
}
}
pub fn set_endpoint(&self, addr: SocketAddr) {
let mut endpoint = self.endpoint.write();
if endpoint.addr != Some(addr) {
// We only need to update the endpoint if it differs from the current one
if let Some(conn) = endpoint.conn.take() {
conn.shutdown(Shutdown::Both).unwrap();
}
endpoint.addr = Some(addr);
}
}
pub fn connect_endpoint(
&self,
port: u16,
fwmark: Option<u32>,
) -> Result<socket2::Socket, Error> {
let mut endpoint = self.endpoint.write();
if endpoint.conn.is_some() {
return Err(Error::Connect("Connected".to_owned()));
}
let addr = endpoint
.addr
.expect("Attempt to connect to undefined endpoint");
let udp_conn =
socket2::Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::UDP))?;
udp_conn.set_reuse_address(true)?;
let bind_addr = if addr.is_ipv4() {
SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into()
} else {
SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).into()
};
udp_conn.bind(&bind_addr)?;
udp_conn.connect(&addr.into())?;
udp_conn.set_nonblocking(true)?;
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(fwmark) = fwmark {
udp_conn.set_mark(fwmark)?;
}
tracing::info!(
message="Connected endpoint",
port=port,
endpoint=?endpoint.addr.unwrap()
);
endpoint.conn = Some(udp_conn.try_clone().unwrap());
Ok(udp_conn)
}
pub fn is_allowed_ip<I: Into<IpAddr>>(&self, addr: I) -> bool {
self.allowed_ips.find(addr.into()).is_some()
}
pub fn allowed_ips(&self) -> impl Iterator<Item = (IpAddr, u8)> + '_ {
self.allowed_ips.iter().map(|(_, ip, cidr)| (ip, cidr))
}
pub fn time_since_last_handshake(&self) -> Option<std::time::Duration> {
self.tunnel.time_since_last_handshake()
}
pub fn persistent_keepalive(&self) -> Option<u16> {
self.tunnel.persistent_keepalive()
}
pub fn preshared_key(&self) -> Option<&[u8; 32]> {
self.preshared_key.as_ref()
}
pub fn index(&self) -> u32 {
self.index
}
}

View File

@@ -0,0 +1,256 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use super::Error;
use libc::*;
use std::io;
use std::mem::size_of;
use std::mem::size_of_val;
use std::os::unix::io::{AsRawFd, RawFd};
use std::ptr::null_mut;
const CTRL_NAME: &[u8] = b"com.apple.net.utun_control";
#[repr(C)]
pub struct ctl_info {
pub ctl_id: u32,
pub ctl_name: [c_uchar; 96],
}
#[repr(C)]
union IfrIfru {
ifru_addr: sockaddr,
ifru_addr_v4: sockaddr_in,
ifru_addr_v6: sockaddr_in,
ifru_dstaddr: sockaddr,
ifru_broadaddr: sockaddr,
ifru_flags: c_short,
ifru_metric: c_int,
ifru_mtu: c_int,
ifru_phys: c_int,
ifru_media: c_int,
ifru_intval: c_int,
//ifru_data: caddr_t,
//ifru_devmtu: ifdevmtu,
//ifru_kpi: ifkpi,
ifru_wake_flags: u32,
ifru_route_refcnt: u32,
ifru_cap: [c_int; 2],
ifru_functional_type: u32,
}
#[repr(C)]
pub struct ifreq {
ifr_name: [c_uchar; IF_NAMESIZE],
ifr_ifru: IfrIfru,
}
const CTLIOCGINFO: u64 = 0x0000_0000_c064_4e03;
const SIOCGIFMTU: u64 = 0x0000_0000_c020_6933;
#[derive(Default, Debug)]
pub struct TunSocket {
fd: RawFd,
}
impl Drop for TunSocket {
fn drop(&mut self) {
unsafe { close(self.fd) };
}
}
impl AsRawFd for TunSocket {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
// On Darwin tunnel can only be named utunXXX
pub fn parse_utun_name(name: &str) -> Result<u32, Error> {
if !name.starts_with("utun") {
return Err(Error::InvalidTunnelName);
}
match name.get(4..) {
None | Some("") => {
// The name is simply "utun"
Ok(0)
}
Some(idx) => {
// Everything past utun should represent an integer index
idx.parse::<u32>()
.map_err(|_| Error::InvalidTunnelName)
.map(|x| x + 1)
}
}
}
impl TunSocket {
fn write(&self, src: &[u8], af: u8) -> usize {
let mut hdr = [0u8, 0u8, 0u8, af as u8];
let mut iov = [
iovec {
iov_base: hdr.as_mut_ptr() as _,
iov_len: hdr.len(),
},
iovec {
iov_base: src.as_ptr() as _,
iov_len: src.len(),
},
];
let msg_hdr = msghdr {
msg_name: null_mut(),
msg_namelen: 0,
msg_iov: &mut iov[0],
msg_iovlen: iov.len() as _,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
match unsafe { sendmsg(self.fd, &msg_hdr, 0) } {
-1 => 0,
n => n as usize,
}
}
pub fn new(name: &str) -> Result<TunSocket, Error> {
let idx = parse_utun_name(name)?;
let fd = match unsafe { socket(PF_SYSTEM, SOCK_DGRAM, SYSPROTO_CONTROL) } {
-1 => return Err(Error::Socket(io::Error::last_os_error())),
fd => fd,
};
let mut info = ctl_info {
ctl_id: 0,
ctl_name: [0u8; 96],
};
info.ctl_name[..CTRL_NAME.len()].copy_from_slice(CTRL_NAME);
if unsafe { ioctl(fd, CTLIOCGINFO, &mut info as *mut ctl_info) } < 0 {
unsafe { close(fd) };
return Err(Error::IOCtl(io::Error::last_os_error()));
}
let addr = sockaddr_ctl {
sc_len: size_of::<sockaddr_ctl>() as u8,
sc_family: AF_SYSTEM as u8,
ss_sysaddr: AF_SYS_CONTROL as u16,
sc_id: info.ctl_id,
sc_unit: idx,
sc_reserved: Default::default(),
};
if unsafe {
connect(
fd,
&addr as *const sockaddr_ctl as _,
size_of_val(&addr) as _,
)
} < 0
{
unsafe { close(fd) };
let mut err_string = io::Error::last_os_error().to_string();
err_string.push_str("(did you run with sudo?)");
return Err(Error::Connect(err_string));
}
Ok(TunSocket { fd })
}
pub fn set_non_blocking(self) -> Result<TunSocket, Error> {
match unsafe { fcntl(self.fd, F_GETFL) } {
-1 => Err(Error::FCntl(io::Error::last_os_error())),
flags => match unsafe { fcntl(self.fd, F_SETFL, flags | O_NONBLOCK) } {
-1 => Err(Error::FCntl(io::Error::last_os_error())),
_ => Ok(self),
},
}
}
pub fn name(&self) -> Result<String, Error> {
let mut tunnel_name = [0u8; 256];
let mut tunnel_name_len: socklen_t = tunnel_name.len() as u32;
if unsafe {
getsockopt(
self.fd,
SYSPROTO_CONTROL,
UTUN_OPT_IFNAME,
tunnel_name.as_mut_ptr() as _,
&mut tunnel_name_len,
)
} < 0
|| tunnel_name_len == 0
{
return Err(Error::GetSockOpt(io::Error::last_os_error()));
}
Ok(String::from_utf8_lossy(&tunnel_name[..(tunnel_name_len - 1) as usize]).to_string())
}
/// Get the current MTU value
pub fn mtu(&self) -> Result<usize, Error> {
let fd = match unsafe { socket(AF_INET, SOCK_STREAM, IPPROTO_IP) } {
-1 => return Err(Error::Socket(io::Error::last_os_error())),
fd => fd,
};
let name = self.name()?;
let iface_name: &[u8] = name.as_ref();
let mut ifr = ifreq {
ifr_name: [0; IF_NAMESIZE],
ifr_ifru: IfrIfru { ifru_mtu: 0 },
};
ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name);
if unsafe { ioctl(fd, SIOCGIFMTU, &ifr) } < 0 {
return Err(Error::IOCtl(io::Error::last_os_error()));
}
unsafe { close(fd) };
Ok(unsafe { ifr.ifr_ifru.ifru_mtu } as _)
}
pub fn write4(&self, src: &[u8]) -> usize {
self.write(src, AF_INET as u8)
}
pub fn write6(&self, src: &[u8]) -> usize {
self.write(src, AF_INET6 as u8)
}
pub fn read<'a>(&self, dst: &'a mut [u8]) -> Result<&'a mut [u8], Error> {
let mut hdr = [0u8; 4];
let mut iov = [
iovec {
iov_base: hdr.as_mut_ptr() as _,
iov_len: hdr.len(),
},
iovec {
iov_base: dst.as_mut_ptr() as _,
iov_len: dst.len(),
},
];
let mut msg_hdr = msghdr {
msg_name: null_mut(),
msg_namelen: 0,
msg_iov: &mut iov[0],
msg_iovlen: iov.len() as _,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
match unsafe { recvmsg(self.fd, &mut msg_hdr, 0) } {
-1 => Err(Error::IfaceRead(io::Error::last_os_error())),
0..=4 => Ok(&mut dst[..0]),
n => Ok(&mut dst[..(n - 4) as usize]),
}
}
}

View File

@@ -0,0 +1,159 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use super::Error;
use libc::*;
use std::io;
use std::os::unix::io::{AsRawFd, RawFd};
const TUNSETIFF: u64 = 0x4004_54ca;
#[repr(C)]
union IfrIfru {
ifru_addr: sockaddr,
ifru_addr_v4: sockaddr_in,
ifru_addr_v6: sockaddr_in,
ifru_dstaddr: sockaddr,
ifru_broadaddr: sockaddr,
ifru_flags: c_short,
ifru_metric: c_int,
ifru_mtu: c_int,
ifru_phys: c_int,
ifru_media: c_int,
ifru_intval: c_int,
//ifru_data: caddr_t,
//ifru_devmtu: ifdevmtu,
//ifru_kpi: ifkpi,
ifru_wake_flags: u32,
ifru_route_refcnt: u32,
ifru_cap: [c_int; 2],
ifru_functional_type: u32,
}
#[repr(C)]
pub struct ifreq {
ifr_name: [c_uchar; IFNAMSIZ],
ifr_ifru: IfrIfru,
}
#[derive(Default, Debug)]
pub struct TunSocket {
fd: RawFd,
name: String,
}
impl Drop for TunSocket {
fn drop(&mut self) {
unsafe { close(self.fd) };
}
}
impl AsRawFd for TunSocket {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl TunSocket {
fn write(&self, buf: &[u8]) -> usize {
match unsafe { write(self.fd, buf.as_ptr() as _, buf.len() as _) } {
-1 => 0,
n => n as usize,
}
}
pub fn new(name: &str) -> Result<TunSocket, Error> {
// If the provided name appears to be a FD, use that.
let provided_fd = name.parse::<i32>();
if let Ok(fd) = provided_fd {
return Ok(TunSocket {
fd,
name: name.to_string(),
});
}
let fd = match unsafe { open(b"/dev/net/tun\0".as_ptr() as _, O_RDWR) } {
-1 => return Err(Error::Socket(io::Error::last_os_error())),
fd => fd,
};
let iface_name = name.as_bytes();
let mut ifr = ifreq {
ifr_name: [0; IFNAMSIZ],
ifr_ifru: IfrIfru {
ifru_flags: (IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE) as _,
},
};
if iface_name.len() >= ifr.ifr_name.len() {
return Err(Error::InvalidTunnelName);
}
ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name);
if unsafe { ioctl(fd, TUNSETIFF as _, &ifr) } < 0 {
return Err(Error::IOCtl(io::Error::last_os_error()));
}
let name = name.to_string();
Ok(TunSocket { fd, name })
}
pub fn set_non_blocking(self) -> Result<TunSocket, Error> {
match unsafe { fcntl(self.fd, F_GETFL) } {
-1 => Err(Error::FCntl(io::Error::last_os_error())),
flags => match unsafe { fcntl(self.fd, F_SETFL, flags | O_NONBLOCK) } {
-1 => Err(Error::FCntl(io::Error::last_os_error())),
_ => Ok(self),
},
}
}
pub fn name(&self) -> Result<String, Error> {
Ok(self.name.clone())
}
/// Get the current MTU value
pub fn mtu(&self) -> Result<usize, Error> {
let provided_fd = self.name.parse::<i32>();
if provided_fd.is_ok() {
return Ok(1500);
}
let fd = match unsafe { socket(AF_INET, SOCK_STREAM, IPPROTO_IP) } {
-1 => return Err(Error::Socket(io::Error::last_os_error())),
fd => fd,
};
let name = self.name()?;
let iface_name: &[u8] = name.as_ref();
let mut ifr = ifreq {
ifr_name: [0; IF_NAMESIZE],
ifr_ifru: IfrIfru { ifru_mtu: 0 },
};
ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name);
if unsafe { ioctl(fd, SIOCGIFMTU as _, &ifr) } < 0 {
return Err(Error::IOCtl(io::Error::last_os_error()));
}
unsafe { close(fd) };
Ok(unsafe { ifr.ifr_ifru.ifru_mtu } as _)
}
pub fn write4(&self, src: &[u8]) -> usize {
self.write(src)
}
pub fn write6(&self, src: &[u8]) -> usize {
self.write(src)
}
pub fn read<'a>(&self, dst: &'a mut [u8]) -> Result<&'a mut [u8], Error> {
match unsafe { read(self.fd, dst.as_mut_ptr() as _, dst.len()) } {
-1 => Err(Error::IfaceRead(io::Error::last_os_error())),
n => Ok(&mut dst[..n as usize]),
}
}
}

View File

@@ -0,0 +1,397 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// Requiring explicit per-fn "Safety" docs not worth it. Just pass in valid
// pointers and buffers/lengths to these, ok?
#![allow(clippy::missing_safety_doc)]
//! C bindings for the BoringTun library
use super::noise::{Tunn, TunnResult};
use crate::x25519::{PublicKey, StaticSecret};
use base64::{decode, encode};
use hex::encode as encode_hex;
use libc::{raise, SIGSEGV};
use parking_lot::Mutex;
use rand_core::OsRng;
use tracing;
use tracing_subscriber::fmt;
use crate::serialization::KeyBytes;
use std::ffi::{CStr, CString};
use std::io::{Error, ErrorKind, Write};
use std::os::raw::c_char;
use std::panic;
use std::ptr;
use std::ptr::null_mut;
use std::slice;
use std::sync::Once;
static PANIC_HOOK: Once = Once::new();
#[allow(non_camel_case_types)]
#[repr(C)]
/// Indicates the operation required from the caller
pub enum result_type {
/// No operation is required.
WIREGUARD_DONE = 0,
/// Write dst buffer to network. Size indicates the number of bytes to write.
WRITE_TO_NETWORK = 1,
/// Some error occurred, no operation is required. Size indicates error code.
WIREGUARD_ERROR = 2,
/// Write dst buffer to the interface as an ipv4 packet. Size indicates the number of bytes to write.
WRITE_TO_TUNNEL_IPV4 = 4,
/// Write dst buffer to the interface as an ipv6 packet. Size indicates the number of bytes to write.
WRITE_TO_TUNNEL_IPV6 = 6,
}
/// The return type of WireGuard functions
#[repr(C)]
pub struct wireguard_result {
/// The operation to be performed by the caller
pub op: result_type,
/// Additional information, required to perform the operation
pub size: usize,
}
#[repr(C)]
pub struct stats {
pub time_since_last_handshake: i64,
pub tx_bytes: usize,
pub rx_bytes: usize,
pub estimated_loss: f32,
pub estimated_rtt: i32,
reserved: [u8; 56], // Make sure to add new fields in this space, keeping total size constant
}
impl<'a> From<TunnResult<'a>> for wireguard_result {
fn from(res: TunnResult<'a>) -> wireguard_result {
match res {
TunnResult::Done => wireguard_result {
op: result_type::WIREGUARD_DONE,
size: 0,
},
TunnResult::Err(e) => wireguard_result {
op: result_type::WIREGUARD_ERROR,
size: e as _,
},
TunnResult::WriteToNetwork(b) => wireguard_result {
op: result_type::WRITE_TO_NETWORK,
size: b.len(),
},
TunnResult::WriteToTunnelV4(b, _) => wireguard_result {
op: result_type::WRITE_TO_TUNNEL_IPV4,
size: b.len(),
},
TunnResult::WriteToTunnelV6(b, _) => wireguard_result {
op: result_type::WRITE_TO_TUNNEL_IPV6,
size: b.len(),
},
}
}
}
#[repr(C)]
pub struct x25519_key {
pub key: [u8; 32],
}
/// Generates a new x25519 secret key.
#[no_mangle]
pub extern "C" fn x25519_secret_key() -> x25519_key {
x25519_key {
key: StaticSecret::random_from_rng(OsRng).to_bytes(),
}
}
/// Computes a public x25519 key from a secret key.
#[no_mangle]
pub extern "C" fn x25519_public_key(private_key: x25519_key) -> x25519_key {
let private = StaticSecret::from(private_key.key);
let public = PublicKey::from(&private);
x25519_key {
key: public.to_bytes(),
}
}
/// Returns the base64 encoding of a key as a UTF8 C-string.
///
/// The memory has to be freed by calling `x25519_key_to_str_free`
#[no_mangle]
pub extern "C" fn x25519_key_to_base64(key: x25519_key) -> *const c_char {
let encoded_key = encode(key.key);
CString::into_raw(CString::new(encoded_key).unwrap())
}
/// Returns the hex encoding of a key as a UTF8 C-string.
///
/// The memory has to be freed by calling `x25519_key_to_str_free`
#[no_mangle]
pub extern "C" fn x25519_key_to_hex(key: x25519_key) -> *const c_char {
let encoded_key = encode_hex(key.key);
CString::into_raw(CString::new(encoded_key).unwrap())
}
/// Frees memory of the string given by `x25519_key_to_hex` or `x25519_key_to_base64`
#[no_mangle]
pub unsafe extern "C" fn x25519_key_to_str_free(stringified_key: *mut c_char) {
drop(CString::from_raw(stringified_key));
}
/// Check if the input C-string represents a valid base64 encoded x25519 key.
/// Return 1 if valid 0 otherwise.
#[no_mangle]
pub unsafe extern "C" fn check_base64_encoded_x25519_key(key: *const c_char) -> i32 {
let c_str = CStr::from_ptr(key);
let utf8_key = match c_str.to_str() {
Err(_) => return 0,
Ok(string) => string,
};
if let Ok(key) = decode(utf8_key) {
let len = key.len();
let mut zero = 0u8;
for b in key {
zero |= b
}
if len == 32 && zero != 0 {
1
} else {
0
}
} else {
0
}
}
/// Custom tracing_subscriber writer to an external function pointer
struct FFIFunctionPointerWriter {
log_func: unsafe extern "C" fn(*const c_char),
}
/// Implements Write trait for use with tracing_subscriber
impl Write for FFIFunctionPointerWriter {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
let out_str = String::from_utf8_lossy(buf).to_string();
if let Ok(c_string) = CString::new(out_str) {
unsafe { (self.log_func)(c_string.as_ptr()) }
Ok(buf.len())
} else {
Err(Error::new(
ErrorKind::Other,
"Failed to create CString from buffer.",
))
}
}
fn flush(&mut self) -> Result<(), std::io::Error> {
// no-op
Ok(())
}
}
/// Sets the default tracing_subscriber to write to `log_func`.
///
/// Uses Compact format without level, target, thread ids, thread names, or ansi control characters.
/// Subscribes to TRACE level events.
///
/// This function should only be called once as setting the default tracing_subscriber
/// more than once will result in an error.
///
/// Returns false on failure.
///
/// # Safety
///
/// `c_char` will be freed by the library after calling `log_func`. If the value needs
/// to be stored then `log_func` needs to create a copy, e.g. `strcpy`.
#[no_mangle]
pub unsafe extern "C" fn set_logging_function(
log_func: unsafe extern "C" fn(*const c_char),
) -> bool {
let result = std::panic::catch_unwind(|| -> bool {
let writer = FFIFunctionPointerWriter { log_func };
let format = fmt::format()
// don't include levels in formatted output
.with_level(false)
// don't include targets
.with_target(false)
// don't 'include the thread ID of the current thread
.with_thread_ids(false)
// don't 'include the name of the current thread
.with_thread_names(false)
// use the `Compact` formatting style.
.compact()
// disable terminal escape codes
.with_ansi(false);
fmt()
.event_format(format)
.with_writer(std::sync::Mutex::new(writer))
.with_max_level(tracing::Level::TRACE)
.with_ansi(false)
.try_init()
.is_ok()
});
if let Ok(value) = result {
value
} else {
false
}
}
/// Allocate a new tunnel, return NULL on failure.
/// Keys must be valid base64 encoded 32-byte keys.
#[no_mangle]
pub unsafe extern "C" fn new_tunnel(
static_private: *const c_char,
server_static_public: *const c_char,
preshared_key: *const c_char,
keep_alive: u16,
index: u32,
) -> *mut Mutex<Tunn> {
let c_str = CStr::from_ptr(static_private);
let static_private = match c_str.to_str() {
Err(_) => return ptr::null_mut(),
Ok(string) => string,
};
let c_str = CStr::from_ptr(server_static_public);
let server_static_public = match c_str.to_str() {
Err(_) => return ptr::null_mut(),
Ok(string) => string,
};
let preshared_key = if preshared_key.is_null() {
None
} else {
let c_str = CStr::from_ptr(preshared_key);
if let Ok(string) = c_str.to_str() {
if let Ok(key) = string.parse::<KeyBytes>() {
Some(key.0)
} else {
return null_mut();
}
} else {
return null_mut();
}
};
let private_key = match static_private.parse::<KeyBytes>() {
Err(_) => return ptr::null_mut(),
Ok(key) => StaticSecret::from(key.0),
};
let public_key = match server_static_public.parse::<KeyBytes>() {
Err(_) => return ptr::null_mut(),
Ok(key) => PublicKey::from(key.0),
};
let keep_alive = if keep_alive == 0 {
None
} else {
Some(keep_alive)
};
let tunnel = Box::new(Mutex::new(Tunn::new(
private_key,
public_key,
preshared_key,
keep_alive,
index,
None,
)));
PANIC_HOOK.call_once(|| {
// FFI won't properly unwind on panic, but it will if we cause a segmentation fault
panic::set_hook(Box::new(move |_| {
raise(SIGSEGV);
}));
});
Box::into_raw(tunnel)
}
/// Drops the Tunn object
#[no_mangle]
pub unsafe extern "C" fn tunnel_free(tunnel: *mut Mutex<Tunn>) {
drop(Box::from_raw(tunnel));
}
/// Write an IP packet from the tunnel interface.
/// For more details check noise::tunnel_to_network functions.
#[no_mangle]
pub unsafe extern "C" fn wireguard_write(
tunnel: *const Mutex<Tunn>,
src: *const u8,
src_size: u32,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let mut tunnel = tunnel.as_ref().unwrap().lock();
// Slices are not owned, and therefore will not be freed by Rust
let src = slice::from_raw_parts(src, src_size as usize);
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
wireguard_result::from(tunnel.encapsulate(src, dst))
}
/// Read a UDP packet from the server.
/// For more details check noise::network_to_tunnel functions.
#[no_mangle]
pub unsafe extern "C" fn wireguard_read(
tunnel: *const Mutex<Tunn>,
src: *const u8,
src_size: u32,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let mut tunnel = tunnel.as_ref().unwrap().lock();
// Slices are not owned, and therefore will not be freed by Rust
let src = slice::from_raw_parts(src, src_size as usize);
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
wireguard_result::from(tunnel.decapsulate(None, src, dst))
}
/// This is a state keeping function, that need to be called periodically.
/// Recommended interval: 100ms.
#[no_mangle]
pub unsafe extern "C" fn wireguard_tick(
tunnel: *const Mutex<Tunn>,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let mut tunnel = tunnel.as_ref().unwrap().lock();
// Slices are not owned, and therefore will not be freed by Rust
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
wireguard_result::from(tunnel.update_timers(dst))
}
/// Force the tunnel to initiate a new handshake, dst buffer must be at least 148 byte long.
#[no_mangle]
pub unsafe extern "C" fn wireguard_force_handshake(
tunnel: *const Mutex<Tunn>,
dst: *mut u8,
dst_size: u32,
) -> wireguard_result {
let mut tunnel = tunnel.as_ref().unwrap().lock();
// Slices are not owned, and therefore will not be freed by Rust
let dst = slice::from_raw_parts_mut(dst, dst_size as usize);
wireguard_result::from(tunnel.format_handshake_initiation(dst, true))
}
/// Returns stats from the tunnel:
/// Time of last handshake in seconds (or -1 if no handshake occurred)
/// Number of data bytes encapsulated
/// Number of data bytes decapsulated
#[no_mangle]
pub unsafe extern "C" fn wireguard_stats(tunnel: *const Mutex<Tunn>) -> stats {
let tunnel = tunnel.as_ref().unwrap().lock();
let (time, tx_bytes, rx_bytes, estimated_loss, estimated_rtt) = tunnel.stats();
stats {
time_since_last_handshake: time.map(|t| t.as_secs() as i64).unwrap_or(-1),
tx_bytes,
rx_bytes,
estimated_loss,
estimated_rtt: estimated_rtt.map(|r| r as i32).unwrap_or(-1),
reserved: [0u8; 56],
}
}

271
lib/boringtun/src/jni.rs Normal file
View File

@@ -0,0 +1,271 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// temporary, we need to do some verification around these bindings later
#![allow(clippy::missing_safety_doc)]
/// JNI bindings for BoringTun library
use std::os::raw::c_char;
use std::ptr;
use jni::objects::{JByteBuffer, JClass, JString};
use jni::strings::JNIStr;
use jni::sys::{jbyteArray, jint, jlong, jshort, jstring};
use jni::JNIEnv;
use parking_lot::Mutex;
use crate::ffi::new_tunnel;
use crate::ffi::wireguard_read;
use crate::ffi::wireguard_result;
use crate::ffi::wireguard_tick;
use crate::ffi::wireguard_write;
use crate::ffi::x25519_key;
use crate::ffi::x25519_key_to_base64;
use crate::ffi::x25519_key_to_hex;
use crate::ffi::x25519_public_key;
use crate::ffi::x25519_secret_key;
use crate::noise::Tunn;
pub extern "C" fn log_print(_log_string: *const c_char) {
/*
XXX:
Define callback function in app.
*/
}
/// Generates new x25519 secret key and converts into java byte array.
#[no_mangle]
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1secret_1key"]
pub extern "C" fn generate_secret_key(env: JNIEnv, _class: JClass) -> jbyteArray {
match env.byte_array_from_slice(&x25519_secret_key().key) {
Ok(v) => v,
Err(_) => ptr::null_mut(),
}
}
/// Computes public x25519 key from secret key and converts into java byte array.
#[no_mangle]
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1public_1key"]
pub unsafe extern "C" fn generate_public_key1(
env: JNIEnv,
_class: JClass,
arg_secret_key: jbyteArray,
) -> jbyteArray {
let mut key_inner = [0; 32];
if env
.get_byte_array_region(arg_secret_key, 0, &mut key_inner)
.is_err()
{
return ptr::null_mut();
}
let secret_key = x25519_key {
key: std::mem::transmute::<[i8; 32], [u8; 32]>(key_inner),
};
match env.byte_array_from_slice(&x25519_public_key(secret_key).key) {
Ok(v) => v,
Err(_) => ptr::null_mut(),
}
}
/// Converts x25519 key to hex string.
#[no_mangle]
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1key_1to_1hex"]
pub unsafe extern "C" fn convert_x25519_key_to_hex(
env: JNIEnv,
_class: JClass,
arg_key: jbyteArray,
) -> jstring {
let mut key = [0; 32];
if env.get_byte_array_region(arg_key, 0, &mut key).is_err() {
return ptr::null_mut();
}
let x25519_key = x25519_key {
key: std::mem::transmute::<[i8; 32], [u8; 32]>(key),
};
let output = match env.new_string(JNIStr::from_ptr(x25519_key_to_hex(x25519_key)).to_owned()) {
Ok(v) => v,
Err(_) => return ptr::null_mut(),
};
output.into_inner()
}
/// Converts x25519 key to base64 string.
#[no_mangle]
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1key_1to_1base64"]
pub unsafe extern "C" fn convert_x25519_key_to_base64(
env: JNIEnv,
_class: JClass,
arg_key: jbyteArray,
) -> jstring {
let mut key = [0; 32];
if env.get_byte_array_region(arg_key, 0, &mut key).is_err() {
return ptr::null_mut();
}
let x25519_key = x25519_key {
key: std::mem::transmute::<[i8; 32], [u8; 32]>(key),
};
let output = match env.new_string(JNIStr::from_ptr(x25519_key_to_base64(x25519_key)).to_owned())
{
Ok(v) => v,
Err(_) => return ptr::null_mut(),
};
output.into_inner()
}
/// Creates new tunnel
#[no_mangle]
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_new_1tunnel"]
pub unsafe extern "C" fn create_new_tunnel(
env: JNIEnv,
_class: JClass,
arg_secret_key: JString,
arg_public_key: JString,
arg_preshared_key: JString,
keep_alive: jshort,
index: jint,
) -> jlong {
let secret_key = match env.get_string_utf_chars(arg_secret_key) {
Ok(v) => v,
Err(_) => return 0,
};
let public_key = match env.get_string_utf_chars(arg_public_key) {
Ok(v) => v,
Err(_) => return 0,
};
let preshared_key = if arg_preshared_key.is_null() {
ptr::null_mut()
} else {
match env.get_string_utf_chars(arg_preshared_key) {
Ok(v) => v,
Err(_) => return 0,
}
};
let tunnel = new_tunnel(
secret_key,
public_key,
preshared_key,
keep_alive as u16,
index as u32,
);
if tunnel.is_null() {
return 0;
}
tunnel as jlong
}
/// Encrypts raw IP packets into WG formatted packets.
#[no_mangle]
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_wireguard_1write"]
pub unsafe extern "C" fn encrypt_raw_packet(
env: JNIEnv,
_class: JClass,
tunnel: jlong,
src: jbyteArray,
src_size: jint,
dst: JByteBuffer,
dst_size: jint,
op: JByteBuffer,
) -> jint {
let dst_ptr: *mut u8 = match env.get_direct_buffer_address(dst) {
Ok(v) => v.as_mut_ptr(),
Err(_) => return 0,
};
let op_ptr: *mut u8 = match env.get_direct_buffer_address(op) {
Ok(v) => v.as_mut_ptr(),
Err(_) => return 0,
};
let output: wireguard_result = wireguard_write(
tunnel as *const Mutex<Tunn>,
env.convert_byte_array(src).unwrap().as_mut_ptr(),
src_size as u32,
dst_ptr,
dst_size as u32,
);
*op_ptr = output.op as u8;
output.size as i32
}
/// Decrypts WG formatted packets into raw IP packets.
#[no_mangle]
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_wireguard_1read"]
pub unsafe extern "C" fn decrypt_to_raw_packet(
env: JNIEnv,
_class: JClass,
tunnel: jlong,
src: jbyteArray,
src_size: jint,
dst: JByteBuffer,
dst_size: jint,
op: JByteBuffer,
) -> jint {
let dst_ptr: *mut u8 = match env.get_direct_buffer_address(dst) {
Ok(v) => v.as_mut_ptr(),
Err(_) => return 0,
};
let op_ptr: *mut u8 = match env.get_direct_buffer_address(op) {
Ok(v) => v.as_mut_ptr(),
Err(_) => return 0,
};
let output: wireguard_result = wireguard_read(
tunnel as *const Mutex<Tunn>,
env.convert_byte_array(src).unwrap().as_mut_ptr(),
src_size as u32,
dst_ptr,
dst_size as u32,
);
*op_ptr = output.op as u8;
output.size as i32
}
/// Periodic function that writes WG formatted packets into destination buffer
#[no_mangle]
#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_wireguard_1tick"]
pub unsafe extern "C" fn run_periodic_task(
env: JNIEnv,
_class: JClass,
tunnel: jlong,
dst: JByteBuffer,
dst_size: jint,
op: JByteBuffer,
) -> jint {
let dst_ptr: *mut u8 = match env.get_direct_buffer_address(dst) {
Ok(v) => v.as_mut_ptr(),
Err(_) => return 0,
};
let op_ptr: *mut u8 = match env.get_direct_buffer_address(op) {
Ok(v) => v.as_mut_ptr(),
Err(_) => return 0,
};
let output: wireguard_result =
wireguard_tick(tunnel as *const Mutex<Tunn>, dst_ptr, dst_size as u32);
*op_ptr = output.op as u8;
output.size as i32
}

27
lib/boringtun/src/lib.rs Normal file
View File

@@ -0,0 +1,27 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//! Simple implementation of the client-side of the WireGuard protocol.
//!
//! <code>git clone https://github.com/cloudflare/boringtun.git</code>
#[cfg(feature = "device")]
pub mod device;
#[cfg(feature = "ffi-bindings")]
pub mod ffi;
#[cfg(feature = "jni-bindings")]
pub mod jni;
pub mod noise;
#[cfg(not(feature = "mock-instant"))]
pub(crate) mod sleepyinstant;
pub(crate) mod serialization;
/// Re-export of the x25519 types
pub mod x25519 {
pub use x25519_dalek::{
EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret,
};
}

View File

@@ -0,0 +1,23 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
#[derive(Debug)]
pub enum WireGuardError {
DestinationBufferTooSmall,
IncorrectPacketLength,
UnexpectedPacket,
WrongPacketType,
WrongIndex,
WrongKey,
InvalidTai64nTimestamp,
WrongTai64nTimestamp,
InvalidMac,
InvalidAeadTag,
InvalidCounter,
DuplicateCounter,
InvalidPacket,
NoCurrentSession,
LockFailed,
ConnectionExpired,
UnderLoad,
}

View File

@@ -0,0 +1,940 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use super::{HandshakeInit, HandshakeResponse, PacketCookieReply};
use crate::noise::errors::WireGuardError;
use crate::noise::session::Session;
#[cfg(not(feature = "mock-instant"))]
use crate::sleepyinstant::Instant;
use crate::x25519;
use aead::{Aead, Payload};
use blake2::digest::{FixedOutput, KeyInit};
use blake2::{Blake2s256, Blake2sMac, Digest};
use chacha20poly1305::XChaCha20Poly1305;
use rand_core::OsRng;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use std::convert::TryInto;
use std::time::{Duration, SystemTime};
#[cfg(feature = "mock-instant")]
use mock_instant::Instant;
pub(crate) const LABEL_MAC1: &[u8; 8] = b"mac1----";
pub(crate) const LABEL_COOKIE: &[u8; 8] = b"cookie--";
const KEY_LEN: usize = 32;
const TIMESTAMP_LEN: usize = 12;
// initiator.chaining_key = HASH(CONSTRUCTION)
const INITIAL_CHAIN_KEY: [u8; KEY_LEN] = [
96, 226, 109, 174, 243, 39, 239, 192, 46, 195, 53, 226, 160, 37, 210, 208, 22, 235, 66, 6, 248,
114, 119, 245, 45, 56, 209, 152, 139, 120, 205, 54,
];
// initiator.chaining_hash = HASH(initiator.chaining_key || IDENTIFIER)
const INITIAL_CHAIN_HASH: [u8; KEY_LEN] = [
34, 17, 179, 97, 8, 26, 197, 102, 105, 18, 67, 219, 69, 138, 213, 50, 45, 156, 108, 102, 34,
147, 232, 183, 14, 225, 156, 101, 186, 7, 158, 243,
];
#[inline]
pub(crate) fn b2s_hash(data1: &[u8], data2: &[u8]) -> [u8; 32] {
let mut hash = Blake2s256::new();
hash.update(data1);
hash.update(data2);
hash.finalize().into()
}
#[inline]
/// RFC 2401 HMAC+Blake2s, not to be confused with *keyed* Blake2s
pub(crate) fn b2s_hmac(key: &[u8], data1: &[u8]) -> [u8; 32] {
use blake2::digest::Update;
type HmacBlake2s = hmac::SimpleHmac<Blake2s256>;
let mut hmac = HmacBlake2s::new_from_slice(key).unwrap();
hmac.update(data1);
hmac.finalize_fixed().into()
}
#[inline]
/// Like b2s_hmac, but chain data1 and data2 together
pub(crate) fn b2s_hmac2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 32] {
use blake2::digest::Update;
type HmacBlake2s = hmac::SimpleHmac<Blake2s256>;
let mut hmac = HmacBlake2s::new_from_slice(key).unwrap();
hmac.update(data1);
hmac.update(data2);
hmac.finalize_fixed().into()
}
#[inline]
pub(crate) fn b2s_keyed_mac_16(key: &[u8], data1: &[u8]) -> [u8; 16] {
let mut hmac = Blake2sMac::new_from_slice(key).unwrap();
blake2::digest::Update::update(&mut hmac, data1);
hmac.finalize_fixed().into()
}
#[inline]
pub(crate) fn b2s_keyed_mac_16_2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 16] {
let mut hmac = Blake2sMac::new_from_slice(key).unwrap();
blake2::digest::Update::update(&mut hmac, data1);
blake2::digest::Update::update(&mut hmac, data2);
hmac.finalize_fixed().into()
}
pub(crate) fn b2s_mac_24(key: &[u8], data1: &[u8]) -> [u8; 24] {
let mut hmac = Blake2sMac::new_from_slice(key).unwrap();
blake2::digest::Update::update(&mut hmac, data1);
hmac.finalize_fixed().into()
}
#[inline]
/// This wrapper involves an extra copy and MAY BE SLOWER
fn aead_chacha20_seal(ciphertext: &mut [u8], key: &[u8], counter: u64, data: &[u8], aad: &[u8]) {
let mut nonce: [u8; 12] = [0; 12];
nonce[4..12].copy_from_slice(&counter.to_le_bytes());
aead_chacha20_seal_inner(ciphertext, key, nonce, data, aad)
}
#[inline]
fn aead_chacha20_seal_inner(
ciphertext: &mut [u8],
key: &[u8],
nonce: [u8; 12],
data: &[u8],
aad: &[u8],
) {
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap());
ciphertext[..data.len()].copy_from_slice(data);
let tag = key
.seal_in_place_separate_tag(
Nonce::assume_unique_for_key(nonce),
Aad::from(aad),
&mut ciphertext[..data.len()],
)
.unwrap();
ciphertext[data.len()..].copy_from_slice(tag.as_ref());
}
#[inline]
/// This wrapper involves an extra copy and MAY BE SLOWER
fn aead_chacha20_open(
buffer: &mut [u8],
key: &[u8],
counter: u64,
data: &[u8],
aad: &[u8],
) -> Result<(), WireGuardError> {
let mut nonce: [u8; 12] = [0; 12];
nonce[4..].copy_from_slice(&counter.to_le_bytes());
aead_chacha20_open_inner(buffer, key, nonce, data, aad)
.map_err(|_| WireGuardError::InvalidAeadTag)?;
Ok(())
}
#[inline]
fn aead_chacha20_open_inner(
buffer: &mut [u8],
key: &[u8],
nonce: [u8; 12],
data: &[u8],
aad: &[u8],
) -> Result<(), ring::error::Unspecified> {
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap());
let mut inner_buffer = data.to_owned();
let plaintext = key.open_in_place(
Nonce::assume_unique_for_key(nonce),
Aad::from(aad),
&mut inner_buffer,
)?;
buffer.copy_from_slice(plaintext);
Ok(())
}
#[derive(Debug)]
/// This struct represents a 12 byte [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp
struct Tai64N {
secs: u64,
nano: u32,
}
#[derive(Debug)]
/// This struct computes a [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp from current system time
struct TimeStamper {
duration_at_start: Duration,
instant_at_start: Instant,
}
impl TimeStamper {
/// Create a new TimeStamper
pub fn new() -> TimeStamper {
TimeStamper {
duration_at_start: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap(),
instant_at_start: Instant::now(),
}
}
/// Take time reading and generate a 12 byte timestamp
pub fn stamp(&self) -> [u8; 12] {
const TAI64_BASE: u64 = (1u64 << 62) + 37;
let mut ext_stamp = [0u8; 12];
let stamp = Instant::now().duration_since(self.instant_at_start) + self.duration_at_start;
ext_stamp[0..8].copy_from_slice(&(stamp.as_secs() + TAI64_BASE).to_be_bytes());
ext_stamp[8..12].copy_from_slice(&stamp.subsec_nanos().to_be_bytes());
ext_stamp
}
}
impl Tai64N {
/// A zeroed out timestamp
fn zero() -> Tai64N {
Tai64N { secs: 0, nano: 0 }
}
/// Parse a timestamp from a 12 byte u8 slice
fn parse(buf: &[u8; 12]) -> Result<Tai64N, WireGuardError> {
if buf.len() < 12 {
return Err(WireGuardError::InvalidTai64nTimestamp);
}
let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::<u64>());
let secs = u64::from_be_bytes(sec_bytes.try_into().unwrap());
let nano = u32::from_be_bytes(nano_bytes.try_into().unwrap());
// WireGuard does not actually expect tai64n timestamp, just monotonically increasing one
//if secs < (1u64 << 62) || secs >= (1u64 << 63) {
// return Err(WireGuardError::InvalidTai64nTimestamp);
//};
//if nano >= 1_000_000_000 {
// return Err(WireGuardError::InvalidTai64nTimestamp);
//}
Ok(Tai64N { secs, nano })
}
/// Check if this timestamp represents a time that is chronologically after the time represented
/// by the other timestamp
pub fn after(&self, other: &Tai64N) -> bool {
(self.secs > other.secs) || ((self.secs == other.secs) && (self.nano > other.nano))
}
}
/// Parameters used by the noise protocol
struct NoiseParams {
/// Our static public key
static_public: x25519::PublicKey,
/// Our static private key
static_private: x25519::StaticSecret,
/// Static public key of the other party
peer_static_public: x25519::PublicKey,
/// A shared key = DH(static_private, peer_static_public)
static_shared: x25519::SharedSecret,
/// A pre-computation of HASH("mac1----", peer_static_public) for this peer
sending_mac1_key: [u8; KEY_LEN],
/// An optional preshared key
preshared_key: Option<[u8; KEY_LEN]>,
}
impl std::fmt::Debug for NoiseParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NoiseParams")
.field("static_public", &self.static_public)
.field("static_private", &"<redacted>")
.field("peer_static_public", &self.peer_static_public)
.field("static_shared", &"<redacted>")
.field("sending_mac1_key", &self.sending_mac1_key)
.field("preshared_key", &self.preshared_key)
.finish()
}
}
struct HandshakeInitSentState {
local_index: u32,
hash: [u8; KEY_LEN],
chaining_key: [u8; KEY_LEN],
ephemeral_private: x25519::ReusableSecret,
time_sent: Instant,
}
impl std::fmt::Debug for HandshakeInitSentState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HandshakeInitSentState")
.field("local_index", &self.local_index)
.field("hash", &self.hash)
.field("chaining_key", &self.chaining_key)
.field("ephemeral_private", &"<redacted>")
.field("time_sent", &self.time_sent)
.finish()
}
}
#[derive(Debug)]
enum HandshakeState {
/// No handshake in process
None,
/// We initiated the handshake
InitSent(HandshakeInitSentState),
/// Handshake initiated by peer
InitReceived {
hash: [u8; KEY_LEN],
chaining_key: [u8; KEY_LEN],
peer_ephemeral_public: x25519::PublicKey,
peer_index: u32,
},
/// Handshake was established too long ago (implies no handshake is in progress)
Expired,
}
pub struct Handshake {
params: NoiseParams,
/// Index of the next session
next_index: u32,
/// Allow to have two outgoing handshakes in flight, because sometimes we may receive a delayed response to a handshake with bad networks
previous: HandshakeState,
/// Current handshake state
state: HandshakeState,
cookies: Cookies,
/// The timestamp of the last handshake we received
last_handshake_timestamp: Tai64N,
// TODO: make TimeStamper a singleton
stamper: TimeStamper,
pub(super) last_rtt: Option<u32>,
}
#[derive(Default)]
struct Cookies {
last_mac1: Option<[u8; 16]>,
index: u32,
write_cookie: Option<[u8; 16]>,
}
#[derive(Debug)]
pub struct HalfHandshake {
pub peer_index: u32,
pub peer_static_public: [u8; 32],
}
pub fn parse_handshake_anon(
static_private: &x25519::StaticSecret,
static_public: &x25519::PublicKey,
packet: &HandshakeInit,
) -> Result<HalfHandshake, WireGuardError> {
let peer_index = packet.sender_idx;
// initiator.chaining_key = HASH(CONSTRUCTION)
let mut chaining_key = INITIAL_CHAIN_KEY;
// initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public)
let mut hash = INITIAL_CHAIN_HASH;
hash = b2s_hash(&hash, static_public.as_bytes());
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral);
// initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral)
hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes());
// temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral)
// initiator.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(
&b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()),
&[0x01],
);
// temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public))
let ephemeral_shared = static_private.diffie_hellman(&peer_ephemeral_public);
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
// initiator.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// key = HMAC(temp, initiator.chaining_key || 0x2)
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
let mut peer_static_public = [0u8; KEY_LEN];
// msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash)
aead_chacha20_open(
&mut peer_static_public,
&key,
0,
packet.encrypted_static,
&hash,
)?;
Ok(HalfHandshake {
peer_index,
peer_static_public,
})
}
impl NoiseParams {
/// New noise params struct from our secret key, peers public key, and optional preshared key
fn new(
static_private: x25519::StaticSecret,
static_public: x25519::PublicKey,
peer_static_public: x25519::PublicKey,
preshared_key: Option<[u8; 32]>,
) -> NoiseParams {
let static_shared = static_private.diffie_hellman(&peer_static_public);
let initial_sending_mac_key = b2s_hash(LABEL_MAC1, peer_static_public.as_bytes());
NoiseParams {
static_public,
static_private,
peer_static_public,
static_shared,
sending_mac1_key: initial_sending_mac_key,
preshared_key,
}
}
/// Set a new private key
fn set_static_private(
&mut self,
static_private: x25519::StaticSecret,
static_public: x25519::PublicKey,
) {
// Check that the public key indeed matches the private key
let check_key = x25519::PublicKey::from(&static_private);
assert_eq!(check_key.as_bytes(), static_public.as_bytes());
self.static_private = static_private;
self.static_public = static_public;
self.static_shared = self.static_private.diffie_hellman(&self.peer_static_public);
}
}
impl Handshake {
pub(crate) fn new(
static_private: x25519::StaticSecret,
static_public: x25519::PublicKey,
peer_static_public: x25519::PublicKey,
global_idx: u32,
preshared_key: Option<[u8; 32]>,
) -> Handshake {
let params = NoiseParams::new(
static_private,
static_public,
peer_static_public,
preshared_key,
);
Handshake {
params,
next_index: global_idx,
previous: HandshakeState::None,
state: HandshakeState::None,
last_handshake_timestamp: Tai64N::zero(),
stamper: TimeStamper::new(),
cookies: Default::default(),
last_rtt: None,
}
}
pub(crate) fn is_in_progress(&self) -> bool {
!matches!(self.state, HandshakeState::None | HandshakeState::Expired)
}
pub(crate) fn timer(&self) -> Option<Instant> {
match self.state {
HandshakeState::InitSent(HandshakeInitSentState { time_sent, .. }) => Some(time_sent),
_ => None,
}
}
pub(crate) fn set_expired(&mut self) {
self.previous = HandshakeState::Expired;
self.state = HandshakeState::Expired;
}
pub(crate) fn is_expired(&self) -> bool {
matches!(self.state, HandshakeState::Expired)
}
pub(crate) fn has_cookie(&self) -> bool {
self.cookies.write_cookie.is_some()
}
pub(crate) fn clear_cookie(&mut self) {
self.cookies.write_cookie = None;
}
// The index used is 24 bits for peer index, allowing for 16M active peers per server and 8 bits for cyclic session index
fn inc_index(&mut self) -> u32 {
let index = self.next_index;
let idx8 = index as u8;
self.next_index = (index & !0xff) | u32::from(idx8.wrapping_add(1));
self.next_index
}
pub(crate) fn set_static_private(
&mut self,
private_key: x25519::StaticSecret,
public_key: x25519::PublicKey,
) {
self.params.set_static_private(private_key, public_key)
}
pub(super) fn receive_handshake_initialization<'a>(
&mut self,
packet: HandshakeInit,
dst: &'a mut [u8],
) -> Result<(&'a mut [u8], Session), WireGuardError> {
// initiator.chaining_key = HASH(CONSTRUCTION)
let mut chaining_key = INITIAL_CHAIN_KEY;
// initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public)
let mut hash = INITIAL_CHAIN_HASH;
hash = b2s_hash(&hash, self.params.static_public.as_bytes());
// msg.sender_index = little_endian(initiator.sender_index)
let peer_index = packet.sender_idx;
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral);
// initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral)
hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes());
// temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral)
// initiator.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(
&b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()),
&[0x01],
);
// temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public))
let ephemeral_shared = self
.params
.static_private
.diffie_hellman(&peer_ephemeral_public);
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
// initiator.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// key = HMAC(temp, initiator.chaining_key || 0x2)
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
let mut peer_static_public_decrypted = [0u8; KEY_LEN];
// msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash)
aead_chacha20_open(
&mut peer_static_public_decrypted,
&key,
0,
packet.encrypted_static,
&hash,
)?;
ring::constant_time::verify_slices_are_equal(
self.params.peer_static_public.as_bytes(),
&peer_static_public_decrypted,
)
.map_err(|_| WireGuardError::WrongKey)?;
// initiator.hash = HASH(initiator.hash || msg.encrypted_static)
hash = b2s_hash(&hash, packet.encrypted_static);
// temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public))
let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes());
// initiator.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// key = HMAC(temp, initiator.chaining_key || 0x2)
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
// msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash)
let mut timestamp = [0u8; TIMESTAMP_LEN];
aead_chacha20_open(&mut timestamp, &key, 0, packet.encrypted_timestamp, &hash)?;
let timestamp = Tai64N::parse(&timestamp)?;
if !timestamp.after(&self.last_handshake_timestamp) {
// Possibly a replay
return Err(WireGuardError::WrongTai64nTimestamp);
}
self.last_handshake_timestamp = timestamp;
// initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp)
hash = b2s_hash(&hash, packet.encrypted_timestamp);
self.previous = std::mem::replace(
&mut self.state,
HandshakeState::InitReceived {
chaining_key,
hash,
peer_ephemeral_public,
peer_index,
},
);
self.format_handshake_response(dst)
}
pub(super) fn receive_handshake_response(
&mut self,
packet: HandshakeResponse,
) -> Result<Session, WireGuardError> {
// Check if there is a handshake awaiting a response and return the correct one
let (state, is_previous) = match (&self.state, &self.previous) {
(HandshakeState::InitSent(s), _) if s.local_index == packet.receiver_idx => (s, false),
(_, HandshakeState::InitSent(s)) if s.local_index == packet.receiver_idx => (s, true),
_ => return Err(WireGuardError::UnexpectedPacket),
};
let peer_index = packet.sender_idx;
let local_index = state.local_index;
let unencrypted_ephemeral = x25519::PublicKey::from(*packet.unencrypted_ephemeral);
// msg.unencrypted_ephemeral = DH_PUBKEY(responder.ephemeral_private)
// responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral)
let mut hash = b2s_hash(&state.hash, unencrypted_ephemeral.as_bytes());
// temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral)
let temp = b2s_hmac(&state.chaining_key, unencrypted_ephemeral.as_bytes());
// responder.chaining_key = HMAC(temp, 0x1)
let mut chaining_key = b2s_hmac(&temp, &[0x01]);
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.ephemeral_public))
let ephemeral_shared = state
.ephemeral_private
.diffie_hellman(&unencrypted_ephemeral);
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
// responder.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.static_public))
let temp = b2s_hmac(
&chaining_key,
&self
.params
.static_private
.diffie_hellman(&unencrypted_ephemeral)
.to_bytes(),
);
// responder.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// temp = HMAC(responder.chaining_key, preshared_key)
let temp = b2s_hmac(
&chaining_key,
&self.params.preshared_key.unwrap_or([0u8; 32])[..],
);
// responder.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// temp2 = HMAC(temp, responder.chaining_key || 0x2)
let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]);
// key = HMAC(temp, temp2 || 0x3)
let key = b2s_hmac2(&temp, &temp2, &[0x03]);
// responder.hash = HASH(responder.hash || temp2)
hash = b2s_hash(&hash, &temp2);
// msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash)
aead_chacha20_open(&mut [], &key, 0, packet.encrypted_nothing, &hash)?;
// responder.hash = HASH(responder.hash || msg.encrypted_nothing)
// hash = b2s_hash(hash, buf[ENC_NOTHING_OFF..ENC_NOTHING_OFF + ENC_NOTHING_SZ]);
// Derive keys
// temp1 = HMAC(initiator.chaining_key, [empty])
// temp2 = HMAC(temp1, 0x1)
// temp3 = HMAC(temp1, temp2 || 0x2)
// initiator.sending_key = temp2
// initiator.receiving_key = temp3
// initiator.sending_key_counter = 0
// initiator.receiving_key_counter = 0
let temp1 = b2s_hmac(&chaining_key, &[]);
let temp2 = b2s_hmac(&temp1, &[0x01]);
let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]);
let rtt_time = Instant::now().duration_since(state.time_sent);
self.last_rtt = Some(rtt_time.as_millis() as u32);
if is_previous {
self.previous = HandshakeState::None;
} else {
self.state = HandshakeState::None;
}
Ok(Session::new(local_index, peer_index, temp3, temp2))
}
pub(super) fn receive_cookie_reply(
&mut self,
packet: PacketCookieReply,
) -> Result<(), WireGuardError> {
let mac1 = match self.cookies.last_mac1 {
Some(mac) => mac,
None => {
return Err(WireGuardError::UnexpectedPacket);
}
};
let local_index = self.cookies.index;
if packet.receiver_idx != local_index {
return Err(WireGuardError::WrongIndex);
}
// msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public), msg.nonce, cookie, last_received_msg.mac1)
let key = b2s_hash(LABEL_COOKIE, self.params.peer_static_public.as_bytes()); // TODO: pre-compute
let payload = Payload {
aad: &mac1[0..16],
msg: packet.encrypted_cookie,
};
let plaintext = XChaCha20Poly1305::new_from_slice(&key)
.unwrap()
.decrypt(packet.nonce.into(), payload)
.map_err(|_| WireGuardError::InvalidAeadTag)?;
let cookie = plaintext
.try_into()
.map_err(|_| WireGuardError::InvalidPacket)?;
self.cookies.write_cookie = Some(cookie);
Ok(())
}
// Compute and append mac1 and mac2 to a handshake message
fn append_mac1_and_mac2<'a>(
&mut self,
local_index: u32,
dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> {
let mac1_off = dst.len() - 32;
let mac2_off = dst.len() - 16;
// msg.mac1 = MAC(HASH(LABEL_MAC1 || responder.static_public), msg[0:offsetof(msg.mac1)])
let msg_mac1 = b2s_keyed_mac_16(&self.params.sending_mac1_key, &dst[..mac1_off]);
dst[mac1_off..mac2_off].copy_from_slice(&msg_mac1[..]);
//msg.mac2 = MAC(initiator.last_received_cookie, msg[0:offsetof(msg.mac2)])
let msg_mac2: [u8; 16] = if let Some(cookie) = self.cookies.write_cookie {
b2s_keyed_mac_16(&cookie, &dst[..mac2_off])
} else {
[0u8; 16]
};
dst[mac2_off..].copy_from_slice(&msg_mac2[..]);
self.cookies.index = local_index;
self.cookies.last_mac1 = Some(msg_mac1);
Ok(dst)
}
pub(super) fn format_handshake_initiation<'a>(
&mut self,
dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> {
if dst.len() < super::HANDSHAKE_INIT_SZ {
return Err(WireGuardError::DestinationBufferTooSmall);
}
let (message_type, rest) = dst.split_at_mut(4);
let (sender_index, rest) = rest.split_at_mut(4);
let (unencrypted_ephemeral, rest) = rest.split_at_mut(32);
let (encrypted_static, rest) = rest.split_at_mut(32 + 16);
let (encrypted_timestamp, _) = rest.split_at_mut(12 + 16);
let local_index = self.inc_index();
// initiator.chaining_key = HASH(CONSTRUCTION)
let mut chaining_key = INITIAL_CHAIN_KEY;
// initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public)
let mut hash = INITIAL_CHAIN_HASH;
hash = b2s_hash(&hash, self.params.peer_static_public.as_bytes());
// initiator.ephemeral_private = DH_GENERATE()
let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng);
// msg.message_type = 1
// msg.reserved_zero = { 0, 0, 0 }
message_type.copy_from_slice(&super::HANDSHAKE_INIT.to_le_bytes());
// msg.sender_index = little_endian(initiator.sender_index)
sender_index.copy_from_slice(&local_index.to_le_bytes());
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
unencrypted_ephemeral
.copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes());
// initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral)
hash = b2s_hash(&hash, unencrypted_ephemeral);
// temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral)
// initiator.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&b2s_hmac(&chaining_key, unencrypted_ephemeral), &[0x01]);
// temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public))
let ephemeral_shared = ephemeral_private.diffie_hellman(&self.params.peer_static_public);
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
// initiator.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// key = HMAC(temp, initiator.chaining_key || 0x2)
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
// msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash)
aead_chacha20_seal(
encrypted_static,
&key,
0,
self.params.static_public.as_bytes(),
&hash,
);
// initiator.hash = HASH(initiator.hash || msg.encrypted_static)
hash = b2s_hash(&hash, encrypted_static);
// temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public))
let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes());
// initiator.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// key = HMAC(temp, initiator.chaining_key || 0x2)
let key = b2s_hmac2(&temp, &chaining_key, &[0x02]);
// msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash)
let timestamp = self.stamper.stamp();
aead_chacha20_seal(encrypted_timestamp, &key, 0, &timestamp, &hash);
// initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp)
hash = b2s_hash(&hash, encrypted_timestamp);
let time_now = Instant::now();
self.previous = std::mem::replace(
&mut self.state,
HandshakeState::InitSent(HandshakeInitSentState {
local_index,
chaining_key,
hash,
ephemeral_private,
time_sent: time_now,
}),
);
self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_INIT_SZ])
}
fn format_handshake_response<'a>(
&mut self,
dst: &'a mut [u8],
) -> Result<(&'a mut [u8], Session), WireGuardError> {
if dst.len() < super::HANDSHAKE_RESP_SZ {
return Err(WireGuardError::DestinationBufferTooSmall);
}
let state = std::mem::replace(&mut self.state, HandshakeState::None);
let (mut chaining_key, mut hash, peer_ephemeral_public, peer_index) = match state {
HandshakeState::InitReceived {
chaining_key,
hash,
peer_ephemeral_public,
peer_index,
} => (chaining_key, hash, peer_ephemeral_public, peer_index),
_ => {
panic!("Unexpected attempt to call send_handshake_response");
}
};
let (message_type, rest) = dst.split_at_mut(4);
let (sender_index, rest) = rest.split_at_mut(4);
let (receiver_index, rest) = rest.split_at_mut(4);
let (unencrypted_ephemeral, rest) = rest.split_at_mut(32);
let (encrypted_nothing, _) = rest.split_at_mut(16);
// responder.ephemeral_private = DH_GENERATE()
let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng);
let local_index = self.inc_index();
// msg.message_type = 2
// msg.reserved_zero = { 0, 0, 0 }
message_type.copy_from_slice(&super::HANDSHAKE_RESP.to_le_bytes());
// msg.sender_index = little_endian(responder.sender_index)
sender_index.copy_from_slice(&local_index.to_le_bytes());
// msg.receiver_index = little_endian(initiator.sender_index)
receiver_index.copy_from_slice(&peer_index.to_le_bytes());
// msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private)
unencrypted_ephemeral
.copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes());
// responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral)
hash = b2s_hash(&hash, unencrypted_ephemeral);
// temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral)
let temp = b2s_hmac(&chaining_key, unencrypted_ephemeral);
// responder.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.ephemeral_public))
let ephemeral_shared = ephemeral_private.diffie_hellman(&peer_ephemeral_public);
let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes());
// responder.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.static_public))
let temp = b2s_hmac(
&chaining_key,
&ephemeral_private
.diffie_hellman(&self.params.peer_static_public)
.to_bytes(),
);
// responder.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// temp = HMAC(responder.chaining_key, preshared_key)
let temp = b2s_hmac(
&chaining_key,
&self.params.preshared_key.unwrap_or([0u8; 32])[..],
);
// responder.chaining_key = HMAC(temp, 0x1)
chaining_key = b2s_hmac(&temp, &[0x01]);
// temp2 = HMAC(temp, responder.chaining_key || 0x2)
let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]);
// key = HMAC(temp, temp2 || 0x3)
let key = b2s_hmac2(&temp, &temp2, &[0x03]);
// responder.hash = HASH(responder.hash || temp2)
hash = b2s_hash(&hash, &temp2);
// msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash)
aead_chacha20_seal(encrypted_nothing, &key, 0, &[], &hash);
// Derive keys
// temp1 = HMAC(initiator.chaining_key, [empty])
// temp2 = HMAC(temp1, 0x1)
// temp3 = HMAC(temp1, temp2 || 0x2)
// initiator.sending_key = temp2
// initiator.receiving_key = temp3
// initiator.sending_key_counter = 0
// initiator.receiving_key_counter = 0
let temp1 = b2s_hmac(&chaining_key, &[]);
let temp2 = b2s_hmac(&temp1, &[0x01]);
let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]);
let dst = self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_RESP_SZ])?;
Ok((dst, Session::new(local_index, peer_index, temp2, temp3)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chacha20_seal_rfc7530_test_vector() {
let plaintext = b"Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it.";
let aad: [u8; 12] = [
0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
];
let key: [u8; 32] = [
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d,
0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b,
0x9c, 0x9d, 0x9e, 0x9f,
];
let nonce: [u8; 12] = [
0x07, 0x00, 0x00, 0x00, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47,
];
let mut buffer = vec![0; plaintext.len() + 16];
aead_chacha20_seal_inner(&mut buffer, &key, nonce, plaintext, &aad);
const EXPECTED_CIPHERTEXT: [u8; 114] = [
0xd3, 0x1a, 0x8d, 0x34, 0x64, 0x8e, 0x60, 0xdb, 0x7b, 0x86, 0xaf, 0xbc, 0x53, 0xef,
0x7e, 0xc2, 0xa4, 0xad, 0xed, 0x51, 0x29, 0x6e, 0x08, 0xfe, 0xa9, 0xe2, 0xb5, 0xa7,
0x36, 0xee, 0x62, 0xd6, 0x3d, 0xbe, 0xa4, 0x5e, 0x8c, 0xa9, 0x67, 0x12, 0x82, 0xfa,
0xfb, 0x69, 0xda, 0x92, 0x72, 0x8b, 0x1a, 0x71, 0xde, 0x0a, 0x9e, 0x06, 0x0b, 0x29,
0x05, 0xd6, 0xa5, 0xb6, 0x7e, 0xcd, 0x3b, 0x36, 0x92, 0xdd, 0xbd, 0x7f, 0x2d, 0x77,
0x8b, 0x8c, 0x98, 0x03, 0xae, 0xe3, 0x28, 0x09, 0x1b, 0x58, 0xfa, 0xb3, 0x24, 0xe4,
0xfa, 0xd6, 0x75, 0x94, 0x55, 0x85, 0x80, 0x8b, 0x48, 0x31, 0xd7, 0xbc, 0x3f, 0xf4,
0xde, 0xf0, 0x8e, 0x4b, 0x7a, 0x9d, 0xe5, 0x76, 0xd2, 0x65, 0x86, 0xce, 0xc6, 0x4b,
0x61, 0x16,
];
const EXPECTED_TAG: [u8; 16] = [
0x1a, 0xe1, 0x0b, 0x59, 0x4f, 0x09, 0xe2, 0x6a, 0x7e, 0x90, 0x2e, 0xcb, 0xd0, 0x60,
0x06, 0x91,
];
assert_eq!(buffer[..plaintext.len()], EXPECTED_CIPHERTEXT);
assert_eq!(buffer[plaintext.len()..], EXPECTED_TAG);
}
#[test]
fn symmetric_chacha20_seal_open() {
let aad: [u8; 32] = Default::default();
let key: [u8; 32] = Default::default();
let counter = 0;
let mut encrypted_nothing: [u8; 16] = Default::default();
aead_chacha20_seal(&mut encrypted_nothing, &key, counter, &[], &aad);
eprintln!("encrypted_nothing: {:?}", encrypted_nothing);
aead_chacha20_open(&mut [], &key, counter, &encrypted_nothing, &aad)
.expect("Should open what we just sealed");
}
}

View File

@@ -0,0 +1,794 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
pub mod errors;
pub mod handshake;
pub mod rate_limiter;
mod session;
mod timers;
use crate::noise::errors::WireGuardError;
use crate::noise::handshake::Handshake;
use crate::noise::rate_limiter::RateLimiter;
use crate::noise::timers::{TimerName, Timers};
use crate::x25519;
use std::collections::VecDeque;
use std::convert::{TryFrom, TryInto};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use std::time::Duration;
/// The default value to use for rate limiting, when no other rate limiter is defined
const PEER_HANDSHAKE_RATE_LIMIT: u64 = 10;
const IPV4_MIN_HEADER_SIZE: usize = 20;
const IPV4_LEN_OFF: usize = 2;
const IPV4_SRC_IP_OFF: usize = 12;
const IPV4_DST_IP_OFF: usize = 16;
const IPV4_IP_SZ: usize = 4;
const IPV6_MIN_HEADER_SIZE: usize = 40;
const IPV6_LEN_OFF: usize = 4;
const IPV6_SRC_IP_OFF: usize = 8;
const IPV6_DST_IP_OFF: usize = 24;
const IPV6_IP_SZ: usize = 16;
const IP_LEN_SZ: usize = 2;
const MAX_QUEUE_DEPTH: usize = 256;
/// number of sessions in the ring, better keep a PoT
const N_SESSIONS: usize = 8;
#[derive(Debug)]
pub enum TunnResult<'a> {
Done,
Err(WireGuardError),
WriteToNetwork(&'a mut [u8]),
WriteToTunnelV4(&'a mut [u8], Ipv4Addr),
WriteToTunnelV6(&'a mut [u8], Ipv6Addr),
}
impl<'a> From<WireGuardError> for TunnResult<'a> {
fn from(err: WireGuardError) -> TunnResult<'a> {
TunnResult::Err(err)
}
}
/// Tunnel represents a point-to-point WireGuard connection
pub struct Tunn {
/// The handshake currently in progress
handshake: handshake::Handshake,
/// The N_SESSIONS most recent sessions, index is session id modulo N_SESSIONS
sessions: [Option<session::Session>; N_SESSIONS],
/// Index of most recently used session
current: usize,
/// Queue to store blocked packets
packet_queue: VecDeque<Vec<u8>>,
/// Keeps tabs on the expiring timers
timers: timers::Timers,
tx_bytes: usize,
rx_bytes: usize,
rate_limiter: Arc<RateLimiter>,
}
type MessageType = u32;
const HANDSHAKE_INIT: MessageType = 1;
const HANDSHAKE_RESP: MessageType = 2;
const COOKIE_REPLY: MessageType = 3;
const DATA: MessageType = 4;
const HANDSHAKE_INIT_SZ: usize = 148;
const HANDSHAKE_RESP_SZ: usize = 92;
const COOKIE_REPLY_SZ: usize = 64;
const DATA_OVERHEAD_SZ: usize = 32;
#[derive(Debug)]
pub struct HandshakeInit<'a> {
sender_idx: u32,
pub unencrypted_ephemeral: &'a [u8; 32],
encrypted_static: &'a [u8],
encrypted_timestamp: &'a [u8],
}
#[derive(Debug)]
pub struct HandshakeResponse<'a> {
sender_idx: u32,
pub receiver_idx: u32,
pub unencrypted_ephemeral: &'a [u8; 32],
encrypted_nothing: &'a [u8],
}
#[derive(Debug)]
pub struct PacketCookieReply<'a> {
pub receiver_idx: u32,
nonce: &'a [u8],
encrypted_cookie: &'a [u8],
}
#[derive(Debug)]
pub struct PacketData<'a> {
pub receiver_idx: u32,
counter: u64,
encrypted_encapsulated_packet: &'a [u8],
}
/// Describes a packet from network
#[derive(Debug)]
pub enum Packet<'a> {
HandshakeInit(HandshakeInit<'a>),
HandshakeResponse(HandshakeResponse<'a>),
PacketCookieReply(PacketCookieReply<'a>),
PacketData(PacketData<'a>),
}
impl Tunn {
#[inline(always)]
pub fn parse_incoming_packet(src: &[u8]) -> Result<Packet, WireGuardError> {
if src.len() < 4 {
return Err(WireGuardError::InvalidPacket);
}
// Checks the type, as well as the reserved zero fields
let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap());
Ok(match (packet_type, src.len()) {
(HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit {
sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[8..40])
.expect("length already checked above"),
encrypted_static: &src[40..88],
encrypted_timestamp: &src[88..116],
}),
(HANDSHAKE_RESP, HANDSHAKE_RESP_SZ) => Packet::HandshakeResponse(HandshakeResponse {
sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
receiver_idx: u32::from_le_bytes(src[8..12].try_into().unwrap()),
unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[12..44])
.expect("length already checked above"),
encrypted_nothing: &src[44..60],
}),
(COOKIE_REPLY, COOKIE_REPLY_SZ) => Packet::PacketCookieReply(PacketCookieReply {
receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
nonce: &src[8..32],
encrypted_cookie: &src[32..64],
}),
(DATA, DATA_OVERHEAD_SZ..=std::usize::MAX) => Packet::PacketData(PacketData {
receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
counter: u64::from_le_bytes(src[8..16].try_into().unwrap()),
encrypted_encapsulated_packet: &src[16..],
}),
_ => return Err(WireGuardError::InvalidPacket),
})
}
pub fn is_expired(&self) -> bool {
self.handshake.is_expired()
}
pub fn dst_address(packet: &[u8]) -> Option<IpAddr> {
if packet.is_empty() {
return None;
}
match packet[0] >> 4 {
4 if packet.len() >= IPV4_MIN_HEADER_SIZE => {
let addr_bytes: [u8; IPV4_IP_SZ] = packet
[IPV4_DST_IP_OFF..IPV4_DST_IP_OFF + IPV4_IP_SZ]
.try_into()
.unwrap();
Some(IpAddr::from(addr_bytes))
}
6 if packet.len() >= IPV6_MIN_HEADER_SIZE => {
let addr_bytes: [u8; IPV6_IP_SZ] = packet
[IPV6_DST_IP_OFF..IPV6_DST_IP_OFF + IPV6_IP_SZ]
.try_into()
.unwrap();
Some(IpAddr::from(addr_bytes))
}
_ => None,
}
}
/// Create a new tunnel using own private key and the peer public key
pub fn new(
static_private: x25519::StaticSecret,
peer_static_public: x25519::PublicKey,
preshared_key: Option<[u8; 32]>,
persistent_keepalive: Option<u16>,
index: u32,
rate_limiter: Option<Arc<RateLimiter>>,
) -> Self {
let static_public = x25519::PublicKey::from(&static_private);
Tunn {
handshake: Handshake::new(
static_private,
static_public,
peer_static_public,
index << 8,
preshared_key,
),
sessions: Default::default(),
current: Default::default(),
tx_bytes: Default::default(),
rx_bytes: Default::default(),
packet_queue: VecDeque::new(),
timers: Timers::new(persistent_keepalive, rate_limiter.is_none()),
rate_limiter: rate_limiter.unwrap_or_else(|| {
Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT))
}),
}
}
/// Update the private key and clear existing sessions
pub fn set_static_private(
&mut self,
static_private: x25519::StaticSecret,
static_public: x25519::PublicKey,
rate_limiter: Option<Arc<RateLimiter>>,
) {
self.timers.should_reset_rr = rate_limiter.is_none();
self.rate_limiter = rate_limiter.unwrap_or_else(|| {
Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT))
});
self.handshake
.set_static_private(static_private, static_public);
for s in &mut self.sessions {
*s = None;
}
}
/// Encapsulate a single packet from the tunnel interface.
/// Returns TunnResult.
///
/// # Panics
/// Panics if dst buffer is too small.
/// Size of dst should be at least src.len() + 32, and no less than 148 bytes.
pub fn encapsulate<'a>(&mut self, src: &[u8], dst: &'a mut [u8]) -> TunnResult<'a> {
let current = self.current;
if let Some(ref session) = self.sessions[current % N_SESSIONS] {
// Send the packet using an established session
let packet = session.format_packet_data(src, dst);
self.timer_tick(TimerName::TimeLastPacketSent);
// Exclude Keepalive packets from timer update.
if !src.is_empty() {
self.timer_tick(TimerName::TimeLastDataPacketSent);
}
self.tx_bytes += src.len();
return TunnResult::WriteToNetwork(packet);
}
// If there is no session, queue the packet for future retry
self.queue_packet(src);
// Initiate a new handshake if none is in progress
self.format_handshake_initiation(dst, false)
}
/// Receives a UDP datagram from the network and parses it.
/// Returns TunnResult.
///
/// If the result is of type TunnResult::WriteToNetwork, should repeat the call with empty datagram,
/// until TunnResult::Done is returned. If batch processing packets, it is OK to defer until last
/// packet is processed.
pub fn decapsulate<'a>(
&mut self,
src_addr: Option<IpAddr>,
datagram: &[u8],
dst: &'a mut [u8],
) -> TunnResult<'a> {
if datagram.is_empty() {
// Indicates a repeated call
return self.send_queued_packet(dst);
}
let mut cookie = [0u8; COOKIE_REPLY_SZ];
let packet = match self
.rate_limiter
.verify_packet(src_addr, datagram, &mut cookie)
{
Ok(packet) => packet,
Err(TunnResult::WriteToNetwork(cookie)) => {
dst[..cookie.len()].copy_from_slice(cookie);
return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]);
}
Err(TunnResult::Err(e)) => return TunnResult::Err(e),
_ => unreachable!(),
};
self.handle_verified_packet(packet, dst)
}
pub(crate) fn handle_verified_packet<'a>(
&mut self,
packet: Packet,
dst: &'a mut [u8],
) -> TunnResult<'a> {
match packet {
Packet::HandshakeInit(p) => self.handle_handshake_init(p, dst),
Packet::HandshakeResponse(p) => self.handle_handshake_response(p, dst),
Packet::PacketCookieReply(p) => self.handle_cookie_reply(p),
Packet::PacketData(p) => self.handle_data(p, dst),
}
.unwrap_or_else(TunnResult::from)
}
fn handle_handshake_init<'a>(
&mut self,
p: HandshakeInit,
dst: &'a mut [u8],
) -> Result<TunnResult<'a>, WireGuardError> {
tracing::debug!(
message = "Received handshake_initiation",
remote_idx = p.sender_idx
);
let (packet, session) = self.handshake.receive_handshake_initialization(p, dst)?;
// Store new session in ring buffer
let index = session.local_index();
self.sessions[index % N_SESSIONS] = Some(session);
self.timer_tick(TimerName::TimeLastPacketReceived);
self.timer_tick(TimerName::TimeLastPacketSent);
self.timer_tick_session_established(false, index); // New session established, we are not the initiator
tracing::debug!(message = "Sending handshake_response", local_idx = index);
Ok(TunnResult::WriteToNetwork(packet))
}
fn handle_handshake_response<'a>(
&mut self,
p: HandshakeResponse,
dst: &'a mut [u8],
) -> Result<TunnResult<'a>, WireGuardError> {
tracing::debug!(
message = "Received handshake_response",
local_idx = p.receiver_idx,
remote_idx = p.sender_idx
);
let session = self.handshake.receive_handshake_response(p)?;
let keepalive_packet = session.format_packet_data(&[], dst);
// Store new session in ring buffer
let l_idx = session.local_index();
let index = l_idx % N_SESSIONS;
self.sessions[index] = Some(session);
self.timer_tick(TimerName::TimeLastPacketReceived);
self.timer_tick_session_established(true, index); // New session established, we are the initiator
self.set_current_session(l_idx);
tracing::debug!("Sending keepalive");
Ok(TunnResult::WriteToNetwork(keepalive_packet)) // Send a keepalive as a response
}
fn handle_cookie_reply<'a>(
&mut self,
p: PacketCookieReply,
) -> Result<TunnResult<'a>, WireGuardError> {
tracing::debug!(
message = "Received cookie_reply",
local_idx = p.receiver_idx
);
self.handshake.receive_cookie_reply(p)?;
self.timer_tick(TimerName::TimeLastPacketReceived);
self.timer_tick(TimerName::TimeCookieReceived);
tracing::debug!("Did set cookie");
Ok(TunnResult::Done)
}
/// Update the index of the currently used session, if needed
fn set_current_session(&mut self, new_idx: usize) {
let cur_idx = self.current;
if cur_idx == new_idx {
// There is nothing to do, already using this session, this is the common case
return;
}
if self.sessions[cur_idx % N_SESSIONS].is_none()
|| self.timers.session_timers[new_idx % N_SESSIONS]
>= self.timers.session_timers[cur_idx % N_SESSIONS]
{
self.current = new_idx;
tracing::debug!(message = "New session", session = new_idx);
}
}
/// Decrypts a data packet, and stores the decapsulated packet in dst.
fn handle_data<'a>(
&mut self,
packet: PacketData,
dst: &'a mut [u8],
) -> Result<TunnResult<'a>, WireGuardError> {
let r_idx = packet.receiver_idx as usize;
let idx = r_idx % N_SESSIONS;
// Get the (probably) right session
let decapsulated_packet = {
let session = self.sessions[idx].as_ref();
let session = session.ok_or_else(|| {
tracing::trace!(message = "No current session available", remote_idx = r_idx);
WireGuardError::NoCurrentSession
})?;
session.receive_packet_data(packet, dst)?
};
self.set_current_session(r_idx);
self.timer_tick(TimerName::TimeLastPacketReceived);
Ok(self.validate_decapsulated_packet(decapsulated_packet))
}
/// Formats a new handshake initiation message and store it in dst. If force_resend is true will send
/// a new handshake, even if a handshake is already in progress (for example when a handshake times out)
pub fn format_handshake_initiation<'a>(
&mut self,
dst: &'a mut [u8],
force_resend: bool,
) -> TunnResult<'a> {
if self.handshake.is_in_progress() && !force_resend {
return TunnResult::Done;
}
if self.handshake.is_expired() {
self.timers.clear();
}
let starting_new_handshake = !self.handshake.is_in_progress();
match self.handshake.format_handshake_initiation(dst) {
Ok(packet) => {
tracing::debug!("Sending handshake_initiation");
if starting_new_handshake {
self.timer_tick(TimerName::TimeLastHandshakeStarted);
}
self.timer_tick(TimerName::TimeLastPacketSent);
TunnResult::WriteToNetwork(packet)
}
Err(e) => TunnResult::Err(e),
}
}
/// Check if an IP packet is v4 or v6, truncate to the length indicated by the length field
/// Returns the truncated packet and the source IP as TunnResult
fn validate_decapsulated_packet<'a>(&mut self, packet: &'a mut [u8]) -> TunnResult<'a> {
let (computed_len, src_ip_address) = match packet.len() {
0 => return TunnResult::Done, // This is keepalive, and not an error
_ if packet[0] >> 4 == 4 && packet.len() >= IPV4_MIN_HEADER_SIZE => {
let len_bytes: [u8; IP_LEN_SZ] = packet[IPV4_LEN_OFF..IPV4_LEN_OFF + IP_LEN_SZ]
.try_into()
.unwrap();
let addr_bytes: [u8; IPV4_IP_SZ] = packet
[IPV4_SRC_IP_OFF..IPV4_SRC_IP_OFF + IPV4_IP_SZ]
.try_into()
.unwrap();
(
u16::from_be_bytes(len_bytes) as usize,
IpAddr::from(addr_bytes),
)
}
_ if packet[0] >> 4 == 6 && packet.len() >= IPV6_MIN_HEADER_SIZE => {
let len_bytes: [u8; IP_LEN_SZ] = packet[IPV6_LEN_OFF..IPV6_LEN_OFF + IP_LEN_SZ]
.try_into()
.unwrap();
let addr_bytes: [u8; IPV6_IP_SZ] = packet
[IPV6_SRC_IP_OFF..IPV6_SRC_IP_OFF + IPV6_IP_SZ]
.try_into()
.unwrap();
(
u16::from_be_bytes(len_bytes) as usize + IPV6_MIN_HEADER_SIZE,
IpAddr::from(addr_bytes),
)
}
_ => return TunnResult::Err(WireGuardError::InvalidPacket),
};
if computed_len > packet.len() {
return TunnResult::Err(WireGuardError::InvalidPacket);
}
self.timer_tick(TimerName::TimeLastDataPacketReceived);
self.rx_bytes += computed_len;
match src_ip_address {
IpAddr::V4(addr) => TunnResult::WriteToTunnelV4(&mut packet[..computed_len], addr),
IpAddr::V6(addr) => TunnResult::WriteToTunnelV6(&mut packet[..computed_len], addr),
}
}
/// Get a packet from the queue, and try to encapsulate it
fn send_queued_packet<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
if let Some(packet) = self.dequeue_packet() {
match self.encapsulate(&packet, dst) {
TunnResult::Err(_) => {
// On error, return packet to the queue
self.requeue_packet(packet);
}
r => return r,
}
}
TunnResult::Done
}
/// Push packet to the back of the queue
fn queue_packet(&mut self, packet: &[u8]) {
if self.packet_queue.len() < MAX_QUEUE_DEPTH {
// Drop if too many are already in queue
self.packet_queue.push_back(packet.to_vec());
}
}
/// Push packet to the front of the queue
fn requeue_packet(&mut self, packet: Vec<u8>) {
if self.packet_queue.len() < MAX_QUEUE_DEPTH {
// Drop if too many are already in queue
self.packet_queue.push_front(packet);
}
}
fn dequeue_packet(&mut self) -> Option<Vec<u8>> {
self.packet_queue.pop_front()
}
fn estimate_loss(&self) -> f32 {
let session_idx = self.current;
let mut weight = 9.0;
let mut cur_avg = 0.0;
let mut total_weight = 0.0;
for i in 0..N_SESSIONS {
if let Some(ref session) = self.sessions[(session_idx.wrapping_sub(i)) % N_SESSIONS] {
let (expected, received) = session.current_packet_cnt();
let loss = if expected == 0 {
0.0
} else {
1.0 - received as f32 / expected as f32
};
cur_avg += loss * weight;
total_weight += weight;
weight /= 3.0;
}
}
if total_weight == 0.0 {
0.0
} else {
cur_avg / total_weight
}
}
/// Return stats from the tunnel:
/// * Time since last handshake in seconds
/// * Data bytes sent
/// * Data bytes received
pub fn stats(&self) -> (Option<Duration>, usize, usize, f32, Option<u32>) {
let time = self.time_since_last_handshake();
let tx_bytes = self.tx_bytes;
let rx_bytes = self.rx_bytes;
let loss = self.estimate_loss();
let rtt = self.handshake.last_rtt;
(time, tx_bytes, rx_bytes, loss, rtt)
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "mock-instant")]
use crate::noise::timers::{REKEY_AFTER_TIME, REKEY_TIMEOUT};
use super::*;
use rand_core::{OsRng, RngCore};
fn create_two_tuns() -> (Tunn, Tunn) {
let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
let my_public_key = x25519_dalek::PublicKey::from(&my_secret_key);
let my_idx = OsRng.next_u32();
let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key);
let their_idx = OsRng.next_u32();
let my_tun = Tunn::new(my_secret_key, their_public_key, None, None, my_idx, None);
let their_tun = Tunn::new(their_secret_key, my_public_key, None, None, their_idx, None);
(my_tun, their_tun)
}
fn create_handshake_init(tun: &mut Tunn) -> Vec<u8> {
let mut dst = vec![0u8; 2048];
let handshake_init = tun.format_handshake_initiation(&mut dst, false);
assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
sent
} else {
unreachable!();
};
handshake_init.into()
}
fn create_handshake_response(tun: &mut Tunn, handshake_init: &[u8]) -> Vec<u8> {
let mut dst = vec![0u8; 2048];
let handshake_resp = tun.decapsulate(None, handshake_init, &mut dst);
assert!(matches!(handshake_resp, TunnResult::WriteToNetwork(_)));
let handshake_resp = if let TunnResult::WriteToNetwork(sent) = handshake_resp {
sent
} else {
unreachable!();
};
handshake_resp.into()
}
fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec<u8> {
let mut dst = vec![0u8; 2048];
let keepalive = tun.decapsulate(None, handshake_resp, &mut dst);
assert!(matches!(keepalive, TunnResult::WriteToNetwork(_)));
let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive {
sent
} else {
unreachable!();
};
keepalive.into()
}
fn parse_keepalive(tun: &mut Tunn, keepalive: &[u8]) {
let mut dst = vec![0u8; 2048];
let keepalive = tun.decapsulate(None, keepalive, &mut dst);
assert!(matches!(keepalive, TunnResult::Done));
}
fn create_two_tuns_and_handshake() -> (Tunn, Tunn) {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, &init);
let keepalive = parse_handshake_resp(&mut my_tun, &resp);
parse_keepalive(&mut their_tun, &keepalive);
(my_tun, their_tun)
}
fn create_ipv4_udp_packet() -> Vec<u8> {
let header =
etherparse::PacketBuilder::ipv4([192, 168, 1, 2], [192, 168, 1, 3], 5).udp(5678, 23);
let payload = [0, 1, 2, 3];
let mut packet = Vec::<u8>::with_capacity(header.size(payload.len()));
header.write(&mut packet, &payload).unwrap();
packet
}
#[cfg(feature = "mock-instant")]
fn update_timer_results_in_handshake(tun: &mut Tunn) {
let mut dst = vec![0u8; 2048];
let result = tun.update_timers(&mut dst);
assert!(matches!(result, TunnResult::WriteToNetwork(_)));
let packet_data = if let TunnResult::WriteToNetwork(data) = result {
data
} else {
unreachable!();
};
let packet = Tunn::parse_incoming_packet(packet_data).unwrap();
assert!(matches!(packet, Packet::HandshakeInit(_)));
}
#[test]
fn create_two_tunnels_linked_to_eachother() {
let (_my_tun, _their_tun) = create_two_tuns();
}
#[test]
fn handshake_init() {
let (mut my_tun, _their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let packet = Tunn::parse_incoming_packet(&init).unwrap();
assert!(matches!(packet, Packet::HandshakeInit(_)));
}
#[test]
fn handshake_init_and_response() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, &init);
let packet = Tunn::parse_incoming_packet(&resp).unwrap();
assert!(matches!(packet, Packet::HandshakeResponse(_)));
}
#[test]
fn full_handshake() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, &init);
let keepalive = parse_handshake_resp(&mut my_tun, &resp);
let packet = Tunn::parse_incoming_packet(&keepalive).unwrap();
assert!(matches!(packet, Packet::PacketData(_)));
}
#[test]
fn full_handshake_plus_timers() {
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
// Time has not yet advanced so their is nothing to do
assert!(matches!(my_tun.update_timers(&mut []), TunnResult::Done));
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
}
#[test]
#[cfg(feature = "mock-instant")]
fn new_handshake_after_two_mins() {
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
let mut my_dst = [0u8; 1024];
// Advance time 1 second and "send" 1 packet so that we send a handshake
// after the timeout
mock_instant::MockClock::advance(Duration::from_secs(1));
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
assert!(matches!(
my_tun.update_timers(&mut my_dst),
TunnResult::Done
));
let sent_packet_buf = create_ipv4_udp_packet();
let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst);
assert!(matches!(data, TunnResult::WriteToNetwork(_)));
//Advance to timeout
mock_instant::MockClock::advance(REKEY_AFTER_TIME);
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
update_timer_results_in_handshake(&mut my_tun);
}
#[test]
#[cfg(feature = "mock-instant")]
fn handshake_no_resp_rekey_timeout() {
let (mut my_tun, _their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let packet = Tunn::parse_incoming_packet(&init).unwrap();
assert!(matches!(packet, Packet::HandshakeInit(_)));
mock_instant::MockClock::advance(REKEY_TIMEOUT);
update_timer_results_in_handshake(&mut my_tun)
}
#[test]
fn one_ip_packet() {
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
let mut my_dst = [0u8; 1024];
let mut their_dst = [0u8; 1024];
let sent_packet_buf = create_ipv4_udp_packet();
let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst);
assert!(matches!(data, TunnResult::WriteToNetwork(_)));
let data = if let TunnResult::WriteToNetwork(sent) = data {
sent
} else {
unreachable!();
};
let data = their_tun.decapsulate(None, data, &mut their_dst);
assert!(matches!(data, TunnResult::WriteToTunnelV4(..)));
let recv_packet_buf = if let TunnResult::WriteToTunnelV4(recv, _addr) = data {
recv
} else {
unreachable!();
};
assert_eq!(sent_packet_buf, recv_packet_buf);
}
}

View File

@@ -0,0 +1,193 @@
use super::handshake::{b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24};
use crate::noise::handshake::{LABEL_COOKIE, LABEL_MAC1};
use crate::noise::{HandshakeInit, HandshakeResponse, Packet, Tunn, TunnResult, WireGuardError};
#[cfg(feature = "mock-instant")]
use mock_instant::Instant;
use std::net::IpAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(not(feature = "mock-instant"))]
use crate::sleepyinstant::Instant;
use aead::generic_array::GenericArray;
use aead::{AeadInPlace, KeyInit};
use chacha20poly1305::{Key, XChaCha20Poly1305};
use parking_lot::Mutex;
use rand_core::{OsRng, RngCore};
use ring::constant_time::verify_slices_are_equal;
const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division
const COOKIE_SIZE: usize = 16;
const COOKIE_NONCE_SIZE: usize = 24;
/// How often should reset count in seconds
const RESET_PERIOD: u64 = 1;
type Cookie = [u8; COOKIE_SIZE];
/// There are two places where WireGuard requires "randomness" for cookies
/// * The 24 byte nonce in the cookie massage - here the only goal is to avoid nonce reuse
/// * A secret value that changes every two minutes
/// Because the main goal of the cookie is simply for a party to prove ownership of an IP address
/// we can relax the randomness definition a bit, in order to avoid locking, because using less
/// resources is the main goal of any DoS prevention mechanism.
/// In order to avoid locking and calls to rand we derive pseudo random values using the AEAD and
/// some counters.
pub struct RateLimiter {
/// The key we use to derive the nonce
nonce_key: [u8; 32],
/// The key we use to derive the cookie
secret_key: [u8; 16],
start_time: Instant,
/// A single 64 bit counter (should suffice for many years)
nonce_ctr: AtomicUsize,
mac1_key: [u8; 32],
cookie_key: Key,
limit: usize,
/// The counter since last reset
count: AtomicUsize,
/// The time last reset was performed on this rate limiter
last_reset: Mutex<Instant>,
}
impl RateLimiter {
pub fn new(public_key: &crate::x25519::PublicKey, limit: u64) -> Self {
let mut secret_key = [0u8; 16];
OsRng.fill_bytes(&mut secret_key);
RateLimiter {
nonce_key: Self::rand_bytes(),
secret_key,
start_time: Instant::now(),
nonce_ctr: AtomicUsize::new(0),
mac1_key: b2s_hash(LABEL_MAC1, public_key.as_bytes()),
cookie_key: b2s_hash(LABEL_COOKIE, public_key.as_bytes()).into(),
limit: limit as _,
count: AtomicUsize::new(0),
last_reset: Mutex::new(Instant::now()),
}
}
fn rand_bytes() -> [u8; 32] {
let mut key = [0u8; 32];
OsRng.fill_bytes(&mut key);
key
}
/// Reset packet count (ideally should be called with a period of 1 second)
pub fn reset_count(&self) {
// The rate limiter is not very accurate, but at the scale we care about it doesn't matter much
let current_time = Instant::now();
let mut last_reset_time = self.last_reset.lock();
if current_time.duration_since(*last_reset_time).as_secs() >= RESET_PERIOD {
self.count.store(0, Ordering::SeqCst);
*last_reset_time = current_time;
}
}
/// Compute the correct cookie value based on the current secret value and the source IP
fn current_cookie(&self, addr: IpAddr) -> Cookie {
let mut addr_bytes = [0u8; 16];
match addr {
IpAddr::V4(a) => addr_bytes[..4].copy_from_slice(&a.octets()[..]),
IpAddr::V6(a) => addr_bytes[..].copy_from_slice(&a.octets()[..]),
}
// The current cookie for a given IP is the MAC(responder.changing_secret_every_two_minutes, initiator.ip_address)
// First we derive the secret from the current time, the value of cur_counter would change with time.
let cur_counter = Instant::now().duration_since(self.start_time).as_secs() / COOKIE_REFRESH;
// Next we derive the cookie
b2s_keyed_mac_16_2(&self.secret_key, &cur_counter.to_le_bytes(), &addr_bytes)
}
fn nonce(&self) -> [u8; COOKIE_NONCE_SIZE] {
let ctr = self.nonce_ctr.fetch_add(1, Ordering::Relaxed);
b2s_mac_24(&self.nonce_key, &ctr.to_le_bytes())
}
fn is_under_load(&self) -> bool {
self.count.fetch_add(1, Ordering::SeqCst) >= self.limit
}
pub(crate) fn format_cookie_reply<'a>(
&self,
idx: u32,
cookie: Cookie,
mac1: &[u8],
dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> {
if dst.len() < super::COOKIE_REPLY_SZ {
return Err(WireGuardError::DestinationBufferTooSmall);
}
let (message_type, rest) = dst.split_at_mut(4);
let (receiver_index, rest) = rest.split_at_mut(4);
let (nonce, rest) = rest.split_at_mut(24);
let (encrypted_cookie, _) = rest.split_at_mut(16 + 16);
// msg.message_type = 3
// msg.reserved_zero = { 0, 0, 0 }
message_type.copy_from_slice(&super::COOKIE_REPLY.to_le_bytes());
// msg.receiver_index = little_endian(initiator.sender_index)
receiver_index.copy_from_slice(&idx.to_le_bytes());
nonce.copy_from_slice(&self.nonce()[..]);
let cipher = XChaCha20Poly1305::new(&self.cookie_key);
let iv = GenericArray::from_slice(nonce);
encrypted_cookie[..16].copy_from_slice(&cookie);
let tag = cipher
.encrypt_in_place_detached(iv, mac1, &mut encrypted_cookie[..16])
.map_err(|_| WireGuardError::DestinationBufferTooSmall)?;
encrypted_cookie[16..].copy_from_slice(&tag);
Ok(&mut dst[..super::COOKIE_REPLY_SZ])
}
/// Verify the MAC fields on the datagram, and apply rate limiting if needed
pub fn verify_packet<'a, 'b>(
&self,
src_addr: Option<IpAddr>,
src: &'a [u8],
dst: &'b mut [u8],
) -> Result<Packet<'a>, TunnResult<'b>> {
let packet = Tunn::parse_incoming_packet(src)?;
// Verify and rate limit handshake messages only
if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. })
| Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet
{
let (msg, macs) = src.split_at(src.len() - 32);
let (mac1, mac2) = macs.split_at(16);
let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, msg);
verify_slices_are_equal(&computed_mac1[..16], mac1)
.map_err(|_| TunnResult::Err(WireGuardError::InvalidMac))?;
if self.is_under_load() {
let addr = match src_addr {
None => return Err(TunnResult::Err(WireGuardError::UnderLoad)),
Some(addr) => addr,
};
// Only given an address can we validate mac2
let cookie = self.current_cookie(addr);
let computed_mac2 = b2s_keyed_mac_16_2(&cookie, msg, mac1);
if verify_slices_are_equal(&computed_mac2[..16], mac2).is_err() {
let cookie_packet = self
.format_cookie_reply(sender_idx, cookie, mac1, dst)
.map_err(TunnResult::Err)?;
return Err(TunnResult::WriteToNetwork(cookie_packet));
}
}
}
Ok(packet)
}
}

View File

@@ -0,0 +1,329 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use super::PacketData;
use crate::noise::errors::WireGuardError;
use parking_lot::Mutex;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use std::sync::atomic::{AtomicUsize, Ordering};
pub struct Session {
pub(crate) receiving_index: u32,
sending_index: u32,
receiver: LessSafeKey,
sender: LessSafeKey,
sending_key_counter: AtomicUsize,
receiving_key_counter: Mutex<ReceivingKeyCounterValidator>,
}
impl std::fmt::Debug for Session {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"Session: {}<- ->{}",
self.receiving_index, self.sending_index
)
}
}
/// Where encrypted data resides in a data packet
const DATA_OFFSET: usize = 16;
/// The overhead of the AEAD
const AEAD_SIZE: usize = 16;
// Receiving buffer constants
const WORD_SIZE: u64 = 64;
const N_WORDS: u64 = 16; // Suffice to reorder 64*16 = 1024 packets; can be increased at will
const N_BITS: u64 = WORD_SIZE * N_WORDS;
#[derive(Debug, Clone, Default)]
struct ReceivingKeyCounterValidator {
/// In order to avoid replays while allowing for some reordering of the packets, we keep a
/// bitmap of received packets, and the value of the highest counter
next: u64,
/// Used to estimate packet loss
receive_cnt: u64,
bitmap: [u64; N_WORDS as usize],
}
impl ReceivingKeyCounterValidator {
#[inline(always)]
fn set_bit(&mut self, idx: u64) {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
let bit = (bit_idx % WORD_SIZE) as usize;
self.bitmap[word] |= 1 << bit;
}
#[inline(always)]
fn clear_bit(&mut self, idx: u64) {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
let bit = (bit_idx % WORD_SIZE) as usize;
self.bitmap[word] &= !(1u64 << bit);
}
/// Clear the word that contains idx
#[inline(always)]
fn clear_word(&mut self, idx: u64) {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
self.bitmap[word] = 0;
}
/// Returns true if bit is set, false otherwise
#[inline(always)]
fn check_bit(&self, idx: u64) -> bool {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
let bit = (bit_idx % WORD_SIZE) as usize;
((self.bitmap[word] >> bit) & 1) == 1
}
/// Returns true if the counter was not yet received, and is not too far back
#[inline(always)]
fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> {
if counter >= self.next {
// As long as the counter is growing no replay took place for sure
return Ok(());
}
if counter + N_BITS < self.next {
// Drop if too far back
return Err(WireGuardError::InvalidCounter);
}
if !self.check_bit(counter) {
Ok(())
} else {
Err(WireGuardError::DuplicateCounter)
}
}
/// Marks the counter as received, and returns true if it is still good (in case during
/// decryption something changed)
#[inline(always)]
fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> {
if counter + N_BITS < self.next {
// Drop if too far back
return Err(WireGuardError::InvalidCounter);
}
if counter == self.next {
// Usually the packets arrive in order, in that case we simply mark the bit and
// increment the counter
self.set_bit(counter);
self.next += 1;
return Ok(());
}
if counter < self.next {
// A packet arrived out of order, check if it is valid, and mark
if self.check_bit(counter) {
return Err(WireGuardError::InvalidCounter);
}
self.set_bit(counter);
return Ok(());
}
// Packets where dropped, or maybe reordered, skip them and mark unused
if counter - self.next >= N_BITS {
// Too far ahead, clear all the bits
for c in self.bitmap.iter_mut() {
*c = 0;
}
} else {
let mut i = self.next;
while i % WORD_SIZE != 0 && i < counter {
// Clear until i aligned to word size
self.clear_bit(i);
i += 1;
}
while i + WORD_SIZE < counter {
// Clear whole word at a time
self.clear_word(i);
i = (i + WORD_SIZE) & 0u64.wrapping_sub(WORD_SIZE);
}
while i < counter {
// Clear any remaining bits
self.clear_bit(i);
i += 1;
}
}
self.set_bit(counter);
self.next = counter + 1;
Ok(())
}
}
impl Session {
pub(super) fn new(
local_index: u32,
peer_index: u32,
receiving_key: [u8; 32],
sending_key: [u8; 32],
) -> Session {
Session {
receiving_index: local_index,
sending_index: peer_index,
receiver: LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &receiving_key).unwrap(),
),
sender: LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &sending_key).unwrap()),
sending_key_counter: AtomicUsize::new(0),
receiving_key_counter: Mutex::new(Default::default()),
}
}
pub(super) fn local_index(&self) -> usize {
self.receiving_index as usize
}
/// Returns true if receiving counter is good to use
fn receiving_counter_quick_check(&self, counter: u64) -> Result<(), WireGuardError> {
let counter_validator = self.receiving_key_counter.lock();
counter_validator.will_accept(counter)
}
/// Returns true if receiving counter is good to use, and marks it as used {
fn receiving_counter_mark(&self, counter: u64) -> Result<(), WireGuardError> {
let mut counter_validator = self.receiving_key_counter.lock();
let ret = counter_validator.mark_did_receive(counter);
if ret.is_ok() {
counter_validator.receive_cnt += 1;
}
ret
}
/// src - an IP packet from the interface
/// dst - pre-allocated space to hold the encapsulating UDP packet to send over the network
/// returns the size of the formatted packet
pub(super) fn format_packet_data<'a>(&self, src: &[u8], dst: &'a mut [u8]) -> &'a mut [u8] {
if dst.len() < src.len() + super::DATA_OVERHEAD_SZ {
panic!("The destination buffer is too small");
}
let sending_key_counter = self.sending_key_counter.fetch_add(1, Ordering::Relaxed) as u64;
let (message_type, rest) = dst.split_at_mut(4);
let (receiver_index, rest) = rest.split_at_mut(4);
let (counter, data) = rest.split_at_mut(8);
message_type.copy_from_slice(&super::DATA.to_le_bytes());
receiver_index.copy_from_slice(&self.sending_index.to_le_bytes());
counter.copy_from_slice(&sending_key_counter.to_le_bytes());
// TODO: spec requires padding to 16 bytes, but actually works fine without it
let n = {
let mut nonce = [0u8; 12];
nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes());
data[..src.len()].copy_from_slice(src);
self.sender
.seal_in_place_separate_tag(
Nonce::assume_unique_for_key(nonce),
Aad::from(&[]),
&mut data[..src.len()],
)
.map(|tag| {
data[src.len()..src.len() + AEAD_SIZE].copy_from_slice(tag.as_ref());
src.len() + AEAD_SIZE
})
.unwrap()
};
&mut dst[..DATA_OFFSET + n]
}
/// packet - a data packet we received from the network
/// dst - pre-allocated space to hold the encapsulated IP packet, to send to the interface
/// dst will always take less space than src
/// return the size of the encapsulated packet on success
pub(super) fn receive_packet_data<'a>(
&self,
packet: PacketData,
dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> {
let ct_len = packet.encrypted_encapsulated_packet.len();
if dst.len() < ct_len {
// This is a very incorrect use of the library, therefore panic and not error
panic!("The destination buffer is too small");
}
if packet.receiver_idx != self.receiving_index {
return Err(WireGuardError::WrongIndex);
}
// Don't reuse counters, in case this is a replay attack we want to quickly check the counter without running expensive decryption
self.receiving_counter_quick_check(packet.counter)?;
let ret = {
let mut nonce = [0u8; 12];
nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes());
dst[..ct_len].copy_from_slice(packet.encrypted_encapsulated_packet);
self.receiver
.open_in_place(
Nonce::assume_unique_for_key(nonce),
Aad::from(&[]),
&mut dst[..ct_len],
)
.map_err(|_| WireGuardError::InvalidAeadTag)?
};
// After decryption is done, check counter again, and mark as received
self.receiving_counter_mark(packet.counter)?;
Ok(ret)
}
/// Returns the estimated downstream packet loss for this session
pub(super) fn current_packet_cnt(&self) -> (u64, u64) {
let counter_validator = self.receiving_key_counter.lock();
(counter_validator.next, counter_validator.receive_cnt)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replay_counter() {
let mut c: ReceivingKeyCounterValidator = Default::default();
assert!(c.mark_did_receive(0).is_ok());
assert!(c.mark_did_receive(0).is_err());
assert!(c.mark_did_receive(1).is_ok());
assert!(c.mark_did_receive(1).is_err());
assert!(c.mark_did_receive(63).is_ok());
assert!(c.mark_did_receive(63).is_err());
assert!(c.mark_did_receive(15).is_ok());
assert!(c.mark_did_receive(15).is_err());
for i in 64..N_BITS + 128 {
assert!(c.mark_did_receive(i).is_ok());
assert!(c.mark_did_receive(i).is_err());
}
assert!(c.mark_did_receive(N_BITS * 3).is_ok());
for i in 0..=N_BITS * 2 {
assert!(matches!(
c.will_accept(i),
Err(WireGuardError::InvalidCounter)
));
assert!(c.mark_did_receive(i).is_err());
}
for i in N_BITS * 2 + 1..N_BITS * 3 {
assert!(c.will_accept(i).is_ok());
}
assert!(matches!(
c.will_accept(N_BITS * 3),
Err(WireGuardError::DuplicateCounter)
));
for i in (N_BITS * 2 + 1..N_BITS * 3).rev() {
assert!(c.mark_did_receive(i).is_ok());
assert!(c.mark_did_receive(i).is_err());
}
assert!(c.mark_did_receive(N_BITS * 3 + 70).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 71).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 72).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 72 + 125).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 63).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 70).is_err());
assert!(c.mark_did_receive(N_BITS * 3 + 71).is_err());
assert!(c.mark_did_receive(N_BITS * 3 + 72).is_err());
}
}

View File

@@ -0,0 +1,335 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
use super::errors::WireGuardError;
use crate::noise::{Tunn, TunnResult};
use std::mem;
use std::ops::{Index, IndexMut};
use std::time::Duration;
#[cfg(feature = "mock-instant")]
use mock_instant::Instant;
#[cfg(not(feature = "mock-instant"))]
use crate::sleepyinstant::Instant;
// Some constants, represent time in seconds
// https://www.wireguard.com/papers/wireguard.pdf#page=14
pub(crate) const REKEY_AFTER_TIME: Duration = Duration::from_secs(120);
const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90);
pub(crate) const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
const COOKIE_EXPIRATION_TIME: Duration = Duration::from_secs(120);
#[derive(Debug)]
pub enum TimerName {
/// Current time, updated each call to `update_timers`
TimeCurrent,
/// Time when last handshake was completed
TimeSessionEstablished,
/// Time the last attempt for a new handshake began
TimeLastHandshakeStarted,
/// Time we last received and authenticated a packet
TimeLastPacketReceived,
/// Time we last send a packet
TimeLastPacketSent,
/// Time we last received and authenticated a DATA packet
TimeLastDataPacketReceived,
/// Time we last send a DATA packet
TimeLastDataPacketSent,
/// Time we last received a cookie
TimeCookieReceived,
/// Time we last sent persistent keepalive
TimePersistentKeepalive,
Top,
}
use self::TimerName::*;
#[derive(Debug)]
pub struct Timers {
/// Is the owner of the timer the initiator or the responder for the last handshake?
is_initiator: bool,
/// Start time of the tunnel
time_started: Instant,
timers: [Duration; TimerName::Top as usize],
pub(super) session_timers: [Duration; super::N_SESSIONS],
/// Did we receive data without sending anything back?
want_keepalive: bool,
/// Did we send data without hearing back?
want_handshake: bool,
persistent_keepalive: usize,
/// Should this timer call reset rr function (if not a shared rr instance)
pub(super) should_reset_rr: bool,
}
impl Timers {
pub(super) fn new(persistent_keepalive: Option<u16>, reset_rr: bool) -> Timers {
Timers {
is_initiator: false,
time_started: Instant::now(),
timers: Default::default(),
session_timers: Default::default(),
want_keepalive: Default::default(),
want_handshake: Default::default(),
persistent_keepalive: usize::from(persistent_keepalive.unwrap_or(0)),
should_reset_rr: reset_rr,
}
}
fn is_initiator(&self) -> bool {
self.is_initiator
}
// We don't really clear the timers, but we set them to the current time to
// so the reference time frame is the same
pub(super) fn clear(&mut self) {
let now = Instant::now().duration_since(self.time_started);
for t in &mut self.timers[..] {
*t = now;
}
self.want_handshake = false;
self.want_keepalive = false;
}
}
impl Index<TimerName> for Timers {
type Output = Duration;
fn index(&self, index: TimerName) -> &Duration {
&self.timers[index as usize]
}
}
impl IndexMut<TimerName> for Timers {
fn index_mut(&mut self, index: TimerName) -> &mut Duration {
&mut self.timers[index as usize]
}
}
impl Tunn {
pub(super) fn timer_tick(&mut self, timer_name: TimerName) {
match timer_name {
TimeLastPacketReceived => {
self.timers.want_keepalive = true;
self.timers.want_handshake = false;
}
TimeLastPacketSent => {
self.timers.want_handshake = true;
self.timers.want_keepalive = false;
}
_ => {}
}
let time = self.timers[TimeCurrent];
self.timers[timer_name] = time;
}
pub(super) fn timer_tick_session_established(
&mut self,
is_initiator: bool,
session_idx: usize,
) {
self.timer_tick(TimeSessionEstablished);
self.timers.session_timers[session_idx % crate::noise::N_SESSIONS] =
self.timers[TimeCurrent];
self.timers.is_initiator = is_initiator;
}
// We don't really clear the timers, but we set them to the current time to
// so the reference time frame is the same
fn clear_all(&mut self) {
for session in &mut self.sessions {
*session = None;
}
self.packet_queue.clear();
self.timers.clear();
}
fn update_session_timers(&mut self, time_now: Duration) {
let timers = &mut self.timers;
for (i, t) in timers.session_timers.iter_mut().enumerate() {
if time_now - *t > REJECT_AFTER_TIME {
if let Some(session) = self.sessions[i].take() {
tracing::debug!(
message = "SESSION_EXPIRED(REJECT_AFTER_TIME)",
session = session.receiving_index
);
}
*t = time_now;
}
}
}
pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
let mut handshake_initiation_required = false;
let mut keepalive_required = false;
let time = Instant::now();
if self.timers.should_reset_rr {
self.rate_limiter.reset_count();
}
// All the times are counted from tunnel initiation, for efficiency our timers are rounded
// to a second, as there is no real benefit to having highly accurate timers.
let now = time.duration_since(self.timers.time_started);
self.timers[TimeCurrent] = now;
self.update_session_timers(now);
// Load timers only once:
let session_established = self.timers[TimeSessionEstablished];
let handshake_started = self.timers[TimeLastHandshakeStarted];
let aut_packet_received = self.timers[TimeLastPacketReceived];
let aut_packet_sent = self.timers[TimeLastPacketSent];
let data_packet_received = self.timers[TimeLastDataPacketReceived];
let data_packet_sent = self.timers[TimeLastDataPacketSent];
let persistent_keepalive = self.timers.persistent_keepalive;
{
if self.handshake.is_expired() {
return TunnResult::Err(WireGuardError::ConnectionExpired);
}
// Clear cookie after COOKIE_EXPIRATION_TIME
if self.handshake.has_cookie()
&& now - self.timers[TimeCookieReceived] >= COOKIE_EXPIRATION_TIME
{
self.handshake.clear_cookie();
}
// All ephemeral private keys and symmetric session keys are zeroed out after
// (REJECT_AFTER_TIME * 3) ms if no new keys have been exchanged.
if now - session_established >= REJECT_AFTER_TIME * 3 {
tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)");
self.handshake.set_expired();
self.clear_all();
return TunnResult::Err(WireGuardError::ConnectionExpired);
}
if let Some(time_init_sent) = self.handshake.timer() {
// Handshake Initiation Retransmission
if now - handshake_started >= REKEY_ATTEMPT_TIME {
// After REKEY_ATTEMPT_TIME ms of trying to initiate a new handshake,
// the retries give up and cease, and clear all existing packets queued
// up to be sent. If a packet is explicitly queued up to be sent, then
// this timer is reset.
tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)");
self.handshake.set_expired();
self.clear_all();
return TunnResult::Err(WireGuardError::ConnectionExpired);
}
if time_init_sent.elapsed() >= REKEY_TIMEOUT {
// We avoid using `time` here, because it can be earlier than `time_init_sent`.
// Once `checked_duration_since` is stable we can use that.
// A handshake initiation is retried after REKEY_TIMEOUT + jitter ms,
// if a response has not been received, where jitter is some random
// value between 0 and 333 ms.
tracing::warn!("HANDSHAKE(REKEY_TIMEOUT)");
handshake_initiation_required = true;
}
} else {
if self.timers.is_initiator() {
// After sending a packet, if the sender was the original initiator
// of the handshake and if the current session key is REKEY_AFTER_TIME
// ms old, we initiate a new handshake. If the sender was the original
// responder of the handshake, it does not re-initiate a new handshake
// after REKEY_AFTER_TIME ms like the original initiator does.
if session_established < data_packet_sent
&& now - session_established >= REKEY_AFTER_TIME
{
tracing::debug!("HANDSHAKE(REKEY_AFTER_TIME (on send))");
handshake_initiation_required = true;
}
// After receiving a packet, if the receiver was the original initiator
// of the handshake and if the current session key is REJECT_AFTER_TIME
// - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT ms old, we initiate a new
// handshake.
if session_established < data_packet_received
&& now - session_established
>= REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT
{
tracing::warn!(
"HANDSHAKE(REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - \
REKEY_TIMEOUT \
(on receive))"
);
handshake_initiation_required = true;
}
}
// If we have sent a packet to a given peer but have not received a
// packet after from that peer for (KEEPALIVE + REKEY_TIMEOUT) ms,
// we initiate a new handshake.
if data_packet_sent > aut_packet_received
&& now - aut_packet_received >= KEEPALIVE_TIMEOUT + REKEY_TIMEOUT
&& mem::replace(&mut self.timers.want_handshake, false)
{
tracing::warn!("HANDSHAKE(KEEPALIVE + REKEY_TIMEOUT)");
handshake_initiation_required = true;
}
if !handshake_initiation_required {
// If a packet has been received from a given peer, but we have not sent one back
// to the given peer in KEEPALIVE ms, we send an empty packet.
if data_packet_received > aut_packet_sent
&& now - aut_packet_sent >= KEEPALIVE_TIMEOUT
&& mem::replace(&mut self.timers.want_keepalive, false)
{
tracing::debug!("KEEPALIVE(KEEPALIVE_TIMEOUT)");
keepalive_required = true;
}
// Persistent KEEPALIVE
if persistent_keepalive > 0
&& (now - self.timers[TimePersistentKeepalive]
>= Duration::from_secs(persistent_keepalive as _))
{
tracing::debug!("KEEPALIVE(PERSISTENT_KEEPALIVE)");
self.timer_tick(TimePersistentKeepalive);
keepalive_required = true;
}
}
}
}
if handshake_initiation_required {
return self.format_handshake_initiation(dst, true);
}
if keepalive_required {
return self.encapsulate(&[], dst);
}
TunnResult::Done
}
pub fn time_since_last_handshake(&self) -> Option<Duration> {
let current_session = self.current;
if self.sessions[current_session % super::N_SESSIONS].is_some() {
let duration_since_tun_start = Instant::now().duration_since(self.timers.time_started);
let duration_since_session_established = self.timers[TimeSessionEstablished];
Some(duration_since_tun_start - duration_since_session_established)
} else {
None
}
}
pub fn persistent_keepalive(&self) -> Option<u16> {
let keepalive = self.timers.persistent_keepalive;
if keepalive > 0 {
Some(keepalive as u16)
} else {
None
}
}
}

View File

@@ -0,0 +1,33 @@
pub(crate) struct KeyBytes(pub [u8; 32]);
impl std::str::FromStr for KeyBytes {
type Err = &'static str;
/// Can parse a secret key from a hex or base64 encoded string.
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut internal = [0u8; 32];
match s.len() {
64 => {
// Try to parse as hex
for i in 0..32 {
internal[i] = u8::from_str_radix(&s[i * 2..=i * 2 + 1], 16)
.map_err(|_| "Illegal character in key")?;
}
}
43 | 44 => {
// Try to parse as base64
if let Ok(decoded_key) = base64::decode(s) {
if decoded_key.len() == internal.len() {
internal[..].copy_from_slice(&decoded_key);
} else {
return Err("Illegal character in key");
}
}
}
_ => return Err("Illegal key size"),
}
Ok(KeyBytes(internal))
}
}

View File

@@ -0,0 +1,77 @@
#![forbid(unsafe_code)]
//! Attempts to provide the same functionality as std::time::Instant, except it
//! uses a timer which accounts for time when the system is asleep
use std::time::Duration;
#[cfg(target_os = "windows")]
mod windows;
#[cfg(target_os = "windows")]
use windows as inner;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
use unix as inner;
/// A measurement of a monotonically nondecreasing clock.
/// Opaque and useful only with [`Duration`].
///
/// Instants are always guaranteed, barring [platform bugs], to be no less than any previously
/// measured instant when created, and are often useful for tasks such as measuring
/// benchmarks or timing how long an operation takes.
///
/// Note, however, that instants are **not** guaranteed to be **steady**. In other
/// words, each tick of the underlying clock might not be the same length (e.g.
/// some seconds may be longer than others). An instant may jump forwards or
/// experience time dilation (slow down or speed up), but it will never go
/// backwards.
///
/// Instants are opaque types that can only be compared to one another. There is
/// no method to get "the number of seconds" from an instant. Instead, it only
/// allows measuring the duration between two instants (or comparing two
/// instants).
///
/// The size of an `Instant` struct may vary depending on the target operating
/// system.
///
#[derive(Clone, Copy, Debug)]
pub struct Instant {
t: inner::Instant,
}
impl Instant {
/// Returns an instant corresponding to "now".
pub fn now() -> Self {
Self {
t: inner::Instant::now(),
}
}
/// Returns the amount of time elapsed from another instant to this one,
/// or zero duration if that instant is later than this one.
///
/// # Panics
///
/// panics when `earlier` was later than `self`.
pub fn duration_since(&self, earlier: Instant) -> Duration {
self.t.duration_since(earlier.t)
}
/// Returns the amount of time elapsed since this instant was created.
pub fn elapsed(&self) -> Duration {
Self::now().duration_since(*self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn time_increments_after_sleep() {
let sleep_time = Duration::from_millis(10);
let start = Instant::now();
std::thread::sleep(sleep_time);
assert!(start.elapsed() >= sleep_time);
}
}

View File

@@ -0,0 +1,48 @@
use std::time::Duration;
use nix::sys::time::TimeSpec;
use nix::time::{clock_gettime, ClockId};
#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))]
const CLOCK_ID: ClockId = ClockId::CLOCK_MONOTONIC;
#[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "tvos")))]
const CLOCK_ID: ClockId = ClockId::CLOCK_BOOTTIME;
#[derive(Clone, Copy, Debug)]
pub(crate) struct Instant {
t: TimeSpec,
}
impl Instant {
pub(crate) fn now() -> Self {
// std::time::Instant unwraps as well, so feel safe doing so here
let t = clock_gettime(CLOCK_ID).unwrap();
Self { t }
}
fn checked_duration_since(&self, earlier: Instant) -> Option<Duration> {
const NANOSECOND: nix::libc::c_long = 1_000_000_000;
let (tv_sec, tv_nsec) = if self.t.tv_nsec() < earlier.t.tv_nsec() {
(
self.t.tv_sec() - earlier.t.tv_sec() - 1,
self.t.tv_nsec() - earlier.t.tv_nsec() + NANOSECOND,
)
} else {
(
self.t.tv_sec() - earlier.t.tv_sec(),
self.t.tv_nsec() - earlier.t.tv_nsec(),
)
};
if tv_sec < 0 {
None
} else {
Some(Duration::new(tv_sec as _, tv_nsec as _))
}
}
pub(crate) fn duration_since(&self, earlier: Instant) -> Duration {
self.checked_duration_since(earlier)
.unwrap_or(Duration::ZERO)
}
}

View File

@@ -0,0 +1 @@
pub(crate) use std::time::Instant;

View File

@@ -0,0 +1,106 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
#pragma once
#include <stdint.h>
#include <stdbool.h>
struct wireguard_tunnel; // This corresponds to the Rust type
enum
{
MAX_WIREGUARD_PACKET_SIZE = 65536 + 64,
};
enum result_type
{
WIREGUARD_DONE = 0,
WRITE_TO_NETWORK = 1,
WIREGUARD_ERROR = 2,
WRITE_TO_TUNNEL_IPV4 = 4,
WRITE_TO_TUNNEL_IPV6 = 6,
};
struct wireguard_result
{
enum result_type op;
size_t size;
};
struct stats
{
int64_t time_since_last_handshake;
size_t tx_bytes;
size_t rx_bytes;
float estimated_loss;
int32_t estimated_rtt; // rtt estimated on time it took to complete latest initiated handshake in ms
uint8_t reserved[56]; // decrement appropriately when adding new fields
};
struct x25519_key
{
uint8_t key[32];
};
// Generates a fresh x25519 secret key
struct x25519_key x25519_secret_key();
// Computes an x25519 public key from a secret key
struct x25519_key x25519_public_key(struct x25519_key private_key);
// Encodes a public or private x25519 key to base64. Must be freed with x25519_key_to_str_free.
const char *x25519_key_to_base64(struct x25519_key key);
// Encodes a public or private x25519 key to hex. Must be freed with x25519_key_to_str_free.
const char *x25519_key_to_hex(struct x25519_key key);
// Free string pointer obtained from either x25519_key_to_base64 or x25519_key_to_hex
void x25519_key_to_str_free(const char *key_str);
// Check if a null terminated string represents a valid x25519 key
// Returns 0 if not
int check_base64_encoded_x25519_key(const char *key);
/// Sets the default tracing_subscriber to write to `log_func`.
///
/// Uses Compact format without level, target, thread ids, thread names, or ansi control characters.
/// Subscribes to TRACE level events.
///
/// This function should only be called once as setting the default tracing_subscriber
/// more than once will result in an error.
///
/// Returns false on failure.
///
/// # Safety
///
/// `c_char` will be freed by the library after calling `log_func`. If the value needs
/// to be stored then `log_func` needs to create a copy, e.g. `strcpy`.
bool set_logging_function(void (*log_func)(const char *));
// Allocate a new tunnel
struct wireguard_tunnel *new_tunnel(const char *static_private,
const char *server_static_public,
const char *preshared_key,
uint16_t keep_alive, // Keep alive interval in seconds
uint32_t index); // The 24bit index prefix to be used for session indexes
// Deallocate the tunnel
void tunnel_free(struct wireguard_tunnel *);
struct wireguard_result wireguard_write(const struct wireguard_tunnel *tunnel,
const uint8_t *src,
uint32_t src_size,
uint8_t *dst,
uint32_t dst_size);
struct wireguard_result wireguard_read(const struct wireguard_tunnel *tunnel,
const uint8_t *src,
uint32_t src_size,
uint8_t *dst,
uint32_t dst_size);
struct wireguard_result wireguard_tick(const struct wireguard_tunnel *tunnel,
uint8_t *dst,
uint32_t dst_size);
struct wireguard_result wireguard_force_handshake(const struct wireguard_tunnel *tunnel,
uint8_t *dst,
uint32_t dst_size);
struct stats wireguard_stats(const struct wireguard_tunnel *tunnel);

View File

@@ -24,6 +24,7 @@ message RegistrationRequest {
fixed32 virtual_ip = 6;
bool allow_ip_change = 7;
bool client_secret = 8;
bytes client_secret_hash = 9;
}
message RegistrationResponse {
@@ -41,6 +42,8 @@ message DeviceInfo {
fixed32 virtual_ip = 2;
uint32 device_status = 3;
bool client_secret = 4;
bytes client_secret_hash = 5;
bool wireguard = 6;
}
message DeviceList {

View File

@@ -89,7 +89,7 @@ impl RsaCipher {
})
}
pub fn finger_(public_key_der: &[u8]) -> io::Result<String> {
match rsa::pkcs8::SubjectPublicKeyInfo::from_der(public_key_der) {
match spki::SubjectPublicKeyInfoOwned::from_der(public_key_der) {
Ok(spki) => match spki.fingerprint_base64() {
Ok(finger) => Ok(finger),
Err(e) => Err(io::Error::new(
@@ -120,7 +120,7 @@ impl RsaCipher {
match self
.inner
.private_key
.decrypt(rsa::PaddingScheme::PKCS1v15Encrypt, net_packet.payload())
.decrypt(rsa::Pkcs1v15Encrypt, net_packet.payload())
{
Ok(rs) => {
let mut nonce_raw = [0; 12];

View File

@@ -1,10 +1,23 @@
use chrono::{DateTime, Local};
use std::collections::HashMap;
use std::net::{Ipv4Addr, SocketAddr};
use chrono::{DateTime, Local};
use tokio::sync::mpsc::Sender;
#[derive(Clone, Debug)]
pub struct WireGuardConfig {
pub vnts_endpoint: String,
pub vnts_allowed_ips: String,
pub group_id: String,
pub device_id: String,
pub ip: Ipv4Addr,
pub prefix: u8,
pub persistent_keepalive: u16,
pub secret_key: [u8; 32],
pub public_key: [u8; 32],
}
/// 网段信息
#[derive(Default)]
#[derive(Default, Debug)]
pub struct NetworkInfo {
// 组网编号
// pub group: String,
@@ -33,6 +46,7 @@ impl NetworkInfo {
}
/// 客户端信息
#[derive(Debug)]
pub struct ClientInfo {
// 设备ID
pub device_id: String,
@@ -42,6 +56,8 @@ pub struct ClientInfo {
pub name: String,
// 客户端间是否加密
pub client_secret: bool,
// 加密hash
pub client_secret_hash: Vec<u8>,
// 和服务端是否加密
pub server_secret: bool,
// 链接服务器的来源地址
@@ -52,11 +68,51 @@ pub struct ClientInfo {
pub virtual_ip: u32,
// 建立的tcp连接发送端
pub tcp_sender: Option<Sender<Vec<u8>>>,
// wireguard客户端公钥
pub wireguard: Option<[u8; 32]>,
pub wg_sender: Option<Sender<(Vec<u8>, Ipv4Addr)>>,
pub client_status: Option<ClientStatusInfo>,
pub last_join_time: DateTime<Local>,
pub timestamp: i64,
}
/// 客户端简要信息
#[derive(Debug)]
pub struct SimpleClientInfo {
// 分配的ip
pub virtual_ip: u32,
// 版本
pub version: String,
// 名称
pub name: String,
// 客户端间是否加密
pub client_secret: bool,
// 加密hash
pub client_secret_hash: Vec<u8>,
// 和服务端是否加密
pub server_secret: bool,
// 是否在线
pub online: bool,
// 是wg客户端
pub wireguard: bool,
}
impl From<&ClientInfo> for SimpleClientInfo {
fn from(value: &ClientInfo) -> Self {
Self {
virtual_ip: value.virtual_ip,
version: value.version.clone(),
name: value.name.clone(),
client_secret: value.client_secret,
client_secret_hash: if value.online {
value.client_secret_hash.clone()
} else {
vec![]
},
server_secret: value.server_secret,
online: value.online,
wireguard: value.wireguard.is_some(),
}
}
}
impl Default for ClientInfo {
fn default() -> Self {
Self {
@@ -64,18 +120,21 @@ impl Default for ClientInfo {
version: "".to_string(),
name: "".to_string(),
client_secret: false,
client_secret_hash: vec![],
server_secret: false,
address: "0.0.0.0:0".parse().unwrap(),
online: false,
virtual_ip: 0,
tcp_sender: None,
wireguard: None,
wg_sender: None,
client_status: None,
last_join_time: Local::now(),
timestamp: 0,
}
}
}
#[derive(Debug)]
pub struct ClientStatusInfo {
pub p2p_list: Vec<Ipv4Addr>,
pub up_stream: u64,

View File

@@ -4,6 +4,7 @@ use std::sync::Arc;
use tokio::net::{TcpListener, UdpSocket};
use crate::cipher::RsaCipher;
use crate::core::server::wire_guard::WireGuardGroup;
use crate::core::service::PacketHandler;
use crate::core::store::cache::AppCache;
use crate::ConfigInfo;
@@ -13,6 +14,7 @@ mod udp;
#[cfg(feature = "web")]
mod web;
mod websocket;
mod wire_guard;
pub async fn start(
udp: std::net::UdpSocket,
@@ -29,8 +31,9 @@ pub async fn start(
rsa_cipher.clone(),
udp.clone(),
);
let wg = WireGuardGroup::new(cache.clone(), config.clone(), udp.clone());
let tcp_handle = tokio::spawn(tcp::start(TcpListener::from_std(tcp)?, handler.clone()));
let udp_handle = tokio::spawn(udp::start(udp, handler.clone()));
let udp_handle = tokio::spawn(udp::start(udp, handler.clone(), wg));
#[cfg(not(feature = "web"))]
let _ = tokio::try_join!(tcp_handle, udp_handle);
#[cfg(feature = "web")]

View File

@@ -1,4 +1,5 @@
use crate::core::service::PacketHandler;
use crate::core::store::cache::VntContext;
use crate::protocol::NetPacket;
use std::io;
use std::net::SocketAddr;
@@ -73,21 +74,29 @@ async fn stream_handle(stream: TcpStream, addr: SocketAddr, handler: PacketHandl
let _ = w.shutdown().await;
});
tokio::spawn(async move {
if let Err(e) = tcp_read(r, addr, sender, handler).await {
let mut context = VntContext {
link_context: None,
server_cipher: None,
link_address: addr,
};
if let Err(e) = tcp_read(&mut context, r, addr, sender, &handler).await {
log::warn!("tcp_read {:?}", e)
}
handler.leave(context).await;
});
}
async fn tcp_read(
context: &mut VntContext,
mut read: OwnedReadHalf,
addr: SocketAddr,
sender: Sender<Vec<u8>>,
handler: PacketHandler,
handler: &PacketHandler,
) -> io::Result<()> {
let mut head = [0; 4];
let mut buf = [0; 65536];
let sender = Some(sender);
loop {
read.read_exact(&mut head).await?;
if head[0] != 0 {
@@ -103,7 +112,7 @@ async fn tcp_read(
}
read.read_exact(&mut buf[..len]).await?;
let packet = NetPacket::new0(len, &mut buf)?;
if let Some(rs) = handler.handle(packet, addr, &sender).await {
if let Some(rs) = handler.handle(context, packet, addr, &sender).await {
if sender
.as_ref()
.unwrap()

View File

@@ -1,21 +1,91 @@
use std::sync::Arc;
use tokio::net::UdpSocket;
use crate::core::server::wire_guard::WireGuardGroup;
use crate::core::service::PacketHandler;
use crate::core::store::cache::VntContext;
use crate::protocol::NetPacket;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::UdpSocket;
use tokio::sync::mpsc::{channel, Sender};
pub async fn start(main_udp: Arc<UdpSocket>, handler: PacketHandler, mut wg: WireGuardGroup) {
let mut udp_group = UdpGroup::new(main_udp.clone(), handler);
let mut buf = [0u8; 65536];
pub async fn start(main_udp: Arc<UdpSocket>, handler: PacketHandler) {
loop {
let mut buf = vec![0u8; 65536];
match main_udp.recv_from(&mut buf).await {
Ok((len, addr)) => {
let handler = handler.clone();
let udp = main_udp.clone();
tokio::spawn(async move {
match NetPacket::new(&mut buf[..len]) {
if len == 0 {
log::warn!("UnexpectedEof {}", addr);
continue;
}
let buf = buf[..len].to_vec();
if WireGuardGroup::maybe_wg(&buf) {
// 可能是wg协议
wg.handle(buf, addr);
continue;
}
if let Err(e) = udp_group.handle(buf, addr) {
log::error!("{} {:?}", addr, e);
}
}
#[cfg(windows)]
Err(ref e) if e.kind() == std::io::ErrorKind::ConnectionReset => {
// 忽略 ConnectionReset 错误
}
Err(e) => {
log::error!("{:?}", e)
}
}
}
}
pub struct UdpGroup {
data_channel_map: Arc<Mutex<HashMap<SocketAddr, Sender<Vec<u8>>>>>,
udp: Arc<UdpSocket>,
handler: PacketHandler,
}
impl UdpGroup {
pub fn new(udp: Arc<UdpSocket>, handler: PacketHandler) -> Self {
Self {
data_channel_map: Default::default(),
udp,
handler,
}
}
pub fn handle(&mut self, buf: Vec<u8>, addr: SocketAddr) -> anyhow::Result<()> {
if let Some(sender) = self.data_channel_map.lock().get(&addr) {
sender.try_send(buf)?;
return Ok(());
}
let (udp_sender, mut udp_receiver) = channel(64);
udp_sender.try_send(buf)?;
let data_channel_map = self.data_channel_map.clone();
data_channel_map.lock().insert(addr, udp_sender);
let handler = self.handler.clone();
let udp = self.udp.clone();
tokio::spawn(async move {
let mut context = VntContext {
link_context: None,
server_cipher: None,
link_address: addr,
};
loop {
let data = match tokio::time::timeout(Duration::from_secs(60), udp_receiver.recv())
.await
{
Ok(data) => data,
Err(_) => break,
};
if let Some(data) = data {
match NetPacket::new(data) {
Ok(net_packet) => {
if let Some(rs) = handler.handle(net_packet, addr, &None).await {
if let Some(rs) =
handler.handle(&mut context, net_packet, addr, &None).await
{
if let Err(e) = udp.send_to(rs.buffer(), addr).await {
log::error!("{:?} {}", e, addr)
}
@@ -25,11 +95,13 @@ pub async fn start(main_udp: Arc<UdpSocket>, handler: PacketHandler) {
log::error!("{:?} {}", e, addr)
}
}
});
} else {
break;
}
}
Err(e) => {
log::error!("{:?}", e)
}
}
handler.leave(context).await;
data_channel_map.lock().remove(&addr);
});
Ok(())
}
}

View File

@@ -1,6 +1,5 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::net;
use std::sync::Arc;
use actix_web::dev::Service;
use actix_web::web::Data;
@@ -9,7 +8,9 @@ use actix_web::{middleware, post, web, App, HttpRequest, HttpResponse, HttpServe
use actix_web_static_files::ResourceFiles;
use crate::core::server::web::service::VntsWebService;
use crate::core::server::web::vo::{LoginData, ResponseMessage};
use crate::core::server::web::vo::req::{CreateWGData, LoginData, RemoveClientReq};
use crate::core::server::web::vo::ResponseMessage;
use crate::core::store::cache::AppCache;
use crate::ConfigInfo;
@@ -18,7 +19,7 @@ mod vo;
include!(concat!(env!("OUT_DIR"), "/generated.rs"));
#[post("/login")]
#[post("/api/login")]
async fn login(service: Data<VntsWebService>, data: web::Json<LoginData>) -> HttpResponse {
match service.login(data.0).await {
Ok(auth) => HttpResponse::Ok().json(ResponseMessage::success(auth)),
@@ -26,13 +27,37 @@ async fn login(service: Data<VntsWebService>, data: web::Json<LoginData>) -> Htt
}
}
#[post("/group_list")]
#[post("/api/group_list")]
async fn group_list(_req: HttpRequest, service: Data<VntsWebService>) -> HttpResponse {
let info = service.group_list();
HttpResponse::Ok().json(ResponseMessage::success(info))
}
#[post("/group_info")]
#[post("/api/remove_client")]
async fn remove_client(
_req: HttpRequest,
service: Data<VntsWebService>,
data: web::Json<RemoveClientReq>,
) -> HttpResponse {
service.remove_client(data.0);
HttpResponse::Ok().json(ResponseMessage::success("success"))
}
#[post("/api/private_key")]
async fn private_key(_req: HttpRequest, service: Data<VntsWebService>) -> HttpResponse {
let private_key = service.gen_wg_private_key();
HttpResponse::Ok().json(ResponseMessage::success(private_key))
}
#[post("/api/create_wg_config")]
async fn create_wg_config(
_req: HttpRequest,
service: Data<VntsWebService>,
data: web::Json<CreateWGData>,
) -> HttpResponse {
match service.create_wg_config(data.0).await {
Ok(wg_config) => HttpResponse::Ok().json(ResponseMessage::success(wg_config)),
Err(e) => HttpResponse::Ok().json(ResponseMessage::fail(e.to_string())),
}
}
#[post("/api/group_info")]
async fn group_info(
_req: HttpRequest,
service: Data<VntsWebService>,
@@ -46,36 +71,19 @@ async fn group_info(
}
}
#[derive(Clone)]
struct AuthApi {
api_set: Arc<HashSet<String>>,
}
fn auth_api_set() -> AuthApi {
let mut api_set = HashSet::new();
api_set.insert("/group_info".to_string());
api_set.insert("/group_list".to_string());
AuthApi {
api_set: Arc::new(api_set),
}
}
pub async fn start(
lst: net::TcpListener,
cache: AppCache,
config: ConfigInfo,
) -> std::io::Result<()> {
let web_service = VntsWebService::new(cache, config);
let auth_api = auth_api_set();
HttpServer::new(move || {
let generated = generate();
App::new()
.app_data(Data::new(web_service.clone()))
.app_data(Data::new(auth_api.clone()))
.wrap_fn(|request, srv| {
let auth_api: &Data<AuthApi> = request.app_data().unwrap();
let path = request.path();
if path == "/login" || !auth_api.api_set.contains(path) {
if path == "/api/login" || !path.contains("/api/") {
return srv.call(request);
}
let service: &Data<VntsWebService> = request.app_data().unwrap();
@@ -96,6 +104,9 @@ pub async fn start(
})
.wrap(middleware::Compress::default())
.service(login)
.service(remove_client)
.service(private_key)
.service(create_wg_config)
.service(group_list)
.service(group_info)
.service(ResourceFiles::new("/", generated))

View File

@@ -1,11 +1,22 @@
use crossbeam_utils::atomic::AtomicCell;
use std::net::{SocketAddr, SocketAddrV4};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::core::server::web::vo::{
ClientInfo, ClientStatusInfo, GroupList, LoginData, NetworkInfo,
use anyhow::{anyhow, Context};
use base64::engine::general_purpose;
use base64::Engine;
use crossbeam_utils::atomic::AtomicCell;
use ipnetwork::Ipv4Network;
use rsa::rand_core::RngCore;
use crate::core::entity::WireGuardConfig;
use crate::core::server::web::vo::req::{CreateWGData, CreateWgConfig, LoginData, RemoveClientReq};
use crate::core::server::web::vo::res::{
ClientInfo, ClientStatusInfo, GroupList, NetworkInfo, WGData, WgConfig,
};
use crate::core::service::server::{generate_ip, RegisterClientRequest};
use crate::core::store::cache::AppCache;
use crate::ConfigInfo;
@@ -60,6 +71,113 @@ impl VntsWebService {
.collect();
GroupList { group_list }
}
pub fn remove_client(&self, req: RemoveClientReq) {
if let Some(ip) = req.virtual_ip {
if let Some(network_info) = self.cache.virtual_network.get(&req.group_id) {
if let Some(client_info) = network_info.write().clients.remove(&ip.into()) {
if let Some(key) = client_info.wireguard {
self.cache.wg_group_map.remove(&key);
}
}
}
} else {
if let Some(network_info) = self.cache.virtual_network.remove(&req.group_id) {
for (_, client_info) in network_info.write().clients.drain() {
if let Some(key) = client_info.wireguard {
self.cache.wg_group_map.remove(&key);
}
}
}
}
}
pub fn gen_wg_private_key(&self) -> String {
let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes);
return general_purpose::STANDARD.encode(bytes);
}
pub async fn create_wg_config(&self, wg_data: CreateWGData) -> anyhow::Result<WGData> {
let device_id = wg_data.device_id.trim().to_string();
let group_id = wg_data.group_id.trim().to_string();
if group_id.is_empty() {
Err(anyhow!("组网id不能为空"))?;
}
if device_id.is_empty() {
Err(anyhow!("设备id不能为空"))?;
}
let cache = &self.cache;
let (secret_key, public_key) = Self::check_wg_config(&wg_data.config)?;
let gateway = self.config.gateway;
let netmask = self.config.netmask;
let network = Ipv4Network::with_netmask(gateway, netmask)?;
let network = Ipv4Network::with_netmask(network.network(), netmask)?;
let virtual_ip = if wg_data.virtual_ip.trim().is_empty() {
Ipv4Addr::UNSPECIFIED
} else {
Ipv4Addr::from_str(&wg_data.virtual_ip).context("虚拟IP错误")?
};
let register_client_request = RegisterClientRequest {
group_id: group_id.clone(),
virtual_ip,
gateway,
netmask,
allow_ip_change: false,
device_id: device_id.clone(),
version: String::from("wg"),
name: wg_data.name.clone(),
client_secret: true,
client_secret_hash: vec![],
server_secret: true,
address: SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(),
tcp_sender: None,
online: false,
wireguard: Some(public_key),
};
let response = generate_ip(cache, register_client_request).await?;
let wireguard_config = WireGuardConfig {
vnts_endpoint: wg_data.config.vnts_endpoint.clone(),
vnts_allowed_ips: network.to_string(),
group_id: group_id.clone(),
device_id: device_id.clone(),
ip: response.virtual_ip,
prefix: network.prefix(),
persistent_keepalive: wg_data.config.persistent_keepalive,
secret_key,
public_key,
};
cache.wg_group_map.insert(public_key, wireguard_config);
let config = WgConfig {
vnts_endpoint: wg_data.config.vnts_endpoint,
vnts_public_key: general_purpose::STANDARD.encode(&self.config.wg_public_key),
vnts_allowed_ips: network.to_string(),
public_key: general_purpose::STANDARD.encode(public_key),
private_key: general_purpose::STANDARD.encode(secret_key),
ip: response.virtual_ip,
prefix: network.prefix(),
persistent_keepalive: wg_data.config.persistent_keepalive,
};
let wg_data = WGData {
group_id,
virtual_ip: response.virtual_ip,
device_id,
name: wg_data.name,
config,
};
Ok(wg_data)
}
fn check_wg_config(config: &CreateWgConfig) -> anyhow::Result<([u8; 32], [u8; 32])> {
let addr = SocketAddr::from_str(&config.vnts_endpoint).context("服务器地址错误")?;
if addr.ip().is_unspecified() || addr.port() == 0 {
Err(anyhow!("服务端地址错误"))?
}
let private_key = general_purpose::STANDARD
.decode(&config.private_key)
.context("私钥错误")?;
let private_key: [u8; 32] = private_key.try_into().map_err(|_| anyhow!("私钥错误"))?;
let secret_key = boringtun::x25519::StaticSecret::from(private_key);
let public_key = *boringtun::x25519::PublicKey::from(&secret_key).as_bytes();
Ok((private_key, public_key))
}
pub fn group_info(&self, group: String) -> Option<NetworkInfo> {
if let Some(info) = self.cache.virtual_network.get(&group) {
let guard = info.read();
@@ -67,19 +185,20 @@ impl VntsWebService {
guard.network_ip.into(),
guard.mask_ip.into(),
guard.gateway_ip.into(),
general_purpose::STANDARD.encode(&self.config.wg_public_key),
);
for into in guard.clients.values() {
let address = match into.address {
SocketAddr::V4(_) => into.address,
for info in guard.clients.values() {
let address = match info.address {
SocketAddr::V4(_) => info.address,
SocketAddr::V6(ipv6) => {
if let Some(ipv4) = ipv6.ip().to_ipv4_mapped() {
SocketAddr::V4(SocketAddrV4::new(ipv4, ipv6.port()))
} else {
into.address
info.address
}
}
};
let status_info = if let Some(client_status) = &into.client_status {
let status_info = if let Some(client_status) = &info.client_status {
Some(ClientStatusInfo {
p2p_list: client_status.p2p_list.clone(),
up_stream: client_status.up_stream,
@@ -93,18 +212,24 @@ impl VntsWebService {
} else {
None
};
let mut wg_config = None;
if let Some(key) = &info.wireguard {
if let Some(v) = self.cache.wg_group_map.get(key) {
wg_config.replace(v.clone());
}
}
let client_info = ClientInfo {
device_id: into.device_id.clone(),
version: into.version.clone(),
name: into.name.clone(),
client_secret: into.client_secret,
server_secret: into.server_secret,
device_id: info.device_id.clone(),
version: info.version.clone(),
name: info.name.clone(),
client_secret: info.client_secret,
server_secret: info.server_secret,
address,
online: into.online,
virtual_ip: into.virtual_ip.into(),
online: info.online,
virtual_ip: info.virtual_ip.into(),
status_info,
last_join_time: into.last_join_time.format("%Y-%m-%d %H:%M:%S").to_string(),
last_join_time: info.last_join_time.format("%Y-%m-%d %H:%M:%S").to_string(),
wg_config: wg_config.map(|v| v.into()),
};
network.clients.push(client_info);
}
@@ -116,32 +241,4 @@ impl VntsWebService {
None
}
}
// pub fn groups_info(&self) -> GroupsInfo {
// let mut data = GroupsInfo::new();
// for (group, info) in self.cache.virtual_network.key_values() {
// let guard = info.read();
// let mut network = NetworkInfo::new(
// guard.network_ip.into(),
// guard.mask_ip.into(),
// guard.gateway_ip.into(),
// );
// for (_ip, into) in guard.clients.iter() {
// let client_info = ClientInfo {
// device_id: into.device_id.clone(),
// name: into.name.clone(),
// client_secret: into.client_secret,
// server_secret: into.server_secret.is_some(),
// address: into.address,
// online: into.online,
// virtual_ip: into.virtual_ip.into(),
// };
// network.clients.push(client_info);
// }
// network
// .clients
// .sort_by(|v1, v2| v1.virtual_ip.cmp(&v2.virtual_ip));
// data.data.insert(group.to_string(), network);
// }
// data
// }
}

View File

@@ -1,8 +1,7 @@
use std::collections::HashMap;
use std::net::{Ipv4Addr, SocketAddr};
use serde::{Deserialize, Serialize};
pub mod req;
pub mod res;
#[derive(Debug, Serialize, Deserialize)]
pub struct ResponseMessage<V> {
data: V,
@@ -38,73 +37,3 @@ impl ResponseMessage<Option<()>> {
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ClientInfo {
// 设备ID
pub device_id: String,
// 客户端版本
pub version: String,
// 名称
pub name: String,
// 客户端间是否加密
pub client_secret: bool,
// 客户端和服务端是否加密
pub server_secret: bool,
// 链接服务器的来源地址
pub address: SocketAddr,
// 是否在线
pub online: bool,
// 分配的ip
pub virtual_ip: Ipv4Addr,
pub status_info: Option<ClientStatusInfo>,
pub last_join_time: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ClientStatusInfo {
pub p2p_list: Vec<Ipv4Addr>,
pub up_stream: u64,
pub down_stream: u64,
pub is_cone: bool,
pub update_time: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct NetworkInfo {
// 网段
pub network_ip: Ipv4Addr,
// 掩码
pub mask_ip: Ipv4Addr,
// 网关
pub gateway_ip: Ipv4Addr,
// 网段下的客户端列表
pub clients: Vec<ClientInfo>,
}
impl NetworkInfo {
pub fn new(network_ip: Ipv4Addr, mask_ip: Ipv4Addr, gateway_ip: Ipv4Addr) -> Self {
Self {
network_ip,
mask_ip,
gateway_ip,
clients: Default::default(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GroupList {
pub group_list: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GroupsInfo {
pub data: HashMap<String, NetworkInfo>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginData {
pub username: String,
pub password: String,
}

View File

@@ -0,0 +1,29 @@
use serde::{Deserialize, Serialize};
use std::net::Ipv4Addr;
#[derive(Debug, Serialize, Deserialize)]
pub struct CreateWGData {
pub group_id: String,
pub virtual_ip: String,
pub device_id: String,
pub name: String,
pub config: CreateWgConfig,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CreateWgConfig {
pub vnts_endpoint: String,
pub private_key: String,
pub persistent_keepalive: u16,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginData {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RemoveClientReq {
pub group_id: String,
pub virtual_ip: Option<Ipv4Addr>,
}

View File

@@ -0,0 +1,123 @@
use crate::core::entity::WireGuardConfig;
use base64::engine::general_purpose;
use base64::Engine;
use serde::{Deserialize, Serialize};
use std::net::{Ipv4Addr, SocketAddr};
#[derive(Debug, Serialize, Deserialize)]
pub struct WGData {
pub group_id: String,
pub virtual_ip: Ipv4Addr,
pub device_id: String,
pub name: String,
pub config: WgConfig,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct WgConfig {
pub vnts_endpoint: String,
pub vnts_public_key: String,
pub vnts_allowed_ips: String,
pub public_key: String,
pub private_key: String,
// 合一起是 Address = ip/prefix
pub ip: Ipv4Addr,
pub prefix: u8,
pub persistent_keepalive: u16,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ClientInfo {
// 设备ID
pub device_id: String,
// 客户端版本
pub version: String,
// 名称
pub name: String,
// 客户端间是否加密
pub client_secret: bool,
// 客户端和服务端是否加密
pub server_secret: bool,
// 链接服务器的来源地址
pub address: SocketAddr,
// 是否在线
pub online: bool,
// 分配的ip
pub virtual_ip: Ipv4Addr,
pub status_info: Option<ClientStatusInfo>,
pub last_join_time: String,
// wg配置
pub wg_config: Option<WireGuardConfigRes>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct WireGuardConfigRes {
pub vnts_endpoint: String,
pub vnts_allowed_ips: String,
pub group_id: String,
pub device_id: String,
pub ip: Ipv4Addr,
pub prefix: u8,
pub persistent_keepalive: u16,
pub secret_key: String,
pub public_key: String,
}
impl From<WireGuardConfig> for WireGuardConfigRes {
fn from(value: WireGuardConfig) -> Self {
Self {
vnts_endpoint: value.vnts_endpoint,
vnts_allowed_ips: value.vnts_allowed_ips,
group_id: value.group_id,
device_id: value.device_id,
ip: value.ip,
prefix: value.prefix,
persistent_keepalive: value.persistent_keepalive,
secret_key: general_purpose::STANDARD.encode(&value.secret_key),
public_key: general_purpose::STANDARD.encode(&value.public_key),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ClientStatusInfo {
pub p2p_list: Vec<Ipv4Addr>,
pub up_stream: u64,
pub down_stream: u64,
pub is_cone: bool,
pub update_time: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct NetworkInfo {
// 网段
pub network_ip: Ipv4Addr,
// 掩码
pub mask_ip: Ipv4Addr,
// 网关
pub gateway_ip: Ipv4Addr,
// vnts的公钥
pub vnts_public_key: String,
// 网段下的客户端列表
pub clients: Vec<ClientInfo>,
}
impl NetworkInfo {
pub fn new(
network_ip: Ipv4Addr,
mask_ip: Ipv4Addr,
gateway_ip: Ipv4Addr,
vnts_public_key: String,
) -> Self {
Self {
network_ip,
mask_ip,
gateway_ip,
vnts_public_key,
clients: Default::default(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GroupList {
pub group_list: Vec<String>,
}

View File

@@ -1,4 +1,5 @@
use crate::core::service::PacketHandler;
use crate::core::store::cache::VntContext;
use crate::protocol::NetPacket;
use anyhow::Context;
use futures_util::{SinkExt, StreamExt};
@@ -14,16 +15,23 @@ pub async fn handle_websocket_connection(
handler: PacketHandler,
) {
tokio::spawn(async move {
if let Err(e) = handle_websocket_connection0(stream, addr, handler).await {
let mut context = VntContext {
link_context: None,
server_cipher: None,
link_address: addr,
};
if let Err(e) = handle_websocket_connection0(&mut context, stream, addr, &handler).await {
log::warn!("websocket err {:?} {}", e, addr);
}
handler.leave(context).await;
});
}
async fn handle_websocket_connection0(
context: &mut VntContext,
stream: TcpStream,
addr: SocketAddr,
handler: PacketHandler,
handler: &PacketHandler,
) -> anyhow::Result<()> {
let ws_stream = accept_async(stream)
.await
@@ -42,13 +50,14 @@ async fn handle_websocket_connection0(
let _ = ws_write.close().await;
});
let sender = Some(sender);
while let Some(msg) = ws_read.next().await {
let msg = msg.with_context(|| format!("Error during WebSocket read {}", addr))?;
match msg {
Message::Text(txt) => log::info!("Received text message: {} {}", txt, addr),
Message::Binary(mut data) => {
let packet = NetPacket::new0(data.len(), &mut data)?;
if let Some(rs) = handler.handle(packet, addr, &sender).await {
if let Some(rs) = handler.handle(context, packet, addr, &sender).await {
if sender
.as_ref()
.unwrap()

View File

@@ -0,0 +1,470 @@
use crate::core::entity::{NetworkInfo, WireGuardConfig};
use crate::core::store::cache::AppCache;
use crate::protocol::{ip_turn_packet, NetPacket, Protocol, HEAD_LEN, MAX_TTL};
use crate::ConfigInfo;
use anyhow::{anyhow, Context};
use boringtun::noise::errors::WireGuardError;
use boringtun::noise::{handshake, Packet, Tunn, TunnResult};
use boringtun::x25519::StaticSecret;
use chrono::Local;
use packet::icmp::{icmp, Kind};
use packet::ip::ipv4;
use packet::ip::ipv4::packet::IpV4Packet;
use parking_lot::{Mutex, RwLock};
use rand::RngCore;
use std::collections::HashMap;
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::UdpSocket;
use tokio::sync::mpsc::{channel, Receiver, Sender};
pub struct WireGuardGroup {
cache: AppCache,
config: ConfigInfo,
udp: Arc<UdpSocket>,
data_channel_map: Arc<Mutex<HashMap<SocketAddr, Sender<Vec<u8>>>>>,
}
impl WireGuardGroup {
pub fn new(cache: AppCache, config: ConfigInfo, udp: Arc<UdpSocket>) -> Self {
Self {
cache,
config,
udp,
data_channel_map: Default::default(),
}
}
pub fn handle(&mut self, buf: Vec<u8>, addr: SocketAddr) {
if let Err(e) = self.handle0(buf, addr) {
log::warn!("{},{}", addr, e);
}
}
fn handle0(&mut self, buf: Vec<u8>, addr: SocketAddr) -> anyhow::Result<()> {
if let Some(sender) = self.data_channel_map.lock().get(&addr) {
sender.try_send(buf)?;
return Ok(());
}
let config = self.handshake(&buf)?;
let network_info = self
.cache
.virtual_network
.get(&config.group_id)
.context("wg配置已过期")?;
let (network_receiver, broadcast_ip, mask_ip, gateway_ip) = {
let mut guard = network_info.write();
let broadcast_ip = guard.network_ip | (!guard.mask_ip);
let client_info = guard
.clients
.get_mut(&config.ip.into())
.context("wg配置已过期")?;
if client_info.wireguard.is_none() {
Err(anyhow!("不是wg配置"))?;
}
let (network_sender, network_receiver) = channel(64);
client_info.wg_sender = Some(network_sender);
client_info.last_join_time = Local::now();
client_info.timestamp = client_info.last_join_time.timestamp();
client_info.address = addr;
client_info.online = true;
guard.epoch += 1;
(
network_receiver,
broadcast_ip,
guard.mask_ip,
guard.gateway_ip,
)
};
let wg = WireGuard::new(
network_info.clone(),
broadcast_ip.into(),
mask_ip.into(),
gateway_ip.into(),
self.cache.clone(),
self.config.wg_secret_key.clone(),
self.udp.clone(),
addr,
config,
self.data_channel_map.clone(),
);
let (udp_sender, udp_receiver) = channel(64);
udp_sender.try_send(buf)?;
self.data_channel_map.lock().insert(addr, udp_sender);
tokio::spawn(wg.start(udp_receiver, network_receiver));
Ok(())
}
#[inline]
pub fn maybe_wg(buf: &[u8]) -> bool {
if buf.len() < 4 {
return false;
}
// Checks the type, as well as the reserved zero fields
let packet_type = u32::from_le_bytes(buf[0..4].try_into().unwrap());
(1..=4).contains(&packet_type)
}
pub fn handshake(&mut self, buf: &[u8]) -> anyhow::Result<WireGuardConfig> {
let packet = match Tunn::parse_incoming_packet(buf) {
Ok(packet) => packet,
Err(e) => Err(anyhow!("{:?}", e))?,
};
match packet {
Packet::HandshakeInit(data) => {
let half_handshake = handshake::parse_handshake_anon(
&self.config.wg_secret_key,
&self.config.wg_public_key,
&data,
)
.map_err(|e| anyhow!("HandshakeInit {:?}", e))?;
let config = self
.cache
.wg_group_map
.get(&half_handshake.peer_static_public)
.context("需要先在vnts配置wg信息")?
.clone();
Ok(config)
}
_ => Err(anyhow!("非握手包")),
}
}
}
pub struct WireGuard {
network_info: Arc<RwLock<NetworkInfo>>,
ip: Ipv4Addr,
broadcast_ip: Ipv4Addr,
mask_ip: Ipv4Addr,
gateway_ip: Ipv4Addr,
group_id: String,
tunn: Tunn,
cache: AppCache,
wg_source_addr: SocketAddr,
udp: Arc<UdpSocket>,
data_channel_map: Arc<Mutex<HashMap<SocketAddr, Sender<Vec<u8>>>>>,
}
impl WireGuard {
pub fn new(
network_info: Arc<RwLock<NetworkInfo>>,
broadcast_ip: Ipv4Addr,
mask_ip: Ipv4Addr,
gateway_ip: Ipv4Addr,
cache: AppCache,
vnts_secret_key: StaticSecret,
udp: Arc<UdpSocket>,
wg_source_addr: SocketAddr,
config: WireGuardConfig,
data_channel_map: Arc<Mutex<HashMap<SocketAddr, Sender<Vec<u8>>>>>,
) -> Self {
let tunn = Tunn::new(
vnts_secret_key,
config.public_key.into(),
None,
Some(config.persistent_keepalive),
rand::thread_rng().next_u32(),
None,
);
Self {
network_info,
ip: config.ip,
broadcast_ip,
mask_ip,
gateway_ip,
group_id: config.group_id,
tunn,
cache,
wg_source_addr,
udp,
data_channel_map,
}
}
pub async fn start(
mut self,
udp_receiver: Receiver<Vec<u8>>,
ipv4_receiver: Receiver<(Vec<u8>, Ipv4Addr)>,
) {
if let Err(e) = self.start0(udp_receiver, ipv4_receiver).await {
log::warn!(
"wg连接异常断开 {:?},{:?},{:?},{:?}",
self.group_id,
self.ip,
self.wg_source_addr,
e
);
}
self.offline();
}
fn offline(&self) {
if let Some(v) = self.cache.virtual_network.get(&self.group_id) {
if let Some(v) = v.write().clients.get_mut(&self.ip.into()) {
if v.address == self.wg_source_addr {
v.online = false;
v.wg_sender = None;
}
}
}
self.data_channel_map.lock().remove(&self.wg_source_addr);
}
pub async fn start0(
&mut self,
mut udp_receiver: Receiver<Vec<u8>>,
mut ipv4_receiver: Receiver<(Vec<u8>, Ipv4Addr)>,
) -> anyhow::Result<()> {
let mut interval = tokio::time::interval(Duration::from_millis(200));
let mut dst_buf = [0; 65535];
let mut dst_buf2 = [0; 65535];
log::info!(
"处理wg链接 {},{}/{},{}",
self.group_id,
self.ip,
self.mask_ip,
self.wg_source_addr
);
loop {
tokio::select! {
rs = udp_receiver.recv()=>{
if let Some(mut data) = rs{
self.handle_wg_data(&mut data,&mut dst_buf,&mut dst_buf2).await?;
}else{
break;
}
}
rs = ipv4_receiver.recv()=>{
if let Some((data,ip)) = rs{
if let Err(e) = self.handle_ipv4_data(&data,&mut dst_buf).await{
log::warn!("来源{},发送到wg失败,{:?}",ip,e)
}
}else{
break;
}
}
_ = interval.tick()=>{
self.update_timers(&mut dst_buf,&mut dst_buf2).await?
}
}
}
Ok(())
}
pub async fn handle_ipv4_data(&mut self, buf: &[u8], dst_buf: &mut [u8]) -> anyhow::Result<()> {
let result = self.tunn.encapsulate(buf, dst_buf);
match result {
TunnResult::Done => {}
TunnResult::WriteToNetwork(data) => {
self.udp.send_to(data, self.wg_source_addr).await?;
}
e => Err(anyhow!("{:?}", e))?,
}
Ok(())
}
pub async fn handle_wg_data(
&mut self,
mut buf: &mut [u8],
dst_buf: &mut [u8],
dst_buf2: &mut [u8],
) -> anyhow::Result<()> {
loop {
let mut result = self.tunn.decapsulate(None, buf, dst_buf);
if !self.handle_tunn_result(&mut result, dst_buf2).await? {
break;
}
buf = &mut [];
}
Ok(())
}
async fn handle_tunn_result(
&mut self,
result: &mut TunnResult<'_>,
dst_buf: &mut [u8],
) -> anyhow::Result<bool> {
match result {
TunnResult::Done => {}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
// 超时了直接断开vnts不重连等对端重连
return Err(anyhow!("链接超时"));
}
TunnResult::Err(e) => {
log::warn!("WireGuard数据异常 {:?}", e);
}
TunnResult::WriteToNetwork(data) => {
self.udp.send_to(data, self.wg_source_addr).await?;
return Ok(true);
}
TunnResult::WriteToTunnelV4(data, _source_ip) => {
let mut packet = IpV4Packet::new(data)?;
let source_ip = packet.source_ip();
let destination_ip = packet.destination_ip();
if let Err(e) = self
.turn_data(source_ip, destination_ip, &mut packet.buffer, dst_buf)
.await
{
log::warn!("wg数据转发失败 {}->{} {:?}", source_ip, destination_ip, e);
}
}
TunnResult::WriteToTunnelV6(_packet, ip) => {
return Err(anyhow!("不支持ipv6连接 {:?}", ip))
}
}
Ok(false)
}
/// from 'wireguard_tick':
/// This is a state keeping function, that need to be called periodically.
/// Recommended interval: 100ms.
pub async fn update_timers(
&mut self,
dst_buf: &mut [u8],
dst_buf2: &mut [u8],
) -> anyhow::Result<()> {
let mut result = self.tunn.update_timers(dst_buf);
self.handle_tunn_result(&mut result, dst_buf2).await?;
Ok(())
}
async fn turn_data(
&mut self,
src_ip: Ipv4Addr,
dest_ip: Ipv4Addr,
data: &mut [u8],
dst_buf: &mut [u8],
) -> anyhow::Result<()> {
if dest_ip == self.gateway_ip {
if self.ping(data, src_ip, dest_ip).is_ok() {
if let Err(e) = self.handle_ipv4_data(&data, dst_buf).await {
log::warn!("发送ping回应到wg失败,{:?}", e)
}
}
return Ok(());
}
if dest_ip.is_broadcast() || dest_ip == self.broadcast_ip {
// 广播
let x: Vec<_> = self
.network_info
.read()
.clients
.values()
.filter(|v| v.online && v.virtual_ip != u32::from(self.ip))
.map(|v| {
(
v.address,
v.tcp_sender.clone(),
v.server_secret,
v.wg_sender.clone(),
)
})
.collect();
for (peer_addr, peer_tcp_sender, server_secret, peer_wg_sender) in x {
if let Err(e) = self
.send_one(
peer_addr,
peer_tcp_sender,
peer_wg_sender,
server_secret,
src_ip,
dest_ip,
data,
dst_buf,
)
.await
{
log::warn!("wg广播失败 {} {} {:?}", src_ip, peer_addr, e);
}
}
return Ok(());
}
let (server_secret, peer_addr, peer_tcp_sender, peer_wg_sender) = {
let guard = self.network_info.read();
if let Some(dest_client_info) = guard.clients.get(&dest_ip.into()) {
if !dest_client_info.online {
Err(anyhow!("目标不在线"))?
}
if !dest_client_info.virtual_ip == u32::from(self.ip) {
Err(anyhow!("阻止回路"))?
}
let dest_link_addr = dest_client_info.address;
let server_secret = dest_client_info.server_secret;
(
server_secret,
dest_link_addr,
dest_client_info.tcp_sender.clone(),
dest_client_info.wg_sender.clone(),
)
} else {
Err(anyhow!("目标未注册"))?
}
};
self.send_one(
peer_addr,
peer_tcp_sender,
peer_wg_sender,
server_secret,
src_ip,
dest_ip,
data,
dst_buf,
)
.await?;
Ok(())
}
async fn send_one(
&self,
peer_addr: SocketAddr,
peer_tcp_sender: Option<Sender<Vec<u8>>>,
peer_wg_sender: Option<Sender<(Vec<u8>, Ipv4Addr)>>,
server_secret: bool,
src_ip: Ipv4Addr,
dest_ip: Ipv4Addr,
data: &mut [u8],
dst_buf: &mut [u8],
) -> anyhow::Result<()> {
if let Some(peer_wg_sender) = peer_wg_sender {
if let Err(e) = peer_wg_sender.send((data.to_vec(), self.ip)).await {
Err(anyhow!("发送到对端wg失败 {}", e))?
}
return Ok(());
}
let mut net_packet = NetPacket::new0(HEAD_LEN + data.len(), dst_buf)?;
net_packet.set_default_version();
// 把wg的转发当成是服务端来源的数据因为服务端没有客户端密钥对数据进行加密
net_packet.set_gateway_flag(true);
net_packet.set_protocol(Protocol::IpTurn);
net_packet.set_transport_protocol(ip_turn_packet::Protocol::WGIpv4.into());
net_packet.first_set_ttl(MAX_TTL);
net_packet.set_source(src_ip);
net_packet.set_destination(dest_ip);
net_packet.set_payload(data)?;
if server_secret {
let cipher = self
.cache
.cipher_session
.get(&peer_addr)
.context("加密信息不存在")?;
cipher.encrypt_ipv4(&mut net_packet)?;
}
if let Some(tcp_sender) = peer_tcp_sender {
tcp_sender.send(net_packet.buffer().to_vec()).await?;
} else {
self.udp.send_to(net_packet.buffer(), peer_addr).await?;
}
Ok(())
}
fn ping(&self, data: &mut [u8], src_ip: Ipv4Addr, dest_ip: Ipv4Addr) -> anyhow::Result<()> {
let mut ipv4 = IpV4Packet::new(data)?;
if let ipv4::protocol::Protocol::Icmp = ipv4.protocol() {
let mut icmp_packet = icmp::IcmpPacket::new(ipv4.payload_mut())?;
if icmp_packet.kind() == Kind::EchoRequest {
//开启ping
icmp_packet.set_kind(Kind::EchoReply);
icmp_packet.update_checksum();
ipv4.set_source_ip(dest_ip);
ipv4.set_destination_ip(src_ip);
ipv4.update_checksum();
return Ok(());
}
}
Err(anyhow!("非ping Echo 不处理"))
}
}

View File

@@ -3,14 +3,13 @@
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use crate::cipher::RsaCipher;
use crate::core::entity::ClientInfo;
use crate::core::store::cache::{AppCache, Context};
use crate::core::store::cache::{AppCache, LinkVntContext, VntContext};
use crate::error::*;
use crate::protocol::NetPacket;
use crate::ConfigInfo;
use tokio::net::UdpSocket;
use tokio::sync::mpsc::Sender;
#[derive(Clone)]
pub struct ClientPacketHandler {
@@ -37,25 +36,26 @@ impl ClientPacketHandler {
}
impl ClientPacketHandler {
pub fn handle<B: AsRef<[u8]> + AsMut<[u8]>>(
pub async fn handle<B: AsRef<[u8]> + AsMut<[u8]>>(
&self,
context: &VntContext,
net_packet: NetPacket<B>,
addr: SocketAddr,
_addr: SocketAddr,
) -> Result<()> {
if let Some(context) = self.cache.get_context(&addr) {
self.handle0(net_packet, context)
if let Some(context) = &context.link_context {
self.handle0(context, net_packet).await
} else {
Err(Error::Disconnect)
Err(Error::Disconnect)?
}
}
}
impl ClientPacketHandler {
/// 转发到目标地址
fn handle0<B: AsRef<[u8]> + AsMut<[u8]>>(
async fn handle0<B: AsRef<[u8]> + AsMut<[u8]>>(
&self,
context: &LinkVntContext,
mut net_packet: NetPacket<B>,
context: Context,
) -> Result<()> {
if net_packet.incr_ttl() > 1 {
if self.config.check_finger {
@@ -65,33 +65,65 @@ impl ClientPacketHandler {
let destination = net_packet.destination();
if destination.is_broadcast() || self.config.broadcast == destination {
//处理广播
broadcast(&self.udp, context, net_packet);
} else if let Some(client_info) =
context.network_info.read().clients.get(&destination.into())
{
send_one(&self.udp, client_info, &net_packet);
broadcast(context, &self.udp, net_packet).await;
} else {
let is_encrypt = net_packet.is_encrypt();
let source_ip = u32::from(net_packet.source());
let rs = context
.network_info
.read()
.clients
.get(&destination.into())
.filter(|v| {
v.wireguard.is_none()
&& v.online
&& v.client_secret == is_encrypt
&& v.virtual_ip != source_ip
})
.map(|v| (v.address, v.tcp_sender.clone()));
if let Some((peer_addr, peer_tcp_sender)) = rs {
send_one(&self.udp, peer_addr, peer_tcp_sender, &net_packet).await;
}
}
}
Ok(())
}
}
fn broadcast<B: AsRef<[u8]>>(udp_socket: &UdpSocket, context: Context, net_packet: NetPacket<B>) {
for client_info in context.network_info.read().clients.values() {
send_one(udp_socket, client_info, &net_packet);
async fn broadcast<B: AsRef<[u8]>>(
context: &LinkVntContext,
udp_socket: &UdpSocket,
net_packet: NetPacket<B>,
) {
let is_encrypt = net_packet.is_encrypt();
let source_ip = u32::from(net_packet.source());
let x: Vec<_> = context
.network_info
.read()
.clients
.values()
.filter(|v| {
v.wireguard.is_none()
&& v.online
&& v.client_secret == is_encrypt
&& v.virtual_ip != source_ip
})
.map(|v| (v.address, v.tcp_sender.clone()))
.collect();
for (peer_addr, peer_tcp_sender) in x {
send_one(udp_socket, peer_addr, peer_tcp_sender, &net_packet).await;
}
}
fn send_one<B: AsRef<[u8]>>(
async fn send_one<B: AsRef<[u8]>>(
udp_socket: &UdpSocket,
client_info: &ClientInfo,
peer_addr: SocketAddr,
peer_tcp_sender: Option<Sender<Vec<u8>>>,
net_packet: &NetPacket<B>,
) {
if client_info.online && client_info.client_secret == net_packet.is_encrypt() {
if let Some(sender) = &client_info.tcp_sender {
let _ = sender.try_send(net_packet.buffer().to_vec());
} else {
let _ = udp_socket.try_send_to(net_packet.buffer(), client_info.address);
}
if let Some(sender) = &peer_tcp_sender {
let _ = sender.send(net_packet.buffer().to_vec()).await;
} else {
let _ = udp_socket.send_to(net_packet.buffer(), peer_addr).await;
}
}

View File

@@ -7,7 +7,7 @@ use tokio::sync::mpsc::Sender;
use crate::cipher::RsaCipher;
use crate::core::service::client::ClientPacketHandler;
use crate::core::service::server::ServerPacketHandler;
use crate::core::store::cache::AppCache;
use crate::core::store::cache::{AppCache, VntContext};
use crate::error::*;
use crate::protocol::NetPacket;
use crate::ConfigInfo;
@@ -41,13 +41,17 @@ impl PacketHandler {
}
impl PacketHandler {
pub async fn leave(&self, context: VntContext) {
self.server.leave(context).await;
}
pub async fn handle<B: AsRef<[u8]> + AsMut<[u8]>>(
&self,
context: &mut VntContext,
net_packet: NetPacket<B>,
addr: SocketAddr,
tcp_sender: &Option<Sender<Vec<u8>>>,
) -> Option<NetPacket<Vec<u8>>> {
self.handle0(net_packet, addr, tcp_sender)
self.handle0(context, net_packet, addr, tcp_sender)
.await
.unwrap_or_else(|e| {
log::error!("addr={},{:?}", addr, e);
@@ -56,14 +60,17 @@ impl PacketHandler {
}
async fn handle0<B: AsRef<[u8]> + AsMut<[u8]>>(
&self,
context: &mut VntContext,
net_packet: NetPacket<B>,
addr: SocketAddr,
tcp_sender: &Option<Sender<Vec<u8>>>,
) -> Result<Option<NetPacket<Vec<u8>>>> {
if net_packet.is_gateway() {
self.server.handle(net_packet, addr, tcp_sender).await
self.server
.handle(context, net_packet, addr, tcp_sender)
.await
} else {
self.client.handle(net_packet, addr)?;
self.client.handle(context, net_packet, addr).await?;
Ok(None)
}
}

View File

@@ -1,20 +1,20 @@
use anyhow::{anyhow, Context};
use chrono::Local;
use packet::icmp::{icmp, Kind};
use packet::ip::ipv4;
use packet::ip::ipv4::packet::IpV4Packet;
use protobuf::Message;
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use std::{io, result};
use protobuf::Message;
use tokio::net::UdpSocket;
use tokio::sync::mpsc::Sender;
use crate::cipher::{Aes256GcmCipher, Finger, RsaCipher};
use crate::core::entity::{ClientInfo, ClientStatusInfo, NetworkInfo};
use crate::core::store::cache::{AppCache, Context};
use crate::core::entity::{ClientInfo, ClientStatusInfo, NetworkInfo, SimpleClientInfo};
use crate::core::store::cache::{AppCache, LinkVntContext, VntContext};
use crate::error::*;
use crate::proto::message;
use crate::proto::message::{DeviceList, RegistrationRequest, RegistrationResponse};
@@ -48,8 +48,12 @@ impl ServerPacketHandler {
}
impl ServerPacketHandler {
pub async fn leave(&self, context: VntContext) {
context.leave(&self.cache).await;
}
pub async fn handle<B: AsRef<[u8]> + AsMut<[u8]>>(
&self,
context: &mut VntContext,
mut net_packet: NetPacket<B>,
addr: SocketAddr,
tcp_sender: &Option<Sender<Vec<u8>>>,
@@ -66,26 +70,24 @@ impl ServerPacketHandler {
}
service_packet::Protocol::SecretHandshakeRequest => {
// 加密握手
let rs = self.secret_handshake(net_packet, addr).await?;
let rs = self.secret_handshake(context, net_packet, addr).await?;
return Ok(Some(rs));
}
_ => {}
}
}
// 解密
let aes = if net_packet.is_encrypt() {
if let Some(aes) = self.cache.cipher_session.get(&addr) {
let server_secret = net_packet.is_encrypt();
if server_secret {
if let Some(aes) = &context.server_cipher {
aes.decrypt_ipv4(&mut net_packet)?;
Some(aes)
} else {
log::info!("没有密钥:{},head={:?}", addr, net_packet.head());
return Ok(Some(self.handle_err(addr, source, Error::NoKey)?));
return Ok(Some(self.handle_err(addr, source, &Error::NoKey)?));
}
} else {
None
};
}
let mut packet = match self
.handle0(net_packet, addr, tcp_sender, aes.is_some())
.handle0(context, net_packet, addr, tcp_sender, server_secret)
.await
{
Ok(rs) => {
@@ -95,11 +97,13 @@ impl ServerPacketHandler {
return Ok(None);
}
}
Err(e) => self.handle_err(addr, source, e)?,
Err(e) => self.handle_anyhow_err(addr, source, e)?,
};
self.common_param(&mut packet, source);
if let Some(aes) = aes {
aes.encrypt_ipv4(&mut packet)?;
if server_secret {
if let Some(aes) = &context.server_cipher {
aes.encrypt_ipv4(&mut packet)?;
}
}
Ok(Some(packet))
}
@@ -115,20 +119,28 @@ impl ServerPacketHandler {
net_packet.first_set_ttl(MAX_TTL);
net_packet.set_gateway_flag(true);
}
fn handle_anyhow_err(
&self,
addr: SocketAddr,
source: Ipv4Addr,
e: anyhow::Error,
) -> Result<NetPacket<Vec<u8>>> {
if let Some(e) = e.downcast_ref() {
self.handle_err(addr, source, e)
} else {
self.handle_err(addr, source, &Error::Other(format!("{}", e)))
}
}
fn handle_err(
&self,
addr: SocketAddr,
source: Ipv4Addr,
e: Error,
e: &Error,
) -> Result<NetPacket<Vec<u8>>> {
log::warn!("addr={},source={},{:?}", addr, source, e);
let rs = vec![0u8; 12 + ENCRYPTION_RESERVED];
let mut packet = NetPacket::new_encrypt(rs)?;
match e {
Error::Io(_) => {}
Error::Channel(_) => {}
Error::Protobuf(_) => {}
Error::AddressExhausted => {
packet.set_transport_protocol(error_packet::Protocol::AddressExhausted.into());
}
@@ -161,6 +173,7 @@ impl ServerPacketHandler {
}
async fn handle0<B: AsRef<[u8]> + AsMut<[u8]>>(
&self,
context: &mut VntContext,
net_packet: NetPacket<B>,
addr: SocketAddr,
tcp_sender: &Option<Sender<Vec<u8>>>,
@@ -168,7 +181,7 @@ impl ServerPacketHandler {
) -> Result<Option<NetPacket<Vec<u8>>>> {
// 处理不需要连接上下文的请求
let mut net_packet = match self
.not_context(net_packet, addr, tcp_sender, server_secret)
.not_context(context, net_packet, addr, tcp_sender, server_secret)
.await
{
Ok(rs) => {
@@ -177,10 +190,10 @@ impl ServerPacketHandler {
Err(net_packet) => net_packet,
};
// 需要连接的上下文
let context = if let Some(context) = self.cache.get_context(&addr) {
context
let link_context = if let Some(link_context) = &context.link_context {
link_context
} else {
return Err(Error::Disconnect);
return Err(Error::Disconnect)?;
};
match net_packet.protocol() {
@@ -188,13 +201,13 @@ impl ServerPacketHandler {
match protocol::service_packet::Protocol::from(net_packet.transport_protocol()) {
service_packet::Protocol::PullDeviceList => {
//拉取网段设备信息
return self.poll_device_list(net_packet, addr, &context);
return self.poll_device_list(net_packet, addr, &link_context);
}
service_packet::Protocol::ClientStatusInfo => {
//客户端上报信息
let client_status_info =
message::ClientStatusInfo::parse_from_bytes(net_packet.payload())?;
self.up_client_status_info(client_status_info, &context);
self.up_client_status_info(client_status_info, &link_context);
return Ok(None);
}
_ => {}
@@ -205,17 +218,22 @@ impl ServerPacketHandler {
if let control_packet::Protocol::Ping =
protocol::control_packet::Protocol::from(net_packet.transport_protocol())
{
return self.control_ping(net_packet, &context);
return self.control_ping(net_packet, &link_context);
}
}
Protocol::IpTurn => {
match protocol::ip_turn_packet::Protocol::from(net_packet.transport_protocol()) {
protocol::ip_turn_packet::Protocol::WGIpv4 => {
//wg数据转发
self.wg_ipv4(&link_context, net_packet).await?;
return Ok(None);
}
protocol::ip_turn_packet::Protocol::Ipv4Broadcast => {
//处理选择性广播,进过网关还原成原始广播
let broadcast_packet = BroadcastPacket::new(net_packet.payload())?;
let exclude = broadcast_packet.addresses();
let broadcast_net_packet = NetPacket::new(broadcast_packet.data()?)?;
self.broadcast(&context, broadcast_net_packet, &exclude)?;
self.broadcast(&link_context, broadcast_net_packet, &exclude)?;
return Ok(None);
}
protocol::ip_turn_packet::Protocol::Ipv4 => {
@@ -258,6 +276,7 @@ impl ServerPacketHandler {
impl ServerPacketHandler {
async fn not_context<B: AsRef<[u8]>>(
&self,
context: &mut VntContext,
net_packet: NetPacket<B>,
addr: SocketAddr,
tcp_sender: &Option<Sender<Vec<u8>>>,
@@ -269,7 +288,7 @@ impl ServerPacketHandler {
{
//注册
return Ok(self
.register(net_packet, addr, tcp_sender, server_secret)
.register(context, net_packet, addr, tcp_sender, server_secret)
.await);
}
} else if net_packet.protocol() == Protocol::Control {
@@ -287,7 +306,7 @@ impl ServerPacketHandler {
fn control_ping<B: AsRef<[u8]>>(
&self,
net_packet: NetPacket<B>,
context: &Context,
context: &LinkVntContext,
) -> Result<Option<NetPacket<Vec<u8>>>> {
let vec = vec![0u8; 12 + 4 + ENCRYPTION_RESERVED];
let mut packet = NetPacket::new_encrypt(vec)?;
@@ -324,13 +343,13 @@ impl ServerPacketHandler {
impl ServerPacketHandler {
async fn register<B: AsRef<[u8]>>(
&self,
context: &mut VntContext,
net_packet: NetPacket<B>,
addr: SocketAddr,
tcp_sender: &Option<Sender<Vec<u8>>>,
server_secret: bool,
) -> Result<Option<NetPacket<Vec<u8>>>> {
let config = &self.config;
let cache = &self.cache;
let request = RegistrationRequest::parse_from_bytes(net_packet.payload())?;
check_reg(&request)?;
log::info!(
@@ -346,6 +365,8 @@ impl ServerPacketHandler {
tcp_sender.is_some()
);
let group_id = request.token.clone();
let gateway = config.gateway;
let netmask = config.netmask;
if let Some(white_token) = &config.white_token {
if !white_token.contains(&group_id) {
log::info!(
@@ -353,7 +374,7 @@ impl ServerPacketHandler {
white_token,
group_id
);
return Err(Error::TokenError);
Err(Error::TokenError)?
}
}
let mut response = RegistrationResponse::new();
@@ -371,119 +392,46 @@ impl ServerPacketHandler {
}
}
}
//固定网段
let gateway: u32 = config.gateway.into();
let netmask: u32 = config.netmask.into();
let network: u32 = gateway & netmask;
let register_client_request = RegisterClientRequest {
group_id: group_id.clone(),
virtual_ip: request.virtual_ip.into(),
gateway,
netmask,
allow_ip_change: request.allow_ip_change,
device_id: request.device_id,
version: request.version,
name: request.name,
client_secret: request.client_secret,
client_secret_hash: request.client_secret_hash,
server_secret,
address: addr,
tcp_sender: tcp_sender.clone(),
online: true,
wireguard: None,
};
let register_response = generate_ip(&self.cache, register_client_request).await?;
let virtual_ip = register_response.virtual_ip.into();
response.virtual_gateway = gateway.into();
response.virtual_netmask = netmask.into();
response.virtual_ip = virtual_ip;
response.epoch = register_response.epoch as u32;
response.device_info_list = register_response
.client_list
.into_iter()
.map(|v| v.into())
.collect();
context.link_context.replace(LinkVntContext {
network_info: self
.cache
.virtual_network
.get(&group_id)
.context("virtual_network is none")?,
group: group_id.clone(),
virtual_ip,
broadcast: config.broadcast,
timestamp: register_response.timestamp,
});
response.virtual_netmask = netmask;
response.virtual_gateway = gateway;
let v = cache
.virtual_network
.optionally_get_with(group_id.clone(), || {
(
Duration::from_secs(7 * 24 * 3600),
Arc::new(parking_lot::const_rwlock(NetworkInfo::new(
network, netmask, gateway,
))),
)
})
.await;
let mut virtual_ip = request.virtual_ip;
// 可分配的ip段
let ip_range = network + 1..gateway | (!netmask);
let timestamp = Local::now().timestamp();
{
let mut lock = v.write();
let mut insert = true;
if virtual_ip != 0 {
if u32::from(config.gateway) == virtual_ip
|| u32::from(config.broadcast) == virtual_ip
|| !ip_range.contains(&virtual_ip)
{
log::warn!("手动指定的ip无效: {:?}", request);
return Err(Error::InvalidIp);
}
//指定了ip
if let Some(info) = lock.clients.get_mut(&request.virtual_ip) {
if info.device_id != request.device_id {
//ip被占用了,并且不能更改ip
if !request.allow_ip_change {
log::warn!("手动指定的ip已经存在:{:?}", request);
return Err(Error::IpAlreadyExists);
}
// 重新挑选ip
virtual_ip = 0;
} else {
insert = false;
}
}
}
let mut old_ip = 0;
if insert {
// 找到上一次用的ip
for (ip, x) in &lock.clients {
if x.device_id == request.device_id {
if virtual_ip == 0 {
virtual_ip = *ip;
} else {
old_ip = *ip;
}
break;
}
}
}
if virtual_ip == 0 {
// 从小到大找一个未使用的ip
for ip in ip_range {
if ip == lock.gateway_ip {
continue;
}
if !lock.clients.contains_key(&ip) {
virtual_ip = ip;
break;
}
}
}
if virtual_ip == 0 {
log::error!("地址使用完:{:?}", request);
return Err(Error::AddressExhausted);
}
let info = if old_ip == 0 {
lock.clients
.entry(virtual_ip)
.or_insert_with(ClientInfo::default)
} else {
let client_info = lock.clients.remove(&old_ip).unwrap();
lock.clients
.entry(virtual_ip)
.or_insert_with(|| client_info)
};
info.name = request.name;
info.device_id = request.device_id;
info.version = request.version;
info.client_secret = request.client_secret;
info.server_secret = server_secret;
info.address = addr;
info.online = true;
info.virtual_ip = virtual_ip;
info.tcp_sender = tcp_sender.clone();
info.last_join_time = Local::now();
info.timestamp = timestamp;
lock.epoch += 1;
response.virtual_ip = virtual_ip;
response.epoch = lock.epoch as u32;
response.device_info_list = Self::clients_info(&lock.clients, virtual_ip);
drop(lock);
}
cache
.insert_ip_session((group_id.clone(), virtual_ip), addr)
.await;
cache
.insert_addr_session(addr, (group_id, virtual_ip, timestamp))
.await;
let bytes = response.write_to_bytes()?;
let rs = vec![0u8; 12 + bytes.len() + ENCRYPTION_RESERVED];
let mut packet = NetPacket::new_encrypt(rs)?;
@@ -496,13 +444,16 @@ impl ServerPacketHandler {
fn check_reg(request: &RegistrationRequest) -> Result<()> {
if request.token.is_empty() || request.token.len() > 128 {
return Err(Error::Other("group length error".into()));
Err(anyhow!("group length error"))?
}
if request.device_id.is_empty() || request.device_id.len() > 128 {
return Err(Error::Other("device_id length error".into()));
Err(anyhow!("device_id length error"))?
}
if request.name.is_empty() || request.name.len() > 128 {
return Err(Error::Other("name length error".into()));
Err(anyhow!("name length error"))?
}
if request.client_secret_hash.len() > 128 {
Err(anyhow!("client_secret_hash length error"))?
}
Ok(())
}
@@ -535,6 +486,7 @@ impl ServerPacketHandler {
}
async fn secret_handshake<B: AsRef<[u8]>>(
&self,
context: &mut VntContext,
net_packet: NetPacket<B>,
addr: SocketAddr,
) -> Result<NetPacket<Vec<u8>>> {
@@ -545,10 +497,7 @@ impl ServerPacketHandler {
let sync_secret =
message::SecretHandshakeRequest::parse_from_bytes(rsa_secret_body.data())?;
let c = Aes256GcmCipher::new(
sync_secret
.key
.try_into()
.map_err(|_| Error::Other("key err".into()))?,
sync_secret.key.try_into().map_err(|_| anyhow!("key err"))?,
Finger::new(&sync_secret.token),
);
let rs = vec![0u8; 12 + ENCRYPTION_RESERVED];
@@ -557,10 +506,11 @@ impl ServerPacketHandler {
packet.set_transport_protocol(service_packet::Protocol::SecretHandshakeResponse.into());
self.common_param(&mut packet, source);
c.encrypt_ipv4(&mut packet)?;
context.server_cipher.replace(c.clone());
self.cache.insert_cipher_session(addr, c).await;
return Ok(packet);
}
Err(Error::Other("no encryption".into()))
Err(anyhow!("no encryption"))
}
}
@@ -569,15 +519,15 @@ impl ServerPacketHandler {
&self,
_net_packet: NetPacket<B>,
_addr: SocketAddr,
context: &Context,
context: &LinkVntContext,
) -> Result<Option<NetPacket<Vec<u8>>>> {
let guard = context.network_info.read();
let ips = Self::clients_info(&guard.clients, context.virtual_ip);
let ips = clients_info(&guard.clients, context.virtual_ip);
let epoch = guard.epoch;
drop(guard);
let mut device_list = DeviceList::new();
device_list.epoch = epoch as u32;
device_list.device_info_list = ips;
device_list.device_info_list = ips.into_iter().map(|v| v.into()).collect();
let bytes = device_list.write_to_bytes()?;
let vec = vec![0u8; 12 + bytes.len() + ENCRYPTION_RESERVED];
let mut device_list_packet = NetPacket::new_encrypt(vec)?;
@@ -589,7 +539,7 @@ impl ServerPacketHandler {
fn up_client_status_info(
&self,
client_status_info: message::ClientStatusInfo,
context: &Context,
context: &LinkVntContext,
) {
let mut status_info = ClientStatusInfo::default();
let iplist = &mut status_info.p2p_list;
@@ -612,26 +562,40 @@ impl ServerPacketHandler {
v.client_status = Some(status_info);
}
}
fn clients_info(
clients: &HashMap<u32, ClientInfo>,
current_ip: u32,
) -> Vec<message::DeviceInfo> {
clients
.iter()
.filter(|&(_, dev)| dev.virtual_ip != current_ip)
.map(|(_, device_info)| {
let mut dev = message::DeviceInfo::new();
dev.virtual_ip = device_info.virtual_ip;
dev.name = device_info.name.clone();
dev.device_status = if device_info.online { 0 } else { 1 };
dev.client_secret = device_info.client_secret;
dev
})
.collect()
async fn wg_ipv4<B: AsRef<[u8]>>(
&self,
context: &LinkVntContext,
net_packet: NetPacket<B>,
) -> anyhow::Result<()> {
let source = net_packet.source();
let dest = net_packet.destination();
if dest.is_broadcast() || dest == context.broadcast {
// 广播
for peer in context.network_info.read().clients.values() {
if !peer.online {
continue;
}
if let Some(sender) = &peer.wg_sender {
if let Err(e) = sender.try_send((net_packet.payload().to_vec(), source)) {
log::info!("广播到对端wg失败 {}->{},{}", source, dest, e);
}
}
}
} else if let Some(peer) = context.network_info.read().clients.get(&dest.into()) {
// 点对点
if peer.online {
if let Some(sender) = &peer.wg_sender {
if let Err(e) = sender.try_send((net_packet.payload().to_vec(), source)) {
log::info!("发送到对端wg失败 {}->{},{}", source, dest, e);
}
}
}
}
Ok(())
}
fn broadcast<B: AsRef<[u8]>>(
&self,
context: &Context,
context: &LinkVntContext,
net_packet: NetPacket<B>,
exclude: &[Ipv4Addr],
) -> io::Result<()> {
@@ -640,6 +604,7 @@ impl ServerPacketHandler {
if client_info.online
&& !exclude.contains(&(*ip).into())
&& client_info.client_secret == client_secret
&& client_info.wireguard.is_none()
{
if let Some(sender) = &client_info.tcp_sender {
let _ = sender.try_send(net_packet.buffer().to_vec());
@@ -653,3 +618,171 @@ impl ServerPacketHandler {
Ok(())
}
}
pub struct RegisterClientRequest {
pub group_id: String,
// ip 0表示自动分配
pub virtual_ip: Ipv4Addr,
pub gateway: Ipv4Addr,
pub netmask: Ipv4Addr,
// 允许分配不一样的ip
pub allow_ip_change: bool,
// 设备ID
pub device_id: String,
// 版本
pub version: String,
// 名称
pub name: String,
// 客户端间是否加密
pub client_secret: bool,
// 加密hash
pub client_secret_hash: Vec<u8>,
// 和服务端是否加密
pub server_secret: bool,
// 链接服务器的来源地址
pub address: SocketAddr,
pub tcp_sender: Option<Sender<Vec<u8>>>,
// 是否在线
pub online: bool,
// wireguard客户端公钥
pub wireguard: Option<[u8; 32]>,
}
pub struct RegisterClientResponse {
timestamp: i64,
pub virtual_ip: Ipv4Addr,
// 纪元号
pub epoch: u64,
pub client_list: Vec<SimpleClientInfo>,
}
pub async fn generate_ip(
cache: &AppCache,
register_request: RegisterClientRequest,
) -> anyhow::Result<RegisterClientResponse> {
let gateway: u32 = register_request.gateway.into();
let netmask: u32 = register_request.netmask.into();
let network: u32 = gateway & netmask;
let mut virtual_ip: u32 = register_request.virtual_ip.into();
let device_id = register_request.device_id;
let allow_ip_change = register_request.allow_ip_change;
let group_id = register_request.group_id;
let v = cache
.virtual_network
.optionally_get_with(group_id, || {
(
Duration::from_secs(7 * 24 * 3600),
Arc::new(parking_lot::const_rwlock(NetworkInfo::new(
network, netmask, gateway,
))),
)
})
.await;
// 可分配的ip段
let ip_range = network + 1..gateway | (!netmask);
let timestamp = Local::now().timestamp();
let mut lock = v.write();
let mut insert = true;
if virtual_ip != 0 {
if gateway == virtual_ip || !ip_range.contains(&virtual_ip) {
Err(Error::InvalidIp)?
}
//指定了ip
if let Some(info) = lock.clients.get_mut(&virtual_ip) {
if info.device_id != device_id {
//ip被占用了,并且不能更改ip
if !allow_ip_change {
Err(Error::IpAlreadyExists)?
}
// 重新挑选ip
virtual_ip = 0;
} else {
insert = false;
}
}
}
let mut old_ip = 0;
if insert {
// 找到上一次用的ip
for (ip, x) in &lock.clients {
if x.device_id == device_id {
if virtual_ip == 0 {
virtual_ip = *ip;
} else {
old_ip = *ip;
}
break;
}
}
}
if virtual_ip == 0 {
// 从小到大找一个未使用的ip
for ip in ip_range {
if ip == lock.gateway_ip {
continue;
}
if !lock.clients.contains_key(&ip) {
virtual_ip = ip;
break;
}
}
}
if virtual_ip == 0 {
log::error!("地址使用完:{:?}", lock);
Err(Error::AddressExhausted)?
}
let info = if old_ip == 0 {
lock.clients
.entry(virtual_ip)
.or_insert_with(ClientInfo::default)
} else {
let client_info = lock.clients.remove(&old_ip).unwrap();
lock.clients
.entry(virtual_ip)
.or_insert_with(|| client_info)
};
info.name = register_request.name;
info.device_id = device_id;
info.version = register_request.version;
info.client_secret = register_request.client_secret;
info.client_secret_hash = register_request.client_secret_hash;
info.server_secret = register_request.server_secret;
info.address = register_request.address;
info.online = register_request.online;
info.wireguard = register_request.wireguard;
info.virtual_ip = virtual_ip;
info.tcp_sender = register_request.tcp_sender;
info.last_join_time = Local::now();
info.timestamp = timestamp;
lock.epoch += 1;
let response = RegisterClientResponse {
timestamp,
virtual_ip: virtual_ip.into(),
epoch: lock.epoch,
client_list: clients_info(&lock.clients, virtual_ip),
};
Ok(response)
}
fn clients_info(clients: &HashMap<u32, ClientInfo>, current_ip: u32) -> Vec<SimpleClientInfo> {
clients
.iter()
.filter(|&(_, dev)| dev.virtual_ip != current_ip)
.map(|(_, device_info)| device_info.into())
.collect()
}
impl From<SimpleClientInfo> for message::DeviceInfo {
fn from(value: SimpleClientInfo) -> Self {
let mut dev = message::DeviceInfo::new();
dev.virtual_ip = value.virtual_ip;
dev.name = value.name;
dev.device_status = if value.online { 0 } else { 1 };
dev.client_secret = value.client_secret;
if value.online {
dev.client_secret_hash = value.client_secret_hash;
}
dev.wireguard = value.wireguard;
dev
}
}

View File

@@ -1,130 +1,121 @@
use dashmap::DashMap;
use parking_lot::RwLock;
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use parking_lot::RwLock;
use crate::cipher::Aes256GcmCipher;
use crate::core::entity::NetworkInfo;
use crate::core::entity::{NetworkInfo, WireGuardConfig};
use crate::core::store::expire_map::ExpireMap;
#[derive(Clone)]
pub struct AppCache {
// group -> NetworkInfo
pub virtual_network: ExpireMap<String, Arc<RwLock<NetworkInfo>>>,
// (group,ip) -> addr
// (group,ip) -> addr 用于客户端过期,只有客户端离线才设置
pub ip_session: ExpireMap<(String, u32), SocketAddr>,
// addr -> (groupip)
pub addr_session: ExpireMap<SocketAddr, (String, u32, i64)>,
pub cipher_session: ExpireMap<SocketAddr, Arc<Aes256GcmCipher>>,
// 加密密钥
pub cipher_session: Arc<DashMap<SocketAddr, Arc<Aes256GcmCipher>>>,
// web登录状态
pub auth_map: ExpireMap<String, ()>,
// wg公钥 -> wg配置
pub wg_group_map: Arc<DashMap<[u8; 32], WireGuardConfig>>,
}
pub struct Context {
pub struct VntContext {
pub link_context: Option<LinkVntContext>,
pub server_cipher: Option<Aes256GcmCipher>,
pub link_address: SocketAddr,
}
pub struct LinkVntContext {
pub network_info: Arc<RwLock<NetworkInfo>>,
pub group: String,
pub virtual_ip: u32,
pub broadcast: Ipv4Addr,
pub timestamp: i64,
}
impl VntContext {
pub async fn leave(self, cache: &AppCache) {
if self.server_cipher.is_some() {
cache.cipher_session.remove(&self.link_address);
}
if let Some(context) = self.link_context {
if let Some(network_info) = cache.virtual_network.get(&context.group) {
{
let mut guard = network_info.write();
if let Some(client_info) = guard.clients.get_mut(&context.virtual_ip) {
if client_info.address != self.link_address
&& client_info.timestamp != context.timestamp
{
return;
}
client_info.online = false;
client_info.tcp_sender = None;
guard.epoch += 1;
}
drop(guard);
}
cache
.insert_ip_session((context.group, context.virtual_ip), self.link_address)
.await;
}
}
}
}
impl AppCache {
pub fn new() -> Self {
let wg_group_map: Arc<DashMap<[u8; 32], WireGuardConfig>> = Default::default();
// 网段7天未使用则回收
let virtual_network: ExpireMap<String, Arc<RwLock<NetworkInfo>>> =
ExpireMap::new(|_k, _v| {});
let virtual_network_ = virtual_network.clone();
// ip一天未使用则回收
let ip_session: ExpireMap<(String, u32), SocketAddr> =
ExpireMap::new(move |(group_id, ip), addr: SocketAddr| {
log::info!(
"ip_session eviction group_id={},ip={},addr={}",
group_id,
Ipv4Addr::from(ip),
addr
);
if let Some(v) = virtual_network_.get(&group_id) {
let mut lock = v.write();
if let Some(dev) = lock.clients.get(&ip) {
if dev.address == addr {
lock.clients.remove(&ip);
lock.epoch += 1;
}
}
ExpireMap::new(|_k, v: &Arc<RwLock<NetworkInfo>>| {
let lock = v.read();
if !lock.clients.is_empty() {
// 存在客户端的不过期
return Some(Duration::from_secs(7 * 24 * 3600));
}
None
});
let virtual_network_ = virtual_network.clone();
// 20秒钟没有收到消息则判定为掉线
let addr_session = ExpireMap::new(
move |addr: SocketAddr, (group, virtual_ip, timestamp)| {
log::info!(
"addr_session eviction group={},virtual_ip={},addr={},timestamp={}",
group,
Ipv4Addr::from(virtual_ip),
addr,
timestamp
);
if let Some(v) = virtual_network_.get(&group) {
let mut lock = v.write();
if let Some(item) = lock.clients.get_mut(&virtual_ip) {
if item.address != addr || item.timestamp != timestamp {
log::info!(
"无效信息 addr_session eviction group={},virtual_ip={},addr={},timestamp={}",
group,
Ipv4Addr::from(virtual_ip),
addr,
timestamp
);
return;
}
item.online = false;
// ip一天未使用则回收
let ip_session: ExpireMap<(String, u32), SocketAddr> = ExpireMap::new(move |key, addr| {
let (group_id, ip) = &key;
log::info!(
"ip_session eviction group_id={},ip={},addr={}",
group_id,
Ipv4Addr::from(*ip),
addr
);
if let Some(v) = virtual_network_.get(group_id) {
let mut lock = v.write();
if let Some(dev) = lock.clients.get(ip) {
if !dev.online && &dev.address == addr {
lock.clients.remove(ip);
lock.epoch += 1;
}
}
},
);
let cipher_session = ExpireMap::new(|_k, _v| {});
let auth_map = ExpireMap::new(|_k, _v| {});
}
None
});
let auth_map = ExpireMap::new(|_k, _v| None);
Self {
virtual_network,
ip_session,
addr_session,
cipher_session,
cipher_session: Default::default(),
auth_map,
wg_group_map,
}
}
}
impl AppCache {
pub fn get_context(&self, addr: &SocketAddr) -> Option<Context> {
if let Some((group, virtual_ip, _)) = self.addr_session.get(addr) {
let k = (group, virtual_ip);
self.ip_session.get(&k)?;
let (group, virtual_ip) = k;
return self
.virtual_network
.get(&group)
.map(|network_info| Context {
network_info,
group,
virtual_ip,
});
}
None
}
pub async fn insert_cipher_session(&self, key: SocketAddr, value: Aes256GcmCipher) {
self.cipher_session
.insert(key, Arc::new(value), Duration::from_secs(120))
.await
self.cipher_session.insert(key, Arc::new(value));
}
pub async fn insert_ip_session(&self, key: (String, u32), value: SocketAddr) {
self.ip_session
.insert(key, value, Duration::from_secs(24 * 3600))
.await
}
pub async fn insert_addr_session(&self, key: SocketAddr, value: (String, u32, i64)) {
self.addr_session
.insert(key, value, Duration::from_secs(20))
.await
}
}

View File

@@ -26,7 +26,7 @@ struct Value<V> {
impl<K, V> ExpireMap<K, V> {
pub fn new<F>(call: F) -> ExpireMap<K, V>
where
F: Fn(K, V) + Send + 'static,
F: Fn(&K, &V) -> Option<Duration> + Send + 'static,
K: Eq + Hash + Clone + Sync + Send + 'static,
V: Clone + Sync + Send + 'static,
{
@@ -66,6 +66,14 @@ where
.await
.unwrap();
}
/// remove出去的不会执行过期回调
pub fn remove(&self, k: &K) -> Option<V> {
if let Some(v) = self.base.write().remove(k) {
Some(v.val)
} else {
None
}
}
pub fn get(&self, k: &K) -> Option<V> {
if let Some(v) = self.base.read().get(k) {
// 延长过期时间
@@ -78,7 +86,10 @@ where
pub fn get_val(&self, k: &K) -> Option<V> {
self.base.read().get(k).map(|v| v.val.clone())
}
fn expire_call(&self, k: &K) -> Op<K, V> {
fn expire_call<F>(&self, k: &K, f: &F) -> Op<K, V>
where
F: Fn(&K, &V) -> Option<Duration>,
{
let mut write_guard = self.base.write();
if let Some(v) = write_guard.get(k) {
let now = Instant::now();
@@ -87,10 +98,14 @@ where
// 过期时间更新了
return Op::Reset(instant);
} else {
//删除key
if let Some((k, v)) = write_guard.remove_entry(k) {
//执行回调
return Op::Remove(k, v.val);
//执行过期回调
if let Some(v) = f(k, &v.val) {
return Op::Reset(now.add(v));
} else {
//删除key
if let Some((k, v)) = write_guard.remove_entry(k) {
return Op::Remove(k, v.val);
}
}
}
}
@@ -142,7 +157,7 @@ async fn expire_task<K, V, F>(mut receiver: Receiver<DelayedTask<K>>, map: Expir
where
K: Eq + Hash + Clone,
V: Clone,
F: Fn(K, V),
F: Fn(&K, &V) -> Option<Duration>,
{
let mut binary_heap = BinaryHeap::<DelayedTask<K>>::with_capacity(32);
loop {
@@ -165,17 +180,13 @@ where
}
} else if let Some(mut task) = binary_heap.pop() {
//执行过期逻辑
match map.expire_call(&task.k) {
match map.expire_call(&task.k, &f) {
Op::Reset(time) => {
//没有过期,重新加入监听
task.time = time;
binary_heap.push(task);
}
Op::Remove(k, v) => {
//执行回调
f(k, v)
}
Op::Remove(_, _) => {}
Op::None => {}
}
}

View File

@@ -1,18 +1,9 @@
#![allow(dead_code, clippy::enum_variant_names)]
use std::io;
use crossbeam::channel::RecvError;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Io error")]
Io(#[from] io::Error),
#[error("Channel error")]
Channel(#[from] RecvError),
#[error("Protobuf error")]
Protobuf(#[from] protobuf::Error),
#[error("Disconnect")]
Disconnect,
#[error("No Key")]
@@ -29,4 +20,4 @@ pub enum Error {
Other(String),
}
pub type Result<T> = std::result::Result<T, Error>;
pub type Result<T> = anyhow::Result<T>;

View File

@@ -1,12 +1,16 @@
use aes_gcm::aead::rand_core::RngCore;
use anyhow::{anyhow, Context};
use base64::engine::general_purpose;
use base64::Engine;
use boringtun::x25519::{PublicKey, StaticSecret};
use clap::Parser;
use std::collections::HashSet;
use std::fmt::Display;
use std::fmt::{Debug, Display, Formatter};
use std::io;
use std::io::Write;
use std::net::Ipv4Addr;
use std::path::PathBuf;
use clap::Parser;
use crate::cipher::RsaCipher;
mod cipher;
@@ -15,6 +19,7 @@ mod error;
mod generated_serial_number;
mod proto;
mod protocol;
pub const VNT_VERSION: &str = env!("CARGO_PKG_VERSION");
/// 默认网关信息
@@ -56,9 +61,12 @@ pub struct StartArgs {
/// web后台用户密码默认为admin
#[arg(short = 'W', long)]
password: Option<String>,
/// wg私钥使用base64编码
#[arg(long = "wg")]
wg_secret_key: Option<String>,
}
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct ConfigInfo {
pub port: u16,
pub white_token: Option<HashSet<String>>,
@@ -70,6 +78,28 @@ pub struct ConfigInfo {
pub username: String,
#[cfg(feature = "web")]
pub password: String,
pub wg_secret_key: StaticSecret,
pub wg_public_key: PublicKey,
}
impl Debug for ConfigInfo {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConfigInfo")
.field("port", &self.port)
.field("white_token", &self.white_token)
.field("gateway", &self.gateway)
.field("broadcast", &self.broadcast)
.field("netmask", &self.netmask)
.field("check_finger", &self.check_finger)
.field(
"wg_secret_key",
&general_purpose::STANDARD.encode(&self.wg_secret_key),
)
.field(
"wg_public_key",
&general_purpose::STANDARD.encode(&self.wg_public_key),
)
.finish()
}
}
fn log_init(root_path: PathBuf, log_path: Option<String>) {
@@ -231,6 +261,22 @@ async fn main() {
if check_finger {
println!("转发校验数据指纹,客户端必须增加--finger参数");
}
let wg_secret_key: [u8; 32] = if let Some(wg_secret_key) = args.wg_secret_key {
let wg_secret_key = general_purpose::STANDARD
.decode(wg_secret_key)
.context("wg私钥错误")
.unwrap();
wg_secret_key
.try_into()
.map_err(|_| anyhow!("wg私钥错误"))
.unwrap()
} else {
let mut wg_secret_key = [0u8; 32];
rand::thread_rng().fill_bytes(&mut wg_secret_key);
wg_secret_key
};
let wg_secret_key = boringtun::x25519::StaticSecret::from(wg_secret_key);
let wg_public_key = boringtun::x25519::PublicKey::from(&wg_secret_key);
let config = ConfigInfo {
port,
white_token,
@@ -242,6 +288,8 @@ async fn main() {
username: args.username.unwrap_or_else(|| "admin".into()),
#[cfg(feature = "web")]
password: args.password.unwrap_or_else(|| "admin".into()),
wg_secret_key,
wg_public_key,
};
let rsa = match RsaCipher::new(root_path) {
Ok(rsa) => {

View File

@@ -6,6 +6,7 @@ use std::net::Ipv4Addr;
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum Protocol {
Ipv4,
WGIpv4,
Ipv4Broadcast,
Unknown(u8),
}
@@ -14,6 +15,7 @@ impl From<u8> for Protocol {
fn from(value: u8) -> Self {
match value {
4 => Protocol::Ipv4,
5 => Protocol::WGIpv4,
201 => Protocol::Ipv4Broadcast,
val => Protocol::Unknown(val),
}
@@ -24,6 +26,7 @@ impl From<Protocol> for u8 {
fn from(val: Protocol) -> Self {
match val {
Protocol::Ipv4 => 4,
Protocol::WGIpv4 => 5,
Protocol::Ipv4Broadcast => 201,
Protocol::Unknown(val) => val,
}

View File

@@ -0,0 +1,174 @@
.option-cell {
display: flex;
gap: 10px; /* 间隔 */
}
.option-cell button {
padding: 5px 10px;
font-size: 14px;
border: none;
border-radius: 4px;
cursor: pointer;
transition: background-color 0.3s ease;
}
.option-cell button.delete-button {
background-color: #f44336; /* 红色 */
color: white;
}
.option-cell button.delete-button:hover {
background-color: #d32f2f; /* 深红色 */
}
.option-cell button.view-button {
background-color: #4CAF50; /* 绿色 */
color: white;
}
.option-cell button.view-button:hover {
background-color: #388E3C; /* 深绿色 */
}
/* wg弹窗样式 */
.modal {
display: none; /* 默认隐藏 */
position: fixed;
z-index: 1;
left: 0;
top: 0;
width: 100%;
height: 100%;
overflow: auto;
background-color: rgb(0,0,0);
background-color: rgba(0,0,0,0.4);
padding-top: 60px;
}
.modal-content {
position: relative;
background-color: #fefefe;
width: 380px;
height: 380px;
margin: 5% auto;
padding: 50px;
border: 1px solid #888;
text-align: center;
box-sizing: border-box;
border-radius: 5px;
}
.add-modal-content{
position: relative;
background-color: #fefefe;
margin: 5% auto;
padding: 20px;
border: 1px solid #888;
width: 80%;
max-width: 600px;
box-sizing: border-box;
border-radius: 5px;
}
.form-group {
display: flex;
align-items: center;
margin: 10px 0;
}
.form-group label {
flex: 1;
margin-right: 10px;
}
.form-group input {
flex: 2;
padding: 10px;
box-sizing: border-box;
}
.modal button {
padding: 10px 20px;
margin: 10px 5px;
}
.button-container {
text-align: center;
}
.button-container button {
padding: 10px 20px;
margin: 10px 5px;
border: none;
border-radius: 5px;
cursor: pointer;
font-size: 16px;
transition: background-color 0.3s, box-shadow 0.3s;
}
.error {
color: red;
font-size: 14px;
}
#confirmButton {
background-color: #4CAF50; /* 绿色背景 */
color: white;
}
#confirmButton:hover {
background-color: #45a049; /* 鼠标悬停时变暗 */
box-shadow: 0 0 10px rgba(0, 0, 0, 0.2); /* 阴影效果 */
}
#cancelButton {
background-color: #f44336; /* 红色背景 */
color: white;
}
#cancelButton:hover {
background-color: #e53935; /* 鼠标悬停时变暗 */
box-shadow: 0 0 10px rgba(0, 0, 0, 0.2); /* 阴影效果 */
}
.modal-content .title{
position: absolute;
left: 20px;
top: 10px;
}
.close {
position: absolute;
right: 20px;
top: 10px;
color: #aaa;
font-size: 28px;
font-weight: bold;
}
.close:hover,
.close:focus {
color: black;
text-decoration: none;
cursor: pointer;
}
.hidden {
display: none;
}
.visible {
display: block;
}
#qrcode {
margin: 0 auto; /* 居中 */
width: 260px; /* 固定宽度 */
height: 260px; /* 固定高度 */
}
pre {
text-align: left; /* 左对齐 */
white-space: pre-wrap; /* 自动换行 */
word-break: break-all;
user-select: auto;
}
#toggleButton{
position: absolute;
bottom: 10px;
left: 116px;
}

View File

@@ -123,7 +123,27 @@ body {
padding-top: 20px;
margin-left: 300px;
}
button{
padding: 10px 20px;
margin: 10px 5px;
border: none;
border-radius: 5px;
cursor: pointer;
font-size: 16px;
transition: background-color 0.3s, box-shadow 0.3s;
}
#addWireGuard {
background-color: #4CAF50; /* 绿色背景 */
color: white;
position: absolute;
right: 50px;
top:20px;
}
#addWireGuard:hover {
background-color: #45a049; /* 鼠标悬停时变暗 */
box-shadow: 0 0 10px rgba(0, 0, 0, 0.2); /* 阴影效果 */
}
/* 下拉菜单按钮 */

View File

@@ -4,9 +4,11 @@
<meta charset="UTF-8">
<script src="./js/jquery-3.7.1.min.js"></script>
<script src="./js/g6.min.js"></script>
<script src="./js/qrcode.min.js"></script>
<script src="./js/api-post.js"></script>
<link rel="stylesheet" href="./css/select.css">
<link rel="stylesheet" href="./css/index.css">
<title>vnts-web</title>
<style>
@@ -91,7 +93,7 @@
<div class="topBox">
<div class="select-content">
<!-- <label>组网标识:</label>-->
<input type="hidden" name="newMachineId">
<input type="hidden" name="groupId">
<input type="text" name="select_input" id="select_input" class="select-input" value="" autocomplete="off"
placeholder="Search..."/>
<div id="search_select" class="search-select">
@@ -101,6 +103,7 @@
</div>
<span class="group_len"></span>
</div>
<button id="addWireGuard">接入WireGuard客户端</button>
</div>
<div id="container_content">
<div id="group_info">
@@ -125,6 +128,7 @@
<th>连接时间</th>
<th>链接地址</th>
<th>设备 ID</th>
<th>操作</th>
</tr>
</thead>
<tbody>
@@ -134,7 +138,51 @@
</div>
</div>
<!-- wg信息弹窗 -->
<div id="wgConfigModal" class="modal">
<div class="modal-content">
<span class="title">使用WireGuard客户端接入</span>
<span class="close">&times;</span>
<div id="qrcode" class="visible"></div>
<pre id="textConfig" class="hidden"></pre>
<button id="toggleButton">显示文本配置</button>
</div>
</div>
<!-- 添加wg弹窗 -->
<div id="addModal" class="modal">
<div class="add-modal-content">
<span class="close">&times;</span>
<h2>WireGuard配置</h2>
<div class="form-group">
<label for="groupId">组网编号</label>
<input type="text" id="groupId" placeholder="组网编号">
</div>
<div class="form-group">
<label for="virtualIP">虚拟IP</label>
<input type="text" id="virtualIP" placeholder="为空则自动分配">
</div>
<div class="form-group">
<label for="deviceName">设备名称</label>
<input type="text" id="deviceName" placeholder="设备名称">
</div>
<div class="form-group">
<label for="privateKey">PrivateKey</label>
<input type="text" id="privateKey" placeholder="PrivateKey">
</div>
<div class="form-group">
<label for="endpoint">Endpoint</label>
<input type="text" id="endpoint" placeholder="服务器UDP链接地址">
</div>
<div class="form-group">
<label for="persistentKeepalive">PersistentKeepalive</label>
<input type="number" id="persistentKeepalive" value="10" placeholder="PersistentKeepalive">
</div>
<div class="button-container">
<button id="confirmButton">确认</button>
</div>
<div class="error" id="addWGError"></div>
</div>
</div>
</body>
<script src="js/group-node.js"></script>
@@ -158,7 +206,7 @@
}
var options = '';
for (var i = 0; i < listArr.length; i++) {
opt = '<li class="li-select" data-newMachineId="' + listArr[i] + '">' + listArr[i] + '</li>';
opt = '<li class="li-select" data-groupId="' + listArr[i] + '">' + listArr[i] + '</li>';
options += opt;
}
if (options == '') {
@@ -196,12 +244,11 @@
$('#select_ul').delegate('.li-select', 'click', function () {
$('#select_ul .li-select').removeClass('li-hover');
var selectText = $(this).html();
var newMachineIdVal = $($(this)[0]).attr("data-newMachineId");
var groupIdVal = $($(this)[0]).attr("data-groupId");
$('#select_input').val(selectText);
$('#search_select').hide();
$("input[name='newMachineId']").val(newMachineIdVal);
console.log(newMachineIdVal);
getGroupInfoFunc(newMachineIdVal)
$("input[name='groupId']").val(groupIdVal);
getGroupInfoFunc(groupIdVal)
});
$('#select_ul').delegate('.li-select', 'mouseover', function () {
@@ -215,8 +262,10 @@
$('.group_len').html('(' + tempArr.length + ')')
}
function displayDeviceInfo(devices) {
function displayDeviceInfo(groupId, data) {
let devices = data.clients;
const tableBody = document.getElementById('deviceTable').getElementsByTagName('tbody')[0];
tableBody.innerHTML = "";
devices.forEach(device => {
const row = document.createElement('tr');
@@ -247,7 +296,7 @@
const lastJoinTimeCell = document.createElement('td');
lastJoinTimeCell.textContent = device.last_join_time;
row.appendChild(lastJoinTimeCell);
const addressCell = document.createElement('td');
addressCell.textContent = device.address;
row.appendChild(addressCell);
@@ -255,12 +304,132 @@
const deviceIdCell = document.createElement('td');
deviceIdCell.textContent = device.device_id;
row.appendChild(deviceIdCell);
// 操作栏
const optionCell = document.createElement('td');
optionCell.className = 'option-cell';
const deleteButton = document.createElement('button');
optionCell.appendChild(deleteButton);
deleteButton.className = 'delete-button';
deleteButton.textContent = '删除';
deleteButton.onclick = function () {
const confirmed = window.confirm('你确定要删除这条记录吗?');
if (confirmed) {
postRemoveClient({'group_id': groupId, 'virtual_ip': device.virtual_ip}, function () {
location.reload();
});
row.remove();
}
};
let wg_config = device.wg_config;
if (wg_config) {
function generateWireguardConfig(privateKey, ip, prefix, publicKey, allowedIPs, endpoint, persistentKeepalive) {
return `[Interface]
PrivateKey = ${privateKey}
Address = ${ip}/${prefix}
[Peer]
PublicKey = ${publicKey}
AllowedIPs = ${allowedIPs}
Endpoint = ${endpoint}
PersistentKeepalive = ${persistentKeepalive}`;
}
const wireguardConfig = generateWireguardConfig(wg_config.secret_key, wg_config.ip, wg_config.prefix,
data.vnts_public_key, wg_config.vnts_allowed_ips, wg_config.vnts_endpoint, wg_config.persistent_keepalive);
console.log(wireguardConfig);
const viewButton = document.createElement('button');
viewButton.textContent = '接入';
viewButton.className = 'view-button';
viewButton.onclick = function () {
// 显示弹窗
$('#wgConfigModal').show();
// 设置文本配置
$('#textConfig').text(wireguardConfig);
$('#qrcode').html('');
// 生成二维码
new QRCode(document.getElementById("qrcode"), {
correctLevel: 3,
text: wireguardConfig,
width: 256,
height: 256
});
};
optionCell.appendChild(viewButton);
}
row.appendChild(optionCell);
tableBody.appendChild(row);
});
}
</script>
<script>
let groupId;
let deviceId = 'wg' + new Date().getTime();
$('#addWireGuard').on('click', function () {
$('#groupId').val(groupId);
$('#deviceName').val('WireGuard');
postWgPrivateKey({}, function (data) {
$('#privateKey').val(data.data);
$('.error').text('');
$('#addModal').show();
})
});
$('#confirmButton').on('click', function () {
// 清除以前的错误信息
$('.error').text('');
// 获取输入值
const groupId = $('#groupId').val().trim();
const virtualIP = $('#virtualIP').val().trim();
const deviceName = $('#deviceName').val().trim();
const privateKey = $('#privateKey').val().trim();
const endpoint = $('#endpoint').val().trim();
const persistentKeepalive = $('#persistentKeepalive').val().trim();
// 校验输入值
if (!groupId) {
$('#addWGError').text('组网编号不能为空');
return
}
if (!deviceName) {
$('#addWGError').text('设备名称不能为空');
return;
}
if (!privateKey) {
$('#addWGError').text('PrivateKey不能为空');
return;
}
if (!endpoint) {
$('#addWGError').text('Endpoint不能为空');
return;
}
if (persistentKeepalive === '') {
$('#addWGError').text('PersistentKeepalive不能为空');
return;
} else if (isNaN(persistentKeepalive) || persistentKeepalive < 0 || persistentKeepalive > 65535) {
$('#addWGError').text('PersistentKeepalive必须是0~65535');
return;
}
postCreateWG({
'group_id': groupId, 'virtual_ip': virtualIP, 'device_id': deviceId, 'name': deviceName,
'config': {
'vnts_endpoint': endpoint,
'private_key': privateKey,
'persistent_keepalive': parseInt(persistentKeepalive),
}
}, function (data) {
if (data && data.code === 200) {
location.reload();
$('#addModal').hide(); // 关闭模态框
} else {
$('#addWGError').text('添加失败:' + data.message);
}
})
});
function formatBytes(bytes) {
if (!bytes) {
return ''
@@ -361,6 +530,7 @@
graph.render();
}
let getGroupInfoFunc = function (group) {
groupId = group;
postGroupInfo({'group': group}, function (response) {
if (response && response.code === 200) {
let data = response.data;
@@ -368,7 +538,7 @@
$('.mask_ip').html(data.mask_ip);
$('.network_ip').html(data.network_ip);
nodeInitFunc(data.gateway_ip, data.clients);
displayDeviceInfo(data.clients);
displayDeviceInfo(group, data);
} else {
window.alert("调用服务失败")
}
@@ -380,6 +550,20 @@
console.log(group_list);
searchInput(group_list);
});
$('.close').on('click', function () {
$('#wgConfigModal').hide();
$('#addModal').hide();
});
$('#toggleButton').on('click', function () {
if ($('#qrcode').hasClass('visible')) {
$('#qrcode').removeClass('visible').addClass('hidden');
$('#textConfig').removeClass('hidden').addClass('visible');
$('#toggleButton').text('显示二维码');
} else {
$('#qrcode').removeClass('hidden').addClass('visible');
$('#textConfig').removeClass('visible').addClass('hidden');
$('#toggleButton').text('显示文本配置');
}
});
</script>
</html>

View File

@@ -64,13 +64,25 @@ function post(url, data, success, error,) {
}
function postLogin(requestData, success, error) {
post("login", requestData, success, error)
post("api/login", requestData, success, error)
}
function postGroupList(requestData, success, error) {
post("group_list", requestData, success, error)
post("api/group_list", requestData, success, error)
}
function postGroupInfo(requestData, success, error) {
post("group_info", requestData, success, error)
post("api/group_info", requestData, success, error)
}
function postWgPrivateKey(requestData, success, error) {
post("api/private_key", requestData, success, error)
}
function postCreateWG(requestData, success, error) {
post("api/create_wg_config", requestData, success, error)
}
function postRemoveClient(requestData, success, error) {
post("api/remove_client", requestData, success, error)
}

1
static/js/qrcode.min.js vendored Normal file

File diff suppressed because one or more lines are too long