tests: add more upgrade tests

This commit is contained in:
chesedo 2022-05-10 09:02:21 +02:00 committed by Felipe Noronha
parent 87f1ed675a
commit 16ce317c7e
2 changed files with 85 additions and 20 deletions

View File

@ -150,6 +150,7 @@ pub enum ProxyError {
InvalidUri(InvalidUri), InvalidUri(InvalidUri),
HyperError(Error), HyperError(Error),
ForwardHeaderError, ForwardHeaderError,
UpgradeError(String),
} }
impl From<Error> for ProxyError { impl From<Error> for ProxyError {
@ -314,6 +315,7 @@ fn create_proxied_request<B>(
client_ip: IpAddr, client_ip: IpAddr,
forward_url: &str, forward_url: &str,
mut request: Request<B>, mut request: Request<B>,
upgrade_type: Option<&String>,
) -> Result<Request<B>, ProxyError> { ) -> Result<Request<B>, ProxyError> {
info!("Creating proxied request"); info!("Creating proxied request");
@ -328,7 +330,6 @@ fn create_proxied_request<B>(
.any(|e| e.to_lowercase() == "trailers") .any(|e| e.to_lowercase() == "trailers")
}) })
.unwrap_or(false); .unwrap_or(false);
let upgrade_type = get_upgrade_type(request.headers());
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?; let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
@ -400,32 +401,51 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync +
client_ip client_ip
); );
let request_upgrade_type = get_upgrade_type(request.headers());
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>(); let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?; let proxied_request = create_proxied_request(
client_ip,
forward_uri,
request,
request_upgrade_type.as_ref(),
)?;
let mut response = client.request(proxied_request).await?; let mut response = client.request(proxied_request).await?;
if response.status() == StatusCode::SWITCHING_PROTOCOLS { if response.status() == StatusCode::SWITCHING_PROTOCOLS {
let mut response_upgraded = response let response_upgrade_type = get_upgrade_type(response.headers());
.extensions_mut()
.remove::<OnUpgrade>()
.expect("response does not have an upgrade extension")
.await?;
debug!("Responding to a connection upgrade response"); if request_upgrade_type == response_upgrade_type {
if let Some(request_upgraded) = request_upgraded {
let mut response_upgraded = response
.extensions_mut()
.remove::<OnUpgrade>()
.expect("response does not have an upgrade extension")
.await?;
tokio::spawn(async move { debug!("Responding to a connection upgrade response");
let mut request_upgraded = request_upgraded
.expect("request does not have an upgrade extension")
.await
.expect("failed to upgrade request");
copy_bidirectional(&mut response_upgraded, &mut request_upgraded) tokio::spawn(async move {
.await let mut request_upgraded =
.expect("coping between upgraded connections failed"); request_upgraded.await.expect("failed to upgrade request");
});
Ok(response) copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
.await
.expect("coping between upgraded connections failed");
});
Ok(response)
} else {
Err(ProxyError::UpgradeError(
"request does not have an upgrade extension".to_string(),
))
}
} else {
Err(ProxyError::UpgradeError(format!(
"backend tried to switch to protocol {:?} when {:?} was requested",
response_upgrade_type, request_upgrade_type
)))
}
} else { } else {
let proxied_response = create_proxied_response(response); let proxied_response = create_proxied_response(response);
@ -616,7 +636,7 @@ mod tests {
*request.headers_mut().unwrap() = headers_map.clone(); *request.headers_mut().unwrap() = headers_map.clone();
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap()) super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap(), None)
.unwrap(); .unwrap();
}); });
} }
@ -636,7 +656,7 @@ mod tests {
*request.headers_mut().unwrap() = headers_map.clone(); *request.headers_mut().unwrap() = headers_map.clone();
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap()) super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap(), None)
.unwrap(); .unwrap();
}); });
} }

View File

@ -1,5 +1,6 @@
use hyper::client::connect::dns::GaiResolver; use hyper::client::connect::dns::GaiResolver;
use hyper::client::HttpConnector; use hyper::client::HttpConnector;
use hyper::header::{CONNECTION, UPGRADE};
use hyper::server::conn::AddrStream; use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn}; use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri}; use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
@ -46,6 +47,50 @@ async fn test_get_error_500(ctx: &mut ProxyTestContext) {
assert_eq!(500, resp.status()); assert_eq!(500, resp.status());
} }
#[test_context(ProxyTestContext)]
#[tokio::test]
async fn test_upgrade_mismatch(ctx: &mut ProxyTestContext) {
ctx.http_back.add(
HandlerBuilder::new("/normal")
.status_code(StatusCode::SWITCHING_PROTOCOLS)
.build(),
);
let resp = Client::new()
.request(
Request::builder()
.header(CONNECTION, "Upgrade")
.header(UPGRADE, "websocket")
.method("GET")
.uri(ctx.uri("/normal"))
.body(Body::from(""))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), 502);
}
#[test_context(ProxyTestContext)]
#[tokio::test]
async fn test_upgrade_unrequested(ctx: &mut ProxyTestContext) {
ctx.http_back.add(
HandlerBuilder::new("/normal")
.status_code(StatusCode::SWITCHING_PROTOCOLS)
.build(),
);
let resp = Client::new()
.request(
Request::builder()
.method("GET")
.uri(ctx.uri("/normal"))
.body(Body::from(""))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), 502);
}
#[test_context(ProxyTestContext)] #[test_context(ProxyTestContext)]
#[tokio::test] #[tokio::test]
async fn test_get(ctx: &mut ProxyTestContext) { async fn test_get(ctx: &mut ProxyTestContext) {