diff --git a/src/lib.rs b/src/lib.rs index 2a049e9..72c159e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,26 +100,29 @@ extern crate test; use hyper::client::{connect::dns::GaiResolver, HttpConnector}; -use hyper::header::{HeaderMap, HeaderValue, HeaderName, HOST}; +use hyper::header::{HeaderName, HeaderValue, HOST}; use hyper::http::header::{InvalidHeaderValue, ToStrError}; use hyper::http::uri::InvalidUri; -use hyper::{Body, Client, Error, Request, Response}; +use hyper::{Body, Client, Error, HeaderMap, Request, Response}; use lazy_static::lazy_static; use std::net::IpAddr; lazy_static! { + static ref TE_HEADER: HeaderName = HeaderName::from_static("te"); + static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection"); + static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade"); // A list of the headers, using hypers actual HeaderName comparison static ref HOP_HEADERS: [HeaderName; 8] = [ - HeaderName::from_static("connection"), + CONNECTION_HEADER.clone(), + TE_HEADER.clone(), HeaderName::from_static("keep-alive"), HeaderName::from_static("proxy-authenticate"), HeaderName::from_static("proxy-authorization"), - HeaderName::from_static("te"), - HeaderName::from_static("trailers"), + HeaderName::from_static("trailer"), HeaderName::from_static("transfer-encoding"), HeaderName::from_static("upgrade"), ]; - + static ref X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); } @@ -154,32 +157,54 @@ impl From for ProxyError { } } -fn is_hop_header(name: &str) -> bool { - HOP_HEADERS.iter().any(|h| h == &name) +fn remove_hop_headers(headers: &mut HeaderMap) { + for header in &*HOP_HEADERS { + headers.remove(header); + } } -/// Returns a clone of the headers without the [hop-by-hop headers]. -/// -/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html -fn remove_hop_headers(headers: &HeaderMap) -> HeaderMap { - let mut result = HeaderMap::new(); - for (k, v) in headers.iter() { - if !is_hop_header(k.as_str()) { - result.insert(k.clone(), v.clone()); +fn get_upgrade_type(headers: &HeaderMap) -> Option { + if headers + .get(&*CONNECTION_HEADER) + .map(|value| { + value + .to_str() + .unwrap() + .split(",") + .any(|e| e.to_lowercase() == "upgrade") + }) + .unwrap_or(false) + { + if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) { + return Some(upgrade_value.to_str().unwrap().to_owned()); + } + } + None +} + +fn remove_connection_headers(headers: &mut HeaderMap) { + if headers.get(&*CONNECTION_HEADER).is_some() { + let value = headers.get(&*CONNECTION_HEADER).map(|e| e.clone()).unwrap(); + + for name in value.to_str().unwrap().split(",") { + if !name.trim().is_empty() { + headers.remove(name.trim()); + } } } - result } fn create_proxied_response(mut response: Response) -> Response { - *response.headers_mut() = remove_hop_headers(response.headers()); + remove_hop_headers(response.headers_mut()); + remove_connection_headers(response.headers_mut()); + response } - fn forward_uri(forward_url: &str, req: &Request) -> String { if let Some(query) = req.uri().query() { - let mut forwarding_uri = String::with_capacity(forward_url.len() + req.uri().path().len() + query.len() + 1); + let mut forwarding_uri = + String::with_capacity(forward_url.len() + req.uri().path().len() + query.len() + 1); forwarding_uri.push_str(forward_url); forwarding_uri.push_str(req.uri().path()); @@ -202,15 +227,43 @@ fn create_proxied_request( forward_url: &str, mut request: Request, ) -> Result, ProxyError> { + let contains_te_trailers_value = request + .headers() + .get(&*TE_HEADER) + .map(|value| { + value + .to_str() + .unwrap() + .split(",") + .any(|e| e.to_lowercase() == "trailers") + }) + .unwrap_or(false); + let upgrade_type = get_upgrade_type(request.headers()); + let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?; - request .headers_mut() .insert(HOST, HeaderValue::from_str(uri.host().unwrap())?); *request.uri_mut() = uri; - *request.headers_mut() = remove_hop_headers(request.headers()); + remove_hop_headers(request.headers_mut()); + remove_connection_headers(request.headers_mut()); + + if contains_te_trailers_value { + request + .headers_mut() + .insert(&*TE_HEADER, HeaderValue::from_static("trailers")); + } + + if let Some(value) = upgrade_type { + request + .headers_mut() + .insert(&*UPGRADE_HEADER, value.parse().unwrap()); + request + .headers_mut() + .insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE")); + } // Add forwarding information in the headers match request.headers_mut().entry(&*X_FORWARDED_FOR) { @@ -257,21 +310,19 @@ pub async fn call( Ok(proxied_response) } - #[cfg(all(not(stable), test))] mod tests { - use rand::distributions::Alphanumeric; - use rand::prelude::*; use hyper::header::HeaderName; use hyper::Uri; use hyper::{HeaderMap, Request, Response}; + use rand::distributions::Alphanumeric; + use rand::prelude::*; use std::net::Ipv4Addr; use std::str::FromStr; use test::Bencher; use test_context::AsyncTestContext; use tokiotest_httpserver::HttpTestContext; - fn generate_string() -> String { let take = rand::thread_rng().gen::().into(); rand::thread_rng() @@ -343,7 +394,10 @@ mod tests { *response.headers_mut().unwrap() = headers_map.clone(); - super::create_proxied_response(response.body(()).unwrap(), hyper::header::HeaderValue::from_static("me")); + super::create_proxied_response( + response.body(()).unwrap(), + hyper::header::HeaderValue::from_static("me"), + ); }); } @@ -405,7 +459,6 @@ mod tests { let port = rand::thread_rng().gen::(); let forward_url = &format!("http://0.0.0.0:{}", port); - let mut headers_map = build_headers(); headers_map.insert( @@ -420,12 +473,8 @@ mod tests { *request.headers_mut().unwrap() = headers_map.clone(); - super::create_proxied_request( - client_ip, - forward_url, - request.body(()).unwrap(), - ) - .unwrap(); + super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap()) + .unwrap(); }); } @@ -444,13 +493,8 @@ mod tests { *request.headers_mut().unwrap() = headers_map.clone(); - super::create_proxied_request( - client_ip, - forward_url, - request.body(()).unwrap(), - ) - .unwrap(); + super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap()) + .unwrap(); }); } } - diff --git a/tests/test_http.rs b/tests/test_http.rs index 1f30a9f..fe35d6f 100644 --- a/tests/test_http.rs +++ b/tests/test_http.rs @@ -1,50 +1,69 @@ +use hyper::server::conn::AddrStream; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri}; +use std::convert::Infallible; +use std::net::{IpAddr, SocketAddr}; +use test_context::test_context; +use test_context::AsyncTestContext; use tokio::sync::oneshot::Sender; use tokio::task::JoinHandle; -use hyper::service::{make_service_fn, service_fn}; -use hyper::server::conn::AddrStream; -use std::convert::Infallible; -use hyper::{Uri, Client, Request, Body, Server, Response, StatusCode}; -use tokiotest_httpserver::{HttpTestContext, take_port}; -use test_context::AsyncTestContext; -use test_context::test_context; -use std::net::{IpAddr, SocketAddr}; use tokiotest_httpserver::handler::HandlerBuilder; +use tokiotest_httpserver::{take_port, HttpTestContext}; struct ProxyTestContext { sender: Sender<()>, proxy_handler: JoinHandle>, http_back: HttpTestContext, - port: u16 + port: u16, } #[test_context(ProxyTestContext)] #[tokio::test] async fn test_get_error_500(ctx: &mut ProxyTestContext) { - let resp = Client::new().get(ctx.uri("/500")).await.unwrap(); + let client = Client::new(); + let resp = client + .request( + Request::builder() + .header("keep-alive", "treu") + .method("GET") + .uri(ctx.uri("/500")) + .body(Body::from("")) + .unwrap(), + ) + .await + .unwrap(); assert_eq!(500, resp.status()); } #[test_context(ProxyTestContext)] #[tokio::test] async fn test_get(ctx: &mut ProxyTestContext) { - ctx.http_back.add(HandlerBuilder::new("/foo").status_code(StatusCode::OK).build()); + ctx.http_back.add( + HandlerBuilder::new("/foo") + .status_code(StatusCode::OK) + .build(), + ); let resp = Client::new().get(ctx.uri("/foo")).await.unwrap(); assert_eq!(200, resp.status()); } -async fn handle(client_ip: IpAddr, req: Request, backend_port: u16) -> Result, Infallible> { - match hyper_reverse_proxy::call(client_ip, - format!("http://127.0.0.1:{}", backend_port).as_str(), - req).await { - Ok(response) => {Ok(response)} - Err(_) => {Ok(Response::builder() - .status(502) - .body(Body::empty()) - .unwrap())} +async fn handle( + client_ip: IpAddr, + req: Request, + backend_port: u16, +) -> Result, Infallible> { + match hyper_reverse_proxy::call( + client_ip, + format!("http://127.0.0.1:{}", backend_port).as_str(), + req, + ) + .await + { + Ok(response) => Ok(response), + Err(_) => Ok(Response::builder().status(502).body(Body::empty()).unwrap()), } } - #[async_trait::async_trait] impl AsyncTestContext for ProxyTestContext { async fn setup() -> ProxyTestContext { @@ -60,13 +79,17 @@ impl AsyncTestContext for ProxyTestContext { }); let port = take_port(); let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), port); - let server = Server::bind(&addr).serve(make_svc).with_graceful_shutdown(async { receiver.await.ok(); }); + let server = Server::bind(&addr) + .serve(make_svc) + .with_graceful_shutdown(async { + receiver.await.ok(); + }); let proxy_handler = tokio::spawn(server); ProxyTestContext { sender, proxy_handler, http_back, - port + port, } } async fn teardown(self) { @@ -77,6 +100,8 @@ impl AsyncTestContext for ProxyTestContext { } impl ProxyTestContext { pub fn uri(&self, path: &str) -> Uri { - format!("http://{}:{}{}", "localhost", self.port, path).parse::().unwrap() + format!("http://{}:{}{}", "localhost", self.port, path) + .parse::() + .unwrap() } -} \ No newline at end of file +}