tests: add more upgrade tests
This commit is contained in:
parent
87f1ed675a
commit
16ce317c7e
60
src/lib.rs
60
src/lib.rs
@ -150,6 +150,7 @@ pub enum ProxyError {
|
|||||||
InvalidUri(InvalidUri),
|
InvalidUri(InvalidUri),
|
||||||
HyperError(Error),
|
HyperError(Error),
|
||||||
ForwardHeaderError,
|
ForwardHeaderError,
|
||||||
|
UpgradeError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Error> for ProxyError {
|
impl From<Error> for ProxyError {
|
||||||
@ -314,6 +315,7 @@ fn create_proxied_request<B>(
|
|||||||
client_ip: IpAddr,
|
client_ip: IpAddr,
|
||||||
forward_url: &str,
|
forward_url: &str,
|
||||||
mut request: Request<B>,
|
mut request: Request<B>,
|
||||||
|
upgrade_type: Option<&String>,
|
||||||
) -> Result<Request<B>, ProxyError> {
|
) -> Result<Request<B>, ProxyError> {
|
||||||
info!("Creating proxied request");
|
info!("Creating proxied request");
|
||||||
|
|
||||||
@ -328,7 +330,6 @@ fn create_proxied_request<B>(
|
|||||||
.any(|e| e.to_lowercase() == "trailers")
|
.any(|e| e.to_lowercase() == "trailers")
|
||||||
})
|
})
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
let upgrade_type = get_upgrade_type(request.headers());
|
|
||||||
|
|
||||||
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
|
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
|
||||||
|
|
||||||
@ -400,32 +401,51 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync +
|
|||||||
client_ip
|
client_ip
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let request_upgrade_type = get_upgrade_type(request.headers());
|
||||||
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
|
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
|
||||||
|
|
||||||
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
|
let proxied_request = create_proxied_request(
|
||||||
|
client_ip,
|
||||||
|
forward_uri,
|
||||||
|
request,
|
||||||
|
request_upgrade_type.as_ref(),
|
||||||
|
)?;
|
||||||
let mut response = client.request(proxied_request).await?;
|
let mut response = client.request(proxied_request).await?;
|
||||||
|
|
||||||
if response.status() == StatusCode::SWITCHING_PROTOCOLS {
|
if response.status() == StatusCode::SWITCHING_PROTOCOLS {
|
||||||
let mut response_upgraded = response
|
let response_upgrade_type = get_upgrade_type(response.headers());
|
||||||
.extensions_mut()
|
|
||||||
.remove::<OnUpgrade>()
|
|
||||||
.expect("response does not have an upgrade extension")
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
debug!("Responding to a connection upgrade response");
|
if request_upgrade_type == response_upgrade_type {
|
||||||
|
if let Some(request_upgraded) = request_upgraded {
|
||||||
|
let mut response_upgraded = response
|
||||||
|
.extensions_mut()
|
||||||
|
.remove::<OnUpgrade>()
|
||||||
|
.expect("response does not have an upgrade extension")
|
||||||
|
.await?;
|
||||||
|
|
||||||
tokio::spawn(async move {
|
debug!("Responding to a connection upgrade response");
|
||||||
let mut request_upgraded = request_upgraded
|
|
||||||
.expect("request does not have an upgrade extension")
|
|
||||||
.await
|
|
||||||
.expect("failed to upgrade request");
|
|
||||||
|
|
||||||
copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
|
tokio::spawn(async move {
|
||||||
.await
|
let mut request_upgraded =
|
||||||
.expect("coping between upgraded connections failed");
|
request_upgraded.await.expect("failed to upgrade request");
|
||||||
});
|
|
||||||
|
|
||||||
Ok(response)
|
copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
|
||||||
|
.await
|
||||||
|
.expect("coping between upgraded connections failed");
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
} else {
|
||||||
|
Err(ProxyError::UpgradeError(
|
||||||
|
"request does not have an upgrade extension".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(ProxyError::UpgradeError(format!(
|
||||||
|
"backend tried to switch to protocol {:?} when {:?} was requested",
|
||||||
|
response_upgrade_type, request_upgrade_type
|
||||||
|
)))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
let proxied_response = create_proxied_response(response);
|
let proxied_response = create_proxied_response(response);
|
||||||
|
|
||||||
@ -616,7 +636,7 @@ mod tests {
|
|||||||
|
|
||||||
*request.headers_mut().unwrap() = headers_map.clone();
|
*request.headers_mut().unwrap() = headers_map.clone();
|
||||||
|
|
||||||
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap())
|
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap(), None)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -636,7 +656,7 @@ mod tests {
|
|||||||
|
|
||||||
*request.headers_mut().unwrap() = headers_map.clone();
|
*request.headers_mut().unwrap() = headers_map.clone();
|
||||||
|
|
||||||
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap())
|
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap(), None)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use hyper::client::connect::dns::GaiResolver;
|
use hyper::client::connect::dns::GaiResolver;
|
||||||
use hyper::client::HttpConnector;
|
use hyper::client::HttpConnector;
|
||||||
|
use hyper::header::{CONNECTION, UPGRADE};
|
||||||
use hyper::server::conn::AddrStream;
|
use hyper::server::conn::AddrStream;
|
||||||
use hyper::service::{make_service_fn, service_fn};
|
use hyper::service::{make_service_fn, service_fn};
|
||||||
use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
|
use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
|
||||||
@ -46,6 +47,50 @@ async fn test_get_error_500(ctx: &mut ProxyTestContext) {
|
|||||||
assert_eq!(500, resp.status());
|
assert_eq!(500, resp.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test_context(ProxyTestContext)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_upgrade_mismatch(ctx: &mut ProxyTestContext) {
|
||||||
|
ctx.http_back.add(
|
||||||
|
HandlerBuilder::new("/normal")
|
||||||
|
.status_code(StatusCode::SWITCHING_PROTOCOLS)
|
||||||
|
.build(),
|
||||||
|
);
|
||||||
|
let resp = Client::new()
|
||||||
|
.request(
|
||||||
|
Request::builder()
|
||||||
|
.header(CONNECTION, "Upgrade")
|
||||||
|
.header(UPGRADE, "websocket")
|
||||||
|
.method("GET")
|
||||||
|
.uri(ctx.uri("/normal"))
|
||||||
|
.body(Body::from(""))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 502);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test_context(ProxyTestContext)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_upgrade_unrequested(ctx: &mut ProxyTestContext) {
|
||||||
|
ctx.http_back.add(
|
||||||
|
HandlerBuilder::new("/normal")
|
||||||
|
.status_code(StatusCode::SWITCHING_PROTOCOLS)
|
||||||
|
.build(),
|
||||||
|
);
|
||||||
|
let resp = Client::new()
|
||||||
|
.request(
|
||||||
|
Request::builder()
|
||||||
|
.method("GET")
|
||||||
|
.uri(ctx.uri("/normal"))
|
||||||
|
.body(Body::from(""))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 502);
|
||||||
|
}
|
||||||
|
|
||||||
#[test_context(ProxyTestContext)]
|
#[test_context(ProxyTestContext)]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get(ctx: &mut ProxyTestContext) {
|
async fn test_get(ctx: &mut ProxyTestContext) {
|
||||||
|
Loading…
Reference in New Issue
Block a user