mirror of
https://github.com/erebe/wstunnel.git
synced 2025-10-05 23:16:55 +08:00
feat(http-proxy): support non HTTP CONNECT method for non TLS connection
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3943,6 +3943,7 @@ dependencies = [
|
|||||||
"get_if_addrs",
|
"get_if_addrs",
|
||||||
"hickory-resolver",
|
"hickory-resolver",
|
||||||
"http-body-util",
|
"http-body-util",
|
||||||
|
"httparse",
|
||||||
"hyper",
|
"hyper",
|
||||||
"hyper-util",
|
"hyper-util",
|
||||||
"ipnet",
|
"ipnet",
|
||||||
|
@@ -36,6 +36,7 @@ nix = { version = "0.30.1", features = ["socket", "net", "uio"] }
|
|||||||
parking_lot = "0.12.4"
|
parking_lot = "0.12.4"
|
||||||
pin-project = "1"
|
pin-project = "1"
|
||||||
notify = { version = "8.0.0", features = [] }
|
notify = { version = "8.0.0", features = [] }
|
||||||
|
httparse = { version = "1.10.1", features = [] }
|
||||||
|
|
||||||
rustls-native-certs = { version = "0.8.1", features = [] }
|
rustls-native-certs = { version = "0.8.1", features = [] }
|
||||||
rustls-pemfile = { version = "2.2.0", features = [] }
|
rustls-pemfile = { version = "2.2.0", features = [] }
|
||||||
|
@@ -539,10 +539,7 @@ mod parsers {
|
|||||||
let get_proxy_protocol = |options: &BTreeMap<String, String>| options.contains_key("proxy_protocol");
|
let get_proxy_protocol = |options: &BTreeMap<String, String>| options.contains_key("proxy_protocol");
|
||||||
|
|
||||||
let Some((proto, tunnel_info)) = arg.split_once("://") else {
|
let Some((proto, tunnel_info)) = arg.split_once("://") else {
|
||||||
return Err(Error::new(
|
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse protocol from {arg}")));
|
||||||
ErrorKind::InvalidInput,
|
|
||||||
format!("cannot parse protocol from {arg}"),
|
|
||||||
));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
match proto {
|
match proto {
|
||||||
@@ -689,10 +686,7 @@ mod parsers {
|
|||||||
pub fn parse_sni_override(arg: &str) -> Result<DnsName<'static>, io::Error> {
|
pub fn parse_sni_override(arg: &str) -> Result<DnsName<'static>, io::Error> {
|
||||||
match DnsName::try_from(arg.to_string()) {
|
match DnsName::try_from(arg.to_string()) {
|
||||||
Ok(val) => Ok(val),
|
Ok(val) => Ok(val),
|
||||||
Err(err) => Err(io::Error::new(
|
Err(err) => Err(io::Error::new(ErrorKind::InvalidInput, format!("Invalid sni override: {err}"))),
|
||||||
ErrorKind::InvalidInput,
|
|
||||||
format!("Invalid sni override: {err}"),
|
|
||||||
)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -3,7 +3,7 @@ use std::future::Future;
|
|||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use log::{debug, error};
|
use log::{debug, error};
|
||||||
use std::net::SocketAddr;
|
use std::net::{Ipv4Addr, SocketAddr};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ use tokio::net::{TcpListener, TcpStream};
|
|||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::task::JoinSet;
|
use tokio::task::JoinSet;
|
||||||
use tracing::log::info;
|
use tracing::log::info;
|
||||||
use url::Host;
|
use url::{Host, Url};
|
||||||
|
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
pub struct HttpProxyListener {
|
pub struct HttpProxyListener {
|
||||||
@@ -36,7 +36,7 @@ impl Stream for HttpProxyListener {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_request(
|
fn handle_http_connect_request(
|
||||||
credentials: &Option<String>,
|
credentials: &Option<String>,
|
||||||
dest: &Mutex<Option<(Host, u16)>>,
|
dest: &Mutex<Option<(Host, u16)>>,
|
||||||
req: Request<Incoming>,
|
req: Request<Incoming>,
|
||||||
@@ -81,9 +81,7 @@ pub async fn run_server(
|
|||||||
timeout: Option<Duration>,
|
timeout: Option<Duration>,
|
||||||
credentials: Option<(String, String)>,
|
credentials: Option<(String, String)>,
|
||||||
) -> Result<HttpProxyListener, anyhow::Error> {
|
) -> Result<HttpProxyListener, anyhow::Error> {
|
||||||
info!(
|
info!("Starting http proxy server listening cnx on {bind} with credentials {credentials:?}");
|
||||||
"Starting http proxy server listening cnx on {bind} with credentials {credentials:?}"
|
|
||||||
);
|
|
||||||
|
|
||||||
let listener = TcpListener::bind(bind)
|
let listener = TcpListener::bind(bind)
|
||||||
.await
|
.await
|
||||||
@@ -140,12 +138,37 @@ pub async fn run_server(
|
|||||||
let handle_new_cnx = {
|
let handle_new_cnx = {
|
||||||
let proxy_cfg = proxy_cfg.clone();
|
let proxy_cfg = proxy_cfg.clone();
|
||||||
async move {
|
async move {
|
||||||
|
// We need to know if the http request if a CONNECT method or a regular one.
|
||||||
|
// HTTP CONNECT requires doing a handshake with client (which is easier)
|
||||||
|
// While for regular method, we need to replay the request as if it was done by the client.
|
||||||
|
// Non HTTP CONNECT method only works for non TLS connection/request.
|
||||||
|
let forward_to = {
|
||||||
|
let mut buf = [0; 512];
|
||||||
|
let buf_size = stream.peek(&mut buf).await.ok()?;
|
||||||
|
let mut http_parser = httparse::Request::new(&mut []);
|
||||||
|
|
||||||
|
let _ = http_parser.parse(&buf[..buf_size]);
|
||||||
|
if http_parser.method == Some(hyper::Method::CONNECT.as_str()) {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let url = Url::parse(http_parser.path.unwrap_or("")).ok()?;
|
||||||
|
let host = url.host().unwrap_or(Host::Ipv4(Ipv4Addr::UNSPECIFIED)).to_owned();
|
||||||
|
let port = url.port_or_known_default().unwrap_or(80);
|
||||||
|
Some((host, port))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle regular http request. Meaning we need to forward it directly as is
|
||||||
|
return if forward_to.is_some() {
|
||||||
|
Some((stream, forward_to))
|
||||||
|
} else {
|
||||||
|
// Handle HTTP CONNECT request
|
||||||
let http1 = &proxy_cfg.1;
|
let http1 = &proxy_cfg.1;
|
||||||
let auth_header = &proxy_cfg.0;
|
let auth_header = &proxy_cfg.0;
|
||||||
let forward_to = Mutex::new(None);
|
let forward_to = Mutex::new(None);
|
||||||
let conn_fut = http1.serve_connection(
|
let conn_fut = http1.serve_connection(
|
||||||
hyper_util::rt::TokioIo::new(&mut stream),
|
hyper_util::rt::TokioIo::new(&mut stream),
|
||||||
service_fn(|req| handle_request(auth_header, &forward_to, req)),
|
service_fn(|req| handle_http_connect_request(auth_header, &forward_to, req)),
|
||||||
);
|
);
|
||||||
|
|
||||||
match conn_fut.await {
|
match conn_fut.await {
|
||||||
@@ -155,6 +178,7 @@ pub async fn run_server(
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
tasks.spawn(handle_new_cnx);
|
tasks.spawn(handle_new_cnx);
|
||||||
|
@@ -51,8 +51,8 @@ pub async fn run_server(socket_path: &Path) -> Result<UnixListenerStream, anyhow
|
|||||||
info!("Starting Unix socket server listening cnx on {socket_path:?}");
|
info!("Starting Unix socket server listening cnx on {socket_path:?}");
|
||||||
|
|
||||||
let path_to_delete = !socket_path.exists();
|
let path_to_delete = !socket_path.exists();
|
||||||
let listener = UnixListener::bind(socket_path)
|
let listener =
|
||||||
.with_context(|| format!("Cannot create Unix socket server {socket_path:?}"))?;
|
UnixListener::bind(socket_path).with_context(|| format!("Cannot create Unix socket server {socket_path:?}"))?;
|
||||||
|
|
||||||
Ok(UnixListenerStream::new(listener, path_to_delete))
|
Ok(UnixListenerStream::new(listener, path_to_delete))
|
||||||
}
|
}
|
||||||
|
@@ -129,9 +129,7 @@ where
|
|||||||
.map(|port_mapping| {
|
.map(|port_mapping| {
|
||||||
let port_mapping_parts: Vec<&str> = port_mapping.split(':').collect();
|
let port_mapping_parts: Vec<&str> = port_mapping.split(':').collect();
|
||||||
if port_mapping_parts.len() != 2 {
|
if port_mapping_parts.len() != 2 {
|
||||||
Err(serde::de::Error::custom(format!(
|
Err(serde::de::Error::custom(format!("Invalid port_mapping entry: {port_mapping}")))
|
||||||
"Invalid port_mapping entry: {port_mapping}"
|
|
||||||
)))
|
|
||||||
} else {
|
} else {
|
||||||
let orig_port = port_mapping_parts[0].parse::<u16>().map_err(serde::de::Error::custom)?;
|
let orig_port = port_mapping_parts[0].parse::<u16>().map_err(serde::de::Error::custom)?;
|
||||||
let target_port = port_mapping_parts[1].parse::<u16>().map_err(serde::de::Error::custom)?;
|
let target_port = port_mapping_parts[1].parse::<u16>().map_err(serde::de::Error::custom)?;
|
||||||
|
Reference in New Issue
Block a user