From 8164878b7c2f7117b28ee245cbc722e6d3aebc8e Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Tue, 5 Mar 2024 20:40:55 +0100 Subject: [PATCH] Fix websocket proxying --- Cargo.toml | 1 + src/lib.rs | 148 +++++++++++++++++++++++++++++++---------------------- 2 files changed, 88 insertions(+), 61 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 55e824d..c6630aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ name="internal" harness = false [dependencies] +http-body-util = "0.1.0" hyper = { version = "1.2.0", features = ["client", "http1"] } hyper-util = { version = "0.1.3", features = ["client-legacy", "http1","tokio"] } lazy_static = "1.4.0" diff --git a/src/lib.rs b/src/lib.rs index d8183b3..cb76112 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,15 +3,15 @@ #[macro_use] extern crate tracing; +use http_body_util::{BodyExt, Empty}; use hyper::header::{HeaderMap, HeaderName, HeaderValue}; use hyper::http::header::{InvalidHeaderValue, ToStrError}; use hyper::http::uri::InvalidUri; -use hyper::upgrade::OnUpgrade; use hyper::{body::Incoming, Error, Request, Response, StatusCode}; use hyper_util::client::legacy::{connect::Connect, Client, Error as LegacyError}; use hyper_util::rt::tokio::TokioIo; use lazy_static::lazy_static; -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; use tokio::io::copy_bidirectional; lazy_static! { @@ -43,6 +43,7 @@ pub enum ProxyError { HyperError(Error), ForwardHeaderError, UpgradeError(String), + UpstreamError(String), } impl From for ProxyError { @@ -132,7 +133,7 @@ fn create_proxied_response(mut response: Response) -> Response { response } -fn forward_uri(forward_url: &str, req: &Request) -> String { +fn create_forward_uri(forward_url: &str, req: &Request) -> String { debug!("Building forward uri"); let split_url = forward_url.split('?').collect::>(); @@ -212,7 +213,6 @@ fn forward_uri(forward_url: &str, req: &Request) -> String { fn create_proxied_request( client_ip: IpAddr, - forward_url: &str, mut request: Request, upgrade_type: Option<&String>, ) -> Result, ProxyError> { @@ -230,16 +230,8 @@ fn create_proxied_request( }) .unwrap_or(false); - let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?; - debug!("Setting headers of proxied request"); - //request - // .headers_mut() - // .insert(HOST, HeaderValue::from_str(uri.host().unwrap())?); - - *request.uri_mut() = uri; - remove_hop_headers(request.headers_mut()); remove_connection_headers(request.headers_mut()); @@ -287,12 +279,29 @@ fn create_proxied_request( Ok(request) } +fn get_upstream_addr(forward_uri: &hyper::Uri) -> Result { + let host = forward_uri.host().ok_or(ProxyError::UpstreamError( + "forward_uri has no host".to_string(), + ))?; + let port = forward_uri.port_u16().ok_or(ProxyError::UpstreamError( + "forward_uri has no port".to_string(), + ))?; + Ok(SocketAddr::new( + host.parse().map_err(|_| { + ProxyError::UpstreamError("forward_uri host must be an IP address".to_string()) + })?, + port, + )) +} + +type ResponseBody = http_body_util::combinators::UnsyncBoxBody; + pub async fn call<'a, T: Connect + Clone + Send + Sync + 'static>( client_ip: IpAddr, forward_uri: &str, mut request: Request, client: &'a Client, -) -> Result, ProxyError> { +) -> Result, ProxyError> { debug!( "Received proxy call from {} to {}, client: {}", request.uri().to_string(), @@ -301,57 +310,74 @@ pub async fn call<'a, T: Connect + Clone + Send + Sync + 'static>( ); let request_upgrade_type = get_upgrade_type(request.headers()); - let request_upgraded = request.extensions_mut().remove::(); - let proxied_request = create_proxied_request( - client_ip, - forward_uri, - request, - request_upgrade_type.as_ref(), - )?; - let mut response = client.request(proxied_request).await?; + let request_uri: hyper::Uri = create_forward_uri(forward_uri, &request).parse()?; + *request.uri_mut() = request_uri.clone(); - if response.status() == StatusCode::SWITCHING_PROTOCOLS { - let response_upgrade_type = get_upgrade_type(response.headers()); + let request = create_proxied_request(client_ip, request, request_upgrade_type.as_ref())?; - if request_upgrade_type == response_upgrade_type { - if let Some(request_upgraded) = request_upgraded { - let mut response_upgraded = TokioIo::new( - response - .extensions_mut() - .remove::() - .ok_or(ProxyError::UpgradeError( - "Failed to upgrade response".to_string(), - ))? - .await?, - ); - - debug!("Responding to a connection upgrade response"); - - let mut request_upgraded = TokioIo::new(request_upgraded.await?); - - tokio::spawn(async move { - copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await - }); - - Ok(response) - } else { - Err(ProxyError::UpgradeError( - "request does not have an upgrade extension".to_string(), - )) - } - } else { - Err(ProxyError::UpgradeError(format!( - "backend tried to switch to protocol {:?} when {:?} was requested", - response_upgrade_type, request_upgrade_type - ))) - } - } else { - let proxied_response = create_proxied_response(response); + if request_upgrade_type.is_none() { + let response = client.request(request).await?; debug!("Responding to call with response"); - Ok(proxied_response) + return Ok(create_proxied_response( + response.map(|body| body.map_err(std::io::Error::other).boxed_unsync()), + )); } + + let (request_parts, request_body) = request.into_parts(); + let upstream_request = + Request::from_parts(request_parts.clone(), Empty::::new()); + let mut downstream_request = Request::from_parts(request_parts, request_body); + + let (mut upstream_conn, downstream_response) = { + let upstream_addr = get_upstream_addr(&request_uri)?; + let conn = TokioIo::new( + tokio::net::TcpStream::connect(upstream_addr) + .await + .map_err(|e| ProxyError::UpstreamError(e.to_string()))?, + ); + let (mut sender, conn) = hyper::client::conn::http1::handshake(conn).await?; + + tokio::task::spawn(async move { + if let Err(err) = conn.with_upgrades().await { + warn!("Upgrading connection failed: {:?}", err); + } + }); + + let response = sender.send_request(upstream_request).await?; + + if response.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(ProxyError::UpgradeError( + "Server did not response with Switching Protocols status".to_string(), + )); + }; + + let (response_parts, response_body) = response.into_parts(); + let upstream_response = Response::from_parts(response_parts.clone(), response_body); + let downstream_response = Response::from_parts(response_parts, Empty::new()); + + ( + TokioIo::new(hyper::upgrade::on(upstream_response).await?), + downstream_response, + ) + }; + + tokio::task::spawn(async move { + let mut downstream_conn = match hyper::upgrade::on(&mut downstream_request).await { + Ok(upgraded) => TokioIo::new(upgraded), + Err(e) => { + warn!("Failed to upgrade request: {e}"); + return; + } + }; + + if let Err(e) = copy_bidirectional(&mut downstream_conn, &mut upstream_conn).await { + warn!("Bidirectional copy failed: {e}"); + } + }); + + Ok(downstream_response.map(|body| body.map_err(std::io::Error::other).boxed_unsync())) } pub struct ReverseProxy { @@ -368,7 +394,7 @@ impl ReverseProxy { client_ip: IpAddr, forward_uri: &str, request: Request, - ) -> Result, ProxyError> { + ) -> Result, ProxyError> { call::(client_ip, forward_uri, request, &self.client).await } } @@ -383,8 +409,8 @@ pub mod benches { super::create_proxied_response(response); } - pub fn forward_uri(forward_url: &str, req: &crate::Request) { - super::forward_uri(forward_url, req); + pub fn create_forward_uri(forward_url: &str, req: &crate::Request) { + super::create_forward_uri(forward_url, req); } pub fn create_proxied_request(