mirror of
https://github.com/erebe/wstunnel.git
synced 2025-09-27 03:25:53 +08:00
chore: improve library interface to support provided executor
This commit is contained in:
52
src/lib.rs
52
src/lib.rs
@@ -24,6 +24,7 @@ use crate::tunnel::server::{TlsServerConfig, WsServer, WsServerConfig};
|
|||||||
use crate::tunnel::transport::{TransportAddr, TransportScheme};
|
use crate::tunnel::transport::{TransportAddr, TransportScheme};
|
||||||
use crate::tunnel::{RemoteAddr, to_host_port};
|
use crate::tunnel::{RemoteAddr, to_host_port};
|
||||||
use anyhow::{Context, anyhow};
|
use anyhow::{Context, anyhow};
|
||||||
|
use futures_util::future::BoxFuture;
|
||||||
use hyper::header::HOST;
|
use hyper::header::HOST;
|
||||||
use hyper::http::HeaderValue;
|
use hyper::http::HeaderValue;
|
||||||
use log::debug;
|
use log::debug;
|
||||||
@@ -32,11 +33,30 @@ use std::str::FromStr;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::task::AbortHandle;
|
use tokio::sync::oneshot;
|
||||||
|
use tokio::task::JoinSet;
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::Result<()> {
|
pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::Result<()> {
|
||||||
|
let tunnels = create_client_tunnels(args, executor.clone()).await?;
|
||||||
|
|
||||||
|
// Start all tunnels
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
executor.spawn(async move {
|
||||||
|
let _ = JoinSet::from_iter(tunnels).join_all().await;
|
||||||
|
let _ = tx.send(());
|
||||||
|
});
|
||||||
|
|
||||||
|
// wait for all tunnels to finish
|
||||||
|
rx.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_client_tunnels(
|
||||||
|
args: Client,
|
||||||
|
executor: impl TokioExecutor,
|
||||||
|
) -> anyhow::Result<Vec<BoxFuture<'static, ()>>> {
|
||||||
let (tls_certificate, tls_key) = if let (Some(cert), Some(key)) =
|
let (tls_certificate, tls_key) = if let (Some(cert), Some(key)) =
|
||||||
(args.tls_certificate.as_ref(), args.tls_private_key.as_ref())
|
(args.tls_certificate.as_ref(), args.tls_private_key.as_ref())
|
||||||
{
|
{
|
||||||
@@ -139,11 +159,10 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R
|
|||||||
info!("Starting wstunnel client v{}", env!("CARGO_PKG_VERSION"),);
|
info!("Starting wstunnel client v{}", env!("CARGO_PKG_VERSION"),);
|
||||||
|
|
||||||
// Keep track of all spawned tunnels
|
// Keep track of all spawned tunnels
|
||||||
let executor = client.executor.clone();
|
let mut tunnels: Vec<BoxFuture<()>> = Vec::with_capacity(args.remote_to_local.len() + args.local_to_remote.len());
|
||||||
let mut tunnels: Vec<AbortHandle> = Vec::with_capacity(args.remote_to_local.len() + args.local_to_remote.len());
|
|
||||||
macro_rules! spawn_tunnel {
|
macro_rules! spawn_tunnel {
|
||||||
( $($s:stmt);* ) => {
|
( $($s:stmt);* ) => {
|
||||||
tunnels.push(executor.spawn(async move {
|
tunnels.push(Box::pin(async move {
|
||||||
$($s)*
|
$($s)*
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
@@ -383,17 +402,21 @@ pub async fn run_client(args: Client, executor: impl TokioExecutor) -> anyhow::R
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for all tunnels to complete
|
Ok(tunnels)
|
||||||
let mut ticker = tokio::time::interval(Duration::from_secs(1));
|
|
||||||
for tunnel in tunnels.into_iter() {
|
|
||||||
while !tunnel.is_finished() {
|
|
||||||
ticker.tick().await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
|
pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
let exec = executor.clone();
|
||||||
|
executor.spawn(async move {
|
||||||
|
let ret = run_server_impl(args, exec).await;
|
||||||
|
let _ = tx.send(ret);
|
||||||
|
});
|
||||||
|
|
||||||
|
rx.await?
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_server_impl(args: Server, executor: impl TokioExecutor) -> anyhow::Result<()> {
|
||||||
let tls_config = if args.remote_addr.scheme() == "wss" {
|
let tls_config = if args.remote_addr.scheme() == "wss" {
|
||||||
let tls_certificate = if let Some(cert_path) = &args.tls_certificate {
|
let tls_certificate = if let Some(cert_path) = &args.tls_certificate {
|
||||||
tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate")
|
tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate")
|
||||||
@@ -479,10 +502,7 @@ pub async fn run_server(args: Server, executor: impl TokioExecutor) -> anyhow::R
|
|||||||
server.config
|
server.config
|
||||||
);
|
);
|
||||||
debug!("Restriction rules: {:#?}", restrictions);
|
debug!("Restriction rules: {:#?}", restrictions);
|
||||||
server.serve(restrictions).await.unwrap_or_else(|err| {
|
server.serve(restrictions).await
|
||||||
panic!("Cannot start wstunnel server: {:?}", err);
|
|
||||||
});
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mk_http_proxy(
|
fn mk_http_proxy(
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
use anyhow::anyhow;
|
use anyhow::{Context, anyhow};
|
||||||
use futures_util::FutureExt;
|
use futures_util::FutureExt;
|
||||||
use http_body_util::Either;
|
use http_body_util::Either;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
@@ -395,7 +395,9 @@ impl<E: crate::TokioExecutor> WsServer<E> {
|
|||||||
|
|
||||||
// Bind server and run forever to serve incoming connections.
|
// Bind server and run forever to serve incoming connections.
|
||||||
let restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?;
|
let restrictions = RestrictionsRulesReloader::new(restrictions, self.config.restriction_config.clone())?;
|
||||||
let listener = TcpListener::bind(&self.config.bind).await?;
|
let listener = TcpListener::bind(&self.config.bind)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("Failed to bind to socket on {}", self.config.bind))?;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let (stream, peer_addr) = match listener.accept().await {
|
let (stream, peer_addr) = match listener.accept().await {
|
||||||
@@ -406,7 +408,7 @@ impl<E: crate::TokioExecutor> WsServer<E> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string(),);
|
let span = span!(Level::INFO, "cnx", peer = peer_addr.to_string());
|
||||||
info!(parent: &span, "Accepting connection");
|
info!(parent: &span, "Accepting connection");
|
||||||
if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), SoMark::new(None)) {
|
if let Err(err) = protocols::tcp::configure_socket(SockRef::from(&stream), SoMark::new(None)) {
|
||||||
warn!("Error while configuring server socket {:?}", err);
|
warn!("Error while configuring server socket {:?}", err);
|
||||||
|
@@ -94,10 +94,18 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
match args.commands {
|
match args.commands {
|
||||||
Commands::Client(args) => {
|
Commands::Client(args) => {
|
||||||
run_client(*args, DefaultTokioExecutor::default()).await?;
|
run_client(*args, DefaultTokioExecutor::default())
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|err| {
|
||||||
|
panic!("Cannot start wstunnel client: {:?}", err);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
Commands::Server(args) => {
|
Commands::Server(args) => {
|
||||||
run_server(*args, DefaultTokioExecutor::default()).await?;
|
run_server(*args, DefaultTokioExecutor::default())
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|err| {
|
||||||
|
panic!("Cannot start wstunnel server: {:?}", err);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user