simplify ech implementation

This commit is contained in:
Σrebe - Romain GERARD
2025-05-29 11:32:50 +02:00
parent 247c239b72
commit 1a4075a319
7 changed files with 42 additions and 64 deletions

View File

@@ -79,14 +79,32 @@ async fn create_client_tunnels(
args.http_upgrade_path_prefix
};
let http_proxy = mk_http_proxy(args.http_proxy, args.http_proxy_login, args.http_proxy_password)?;
let dns_resolver = DnsResolver::new_from_urls(
&args.dns_resolver,
http_proxy.clone(),
SoMark::new(args.socket_so_mark),
!args.dns_resolver_prefer_ipv4,
)
.expect("cannot create dns resolver");
let transport_scheme = TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url");
let tls = match transport_scheme {
TransportScheme::Ws | TransportScheme::Http => None,
TransportScheme::Wss | TransportScheme::Https => {
let (tls_connector, root_store) = tls::tls_connector(
let ech_config = if args.tls_ech_enable {
dns_resolver
.lookup_ech_config(&args.remote_addr.host().unwrap().to_owned())
.await?
} else {
None
};
let tls_connector = tls::tls_connector(
args.tls_verify_certificate,
transport_scheme.alpn_protocols(),
!args.tls_sni_disable,
ech_config,
tls_certificate,
tls_key,
)
@@ -94,7 +112,6 @@ async fn create_client_tunnels(
Some(TlsClientConfig {
tls_connector: Arc::new(RwLock::new(tls_connector)),
root_store,
tls_sni_override: args.tls_sni_override,
tls_verify_certificate: args.tls_verify_certificate,
tls_sni_disabled: args.tls_sni_disable,
@@ -120,7 +137,6 @@ async fn create_client_tunnels(
}
}
let http_proxy = mk_http_proxy(args.http_proxy, args.http_proxy_login, args.http_proxy_password)?;
let client_config = WsClientConfig {
remote_addr: TransportAddr::new(
TransportScheme::from_str(args.remote_addr.scheme()).unwrap(),
@@ -141,14 +157,7 @@ async fn create_client_tunnels(
.or(Some(Duration::from_secs(30)))
.filter(|d| d.as_secs() > 0),
websocket_mask_frame: args.websocket_mask_frame,
dns_resolver: DnsResolver::new_from_urls(
&args.dns_resolver,
http_proxy.clone(),
SoMark::new(args.socket_so_mark),
!args.dns_resolver_prefer_ipv4,
args.tls_ech_enable,
)
.expect("cannot create dns resolver"),
dns_resolver,
http_proxy,
};
@@ -492,7 +501,6 @@ async fn run_server_impl(args: Server, executor: impl TokioExecutor) -> anyhow::
None,
SoMark::new(args.socket_so_mark),
!args.dns_resolver_prefer_ipv4,
false,
)
.expect("Cannot create DNS resolver"),
restriction_config: args.restrict_config,

View File

@@ -44,7 +44,6 @@ pub enum DnsResolver {
TrustDns {
resolver: Resolver<GenericConnector<TokioRuntimeProviderWithSoMark>>,
prefer_ipv6: bool,
ech_enabled: bool,
},
}
@@ -52,11 +51,7 @@ impl DnsResolver {
pub async fn lookup_host(&self, domain: &str, port: u16) -> anyhow::Result<Vec<SocketAddr>> {
let addrs = match self {
Self::System => tokio::net::lookup_host(format!("{}:{}", domain, port)).await?.collect(),
Self::TrustDns {
resolver,
prefer_ipv6,
ech_enabled: _,
} => {
Self::TrustDns { resolver, prefer_ipv6 } => {
let addrs: Vec<_> = resolver
.lookup_ip(domain)
.await?
@@ -75,9 +70,7 @@ impl DnsResolver {
pub async fn lookup_ech_config(&self, domain: &Host) -> Result<Option<EchConfig>, ResolveError> {
let resolver = match self {
DnsResolver::TrustDns {
resolver, ech_enabled, ..
} if *ech_enabled => resolver,
DnsResolver::TrustDns { resolver, .. } => resolver,
_ => {
return Ok(None);
}
@@ -119,7 +112,6 @@ impl DnsResolver {
proxy: Option<Url>,
so_mark: SoMark,
prefer_ipv6: bool,
ech_enabled: bool,
) -> anyhow::Result<Self> {
fn mk_resolver(
cfg: ResolverConfig,
@@ -195,7 +187,6 @@ impl DnsResolver {
return Ok(Self::TrustDns {
resolver: mk_resolver(cfg, opts, proxy, so_mark),
prefer_ipv6,
ech_enabled,
});
};
@@ -213,7 +204,6 @@ impl DnsResolver {
Ok(Self::TrustDns {
resolver: mk_resolver(cfg, ResolverOpts::default(), proxy, so_mark),
prefer_ipv6,
ech_enabled,
})
}
}

View File

@@ -9,7 +9,7 @@ use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use crate::tunnel::client::{TlsClientConfig, WsClientConfig};
use crate::tunnel::client::WsClientConfig;
use crate::tunnel::server::TlsServerConfig;
use crate::tunnel::transport::TransportAddr;
use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
@@ -107,9 +107,10 @@ pub fn tls_connector(
tls_verify_certificate: bool,
alpn_protocols: Vec<Vec<u8>>,
enable_sni: bool,
ech_config: Option<EchConfig>,
tls_client_certificate: Option<Vec<CertificateDer<'static>>>,
tls_client_key: Option<PrivateKeyDer<'static>>,
) -> anyhow::Result<(TlsConnector, RootCertStore)> {
) -> anyhow::Result<TlsConnector> {
let mut root_store = RootCertStore::empty();
// Load system certificates and add them to the root store
@@ -124,7 +125,14 @@ pub fn tls_connector(
}
}
let config_builder = ClientConfig::builder().with_root_certificates(root_store.clone());
let crypto_provider = ClientConfig::builder().crypto_provider().clone();
let config_builder = ClientConfig::builder_with_provider(crypto_provider);
let config_builder = if let Some(ech_config) = ech_config {
config_builder.with_ech(EchMode::Enable(ech_config))?
} else {
config_builder.with_safe_default_protocol_versions()?
};
let config_builder = config_builder.with_root_certificates(root_store);
let mut config = match (tls_client_certificate, tls_client_key) {
(Some(tls_client_certificate), Some(tls_client_key)) => config_builder
@@ -143,7 +151,7 @@ pub fn tls_connector(
config.alpn_protocols = alpn_protocols;
let tls_connector = TlsConnector::from(Arc::new(config));
Ok((tls_connector, root_store))
Ok(tls_connector)
}
pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsAcceptor> {
@@ -174,11 +182,7 @@ pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8
Ok(TlsAcceptor::from(Arc::new(config)))
}
pub async fn connect(
client_cfg: &WsClientConfig,
tcp_stream: TcpStream,
ech_config: Option<EchConfig>,
) -> anyhow::Result<TlsStream<TcpStream>> {
pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyhow::Result<TlsStream<TcpStream>> {
let sni = client_cfg.tls_server_name();
let tls_config = match &client_cfg.remote_addr {
TransportAddr::Wss { tls, .. } => tls,
@@ -203,30 +207,7 @@ pub async fn connect(
}
let tls_connector = tls_config.tls_connector();
let tls_stream = if let Some(ech_config) = ech_config {
//FIXME: do not re-create a tls connector every time ?
let tls_connector = tls_connector_with_ech(ech_config, tls_config, tls_connector)?;
tls_connector.connect(sni, tcp_stream).await?
} else {
tls_connector.connect(sni, tcp_stream).await?
};
let tls_stream = tls_connector.connect(sni, tcp_stream).await?;
Ok(tls_stream)
}
fn tls_connector_with_ech(
ech_config: EchConfig,
tls_config: &TlsClientConfig,
original_connector: TlsConnector,
) -> anyhow::Result<TlsConnector> {
let original_config = original_connector.config();
let mut ech_client_config = ClientConfig::builder_with_provider(original_config.crypto_provider().clone())
.with_ech(EchMode::from(ech_config))?
.with_root_certificates(tls_config.root_store.clone())
.with_no_client_auth();
ech_client_config.key_log = original_config.key_log.clone();
ech_client_config.alpn_protocols = original_config.alpn_protocols.clone();
Ok(TlsConnector::from(Arc::new(ech_client_config)))
}

View File

@@ -25,7 +25,7 @@ use url::Host;
#[fixture]
fn dns_resolver() -> DnsResolver {
DnsResolver::new_from_urls(&[], None, SoMark::new(None), true, false).expect("Cannot create DNS resolver")
DnsResolver::new_from_urls(&[], None, SoMark::new(None), true).expect("Cannot create DNS resolver")
}
#[fixture]

View File

@@ -55,8 +55,7 @@ impl ManageConnection for WsConnection {
};
if self.remote_addr.tls().is_some() {
let ech_config = self.dns_resolver.lookup_ech_config(self.remote_addr.host()).await?;
let tls_stream = tls::connect(self, tcp_stream, ech_config).await?;
let tls_stream = tls::connect(self, tcp_stream).await?;
Ok(Some(TransportStream::from_client_tls(tls_stream, Bytes::default())))
} else {
Ok(Some(TransportStream::from_tcp(tcp_stream, Bytes::default())))

View File

@@ -9,7 +9,6 @@ use std::path::PathBuf;
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use tokio_rustls::TlsConnector;
use tokio_rustls::rustls::RootCertStore;
use tokio_rustls::rustls::pki_types::{DnsName, ServerName};
use url::{Host, Url};
@@ -58,7 +57,6 @@ pub struct TlsClientConfig {
pub tls_connector: Arc<RwLock<TlsConnector>>,
pub tls_certificate_path: Option<PathBuf>,
pub tls_key_path: Option<PathBuf>,
pub root_store: RootCertStore,
}
impl TlsClientConfig {

View File

@@ -290,6 +290,7 @@ impl TlsReloader {
tls.tls_verify_certificate,
this.client_config.remote_addr.scheme().alpn_protocols(),
!tls.tls_sni_disabled,
None,
Some(tls_certs),
Some(tls_key),
);
@@ -300,7 +301,7 @@ impl TlsReloader {
return;
}
};
*tls.tls_connector.write() = tls_connector.0;
*tls.tls_connector.write() = tls_connector;
this.tls_reload_certificate.store(true, Ordering::Relaxed);
}
(Err(err), _) | (_, Err(err)) => {
@@ -333,6 +334,7 @@ impl TlsReloader {
tls.tls_verify_certificate,
this.client_config.remote_addr.scheme().alpn_protocols(),
!tls.tls_sni_disabled,
None,
Some(tls_certs),
Some(tls_key),
);
@@ -343,7 +345,7 @@ impl TlsReloader {
return;
}
};
*tls.tls_connector.write() = tls_connector.0;
*tls.tls_connector.write() = tls_connector;
this.tls_reload_certificate.store(true, Ordering::Relaxed);
}
(Err(err), _) | (_, Err(err)) => {