|
|
|
@ -3,15 +3,15 @@ |
|
|
|
|
#[macro_use] |
|
|
|
|
extern crate tracing; |
|
|
|
|
|
|
|
|
|
use http_body_util::{BodyExt, Empty}; |
|
|
|
|
use hyper::header::{HeaderMap, HeaderName, HeaderValue}; |
|
|
|
|
use hyper::http::header::{InvalidHeaderValue, ToStrError}; |
|
|
|
|
use hyper::http::uri::InvalidUri; |
|
|
|
|
use hyper::upgrade::OnUpgrade; |
|
|
|
|
use hyper::{body::Incoming, Error, Request, Response, StatusCode}; |
|
|
|
|
use hyper_util::client::legacy::{connect::Connect, Client, Error as LegacyError}; |
|
|
|
|
use hyper_util::rt::tokio::TokioIo; |
|
|
|
|
use lazy_static::lazy_static; |
|
|
|
|
use std::net::IpAddr; |
|
|
|
|
use std::net::{IpAddr, SocketAddr}; |
|
|
|
|
use tokio::io::copy_bidirectional; |
|
|
|
|
|
|
|
|
|
lazy_static! { |
|
|
|
@ -43,6 +43,7 @@ pub enum ProxyError { |
|
|
|
|
HyperError(Error), |
|
|
|
|
ForwardHeaderError, |
|
|
|
|
UpgradeError(String), |
|
|
|
|
UpstreamError(String), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl From<LegacyError> for ProxyError { |
|
|
|
@ -132,7 +133,7 @@ fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> { |
|
|
|
|
response |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String { |
|
|
|
|
fn create_forward_uri<B>(forward_url: &str, req: &Request<B>) -> String { |
|
|
|
|
debug!("Building forward uri"); |
|
|
|
|
|
|
|
|
|
let split_url = forward_url.split('?').collect::<Vec<&str>>(); |
|
|
|
@ -212,7 +213,6 @@ fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String { |
|
|
|
|
|
|
|
|
|
fn create_proxied_request<B>( |
|
|
|
|
client_ip: IpAddr, |
|
|
|
|
forward_url: &str, |
|
|
|
|
mut request: Request<B>, |
|
|
|
|
upgrade_type: Option<&String>, |
|
|
|
|
) -> Result<Request<B>, ProxyError> { |
|
|
|
@ -230,16 +230,8 @@ fn create_proxied_request<B>( |
|
|
|
|
}) |
|
|
|
|
.unwrap_or(false); |
|
|
|
|
|
|
|
|
|
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?; |
|
|
|
|
|
|
|
|
|
debug!("Setting headers of proxied request"); |
|
|
|
|
|
|
|
|
|
//request
|
|
|
|
|
// .headers_mut()
|
|
|
|
|
// .insert(HOST, HeaderValue::from_str(uri.host().unwrap())?);
|
|
|
|
|
|
|
|
|
|
*request.uri_mut() = uri; |
|
|
|
|
|
|
|
|
|
remove_hop_headers(request.headers_mut()); |
|
|
|
|
remove_connection_headers(request.headers_mut()); |
|
|
|
|
|
|
|
|
@ -287,12 +279,29 @@ fn create_proxied_request<B>( |
|
|
|
|
Ok(request) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
fn get_upstream_addr(forward_uri: &hyper::Uri) -> Result<SocketAddr, ProxyError> { |
|
|
|
|
let host = forward_uri.host().ok_or(ProxyError::UpstreamError( |
|
|
|
|
"forward_uri has no host".to_string(), |
|
|
|
|
))?; |
|
|
|
|
let port = forward_uri.port_u16().ok_or(ProxyError::UpstreamError( |
|
|
|
|
"forward_uri has no port".to_string(), |
|
|
|
|
))?; |
|
|
|
|
Ok(SocketAddr::new( |
|
|
|
|
host.parse().map_err(|_| { |
|
|
|
|
ProxyError::UpstreamError("forward_uri host must be an IP address".to_string()) |
|
|
|
|
})?, |
|
|
|
|
port, |
|
|
|
|
)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type ResponseBody = http_body_util::combinators::UnsyncBoxBody<hyper::body::Bytes, std::io::Error>; |
|
|
|
|
|
|
|
|
|
pub async fn call<'a, T: Connect + Clone + Send + Sync + 'static>( |
|
|
|
|
client_ip: IpAddr, |
|
|
|
|
forward_uri: &str, |
|
|
|
|
mut request: Request<Incoming>, |
|
|
|
|
client: &'a Client<T, Incoming>, |
|
|
|
|
) -> Result<Response<Incoming>, ProxyError> { |
|
|
|
|
) -> Result<Response<ResponseBody>, ProxyError> { |
|
|
|
|
debug!( |
|
|
|
|
"Received proxy call from {} to {}, client: {}", |
|
|
|
|
request.uri().to_string(), |
|
|
|
@ -301,57 +310,74 @@ pub async fn call<'a, T: Connect + Clone + Send + Sync + 'static>( |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
let request_upgrade_type = get_upgrade_type(request.headers()); |
|
|
|
|
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>(); |
|
|
|
|
|
|
|
|
|
let proxied_request = create_proxied_request( |
|
|
|
|
client_ip, |
|
|
|
|
forward_uri, |
|
|
|
|
request, |
|
|
|
|
request_upgrade_type.as_ref(), |
|
|
|
|
)?; |
|
|
|
|
let mut response = client.request(proxied_request).await?; |
|
|
|
|
|
|
|
|
|
if response.status() == StatusCode::SWITCHING_PROTOCOLS { |
|
|
|
|
let response_upgrade_type = get_upgrade_type(response.headers()); |
|
|
|
|
|
|
|
|
|
if request_upgrade_type == response_upgrade_type { |
|
|
|
|
if let Some(request_upgraded) = request_upgraded { |
|
|
|
|
let mut response_upgraded = TokioIo::new( |
|
|
|
|
response |
|
|
|
|
.extensions_mut() |
|
|
|
|
.remove::<OnUpgrade>() |
|
|
|
|
.ok_or(ProxyError::UpgradeError( |
|
|
|
|
"Failed to upgrade response".to_string(), |
|
|
|
|
))? |
|
|
|
|
.await?, |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
debug!("Responding to a connection upgrade response"); |
|
|
|
|
|
|
|
|
|
let mut request_upgraded = TokioIo::new(request_upgraded.await?); |
|
|
|
|
|
|
|
|
|
tokio::spawn(async move { |
|
|
|
|
copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await |
|
|
|
|
}); |
|
|
|
|
|
|
|
|
|
Ok(response) |
|
|
|
|
} else { |
|
|
|
|
Err(ProxyError::UpgradeError( |
|
|
|
|
"request does not have an upgrade extension".to_string(), |
|
|
|
|
)) |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
Err(ProxyError::UpgradeError(format!( |
|
|
|
|
"backend tried to switch to protocol {:?} when {:?} was requested", |
|
|
|
|
response_upgrade_type, request_upgrade_type |
|
|
|
|
))) |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
let proxied_response = create_proxied_response(response); |
|
|
|
|
let request_uri: hyper::Uri = create_forward_uri(forward_uri, &request).parse()?; |
|
|
|
|
*request.uri_mut() = request_uri.clone(); |
|
|
|
|
|
|
|
|
|
let request = create_proxied_request(client_ip, request, request_upgrade_type.as_ref())?; |
|
|
|
|
|
|
|
|
|
if request_upgrade_type.is_none() { |
|
|
|
|
let response = client.request(request).await?; |
|
|
|
|
|
|
|
|
|
debug!("Responding to call with response"); |
|
|
|
|
Ok(proxied_response) |
|
|
|
|
return Ok(create_proxied_response( |
|
|
|
|
response.map(|body| body.map_err(std::io::Error::other).boxed_unsync()), |
|
|
|
|
)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
let (request_parts, request_body) = request.into_parts(); |
|
|
|
|
let upstream_request = |
|
|
|
|
Request::from_parts(request_parts.clone(), Empty::<hyper::body::Bytes>::new()); |
|
|
|
|
let mut downstream_request = Request::from_parts(request_parts, request_body); |
|
|
|
|
|
|
|
|
|
let (mut upstream_conn, downstream_response) = { |
|
|
|
|
let upstream_addr = get_upstream_addr(&request_uri)?; |
|
|
|
|
let conn = TokioIo::new( |
|
|
|
|
tokio::net::TcpStream::connect(upstream_addr) |
|
|
|
|
.await |
|
|
|
|
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?, |
|
|
|
|
); |
|
|
|
|
let (mut sender, conn) = hyper::client::conn::http1::handshake(conn).await?; |
|
|
|
|
|
|
|
|
|
tokio::task::spawn(async move { |
|
|
|
|
if let Err(err) = conn.with_upgrades().await { |
|
|
|
|
warn!("Upgrading connection failed: {:?}", err); |
|
|
|
|
} |
|
|
|
|
}); |
|
|
|
|
|
|
|
|
|
let response = sender.send_request(upstream_request).await?; |
|
|
|
|
|
|
|
|
|
if response.status() != StatusCode::SWITCHING_PROTOCOLS { |
|
|
|
|
return Err(ProxyError::UpgradeError( |
|
|
|
|
"Server did not response with Switching Protocols status".to_string(), |
|
|
|
|
)); |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
let (response_parts, response_body) = response.into_parts(); |
|
|
|
|
let upstream_response = Response::from_parts(response_parts.clone(), response_body); |
|
|
|
|
let downstream_response = Response::from_parts(response_parts, Empty::new()); |
|
|
|
|
|
|
|
|
|
( |
|
|
|
|
TokioIo::new(hyper::upgrade::on(upstream_response).await?), |
|
|
|
|
downstream_response, |
|
|
|
|
) |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
tokio::task::spawn(async move { |
|
|
|
|
let mut downstream_conn = match hyper::upgrade::on(&mut downstream_request).await { |
|
|
|
|
Ok(upgraded) => TokioIo::new(upgraded), |
|
|
|
|
Err(e) => { |
|
|
|
|
warn!("Failed to upgrade request: {e}"); |
|
|
|
|
return; |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
if let Err(e) = copy_bidirectional(&mut downstream_conn, &mut upstream_conn).await { |
|
|
|
|
warn!("Bidirectional copy failed: {e}"); |
|
|
|
|
} |
|
|
|
|
}); |
|
|
|
|
|
|
|
|
|
Ok(downstream_response.map(|body| body.map_err(std::io::Error::other).boxed_unsync())) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub struct ReverseProxy<T: Connect + Clone + Send + Sync + 'static> { |
|
|
|
@ -368,7 +394,7 @@ impl<T: Connect + Clone + Send + Sync + 'static> ReverseProxy<T> { |
|
|
|
|
client_ip: IpAddr, |
|
|
|
|
forward_uri: &str, |
|
|
|
|
request: Request<Incoming>, |
|
|
|
|
) -> Result<Response<Incoming>, ProxyError> { |
|
|
|
|
) -> Result<Response<ResponseBody>, ProxyError> { |
|
|
|
|
call::<T>(client_ip, forward_uri, request, &self.client).await |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
@ -383,8 +409,8 @@ pub mod benches { |
|
|
|
|
super::create_proxied_response(response); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn forward_uri<B>(forward_url: &str, req: &crate::Request<B>) { |
|
|
|
|
super::forward_uri(forward_url, req); |
|
|
|
|
pub fn create_forward_uri<B>(forward_url: &str, req: &crate::Request<B>) { |
|
|
|
|
super::create_forward_uri(forward_url, req); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn create_proxied_request<B>( |
|
|
|
|