mirror of
				https://github.com/erebe/wstunnel.git
				synced 2025-10-31 02:16:30 +08:00 
			
		
		
		
	chore: improve library interface to support provided executor
This commit is contained in:
		
							
								
								
									
										52
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								src/lib.rs
									
									
									
									
									
								
							| @@ -24,6 +24,7 @@ use crate::tunnel::server::{TlsServerConfig, WsServer, WsServerConfig}; | ||||
| use crate::tunnel::transport::{TransportAddr, TransportScheme}; | ||||
| use crate::tunnel::{RemoteAddr, to_host_port}; | ||||
| use anyhow::{Context, anyhow}; | ||||
| use futures_util::future::BoxFuture; | ||||
| use hyper::header::HOST; | ||||
| use hyper::http::HeaderValue; | ||||
| use log::debug; | ||||
| @@ -32,11 +33,30 @@ use std::str::FromStr; | ||||
| use std::sync::Arc; | ||||
| use std::time::Duration; | ||||
| use tokio::select; | ||||
| use tokio::task::AbortHandle; | ||||
| use tokio::sync::oneshot; | ||||
| use tokio::task::JoinSet; | ||||
| use tracing::{error, info}; | ||||
| use url::Url; | ||||
|  | ||||
| pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::Result<()> { | ||||
|     let tunnels = create_client_tunnels(args, executor.clone()).await?; | ||||
|  | ||||
|     // Start all tunnels | ||||
|     let (tx, rx) = oneshot::channel(); | ||||
|     executor.spawn(async move { | ||||
|         let _ = JoinSet::from_iter(tunnels).join_all().await; | ||||
|         let _ = tx.send(()); | ||||
|     }); | ||||
|  | ||||
|     // wait for all tunnels to finish | ||||
|     rx.await?; | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn create_client_tunnels( | ||||
|     args: Client, | ||||
|     executor: impl TokioExecutor, | ||||
| ) -> anyhow::Result<Vec<BoxFuture<'static, ()>>> { | ||||
|     let (tls_certificate, tls_key) = if let (Some(cert), Some(key)) = | ||||
|         (args.tls_certificate.as_ref(), args.tls_private_key.as_ref()) | ||||
|     { | ||||
| @@ -139,11 +159,10 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R | ||||
|     info!("Starting wstunnel client v{}", env!("CARGO_PKG_VERSION"),); | ||||
|  | ||||
|     // Keep track of all spawned tunnels | ||||
|     let executor = client.executor.clone(); | ||||
|     let mut tunnels: Vec<AbortHandle> = Vec::with_capacity(args.remote_to_local.len() + args.local_to_remote.len()); | ||||
|     let mut tunnels: Vec<BoxFuture<()>> = Vec::with_capacity(args.remote_to_local.len() + args.local_to_remote.len()); | ||||
|     macro_rules! spawn_tunnel { | ||||
|         ( $($s:stmt);* ) => { | ||||
|             tunnels.push(executor.spawn(async move { | ||||
|             tunnels.push(Box::pin(async move { | ||||
|                 $($s)* | ||||
|             })); | ||||
|         } | ||||
| @@ -383,17 +402,21 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // wait for all tunnels to complete | ||||
|     let mut ticker = tokio::time::interval(Duration::from_secs(1)); | ||||
|     for tunnel in tunnels.into_iter() { | ||||
|         while !tunnel.is_finished() { | ||||
|             ticker.tick().await; | ||||
|         } | ||||
|     } | ||||
|     Ok(()) | ||||
|     Ok(tunnels) | ||||
| } | ||||
|  | ||||
| pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> { | ||||
|     let (tx, rx) = oneshot::channel(); | ||||
|     let exec = executor.clone(); | ||||
|     executor.spawn(async move { | ||||
|         let ret = run_server_impl(args, exec).await; | ||||
|         let _ = tx.send(ret); | ||||
|     }); | ||||
|  | ||||
|     rx.await? | ||||
| } | ||||
|  | ||||
| async fn run_server_impl(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> { | ||||
|     let tls_config = if args.remote_addr.scheme() == "wss" { | ||||
|         let tls_certificate = if let Some(cert_path) = &args.tls_certificate { | ||||
|             tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate") | ||||
| @@ -479,10 +502,7 @@ pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::R | ||||
|         server.config | ||||
|     ); | ||||
|     debug!("Restriction rules: {:#?}", restrictions); | ||||
|     server.serve(restrictions).await.unwrap_or_else(|err| { | ||||
|         panic!("Cannot start wstunnel server: {:?}", err); | ||||
|     }); | ||||
|     Ok(()) | ||||
|     server.serve(restrictions).await | ||||
| } | ||||
|  | ||||
| fn mk_http_proxy( | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| use anyhow::anyhow; | ||||
| use anyhow::{Context, anyhow}; | ||||
| use futures_util::FutureExt; | ||||
| use http_body_util::Either; | ||||
| use std::fmt; | ||||
| @@ -395,7 +395,9 @@ impl<E: crate::TokioExecutor> WsServer<E> { | ||||
|  | ||||
|         // Bind server and run forever to serve incoming connections. | ||||
|         let restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?; | ||||
|         let listener = TcpListener::bind(&self.config.bind).await?; | ||||
|         let listener = TcpListener::bind(&self.config.bind) | ||||
|             .await | ||||
|             .with_context(|| format!("Failed to bind to socket on {}", self.config.bind))?; | ||||
|  | ||||
|         loop { | ||||
|             let (stream, peer_addr) = match listener.accept().await { | ||||
| @@ -406,7 +408,7 @@ impl<E: crate::TokioExecutor> WsServer<E> { | ||||
|                 } | ||||
|             }; | ||||
|  | ||||
|             let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string(),); | ||||
|             let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string()); | ||||
|             info!(parent: &span, "Accepting connection"); | ||||
|             if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), SoMark::new(None)) { | ||||
|                 warn!("Error while configuring server socket {:?}", err); | ||||
|   | ||||
| @@ -94,10 +94,18 @@ async fn main() -> anyhow::Result<()> { | ||||
|  | ||||
|     match args.commands { | ||||
|         Commands::Client(args) => { | ||||
|             run_client(*args, DefaultTokioExecutor::default()).await?; | ||||
|             run_client(*args, DefaultTokioExecutor::default()) | ||||
|                 .await | ||||
|                 .unwrap_or_else(|err| { | ||||
|                     panic!("Cannot start wstunnel client: {:?}", err); | ||||
|                 }); | ||||
|         } | ||||
|         Commands::Server(args) => { | ||||
|             run_server(*args, DefaultTokioExecutor::default()).await?; | ||||
|             run_server(*args, DefaultTokioExecutor::default()) | ||||
|                 .await | ||||
|                 .unwrap_or_else(|err| { | ||||
|                     panic!("Cannot start wstunnel server: {:?}", err); | ||||
|                 }); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Σrebe - Romain GERARD
					Σrebe - Romain GERARD