Fix websocket proxying
This commit is contained in:
parent
1dc4618994
commit
8164878b7c
@ -23,6 +23,7 @@ name="internal"
|
||||
harness = false
|
||||
|
||||
[dependencies]
|
||||
http-body-util = "0.1.0"
|
||||
hyper = { version = "1.2.0", features = ["client", "http1"] }
|
||||
hyper-util = { version = "0.1.3", features = ["client-legacy", "http1","tokio"] }
|
||||
lazy_static = "1.4.0"
|
||||
|
148
src/lib.rs
148
src/lib.rs
@ -3,15 +3,15 @@
|
||||
#[macro_use]
|
||||
extern crate tracing;
|
||||
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use hyper::http::header::{InvalidHeaderValue, ToStrError};
|
||||
use hyper::http::uri::InvalidUri;
|
||||
use hyper::upgrade::OnUpgrade;
|
||||
use hyper::{body::Incoming, Error, Request, Response, StatusCode};
|
||||
use hyper_util::client::legacy::{connect::Connect, Client, Error as LegacyError};
|
||||
use hyper_util::rt::tokio::TokioIo;
|
||||
use lazy_static::lazy_static;
|
||||
use std::net::IpAddr;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use tokio::io::copy_bidirectional;
|
||||
|
||||
lazy_static! {
|
||||
@ -43,6 +43,7 @@ pub enum ProxyError {
|
||||
HyperError(Error),
|
||||
ForwardHeaderError,
|
||||
UpgradeError(String),
|
||||
UpstreamError(String),
|
||||
}
|
||||
|
||||
impl From<LegacyError> for ProxyError {
|
||||
@ -132,7 +133,7 @@ fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
|
||||
response
|
||||
}
|
||||
|
||||
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
|
||||
fn create_forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
|
||||
debug!("Building forward uri");
|
||||
|
||||
let split_url = forward_url.split('?').collect::<Vec<&str>>();
|
||||
@ -212,7 +213,6 @@ fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
|
||||
|
||||
fn create_proxied_request<B>(
|
||||
client_ip: IpAddr,
|
||||
forward_url: &str,
|
||||
mut request: Request<B>,
|
||||
upgrade_type: Option<&String>,
|
||||
) -> Result<Request<B>, ProxyError> {
|
||||
@ -230,16 +230,8 @@ fn create_proxied_request<B>(
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
|
||||
|
||||
debug!("Setting headers of proxied request");
|
||||
|
||||
//request
|
||||
// .headers_mut()
|
||||
// .insert(HOST, HeaderValue::from_str(uri.host().unwrap())?);
|
||||
|
||||
*request.uri_mut() = uri;
|
||||
|
||||
remove_hop_headers(request.headers_mut());
|
||||
remove_connection_headers(request.headers_mut());
|
||||
|
||||
@ -287,12 +279,29 @@ fn create_proxied_request<B>(
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
fn get_upstream_addr(forward_uri: &hyper::Uri) -> Result<SocketAddr, ProxyError> {
|
||||
let host = forward_uri.host().ok_or(ProxyError::UpstreamError(
|
||||
"forward_uri has no host".to_string(),
|
||||
))?;
|
||||
let port = forward_uri.port_u16().ok_or(ProxyError::UpstreamError(
|
||||
"forward_uri has no port".to_string(),
|
||||
))?;
|
||||
Ok(SocketAddr::new(
|
||||
host.parse().map_err(|_| {
|
||||
ProxyError::UpstreamError("forward_uri host must be an IP address".to_string())
|
||||
})?,
|
||||
port,
|
||||
))
|
||||
}
|
||||
|
||||
type ResponseBody = http_body_util::combinators::UnsyncBoxBody<hyper::body::Bytes, std::io::Error>;
|
||||
|
||||
pub async fn call<'a, T: Connect + Clone + Send + Sync + 'static>(
|
||||
client_ip: IpAddr,
|
||||
forward_uri: &str,
|
||||
mut request: Request<Incoming>,
|
||||
client: &'a Client<T, Incoming>,
|
||||
) -> Result<Response<Incoming>, ProxyError> {
|
||||
) -> Result<Response<ResponseBody>, ProxyError> {
|
||||
debug!(
|
||||
"Received proxy call from {} to {}, client: {}",
|
||||
request.uri().to_string(),
|
||||
@ -301,57 +310,74 @@ pub async fn call<'a, T: Connect + Clone + Send + Sync + 'static>(
|
||||
);
|
||||
|
||||
let request_upgrade_type = get_upgrade_type(request.headers());
|
||||
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
|
||||
|
||||
let proxied_request = create_proxied_request(
|
||||
client_ip,
|
||||
forward_uri,
|
||||
request,
|
||||
request_upgrade_type.as_ref(),
|
||||
)?;
|
||||
let mut response = client.request(proxied_request).await?;
|
||||
let request_uri: hyper::Uri = create_forward_uri(forward_uri, &request).parse()?;
|
||||
*request.uri_mut() = request_uri.clone();
|
||||
|
||||
if response.status() == StatusCode::SWITCHING_PROTOCOLS {
|
||||
let response_upgrade_type = get_upgrade_type(response.headers());
|
||||
let request = create_proxied_request(client_ip, request, request_upgrade_type.as_ref())?;
|
||||
|
||||
if request_upgrade_type == response_upgrade_type {
|
||||
if let Some(request_upgraded) = request_upgraded {
|
||||
let mut response_upgraded = TokioIo::new(
|
||||
response
|
||||
.extensions_mut()
|
||||
.remove::<OnUpgrade>()
|
||||
.ok_or(ProxyError::UpgradeError(
|
||||
"Failed to upgrade response".to_string(),
|
||||
))?
|
||||
.await?,
|
||||
);
|
||||
|
||||
debug!("Responding to a connection upgrade response");
|
||||
|
||||
let mut request_upgraded = TokioIo::new(request_upgraded.await?);
|
||||
|
||||
tokio::spawn(async move {
|
||||
copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await
|
||||
});
|
||||
|
||||
Ok(response)
|
||||
} else {
|
||||
Err(ProxyError::UpgradeError(
|
||||
"request does not have an upgrade extension".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(ProxyError::UpgradeError(format!(
|
||||
"backend tried to switch to protocol {:?} when {:?} was requested",
|
||||
response_upgrade_type, request_upgrade_type
|
||||
)))
|
||||
}
|
||||
} else {
|
||||
let proxied_response = create_proxied_response(response);
|
||||
if request_upgrade_type.is_none() {
|
||||
let response = client.request(request).await?;
|
||||
|
||||
debug!("Responding to call with response");
|
||||
Ok(proxied_response)
|
||||
return Ok(create_proxied_response(
|
||||
response.map(|body| body.map_err(std::io::Error::other).boxed_unsync()),
|
||||
));
|
||||
}
|
||||
|
||||
let (request_parts, request_body) = request.into_parts();
|
||||
let upstream_request =
|
||||
Request::from_parts(request_parts.clone(), Empty::<hyper::body::Bytes>::new());
|
||||
let mut downstream_request = Request::from_parts(request_parts, request_body);
|
||||
|
||||
let (mut upstream_conn, downstream_response) = {
|
||||
let upstream_addr = get_upstream_addr(&request_uri)?;
|
||||
let conn = TokioIo::new(
|
||||
tokio::net::TcpStream::connect(upstream_addr)
|
||||
.await
|
||||
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?,
|
||||
);
|
||||
let (mut sender, conn) = hyper::client::conn::http1::handshake(conn).await?;
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
if let Err(err) = conn.with_upgrades().await {
|
||||
warn!("Upgrading connection failed: {:?}", err);
|
||||
}
|
||||
});
|
||||
|
||||
let response = sender.send_request(upstream_request).await?;
|
||||
|
||||
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
|
||||
return Err(ProxyError::UpgradeError(
|
||||
"Server did not response with Switching Protocols status".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
let (response_parts, response_body) = response.into_parts();
|
||||
let upstream_response = Response::from_parts(response_parts.clone(), response_body);
|
||||
let downstream_response = Response::from_parts(response_parts, Empty::new());
|
||||
|
||||
(
|
||||
TokioIo::new(hyper::upgrade::on(upstream_response).await?),
|
||||
downstream_response,
|
||||
)
|
||||
};
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let mut downstream_conn = match hyper::upgrade::on(&mut downstream_request).await {
|
||||
Ok(upgraded) => TokioIo::new(upgraded),
|
||||
Err(e) => {
|
||||
warn!("Failed to upgrade request: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = copy_bidirectional(&mut downstream_conn, &mut upstream_conn).await {
|
||||
warn!("Bidirectional copy failed: {e}");
|
||||
}
|
||||
});
|
||||
|
||||
Ok(downstream_response.map(|body| body.map_err(std::io::Error::other).boxed_unsync()))
|
||||
}
|
||||
|
||||
pub struct ReverseProxy<T: Connect + Clone + Send + Sync + 'static> {
|
||||
@ -368,7 +394,7 @@ impl<T: Connect + Clone + Send + Sync + 'static> ReverseProxy<T> {
|
||||
client_ip: IpAddr,
|
||||
forward_uri: &str,
|
||||
request: Request<Incoming>,
|
||||
) -> Result<Response<Incoming>, ProxyError> {
|
||||
) -> Result<Response<ResponseBody>, ProxyError> {
|
||||
call::<T>(client_ip, forward_uri, request, &self.client).await
|
||||
}
|
||||
}
|
||||
@ -383,8 +409,8 @@ pub mod benches {
|
||||
super::create_proxied_response(response);
|
||||
}
|
||||
|
||||
pub fn forward_uri<B>(forward_url: &str, req: &crate::Request<B>) {
|
||||
super::forward_uri(forward_url, req);
|
||||
pub fn create_forward_uri<B>(forward_url: &str, req: &crate::Request<B>) {
|
||||
super::create_forward_uri(forward_url, req);
|
||||
}
|
||||
|
||||
pub fn create_proxied_request<B>(
|
||||
|
Loading…
Reference in New Issue
Block a user