|
|
|
@ -150,6 +150,7 @@ pub enum ProxyError { |
|
|
|
|
InvalidUri(InvalidUri), |
|
|
|
|
HyperError(Error), |
|
|
|
|
ForwardHeaderError, |
|
|
|
|
UpgradeError(String), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl From<Error> for ProxyError { |
|
|
|
@ -314,6 +315,7 @@ fn create_proxied_request<B>( |
|
|
|
|
client_ip: IpAddr, |
|
|
|
|
forward_url: &str, |
|
|
|
|
mut request: Request<B>, |
|
|
|
|
upgrade_type: Option<&String>, |
|
|
|
|
) -> Result<Request<B>, ProxyError> { |
|
|
|
|
info!("Creating proxied request"); |
|
|
|
|
|
|
|
|
@ -328,7 +330,6 @@ fn create_proxied_request<B>( |
|
|
|
|
.any(|e| e.to_lowercase() == "trailers") |
|
|
|
|
}) |
|
|
|
|
.unwrap_or(false); |
|
|
|
|
let upgrade_type = get_upgrade_type(request.headers()); |
|
|
|
|
|
|
|
|
|
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?; |
|
|
|
|
|
|
|
|
@ -400,32 +401,51 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync + |
|
|
|
|
client_ip |
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
let request_upgrade_type = get_upgrade_type(request.headers()); |
|
|
|
|
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>(); |
|
|
|
|
|
|
|
|
|
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?; |
|
|
|
|
let proxied_request = create_proxied_request( |
|
|
|
|
client_ip, |
|
|
|
|
forward_uri, |
|
|
|
|
request, |
|
|
|
|
request_upgrade_type.as_ref(), |
|
|
|
|
)?; |
|
|
|
|
let mut response = client.request(proxied_request).await?; |
|
|
|
|
|
|
|
|
|
if response.status() == StatusCode::SWITCHING_PROTOCOLS { |
|
|
|
|
let mut response_upgraded = response |
|
|
|
|
.extensions_mut() |
|
|
|
|
.remove::<OnUpgrade>() |
|
|
|
|
.expect("response does not have an upgrade extension") |
|
|
|
|
.await?; |
|
|
|
|
|
|
|
|
|
debug!("Responding to a connection upgrade response"); |
|
|
|
|
|
|
|
|
|
tokio::spawn(async move { |
|
|
|
|
let mut request_upgraded = request_upgraded |
|
|
|
|
.expect("request does not have an upgrade extension") |
|
|
|
|
.await |
|
|
|
|
.expect("failed to upgrade request"); |
|
|
|
|
|
|
|
|
|
copy_bidirectional(&mut response_upgraded, &mut request_upgraded) |
|
|
|
|
.await |
|
|
|
|
.expect("coping between upgraded connections failed"); |
|
|
|
|
}); |
|
|
|
|
|
|
|
|
|
Ok(response) |
|
|
|
|
let response_upgrade_type = get_upgrade_type(response.headers()); |
|
|
|
|
|
|
|
|
|
if request_upgrade_type == response_upgrade_type { |
|
|
|
|
if let Some(request_upgraded) = request_upgraded { |
|
|
|
|
let mut response_upgraded = response |
|
|
|
|
.extensions_mut() |
|
|
|
|
.remove::<OnUpgrade>() |
|
|
|
|
.expect("response does not have an upgrade extension") |
|
|
|
|
.await?; |
|
|
|
|
|
|
|
|
|
debug!("Responding to a connection upgrade response"); |
|
|
|
|
|
|
|
|
|
tokio::spawn(async move { |
|
|
|
|
let mut request_upgraded = |
|
|
|
|
request_upgraded.await.expect("failed to upgrade request"); |
|
|
|
|
|
|
|
|
|
copy_bidirectional(&mut response_upgraded, &mut request_upgraded) |
|
|
|
|
.await |
|
|
|
|
.expect("coping between upgraded connections failed"); |
|
|
|
|
}); |
|
|
|
|
|
|
|
|
|
Ok(response) |
|
|
|
|
} else { |
|
|
|
|
Err(ProxyError::UpgradeError( |
|
|
|
|
"request does not have an upgrade extension".to_string(), |
|
|
|
|
)) |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
Err(ProxyError::UpgradeError(format!( |
|
|
|
|
"backend tried to switch to protocol {:?} when {:?} was requested", |
|
|
|
|
response_upgrade_type, request_upgrade_type |
|
|
|
|
))) |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
let proxied_response = create_proxied_response(response); |
|
|
|
|
|
|
|
|
@ -616,7 +636,7 @@ mod tests { |
|
|
|
|
|
|
|
|
|
*request.headers_mut().unwrap() = headers_map.clone(); |
|
|
|
|
|
|
|
|
|
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap()) |
|
|
|
|
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap(), None) |
|
|
|
|
.unwrap(); |
|
|
|
|
}); |
|
|
|
|
} |
|
|
|
@ -636,7 +656,7 @@ mod tests { |
|
|
|
|
|
|
|
|
|
*request.headers_mut().unwrap() = headers_map.clone(); |
|
|
|
|
|
|
|
|
|
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap()) |
|
|
|
|
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap(), None) |
|
|
|
|
.unwrap(); |
|
|
|
|
}); |
|
|
|
|
} |
|
|
|
|