diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 89def52..9e816e9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -38,36 +38,53 @@ jobs: include: - TARGET: i686-unknown-linux-musl # test in an alpine container on a mac OS: ubuntu-latest + FEATURES: normal - TARGET: x86_64-unknown-linux-gnu # tested in a debian container on a mac OS: ubuntu-latest + FEATURES: ring-cipher - TARGET: x86_64-unknown-linux-musl # test in an alpine container on a mac OS: ubuntu-latest + FEATURES: ring-cipher - TARGET: aarch64-unknown-linux-gnu # tested on aws t4g.nano OS: ubuntu-latest + FEATURES: ring-cipher - TARGET: aarch64-unknown-linux-musl # tested on aws t4g.nano in alpine container OS: ubuntu-latest + FEATURES: normal - TARGET: armv7-unknown-linux-gnueabihf # raspberry pi 2-3-4, not tested OS: ubuntu-latest + FEATURES: ring-cipher - TARGET: armv7-unknown-linux-musleabihf # raspberry pi 2-3-4, not tested OS: ubuntu-latest + FEATURES: normal - TARGET: arm-unknown-linux-gnueabihf # raspberry pi 0-1, not tested OS: ubuntu-latest + FEATURES: ring-cipher - TARGET: arm-unknown-linux-musleabihf # raspberry pi 0-1, not tested OS: ubuntu-latest + FEATURES: normal - TARGET: x86_64-apple-darwin # tested on a mac, is not properly signed so there are security warnings OS: macos-latest + FEATURES: ring-cipher - TARGET: aarch64-apple-darwin # tested on a mac, is not properly signed so there are security warnings OS: macos-latest + FEATURES: ring-cipher - TARGET: i686-pc-windows-msvc # tested on a windows machine OS: windows-latest + FEATURES: ring-cipher - TARGET: x86_64-pc-windows-msvc # tested on a windows machine OS: windows-latest + FEATURES: ring-cipher + - TARGET: mipsel-unknown-linux-musl # openwrt + OS: ubuntu-latest + FEATURES: normal # needs: test runs-on: ${{ matrix.OS }} env: NAME: vnts # change with the name of your project TARGET: ${{ matrix.TARGET }} OS: ${{ matrix.OS }} + FEATURES: ${{ matrix.FEATURES }} steps: - uses: actions/checkout@v2 - name: Cargo cache @@ -84,28 +101,51 @@ jobs: # dependencies are only needed on ubuntu as that's the only place where # we make cross-compilation if [[ $OS =~ ^ubuntu.*$ ]]; then - sudo apt-get update && sudo apt-get install -qq crossbuild-essential-arm64 crossbuild-essential-armhf + sudo apt-get update && sudo apt-get install -qq crossbuild-essential-arm64 crossbuild-essential-armhf musl-tools gcc-mipsel-linux-gnu fi # some additional configuration for cross-compilation on linux cat >>~/.cargo/config < 子网掩码,例如 --netmask 255.255.255.0 -h, --help Print help ``` -默认情况服务日志输出在 './log/'下,可通过编写'./log/log4rs.yaml'文件自定义日志配置,参考[log4rs](https://github.com/estk/log4rs) + +## 说明 +1. 修改服务端密钥后,客户端要重启才能正常链接(修改密钥后无法自动重连) +2. 服务端密钥用于加密客户端和服务端之间传输的数据(使用rsa+aes256gcm加密),可以防止token被中间人窃取,如果客户端显示的密钥指纹和服务端的不一致,则表示可能有中间人攻击 +3. 服务端密钥在'./key/'目录下,可以替换成自定义的密钥对 +4. 客户端的密码用于加密客户端之间传输的数据 +5. 默认情况服务日志输出在 './log/'下,可通过编写'./log/log4rs.yaml'文件自定义日志配置,参考[log4rs](https://github.com/estk/log4rs) diff --git a/packet/src/arp/arp.rs b/packet/src/arp/arp.rs index 017e5e5..ea027a3 100644 --- a/packet/src/arp/arp.rs +++ b/packet/src/arp/arp.rs @@ -3,12 +3,12 @@ use std::{fmt, io}; /// 地址解析协议,由IP地址找到MAC地址 /// https://www.ietf.org/rfc/rfc6747.txt /* - 0 2 4 5 6 8 10 (字节) - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | 硬件类型|协议类型|硬件地址长度|协议地址长度|操作类型| - | 源MAC地址 | 源ip地址 | - | 目的MAC地址 | 目的ip地址 | - */ + 0 2 4 5 6 8 10 (字节) + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 硬件类型|协议类型|硬件地址长度|协议地址长度|操作类型| + | 源MAC地址 | 源ip地址 | + | 目的MAC地址 | 目的ip地址 | +*/ pub struct ArpPacket { buffer: B, @@ -119,4 +119,4 @@ impl> fmt::Debug for ArpPacket { .field("target_protocol_addr", &self.target_protocol_addr()) .finish() } -} \ No newline at end of file +} diff --git a/packet/src/arp/mod.rs b/packet/src/arp/mod.rs index 2dc8ecf..6a5d36a 100644 --- a/packet/src/arp/mod.rs +++ b/packet/src/arp/mod.rs @@ -1 +1 @@ -pub mod arp; \ No newline at end of file +pub mod arp; diff --git a/packet/src/ethernet/mod.rs b/packet/src/ethernet/mod.rs index d0f3a64..c9cf115 100644 --- a/packet/src/ethernet/mod.rs +++ b/packet/src/ethernet/mod.rs @@ -1,2 +1,2 @@ pub mod packet; -pub mod protocol; \ No newline at end of file +pub mod protocol; diff --git a/packet/src/ethernet/packet.rs b/packet/src/ethernet/packet.rs index 8a2a918..f54619e 100644 --- a/packet/src/ethernet/packet.rs +++ b/packet/src/ethernet/packet.rs @@ -1,13 +1,13 @@ -use std::{fmt, io}; use crate::ethernet::protocol::Protocol; +use std::{fmt, io}; /// 以太网帧协议 /// https://www.ietf.org/rfc/rfc894.txt /* - 0 6 12 14 (字节) - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | 目的地址 | 源地址 | 类型 | - */ + 0 6 12 14 (字节) + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 目的地址 | 源地址 | 类型 | +*/ pub struct EthernetPacket { pub buffer: B, } @@ -74,4 +74,4 @@ impl> fmt::Debug for EthernetPacket { .field("payload", &self.payload()) .finish() } -} \ No newline at end of file +} diff --git a/packet/src/ethernet/protocol.rs b/packet/src/ethernet/protocol.rs index 6a1df95..e775a28 100644 --- a/packet/src/ethernet/protocol.rs +++ b/packet/src/ethernet/protocol.rs @@ -102,7 +102,7 @@ impl From for Protocol { 0x88f7 => Ptp, 0x8902 => Cfm, 0x9100 => QinQ, - n => Unknown(n), + n => Unknown(n), } } } @@ -112,30 +112,30 @@ impl Into for Protocol { use self::Protocol::*; match self { - Ipv4 => 0x0800, - Arp => 0x0806, - WakeOnLan => 0x0842, - Trill => 0x22f3, - DecNet => 0x6003, - Rarp => 0x8035, - AppleTalk => 0x809b, - Aarp => 0x80f3, - Ipx => 0x8137, - Qnx => 0x8204, - Ipv6 => 0x86dd, - FlowControl => 0x8808, - CobraNet => 0x8819, - Mpls => 0x8847, - MplsMulticast => 0x8848, + Ipv4 => 0x0800, + Arp => 0x0806, + WakeOnLan => 0x0842, + Trill => 0x22f3, + DecNet => 0x6003, + Rarp => 0x8035, + AppleTalk => 0x809b, + Aarp => 0x80f3, + Ipx => 0x8137, + Qnx => 0x8204, + Ipv6 => 0x86dd, + FlowControl => 0x8808, + CobraNet => 0x8819, + Mpls => 0x8847, + MplsMulticast => 0x8848, PppoeDiscovery => 0x8863, - PppoeSession => 0x8864, - Vlan => 0x8100, - PBridge => 0x88a8, - Lldp => 0x88cc, - Ptp => 0x88f7, - Cfm => 0x8902, - QinQ => 0x9100, - Unknown(n) => n, + PppoeSession => 0x8864, + Vlan => 0x8100, + PBridge => 0x88a8, + Lldp => 0x88cc, + Ptp => 0x88f7, + Cfm => 0x8902, + QinQ => 0x9100, + Unknown(n) => n, } } } diff --git a/packet/src/icmp/icmp.rs b/packet/src/icmp/icmp.rs index 5515d83..4ae4ab5 100644 --- a/packet/src/icmp/icmp.rs +++ b/packet/src/icmp/icmp.rs @@ -1,8 +1,8 @@ -use std::{fmt, io}; -use byteorder::{BigEndian, ReadBytesExt}; use crate::cal_checksum; use crate::icmp::{Code, Kind}; use crate::ip::ipv4::packet::IpV4Packet; +use byteorder::{BigEndian, ReadBytesExt}; +use std::{fmt, io}; /// icmp 协议 /* https://www.rfc-editor.org/rfc/rfc792 @@ -67,7 +67,7 @@ impl> IcmpPacket { | Kind::TimestampReply | Kind::InformationRequest | Kind::InformationReply => { - let ide =u16::from_be_bytes(self.buffer.as_ref()[4..6].try_into().unwrap()); + let ide = u16::from_be_bytes(self.buffer.as_ref()[4..6].try_into().unwrap()); let seq = u16::from_be_bytes(self.buffer.as_ref()[6..8].try_into().unwrap()); HeaderOther::Identifier(ide, seq) } @@ -121,11 +121,11 @@ impl> fmt::Debug for IcmpPacket { } else { "icmp::Packet!" }) - .field("kind", &self.kind()) - .field("code", &self.code()) - .field("checksum", &self.checksum()) - .field("payload", &self.payload()) - .finish() + .field("kind", &self.kind()) + .field("code", &self.code()) + .field("checksum", &self.checksum()) + .field("payload", &self.payload()) + .finish() } } diff --git a/packet/src/igmp/igmp_v1.rs b/packet/src/igmp/igmp_v1.rs index 10dcbde..58617f6 100644 --- a/packet/src/igmp/igmp_v1.rs +++ b/packet/src/igmp/igmp_v1.rs @@ -1,17 +1,17 @@ -use std::{fmt, io}; -use std::net::Ipv4Addr; use crate::cal_checksum; +use std::net::Ipv4Addr; +use std::{fmt, io}; /// igmp v1 /* https://datatracker.ietf.org/doc/html/rfc1112 - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |Version| Type | Unused | Checksum | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Group Address | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |Version| Type | Unused | Checksum | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Group Address | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ /// v1版本的报文 pub struct IgmpV1Packet { pub buffer: B, @@ -43,7 +43,7 @@ impl Into for IgmpV1Type { match self { IgmpV1Type::Query => 0x11, IgmpV1Type::ReportV1 => 0x12, - IgmpV1Type::Unknown(v) => v + IgmpV1Type::Unknown(v) => v, } } } @@ -114,4 +114,4 @@ impl> fmt::Debug for IgmpV1Packet { .field("group_address", &self.group_address()) .finish() } -} \ No newline at end of file +} diff --git a/packet/src/igmp/igmp_v2.rs b/packet/src/igmp/igmp_v2.rs index 297b5b8..13e65cd 100644 --- a/packet/src/igmp/igmp_v2.rs +++ b/packet/src/igmp/igmp_v2.rs @@ -1,18 +1,18 @@ -use std::{fmt, io}; -use std::net::Ipv4Addr; use crate::cal_checksum; +use std::net::Ipv4Addr; +use std::{fmt, io}; /// igmp v2 /* https://www.rfc-editor.org/rfc/rfc2236.html - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Type | Max Resp Time | Checksum | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Group Address | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Type | Max Resp Time | Checksum | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Group Address | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ /// v2版本的报文 pub struct IgmpV2Packet { @@ -48,7 +48,7 @@ impl Into for IgmpV2Type { IgmpV2Type::Query => 0x11, IgmpV2Type::ReportV2 => 0x16, IgmpV2Type::LeaveV2 => 0x17, - IgmpV2Type::Unknown(v) => v + IgmpV2Type::Unknown(v) => v, } } } diff --git a/packet/src/igmp/igmp_v3.rs b/packet/src/igmp/igmp_v3.rs index f412d0e..0ec2df9 100644 --- a/packet/src/igmp/igmp_v3.rs +++ b/packet/src/igmp/igmp_v3.rs @@ -1,5 +1,5 @@ -use std::{fmt, io}; use std::net::Ipv4Addr; +use std::{fmt, io}; use crate::cal_checksum; @@ -116,7 +116,7 @@ impl Into for IgmpV3Type { match self { IgmpV3Type::Query => 0x11, IgmpV3Type::ReportV3 => 0x22, - IgmpV3Type::Unknown(v) => v + IgmpV3Type::Unknown(v) => v, } } } @@ -203,7 +203,7 @@ impl + AsMut<[u8]>> IgmpV3QueryPacket { self.buffer.as_mut()[2..4].copy_from_slice(&checksum.to_be_bytes()) } pub fn set_qrv(&mut self, qrv: u8) { - self.buffer.as_mut()[8] = (self.buffer.as_ref()[8]&(!0x07)) | (qrv & 0x07) + self.buffer.as_mut()[8] = (self.buffer.as_ref()[8] & (!0x07)) | (qrv & 0x07) } pub fn set_qqic(&mut self, qqic: u8) { self.buffer.as_mut()[9] = qqic @@ -349,7 +349,10 @@ impl> IgmpV3ReportPacket { return None; } if let Ok(record) = IgmpV3RecordPacket::new(&buf[start..]) { - let end = start + 8 + record.aux_data_len() as usize * 4 + record.source_number() as usize * 4; + let end = start + + 8 + + record.aux_data_len() as usize * 4 + + record.source_number() as usize * 4; if end > len { return None; } @@ -364,7 +367,6 @@ impl> IgmpV3ReportPacket { } } - /// group record pub struct IgmpV3RecordPacket { pub buffer: B, @@ -488,4 +490,4 @@ impl> fmt::Debug for IgmpV3RecordPacket { .field("auxiliary_data", &self.auxiliary_data()) .finish() } -} \ No newline at end of file +} diff --git a/packet/src/igmp/mod.rs b/packet/src/igmp/mod.rs index 79d04e1..113ab19 100644 --- a/packet/src/igmp/mod.rs +++ b/packet/src/igmp/mod.rs @@ -2,7 +2,7 @@ pub mod igmp_v1; pub mod igmp_v2; pub mod igmp_v3; -#[derive(Debug,Copy, Clone,Eq, PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum IgmpType { /// 0x11 所有组224.0.0.1或者特定组 Query, @@ -40,7 +40,7 @@ impl Into for IgmpType { IgmpType::ReportV2 => 0x16, IgmpType::ReportV3 => 0x22, IgmpType::LeaveV2 => 0x17, - IgmpType::Unknown(v) => v + IgmpType::Unknown(v) => v, } } -} \ No newline at end of file +} diff --git a/packet/src/ip/ipv4/packet.rs b/packet/src/ip/ipv4/packet.rs index 1308f71..ab1ca74 100644 --- a/packet/src/ip/ipv4/packet.rs +++ b/packet/src/ip/ipv4/packet.rs @@ -1,6 +1,5 @@ -use std::{fmt, io}; use std::net::Ipv4Addr; - +use std::{fmt, io}; use crate::cal_checksum; use crate::ip::ipv4::protocol::Protocol; diff --git a/packet/src/ip/ipv4/protocol.rs b/packet/src/ip/ipv4/protocol.rs index e4d7baa..957fa08 100644 --- a/packet/src/ip/ipv4/protocol.rs +++ b/packet/src/ip/ipv4/protocol.rs @@ -1,4 +1,4 @@ -#[derive(Eq, PartialEq,Ord, PartialOrd, Copy, Clone, Debug)] +#[derive(Eq, PartialEq, Ord, PartialOrd, Copy, Clone, Debug)] pub enum Protocol { /// Hopopt, diff --git a/packet/src/ip/mod.rs b/packet/src/ip/mod.rs index b82e44f..7908624 100644 --- a/packet/src/ip/mod.rs +++ b/packet/src/ip/mod.rs @@ -1,5 +1,5 @@ -use std::io; use ipv4::packet::IpV4Packet; +use std::io; pub mod ipv4; diff --git a/packet/src/lib.rs b/packet/src/lib.rs index 1604c79..829c739 100644 --- a/packet/src/lib.rs +++ b/packet/src/lib.rs @@ -3,13 +3,13 @@ use std::net::Ipv4Addr; use byteorder::BigEndian; use byteorder::ReadBytesExt; +pub mod arp; +pub mod ethernet; pub mod icmp; pub mod igmp; pub mod ip; pub mod tcp; pub mod udp; -pub mod ethernet; -pub mod arp; // pub enum IpUpperLayer { // UDP(UdpPacket), // Unknown(B), diff --git a/packet/src/tcp/tcp.rs b/packet/src/tcp/tcp.rs index 1be0235..38139a7 100644 --- a/packet/src/tcp/tcp.rs +++ b/packet/src/tcp/tcp.rs @@ -1,5 +1,5 @@ -use std::{fmt, io}; use std::net::Ipv4Addr; +use std::{fmt, io}; use crate::tcp::Flags; @@ -58,7 +58,11 @@ impl> TcpPacket { buffer, } } - pub fn new(source_ip: Ipv4Addr, destination_ip: Ipv4Addr, buffer: B) -> io::Result> { + pub fn new( + source_ip: Ipv4Addr, + destination_ip: Ipv4Addr, + buffer: B, + ) -> io::Result> { let packet = TcpPacket::unchecked(source_ip, destination_ip, buffer); if packet.buffer.as_ref().len() < 20 { diff --git a/packet/src/udp/udp.rs b/packet/src/udp/udp.rs index 9aedce1..3960b90 100644 --- a/packet/src/udp/udp.rs +++ b/packet/src/udp/udp.rs @@ -1,5 +1,5 @@ -use std::{fmt, io}; use std::net::Ipv4Addr; +use std::{fmt, io}; /// udp协议 /// @@ -60,7 +60,11 @@ impl> UdpPacket { buffer, } } - pub fn new(source_ip: Ipv4Addr, destination_ip: Ipv4Addr, buffer: B) -> io::Result> { + pub fn new( + source_ip: Ipv4Addr, + destination_ip: Ipv4Addr, + buffer: B, + ) -> io::Result> { if buffer.as_ref().len() < 8 { Err(io::Error::from(io::ErrorKind::InvalidData))?; } diff --git a/src/cipher/aes_gcm_cipher.rs b/src/cipher/aes_gcm_cipher.rs new file mode 100644 index 0000000..4417dff --- /dev/null +++ b/src/cipher/aes_gcm_cipher.rs @@ -0,0 +1,109 @@ +use std::io; + +use aes_gcm::aead::consts::{U12, U16}; +use aes_gcm::aead::generic_array::GenericArray; +use aes_gcm::{AeadInPlace, Aes256Gcm, Key, KeyInit, Nonce, Tag}; +use rand::RngCore; + +use crate::cipher::finger::Finger; +use crate::protocol::{body::SecretBody, body::ENCRYPTION_RESERVED, NetPacket}; + +#[derive(Clone)] +pub struct Aes256GcmCipher { + cipher: Aes256Gcm, + finger: Finger, +} + +impl Aes256GcmCipher { + pub fn new(key: [u8; 32], finger: Finger) -> Self { + let key: &Key = &key.into(); + Self { + cipher: Aes256Gcm::new(key), + finger, + } + } + + pub fn decrypt_ipv4 + AsMut<[u8]>>( + &self, + net_packet: &mut NetPacket, + ) -> io::Result<()> { + if !net_packet.is_encrypt() { + //未加密的数据直接丢弃 + return Err(io::Error::new(io::ErrorKind::Other, "not encrypt")); + } + if net_packet.payload().len() < ENCRYPTION_RESERVED { + log::error!("数据异常,长度小于{}", ENCRYPTION_RESERVED); + return Err(io::Error::new(io::ErrorKind::Other, "data err")); + } + let mut nonce_raw = [0; 12]; + nonce_raw[0..4].copy_from_slice(&net_packet.source().octets()); + nonce_raw[4..8].copy_from_slice(&net_packet.destination().octets()); + nonce_raw[8] = net_packet.protocol().into(); + nonce_raw[9] = net_packet.transport_protocol(); + nonce_raw[10] = net_packet.is_gateway() as u8; + nonce_raw[11] = net_packet.source_ttl(); + let nonce: &GenericArray = Nonce::from_slice(&nonce_raw); + + let mut secret_body = SecretBody::new(net_packet.payload_mut())?; + let tag = secret_body.tag(); + if tag.len() != 16 { + return Err(io::Error::new(io::ErrorKind::Other, "tag err")); + } + let finger = self.finger.calculate_finger(&nonce_raw, &secret_body); + if &finger != secret_body.finger() { + return Err(io::Error::new(io::ErrorKind::Other, "finger err")); + } + + let tag: GenericArray = Tag::clone_from_slice(tag); + if let Err(e) = + self.cipher + .decrypt_in_place_detached(nonce, &[], secret_body.body_mut(), &tag) + { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("解密失败:{}", e), + )); + } + net_packet.set_encrypt_flag(false); + net_packet.set_data_len(net_packet.data_len() - ENCRYPTION_RESERVED)?; + return Ok(()); + } + /// net_packet 必须预留足够长度 + /// data_len是有效载荷的长度 + /// 返回加密后载荷的长度 + pub fn encrypt_ipv4 + AsMut<[u8]>>( + &self, + net_packet: &mut NetPacket, + ) -> io::Result<()> { + if net_packet.reserve() < ENCRYPTION_RESERVED { + return Err(io::Error::new(io::ErrorKind::Other, "too short")); + } + let mut nonce_raw = [0; 12]; + nonce_raw[0..4].copy_from_slice(&net_packet.source().octets()); + nonce_raw[4..8].copy_from_slice(&net_packet.destination().octets()); + nonce_raw[8] = net_packet.protocol().into(); + nonce_raw[9] = net_packet.transport_protocol(); + nonce_raw[10] = net_packet.is_gateway() as u8; + nonce_raw[11] = net_packet.source_ttl(); + let nonce: &GenericArray = Nonce::from_slice(&nonce_raw); + net_packet.set_data_len(net_packet.data_len() + ENCRYPTION_RESERVED)?; + let mut secret_body = SecretBody::new(net_packet.payload_mut())?; + secret_body.set_random(rand::thread_rng().next_u32()); + return match self + .cipher + .encrypt_in_place_detached(nonce, &[], secret_body.body_mut()) + { + Ok(tag) => { + secret_body.set_tag(tag.as_slice())?; + let finger = self.finger.calculate_finger(&nonce_raw, &secret_body); + secret_body.set_finger(&finger)?; + net_packet.set_encrypt_flag(true); + Ok(()) + } + Err(e) => Err(io::Error::new( + io::ErrorKind::Other, + format!("加密失败:{}", e), + )), + }; + } +} diff --git a/src/cipher/finger.rs b/src/cipher/finger.rs new file mode 100644 index 0000000..8c30c3d --- /dev/null +++ b/src/cipher/finger.rs @@ -0,0 +1,52 @@ +use std::io; + +use sha2::Digest; + +use crate::protocol::{body::SecretBody, body::ENCRYPTION_RESERVED, NetPacket}; + +#[derive(Clone)] +pub struct Finger { + token: String, +} + +impl Finger { + pub fn new(token: String) -> Self { + Finger { token } + } + pub fn check_finger>(&self, net_packet: &NetPacket) -> io::Result<()> { + if !net_packet.is_encrypt() { + //未加密的数据直接丢弃 + return Err(io::Error::new(io::ErrorKind::Other, "not encrypt")); + } + if net_packet.payload().len() < ENCRYPTION_RESERVED { + log::error!("数据异常,长度小于{}", ENCRYPTION_RESERVED); + return Err(io::Error::new(io::ErrorKind::Other, "data err")); + } + let mut nonce_raw = [0; 12]; + nonce_raw[0..4].copy_from_slice(&net_packet.source().octets()); + nonce_raw[4..8].copy_from_slice(&net_packet.destination().octets()); + nonce_raw[8] = net_packet.protocol().into(); + nonce_raw[9] = net_packet.transport_protocol(); + nonce_raw[10] = net_packet.is_gateway() as u8; + nonce_raw[11] = net_packet.source_ttl(); + let secret_body = SecretBody::new(net_packet.payload())?; + let finger = self.calculate_finger(&nonce_raw, &secret_body); + if &finger != secret_body.finger() { + return Err(io::Error::new(io::ErrorKind::Other, "finger err")); + } + Ok(()) + } + pub fn calculate_finger>( + &self, + nonce_raw: &[u8; 12], + secret_body: &SecretBody, + ) -> [u8; 12] { + let mut hasher = sha2::Sha256::new(); + hasher.update(secret_body.body()); + hasher.update(nonce_raw); + hasher.update(secret_body.tag()); + hasher.update(&self.token); + let key: [u8; 32] = hasher.finalize().into(); + return key[20..].try_into().unwrap(); + } +} diff --git a/src/cipher/mod.rs b/src/cipher/mod.rs new file mode 100644 index 0000000..e4b6966 --- /dev/null +++ b/src/cipher/mod.rs @@ -0,0 +1,13 @@ +#[cfg(not(feature = "ring-cipher"))] +mod aes_gcm_cipher; +mod finger; +#[cfg(feature = "ring-cipher")] +mod ring_cipher; +mod rsa_cipher; + +#[cfg(not(feature = "ring-cipher"))] +pub use aes_gcm_cipher::Aes256GcmCipher; +pub use finger::Finger; +#[cfg(feature = "ring-cipher")] +pub use ring_cipher::Aes256GcmCipher; +pub use rsa_cipher::RsaCipher; diff --git a/src/cipher/ring_cipher.rs b/src/cipher/ring_cipher.rs new file mode 100644 index 0000000..87f71be --- /dev/null +++ b/src/cipher/ring_cipher.rs @@ -0,0 +1,143 @@ +use crate::cipher::Finger; +use rand::RngCore; +use ring::aead; +use ring::aead::{LessSafeKey, UnboundKey}; +use std::io; + +use crate::protocol::body::{SecretBody, ENCRYPTION_RESERVED}; +use crate::protocol::NetPacket; + +#[derive(Clone)] +pub struct Aes256GcmCipher { + pub(crate) cipher: AesGcmEnum, + pub(crate) finger: Finger, +} + +pub enum AesGcmEnum { + AesGCM128(LessSafeKey, [u8; 16]), + AesGCM256(LessSafeKey, [u8; 32]), +} + +impl Clone for AesGcmEnum { + fn clone(&self) -> Self { + match &self { + AesGcmEnum::AesGCM128(_, key) => { + let c = + LessSafeKey::new(UnboundKey::new(&aead::AES_128_GCM, key.as_slice()).unwrap()); + AesGcmEnum::AesGCM128(c, *key) + } + AesGcmEnum::AesGCM256(_, key) => { + let c = + LessSafeKey::new(UnboundKey::new(&aead::AES_256_GCM, key.as_slice()).unwrap()); + AesGcmEnum::AesGCM256(c, *key) + } + } + } +} + +impl Aes256GcmCipher { + pub fn new(key: [u8; 32], finger: Finger) -> Self { + let cipher = LessSafeKey::new(UnboundKey::new(&aead::AES_256_GCM, &key).unwrap()); + Self { + cipher: AesGcmEnum::AesGCM256(cipher, key), + finger, + } + } + pub fn decrypt_ipv4 + AsMut<[u8]>>( + &self, + net_packet: &mut NetPacket, + ) -> io::Result<()> { + if !net_packet.is_encrypt() { + //未加密的数据直接丢弃 + return Err(io::Error::new(io::ErrorKind::Other, "not encrypt")); + } + if net_packet.payload().len() < ENCRYPTION_RESERVED { + log::error!("数据异常,长度小于{}", ENCRYPTION_RESERVED); + return Err(io::Error::new(io::ErrorKind::Other, "data err")); + } + let mut nonce_raw = [0; 12]; + nonce_raw[0..4].copy_from_slice(&net_packet.source().octets()); + nonce_raw[4..8].copy_from_slice(&net_packet.destination().octets()); + nonce_raw[8] = net_packet.protocol().into(); + nonce_raw[9] = net_packet.transport_protocol(); + nonce_raw[10] = net_packet.is_gateway() as u8; + nonce_raw[11] = net_packet.source_ttl(); + let nonce = aead::Nonce::assume_unique_for_key(nonce_raw); + let mut secret_body = SecretBody::new(net_packet.payload_mut())?; + let tag = secret_body.tag(); + if tag.len() != 16 { + return Err(io::Error::new(io::ErrorKind::Other, "tag err")); + } + let finger = self.finger.calculate_finger(&nonce_raw, &secret_body); + if &finger != secret_body.finger() { + return Err(io::Error::new(io::ErrorKind::Other, "finger err")); + } + + let rs = match &self.cipher { + AesGcmEnum::AesGCM128(cipher, _) => { + cipher.open_in_place(nonce, aead::Aad::empty(), secret_body.en_body_mut()) + } + AesGcmEnum::AesGCM256(cipher, _) => { + cipher.open_in_place(nonce, aead::Aad::empty(), secret_body.en_body_mut()) + } + }; + if let Err(e) = rs { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("解密失败:{}", e), + )); + } + net_packet.set_encrypt_flag(false); + net_packet.set_data_len(net_packet.data_len() - ENCRYPTION_RESERVED)?; + return Ok(()); + } + /// net_packet 必须预留足够长度 + /// data_len是有效载荷的长度 + /// 返回加密后载荷的长度 + pub fn encrypt_ipv4 + AsMut<[u8]>>( + &self, + net_packet: &mut NetPacket, + ) -> io::Result<()> { + let mut nonce_raw = [0; 12]; + nonce_raw[0..4].copy_from_slice(&net_packet.source().octets()); + nonce_raw[4..8].copy_from_slice(&net_packet.destination().octets()); + nonce_raw[8] = net_packet.protocol().into(); + nonce_raw[9] = net_packet.transport_protocol(); + nonce_raw[10] = net_packet.is_gateway() as u8; + nonce_raw[11] = net_packet.source_ttl(); + let nonce = aead::Nonce::assume_unique_for_key(nonce_raw); + let data_len = net_packet.data_len() + ENCRYPTION_RESERVED; + net_packet.set_data_len(data_len)?; + let mut secret_body = SecretBody::new(net_packet.payload_mut())?; + secret_body.set_random(rand::thread_rng().next_u32()); + + let rs = match &self.cipher { + AesGcmEnum::AesGCM128(cipher, _) => { + cipher.seal_in_place_separate_tag(nonce, aead::Aad::empty(), secret_body.body_mut()) + } + AesGcmEnum::AesGCM256(cipher, _) => { + cipher.seal_in_place_separate_tag(nonce, aead::Aad::empty(), secret_body.body_mut()) + } + }; + return match rs { + Ok(tag) => { + let tag = tag.as_ref(); + if tag.len() != 16 { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("加密tag长度错误:{}", tag.len()), + )); + } + secret_body.set_tag(tag)?; + let finger = self.finger.calculate_finger(&nonce_raw, &secret_body); + secret_body.set_finger(&finger)?; + net_packet.set_encrypt_flag(true); + Ok(()) + } + Err(e) => Err(io::Error::new( + io::ErrorKind::Other, + format!("加密失败:{}", e), + )), + }; + } +} diff --git a/src/cipher/rsa_cipher.rs b/src/cipher/rsa_cipher.rs new file mode 100644 index 0000000..4bf1fc6 --- /dev/null +++ b/src/cipher/rsa_cipher.rs @@ -0,0 +1,146 @@ +use std::io; +use std::path::PathBuf; +use std::sync::Arc; + +use crate::protocol::body::RsaSecretBody; +use crate::protocol::NetPacket; +use rsa::pkcs8::der::Decode; +use rsa::pkcs8::{DecodePrivateKey, EncodePrivateKey, EncodePublicKey, LineEnding}; +use rsa::{RsaPrivateKey, RsaPublicKey}; +use sha2::Digest; + +#[derive(Clone)] +pub struct RsaCipher { + inner: Arc, +} + +struct Inner { + private_key: RsaPrivateKey, + public_key_der: Vec, +} + +impl RsaCipher { + pub fn new() -> io::Result { + let priv_key_path = PathBuf::from("key/private_key.pem"); + let private_key = if priv_key_path.exists() { + let key = std::fs::read_to_string(priv_key_path)?; + let private_key = match RsaPrivateKey::from_pkcs8_pem(&key) { + Ok(private_key) => private_key, + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("'key/private_key.pem' content error {}", e), + )); + } + }; + private_key + } else { + let path = PathBuf::from("key"); + if !path.exists() { + std::fs::create_dir(path)?; + } + let mut rng = rand::thread_rng(); + let bits = 2048; + let private_key = match RsaPrivateKey::new(&mut rng, bits) { + Ok(private_key) => private_key, + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("failed to generate a key {}", e), + )); + } + }; + match private_key.write_pkcs8_pem_file(priv_key_path, LineEnding::CRLF) { + Ok(_) => {} + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("failed to write to file 'key/private_key.pem' {}", e), + )); + } + }; + private_key + }; + let public_key = RsaPublicKey::from(&private_key); + match public_key.write_public_key_pem_file("key/public_key.pem", LineEnding::CRLF) { + Ok(_) => {} + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("failed to write to file 'key/public_key.pem' {}", e), + )); + } + }; + let public_key_der = match public_key.to_public_key_der() { + Ok(public_key_der) => public_key_der.to_vec(), + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("to_public_key_der failed {}", e), + )); + } + }; + let inner = Inner { + private_key, + public_key_der, + }; + Ok(Self { + inner: Arc::new(inner), + }) + } + pub fn finger(&self) -> io::Result { + match rsa::pkcs8::SubjectPublicKeyInfo::from_der(&self.inner.public_key_der) { + Ok(spki) => match spki.fingerprint_base64() { + Ok(finger) => Ok(finger), + Err(e) => Err(io::Error::new( + io::ErrorKind::Other, + format!("fingerprint_base64 error {}", e), + )), + }, + Err(e) => Err(io::Error::new( + io::ErrorKind::Other, + format!("from_der error {}", e), + )), + } + } + + pub fn public_key(&self) -> &[u8] { + &self.inner.public_key_der + } +} + +impl RsaCipher { + pub fn decrypt>( + &self, + net_packet: &NetPacket, + ) -> io::Result>> { + match self + .inner + .private_key + .decrypt(rsa::PaddingScheme::PKCS1v15Encrypt, net_packet.payload()) + { + Ok(rs) => { + let mut nonce_raw = [0; 12]; + nonce_raw[0..4].copy_from_slice(&net_packet.source().octets()); + nonce_raw[4..8].copy_from_slice(&net_packet.destination().octets()); + nonce_raw[8] = net_packet.protocol().into(); + nonce_raw[9] = net_packet.transport_protocol(); + nonce_raw[10] = net_packet.is_gateway() as u8; + nonce_raw[11] = net_packet.source_ttl(); + let secret_body = RsaSecretBody::new(rs)?; + let mut hasher = sha2::Sha256::new(); + hasher.update(secret_body.body()); + hasher.update(nonce_raw); + let key: [u8; 32] = hasher.finalize().into(); + if secret_body.finger() != &key[16..] { + return Err(io::Error::new(io::ErrorKind::Other, "finger err")); + } + Ok(secret_body) + } + Err(e) => Err(io::Error::new( + io::ErrorKind::Other, + format!("decrypt failed {}", e), + )), + } + } +} diff --git a/src/main.rs b/src/main.rs index 75ebdfd..5f29c63 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,14 +4,16 @@ use std::net::Ipv4Addr; use std::path::PathBuf; use std::sync::Arc; +use crate::cipher::RsaCipher; +use crate::service::{start_tcp, start_udp}; use clap::Parser; use tokio::net::{TcpListener, UdpSocket}; -use crate::service::{start_tcp, start_udp}; -pub mod error; -pub mod proto; -pub mod protocol; -pub mod service; +mod cipher; +mod error; +mod proto; +mod protocol; +mod service; /// 默认网关信息 const GATEWAY: Ipv4Addr = Ipv4Addr::new(10, 26, 0, 1); @@ -52,7 +54,8 @@ fn log_init() { let log_config = log_path.join("log4rs.yaml"); if !log_config.exists() { if let Ok(mut f) = std::fs::File::create(&log_config) { - let _ = f.write_all(b"refresh_rate: 30 seconds + let _ = f.write_all( + b"refresh_rate: 30 seconds appenders: rolling_file: kind: rolling_file @@ -74,7 +77,8 @@ appenders: root: level: info appenders: - - rolling_file"); + - rolling_file", + ); } } let _ = log4rs::init_file(log_config, Default::default()); @@ -82,49 +86,76 @@ root: #[tokio::main] async fn main() { + log_init(); let args = StartArgs::parse(); - let port = args.port.unwrap_or(29871); - println!("端口:{}", port); + let port = args.port.unwrap_or(29872); + println!("端口: {}", port); let white_token = if let Some(white_token) = args.white_token { Some(HashSet::from_iter(white_token.into_iter())) } else { None }; - println!("token白名单:{:?}", white_token); + println!("token白名单: {:?}", white_token); let gateway = if let Some(gateway) = args.gateway { - gateway.parse::().expect("网关错误,必须为有效的ipv4地址") + match gateway.parse::() { + Ok(ip) => ip, + Err(e) => { + log::error!("网关错误,必须为有效的ipv4地址 gateway={},e={}", gateway, e); + panic!("网关错误,必须为有效的ipv4地址") + } + } } else { GATEWAY }; - println!("网关:{:?}", gateway); + println!("网关: {:?}", gateway); if gateway.is_unspecified() { println!("网关地址无效"); + log::error!("网关错误,必须为有效的ipv4地址 gateway={}", gateway); return; } if gateway.is_broadcast() { println!("网关错误,不能为广播地址"); + log::error!("网关错误,不能为广播地址 gateway={}", gateway); return; } if gateway.is_multicast() { println!("网关错误,不能为组播地址"); + log::error!("网关错误,不能为组播地址 gateway={}", gateway); return; } if !gateway.is_private() { - println!("Warning 不是一个私有地址:{:?},将有可能和公网ip冲突", gateway); + println!( + "Warning 不是一个私有地址:{:?},将有可能和公网ip冲突", + gateway + ); + log::warn!("网关错误,不是一个私有地址 gateway={}", gateway); } let netmask = if let Some(netmask) = args.netmask { - netmask.parse::().expect("子网掩码错误,必须为有效的ipv4地址") + match netmask.parse::() { + Ok(ip) => ip, + Err(e) => { + log::error!( + "子网掩码错误,必须为有效的ipv4地址 netmask={},e={}", + netmask, + e + ); + panic!("子网掩码错误,必须为有效的ipv4地址") + } + } } else { NETMASK }; - println!("子网掩码:{:?}", netmask); - if netmask.is_broadcast() || netmask.is_unspecified() || !(!u32::from_be_bytes(netmask.octets()) + 1).is_power_of_two() { + println!("子网掩码: {:?}", netmask); + if netmask.is_broadcast() + || netmask.is_unspecified() + || !(!u32::from_be_bytes(netmask.octets()) + 1).is_power_of_two() + { println!("子网掩码错误"); + log::error!("子网掩码错误 netmask={}", netmask); return; } - let broadcast = (!u32::from_be_bytes(netmask.octets())) - | u32::from_be_bytes(gateway.octets()); + let broadcast = (!u32::from_be_bytes(netmask.octets())) | u32::from_be_bytes(gateway.octets()); let broadcast = Ipv4Addr::from(broadcast); let config = ConfigInfo { port, @@ -133,33 +164,38 @@ async fn main() { broadcast, netmask, }; - log_init(); - log::info!("config:{:?}",config); + let rsa = match RsaCipher::new() { + Ok(rsa) => { + println!("密钥指纹: {}", rsa.finger().unwrap()); + Some(rsa) + } + Err(e) => { + log::error!("获取密钥错误:{:?}", e); + panic!("获取密钥错误:{}", e); + } + }; + log::info!("config:{:?}", config); let udp = match UdpSocket::bind(format!("0.0.0.0:{}", port)).await { - Ok(udp) => { Arc::new(udp) } + Ok(udp) => Arc::new(udp), Err(e) => { - log::warn!("udp启动失败:{:?}",e); - panic!("{:?}", e); + log::warn!("udp启动失败:{:?}", e); + panic!("udp启动失败:{}", e); } }; - log::info!("监听udp端口:{:?}",udp.local_addr().unwrap()); - println!("监听udp端口:{:?}", udp.local_addr().unwrap()); + log::info!("监听udp端口: {:?}", udp.local_addr().unwrap()); + println!("监听udp端口: {:?}", udp.local_addr().unwrap()); let tcp = match TcpListener::bind(format!("0.0.0.0:{}", port)).await { - Ok(tcp) => { tcp } + Ok(tcp) => tcp, Err(e) => { - log::warn!("tcp启动失败:{:?}",e); - panic!("{:?}", e); + log::warn!("tcp启动失败:{:?}", e); + panic!("tcp启动失败:{:?}", e); } }; - log::info!("监听tcp端口:{:?}",tcp.local_addr().unwrap()); - println!("监听tcp端口:{:?}", tcp.local_addr().unwrap()); + log::info!("监听tcp端口: {:?}", tcp.local_addr().unwrap()); + println!("监听tcp端口: {:?}", tcp.local_addr().unwrap()); let config = config.clone(); let main_udp = udp.clone(); let tcp_config = config.clone(); - tokio::spawn(async move { - if let Err(e) = start_tcp(tcp, main_udp, tcp_config).await { - log::warn!("tcp任务结束:{:?}",e); - } - }); - start_udp(udp, config).await; + tokio::spawn(start_tcp(tcp, main_udp, tcp_config, rsa.clone())); + start_udp(udp, config, rsa.clone()).await; } diff --git a/src/proto/message.rs b/src/proto/message.rs index 2827aaf..99be211 100644 --- a/src/proto/message.rs +++ b/src/proto/message.rs @@ -25,6 +25,462 @@ /// of protobuf runtime. const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_3_2_0; +#[derive(PartialEq,Clone,Default,Debug)] +// @@protoc_insertion_point(message:HandshakeRequest) +pub struct HandshakeRequest { + // message fields + // @@protoc_insertion_point(field:HandshakeRequest.version) + pub version: ::std::string::String, + // @@protoc_insertion_point(field:HandshakeRequest.secret) + pub secret: bool, + // special fields + // @@protoc_insertion_point(special_field:HandshakeRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a HandshakeRequest { + fn default() -> &'a HandshakeRequest { + ::default_instance() + } +} + +impl HandshakeRequest { + pub fn new() -> HandshakeRequest { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(2); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( + "version", + |m: &HandshakeRequest| { &m.version }, + |m: &mut HandshakeRequest| { &mut m.version }, + )); + fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( + "secret", + |m: &HandshakeRequest| { &m.secret }, + |m: &mut HandshakeRequest| { &mut m.secret }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "HandshakeRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for HandshakeRequest { + const NAME: &'static str = "HandshakeRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.version = is.read_string()?; + }, + 16 => { + self.secret = is.read_bool()?; + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if !self.version.is_empty() { + my_size += ::protobuf::rt::string_size(1, &self.version); + } + if self.secret != false { + my_size += 1 + 1; + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if !self.version.is_empty() { + os.write_string(1, &self.version)?; + } + if self.secret != false { + os.write_bool(2, self.secret)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> HandshakeRequest { + HandshakeRequest::new() + } + + fn clear(&mut self) { + self.version.clear(); + self.secret = false; + self.special_fields.clear(); + } + + fn default_instance() -> &'static HandshakeRequest { + static instance: HandshakeRequest = HandshakeRequest { + version: ::std::string::String::new(), + secret: false, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for HandshakeRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("HandshakeRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for HandshakeRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for HandshakeRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +#[derive(PartialEq,Clone,Default,Debug)] +// @@protoc_insertion_point(message:HandshakeResponse) +pub struct HandshakeResponse { + // message fields + // @@protoc_insertion_point(field:HandshakeResponse.version) + pub version: ::std::string::String, + // @@protoc_insertion_point(field:HandshakeResponse.secret) + pub secret: bool, + // @@protoc_insertion_point(field:HandshakeResponse.public_key) + pub public_key: ::std::vec::Vec, + // @@protoc_insertion_point(field:HandshakeResponse.key_finger) + pub key_finger: ::std::string::String, + // special fields + // @@protoc_insertion_point(special_field:HandshakeResponse.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a HandshakeResponse { + fn default() -> &'a HandshakeResponse { + ::default_instance() + } +} + +impl HandshakeResponse { + pub fn new() -> HandshakeResponse { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(4); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( + "version", + |m: &HandshakeResponse| { &m.version }, + |m: &mut HandshakeResponse| { &mut m.version }, + )); + fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( + "secret", + |m: &HandshakeResponse| { &m.secret }, + |m: &mut HandshakeResponse| { &mut m.secret }, + )); + fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( + "public_key", + |m: &HandshakeResponse| { &m.public_key }, + |m: &mut HandshakeResponse| { &mut m.public_key }, + )); + fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( + "key_finger", + |m: &HandshakeResponse| { &m.key_finger }, + |m: &mut HandshakeResponse| { &mut m.key_finger }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "HandshakeResponse", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for HandshakeResponse { + const NAME: &'static str = "HandshakeResponse"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.version = is.read_string()?; + }, + 16 => { + self.secret = is.read_bool()?; + }, + 26 => { + self.public_key = is.read_bytes()?; + }, + 34 => { + self.key_finger = is.read_string()?; + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if !self.version.is_empty() { + my_size += ::protobuf::rt::string_size(1, &self.version); + } + if self.secret != false { + my_size += 1 + 1; + } + if !self.public_key.is_empty() { + my_size += ::protobuf::rt::bytes_size(3, &self.public_key); + } + if !self.key_finger.is_empty() { + my_size += ::protobuf::rt::string_size(4, &self.key_finger); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if !self.version.is_empty() { + os.write_string(1, &self.version)?; + } + if self.secret != false { + os.write_bool(2, self.secret)?; + } + if !self.public_key.is_empty() { + os.write_bytes(3, &self.public_key)?; + } + if !self.key_finger.is_empty() { + os.write_string(4, &self.key_finger)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> HandshakeResponse { + HandshakeResponse::new() + } + + fn clear(&mut self) { + self.version.clear(); + self.secret = false; + self.public_key.clear(); + self.key_finger.clear(); + self.special_fields.clear(); + } + + fn default_instance() -> &'static HandshakeResponse { + static instance: HandshakeResponse = HandshakeResponse { + version: ::std::string::String::new(), + secret: false, + public_key: ::std::vec::Vec::new(), + key_finger: ::std::string::String::new(), + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for HandshakeResponse { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("HandshakeResponse").unwrap()).clone() + } +} + +impl ::std::fmt::Display for HandshakeResponse { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for HandshakeResponse { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +#[derive(PartialEq,Clone,Default,Debug)] +// @@protoc_insertion_point(message:SecretHandshakeRequest) +pub struct SecretHandshakeRequest { + // message fields + // @@protoc_insertion_point(field:SecretHandshakeRequest.token) + pub token: ::std::string::String, + // @@protoc_insertion_point(field:SecretHandshakeRequest.key) + pub key: ::std::vec::Vec, + // special fields + // @@protoc_insertion_point(special_field:SecretHandshakeRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a SecretHandshakeRequest { + fn default() -> &'a SecretHandshakeRequest { + ::default_instance() + } +} + +impl SecretHandshakeRequest { + pub fn new() -> SecretHandshakeRequest { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(2); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( + "token", + |m: &SecretHandshakeRequest| { &m.token }, + |m: &mut SecretHandshakeRequest| { &mut m.token }, + )); + fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( + "key", + |m: &SecretHandshakeRequest| { &m.key }, + |m: &mut SecretHandshakeRequest| { &mut m.key }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "SecretHandshakeRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for SecretHandshakeRequest { + const NAME: &'static str = "SecretHandshakeRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.token = is.read_string()?; + }, + 18 => { + self.key = is.read_bytes()?; + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if !self.token.is_empty() { + my_size += ::protobuf::rt::string_size(1, &self.token); + } + if !self.key.is_empty() { + my_size += ::protobuf::rt::bytes_size(2, &self.key); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if !self.token.is_empty() { + os.write_string(1, &self.token)?; + } + if !self.key.is_empty() { + os.write_bytes(2, &self.key)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> SecretHandshakeRequest { + SecretHandshakeRequest::new() + } + + fn clear(&mut self) { + self.token.clear(); + self.key.clear(); + self.special_fields.clear(); + } + + fn default_instance() -> &'static SecretHandshakeRequest { + static instance: SecretHandshakeRequest = SecretHandshakeRequest { + token: ::std::string::String::new(), + key: ::std::vec::Vec::new(), + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for SecretHandshakeRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("SecretHandshakeRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for SecretHandshakeRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for SecretHandshakeRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + #[derive(PartialEq,Clone,Default,Debug)] // @@protoc_insertion_point(message:RegistrationRequest) pub struct RegistrationRequest { @@ -43,6 +499,8 @@ pub struct RegistrationRequest { pub virtual_ip: u32, // @@protoc_insertion_point(field:RegistrationRequest.allow_ip_change) pub allow_ip_change: bool, + // @@protoc_insertion_point(field:RegistrationRequest.client_secret) + pub client_secret: bool, // special fields // @@protoc_insertion_point(special_field:RegistrationRequest.special_fields) pub special_fields: ::protobuf::SpecialFields, @@ -60,7 +518,7 @@ impl RegistrationRequest { } fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { - let mut fields = ::std::vec::Vec::with_capacity(7); + let mut fields = ::std::vec::Vec::with_capacity(8); let mut oneofs = ::std::vec::Vec::with_capacity(0); fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( "token", @@ -97,6 +555,11 @@ impl RegistrationRequest { |m: &RegistrationRequest| { &m.allow_ip_change }, |m: &mut RegistrationRequest| { &mut m.allow_ip_change }, )); + fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( + "client_secret", + |m: &RegistrationRequest| { &m.client_secret }, + |m: &mut RegistrationRequest| { &mut m.client_secret }, + )); ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( "RegistrationRequest", fields, @@ -136,6 +599,9 @@ impl ::protobuf::Message for RegistrationRequest { 56 => { self.allow_ip_change = is.read_bool()?; }, + 64 => { + self.client_secret = is.read_bool()?; + }, tag => { ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; }, @@ -169,6 +635,9 @@ impl ::protobuf::Message for RegistrationRequest { if self.allow_ip_change != false { my_size += 1 + 1; } + if self.client_secret != false { + my_size += 1 + 1; + } my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); self.special_fields.cached_size().set(my_size as u32); my_size @@ -196,6 +665,9 @@ impl ::protobuf::Message for RegistrationRequest { if self.allow_ip_change != false { os.write_bool(7, self.allow_ip_change)?; } + if self.client_secret != false { + os.write_bool(8, self.client_secret)?; + } os.write_unknown_fields(self.special_fields.unknown_fields())?; ::std::result::Result::Ok(()) } @@ -220,6 +692,7 @@ impl ::protobuf::Message for RegistrationRequest { self.version.clear(); self.virtual_ip = 0; self.allow_ip_change = false; + self.client_secret = false; self.special_fields.clear(); } @@ -232,6 +705,7 @@ impl ::protobuf::Message for RegistrationRequest { version: ::std::string::String::new(), virtual_ip: 0, allow_ip_change: false, + client_secret: false, special_fields: ::protobuf::SpecialFields::new(), }; &instance @@ -514,6 +988,8 @@ pub struct DeviceInfo { pub virtual_ip: u32, // @@protoc_insertion_point(field:DeviceInfo.device_status) pub device_status: u32, + // @@protoc_insertion_point(field:DeviceInfo.client_secret) + pub client_secret: bool, // special fields // @@protoc_insertion_point(special_field:DeviceInfo.special_fields) pub special_fields: ::protobuf::SpecialFields, @@ -531,7 +1007,7 @@ impl DeviceInfo { } fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { - let mut fields = ::std::vec::Vec::with_capacity(3); + let mut fields = ::std::vec::Vec::with_capacity(4); let mut oneofs = ::std::vec::Vec::with_capacity(0); fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( "name", @@ -548,6 +1024,11 @@ impl DeviceInfo { |m: &DeviceInfo| { &m.device_status }, |m: &mut DeviceInfo| { &mut m.device_status }, )); + fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>( + "client_secret", + |m: &DeviceInfo| { &m.client_secret }, + |m: &mut DeviceInfo| { &mut m.client_secret }, + )); ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( "DeviceInfo", fields, @@ -575,6 +1056,9 @@ impl ::protobuf::Message for DeviceInfo { 24 => { self.device_status = is.read_uint32()?; }, + 32 => { + self.client_secret = is.read_bool()?; + }, tag => { ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; }, @@ -596,6 +1080,9 @@ impl ::protobuf::Message for DeviceInfo { if self.device_status != 0 { my_size += ::protobuf::rt::uint32_size(3, self.device_status); } + if self.client_secret != false { + my_size += 1 + 1; + } my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); self.special_fields.cached_size().set(my_size as u32); my_size @@ -611,6 +1098,9 @@ impl ::protobuf::Message for DeviceInfo { if self.device_status != 0 { os.write_uint32(3, self.device_status)?; } + if self.client_secret != false { + os.write_bool(4, self.client_secret)?; + } os.write_unknown_fields(self.special_fields.unknown_fields())?; ::std::result::Result::Ok(()) } @@ -631,6 +1121,7 @@ impl ::protobuf::Message for DeviceInfo { self.name.clear(); self.virtual_ip = 0; self.device_status = 0; + self.client_secret = false; self.special_fields.clear(); } @@ -639,6 +1130,7 @@ impl ::protobuf::Message for DeviceInfo { name: ::std::string::String::new(), virtual_ip: 0, device_status: 0, + client_secret: false, special_fields: ::protobuf::SpecialFields::new(), }; &instance @@ -1107,33 +1599,41 @@ impl PunchNatType { } static file_descriptor_proto_data: &'static [u8] = b"\ - \n\rmessage.proto\"\xd6\x01\n\x13RegistrationRequest\x12\x14\n\x05token\ - \x18\x01\x20\x01(\tR\x05token\x12\x1b\n\tdevice_id\x18\x02\x20\x01(\tR\ - \x08deviceId\x12\x12\n\x04name\x18\x03\x20\x01(\tR\x04name\x12\x17\n\x07\ - is_fast\x18\x04\x20\x01(\x08R\x06isFast\x12\x18\n\x07version\x18\x05\x20\ - \x01(\tR\x07version\x12\x1d\n\nvirtual_ip\x18\x06\x20\x01(\x07R\tvirtual\ - Ip\x12&\n\x0fallow_ip_change\x18\x07\x20\x01(\x08R\rallowIpChange\"\xb3\ - \x02\n\x14RegistrationResponse\x12\x1d\n\nvirtual_ip\x18\x01\x20\x01(\ - \x07R\tvirtualIp\x12'\n\x0fvirtual_gateway\x18\x02\x20\x01(\x07R\x0evirt\ - ualGateway\x12'\n\x0fvirtual_netmask\x18\x03\x20\x01(\x07R\x0evirtualNet\ - mask\x12\x14\n\x05epoch\x18\x04\x20\x01(\rR\x05epoch\x125\n\x10device_in\ - fo_list\x18\x05\x20\x03(\x0b2\x0b.DeviceInfoR\x0edeviceInfoList\x12\x1b\ - \n\tpublic_ip\x18\x06\x20\x01(\x07R\x08publicIp\x12\x1f\n\x0bpublic_port\ - \x18\x07\x20\x01(\rR\npublicPort\x12\x1f\n\x0bpublic_ipv6\x18\x08\x20\ - \x01(\x0cR\npublicIpv6\"d\n\nDeviceInfo\x12\x12\n\x04name\x18\x01\x20\ - \x01(\tR\x04name\x12\x1d\n\nvirtual_ip\x18\x02\x20\x01(\x07R\tvirtualIp\ - \x12#\n\rdevice_status\x18\x03\x20\x01(\rR\x0cdeviceStatus\"Y\n\nDeviceL\ - ist\x12\x14\n\x05epoch\x18\x01\x20\x01(\rR\x05epoch\x125\n\x10device_inf\ - o_list\x18\x02\x20\x03(\x0b2\x0b.DeviceInfoR\x0edeviceInfoList\"\xa2\x02\ - \n\tPunchInfo\x12$\n\x0epublic_ip_list\x18\x02\x20\x03(\x07R\x0cpublicIp\ - List\x12\x1f\n\x0bpublic_port\x18\x03\x20\x01(\rR\npublicPort\x12*\n\x11\ - public_port_range\x18\x04\x20\x01(\rR\x0fpublicPortRange\x12(\n\x08nat_t\ - ype\x18\x05\x20\x01(\x0e2\r.PunchNatTypeR\x07natType\x12\x14\n\x05reply\ - \x18\x06\x20\x01(\x08R\x05reply\x12\x19\n\x08local_ip\x18\x07\x20\x01(\ - \x07R\x07localIp\x12\x1d\n\nlocal_port\x18\x08\x20\x01(\rR\tlocalPort\ - \x12(\n\x10public_ipv6_list\x18\t\x20\x03(\x0cR\x0epublicIpv6List*'\n\ - \x0cPunchNatType\x12\r\n\tSymmetric\x10\0\x12\x08\n\x04Cone\x10\x01b\x06\ - proto3\ + \n\rmessage.proto\"D\n\x10HandshakeRequest\x12\x18\n\x07version\x18\x01\ + \x20\x01(\tR\x07version\x12\x16\n\x06secret\x18\x02\x20\x01(\x08R\x06sec\ + ret\"\x83\x01\n\x11HandshakeResponse\x12\x18\n\x07version\x18\x01\x20\ + \x01(\tR\x07version\x12\x16\n\x06secret\x18\x02\x20\x01(\x08R\x06secret\ + \x12\x1d\n\npublic_key\x18\x03\x20\x01(\x0cR\tpublicKey\x12\x1d\n\nkey_f\ + inger\x18\x04\x20\x01(\tR\tkeyFinger\"@\n\x16SecretHandshakeRequest\x12\ + \x14\n\x05token\x18\x01\x20\x01(\tR\x05token\x12\x10\n\x03key\x18\x02\ + \x20\x01(\x0cR\x03key\"\xfb\x01\n\x13RegistrationRequest\x12\x14\n\x05to\ + ken\x18\x01\x20\x01(\tR\x05token\x12\x1b\n\tdevice_id\x18\x02\x20\x01(\t\ + R\x08deviceId\x12\x12\n\x04name\x18\x03\x20\x01(\tR\x04name\x12\x17\n\ + \x07is_fast\x18\x04\x20\x01(\x08R\x06isFast\x12\x18\n\x07version\x18\x05\ + \x20\x01(\tR\x07version\x12\x1d\n\nvirtual_ip\x18\x06\x20\x01(\x07R\tvir\ + tualIp\x12&\n\x0fallow_ip_change\x18\x07\x20\x01(\x08R\rallowIpChange\ + \x12#\n\rclient_secret\x18\x08\x20\x01(\x08R\x0cclientSecret\"\xb3\x02\n\ + \x14RegistrationResponse\x12\x1d\n\nvirtual_ip\x18\x01\x20\x01(\x07R\tvi\ + rtualIp\x12'\n\x0fvirtual_gateway\x18\x02\x20\x01(\x07R\x0evirtualGatewa\ + y\x12'\n\x0fvirtual_netmask\x18\x03\x20\x01(\x07R\x0evirtualNetmask\x12\ + \x14\n\x05epoch\x18\x04\x20\x01(\rR\x05epoch\x125\n\x10device_info_list\ + \x18\x05\x20\x03(\x0b2\x0b.DeviceInfoR\x0edeviceInfoList\x12\x1b\n\tpubl\ + ic_ip\x18\x06\x20\x01(\x07R\x08publicIp\x12\x1f\n\x0bpublic_port\x18\x07\ + \x20\x01(\rR\npublicPort\x12\x1f\n\x0bpublic_ipv6\x18\x08\x20\x01(\x0cR\ + \npublicIpv6\"\x89\x01\n\nDeviceInfo\x12\x12\n\x04name\x18\x01\x20\x01(\ + \tR\x04name\x12\x1d\n\nvirtual_ip\x18\x02\x20\x01(\x07R\tvirtualIp\x12#\ + \n\rdevice_status\x18\x03\x20\x01(\rR\x0cdeviceStatus\x12#\n\rclient_sec\ + ret\x18\x04\x20\x01(\x08R\x0cclientSecret\"Y\n\nDeviceList\x12\x14\n\x05\ + epoch\x18\x01\x20\x01(\rR\x05epoch\x125\n\x10device_info_list\x18\x02\ + \x20\x03(\x0b2\x0b.DeviceInfoR\x0edeviceInfoList\"\xa2\x02\n\tPunchInfo\ + \x12$\n\x0epublic_ip_list\x18\x02\x20\x03(\x07R\x0cpublicIpList\x12\x1f\ + \n\x0bpublic_port\x18\x03\x20\x01(\rR\npublicPort\x12*\n\x11public_port_\ + range\x18\x04\x20\x01(\rR\x0fpublicPortRange\x12(\n\x08nat_type\x18\x05\ + \x20\x01(\x0e2\r.PunchNatTypeR\x07natType\x12\x14\n\x05reply\x18\x06\x20\ + \x01(\x08R\x05reply\x12\x19\n\x08local_ip\x18\x07\x20\x01(\x07R\x07local\ + Ip\x12\x1d\n\nlocal_port\x18\x08\x20\x01(\rR\tlocalPort\x12(\n\x10public\ + _ipv6_list\x18\t\x20\x03(\x0cR\x0epublicIpv6List*'\n\x0cPunchNatType\x12\ + \r\n\tSymmetric\x10\0\x12\x08\n\x04Cone\x10\x01b\x06proto3\ "; /// `FileDescriptorProto` object which was a source for this generated file @@ -1151,7 +1651,10 @@ pub fn file_descriptor() -> &'static ::protobuf::reflect::FileDescriptor { file_descriptor.get(|| { let generated_file_descriptor = generated_file_descriptor_lazy.get(|| { let mut deps = ::std::vec::Vec::with_capacity(0); - let mut messages = ::std::vec::Vec::with_capacity(5); + let mut messages = ::std::vec::Vec::with_capacity(8); + messages.push(HandshakeRequest::generated_message_descriptor_data()); + messages.push(HandshakeResponse::generated_message_descriptor_data()); + messages.push(SecretHandshakeRequest::generated_message_descriptor_data()); messages.push(RegistrationRequest::generated_message_descriptor_data()); messages.push(RegistrationResponse::generated_message_descriptor_data()); messages.push(DeviceInfo::generated_message_descriptor_data()); diff --git a/src/protocol/body.rs b/src/protocol/body.rs new file mode 100644 index 0000000..215bdb7 --- /dev/null +++ b/src/protocol/body.rs @@ -0,0 +1,205 @@ +use std::{fmt, io}; + +pub const ENCRYPTION_RESERVED: usize = 32; +/* aes_gcm加密数据体 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 数据体 | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | random(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | tag(32) | + | tag(32) | + | tag(32) | + | tag(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | finger(32) | + | finger(32) | + | finger(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + 注:finger用于快速校验数据是否被修改,上层可使用token、协议头参与计算finger, + 确保服务端和客户端都能感知修改(服务端不能解密也能校验指纹) + */ +pub struct SecretBody { + buffer: B, +} + +impl> SecretBody { + pub fn new(buffer: B) -> io::Result> { + let len = buffer.as_ref().len(); + // 不能大于udp最大载荷长度 + if len < 32 || len > 65535 - 20 - 8 - 12 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "length overflow", + )); + } + Ok(SecretBody { buffer }) + } + pub fn data(&self) -> &[u8] { + let end = self.buffer.as_ref().len() - 32; + &self.buffer.as_ref()[..end] + } + pub fn random(&self) -> u32 { + let end = self.buffer.as_ref().len() - 16 - 12; + u32::from_be_bytes(self.buffer.as_ref()[end - 4..end].try_into().unwrap()) + } + pub fn body(&self) -> &[u8] { + let end = self.buffer.as_ref().len() - 16 - 12; + &self.buffer.as_ref()[..end] + } + pub fn tag(&self) -> &[u8] { + let end = self.buffer.as_ref().len() - 12; + &self.buffer.as_ref()[end - 16..end] + } + pub fn finger(&self) -> &[u8] { + let end = self.buffer.as_ref().len(); + &self.buffer.as_ref()[end - 12..end] + } + pub fn buffer(&self) -> &[u8] { + self.buffer.as_ref() + } +} + +impl + AsMut<[u8]>> SecretBody { + pub fn set_data(&mut self, data: &[u8]) -> io::Result<()> { + let end = self.buffer.as_ref().len() - 32; + if end - 4 != data.len() { + return Err(io::Error::new(io::ErrorKind::InvalidData, "end-4 != data.len")); + } + self.buffer.as_mut()[..end].copy_from_slice(data); + Ok(()) + } + pub fn set_random(&mut self, random: u32) { + let end = self.buffer.as_ref().len() - 16 - 12; + self.buffer.as_mut()[end - 4..end].copy_from_slice(&random.to_be_bytes()); + } + + pub fn set_tag(&mut self, tag: &[u8]) -> io::Result<()> { + if tag.len() != 16 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "tag.len != 16")); + } + let end = self.buffer.as_ref().len() - 12; + self.buffer.as_mut()[end - 16..end].copy_from_slice(tag); + Ok(()) + } + pub fn set_finger(&mut self, finger: &[u8]) -> io::Result<()> { + if finger.len() != 12 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "finger.len != 12")); + } + let end = self.buffer.as_ref().len(); + self.buffer.as_mut()[end - 12..end].copy_from_slice(finger); + Ok(()) + } + + pub fn data_mut(&mut self) -> &mut [u8] { + let end = self.buffer.as_ref().len() - 32; + &mut self.buffer.as_mut()[..end] + } + /// 数据部分 + pub fn body_mut(&mut self) -> &mut [u8] { + let end = self.buffer.as_ref().len() - 12 - 16; + &mut self.buffer.as_mut()[..end] + } + pub fn tag_mut(&mut self) -> &mut [u8] { + let end = self.buffer.as_ref().len() - 12; + &mut self.buffer.as_mut()[end - 16..end] + } + /// 数据部分+tag部分 + pub fn en_body_mut(&mut self) -> &mut [u8] { + let end = self.buffer.as_ref().len() - 12; + &mut self.buffer.as_mut()[..end] + } + pub fn buffer_mut(&mut self) -> &mut [u8] { + self.buffer.as_mut() + } +} + +impl> fmt::Debug for SecretBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SecretBody") + .field("random", &self.random()) + .field("body", &self.body()) + .field("tag", &self.tag()) + .finish() + } +} + +/* rsa加密数据体 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 数据体(n) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | random(32) | + | random(32) | + | random(32) | + | random(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | finger(32) | + | finger(32) | + | finger(32) | + | finger(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +pub struct RsaSecretBody { + buffer: B, +} + +impl> RsaSecretBody { + pub fn new(buffer: B) -> io::Result> { + let len = buffer.as_ref().len(); + // 不能大于udp最大载荷长度 + if len < 32 || len > 65535 - 20 - 8 - 12 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "length overflow", + )); + } + Ok(RsaSecretBody { buffer }) + } + pub fn data(&self) -> &[u8] { + let end = self.buffer.as_ref().len() - 32; + &self.buffer.as_ref()[..end] + } + pub fn random(&self) -> &[u8] { + let end = self.buffer.as_ref().len() - 16; + &self.buffer.as_ref()[end - 16..end] + } + pub fn body(&self) -> &[u8] { + let end = self.buffer.as_ref().len() - 16; + &self.buffer.as_ref()[..end] + } + pub fn finger(&self) -> &[u8] { + let end = self.buffer.as_ref().len() - 16; + &self.buffer.as_ref()[end..] + } + pub fn buffer(&self) -> &[u8] { + &self.buffer.as_ref() + } +} + +impl + AsMut<[u8]>> RsaSecretBody { + pub fn set_random(&mut self, random: &[u8]) -> io::Result<()> { + if random.len() != 16 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "random.len != 16")); + } + let end = self.buffer.as_ref().len() - 16; + self.buffer.as_mut()[end - 16..end].copy_from_slice(random); + Ok(()) + } + pub fn random_mut(&mut self) -> &mut [u8] { + let end = self.buffer.as_ref().len() - 16; + &mut self.buffer.as_mut()[end - 16..end] + } + pub fn set_finger(&mut self, finger: &[u8]) -> io::Result<()> { + if finger.len() != 16 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "finger.len != 16")); + } + let end = self.buffer.as_ref().len(); + self.buffer.as_mut()[end - 16..end].copy_from_slice(finger); + Ok(()) + } +} \ No newline at end of file diff --git a/src/protocol/error_packet.rs b/src/protocol/error_packet.rs index 45be4bc..64971a0 100644 --- a/src/protocol/error_packet.rs +++ b/src/protocol/error_packet.rs @@ -7,6 +7,7 @@ pub enum Protocol { AddressExhausted, IpAlreadyExists, InvalidIp, + NoKey, Other(u8), } @@ -18,6 +19,7 @@ impl From for Protocol { 3 => Self::AddressExhausted, 4 => Self::IpAlreadyExists, 5 => Self::InvalidIp, + 6 => Self::NoKey, val => Self::Other(val), } } @@ -31,6 +33,7 @@ impl Into for Protocol { Protocol::AddressExhausted => 3, Protocol::IpAlreadyExists => 4, Protocol::InvalidIp => 5, + Protocol::NoKey => 6, Protocol::Other(val) => val, } } @@ -42,6 +45,7 @@ pub enum InErrorPacket { AddressExhausted, IpAlreadyExists, InvalidIp, + NoKey, OtherError(ErrorPacket), } @@ -53,6 +57,7 @@ impl> InErrorPacket { Protocol::AddressExhausted => Ok(InErrorPacket::AddressExhausted), Protocol::IpAlreadyExists => Ok(InErrorPacket::IpAlreadyExists), Protocol::InvalidIp => Ok(InErrorPacket::InvalidIp), + Protocol::NoKey => Ok(InErrorPacket::NoKey), Protocol::Other(_) => Ok(InErrorPacket::OtherError(ErrorPacket::new(buffer)?)), } } diff --git a/src/protocol/ip_turn_packet.rs b/src/protocol/ip_turn_packet.rs index c2c5a86..667a44f 100644 --- a/src/protocol/ip_turn_packet.rs +++ b/src/protocol/ip_turn_packet.rs @@ -3,8 +3,6 @@ use std::net::Ipv4Addr; #[derive(Copy, Clone, Eq, PartialEq, Debug)] pub enum Protocol { - Icmp, - Igmp, Ipv4, Ipv4Broadcast, Unknown(u8), @@ -13,8 +11,6 @@ pub enum Protocol { impl From for Protocol { fn from(value: u8) -> Self { match value { - 1 => Protocol::Icmp, - 2 => Protocol::Igmp, 4 => Protocol::Ipv4, 201 => Protocol::Ipv4Broadcast, val => Protocol::Unknown(val), @@ -25,8 +21,6 @@ impl From for Protocol { impl Into for Protocol { fn into(self) -> u8 { match self { - Protocol::Icmp => 1, - Protocol::Igmp => 2, Protocol::Ipv4 => 4, Protocol::Ipv4Broadcast => 201, Protocol::Unknown(val) => val, @@ -34,18 +28,18 @@ impl Into for Protocol { } } -pub struct BroadcastPacketEnd { +pub struct BroadcastPacket { buffer: B, } -impl> BroadcastPacketEnd { +impl> BroadcastPacket { pub fn unchecked(buffer: B) -> Self { Self { buffer } } pub fn new(buffer: B) -> io::Result { let len = buffer.as_ref().len(); let packet = Self::unchecked(buffer); - if len < 1 || packet.len() != len { + if len < 2 + 4 || packet.addr_num() == 0 { Err(io::Error::new( io::ErrorKind::InvalidData, "InvalidData", @@ -56,31 +50,36 @@ impl> BroadcastPacketEnd { } } -impl> BroadcastPacketEnd { - pub fn len(&self) -> usize { - 1 + self.num() as usize * 4 - } - pub fn num(&self) -> u8 { - let len = self.buffer.as_ref().len(); - self.buffer.as_ref()[len - 1] +impl> BroadcastPacket { + pub fn addr_num(&self) -> u8 { + self.buffer.as_ref()[1] } /// 已经发送给了这些地址 - /// 从尾往头拿 pub fn addresses(&self) -> Vec { - let num = self.num() as usize; + let num = self.addr_num() as usize; let mut list = Vec::with_capacity(num); let buf = self.buffer.as_ref(); - let mut offset = buf.len() + 4 - 2; + let mut offset = 1; for _ in 0..num { - offset -= 4; - list.push(Ipv4Addr::new(buf[offset - 3], buf[offset - 2], buf[offset - 1], buf[offset])); + list.push(Ipv4Addr::new(buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3])); + offset += 4; } list } + pub fn data(&self) -> io::Result<&[u8]> { + let start = 1 + self.addr_num() as usize * 4; + if start > self.buffer.as_ref().len() { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "InvalidData", + )) + } else { + Ok(&self.buffer.as_ref()[start..]) + } + } } -impl + AsMut<[u8]>> BroadcastPacketEnd { - /// 从头往尾放 +impl + AsMut<[u8]>> BroadcastPacket { pub fn set_address(&mut self, addr: &[Ipv4Addr]) -> io::Result<()> { let buf = self.buffer.as_mut(); if buf.len() < 1 + addr.len() * 4 || addr.len() > u8::MAX as usize { @@ -89,15 +88,28 @@ impl + AsMut<[u8]>> BroadcastPacketEnd { "InvalidData", )) } else { - let mut offset = 0; + buf[0] = addr.len() as u8; + let mut offset = 1; for ip in addr { buf[offset..offset + 4].copy_from_slice(&ip.octets()); offset += 4; } - self.buffer.as_mut()[offset] = addr.len() as u8; Ok(()) } } + pub fn set_data(&mut self, data: &[u8]) -> io::Result<()> { + let num = self.addr_num() as usize; + let start = 1 + 4 * num; + let buf = self.buffer.as_mut(); + if start > buf.len() || start + data.len() != buf.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "InvalidData", + )); + } + buf[start..].copy_from_slice(data); + Ok(()) + } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 5ee818b..ed6bae5 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,11 +1,12 @@ -use std::net::Ipv4Addr; use std::{fmt, io}; +use std::net::Ipv4Addr; +use crate::protocol::body::ENCRYPTION_RESERVED; /* 0 15 31 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | 版本(8) | 协议(8) | 上层协议(8) | 初始ttl(4) | 生存时间(4) | + |e|s|unused| 版本(4) | 协议(8) | 上层协议(8) | 初始ttl(4) | 生存时间(4) | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | 源ip地址(32) | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -13,8 +14,11 @@ use std::{fmt, io}; +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | 数据体 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + 注:e为是否加密标志,s为服务端通信包标志 */ + +pub mod body; pub mod control_packet; pub mod error_packet; pub mod service_packet; @@ -91,32 +95,70 @@ pub const MAX_SOURCE: u8 = 0b11110000; #[derive(Copy, Clone)] pub struct NetPacket { + data_len: usize, buffer: B, } impl> NetPacket { pub fn new(buffer: B) -> io::Result> { - let len = buffer.as_ref().len(); - // 不能大于udp最大载荷长度 - if len < 12 || len > 65535 - 20 - 8 { + let data_len = buffer.as_ref().len(); + Self::new0(data_len, buffer) + } + pub fn new_encrypt(buffer: B) -> io::Result> { + if 12 + ENCRYPTION_RESERVED > buffer.as_ref().len() { return Err(io::Error::new( io::ErrorKind::InvalidData, "length overflow", )); } - Ok(NetPacket { buffer }) + //加密需要预留32字节 + let data_len = buffer.as_ref().len() - ENCRYPTION_RESERVED; + Self::new0(data_len, buffer) + } + pub fn new0(data_len: usize, buffer: B) -> io::Result> { + if data_len > buffer.as_ref().len() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "length overflow", + )); + } + // 不能大于udp最大载荷长度 + if data_len < 12 || buffer.as_ref().len() > 65535 - 20 - 8 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "length overflow", + )); + } + Ok(NetPacket { data_len, buffer }) } pub fn buffer(&self) -> &[u8] { + &self.buffer.as_ref()[..self.data_len] + } + pub fn raw_buffer(&self) -> &[u8] { self.buffer.as_ref() } + pub fn data_len(&self) -> usize { + self.data_len + } + pub fn reserve(&self) -> usize { + self.buffer.as_ref().len() - self.data_len + } pub fn into_buffer(self) -> B { self.buffer } } impl> NetPacket { + /// 数据加密 + pub fn is_encrypt(&self) -> bool { + self.buffer.as_ref()[0] & 0x80 == 0x80 + } + /// 网关通信的标识 + pub fn is_gateway(&self) -> bool { + self.buffer.as_ref()[0] & 0x50 == 0x50 + } pub fn version(&self) -> Version { - Version::from(self.buffer.as_ref()[0]) + Version::from(self.buffer.as_ref()[0] & 0x0F) } pub fn protocol(&self) -> Protocol { Protocol::from(self.buffer.as_ref()[1]) @@ -139,16 +181,31 @@ impl> NetPacket { Ipv4Addr::from(tmp) } pub fn payload(&self) -> &[u8] { - &self.buffer.as_ref()[12..] + &self.buffer.as_ref()[12..self.data_len] } } impl + AsMut<[u8]>> NetPacket { - pub fn buffer_mut(&mut self)->&mut [u8]{ + pub fn buffer_mut(&mut self) -> &mut [u8] { self.buffer.as_mut() } + pub fn set_encrypt_flag(&mut self, is_encrypt: bool) { + if is_encrypt { + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] | 0x80 + } else { + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] & 0x7F + }; + } + pub fn set_gateway_flag(&mut self, is_gateway: bool) { + if is_gateway { + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] | 0x50 + } else { + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] & 0xBF + }; + } pub fn set_version(&mut self, version: Version) { - self.buffer.as_mut()[0] = version.into(); + let v: u8 = version.into(); + self.buffer.as_mut()[0] = (self.buffer.as_ref()[0] & 0xF0) | (0x0F & v); } pub fn set_protocol(&mut self, protocol: Protocol) { self.buffer.as_mut()[1] = protocol.into(); @@ -171,11 +228,22 @@ impl + AsMut<[u8]>> NetPacket { pub fn set_destination(&mut self, destination: Ipv4Addr) { self.buffer.as_mut()[8..12].copy_from_slice(&destination.octets()); } - pub fn set_payload(&mut self, payload: &[u8]) { - self.buffer.as_mut()[12..payload.len() + 12].copy_from_slice(payload); + pub fn set_payload(&mut self, payload: &[u8]) -> io::Result<()> { + if self.data_len - 12 != payload.len() { + return Err(io::Error::new(io::ErrorKind::InvalidData, "data_len - 12 != payload.len")); + } + self.buffer.as_mut()[12..self.data_len].copy_from_slice(payload); + Ok(()) } pub fn payload_mut(&mut self) -> &mut [u8] { - &mut self.buffer.as_mut()[12..] + &mut self.buffer.as_mut()[12..self.data_len] + } + pub fn set_data_len(&mut self, data_len: usize) -> io::Result<()> { + if data_len > self.buffer.as_ref().len() || data_len < 12 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "data_len invalid")); + } + self.data_len = data_len; + Ok(()) } } diff --git a/src/protocol/service_packet.rs b/src/protocol/service_packet.rs index d17a149..0b34699 100644 --- a/src/protocol/service_packet.rs +++ b/src/protocol/service_packet.rs @@ -8,6 +8,11 @@ pub enum Protocol { PollDeviceList, /// 推送设备列表 PushDeviceList, + /// 和服务端握手 + HandshakeRequest, + HandshakeResponse, + SecretHandshakeRequest, + SecretHandshakeResponse, Unknown(u8), } @@ -18,6 +23,10 @@ impl From for Protocol { 2 => Self::RegistrationResponse, 3 => Self::PollDeviceList, 4 => Self::PushDeviceList, + 5 => Self::HandshakeRequest, + 6 => Self::HandshakeResponse, + 7 => Self::SecretHandshakeRequest, + 8 => Self::SecretHandshakeResponse, val => Self::Unknown(val), } } @@ -30,6 +39,10 @@ impl Into for Protocol { Self::RegistrationResponse => 2, Self::PollDeviceList => 3, Self::PushDeviceList => 4, + Self::HandshakeRequest => 5, + Self::HandshakeResponse => 6, + Self::SecretHandshakeRequest => 7, + Self::SecretHandshakeResponse => 8, Self::Unknown(val) => val, } } diff --git a/src/service/igmp_server.rs b/src/service/igmp_server.rs index c4e54f5..db79580 100644 --- a/src/service/igmp_server.rs +++ b/src/service/igmp_server.rs @@ -1,13 +1,13 @@ -use std::collections::{HashMap, HashSet}; -use std::net::Ipv4Addr; -use std::sync::Arc; +use crate::error::*; +use moka::sync::Cache; use packet::igmp::igmp_v2::IgmpV2Packet; use packet::igmp::igmp_v3::{IgmpV3RecordType, IgmpV3ReportPacket}; use packet::igmp::IgmpType; use parking_lot::RwLock; -use moka::sync::Cache; +use std::collections::{HashMap, HashSet}; +use std::net::Ipv4Addr; +use std::sync::Arc; use std::time::Duration; -use crate::error::*; lazy_static::lazy_static! { //组播缓存 30分钟 (token,group_address) -> members @@ -16,9 +16,9 @@ lazy_static::lazy_static! { // (token,group_address,member_ip) static ref MULTICAST_MEMBER:Cache<(String,Ipv4Addr,Ipv4Addr), ()> = Cache::builder() .time_to_idle(Duration::from_secs(20*60)).eviction_listener(|k:Arc<(String,Ipv4Addr,Ipv4Addr)>,_,cause|{ - if cause==moka::notification::RemovalCause::Replaced{ - return; - } + if cause==moka::notification::RemovalCause::Replaced{ + return; + } log::info!("MULTICAST_MEMBER eviction {:?}", k); if let Some(v) = MULTICAST.get(&(k.0.clone(),k.1)){ let mut lock = v.write(); @@ -116,7 +116,8 @@ pub fn handle(buf: &[u8], token: &String, source: Ipv4Addr) -> Result<()> { guard.members.insert(source); guard.map.insert(source, (true, HashSet::from_iter(src))); drop(guard); - MULTICAST_MEMBER.insert((token.clone(), multicast_addr, source), ()); + MULTICAST_MEMBER + .insert((token.clone(), multicast_addr, source), ()); } } } @@ -140,20 +141,18 @@ pub fn handle(buf: &[u8], token: &String, source: Ipv4Addr) -> Result<()> { //在已有源的基础上,接收目标源,如果是排除模式,则删除;是包含模式则添加 match group_record.source_addresses() { None => {} - Some(src) => { - match guard.map.get_mut(&source) { - None => {} - Some((is_include, set)) => { - for ip in src { - if *is_include { - set.insert(ip); - } else { - set.remove(&ip); - } + Some(src) => match guard.map.get_mut(&source) { + None => {} + Some((is_include, set)) => { + for ip in src { + if *is_include { + set.insert(ip); + } else { + set.remove(&ip); } } } - } + }, } drop(guard); MULTICAST_MEMBER.insert((token.clone(), multicast_addr, source), ()); @@ -162,20 +161,18 @@ pub fn handle(buf: &[u8], token: &String, source: Ipv4Addr) -> Result<()> { //在已有源的基础上,不接收目标源 match group_record.source_addresses() { None => {} - Some(src) => { - match guard.map.get_mut(&source) { - None => {} - Some((is_include, set)) => { - for ip in src { - if *is_include { - set.remove(&ip); - } else { - set.insert(ip); - } + Some(src) => match guard.map.get_mut(&source) { + None => {} + Some((is_include, set)) => { + for ip in src { + if *is_include { + set.remove(&ip); + } else { + set.insert(ip); } } } - } + }, } drop(guard); MULTICAST_MEMBER.insert((token.clone(), multicast_addr, source), ()); @@ -188,4 +185,4 @@ pub fn handle(buf: &[u8], token: &String, source: Ipv4Addr) -> Result<()> { IgmpType::Unknown(_) => {} } Ok(()) -} \ No newline at end of file +} diff --git a/src/service/main_service/common.rs b/src/service/main_service/common.rs index 7261e39..c79ac00 100644 --- a/src/service/main_service/common.rs +++ b/src/service/main_service/common.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; + use chrono::Local; use packet::icmp::{icmp, Kind}; use packet::ip::ipv4; @@ -9,21 +10,63 @@ use parking_lot::RwLock; use protobuf::Message; use tokio::net::UdpSocket; use tokio::sync::mpsc::Sender; -use crate::ConfigInfo; + +use crate::cipher::Aes256GcmCipher; +use crate::cipher::{Finger, RsaCipher}; use crate::error::Error; use crate::proto::message; use crate::proto::message::{DeviceList, RegistrationRequest, RegistrationResponse}; -use crate::protocol::ip_turn_packet::BroadcastPacketEnd; -use crate::protocol::{control_packet, error_packet, ip_turn_packet, MAX_TTL, NetPacket, Protocol, service_packet, Version}; +use crate::protocol::ip_turn_packet::BroadcastPacket; +use crate::protocol::{ + body::ENCRYPTION_RESERVED, control_packet, error_packet, ip_turn_packet, service_packet, + NetPacket, Protocol, Version, MAX_TTL, +}; use crate::service::igmp_server::Multicast; -use crate::service::main_service::{Context, DEVICE_ADDRESS, DEVICE_ID_SESSION, DeviceInfo, PeerDeviceStatus, PeerLink, UDP_SESSION, VIRTUAL_NETWORK, VirtualNetwork}; +use crate::service::main_service::{ + Context, DeviceInfo, PeerDeviceStatus, PeerLink, VirtualNetwork, DEVICE_ADDRESS, + DEVICE_ID_SESSION, TCP_AES, UDP_AES, UDP_SESSION, VIRTUAL_NETWORK, +}; +use crate::ConfigInfo; -pub fn register(data: &[u8], config: &ConfigInfo, addr: SocketAddr, link: PeerLink) -> crate::error::Result<(Context, RegistrationResponse)> { +fn check_reg(request: &RegistrationRequest) -> crate::error::Result<()> { + if request.token.len() == 0 || request.token.len() > 64 { + return Err(Error::InvalidPacket); + } + if request.device_id.len() == 0 || request.device_id.len() > 64 { + return Err(Error::InvalidPacket); + } + if request.name.len() == 0 || request.name.len() > 64 { + return Err(Error::InvalidPacket); + } + Ok(()) +} + +fn register0( + context: &mut Option, + data: &[u8], + config: &ConfigInfo, + addr: SocketAddr, + link: PeerLink, +) -> crate::error::Result { let request = RegistrationRequest::parse_from_bytes(data)?; - log::info!("register:{:?}",request); + check_reg(&request)?; + log::info!( + "register,id={:?},name={:?},version={:?},virtual_ip={},client_secret={},allow_ip_change={},is_fast={}", + request.device_id, + request.name, + request.version, + request.virtual_ip, + request.client_secret, + request.allow_ip_change, + request.is_fast + ); if let Some(white_token) = &config.white_token { if !white_token.contains(&request.token) { - log::info!("token不在白名单,white_token={:?},token={:?}",white_token,request.token); + log::info!( + "token不在白名单,white_token={:?},token={:?}", + white_token, + request.token + ); return Err(Error::TokenError); } } @@ -39,28 +82,32 @@ pub fn register(data: &[u8], config: &ConfigInfo, addr: SocketAddr, link: PeerLi } response.virtual_netmask = config.netmask.into(); response.virtual_gateway = config.gateway.into(); - let v = VIRTUAL_NETWORK.optionally_get_with(request.token.clone(), || { - Some(Arc::new(parking_lot::const_rwlock(VirtualNetwork { - epoch: 0, - virtual_ip_map: HashMap::new(), - }))) - }).unwrap(); + let v = VIRTUAL_NETWORK + .optionally_get_with(request.token.clone(), || { + Some(Arc::new(parking_lot::const_rwlock(VirtualNetwork { + epoch: 0, + virtual_ip_map: HashMap::new(), + }))) + }) + .unwrap(); let mut lock = v.write(); lock.epoch += 1; response.epoch = lock.epoch; - let (id, mut virtual_ip) = if let Some(mut device_info) = lock.virtual_ip_map.get_mut(&request.device_id) { - device_info.status = PeerDeviceStatus::Online; - device_info.name = request.name.clone(); - device_info.id = Local::now().timestamp_millis(); - if request.virtual_ip != 0 && device_info.ip != request.virtual_ip { - (Local::now().timestamp_millis(), 0) + let (id, mut virtual_ip) = + if let Some(device_info) = lock.virtual_ip_map.get_mut(&request.device_id) { + device_info.status = PeerDeviceStatus::Online; + device_info.name = request.name.clone(); + device_info.client_secret = request.client_secret; + device_info.id = Local::now().timestamp_millis(); + if request.virtual_ip != 0 && device_info.ip != request.virtual_ip { + (Local::now().timestamp_millis(), 0) + } else { + (Local::now().timestamp_millis(), device_info.ip) + } } else { - (device_info.id, device_info.ip) - } - } else { - (Local::now().timestamp_millis(), 0) - }; + (Local::now().timestamp_millis(), 0) + }; if virtual_ip == 0 { //获取一个未使用的ip let set: HashSet = lock @@ -69,10 +116,14 @@ pub fn register(data: &[u8], config: &ConfigInfo, addr: SocketAddr, link: PeerLi .filter(|(key, _)| *key != &request.device_id) .map(|(_, device_info)| device_info.ip) .collect(); - let ip_range = (response.virtual_gateway & response.virtual_netmask) + 1..response.virtual_gateway | (!response.virtual_netmask); + let ip_range = (response.virtual_gateway & response.virtual_netmask) + 1 + ..response.virtual_gateway | (!response.virtual_netmask); if request.virtual_ip != 0 { let ip = request.virtual_ip; - if u32::from(config.gateway) == ip || u32::from(config.broadcast) == ip || !ip_range.contains(&ip) { + if u32::from(config.gateway) == ip + || u32::from(config.broadcast) == ip + || !ip_range.contains(&ip) + { log::warn!("手动指定的ip无效:{:?}", request); return Err(Error::InvalidIp); } @@ -108,50 +159,123 @@ pub fn register(data: &[u8], config: &ConfigInfo, addr: SocketAddr, link: PeerLi name: request.name.clone(), ip: virtual_ip, status: PeerDeviceStatus::Online, + client_secret: request.client_secret, }, ); + } else { + lock.virtual_ip_map.get_mut(&request.device_id).unwrap().id = id; } for (_device_id, device_info) in &lock.virtual_ip_map { if device_info.ip != virtual_ip { - let mut dev = crate::proto::message::DeviceInfo::new(); + let mut dev = message::DeviceInfo::new(); dev.virtual_ip = device_info.ip; dev.name = device_info.name.clone(); let status: u8 = device_info.status.into(); dev.device_status = status as u32; + dev.client_secret = device_info.client_secret; response.device_info_list.push(dev); } } - DEVICE_ADDRESS.insert((request.token.clone(), virtual_ip), link.clone()); - drop(lock); - DEVICE_ID_SESSION.insert((request.token.clone(), request.device_id.clone()), ()); response.virtual_ip = virtual_ip; - let context = Context { - token: request.token.clone(), - virtual_ip, - id, - device_id: request.device_id.clone(), - }; + let c = context.get_or_insert_with(|| Context::default()); + c.id = id; + c.virtual_ip = virtual_ip; + c.token = request.token.clone(); + c.device_id = request.device_id.clone(); + c.client_secret = request.client_secret; + c.address = addr; + DEVICE_ADDRESS.insert( + (request.token.clone(), virtual_ip), + (link.clone(), c.clone()), + ); + drop(lock); + DEVICE_ID_SESSION.insert((request.token.clone(), request.device_id.clone()), id); match link { PeerLink::Tcp(_) => {} PeerLink::Udp(_) => { - UDP_SESSION.insert( - addr, - context.clone(), - ); + UDP_SESSION.insert(addr, c.clone()); } } - Ok((context, response)) + Ok(response) } -pub async fn broadcast(source_addr: SocketAddr, main_udp: &UdpSocket, context: &Context, buf: &[u8], multicast_info: Option<&RwLock>, exclude: &[Ipv4Addr]) -> crate::error::Result<()> { +async fn register( + context: &mut Option, + aes_gcm_cipher: &Option, + main_udp: &UdpSocket, + net_packet: NetPacket<&mut [u8]>, + config: &ConfigInfo, + addr: SocketAddr, + sender: Option<&Sender>>, +) -> crate::error::Result { + let link = sender + .map(|v| PeerLink::Tcp(v.clone())) + .unwrap_or(PeerLink::Udp(addr)); + match register0(context, net_packet.payload(), config, addr, link.clone()) { + Ok(response) => { + let bytes = response.write_to_bytes()?; + let mut rs = vec![0u8; 12 + bytes.len() + ENCRYPTION_RESERVED]; + let mut packet = NetPacket::new_encrypt(&mut rs)?; + packet.set_version(Version::V1); + packet.set_protocol(Protocol::Service); + packet.set_source(config.gateway); + packet.set_destination(Ipv4Addr::UNSPECIFIED); + packet.set_transport_protocol(service_packet::Protocol::RegistrationResponse.into()); + packet.first_set_ttl(MAX_TTL); + packet.set_payload(&bytes)?; + packet.set_gateway_flag(true); + reply_vec(aes_gcm_cipher, &sender, main_udp, addr, rs).await?; + Ok(response) + } + Err(e) => { + let mut rs = vec![0u8; 12 + ENCRYPTION_RESERVED]; + let mut packet = NetPacket::new_encrypt(&mut rs).unwrap(); + packet.set_version(Version::V1); + packet.set_protocol(Protocol::Error); + packet.first_set_ttl(MAX_TTL); + packet.set_source(config.gateway); + packet.set_gateway_flag(true); + match e { + Error::AddressExhausted => { + packet.set_transport_protocol(error_packet::Protocol::AddressExhausted.into()); + } + Error::TokenError => { + packet.set_transport_protocol(error_packet::Protocol::TokenError.into()); + } + Error::IpAlreadyExists => { + packet.set_transport_protocol(error_packet::Protocol::IpAlreadyExists.into()); + } + Error::InvalidIp => { + packet.set_transport_protocol(error_packet::Protocol::InvalidIp.into()); + } + e => { + log::info!("注册失败:{:?}", e); + return Err(e); + } + } + reply_vec(aes_gcm_cipher, &sender, main_udp, addr, rs).await?; + return Err(e); + } + } +} + +async fn broadcast( + main_udp: &UdpSocket, + context: &Context, + buf: &[u8], + multicast_info: Option<&RwLock>, + exclude: &[Ipv4Addr], +) -> crate::error::Result<()> { if let Some(v) = VIRTUAL_NETWORK.get(&context.token) { - let ips: Vec = v.read() + let ips: Vec = v + .read() .virtual_ip_map .iter() .map(|(_, device_info)| device_info.ip) .filter(|ip| ip != &context.virtual_ip) .collect(); let multicast = multicast_info.map(|v| v.read().clone()); + let client_secret = NetPacket::new(buf)?.is_encrypt(); for ip in ips { let ipv4 = Ipv4Addr::from(ip); if let Some(multicast) = &multicast { @@ -161,14 +285,67 @@ pub async fn broadcast(source_addr: SocketAddr, main_udp: &UdpSocket, context: & } if !exclude.contains(&ipv4) { if let Some(peer) = DEVICE_ADDRESS.get(&(context.token.clone(), ip)) { - match peer.value() { - PeerLink::Tcp(sender) => { - let _ = sender.send(buf.to_vec()).await; - } - PeerLink::Udp(addr) => { - if addr != &source_addr { - let _ = main_udp.send_to(&buf[4..], addr).await; + let (peer_link, peer_context) = peer.value(); + if peer_context.client_secret == client_secret { + match peer_link { + PeerLink::Tcp(sender) => { + let _ = sender.send(buf.to_vec()).await; } + PeerLink::Udp(addr) => { + let _ = main_udp.send_to(buf, addr).await; + } + } + } + } + } + } + } + Ok(()) +} + +async fn broadcast_igmp( + main_udp: &UdpSocket, + context: &Context, + net_packet: NetPacket<&mut [u8]>, +) -> crate::error::Result<()> { + let buf = if net_packet.reserve() != ENCRYPTION_RESERVED { + let mut buf_packet = Vec::with_capacity(net_packet.data_len() + ENCRYPTION_RESERVED); + buf_packet.clone_from_slice(net_packet.buffer()); + buf_packet.resize(net_packet.data_len() + ENCRYPTION_RESERVED, 0); + buf_packet + } else { + net_packet.buffer().to_vec() + }; + if let Some(v) = VIRTUAL_NETWORK.get(&context.token) { + let ips: Vec = v + .read() + .virtual_ip_map + .iter() + .map(|(_, device_info)| device_info.ip) + .filter(|ip| ip != &context.virtual_ip) + .collect(); + for ip in ips { + if let Some(peer) = DEVICE_ADDRESS.get(&(context.token.clone(), ip)) { + let (peer_link, peer_context) = peer.value(); + match peer_link { + PeerLink::Tcp(sender) => { + if let Some(aes) = TCP_AES.get(&peer_context.address) { + let mut packet = NetPacket::new_encrypt(buf.clone())?; + aes.value().encrypt_ipv4(&mut packet)?; + let _ = sender.send(packet.buffer().to_vec()).await; + } else { + let mut packet = NetPacket::new_encrypt(&buf)?; + let _ = sender.send(packet.buffer().to_vec()).await; + } + } + PeerLink::Udp(addr) => { + if let Some(aes) = UDP_AES.get(&peer_context.address) { + let mut packet = NetPacket::new_encrypt(buf.clone())?; + aes.encrypt_ipv4(&mut packet)?; + let _ = main_udp.send_to(packet.buffer(), addr).await; + } else { + let mut packet = NetPacket::new_encrypt(&buf)?; + let _ = main_udp.send_to(packet.buffer(), addr).await; } } } @@ -179,319 +356,503 @@ pub async fn broadcast(source_addr: SocketAddr, main_udp: &UdpSocket, context: & } /// 选择性转发广播/组播,并且去除尾部 -pub async fn change_broadcast(source_addr: SocketAddr, udp: &UdpSocket, context: &Context, broadcast_addr: Ipv4Addr, destination: Ipv4Addr, buf: &[u8]) -> crate::error::Result<()> { - let end_len = buf[buf.len() - 1] as usize * 4 + 1; - if buf.len() <= end_len { - return Err(Error::InvalidPacket); - } - let packet_end = BroadcastPacketEnd::new(&buf[buf.len() - end_len..])?; - let end_len = packet_end.len(); - let exclude = packet_end.addresses(); - let buf = &buf[..buf.len() - end_len]; +async fn change_broadcast( + udp: &UdpSocket, + context: &Context, + broadcast_addr: Ipv4Addr, + destination: Ipv4Addr, + broadcast_packet: BroadcastPacket<&[u8]>, +) -> crate::error::Result<()> { + let exclude = broadcast_packet.addresses(); + let buf = broadcast_packet.data()?; if destination.is_broadcast() || broadcast_addr == destination { - broadcast(source_addr, udp, context, buf, None, &exclude).await?; + broadcast(udp, context, buf, None, &exclude).await?; } else if destination.is_multicast() { - if let Some(multicast_info) = crate::service::igmp_server - ::load(&context.token, destination) { - broadcast(source_addr, udp, context, buf, Some(&multicast_info), &exclude).await?; + if let Some(multicast_info) = crate::service::igmp_server::load(&context.token, destination) + { + broadcast(udp, context, buf, Some(&multicast_info), &exclude).await?; } } Ok(()) } -pub async fn handle(context: &mut Context, main_udp: &UdpSocket, buf: &mut [u8], addr: SocketAddr, config: &ConfigInfo, sender: Option<&Sender>>) -> crate::error::Result<()> { - let reg: u8 = service_packet::Protocol::RegistrationRequest.into(); - let addr_req: u8 = control_packet::Protocol::AddrRequest.into(); - match NetPacket::new(&mut buf[4..]) { - Ok(mut net_packet) => { - let source = net_packet.source(); - let destination = net_packet.destination(); - if net_packet.protocol() == Protocol::Service - && net_packet.transport_protocol() == reg { - let link = sender.map(|v| PeerLink::Tcp(v.clone())).unwrap_or(PeerLink::Udp(addr)); - //注册请求 - match register(net_packet.payload(), config, addr, link.clone()) { - Ok((c, response)) => { - *context = c; - let bytes = response.write_to_bytes()?; - let mut rs = vec![0u8; 4 + 12 + bytes.len()]; - let mut net_packet = NetPacket::new(&mut rs[4..])?; - net_packet.set_version(Version::V1); - net_packet.set_protocol(Protocol::Service); - net_packet.set_source(config.gateway); - net_packet.set_destination(Ipv4Addr::UNSPECIFIED); - net_packet.set_transport_protocol(service_packet::Protocol::RegistrationResponse.into()); - net_packet.first_set_ttl(MAX_TTL); - net_packet.set_payload(&bytes); - match link { - PeerLink::Tcp(sender) => { - let _ = sender.send(rs).await; - } - PeerLink::Udp(addr) => { - main_udp.send_to(net_packet.buffer(), addr).await?; - } - } - } - Err(e) => { - //带上tcp头 - let mut rs = vec![0u8; 4 + 12]; - let mut net_packet = NetPacket::new(&mut rs[4..]).unwrap(); - net_packet.set_version(Version::V1); - net_packet.set_protocol(Protocol::Error); - net_packet.first_set_ttl(MAX_TTL); - net_packet.set_source(config.gateway); - match e { - Error::AddressExhausted => { - net_packet - .set_transport_protocol(error_packet::Protocol::AddressExhausted.into()); - } - Error::TokenError => { - net_packet - .set_transport_protocol(error_packet::Protocol::TokenError.into()); - } - Error::IpAlreadyExists => { - net_packet - .set_transport_protocol(error_packet::Protocol::IpAlreadyExists.into()); - } - Error::InvalidIp => { - net_packet - .set_transport_protocol(error_packet::Protocol::InvalidIp.into()); - } - e => { - log::info!("注册失败:{:?}",e); - return Ok(()); - } - } - match link { - PeerLink::Tcp(sender) => { - let _ = sender.send(rs).await; - } - PeerLink::Udp(_) => { - main_udp.send_to(net_packet.buffer(), addr).await?; - } - } - } +async fn request_addr( + aes_gcm_cipher: &Option, + main_udp: &UdpSocket, + addr: SocketAddr, + net_packet: NetPacket<&mut [u8]>, + sender: Option<&Sender>>, +) -> crate::error::Result<()> { + match addr.ip() { + IpAddr::V4(ipv4) => { + let mut vec = vec![0u8; 12 + 6 + ENCRYPTION_RESERVED]; + let mut packet = NetPacket::new_encrypt(&mut vec)?; + packet.set_version(Version::V1); + packet.set_protocol(Protocol::Control); + packet.set_transport_protocol(control_packet::Protocol::AddrResponse.into()); + packet.first_set_ttl(MAX_TTL); + packet.set_source(net_packet.destination()); + packet.set_destination(net_packet.source()); + packet.set_gateway_flag(true); + let mut addr_packet = control_packet::AddrPacket::new(packet.payload_mut())?; + addr_packet.set_ipv4(ipv4); + addr_packet.set_port(addr.port()); + reply_vec(aes_gcm_cipher, &sender, main_udp, addr, vec).await?; + } + IpAddr::V6(_) => {} + } + Ok(()) +} +async fn server_packet_pre_handle( + context: &mut Option, + rsa_cipher: &Option, + aes_gcm_cipher: &mut Option, + main_udp: &UdpSocket, + net_packet: NetPacket<&mut [u8]>, + config: &ConfigInfo, + addr: SocketAddr, + sender: Option<&Sender>>, +) -> crate::error::Result<()> { + let source = net_packet.source(); + let destination = net_packet.destination(); + match net_packet.protocol() { + Protocol::Service => { + match service_packet::Protocol::from(net_packet.transport_protocol()) { + service_packet::Protocol::RegistrationRequest => { + register( + context, + aes_gcm_cipher, + main_udp, + net_packet, + config, + addr, + sender, + ) + .await?; } - } else if net_packet.protocol() == Protocol::Control - && net_packet.transport_protocol() == addr_req { - match addr.ip() { - IpAddr::V4(ipv4) => { - let mut vec = vec![0u8; 4 + 12 + 6]; - let mut packet = NetPacket::new(&mut vec[4..])?; - packet.set_version(Version::V1); - packet.set_protocol(Protocol::Control); - packet.set_transport_protocol( - control_packet::Protocol::AddrResponse.into(), + service_packet::Protocol::HandshakeRequest => { + // 握手请求,有加密的话回应公钥 + let mut res = message::HandshakeResponse::new(); + res.version = "1.1.3".to_string(); + if let Some(rsp_cipher) = rsa_cipher { + res.public_key.extend_from_slice(rsp_cipher.public_key()); + res.secret = true; + res.key_finger = rsp_cipher.finger()?; + } + let bytes = res.write_to_bytes()?; + let mut vec = vec![0u8; 12 + bytes.len() + ENCRYPTION_RESERVED]; + let mut packet = NetPacket::new_encrypt(&mut vec)?; + packet.set_version(Version::V1); + packet.set_protocol(Protocol::Service); + packet + .set_transport_protocol(service_packet::Protocol::HandshakeResponse.into()); + packet.first_set_ttl(MAX_TTL); + packet.set_source(destination); + packet.set_destination(source); + packet.set_payload(&bytes)?; + packet.set_gateway_flag(true); + reply_vec(&None, &sender, main_udp, addr, vec).await?; + } + service_packet::Protocol::SecretHandshakeRequest => { + // 同步密钥,这个是用公钥加密的,使用私钥解密 + if aes_gcm_cipher.is_none() { + if let Some(rsp_cipher) = rsa_cipher { + let rsa_secret_body = rsp_cipher.decrypt(&net_packet)?; + let sync_secret = message::SecretHandshakeRequest::parse_from_bytes( + rsa_secret_body.data(), + )?; + if sync_secret.key.len() == 32 { + let c = Aes256GcmCipher::new( + sync_secret.key.try_into().unwrap(), + Finger::new(sync_secret.token.clone()), + ); + if sender.is_none() { + UDP_AES.insert(addr, c.clone()); + } else { + TCP_AES.insert(addr, c.clone()); + } + let _ = aes_gcm_cipher.insert(c); + } + } + } + let mut rs = vec![0u8; 12 + ENCRYPTION_RESERVED]; + let mut packet = NetPacket::new_encrypt(&mut rs)?; + packet.set_version(Version::V1); + packet.set_protocol(Protocol::Service); + packet.set_source(config.gateway); + packet.set_destination(source); + packet.set_transport_protocol( + service_packet::Protocol::SecretHandshakeResponse.into(), + ); + packet.first_set_ttl(MAX_TTL); + packet.set_gateway_flag(true); + reply_vec(aes_gcm_cipher, &sender, main_udp, addr, rs).await?; + } + _ => {} + } + } + Protocol::Control => { + match control_packet::Protocol::from(net_packet.transport_protocol()) { + control_packet::Protocol::AddrRequest => { + request_addr(aes_gcm_cipher, main_udp, addr, net_packet, sender).await?; + } + _ => {} + } + } + _ => {} + } + Ok(()) +} +async fn server_packet_handle( + rsa_cipher: &Option, + aes_gcm_cipher: &mut Option, + context: &mut Context, + main_udp: &UdpSocket, + mut net_packet: NetPacket<&mut [u8]>, + config: &ConfigInfo, + addr: SocketAddr, + sender: Option<&Sender>>, +) -> crate::error::Result<()> { + let source = net_packet.source(); + let destination = net_packet.destination(); + match net_packet.protocol() { + Protocol::Service => { + match service_packet::Protocol::from(net_packet.transport_protocol()) { + service_packet::Protocol::PollDeviceList => { + if let Some(v) = VIRTUAL_NETWORK.get(&context.token) { + let (ips, epoch) = { + let lock = v.read(); + let ips: Vec = lock + .virtual_ip_map + .iter() + .filter(|&(_, dev)| dev.ip != context.virtual_ip) + .map(|(_, device_info)| { + let mut dev = message::DeviceInfo::new(); + dev.virtual_ip = device_info.ip; + dev.name = device_info.name.clone(); + let status: u8 = device_info.status.into(); + dev.device_status = status as u32; + dev.client_secret = device_info.client_secret; + dev + }) + .collect(); + let epoch = lock.epoch; + (ips, epoch) + }; + let mut device_list = DeviceList::new(); + device_list.epoch = epoch; + device_list.device_info_list = ips; + let bytes = device_list.write_to_bytes()?; + let mut vec = vec![0u8; 12 + bytes.len() + ENCRYPTION_RESERVED]; + let mut device_list_packet = NetPacket::new_encrypt(&mut vec)?; + device_list_packet.set_version(Version::V1); + device_list_packet.set_protocol(Protocol::Service); + device_list_packet.set_transport_protocol( + service_packet::Protocol::PushDeviceList.into(), ); - packet.first_set_ttl(MAX_TTL); - packet.set_source(destination); - packet.set_destination(source); - let mut addr_packet = control_packet::AddrPacket::new(packet.payload_mut())?; - addr_packet.set_ipv4(ipv4); - addr_packet.set_port(addr.port()); - match sender { - None => { - main_udp.send_to(packet.buffer(), addr).await?; - } - Some(sender) => { - let _ = sender.send(vec).await; - } - } + device_list_packet.first_set_ttl(MAX_TTL); + device_list_packet.set_source(destination); + device_list_packet.set_destination(source); + device_list_packet.set_payload(&bytes)?; + device_list_packet.set_gateway_flag(true); + reply_vec(aes_gcm_cipher, &sender, main_udp, addr, vec).await?; } - IpAddr::V6(_) => {} } - } else if context.virtual_ip != 0 { - if destination == config.gateway { - //给网关的消息 - match net_packet.protocol() { - Protocol::Service => { - if service_packet::Protocol::PollDeviceList == service_packet::Protocol::from(net_packet.transport_protocol()) { - if let Some(v) = VIRTUAL_NETWORK.get(&context.token) { - let (ips, epoch) = { - let lock = v.read(); - let ips: Vec = lock - .virtual_ip_map - .iter() - .filter(|&(_, dev)| { - dev.ip != context.virtual_ip - }) - .map(|(_, device_info)| { - let mut dev = message::DeviceInfo::new(); - dev.virtual_ip = device_info.ip; - dev.name = device_info.name.clone(); - let status: u8 = device_info.status.into(); - dev.device_status = status as u32; - dev - }) - .collect(); - let epoch = lock.epoch; - (ips, epoch) - }; - let mut device_list = DeviceList::new(); - device_list.epoch = epoch; - device_list.device_info_list = ips; - log::info!("context:{:?},device_list:{:?}",context,device_list); - let bytes = device_list.write_to_bytes()?; - let mut vec = vec![0u8; 4 + 12 + bytes.len()]; - let mut device_list_packet = - NetPacket::new(&mut vec[4..])?; - device_list_packet.set_version(Version::V1); - device_list_packet.set_protocol(Protocol::Service); - device_list_packet.set_transport_protocol( - service_packet::Protocol::PushDeviceList.into(), - ); - device_list_packet.first_set_ttl(MAX_TTL); - device_list_packet.set_source(destination); - device_list_packet.set_destination(source); - device_list_packet.set_payload(&bytes); - match sender { - None => { - main_udp.send_to(device_list_packet.buffer(), addr).await?; - } - Some(sender) => { - let _ = sender.send(vec).await; - } - } - } - } - } - Protocol::Control => { - match control_packet::Protocol::from(net_packet.transport_protocol()) { - control_packet::Protocol::Ping => { - let _ = DEVICE_ID_SESSION.get(&(context.token.clone(), context.device_id.clone())); - if let Some(v) = VIRTUAL_NETWORK.get(&context.token) { - let epoch = v.read().epoch; - net_packet.first_set_ttl(MAX_TTL); - net_packet.set_transport_protocol(control_packet::Protocol::Pong.into()); - net_packet.set_source(destination); - net_packet.set_destination(source); - let mut pong_packet = control_packet::PongPacket::new(net_packet.payload_mut())?; - pong_packet.set_epoch(epoch as u16); - match sender { - None => { - main_udp.send_to(net_packet.buffer(), addr).await?; - } - Some(sender) => { - let _ = sender.send(buf.to_vec()).await; - } - } - } - } - _ => {} - } - } - Protocol::IpTurn => { - let mut ipv4 = IpV4Packet::new(net_packet.payload_mut())?; - match ipv4.protocol() { - ipv4::protocol::Protocol::Icmp => { - let mut icmp_packet = icmp::IcmpPacket::new(ipv4.payload_mut())?; - if icmp_packet.kind() == Kind::EchoRequest { - //开启ping - icmp_packet.set_kind(Kind::EchoReply); - icmp_packet.update_checksum(); - ipv4.set_source_ip(destination); - ipv4.set_destination_ip(source); - ipv4.update_checksum(); - net_packet.set_source(destination); - net_packet.set_destination(source); - match sender { - None => { - main_udp.send_to(net_packet.buffer(), addr).await?; - } - Some(sender) => { - let _ = sender.send(buf.to_vec()).await; - } - } - } - } - ipv4::protocol::Protocol::Igmp => { - crate::service::igmp_server::handle(ipv4.payload(), &context.token, source)?; - //Igmp数据也会广播出去,让大家都知道谁加入什么组播 - net_packet.set_destination(Ipv4Addr::new(224, 0, 0, 1)); - broadcast(addr, main_udp, &context, net_packet.buffer(), None, &[]).await?; - } - _ => {} - } - } - _ => { - log::info!("无效数据类型:{:?},Protocol={:?}",addr,net_packet.protocol()) - } + _ => {} + } + } + Protocol::Control => { + match control_packet::Protocol::from(net_packet.transport_protocol()) { + control_packet::Protocol::Ping => { + let _ = + DEVICE_ID_SESSION.get(&(context.token.clone(), context.device_id.clone())); + if let Some(v) = VIRTUAL_NETWORK.get(&context.token) { + let epoch = v.read().epoch; + net_packet.first_set_ttl(MAX_TTL); + net_packet.set_transport_protocol(control_packet::Protocol::Pong.into()); + net_packet.set_source(destination); + net_packet.set_destination(source); + net_packet.set_gateway_flag(true); + let mut pong_packet = + control_packet::PongPacket::new(net_packet.payload_mut())?; + pong_packet.set_epoch(epoch as u16); + reply_buf( + aes_gcm_cipher, + &sender, + main_udp, + addr, + net_packet.into_buffer(), + ) + .await?; } - } else { - //需要转发的数据 - if net_packet.ttl() > 1 { - net_packet.set_ttl(net_packet.ttl() - 1); - if Protocol::IpTurn == net_packet.protocol() { - //处理广播 - match ip_turn_packet::Protocol::from(net_packet.transport_protocol()) { - ip_turn_packet::Protocol::Icmp => {} - ip_turn_packet::Protocol::Igmp => { - let ipv4 = IpV4Packet::new(net_packet.payload())?; - if ipv4.protocol() == ipv4::protocol::Protocol::Igmp { - crate::service::igmp_server::handle(ipv4.payload(), &context.token, source)?; - //Igmp数据也会广播出去,让大家都知道谁加入什么组播 - broadcast(addr, main_udp, &context, buf, None, &[]).await?; - } - return Ok(()); - } - ip_turn_packet::Protocol::Ipv4 => { - //处理广播 - if destination.is_broadcast() || config.broadcast == destination { - broadcast(addr, main_udp, &context, buf, None, &[]).await?; - return Ok(()); - } else if destination.is_multicast() { - if let Some(multicast_info) = crate::service::igmp_server - ::load(&context.token, destination) { - broadcast(addr, main_udp, &context, buf, Some(&multicast_info), &[]).await?; - } - return Ok(()); - } - } - ip_turn_packet::Protocol::Ipv4Broadcast => { - net_packet.set_transport_protocol(ip_turn_packet::Protocol::Ipv4.into()); - return change_broadcast(addr, main_udp, &context, config.broadcast, destination, buf).await; - } - ip_turn_packet::Protocol::Unknown(_) => {} + } + _ => {} + } + } + Protocol::IpTurn => { + match ip_turn_packet::Protocol::from(net_packet.transport_protocol()) { + ip_turn_packet::Protocol::Ipv4Broadcast => { + //处理选择性广播,进过网关还原成原始广播 + let broadcast_packet = BroadcastPacket::new(net_packet.payload())?; + return change_broadcast( + main_udp, + &context, + config.broadcast, + destination, + broadcast_packet, + ) + .await; + } + ip_turn_packet::Protocol::Ipv4 => { + let mut ipv4 = IpV4Packet::new(net_packet.payload_mut())?; + match ipv4.protocol() { + ipv4::protocol::Protocol::Icmp => { + let mut icmp_packet = icmp::IcmpPacket::new(ipv4.payload_mut())?; + if icmp_packet.kind() == Kind::EchoRequest { + //开启ping + icmp_packet.set_kind(Kind::EchoReply); + icmp_packet.update_checksum(); + ipv4.set_source_ip(destination); + ipv4.set_destination_ip(source); + ipv4.update_checksum(); + net_packet.set_source(destination); + net_packet.set_destination(source); + net_packet.set_gateway_flag(true); + reply_buf( + aes_gcm_cipher, + &sender, + main_udp, + addr, + net_packet.into_buffer(), + ) + .await?; } } - //其他的直接转发 - if let Some(peer) = - DEVICE_ADDRESS.get(&(context.token.clone(), destination.into())) + ipv4::protocol::Protocol::Igmp => { + crate::service::igmp_server::handle( + ipv4.payload(), + &context.token, + source, + )?; + //Igmp数据也会广播出去,让大家都知道谁加入什么组播 + net_packet.set_destination(Ipv4Addr::new(224, 0, 0, 1)); + broadcast_igmp(main_udp, &context, net_packet).await?; + } + _ => {} + } + } + _ => {} + } + } + _ => { + log::info!( + "无效数据类型:{:?},Protocol={:?}", + addr, + net_packet.protocol() + ) + } + } + Ok(()) +} + +async fn transmit_handle( + context: &Context, + main_udp: &UdpSocket, + mut net_packet: NetPacket<&mut [u8]>, + config: &ConfigInfo, +) -> crate::error::Result<()> { + let destination = net_packet.destination(); + let client_secret = net_packet.is_encrypt(); + if net_packet.ttl() > 1 { + net_packet.set_ttl(net_packet.ttl() - 1); + if Protocol::IpTurn == net_packet.protocol() { + match ip_turn_packet::Protocol::from(net_packet.transport_protocol()) { + ip_turn_packet::Protocol::Ipv4 => { + //处理广播 + if destination.is_broadcast() || config.broadcast == destination { + broadcast(main_udp, &context, net_packet.buffer(), None, &[]).await?; + return Ok(()); + } else if destination.is_multicast() { + if let Some(multicast_info) = + crate::service::igmp_server::load(&context.token, destination) { - match peer.value() { - PeerLink::Tcp(sender) => { - let _ = sender.send(buf.to_vec()).await; - } - PeerLink::Udp(addr) => { - main_udp.send_to(net_packet.buffer(), addr).await?; - } - } + broadcast( + main_udp, + &context, + net_packet.buffer(), + Some(&multicast_info), + &[], + ) + .await?; } + return Ok(()); } } - } else { - let source = net_packet.source(); - let mut rs = vec![0u8; 4 + 12]; - let mut net_packet = NetPacket::new(&mut rs[4..])?; - net_packet.set_version(Version::V1); - net_packet.set_protocol(Protocol::Error); - net_packet.set_transport_protocol(error_packet::Protocol::Disconnect.into()); - net_packet.first_set_ttl(MAX_TTL); - net_packet.set_source(config.gateway); - net_packet.set_destination(source); - if let Some(sender) = sender { - let _ = sender.send(rs).await; - } else { - main_udp.send_to(net_packet.buffer(), addr).await?; + ip_turn_packet::Protocol::Ipv4Broadcast => {} + ip_turn_packet::Protocol::Unknown(_) => {} + } + } + //其他的直接转发 + if let Some(peer) = DEVICE_ADDRESS.get(&(context.token.clone(), destination.into())) { + let (peer_link, peer_context) = peer.value(); + if peer_context.client_secret == client_secret { + match peer_link { + PeerLink::Tcp(sender) => { + let _ = sender.send(net_packet.buffer().to_vec()).await; + } + PeerLink::Udp(addr) => { + main_udp.send_to(net_packet.buffer(), addr).await?; + } } } } + } + Ok(()) +} + +async fn reply_vec( + aes_gcm_cipher: &Option, + sender: &Option<&Sender>>, + main_udp: &UdpSocket, + addr: SocketAddr, + mut buf: Vec, +) -> crate::error::Result<()> { + if let Some(aes) = aes_gcm_cipher { + let mut packet = NetPacket::new_encrypt(&mut buf)?; + aes.encrypt_ipv4(&mut packet)?; + } else { + let len = buf.len(); + buf.truncate(len - ENCRYPTION_RESERVED); + } + if let Some(sender) = sender { + let _ = sender.send(buf).await; + } else { + main_udp.send_to(&buf, addr).await?; + } + Ok(()) +} + +async fn reply_buf( + aes_gcm_cipher: &Option, + sender: &Option<&Sender>>, + main_udp: &UdpSocket, + addr: SocketAddr, + buf: &mut [u8], +) -> crate::error::Result<()> { + if let Some(aes) = aes_gcm_cipher { + let mut packet = NetPacket::new_encrypt(&mut buf[..])?; + aes.encrypt_ipv4(&mut packet)?; + } + if let Some(sender) = sender { + let _ = sender.send(buf.to_vec()).await; + } else { + main_udp.send_to(buf, addr).await?; + } + Ok(()) +} + +pub async fn handle( + rsa_cipher: &Option, + aes_gcm_cipher: &mut Option, + context: &mut Option, + main_udp: &UdpSocket, + buf: &mut [u8], + addr: SocketAddr, + config: &ConfigInfo, + sender: Option<&Sender>>, +) -> crate::error::Result<()> { + let reg: u8 = service_packet::Protocol::RegistrationRequest.into(); + let handshake: u8 = service_packet::Protocol::HandshakeRequest.into(); + let secret_handshake: u8 = service_packet::Protocol::SecretHandshakeRequest.into(); + let addr_req: u8 = control_packet::Protocol::AddrRequest.into(); + match NetPacket::new(buf) { + Ok(mut net_packet) => { + if net_packet.source_ttl() < net_packet.ttl() { + return Ok(()); + } + let p = net_packet.transport_protocol(); + if net_packet.is_gateway() { + if p != handshake && p != secret_handshake { + //解密数据 + if let Some(aes) = aes_gcm_cipher { + aes.decrypt_ipv4(&mut net_packet)?; + } else if net_packet.is_encrypt() { + log::warn!( + "没有密钥={},src={},dest={}", + addr, + net_packet.source(), + net_packet.destination() + ); + let source = net_packet.source(); + let mut rs = vec![0u8; 12 + ENCRYPTION_RESERVED]; + let mut packet = NetPacket::new_encrypt(&mut rs)?; + packet.set_version(Version::V1); + packet.set_protocol(Protocol::Error); + packet.set_transport_protocol(error_packet::Protocol::NoKey.into()); + packet.first_set_ttl(MAX_TTL); + packet.set_source(config.gateway); + packet.set_destination(source); + packet.set_gateway_flag(true); + reply_vec(&None, &sender, main_udp, addr, rs).await?; + return Ok(()); + } + } + } else if let Some(context) = context { + //不是服务端的包虽然不能解密,但是可以验证数据合法性 + if net_packet.is_encrypt() { + let finger = Finger::new(context.token.clone()); + finger.check_finger(&net_packet)?; + } + } + if net_packet.protocol() == Protocol::Service + && (p == reg || p == handshake || p == secret_handshake || p == addr_req) + { + server_packet_pre_handle( + context, + rsa_cipher, + aes_gcm_cipher, + main_udp, + net_packet, + config, + addr, + sender, + ) + .await?; + } else if let Some(context) = context { + if net_packet.is_gateway() || net_packet.destination() == config.gateway { + //给网关的消息 + server_packet_handle( + rsa_cipher, + aes_gcm_cipher, + context, + main_udp, + net_packet, + config, + addr, + sender, + ) + .await?; + } else { + //需要转发的数据 + transmit_handle(context, main_udp, net_packet, config).await?; + } + } else { + let source = net_packet.source(); + let mut rs = vec![0u8; 12 + ENCRYPTION_RESERVED]; + let mut packet = NetPacket::new_encrypt(&mut rs)?; + packet.set_version(Version::V1); + packet.set_protocol(Protocol::Error); + packet.set_transport_protocol(error_packet::Protocol::Disconnect.into()); + packet.first_set_ttl(MAX_TTL); + packet.set_source(config.gateway); + packet.set_destination(source); + packet.set_gateway_flag(true); + reply_vec(aes_gcm_cipher, &sender, main_udp, addr, rs).await?; + } + } Err(e) => { - log::warn!("数据错误:{},{:?}",addr,e); + log::warn!("数据错误:{},{:?}", addr, e); } } return Ok(()); } - - - diff --git a/src/service/main_service/mod.rs b/src/service/main_service/mod.rs index 4fa62df..192a857 100644 --- a/src/service/main_service/mod.rs +++ b/src/service/main_service/mod.rs @@ -1,56 +1,65 @@ use std::collections::HashMap; -use std::net::SocketAddr; -use moka::sync::Cache; -use std::time::Duration; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; +use std::time::Duration; + use crossbeam_skiplist::SkipMap; +use moka::sync::Cache; use parking_lot::RwLock; use tokio::sync::mpsc::Sender; -mod udp_service; -mod tcp_service; -mod common; -pub use udp_service::start_udp; pub use tcp_service::start_tcp; +pub use udp_service::start_udp; +use crate::cipher::Aes256GcmCipher; +mod common; +mod tcp_service; +mod udp_service; lazy_static::lazy_static! { //七天不连接则回收ip - static ref DEVICE_ID_SESSION:Cache<(String,String),()> = Cache::builder() - .time_to_idle(Duration::from_secs(60*60*24*7)).eviction_listener(|k:Arc<(String,String)>,_,cause|{ - if cause!=moka::notification::RemovalCause::Expired{ - return; - } + static ref DEVICE_ID_SESSION:Cache<(String,String),i64> = Cache::builder() + .time_to_idle(Duration::from_secs(60*60*24*7)).eviction_listener(|k:Arc<(String,String)>,id:i64,cause|{ + if cause!=moka::notification::RemovalCause::Expired{ + return; + } log::info!("DEVICE_ID_SESSION eviction {:?}", k); if let Some(v) = VIRTUAL_NETWORK.get(&k.0){ let mut lock = v.write(); - lock.virtual_ip_map.remove(&k.1); - lock.epoch+=1; + if let Some(dev) = lock.virtual_ip_map.get(&k.1){ + if dev.id==id{ + lock.virtual_ip_map.remove(&k.1); + lock.epoch+=1; + } + } } }).build(); //七天没有用户则回收网段缓存 static ref VIRTUAL_NETWORK:Cache>> = Cache::builder() .time_to_idle(Duration::from_secs(60*60*24*7)).build(); - static ref DEVICE_ADDRESS:SkipMap<(String,u32), PeerLink> = SkipMap::new(); + static ref DEVICE_ADDRESS:SkipMap<(String,u32), (PeerLink,Context)> = SkipMap::new(); + static ref TCP_AES:SkipMap = SkipMap::new(); + static ref UDP_AES:Cache = Cache::builder() + .time_to_idle(Duration::from_secs(20)).build(); //udp专用 10秒钟没有收到消息则判定为掉线 // 地址 -> 注册信息 static ref UDP_SESSION:Cache = Cache::builder() .time_to_idle(Duration::from_secs(10)).eviction_listener(|_,context:Context,cause|{ - if cause!=moka::notification::RemovalCause::Expired{ - return; - } - log::info!("UDP_SESSION eviction {:?}", context); + if cause!=moka::notification::RemovalCause::Expired{ + return; + } + log::info!("UDP_SESSION eviction token={},virtual_ip={},device_id={},id={}", context.token,context.virtual_ip,context.device_id,context.id); if let Some(v) = VIRTUAL_NETWORK.get(&context.token){ let mut lock = v.write(); - if let Some(mut item) = lock.virtual_ip_map.get_mut(&context.device_id){ + if let Some(item) = lock.virtual_ip_map.get_mut(&context.device_id){ if item.id!=context.id{ return; } item.status = PeerDeviceStatus::Offline; DEVICE_ADDRESS.remove(&(context.token,context.virtual_ip)); + lock.epoch+=1; } - lock.epoch+=1; } }).build(); } @@ -61,13 +70,27 @@ pub enum PeerLink { Udp(SocketAddr), } - -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Context { token: String, virtual_ip: u32, id: i64, device_id: String, + client_secret: bool, + address: SocketAddr, +} + +impl Default for Context { + fn default() -> Self { + Context { + token: "".to_string(), + virtual_ip: 0, + id: 0, + device_id: "".to_string(), + client_secret: false, + address: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + } + } } #[derive(Clone, Debug)] @@ -83,6 +106,7 @@ pub struct DeviceInfo { ip: u32, name: String, status: PeerDeviceStatus, + client_secret: bool, } #[derive(Copy, Clone, Debug, Eq, PartialEq)] @@ -104,7 +128,7 @@ impl From for PeerDeviceStatus { fn from(value: u8) -> Self { match value { 0 => PeerDeviceStatus::Online, - _ => PeerDeviceStatus::Offline + _ => PeerDeviceStatus::Offline, } } -} \ No newline at end of file +} diff --git a/src/service/main_service/tcp_service.rs b/src/service/main_service/tcp_service.rs index 6bf3a3a..91c2a5c 100644 --- a/src/service/main_service/tcp_service.rs +++ b/src/service/main_service/tcp_service.rs @@ -3,80 +3,124 @@ use std::net::SocketAddr; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpListener, UdpSocket}; use tokio::net::tcp::OwnedReadHalf; +use tokio::net::{TcpListener, UdpSocket}; use tokio::sync::mpsc::{channel, Sender}; -use crate::ConfigInfo; -use crate::service::main_service::{Context, DEVICE_ADDRESS, PeerDeviceStatus, VIRTUAL_NETWORK}; +use crate::cipher::{Aes256GcmCipher, RsaCipher}; use crate::service::main_service::common::handle; +use crate::service::main_service::{ + Context, PeerDeviceStatus, DEVICE_ADDRESS, TCP_AES, VIRTUAL_NETWORK, +}; +use crate::ConfigInfo; -pub async fn start_tcp(tcp: TcpListener, main_udp: Arc, config: ConfigInfo) -> io::Result<()> { +pub async fn start_tcp( + tcp: TcpListener, + main_udp: Arc, + config: ConfigInfo, + rsa_cipher: Option, +) { loop { let (stream, addr) = match tcp.accept().await { - Ok(rs) => { rs } + Ok(rs) => rs, Err(e) => { - log::warn!("tcp accept err:{:?}",e); + log::warn!("tcp accept err:{:?}", e); continue; } }; - log::info!("tcp连接 {}",addr); + log::info!("tcp连接 {}", addr); let (r, mut w) = stream.into_split(); let (sender, mut receiver) = channel::>(100); tokio::spawn(async move { - while let Some(mut data) = receiver.recv().await { - if data.len() >= 4 { - let len = data.len() - 4; - data[2] = (len >> 8) as u8; - data[3] = (len & 0xFF) as u8; - if let Err(e) = w.write_all(&data).await { - log::info!("发送失败,链接终止:{:?},{:?}",addr,e); - break; - } + let mut head = [0; 4]; + while let Some(data) = receiver.recv().await { + let len = data.len(); + head[2] = (len >> 8) as u8; + head[3] = (len & 0xFF) as u8; + if let Err(e) = w.write_all(&head).await { + log::info!("发送失败,链接终止:{:?},{:?}", addr, e); + } + if let Err(e) = w.write_all(&data).await { + log::info!("发送失败,链接终止:{:?},{:?}", addr, e); + break; } } let _ = w.shutdown().await; }); let main_udp = main_udp.clone(); let config = config.clone(); + let rsa_cipher = rsa_cipher.clone(); tokio::spawn(async move { - let mut context = Context { - token: "".to_string(), - virtual_ip: 0, - id: 0, - device_id: "".to_string(), - }; - if let Err(e) = tcp_handle(&mut context, config, r, addr, sender, main_udp).await { - log::info!("接收失败,链接终止:{:?},{:?}",addr,e); + let mut context: Option = None; + let mut aes_gcm_cipher: Option = None; + if let Err(e) = tcp_handle( + rsa_cipher, + &mut aes_gcm_cipher, + &mut context, + config, + r, + addr, + sender, + main_udp, + ) + .await + { + log::info!("接收失败,链接终止:{:?},{:?}", addr, e); } - if context.virtual_ip != 0 && context.id != 0 { + TCP_AES.remove(&addr); + if let Some(context) = context { if let Some(v) = VIRTUAL_NETWORK.get(&context.token) { let mut lock = v.write(); - if let Some(mut item) = lock.virtual_ip_map.get_mut(&context.device_id) { + if let Some(item) = lock.virtual_ip_map.get_mut(&context.device_id) { if item.id != context.id { return; } item.status = PeerDeviceStatus::Offline; DEVICE_ADDRESS.remove(&(context.token, context.virtual_ip)); + lock.epoch += 1; } - lock.epoch += 1; } } }); } } - -async fn tcp_handle(context: &mut Context, config: ConfigInfo, mut read: OwnedReadHalf, addr: SocketAddr, sender: Sender>, main_udp: Arc) -> io::Result<()> { +async fn tcp_handle( + rsa_cipher: Option, + aes_gcm_cipher: &mut Option, + context: &mut Option, + config: ConfigInfo, + mut read: OwnedReadHalf, + addr: SocketAddr, + sender: Sender>, + main_udp: Arc, +) -> io::Result<()> { + let mut head = [0; 4]; let mut buf = [0; 10240]; loop { - read.read_exact(&mut buf[..4]).await?; - let len = 4 + (((buf[2] as u16) << 8) | buf[3] as u16) as usize; - read.read_exact(&mut buf[4..len]).await?; - if let Err(e) = handle(context, &main_udp, &mut buf[..len], addr, &config, Some(&sender)).await { - log::info!("tcp数据处理失败:{:?},{:?}",addr,e); + read.read_exact(&mut head).await?; + let len = (((head[2] as u16) << 8) | head[3] as u16) as usize; + if len < 12 || len > buf.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "length overflow", + )); + } + read.read_exact(&mut buf[..len]).await?; + if let Err(e) = handle( + &rsa_cipher, + aes_gcm_cipher, + context, + &main_udp, + &mut buf[..len], + addr, + &config, + Some(&sender), + ) + .await + { + log::info!("tcp数据处理失败:{:?},{:?}", addr, e); } } } - diff --git a/src/service/main_service/udp_service.rs b/src/service/main_service/udp_service.rs index 90378b1..5b969b4 100644 --- a/src/service/main_service/udp_service.rs +++ b/src/service/main_service/udp_service.rs @@ -1,31 +1,42 @@ use std::sync::Arc; +use crate::cipher::RsaCipher; use tokio::net::UdpSocket; -use crate::ConfigInfo; -use crate::service::main_service::{Context, UDP_SESSION}; use crate::service::main_service::common::handle; +use crate::service::main_service::{UDP_AES, UDP_SESSION}; +use crate::ConfigInfo; -pub async fn start_udp(main_udp: Arc, config: ConfigInfo) { +pub async fn start_udp( + main_udp: Arc, + config: ConfigInfo, + rsa_cipher: Option, +) { loop { let mut buf = vec![0u8; 10240]; - match main_udp.recv_from(&mut buf[4..]).await { + match main_udp.recv_from(&mut buf).await { Ok((len, addr)) => { let main_udp = main_udp.clone(); let config = config.clone(); - let mut context = UDP_SESSION.get(&addr).unwrap_or_else(|| { - Context { - token: "".to_string(), - virtual_ip: 0, - id: 0, - device_id: "".to_string(), - } - }); + let rsa_cipher = rsa_cipher.clone(); + let mut context = UDP_SESSION.get(&addr); + let mut aes = UDP_AES.get(&addr); tokio::spawn(async move { - match handle(&mut context, &main_udp, &mut buf[..len + 4], addr, &config, None).await { + match handle( + &rsa_cipher, + &mut aes, + &mut context, + &main_udp, + &mut buf[..len], + addr, + &config, + None, + ) + .await + { Ok(_) => {} Err(e) => { - log::info!("udp数据处理失败:{:?},{:?}",addr,e); + log::info!("udp数据处理失败:{:?},{:?}", addr, e); } } }); @@ -35,4 +46,4 @@ pub async fn start_udp(main_udp: Arc, config: ConfigInfo) { } } } -} \ No newline at end of file +}