From 1a4075a3196b0788ad103065eb8b8a3d1675a18c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Thu, 29 May 2025 11:32:50 +0200 Subject: [PATCH] simplify ech implementation --- src/lib.rs | 32 +++++++++++++++--------- src/protocols/dns/resolver.rs | 14 ++--------- src/protocols/tls/server.rs | 47 +++++++++++------------------------ src/test_integrations.rs | 2 +- src/tunnel/client/cnx_pool.rs | 3 +-- src/tunnel/client/config.rs | 2 -- src/tunnel/tls_reloader.rs | 6 +++-- 7 files changed, 42 insertions(+), 64 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 17f81ac..4b08bca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -79,14 +79,32 @@ async fn create_client_tunnels( args.http_upgrade_path_prefix }; + let http_proxy = mk_http_proxy(args.http_proxy, args.http_proxy_login, args.http_proxy_password)?; + let dns_resolver = DnsResolver::new_from_urls( + &args.dns_resolver, + http_proxy.clone(), + SoMark::new(args.socket_so_mark), + !args.dns_resolver_prefer_ipv4, + ) + .expect("cannot create dns resolver"); + let transport_scheme = TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url"); let tls = match transport_scheme { TransportScheme::Ws | TransportScheme::Http => None, TransportScheme::Wss | TransportScheme::Https => { - let (tls_connector, root_store) = tls::tls_connector( + let ech_config = if args.tls_ech_enable { + dns_resolver + .lookup_ech_config(&args.remote_addr.host().unwrap().to_owned()) + .await? + } else { + None + }; + + let tls_connector = tls::tls_connector( args.tls_verify_certificate, transport_scheme.alpn_protocols(), !args.tls_sni_disable, + ech_config, tls_certificate, tls_key, ) @@ -94,7 +112,6 @@ async fn create_client_tunnels( Some(TlsClientConfig { tls_connector: Arc::new(RwLock::new(tls_connector)), - root_store, tls_sni_override: args.tls_sni_override, tls_verify_certificate: args.tls_verify_certificate, tls_sni_disabled: args.tls_sni_disable, @@ -120,7 +137,6 @@ async fn create_client_tunnels( } } - let http_proxy = mk_http_proxy(args.http_proxy, args.http_proxy_login, args.http_proxy_password)?; let client_config = WsClientConfig { remote_addr: TransportAddr::new( TransportScheme::from_str(args.remote_addr.scheme()).unwrap(), @@ -141,14 +157,7 @@ async fn create_client_tunnels( .or(Some(Duration::from_secs(30))) .filter(|d| d.as_secs() > 0), websocket_mask_frame: args.websocket_mask_frame, - dns_resolver: DnsResolver::new_from_urls( - &args.dns_resolver, - http_proxy.clone(), - SoMark::new(args.socket_so_mark), - !args.dns_resolver_prefer_ipv4, - args.tls_ech_enable, - ) - .expect("cannot create dns resolver"), + dns_resolver, http_proxy, }; @@ -492,7 +501,6 @@ async fn run_server_impl(args: Server, executor: impl TokioExecutor) -> anyhow:: None, SoMark::new(args.socket_so_mark), !args.dns_resolver_prefer_ipv4, - false, ) .expect("Cannot create DNS resolver"), restriction_config: args.restrict_config, diff --git a/src/protocols/dns/resolver.rs b/src/protocols/dns/resolver.rs index e526de7..be1d0ca 100644 --- a/src/protocols/dns/resolver.rs +++ b/src/protocols/dns/resolver.rs @@ -44,7 +44,6 @@ pub enum DnsResolver { TrustDns { resolver: Resolver>, prefer_ipv6: bool, - ech_enabled: bool, }, } @@ -52,11 +51,7 @@ impl DnsResolver { pub async fn lookup_host(&self, domain: &str, port: u16) -> anyhow::Result> { let addrs = match self { Self::System => tokio::net::lookup_host(format!("{}:{}", domain, port)).await?.collect(), - Self::TrustDns { - resolver, - prefer_ipv6, - ech_enabled: _, - } => { + Self::TrustDns { resolver, prefer_ipv6 } => { let addrs: Vec<_> = resolver .lookup_ip(domain) .await? @@ -75,9 +70,7 @@ impl DnsResolver { pub async fn lookup_ech_config(&self, domain: &Host) -> Result, ResolveError> { let resolver = match self { - DnsResolver::TrustDns { - resolver, ech_enabled, .. - } if *ech_enabled => resolver, + DnsResolver::TrustDns { resolver, .. } => resolver, _ => { return Ok(None); } @@ -119,7 +112,6 @@ impl DnsResolver { proxy: Option, so_mark: SoMark, prefer_ipv6: bool, - ech_enabled: bool, ) -> anyhow::Result { fn mk_resolver( cfg: ResolverConfig, @@ -195,7 +187,6 @@ impl DnsResolver { return Ok(Self::TrustDns { resolver: mk_resolver(cfg, opts, proxy, so_mark), prefer_ipv6, - ech_enabled, }); }; @@ -213,7 +204,6 @@ impl DnsResolver { Ok(Self::TrustDns { resolver: mk_resolver(cfg, ResolverOpts::default(), proxy, so_mark), prefer_ipv6, - ech_enabled, }) } } diff --git a/src/protocols/tls/server.rs b/src/protocols/tls/server.rs index 74bcbe3..8166b5d 100644 --- a/src/protocols/tls/server.rs +++ b/src/protocols/tls/server.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use tokio::net::TcpStream; use tokio_rustls::client::TlsStream; -use crate::tunnel::client::{TlsClientConfig, WsClientConfig}; +use crate::tunnel::client::WsClientConfig; use crate::tunnel::server::TlsServerConfig; use crate::tunnel::transport::TransportAddr; use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; @@ -107,9 +107,10 @@ pub fn tls_connector( tls_verify_certificate: bool, alpn_protocols: Vec>, enable_sni: bool, + ech_config: Option, tls_client_certificate: Option>>, tls_client_key: Option>, -) -> anyhow::Result<(TlsConnector, RootCertStore)> { +) -> anyhow::Result { let mut root_store = RootCertStore::empty(); // Load system certificates and add them to the root store @@ -124,7 +125,14 @@ pub fn tls_connector( } } - let config_builder = ClientConfig::builder().with_root_certificates(root_store.clone()); + let crypto_provider = ClientConfig::builder().crypto_provider().clone(); + let config_builder = ClientConfig::builder_with_provider(crypto_provider); + let config_builder = if let Some(ech_config) = ech_config { + config_builder.with_ech(EchMode::Enable(ech_config))? + } else { + config_builder.with_safe_default_protocol_versions()? + }; + let config_builder = config_builder.with_root_certificates(root_store); let mut config = match (tls_client_certificate, tls_client_key) { (Some(tls_client_certificate), Some(tls_client_key)) => config_builder @@ -143,7 +151,7 @@ pub fn tls_connector( config.alpn_protocols = alpn_protocols; let tls_connector = TlsConnector::from(Arc::new(config)); - Ok((tls_connector, root_store)) + Ok(tls_connector) } pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option>>) -> anyhow::Result { @@ -174,11 +182,7 @@ pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option, -) -> anyhow::Result> { +pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyhow::Result> { let sni = client_cfg.tls_server_name(); let tls_config = match &client_cfg.remote_addr { TransportAddr::Wss { tls, .. } => tls, @@ -203,30 +207,7 @@ pub async fn connect( } let tls_connector = tls_config.tls_connector(); - let tls_stream = if let Some(ech_config) = ech_config { - //FIXME: do not re-create a tls connector every time ? - let tls_connector = tls_connector_with_ech(ech_config, tls_config, tls_connector)?; - tls_connector.connect(sni, tcp_stream).await? - } else { - tls_connector.connect(sni, tcp_stream).await? - }; + let tls_stream = tls_connector.connect(sni, tcp_stream).await?; Ok(tls_stream) } - -fn tls_connector_with_ech( - ech_config: EchConfig, - tls_config: &TlsClientConfig, - original_connector: TlsConnector, -) -> anyhow::Result { - let original_config = original_connector.config(); - let mut ech_client_config = ClientConfig::builder_with_provider(original_config.crypto_provider().clone()) - .with_ech(EchMode::from(ech_config))? - .with_root_certificates(tls_config.root_store.clone()) - .with_no_client_auth(); - - ech_client_config.key_log = original_config.key_log.clone(); - ech_client_config.alpn_protocols = original_config.alpn_protocols.clone(); - - Ok(TlsConnector::from(Arc::new(ech_client_config))) -} diff --git a/src/test_integrations.rs b/src/test_integrations.rs index 287e9b9..8b1846a 100644 --- a/src/test_integrations.rs +++ b/src/test_integrations.rs @@ -25,7 +25,7 @@ use url::Host; #[fixture] fn dns_resolver() -> DnsResolver { - DnsResolver::new_from_urls(&[], None, SoMark::new(None), true, false).expect("Cannot create DNS resolver") + DnsResolver::new_from_urls(&[], None, SoMark::new(None), true).expect("Cannot create DNS resolver") } #[fixture] diff --git a/src/tunnel/client/cnx_pool.rs b/src/tunnel/client/cnx_pool.rs index b47935f..480a2cc 100644 --- a/src/tunnel/client/cnx_pool.rs +++ b/src/tunnel/client/cnx_pool.rs @@ -55,8 +55,7 @@ impl ManageConnection for WsConnection { }; if self.remote_addr.tls().is_some() { - let ech_config = self.dns_resolver.lookup_ech_config(self.remote_addr.host()).await?; - let tls_stream = tls::connect(self, tcp_stream, ech_config).await?; + let tls_stream = tls::connect(self, tcp_stream).await?; Ok(Some(TransportStream::from_client_tls(tls_stream, Bytes::default()))) } else { Ok(Some(TransportStream::from_tcp(tcp_stream, Bytes::default()))) diff --git a/src/tunnel/client/config.rs b/src/tunnel/client/config.rs index dda9272..395366a 100644 --- a/src/tunnel/client/config.rs +++ b/src/tunnel/client/config.rs @@ -9,7 +9,6 @@ use std::path::PathBuf; use std::sync::{Arc, LazyLock}; use std::time::Duration; use tokio_rustls::TlsConnector; -use tokio_rustls::rustls::RootCertStore; use tokio_rustls::rustls::pki_types::{DnsName, ServerName}; use url::{Host, Url}; @@ -58,7 +57,6 @@ pub struct TlsClientConfig { pub tls_connector: Arc>, pub tls_certificate_path: Option, pub tls_key_path: Option, - pub root_store: RootCertStore, } impl TlsClientConfig { diff --git a/src/tunnel/tls_reloader.rs b/src/tunnel/tls_reloader.rs index ad5915c..7a20d9f 100644 --- a/src/tunnel/tls_reloader.rs +++ b/src/tunnel/tls_reloader.rs @@ -290,6 +290,7 @@ impl TlsReloader { tls.tls_verify_certificate, this.client_config.remote_addr.scheme().alpn_protocols(), !tls.tls_sni_disabled, + None, Some(tls_certs), Some(tls_key), ); @@ -300,7 +301,7 @@ impl TlsReloader { return; } }; - *tls.tls_connector.write() = tls_connector.0; + *tls.tls_connector.write() = tls_connector; this.tls_reload_certificate.store(true, Ordering::Relaxed); } (Err(err), _) | (_, Err(err)) => { @@ -333,6 +334,7 @@ impl TlsReloader { tls.tls_verify_certificate, this.client_config.remote_addr.scheme().alpn_protocols(), !tls.tls_sni_disabled, + None, Some(tls_certs), Some(tls_key), ); @@ -343,7 +345,7 @@ impl TlsReloader { return; } }; - *tls.tls_connector.write() = tls_connector.0; + *tls.tls_connector.write() = tls_connector; this.tls_reload_certificate.store(true, Ordering::Relaxed); } (Err(err), _) | (_, Err(err)) => {