From 7adb97ceaae311c8ad116dca66a4cb303aefa3e3 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Tue, 5 Mar 2024 21:40:28 +0100 Subject: [PATCH] Fix websocket forwarding, the SocketAddr was not being properly parsed from forward_uri --- src/lib.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d019bde..2859769 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -279,7 +279,10 @@ fn create_proxied_request( Ok(request) } -fn get_upstream_addr(forward_uri: &hyper::Uri) -> Result { +fn get_upstream_addr(forward_uri: &str) -> Result { + let forward_uri: hyper::Uri = forward_uri.parse().map_err(|e| { + ProxyError::UpstreamError(format!("parsing forward_uri as a Uri: {e}").to_string()) + })?; let host = forward_uri.host().ok_or(ProxyError::UpstreamError( "forward_uri has no host".to_string(), ))?; @@ -296,7 +299,7 @@ type ResponseBody = http_body_util::combinators::UnsyncBoxBody( client_ip: IpAddr, forward_uri: &str, - mut request: Request, + request: Request, client: &'a Client, ) -> Result, ProxyError> { debug!( @@ -308,12 +311,12 @@ pub async fn call<'a, T: Connect + Clone + Send + Sync + 'static>( let request_upgrade_type = get_upgrade_type(request.headers()); - let request_uri: hyper::Uri = create_forward_uri(forward_uri, &request).parse()?; - *request.uri_mut() = request_uri.clone(); - - let request = create_proxied_request(client_ip, request, request_upgrade_type.as_ref())?; + let mut request = create_proxied_request(client_ip, request, request_upgrade_type.as_ref())?; if request_upgrade_type.is_none() { + let request_uri: hyper::Uri = create_forward_uri(forward_uri, &request).parse()?; + *request.uri_mut() = request_uri.clone(); + let response = client.request(request).await?; debug!("Responding to call with response"); @@ -322,13 +325,13 @@ pub async fn call<'a, T: Connect + Clone + Send + Sync + 'static>( )); } + let upstream_addr = get_upstream_addr(forward_uri)?; let (request_parts, request_body) = request.into_parts(); let upstream_request = Request::from_parts(request_parts.clone(), Empty::::new()); let mut downstream_request = Request::from_parts(request_parts, request_body); let (mut upstream_conn, downstream_response) = { - let upstream_addr = get_upstream_addr(&request_uri)?; let conn = TokioIo::new( tokio::net::TcpStream::connect(upstream_addr) .await