chore: improve library interface to support provided executor

This commit is contained in:
Σrebe - Romain GERARD
2025-05-18 12:36:50 +02:00
parent 3fe61659e7
commit 590c5b5572
3 changed files with 51 additions and 21 deletions

View File

@@ -24,6 +24,7 @@ use crate::tunnel::server::{TlsServerConfig, WsServer, WsServerConfig};
use crate::tunnel::transport::{TransportAddr, TransportScheme}; use crate::tunnel::transport::{TransportAddr, TransportScheme};
use crate::tunnel::{RemoteAddr, to_host_port}; use crate::tunnel::{RemoteAddr, to_host_port};
use anyhow::{Context, anyhow}; use anyhow::{Context, anyhow};
use futures_util::future::BoxFuture;
use hyper::header::HOST; use hyper::header::HOST;
use hyper::http::HeaderValue; use hyper::http::HeaderValue;
use log::debug; use log::debug;
@@ -32,11 +33,30 @@ use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::select; use tokio::select;
use tokio::task::AbortHandle; use tokio::sync::oneshot;
use tokio::task::JoinSet;
use tracing::{error, info}; use tracing::{error, info};
use url::Url; use url::Url;
pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::Result<()> { pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::Result<()> {
let tunnels = create_client_tunnels(args, executor.clone()).await?;
// Start all tunnels
let (tx, rx) = oneshot::channel();
executor.spawn(async move {
let _ = JoinSet::from_iter(tunnels).join_all().await;
let _ = tx.send(());
});
// wait for all tunnels to finish
rx.await?;
Ok(())
}
async fn create_client_tunnels(
args: Client,
executor: impl TokioExecutor,
) -> anyhow::Result<Vec<BoxFuture<'static, ()>>> {
let (tls_certificate, tls_key) = if let (Some(cert), Some(key)) = let (tls_certificate, tls_key) = if let (Some(cert), Some(key)) =
(args.tls_certificate.as_ref(), args.tls_private_key.as_ref()) (args.tls_certificate.as_ref(), args.tls_private_key.as_ref())
{ {
@@ -139,11 +159,10 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R
info!("Starting wstunnel client v{}", env!("CARGO_PKG_VERSION"),); info!("Starting wstunnel client v{}", env!("CARGO_PKG_VERSION"),);
// Keep track of all spawned tunnels // Keep track of all spawned tunnels
let executor = client.executor.clone(); let mut tunnels: Vec<BoxFuture<()>> = Vec::with_capacity(args.remote_to_local.len() + args.local_to_remote.len());
let mut tunnels: Vec<AbortHandle> = Vec::with_capacity(args.remote_to_local.len() + args.local_to_remote.len());
macro_rules! spawn_tunnel { macro_rules! spawn_tunnel {
( $($s:stmt);* ) => { ( $($s:stmt);* ) => {
tunnels.push(executor.spawn(async move { tunnels.push(Box::pin(async move {
$($s)* $($s)*
})); }));
} }
@@ -383,17 +402,21 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R
} }
} }
// wait for all tunnels to complete Ok(tunnels)
let mut ticker = tokio::time::interval(Duration::from_secs(1));
for tunnel in tunnels.into_iter() {
while !tunnel.is_finished() {
ticker.tick().await;
}
}
Ok(())
} }
pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> { pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
let (tx, rx) = oneshot::channel();
let exec = executor.clone();
executor.spawn(async move {
let ret = run_server_impl(args, exec).await;
let _ = tx.send(ret);
});
rx.await?
}
async fn run_server_impl(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
let tls_config = if args.remote_addr.scheme() == "wss" { let tls_config = if args.remote_addr.scheme() == "wss" {
let tls_certificate = if let Some(cert_path) = &args.tls_certificate { let tls_certificate = if let Some(cert_path) = &args.tls_certificate {
tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate") tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate")
@@ -479,10 +502,7 @@ pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::R
server.config server.config
); );
debug!("Restriction rules: {:#?}", restrictions); debug!("Restriction rules: {:#?}", restrictions);
server.serve(restrictions).await.unwrap_or_else(|err| { server.serve(restrictions).await
panic!("Cannot start wstunnel server: {:?}", err);
});
Ok(())
} }
fn mk_http_proxy( fn mk_http_proxy(

View File

@@ -1,4 +1,4 @@
use anyhow::anyhow; use anyhow::{Context, anyhow};
use futures_util::FutureExt; use futures_util::FutureExt;
use http_body_util::Either; use http_body_util::Either;
use std::fmt; use std::fmt;
@@ -395,7 +395,9 @@ impl<E: crate::TokioExecutor> WsServer<E> {
// Bind server and run forever to serve incoming connections. // Bind server and run forever to serve incoming connections.
let restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?; let restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?;
let listener = TcpListener::bind(&self.config.bind).await?; let listener = TcpListener::bind(&self.config.bind)
.await
.with_context(|| format!("Failed to bind to socket on {}", self.config.bind))?;
loop { loop {
let (stream, peer_addr) = match listener.accept().await { let (stream, peer_addr) = match listener.accept().await {
@@ -406,7 +408,7 @@ impl<E: crate::TokioExecutor> WsServer<E> {
} }
}; };
let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string(),); let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string());
info!(parent: &span, "Accepting connection"); info!(parent: &span, "Accepting connection");
if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), SoMark::new(None)) { if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), SoMark::new(None)) {
warn!("Error while configuring server socket {:?}", err); warn!("Error while configuring server socket {:?}", err);

View File

@@ -94,10 +94,18 @@ async fn main() -> anyhow::Result<()> {
match args.commands { match args.commands {
Commands::Client(args) => { Commands::Client(args) => {
run_client(*args, DefaultTokioExecutor::default()).await?; run_client(*args, DefaultTokioExecutor::default())
.await
.unwrap_or_else(|err| {
panic!("Cannot start wstunnel client: {:?}", err);
});
} }
Commands::Server(args) => { Commands::Server(args) => {
run_server(*args, DefaultTokioExecutor::default()).await?; run_server(*args, DefaultTokioExecutor::default())
.await
.unwrap_or_else(|err| {
panic!("Cannot start wstunnel server: {:?}", err);
});
} }
} }