diff --git a/src/lib.rs b/src/lib.rs index a651d82..c2edf7c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -150,6 +150,7 @@ pub enum ProxyError { InvalidUri(InvalidUri), HyperError(Error), ForwardHeaderError, + UpgradeError(String), } impl From for ProxyError { @@ -314,6 +315,7 @@ fn create_proxied_request( client_ip: IpAddr, forward_url: &str, mut request: Request, + upgrade_type: Option<&String>, ) -> Result, ProxyError> { info!("Creating proxied request"); @@ -328,7 +330,6 @@ fn create_proxied_request( .any(|e| e.to_lowercase() == "trailers") }) .unwrap_or(false); - let upgrade_type = get_upgrade_type(request.headers()); let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?; @@ -400,32 +401,51 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync + client_ip ); + let request_upgrade_type = get_upgrade_type(request.headers()); let request_upgraded = request.extensions_mut().remove::(); - let proxied_request = create_proxied_request(client_ip, forward_uri, request)?; + let proxied_request = create_proxied_request( + client_ip, + forward_uri, + request, + request_upgrade_type.as_ref(), + )?; let mut response = client.request(proxied_request).await?; if response.status() == StatusCode::SWITCHING_PROTOCOLS { - let mut response_upgraded = response - .extensions_mut() - .remove::() - .expect("response does not have an upgrade extension") - .await?; + let response_upgrade_type = get_upgrade_type(response.headers()); - debug!("Responding to a connection upgrade response"); + if request_upgrade_type == response_upgrade_type { + if let Some(request_upgraded) = request_upgraded { + let mut response_upgraded = response + .extensions_mut() + .remove::() + .expect("response does not have an upgrade extension") + .await?; - tokio::spawn(async move { - let mut request_upgraded = request_upgraded - .expect("request does not have an upgrade extension") - .await - .expect("failed to upgrade request"); + debug!("Responding to a connection upgrade response"); - copy_bidirectional(&mut response_upgraded, &mut request_upgraded) - .await - .expect("coping between upgraded connections failed"); - }); + tokio::spawn(async move { + let mut request_upgraded = + request_upgraded.await.expect("failed to upgrade request"); - Ok(response) + copy_bidirectional(&mut response_upgraded, &mut request_upgraded) + .await + .expect("coping between upgraded connections failed"); + }); + + Ok(response) + } else { + Err(ProxyError::UpgradeError( + "request does not have an upgrade extension".to_string(), + )) + } + } else { + Err(ProxyError::UpgradeError(format!( + "backend tried to switch to protocol {:?} when {:?} was requested", + response_upgrade_type, request_upgrade_type + ))) + } } else { let proxied_response = create_proxied_response(response); @@ -616,7 +636,7 @@ mod tests { *request.headers_mut().unwrap() = headers_map.clone(); - super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap()) + super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap(), None) .unwrap(); }); } @@ -636,7 +656,7 @@ mod tests { *request.headers_mut().unwrap() = headers_map.clone(); - super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap()) + super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap(), None) .unwrap(); }); } diff --git a/tests/test_http.rs b/tests/test_http.rs index e1f53ed..3c409c9 100644 --- a/tests/test_http.rs +++ b/tests/test_http.rs @@ -1,5 +1,6 @@ use hyper::client::connect::dns::GaiResolver; use hyper::client::HttpConnector; +use hyper::header::{CONNECTION, UPGRADE}; use hyper::server::conn::AddrStream; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri}; @@ -46,6 +47,50 @@ async fn test_get_error_500(ctx: &mut ProxyTestContext) { assert_eq!(500, resp.status()); } +#[test_context(ProxyTestContext)] +#[tokio::test] +async fn test_upgrade_mismatch(ctx: &mut ProxyTestContext) { + ctx.http_back.add( + HandlerBuilder::new("/normal") + .status_code(StatusCode::SWITCHING_PROTOCOLS) + .build(), + ); + let resp = Client::new() + .request( + Request::builder() + .header(CONNECTION, "Upgrade") + .header(UPGRADE, "websocket") + .method("GET") + .uri(ctx.uri("/normal")) + .body(Body::from("")) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(resp.status(), 502); +} + +#[test_context(ProxyTestContext)] +#[tokio::test] +async fn test_upgrade_unrequested(ctx: &mut ProxyTestContext) { + ctx.http_back.add( + HandlerBuilder::new("/normal") + .status_code(StatusCode::SWITCHING_PROTOCOLS) + .build(), + ); + let resp = Client::new() + .request( + Request::builder() + .method("GET") + .uri(ctx.uri("/normal")) + .body(Body::from("")) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(resp.status(), 502); +} + #[test_context(ProxyTestContext)] #[tokio::test] async fn test_get(ctx: &mut ProxyTestContext) {