refactor: simplify

This commit is contained in:
chesedo 2022-05-04 13:49:50 +02:00 committed by Felipe Noronha
parent d4fdbf2a2e
commit c3d2183195

View File

@ -120,7 +120,7 @@ use hyper::header::{HeaderMap, HeaderName, HeaderValue, HOST};
use hyper::http::header::{InvalidHeaderValue, ToStrError}; use hyper::http::header::{InvalidHeaderValue, ToStrError};
use hyper::http::uri::InvalidUri; use hyper::http::uri::InvalidUri;
use hyper::upgrade::OnUpgrade; use hyper::upgrade::OnUpgrade;
use hyper::{upgrade, Body, Client, Error, Request, Response, StatusCode}; use hyper::{Body, Client, Error, Request, Response, StatusCode};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use std::net::IpAddr; use std::net::IpAddr;
use tokio::io::copy_bidirectional; use tokio::io::copy_bidirectional;
@ -390,7 +390,7 @@ fn create_proxied_request<B>(
pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync + 'static>( pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync + 'static>(
client_ip: IpAddr, client_ip: IpAddr,
forward_uri: &str, forward_uri: &str,
request: Request<Body>, mut request: Request<Body>,
client: &'a Client<T>, client: &'a Client<T>,
) -> Result<Response<Body>, ProxyError> { ) -> Result<Response<Body>, ProxyError> {
info!( info!(
@ -399,40 +399,34 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync +
forward_uri, forward_uri,
client_ip client_ip
); );
let mut request = request;
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>(); let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?; let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
let mut response = client.request(proxied_request).await?;
let proxied_response = client.request(proxied_request).await?; if response.status() == StatusCode::SWITCHING_PROTOCOLS {
let mut response_upgraded = response
if proxied_response.status() == StatusCode::SWITCHING_PROTOCOLS { .extensions_mut()
// if response.status() != proxied_request.st .remove::<OnUpgrade>()
.expect("response does not have an upgrade extension")
let mut response = Response::new(Body::empty()); .await?;
*response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
for (k, v) in proxied_response.headers().into_iter() {
response.headers_mut().append(k, v.clone());
}
let mut response_upgraded = upgrade::on(proxied_response)
.await
.expect("failed to upgrade response");
tokio::spawn(async move { tokio::spawn(async move {
let mut request_upgraded = request_upgraded let mut request_upgraded = request_upgraded
.expect("test") .expect("request does not have an upgrade extension")
.await .await
.expect("failed to upgrade request"); .expect("failed to upgrade request");
copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await; copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
.await
.expect("coping between upgraded connections failed");
}); });
return Ok(response); return Ok(response);
} }
let proxied_response = create_proxied_response(proxied_response);
let proxied_response = create_proxied_response(response);
debug!("Responding to call with response"); debug!("Responding to call with response");
Ok(proxied_response) Ok(proxied_response)