diff --git a/src/lib.rs b/src/lib.rs index 7a96b4f..02e9056 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -120,7 +120,7 @@ use hyper::header::{HeaderMap, HeaderName, HeaderValue, HOST}; use hyper::http::header::{InvalidHeaderValue, ToStrError}; use hyper::http::uri::InvalidUri; use hyper::upgrade::OnUpgrade; -use hyper::{upgrade, Body, Client, Error, Request, Response, StatusCode}; +use hyper::{Body, Client, Error, Request, Response, StatusCode}; use lazy_static::lazy_static; use std::net::IpAddr; use tokio::io::copy_bidirectional; @@ -390,7 +390,7 @@ fn create_proxied_request( pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync + 'static>( client_ip: IpAddr, forward_uri: &str, - request: Request, + mut request: Request, client: &'a Client, ) -> Result, ProxyError> { info!( @@ -399,40 +399,34 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync + forward_uri, client_ip ); - let mut request = request; let request_upgraded = request.extensions_mut().remove::(); let proxied_request = create_proxied_request(client_ip, forward_uri, request)?; + let mut response = client.request(proxied_request).await?; - let proxied_response = client.request(proxied_request).await?; - - if proxied_response.status() == StatusCode::SWITCHING_PROTOCOLS { - // if response.status() != proxied_request.st - - let mut response = Response::new(Body::empty()); - *response.status_mut() = StatusCode::SWITCHING_PROTOCOLS; - - for (k, v) in proxied_response.headers().into_iter() { - response.headers_mut().append(k, v.clone()); - } - - let mut response_upgraded = upgrade::on(proxied_response) - .await - .expect("failed to upgrade response"); + if response.status() == StatusCode::SWITCHING_PROTOCOLS { + let mut response_upgraded = response + .extensions_mut() + .remove::() + .expect("response does not have an upgrade extension") + .await?; tokio::spawn(async move { let mut request_upgraded = request_upgraded - .expect("test") + .expect("request does not have an upgrade extension") .await .expect("failed to upgrade request"); - copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await; + copy_bidirectional(&mut response_upgraded, &mut request_upgraded) + .await + .expect("coping between upgraded connections failed"); }); return Ok(response); } - let proxied_response = create_proxied_response(proxied_response); + + let proxied_response = create_proxied_response(response); debug!("Responding to call with response"); Ok(proxied_response)