mirror of
https://github.com/erebe/wstunnel.git
synced 2025-09-26 19:21:10 +08:00
chore: improve library interface to support provided executor
This commit is contained in:
52
src/lib.rs
52
src/lib.rs
@@ -24,6 +24,7 @@ use crate::tunnel::server::{TlsServerConfig, WsServer, WsServerConfig};
|
||||
use crate::tunnel::transport::{TransportAddr, TransportScheme};
|
||||
use crate::tunnel::{RemoteAddr, to_host_port};
|
||||
use anyhow::{Context, anyhow};
|
||||
use futures_util::future::BoxFuture;
|
||||
use hyper::header::HOST;
|
||||
use hyper::http::HeaderValue;
|
||||
use log::debug;
|
||||
@@ -32,11 +33,30 @@ use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::select;
|
||||
use tokio::task::AbortHandle;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::{error, info};
|
||||
use url::Url;
|
||||
|
||||
pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::Result<()> {
|
||||
let tunnels = create_client_tunnels(args, executor.clone()).await?;
|
||||
|
||||
// Start all tunnels
|
||||
let (tx, rx) = oneshot::channel();
|
||||
executor.spawn(async move {
|
||||
let _ = JoinSet::from_iter(tunnels).join_all().await;
|
||||
let _ = tx.send(());
|
||||
});
|
||||
|
||||
// wait for all tunnels to finish
|
||||
rx.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_client_tunnels(
|
||||
args: Client,
|
||||
executor: impl TokioExecutor,
|
||||
) -> anyhow::Result<Vec<BoxFuture<'static, ()>>> {
|
||||
let (tls_certificate, tls_key) = if let (Some(cert), Some(key)) =
|
||||
(args.tls_certificate.as_ref(), args.tls_private_key.as_ref())
|
||||
{
|
||||
@@ -139,11 +159,10 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R
|
||||
info!("Starting wstunnel client v{}", env!("CARGO_PKG_VERSION"),);
|
||||
|
||||
// Keep track of all spawned tunnels
|
||||
let executor = client.executor.clone();
|
||||
let mut tunnels: Vec<AbortHandle> = Vec::with_capacity(args.remote_to_local.len() + args.local_to_remote.len());
|
||||
let mut tunnels: Vec<BoxFuture<()>> = Vec::with_capacity(args.remote_to_local.len() + args.local_to_remote.len());
|
||||
macro_rules! spawn_tunnel {
|
||||
( $($s:stmt);* ) => {
|
||||
tunnels.push(executor.spawn(async move {
|
||||
tunnels.push(Box::pin(async move {
|
||||
$($s)*
|
||||
}));
|
||||
}
|
||||
@@ -383,17 +402,21 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R
|
||||
}
|
||||
}
|
||||
|
||||
// wait for all tunnels to complete
|
||||
let mut ticker = tokio::time::interval(Duration::from_secs(1));
|
||||
for tunnel in tunnels.into_iter() {
|
||||
while !tunnel.is_finished() {
|
||||
ticker.tick().await;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
Ok(tunnels)
|
||||
}
|
||||
|
||||
pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let exec = executor.clone();
|
||||
executor.spawn(async move {
|
||||
let ret = run_server_impl(args, exec).await;
|
||||
let _ = tx.send(ret);
|
||||
});
|
||||
|
||||
rx.await?
|
||||
}
|
||||
|
||||
async fn run_server_impl(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
|
||||
let tls_config = if args.remote_addr.scheme() == "wss" {
|
||||
let tls_certificate = if let Some(cert_path) = &args.tls_certificate {
|
||||
tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate")
|
||||
@@ -479,10 +502,7 @@ pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::R
|
||||
server.config
|
||||
);
|
||||
debug!("Restriction rules: {:#?}", restrictions);
|
||||
server.serve(restrictions).await.unwrap_or_else(|err| {
|
||||
panic!("Cannot start wstunnel server: {:?}", err);
|
||||
});
|
||||
Ok(())
|
||||
server.serve(restrictions).await
|
||||
}
|
||||
|
||||
fn mk_http_proxy(
|
||||
|
@@ -1,4 +1,4 @@
|
||||
use anyhow::anyhow;
|
||||
use anyhow::{Context, anyhow};
|
||||
use futures_util::FutureExt;
|
||||
use http_body_util::Either;
|
||||
use std::fmt;
|
||||
@@ -395,7 +395,9 @@ impl<E: crate::TokioExecutor> WsServer<E> {
|
||||
|
||||
// Bind server and run forever to serve incoming connections.
|
||||
let restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?;
|
||||
let listener = TcpListener::bind(&self.config.bind).await?;
|
||||
let listener = TcpListener::bind(&self.config.bind)
|
||||
.await
|
||||
.with_context(|| format!("Failed to bind to socket on {}", self.config.bind))?;
|
||||
|
||||
loop {
|
||||
let (stream, peer_addr) = match listener.accept().await {
|
||||
@@ -406,7 +408,7 @@ impl<E: crate::TokioExecutor> WsServer<E> {
|
||||
}
|
||||
};
|
||||
|
||||
let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string(),);
|
||||
let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string());
|
||||
info!(parent: &span, "Accepting connection");
|
||||
if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), SoMark::new(None)) {
|
||||
warn!("Error while configuring server socket {:?}", err);
|
||||
|
@@ -94,10 +94,18 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
match args.commands {
|
||||
Commands::Client(args) => {
|
||||
run_client(*args, DefaultTokioExecutor::default()).await?;
|
||||
run_client(*args, DefaultTokioExecutor::default())
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
panic!("Cannot start wstunnel client: {:?}", err);
|
||||
});
|
||||
}
|
||||
Commands::Server(args) => {
|
||||
run_server(*args, DefaultTokioExecutor::default()).await?;
|
||||
run_server(*args, DefaultTokioExecutor::default())
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
panic!("Cannot start wstunnel server: {:?}", err);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user