diff --git a/.gitignore b/.gitignore index 1a05e5d86..18dbcc8a5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ /.idea -/target \ No newline at end of file +**/target diff --git a/Cargo.lock b/Cargo.lock index 3ddf503a1..8682d9286 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21,6 +21,7 @@ dependencies = [ "rand", "reqwest", "serde_json", + "stun_proto", "stunclient", "tokio", "tokio-stream", @@ -788,6 +789,10 @@ dependencies = [ "trackable 0.2.24", ] +[[package]] +name = "stun_proto" +version = "0.1.0" + [[package]] name = "stunclient" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index e5a43bddc..f0503c656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,5 +13,6 @@ rand = { version = "0.8.4", default-features = false } reqwest = { version = "0.11.10", default-features = false, features=["json", "rustls-tls"] } serde_json = { version = "1.0.70", default-features = false } stunclient = { version = "0.3.1", default-features = false, features = ["async"] } +stun_proto = { path = "./stun_proto" } tokio = { version = "1.14.0", default-features = false, features = ["fs", "macros", "net", "rt"] } tokio-stream = { version = "0.1.8", default-features = false } \ No newline at end of file diff --git a/rust-toolchain b/rust-toolchain index 07ade694b..2bf5ad044 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly \ No newline at end of file +stable diff --git a/src/main.rs b/src/main.rs index e8645d61b..d94805068 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ use crate::utils::join_all_with_semaphore; use crate::outputs::{ValidHosts, ValidIpV4s, ValidIpV6s}; use crate::servers::StunServer; use crate::stun::{StunServerTestResult, StunSocketResponse}; -use crate::stun_codec::{Attribute, NonParsableAttribute}; +// use crate::stun_codec::{Attribute, NonParsableAttribute}; extern crate pretty_env_logger; #[macro_use] extern crate log; @@ -23,7 +23,6 @@ mod utils; mod outputs; mod geoip; mod git; -mod stun_codec; const CONCURRENT_SOCKETS_USED_LIMIT: usize = 64; @@ -44,46 +43,46 @@ async fn get_stun_response(addr: &str) -> io::Result<()> { let bytes_read = bytes_read.unwrap(); - let r = stun_codec::StunMessageReader { bytes: buf[0..bytes_read].as_ref() }; - info!("Method {:?} , Class {:?}", r.get_method().unwrap(), r.get_class()); - r.get_attrs().for_each(|attr| { - match &attr { - Ok(attr) => { - match attr { - Attribute::MappedAddress(r) => info!("MappedAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), - Attribute::ResponseAddress(r) => info!("ResponseAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), - Attribute::ChangeAddress(r) => info!("ChangeAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), - Attribute::SourceAddress(r) => info!("SourceAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), - Attribute::ChangedAddress(r) => info!("ChangedAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), - Attribute::XorMappedAddress(r) => info!("XorMappedAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), - Attribute::OptXorMappedAddress(r) => info!("OptXorMappedAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), - Attribute::OtherAddress(r) => info!("OtherAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), - Attribute::ResponseOrigin(r) => info!("ResponseOrigin {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), - Attribute::AlternateServer(r) => info!("AlternateServer {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), - Attribute::Software(r) => info!("Software {}", r.get_software().unwrap()), - Attribute::ReflectedFrom(r) => info!("ReflectedFrom {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), - Attribute::ErrorCode(r) => info!("ErrorCode {:?}", r.get_error().unwrap()), - Attribute::Fingerprint(r) => info!("Fingerprint {}", r.get_checksum()), - Attribute::MessageIntegrity(r) => info!("MessageIntegrity {:?}", r.get_digest()), - Attribute::Realm(r) => info!("Realm {}", r.get_realm().unwrap()), - Attribute::Nonce(r) => info!("Nonce {}", r.get_nonce().unwrap()), - Attribute::Password(r) => info!("Password {}", r.get_password().unwrap()), - Attribute::UnknownAttributes(r) => { - for attr_code in r.get_attr_codes() { - info!("Unknown attribute {}", attr_code) - } - }, - Attribute::Username(r) => info!("Username {}", r.get_username().unwrap()), - } - } - Err(attr) => { - match &attr { - NonParsableAttribute::Unknown(r) => warn!("UnknownAttr type {:04x} len {}", r.get_type_raw(), r.get_total_length()), - NonParsableAttribute::Malformed(r) => warn!("MalformedAttr type {:04x} len {}", r.get_type_raw(), r.get_value_length_raw()), - } - } - } - }); + // let r = stun_codec::StunMessageReader { bytes: buf[0..bytes_read].as_ref() }; + // info!("Method {:?} , Class {:?}", r.get_method().unwrap(), r.get_class()); + // r.get_attrs().for_each(|attr| { + // match &attr { + // Ok(attr) => { + // match attr { + // Attribute::MappedAddress(r) => info!("MappedAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), + // Attribute::ResponseAddress(r) => info!("ResponseAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), + // Attribute::ChangeAddress(r) => info!("ChangeAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), + // Attribute::SourceAddress(r) => info!("SourceAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), + // Attribute::ChangedAddress(r) => info!("ChangedAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), + // Attribute::XorMappedAddress(r) => info!("XorMappedAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), + // Attribute::OptXorMappedAddress(r) => info!("OptXorMappedAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), + // Attribute::OtherAddress(r) => info!("OtherAddress {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), + // Attribute::ResponseOrigin(r) => info!("ResponseOrigin {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), + // Attribute::AlternateServer(r) => info!("AlternateServer {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), + // Attribute::Software(r) => info!("Software {}", r.get_software().unwrap()), + // Attribute::ReflectedFrom(r) => info!("ReflectedFrom {:?}", SocketAddr::new(r.get_address().unwrap(), r.get_port())), + // Attribute::ErrorCode(r) => info!("ErrorCode {:?}", r.get_error().unwrap()), + // Attribute::Fingerprint(r) => info!("Fingerprint {}", r.get_checksum()), + // Attribute::MessageIntegrity(r) => info!("MessageIntegrity {:?}", r.get_digest()), + // Attribute::Realm(r) => info!("Realm {}", r.get_realm().unwrap()), + // Attribute::Nonce(r) => info!("Nonce {}", r.get_nonce().unwrap()), + // Attribute::Password(r) => info!("Password {}", r.get_password().unwrap()), + // Attribute::UnknownAttributes(r) => { + // for attr_code in r.get_attr_codes() { + // info!("Unknown attribute {}", attr_code) + // } + // }, + // Attribute::Username(r) => info!("Username {}", r.get_username().unwrap()), + // } + // } + // Err(attr) => { + // match &attr { + // NonParsableAttribute::Unknown(r) => warn!("UnknownAttr type {:04x} len {}", r.get_type_raw(), r.get_total_length()), + // NonParsableAttribute::Malformed(r) => warn!("MalformedAttr type {:04x} len {}", r.get_type_raw(), r.get_value_length_raw()), + // } + // } + // } + // }); Ok(()) } diff --git a/src/stun_codec/mod.rs b/src/stun_codec/mod.rs deleted file mode 100644 index 0683ceaf1..000000000 --- a/src/stun_codec/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod enums; -mod attrs; -mod message; - -pub use attrs::*; -pub use enums::*; -pub use message::*; diff --git a/stun_proto/Cargo.lock b/stun_proto/Cargo.lock new file mode 100644 index 000000000..6c752316d --- /dev/null +++ b/stun_proto/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "stun_proto" +version = "0.1.0" diff --git a/stun_proto/Cargo.toml b/stun_proto/Cargo.toml new file mode 100644 index 000000000..d149a3c47 --- /dev/null +++ b/stun_proto/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "stun_proto" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/src/stun_codec/attrs/alternate_server.rs b/stun_proto/src/attrs/alternate_server.rs similarity index 100% rename from src/stun_codec/attrs/alternate_server.rs rename to stun_proto/src/attrs/alternate_server.rs diff --git a/src/stun_codec/attrs/base/malformed.rs b/stun_proto/src/attrs/base/malformed.rs similarity index 100% rename from src/stun_codec/attrs/base/malformed.rs rename to stun_proto/src/attrs/base/malformed.rs diff --git a/src/stun_codec/attrs/base/mod.rs b/stun_proto/src/attrs/base/mod.rs similarity index 100% rename from src/stun_codec/attrs/base/mod.rs rename to stun_proto/src/attrs/base/mod.rs diff --git a/src/stun_codec/attrs/base/shared.rs b/stun_proto/src/attrs/base/shared.rs similarity index 100% rename from src/stun_codec/attrs/base/shared.rs rename to stun_proto/src/attrs/base/shared.rs diff --git a/src/stun_codec/attrs/base/unknown.rs b/stun_proto/src/attrs/base/unknown.rs similarity index 100% rename from src/stun_codec/attrs/base/unknown.rs rename to stun_proto/src/attrs/base/unknown.rs diff --git a/src/stun_codec/attrs/change_address.rs b/stun_proto/src/attrs/change_address.rs similarity index 100% rename from src/stun_codec/attrs/change_address.rs rename to stun_proto/src/attrs/change_address.rs diff --git a/src/stun_codec/attrs/changed_address.rs b/stun_proto/src/attrs/changed_address.rs similarity index 100% rename from src/stun_codec/attrs/changed_address.rs rename to stun_proto/src/attrs/changed_address.rs diff --git a/src/stun_codec/attrs/error_code.rs b/stun_proto/src/attrs/error_code.rs similarity index 100% rename from src/stun_codec/attrs/error_code.rs rename to stun_proto/src/attrs/error_code.rs diff --git a/src/stun_codec/attrs/fingerprint.rs b/stun_proto/src/attrs/fingerprint.rs similarity index 100% rename from src/stun_codec/attrs/fingerprint.rs rename to stun_proto/src/attrs/fingerprint.rs diff --git a/src/stun_codec/attrs/mapped_address.rs b/stun_proto/src/attrs/mapped_address.rs similarity index 100% rename from src/stun_codec/attrs/mapped_address.rs rename to stun_proto/src/attrs/mapped_address.rs diff --git a/src/stun_codec/attrs/message_integrity.rs b/stun_proto/src/attrs/message_integrity.rs similarity index 100% rename from src/stun_codec/attrs/message_integrity.rs rename to stun_proto/src/attrs/message_integrity.rs diff --git a/src/stun_codec/attrs/mod.rs b/stun_proto/src/attrs/mod.rs similarity index 100% rename from src/stun_codec/attrs/mod.rs rename to stun_proto/src/attrs/mod.rs diff --git a/src/stun_codec/attrs/nonce.rs b/stun_proto/src/attrs/nonce.rs similarity index 100% rename from src/stun_codec/attrs/nonce.rs rename to stun_proto/src/attrs/nonce.rs diff --git a/src/stun_codec/attrs/opt_xor_mapped_address.rs b/stun_proto/src/attrs/opt_xor_mapped_address.rs similarity index 100% rename from src/stun_codec/attrs/opt_xor_mapped_address.rs rename to stun_proto/src/attrs/opt_xor_mapped_address.rs diff --git a/src/stun_codec/attrs/other_address.rs b/stun_proto/src/attrs/other_address.rs similarity index 100% rename from src/stun_codec/attrs/other_address.rs rename to stun_proto/src/attrs/other_address.rs diff --git a/src/stun_codec/attrs/password.rs b/stun_proto/src/attrs/password.rs similarity index 100% rename from src/stun_codec/attrs/password.rs rename to stun_proto/src/attrs/password.rs diff --git a/src/stun_codec/attrs/realm.rs b/stun_proto/src/attrs/realm.rs similarity index 100% rename from src/stun_codec/attrs/realm.rs rename to stun_proto/src/attrs/realm.rs diff --git a/src/stun_codec/attrs/reflected_from.rs b/stun_proto/src/attrs/reflected_from.rs similarity index 100% rename from src/stun_codec/attrs/reflected_from.rs rename to stun_proto/src/attrs/reflected_from.rs diff --git a/src/stun_codec/attrs/response_address.rs b/stun_proto/src/attrs/response_address.rs similarity index 100% rename from src/stun_codec/attrs/response_address.rs rename to stun_proto/src/attrs/response_address.rs diff --git a/src/stun_codec/attrs/response_origin.rs b/stun_proto/src/attrs/response_origin.rs similarity index 100% rename from src/stun_codec/attrs/response_origin.rs rename to stun_proto/src/attrs/response_origin.rs diff --git a/src/stun_codec/attrs/software.rs b/stun_proto/src/attrs/software.rs similarity index 100% rename from src/stun_codec/attrs/software.rs rename to stun_proto/src/attrs/software.rs diff --git a/src/stun_codec/attrs/source_address.rs b/stun_proto/src/attrs/source_address.rs similarity index 100% rename from src/stun_codec/attrs/source_address.rs rename to stun_proto/src/attrs/source_address.rs diff --git a/src/stun_codec/attrs/unknown_attributes.rs b/stun_proto/src/attrs/unknown_attributes.rs similarity index 100% rename from src/stun_codec/attrs/unknown_attributes.rs rename to stun_proto/src/attrs/unknown_attributes.rs diff --git a/src/stun_codec/attrs/username.rs b/stun_proto/src/attrs/username.rs similarity index 100% rename from src/stun_codec/attrs/username.rs rename to stun_proto/src/attrs/username.rs diff --git a/src/stun_codec/attrs/xor_mapped_address.rs b/stun_proto/src/attrs/xor_mapped_address.rs similarity index 100% rename from src/stun_codec/attrs/xor_mapped_address.rs rename to stun_proto/src/attrs/xor_mapped_address.rs diff --git a/src/stun_codec/enums.rs b/stun_proto/src/enums.rs similarity index 100% rename from src/stun_codec/enums.rs rename to stun_proto/src/enums.rs diff --git a/stun_proto/src/lib.rs b/stun_proto/src/lib.rs new file mode 100644 index 000000000..4cb3ba2d9 --- /dev/null +++ b/stun_proto/src/lib.rs @@ -0,0 +1,434 @@ +#![no_std] + +type Result = core::result::Result; + +#[derive(Debug, PartialEq)] +pub enum ReaderErr { + NotEnoughBytes, + UnexpectedValue, +} + +#[derive(Debug, PartialEq)] +pub enum Method { + Binding, +} + +#[derive(Debug, PartialEq)] +pub enum Class { + Request, + Indirection, + SuccessResponse, + ErrorResponse, +} + +pub struct MsgReader<'a> { + bytes: &'a [u8], +} + +impl MsgReader<'_> { + pub fn get_message_type_raw(&self) -> Result<&[u8; 2]> { + self.bytes.get(0..2) + .map(|b| b.try_into().unwrap()) + .ok_or(ReaderErr::NotEnoughBytes) + } + + pub fn get_message_length_raw(&self) -> Result<&[u8; 2]> { + self.bytes.get(2..4) + .map(|b| b.try_into().unwrap()) + .ok_or(ReaderErr::NotEnoughBytes) + } + + pub fn get_magic_cookie_raw(&self) -> Result<&[u8; 4]> { + self.bytes.get(4..8) + .map(|b| b.try_into().unwrap()) + .ok_or(ReaderErr::NotEnoughBytes) + } + + pub fn get_transaction_id_raw(&self) -> Result<&[u8; 12]> { + self.bytes.get(8..20) + .map(|b| b.try_into().unwrap()) + .ok_or(ReaderErr::NotEnoughBytes) + } + + pub fn get_attributes_raw(&self) -> Result<&[u8]> { + self.bytes.get(20..) + .ok_or(ReaderErr::NotEnoughBytes) + } +} + +impl MsgReader<'_> { + pub fn new(bytes: &[u8]) -> MsgReader { + MsgReader { + bytes + } + } + + /// Gets the message method. + ///

+ /// + /// Currently the method `Binding` is the only method in the RFC specs. + ///

+ /// + /// Ignores the first two bits of the message header, as they should always be 0. + ///

+ /// + /// Returns + /// - `Result::NotEnoughBytes` if the message is not large enough + /// - `Result::UnexpectedValue` if the value doesn't correspond to a known method + ///

+ /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// use stun_proto::{Method, MsgReader}; + /// let msg = [0x0, 0x1]; + /// let r = MsgReader::new(&msg); + /// assert_eq!(Method::Binding, r.get_method().unwrap()); + /// ``` + /// + /// The message is not large enough: + /// ``` + /// use stun_proto::{MsgReader, ReaderErr}; + /// let msg = []; + /// let r = MsgReader::new(&msg); + /// assert_eq!(ReaderErr::NotEnoughBytes, r.get_method().unwrap_err()); + /// ``` + /// + /// The value does not correspond to a known method: + /// ``` + /// use stun_proto::{MsgReader, ReaderErr}; + /// let msg = [0x0, 0xF]; + /// let r = MsgReader::new(&msg); + /// assert_eq!(ReaderErr::UnexpectedValue, r.get_method().unwrap_err()); + /// ``` + pub fn get_method(&self) -> Result { + let b = self.get_message_type_raw()?; + + // we ignore the first two bits which should always be zero, + // as well as the 5th and 9th bit which correspond to message class + let method_raw = u16::from_be_bytes(*b) & 0b0011111011101111; + + match method_raw { + 1 => Ok(Method::Binding), + _ => Err(ReaderErr::UnexpectedValue) + } + } + + + /// Gets the message class. + ///

+ /// + /// Ignores all header bits except the 5th and the 9th bit. + ///

+ /// + /// Returns + /// - `Result::NotEnoughBytes` if the message is not large enough + ///

+ /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// use stun_proto::{Class, MsgReader}; + /// let msg = [0x0, 0x1]; + /// let r = MsgReader::new(&msg); + /// assert_eq!(Class::Request, r.get_class().unwrap()); + /// ``` + /// + /// The message is not large enough: + /// ``` + /// use stun_proto::{MsgReader, ReaderErr}; + /// let msg = []; + /// let r = MsgReader::new(&msg); + /// assert_eq!(ReaderErr::NotEnoughBytes, r.get_class().unwrap_err()); + /// ``` + pub fn get_class(&self) -> Result { + let b = self.get_message_type_raw()?; + + // we ignore the first two bits which should always be zero, + // as well all bits except the 5th and 9th bit since they all + // correspond to message method + let class_raw = u16::from_be_bytes(*b) & 0b0000000100010000; + + match class_raw { + 0b000000000 => Ok(Class::Request), + 0b000010000 => Ok(Class::Indirection), + 0b100000000 => Ok(Class::SuccessResponse), + 0b100010000 => Ok(Class::ErrorResponse), + _ => Err(ReaderErr::UnexpectedValue) + } + } + + /// Gets the total length of the message in bits as declared in the message header. + ///

+ /// + /// Returns + /// - `Result::NotEnoughBytes` if the message is not large enough + ///

+ /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// use stun_proto::{Class, MsgReader}; + /// let msg = [ + /// 0x0, 0x1, // class-method values + /// 0x0, 0x7 // total length (in big-endian order) + /// ]; + /// let r = MsgReader::new(&msg); + /// assert_eq!(7, r.get_message_length().unwrap()); + /// ``` + /// + /// The message is not large enough: + /// ``` + /// use stun_proto::{MsgReader, ReaderErr}; + /// let msg = []; + /// let r = MsgReader::new(&msg); + /// assert_eq!(ReaderErr::NotEnoughBytes, r.get_message_length().unwrap_err()); + /// ``` + pub fn get_message_length(&self) -> Result { + let b = self.get_message_length_raw()?; + Ok(u16::from_be_bytes(*b)) + } +} + +enum ComprehensionCategory { + Required, + Optional, +} + +trait GenericAttribute<'a> { + fn get_type(&'a self) -> Result { + self.get_type_raw() + .map(|b| u16::from_be_bytes(*b)) + } + + fn get_comprehension_category(&'a self) -> Result { + match self.get_type_raw()?[0] { + 0x00..=0x7F => Ok(ComprehensionCategory::Required), + 0x80..=0xFF => Ok(ComprehensionCategory::Optional), + } + } + + fn get_value_length(&'a self) -> Result { + self.get_value_length_raw() + .map(|b| u16::from_be_bytes(*b)) + } + + fn get_type_raw(&'a self) -> Result<&'a [u8; 2]> { + self.get_bytes_raw().get(0..2) + .map(|b| b.try_into().unwrap()) + .ok_or(ReaderErr::NotEnoughBytes) + } + + fn get_value_length_raw(&'a self) -> Result<&'a [u8; 2]> { + self.get_bytes_raw().get(2..4) + .map(|b| b.try_into().unwrap()) + .ok_or(ReaderErr::NotEnoughBytes) + } + + fn get_value_raw(&'a self) -> Result<&'a [u8]> { + let value_length = self.get_value_length()? as usize; + self.get_bytes_raw().get(4..4 + value_length) + .ok_or(ReaderErr::NotEnoughBytes) + } + + fn get_bytes_raw(&'a self) -> &'a [u8]; +} + +struct RawAttributeIterator<'a> { + bytes: &'a [u8], + idx: usize, +} + +impl<'a> Iterator for RawAttributeIterator<'a> { + type Item = Result<(&'a [u8; 2], &'a [u8; 2], &'a [u8])>; + + fn next(&mut self) -> Option { + if self.idx >= self.bytes.len() { + None + } else { + let typ_raw = self.bytes.get(self.idx .. self.idx + 2) + .map(|b| b.try_into().unwrap()) + .ok_or(ReaderErr::NotEnoughBytes); + + let typ_raw = match typ_raw { + Ok(t) => t, + Err(err) => { + self.idx = self.bytes.len(); + return Some(Err(err)) + } + }; + + let val_len_raw = self.bytes.get(self.idx + 2 .. self.idx + 4) + .map(|b| b.try_into().unwrap()) + .ok_or(ReaderErr::NotEnoughBytes); + + let val_len_raw: &[u8; 2] = match val_len_raw { + Ok(t) => t, + Err(err) => { + self.idx = self.bytes.len(); + return Some(Err(err)) + } + }; + + let val_len = u16::from_be_bytes(*val_len_raw) as usize; + + let val_raw = self.bytes.get(self.idx + 4 .. self.idx + 4 + val_len) + .ok_or(ReaderErr::NotEnoughBytes); + + let val_raw = match val_raw { + Ok(val) => val, + Err(err) => { + self.idx = self.bytes.len(); + return Some(Err(err)) + } + }; + + self.idx += 4 + val_len; + + Some(Ok((typ_raw, val_len_raw, val_raw))) + } + } +} + +struct AttributeIterator<'a> { + raw_iter: RawAttributeIterator<'a> +} + +impl<'a> Iterator for AttributeIterator<'a> { + type Item = Result<(u16, &'a [u8])>; + + fn next(&mut self) -> Option { + match self.raw_iter.next() { + None => None, + Some(Err(err)) => Some(Err(err)), + Some(Ok((typ_raw, _, val_raw))) => { + let typ = u16::from_be_bytes(*typ_raw); + Some(Ok((typ, val_raw))) + } + } + } +} + +impl AttributeIterator<'_> { + fn new(bytes: &'_ [u8]) -> AttributeIterator { + AttributeIterator { + raw_iter: RawAttributeIterator { + bytes, + idx: 0 + } + } + } +} + +struct MappedAddress<'a> { + bytes: &'a [u8], +} + +impl GenericAttribute<'_> for MappedAddress<'_> { + fn get_bytes_raw(&'_ self) -> &'_ [u8] { + self.bytes + } +} + +struct XorMappedAddress<'a> { + bytes: &'a [u8], +} + +impl GenericAttribute<'_> for XorMappedAddress<'_> { + fn get_bytes_raw(&'_ self) -> &'_ [u8] { + self.bytes + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_message_header() { + let msg = [ + 0x00, 0x01, // method: Binding , class: Request + 0x00, 0x14, // total length: 20 + 0x21, 0x12, 0xA4, 0x42, // magic cookie (RFC spec constant) + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, // transaction id (12 bytes total) + ]; + + let r = MsgReader::new(&msg); + + assert_eq!(Method::Binding, r.get_method().unwrap()); + assert_eq!(Class::Request, r.get_class().unwrap()); + assert_eq!(0x2112A442, u32::from_be_bytes(*r.get_magic_cookie_raw().unwrap())); + assert_eq!(12, r.get_transaction_id_raw().unwrap().len()); + assert_eq!(0, r.get_attributes_raw().unwrap().len()); + } + + #[test] + fn test_iter_over_attrs() { + let attr = [ + 0x00, 0x01, // type + 0x00, 0x04, // value length + 0x01, 0x01, 0x01, 0x01, // value + ]; + + assert_eq!(1, AttributeIterator::new(&attr).count()); + + for attr in AttributeIterator::new(&attr) { + match attr { + Ok((typ, val)) => { + assert_eq!(1u16, typ); + assert_eq!([0x01, 0x01, 0x01, 0x01], *val); + }, + Err(_) => assert!(false, "Test attr should be valid") + } + } + } + + #[test] + fn test_iter_over_attrs_invalid_attr_missing_byte() { + let attr = [ + 0x00, 0x01, // type + 0x00, 0x05, // value length (4+1 because we're simulating a missing byte) + 0x01, 0x01, 0x01, 0x01, // value + ]; + + assert_eq!(1, AttributeIterator::new(&attr).count()); + + for attr in AttributeIterator::new(&attr) { + match attr { + Ok(_) => assert!(false, "Test attr should be invalid"), + Err(_) => assert!(true, "Test attr should be valid") + } + } + } + + #[test] + fn test_iter_over_attrs_invalid_attr_extra_byte() { + let attr = [ + 0x00, 0x01, // type + 0x00, 0x03, // value length (4-1 because we're simulating an extra byte) + 0x01, 0x01, 0x01, 0x01, // value + ]; + + assert_eq!(2, AttributeIterator::new(&attr).count()); + + let mut iter = AttributeIterator::new(&attr); + + if let Some(Ok((typ, val))) = iter.next() { + assert_eq!(1u16, typ); + assert_eq!([0x01, 0x01, 0x01], *val); + } else { + assert!(false, "First attr should be valid"); + } + + if let Some(Err(_)) = iter.next() { + assert!(true); + } else { + assert!(false, "Second attr should be an error"); + } + } +} diff --git a/src/stun_codec/message.rs b/stun_proto/src/message.rs similarity index 98% rename from src/stun_codec/message.rs rename to stun_proto/src/message.rs index c61209504..bb0951acf 100644 --- a/src/stun_codec/message.rs +++ b/stun_proto/src/message.rs @@ -11,6 +11,13 @@ impl StunMessageReader<'_> { u16::from_be_bytes(self.bytes[0..2].try_into().unwrap()) } + pub fn get_method_raw(&self) -> u16 { + let msg_type = self.get_message_type(); + let method_code: u16 = msg_type & 0b0011111011101111; + method_code + } + + pub fn get_method(&self) -> Result { let method_code = self.get_method_raw(); match method_code { @@ -19,12 +26,6 @@ impl StunMessageReader<'_> { } } - pub fn get_method_raw(&self) -> u16 { - let msg_type = self.get_message_type(); - let method_code: u16 = msg_type & 0b11111011101111; - method_code - } - pub fn get_class(&self) -> MessageClass { let class = self.get_class_raw(); match class {