feat(http-proxy): support non HTTP CONNECT method for non TLS connection

This commit is contained in:
Σrebe - Romain GERARD
2025-06-30 09:05:01 +02:00
parent d8cf44b69f
commit f505c633df
6 changed files with 49 additions and 31 deletions

1
Cargo.lock generated
View File

@@ -3943,6 +3943,7 @@ dependencies = [
"get_if_addrs", "get_if_addrs",
"hickory-resolver", "hickory-resolver",
"http-body-util", "http-body-util",
"httparse",
"hyper", "hyper",
"hyper-util", "hyper-util",
"ipnet", "ipnet",

View File

@@ -36,6 +36,7 @@ nix = { version = "0.30.1", features = ["socket", "net", "uio"] }
parking_lot = "0.12.4" parking_lot = "0.12.4"
pin-project = "1" pin-project = "1"
notify = { version = "8.0.0", features = [] } notify = { version = "8.0.0", features = [] }
httparse = { version = "1.10.1", features = [] }
rustls-native-certs = { version = "0.8.1", features = [] } rustls-native-certs = { version = "0.8.1", features = [] }
rustls-pemfile = { version = "2.2.0", features = [] } rustls-pemfile = { version = "2.2.0", features = [] }

View File

@@ -539,10 +539,7 @@ mod parsers {
let get_proxy_protocol = |options: &BTreeMap<String, String>| options.contains_key("proxy_protocol"); let get_proxy_protocol = |options: &BTreeMap<String, String>| options.contains_key("proxy_protocol");
let Some((proto, tunnel_info)) = arg.split_once("://") else { let Some((proto, tunnel_info)) = arg.split_once("://") else {
return Err(Error::new( return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse protocol from {arg}")));
ErrorKind::InvalidInput,
format!("cannot parse protocol from {arg}"),
));
}; };
match proto { match proto {
@@ -689,10 +686,7 @@ mod parsers {
pub fn parse_sni_override(arg: &str) -> Result<DnsName<'static>, io::Error> { pub fn parse_sni_override(arg: &str) -> Result<DnsName<'static>, io::Error> {
match DnsName::try_from(arg.to_string()) { match DnsName::try_from(arg.to_string()) {
Ok(val) => Ok(val), Ok(val) => Ok(val),
Err(err) => Err(io::Error::new( Err(err) => Err(io::Error::new(ErrorKind::InvalidInput, format!("Invalid sni override: {err}"))),
ErrorKind::InvalidInput,
format!("Invalid sni override: {err}"),
)),
} }
} }

View File

@@ -3,7 +3,7 @@ use std::future::Future;
use bytes::Bytes; use bytes::Bytes;
use log::{debug, error}; use log::{debug, error};
use std::net::SocketAddr; use std::net::{Ipv4Addr, SocketAddr};
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
@@ -21,7 +21,7 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::select; use tokio::select;
use tokio::task::JoinSet; use tokio::task::JoinSet;
use tracing::log::info; use tracing::log::info;
use url::Host; use url::{Host, Url};
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub struct HttpProxyListener { pub struct HttpProxyListener {
@@ -36,7 +36,7 @@ impl Stream for HttpProxyListener {
} }
} }
fn handle_request( fn handle_http_connect_request(
credentials: &Option<String>, credentials: &Option<String>,
dest: &Mutex<Option<(Host, u16)>>, dest: &Mutex<Option<(Host, u16)>>,
req: Request<Incoming>, req: Request<Incoming>,
@@ -81,9 +81,7 @@ pub async fn run_server(
timeout: Option<Duration>, timeout: Option<Duration>,
credentials: Option<(String, String)>, credentials: Option<(String, String)>,
) -> Result<HttpProxyListener, anyhow::Error> { ) -> Result<HttpProxyListener, anyhow::Error> {
info!( info!("Starting http proxy server listening cnx on {bind} with credentials {credentials:?}");
"Starting http proxy server listening cnx on {bind} with credentials {credentials:?}"
);
let listener = TcpListener::bind(bind) let listener = TcpListener::bind(bind)
.await .await
@@ -140,12 +138,37 @@ pub async fn run_server(
let handle_new_cnx = { let handle_new_cnx = {
let proxy_cfg = proxy_cfg.clone(); let proxy_cfg = proxy_cfg.clone();
async move { async move {
// We need to know if the http request if a CONNECT method or a regular one.
// HTTP CONNECT requires doing a handshake with client (which is easier)
// While for regular method, we need to replay the request as if it was done by the client.
// Non HTTP CONNECT method only works for non TLS connection/request.
let forward_to = {
let mut buf = [0; 512];
let buf_size = stream.peek(&mut buf).await.ok()?;
let mut http_parser = httparse::Request::new(&mut []);
let _ = http_parser.parse(&buf[..buf_size]);
if http_parser.method == Some(hyper::Method::CONNECT.as_str()) {
None
} else {
let url = Url::parse(http_parser.path.unwrap_or("")).ok()?;
let host = url.host().unwrap_or(Host::Ipv4(Ipv4Addr::UNSPECIFIED)).to_owned();
let port = url.port_or_known_default().unwrap_or(80);
Some((host, port))
}
};
// Handle regular http request. Meaning we need to forward it directly as is
return if forward_to.is_some() {
Some((stream, forward_to))
} else {
// Handle HTTP CONNECT request
let http1 = &proxy_cfg.1; let http1 = &proxy_cfg.1;
let auth_header = &proxy_cfg.0; let auth_header = &proxy_cfg.0;
let forward_to = Mutex::new(None); let forward_to = Mutex::new(None);
let conn_fut = http1.serve_connection( let conn_fut = http1.serve_connection(
hyper_util::rt::TokioIo::new(&mut stream), hyper_util::rt::TokioIo::new(&mut stream),
service_fn(|req| handle_request(auth_header, &forward_to, req)), service_fn(|req| handle_http_connect_request(auth_header, &forward_to, req)),
); );
match conn_fut.await { match conn_fut.await {
@@ -155,6 +178,7 @@ pub async fn run_server(
None None
} }
} }
};
} }
}; };
tasks.spawn(handle_new_cnx); tasks.spawn(handle_new_cnx);

View File

@@ -51,8 +51,8 @@ pub async fn run_server(socket_path: &Path) -> Result<UnixListenerStream, anyhow
info!("Starting Unix socket server listening cnx on {socket_path:?}"); info!("Starting Unix socket server listening cnx on {socket_path:?}");
let path_to_delete = !socket_path.exists(); let path_to_delete = !socket_path.exists();
let listener = UnixListener::bind(socket_path) let listener =
.with_context(|| format!("Cannot create Unix socket server {socket_path:?}"))?; UnixListener::bind(socket_path).with_context(|| format!("Cannot create Unix socket server {socket_path:?}"))?;
Ok(UnixListenerStream::new(listener, path_to_delete)) Ok(UnixListenerStream::new(listener, path_to_delete))
} }

View File

@@ -129,9 +129,7 @@ where
.map(|port_mapping| { .map(|port_mapping| {
let port_mapping_parts: Vec<&str> = port_mapping.split(':').collect(); let port_mapping_parts: Vec<&str> = port_mapping.split(':').collect();
if port_mapping_parts.len() != 2 { if port_mapping_parts.len() != 2 {
Err(serde::de::Error::custom(format!( Err(serde::de::Error::custom(format!("Invalid port_mapping entry: {port_mapping}")))
"Invalid port_mapping entry: {port_mapping}"
)))
} else { } else {
let orig_port = port_mapping_parts[0].parse::<u16>().map_err(serde::de::Error::custom)?; let orig_port = port_mapping_parts[0].parse::<u16>().map_err(serde::de::Error::custom)?;
let target_port = port_mapping_parts[1].parse::<u16>().map_err(serde::de::Error::custom)?; let target_port = port_mapping_parts[1].parse::<u16>().map_err(serde::de::Error::custom)?;