diff --git a/src/lib.rs b/src/lib.rs index d3ff3b1..9a11283 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -97,7 +97,7 @@ //! use hyper::client::{connect::dns::GaiResolver, HttpConnector}; -use hyper::header::{HeaderMap, HeaderValue}; +use hyper::header::{HeaderMap, HeaderValue, HOST}; use hyper::http::header::{InvalidHeaderValue, ToStrError}; use hyper::http::uri::InvalidUri; use hyper::{Body, Client, Error, Request, Response, Uri}; @@ -171,7 +171,11 @@ fn remove_hop_headers(headers: &HeaderMap) -> HeaderMap(mut response: Response) -> Response { +fn create_proxied_response(mut response: Response, host: HeaderValue) -> Response { + if host.to_str().unwrap_or("") != "" { + response.headers_mut().insert(HOST, host); + } + *response.headers_mut() = remove_hop_headers(response.headers()); response } @@ -190,8 +194,10 @@ fn create_proxied_request( forward_url: &str, mut request: Request, ) -> Result, ProxyError> { + let uri = forward_uri(forward_url, &request)?; + *request.headers_mut() = remove_hop_headers(request.headers()); - *request.uri_mut() = forward_uri(forward_url, &request)?; + *request.uri_mut() = uri.to_owned(); let x_forwarded_for_header_name = "x-forwarded-for"; @@ -207,6 +213,8 @@ fn create_proxied_request( } } + request.headers_mut().insert(HOST, uri.host().unwrap().parse()?); + Ok(request) } @@ -226,10 +234,12 @@ pub async fn call( forward_uri: &str, request: Request, ) -> Result, ProxyError> { + let host = request.headers().get(HOST).unwrap_or(&HeaderValue::from_str("").unwrap()).to_owned(); + let proxied_request = create_proxied_request(client_ip, &forward_uri, request)?; let client = build_client(); let response = client.request(proxied_request).await?; - let proxied_response = create_proxied_response(response); + let proxied_response = create_proxied_response(response, host); Ok(proxied_response) }