chore: improve library interface to support provided executor

This commit is contained in:
Σrebe - Romain GERARD
2025-05-18 12:36:50 +02:00
parent 3fe61659e7
commit 590c5b5572
3 changed files with 51 additions and 21 deletions

View File

@@ -24,6 +24,7 @@ use crate::tunnel::server::{TlsServerConfig, WsServer, WsServerConfig};
use crate::tunnel::transport::{TransportAddr, TransportScheme};
use crate::tunnel::{RemoteAddr, to_host_port};
use anyhow::{Context, anyhow};
use futures_util::future::BoxFuture;
use hyper::header::HOST;
use hyper::http::HeaderValue;
use log::debug;
@@ -32,11 +33,30 @@ use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::select;
use tokio::task::AbortHandle;
use tokio::sync::oneshot;
use tokio::task::JoinSet;
use tracing::{error, info};
use url::Url;
pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::Result<()> {
let tunnels = create_client_tunnels(args, executor.clone()).await?;
// Start all tunnels
let (tx, rx) = oneshot::channel();
executor.spawn(async move {
let _ = JoinSet::from_iter(tunnels).join_all().await;
let _ = tx.send(());
});
// wait for all tunnels to finish
rx.await?;
Ok(())
}
async fn create_client_tunnels(
args: Client,
executor: impl TokioExecutor,
) -> anyhow::Result<Vec<BoxFuture<'static, ()>>> {
let (tls_certificate, tls_key) = if let (Some(cert), Some(key)) =
(args.tls_certificate.as_ref(), args.tls_private_key.as_ref())
{
@@ -139,11 +159,10 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R
info!("Starting wstunnel client v{}", env!("CARGO_PKG_VERSION"),);
// Keep track of all spawned tunnels
let executor = client.executor.clone();
let mut tunnels: Vec<AbortHandle> = Vec::with_capacity(args.remote_to_local.len() + args.local_to_remote.len());
let mut tunnels: Vec<BoxFuture<()>> = Vec::with_capacity(args.remote_to_local.len() + args.local_to_remote.len());
macro_rules! spawn_tunnel {
( $($s:stmt);* ) => {
tunnels.push(executor.spawn(async move {
tunnels.push(Box::pin(async move {
$($s)*
}));
}
@@ -383,17 +402,21 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R
}
}
// wait for all tunnels to complete
let mut ticker = tokio::time::interval(Duration::from_secs(1));
for tunnel in tunnels.into_iter() {
while !tunnel.is_finished() {
ticker.tick().await;
}
}
Ok(())
Ok(tunnels)
}
pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
let (tx, rx) = oneshot::channel();
let exec = executor.clone();
executor.spawn(async move {
let ret = run_server_impl(args, exec).await;
let _ = tx.send(ret);
});
rx.await?
}
async fn run_server_impl(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
let tls_config = if args.remote_addr.scheme() == "wss" {
let tls_certificate = if let Some(cert_path) = &args.tls_certificate {
tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate")
@@ -479,10 +502,7 @@ pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::R
server.config
);
debug!("Restriction rules: {:#?}", restrictions);
server.serve(restrictions).await.unwrap_or_else(|err| {
panic!("Cannot start wstunnel server: {:?}", err);
});
Ok(())
server.serve(restrictions).await
}
fn mk_http_proxy(

View File

@@ -1,4 +1,4 @@
use anyhow::anyhow;
use anyhow::{Context, anyhow};
use futures_util::FutureExt;
use http_body_util::Either;
use std::fmt;
@@ -395,7 +395,9 @@ impl<E: crate::TokioExecutor> WsServer<E> {
// Bind server and run forever to serve incoming connections.
let restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?;
let listener = TcpListener::bind(&self.config.bind).await?;
let listener = TcpListener::bind(&self.config.bind)
.await
.with_context(|| format!("Failed to bind to socket on {}", self.config.bind))?;
loop {
let (stream, peer_addr) = match listener.accept().await {
@@ -406,7 +408,7 @@ impl<E: crate::TokioExecutor> WsServer<E> {
}
};
let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string(),);
let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string());
info!(parent: &span, "Accepting connection");
if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), SoMark::new(None)) {
warn!("Error while configuring server socket {:?}", err);

View File

@@ -94,10 +94,18 @@ async fn main() -> anyhow::Result<()> {
match args.commands {
Commands::Client(args) => {
run_client(*args, DefaultTokioExecutor::default()).await?;
run_client(*args, DefaultTokioExecutor::default())
.await
.unwrap_or_else(|err| {
panic!("Cannot start wstunnel client: {:?}", err);
});
}
Commands::Server(args) => {
run_server(*args, DefaultTokioExecutor::default()).await?;
run_server(*args, DefaultTokioExecutor::default())
.await
.unwrap_or_else(|err| {
panic!("Cannot start wstunnel server: {:?}", err);
});
}
}