mirror of
https://github.com/erebe/wstunnel.git
synced 2025-09-26 19:21:10 +08:00
simplify ech implementation
This commit is contained in:
32
src/lib.rs
32
src/lib.rs
@@ -79,14 +79,32 @@ async fn create_client_tunnels(
|
||||
args.http_upgrade_path_prefix
|
||||
};
|
||||
|
||||
let http_proxy = mk_http_proxy(args.http_proxy, args.http_proxy_login, args.http_proxy_password)?;
|
||||
let dns_resolver = DnsResolver::new_from_urls(
|
||||
&args.dns_resolver,
|
||||
http_proxy.clone(),
|
||||
SoMark::new(args.socket_so_mark),
|
||||
!args.dns_resolver_prefer_ipv4,
|
||||
)
|
||||
.expect("cannot create dns resolver");
|
||||
|
||||
let transport_scheme = TransportScheme::from_str(args.remote_addr.scheme()).expect("invalid scheme in server url");
|
||||
let tls = match transport_scheme {
|
||||
TransportScheme::Ws | TransportScheme::Http => None,
|
||||
TransportScheme::Wss | TransportScheme::Https => {
|
||||
let (tls_connector, root_store) = tls::tls_connector(
|
||||
let ech_config = if args.tls_ech_enable {
|
||||
dns_resolver
|
||||
.lookup_ech_config(&args.remote_addr.host().unwrap().to_owned())
|
||||
.await?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let tls_connector = tls::tls_connector(
|
||||
args.tls_verify_certificate,
|
||||
transport_scheme.alpn_protocols(),
|
||||
!args.tls_sni_disable,
|
||||
ech_config,
|
||||
tls_certificate,
|
||||
tls_key,
|
||||
)
|
||||
@@ -94,7 +112,6 @@ async fn create_client_tunnels(
|
||||
|
||||
Some(TlsClientConfig {
|
||||
tls_connector: Arc::new(RwLock::new(tls_connector)),
|
||||
root_store,
|
||||
tls_sni_override: args.tls_sni_override,
|
||||
tls_verify_certificate: args.tls_verify_certificate,
|
||||
tls_sni_disabled: args.tls_sni_disable,
|
||||
@@ -120,7 +137,6 @@ async fn create_client_tunnels(
|
||||
}
|
||||
}
|
||||
|
||||
let http_proxy = mk_http_proxy(args.http_proxy, args.http_proxy_login, args.http_proxy_password)?;
|
||||
let client_config = WsClientConfig {
|
||||
remote_addr: TransportAddr::new(
|
||||
TransportScheme::from_str(args.remote_addr.scheme()).unwrap(),
|
||||
@@ -141,14 +157,7 @@ async fn create_client_tunnels(
|
||||
.or(Some(Duration::from_secs(30)))
|
||||
.filter(|d| d.as_secs() > 0),
|
||||
websocket_mask_frame: args.websocket_mask_frame,
|
||||
dns_resolver: DnsResolver::new_from_urls(
|
||||
&args.dns_resolver,
|
||||
http_proxy.clone(),
|
||||
SoMark::new(args.socket_so_mark),
|
||||
!args.dns_resolver_prefer_ipv4,
|
||||
args.tls_ech_enable,
|
||||
)
|
||||
.expect("cannot create dns resolver"),
|
||||
dns_resolver,
|
||||
http_proxy,
|
||||
};
|
||||
|
||||
@@ -492,7 +501,6 @@ async fn run_server_impl(args: Server, executor: impl TokioExecutor) -> anyhow::
|
||||
None,
|
||||
SoMark::new(args.socket_so_mark),
|
||||
!args.dns_resolver_prefer_ipv4,
|
||||
false,
|
||||
)
|
||||
.expect("Cannot create DNS resolver"),
|
||||
restriction_config: args.restrict_config,
|
||||
|
@@ -44,7 +44,6 @@ pub enum DnsResolver {
|
||||
TrustDns {
|
||||
resolver: Resolver<GenericConnector<TokioRuntimeProviderWithSoMark>>,
|
||||
prefer_ipv6: bool,
|
||||
ech_enabled: bool,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -52,11 +51,7 @@ impl DnsResolver {
|
||||
pub async fn lookup_host(&self, domain: &str, port: u16) -> anyhow::Result<Vec<SocketAddr>> {
|
||||
let addrs = match self {
|
||||
Self::System => tokio::net::lookup_host(format!("{}:{}", domain, port)).await?.collect(),
|
||||
Self::TrustDns {
|
||||
resolver,
|
||||
prefer_ipv6,
|
||||
ech_enabled: _,
|
||||
} => {
|
||||
Self::TrustDns { resolver, prefer_ipv6 } => {
|
||||
let addrs: Vec<_> = resolver
|
||||
.lookup_ip(domain)
|
||||
.await?
|
||||
@@ -75,9 +70,7 @@ impl DnsResolver {
|
||||
|
||||
pub async fn lookup_ech_config(&self, domain: &Host) -> Result<Option<EchConfig>, ResolveError> {
|
||||
let resolver = match self {
|
||||
DnsResolver::TrustDns {
|
||||
resolver, ech_enabled, ..
|
||||
} if *ech_enabled => resolver,
|
||||
DnsResolver::TrustDns { resolver, .. } => resolver,
|
||||
_ => {
|
||||
return Ok(None);
|
||||
}
|
||||
@@ -119,7 +112,6 @@ impl DnsResolver {
|
||||
proxy: Option<Url>,
|
||||
so_mark: SoMark,
|
||||
prefer_ipv6: bool,
|
||||
ech_enabled: bool,
|
||||
) -> anyhow::Result<Self> {
|
||||
fn mk_resolver(
|
||||
cfg: ResolverConfig,
|
||||
@@ -195,7 +187,6 @@ impl DnsResolver {
|
||||
return Ok(Self::TrustDns {
|
||||
resolver: mk_resolver(cfg, opts, proxy, so_mark),
|
||||
prefer_ipv6,
|
||||
ech_enabled,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -213,7 +204,6 @@ impl DnsResolver {
|
||||
Ok(Self::TrustDns {
|
||||
resolver: mk_resolver(cfg, ResolverOpts::default(), proxy, so_mark),
|
||||
prefer_ipv6,
|
||||
ech_enabled,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -9,7 +9,7 @@ use std::sync::Arc;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::client::TlsStream;
|
||||
|
||||
use crate::tunnel::client::{TlsClientConfig, WsClientConfig};
|
||||
use crate::tunnel::client::WsClientConfig;
|
||||
use crate::tunnel::server::TlsServerConfig;
|
||||
use crate::tunnel::transport::TransportAddr;
|
||||
use tokio_rustls::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
|
||||
@@ -107,9 +107,10 @@ pub fn tls_connector(
|
||||
tls_verify_certificate: bool,
|
||||
alpn_protocols: Vec<Vec<u8>>,
|
||||
enable_sni: bool,
|
||||
ech_config: Option<EchConfig>,
|
||||
tls_client_certificate: Option<Vec<CertificateDer<'static>>>,
|
||||
tls_client_key: Option<PrivateKeyDer<'static>>,
|
||||
) -> anyhow::Result<(TlsConnector, RootCertStore)> {
|
||||
) -> anyhow::Result<TlsConnector> {
|
||||
let mut root_store = RootCertStore::empty();
|
||||
|
||||
// Load system certificates and add them to the root store
|
||||
@@ -124,7 +125,14 @@ pub fn tls_connector(
|
||||
}
|
||||
}
|
||||
|
||||
let config_builder = ClientConfig::builder().with_root_certificates(root_store.clone());
|
||||
let crypto_provider = ClientConfig::builder().crypto_provider().clone();
|
||||
let config_builder = ClientConfig::builder_with_provider(crypto_provider);
|
||||
let config_builder = if let Some(ech_config) = ech_config {
|
||||
config_builder.with_ech(EchMode::Enable(ech_config))?
|
||||
} else {
|
||||
config_builder.with_safe_default_protocol_versions()?
|
||||
};
|
||||
let config_builder = config_builder.with_root_certificates(root_store);
|
||||
|
||||
let mut config = match (tls_client_certificate, tls_client_key) {
|
||||
(Some(tls_client_certificate), Some(tls_client_key)) => config_builder
|
||||
@@ -143,7 +151,7 @@ pub fn tls_connector(
|
||||
|
||||
config.alpn_protocols = alpn_protocols;
|
||||
let tls_connector = TlsConnector::from(Arc::new(config));
|
||||
Ok((tls_connector, root_store))
|
||||
Ok(tls_connector)
|
||||
}
|
||||
|
||||
pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsAcceptor> {
|
||||
@@ -174,11 +182,7 @@ pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8
|
||||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
client_cfg: &WsClientConfig,
|
||||
tcp_stream: TcpStream,
|
||||
ech_config: Option<EchConfig>,
|
||||
) -> anyhow::Result<TlsStream<TcpStream>> {
|
||||
pub async fn connect(client_cfg: &WsClientConfig, tcp_stream: TcpStream) -> anyhow::Result<TlsStream<TcpStream>> {
|
||||
let sni = client_cfg.tls_server_name();
|
||||
let tls_config = match &client_cfg.remote_addr {
|
||||
TransportAddr::Wss { tls, .. } => tls,
|
||||
@@ -203,30 +207,7 @@ pub async fn connect(
|
||||
}
|
||||
|
||||
let tls_connector = tls_config.tls_connector();
|
||||
let tls_stream = if let Some(ech_config) = ech_config {
|
||||
//FIXME: do not re-create a tls connector every time ?
|
||||
let tls_connector = tls_connector_with_ech(ech_config, tls_config, tls_connector)?;
|
||||
tls_connector.connect(sni, tcp_stream).await?
|
||||
} else {
|
||||
tls_connector.connect(sni, tcp_stream).await?
|
||||
};
|
||||
let tls_stream = tls_connector.connect(sni, tcp_stream).await?;
|
||||
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
fn tls_connector_with_ech(
|
||||
ech_config: EchConfig,
|
||||
tls_config: &TlsClientConfig,
|
||||
original_connector: TlsConnector,
|
||||
) -> anyhow::Result<TlsConnector> {
|
||||
let original_config = original_connector.config();
|
||||
let mut ech_client_config = ClientConfig::builder_with_provider(original_config.crypto_provider().clone())
|
||||
.with_ech(EchMode::from(ech_config))?
|
||||
.with_root_certificates(tls_config.root_store.clone())
|
||||
.with_no_client_auth();
|
||||
|
||||
ech_client_config.key_log = original_config.key_log.clone();
|
||||
ech_client_config.alpn_protocols = original_config.alpn_protocols.clone();
|
||||
|
||||
Ok(TlsConnector::from(Arc::new(ech_client_config)))
|
||||
}
|
||||
|
@@ -25,7 +25,7 @@ use url::Host;
|
||||
|
||||
#[fixture]
|
||||
fn dns_resolver() -> DnsResolver {
|
||||
DnsResolver::new_from_urls(&[], None, SoMark::new(None), true, false).expect("Cannot create DNS resolver")
|
||||
DnsResolver::new_from_urls(&[], None, SoMark::new(None), true).expect("Cannot create DNS resolver")
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
|
@@ -55,8 +55,7 @@ impl ManageConnection for WsConnection {
|
||||
};
|
||||
|
||||
if self.remote_addr.tls().is_some() {
|
||||
let ech_config = self.dns_resolver.lookup_ech_config(self.remote_addr.host()).await?;
|
||||
let tls_stream = tls::connect(self, tcp_stream, ech_config).await?;
|
||||
let tls_stream = tls::connect(self, tcp_stream).await?;
|
||||
Ok(Some(TransportStream::from_client_tls(tls_stream, Bytes::default())))
|
||||
} else {
|
||||
Ok(Some(TransportStream::from_tcp(tcp_stream, Bytes::default())))
|
||||
|
@@ -9,7 +9,6 @@ use std::path::PathBuf;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::time::Duration;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio_rustls::rustls::RootCertStore;
|
||||
use tokio_rustls::rustls::pki_types::{DnsName, ServerName};
|
||||
use url::{Host, Url};
|
||||
|
||||
@@ -58,7 +57,6 @@ pub struct TlsClientConfig {
|
||||
pub tls_connector: Arc<RwLock<TlsConnector>>,
|
||||
pub tls_certificate_path: Option<PathBuf>,
|
||||
pub tls_key_path: Option<PathBuf>,
|
||||
pub root_store: RootCertStore,
|
||||
}
|
||||
|
||||
impl TlsClientConfig {
|
||||
|
@@ -290,6 +290,7 @@ impl TlsReloader {
|
||||
tls.tls_verify_certificate,
|
||||
this.client_config.remote_addr.scheme().alpn_protocols(),
|
||||
!tls.tls_sni_disabled,
|
||||
None,
|
||||
Some(tls_certs),
|
||||
Some(tls_key),
|
||||
);
|
||||
@@ -300,7 +301,7 @@ impl TlsReloader {
|
||||
return;
|
||||
}
|
||||
};
|
||||
*tls.tls_connector.write() = tls_connector.0;
|
||||
*tls.tls_connector.write() = tls_connector;
|
||||
this.tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||
}
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
@@ -333,6 +334,7 @@ impl TlsReloader {
|
||||
tls.tls_verify_certificate,
|
||||
this.client_config.remote_addr.scheme().alpn_protocols(),
|
||||
!tls.tls_sni_disabled,
|
||||
None,
|
||||
Some(tls_certs),
|
||||
Some(tls_key),
|
||||
);
|
||||
@@ -343,7 +345,7 @@ impl TlsReloader {
|
||||
return;
|
||||
}
|
||||
};
|
||||
*tls.tls_connector.write() = tls_connector.0;
|
||||
*tls.tls_connector.write() = tls_connector;
|
||||
this.tls_reload_certificate.store(true, Ordering::Relaxed);
|
||||
}
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
|
Reference in New Issue
Block a user