Compare commits
14 Commits
dont-set-h
...
master
Author | SHA1 | Date | |
---|---|---|---|
d5a6f79918 | |||
|
dbbf9c3cca | ||
|
695f9639ef | ||
|
2ec415ecac | ||
|
598b99252e | ||
|
88e08c98f1 | ||
7adb97ceaa | |||
224f7bef5a | |||
1a9e3430dd | |||
8164878b7c | |||
1dc4618994 | |||
5fe9e29ae4 | |||
29ea682d8f | |||
907ea5b7f4 |
17
Cargo.toml
17
Cargo.toml
@ -23,30 +23,27 @@ name="internal"
|
||||
harness = false
|
||||
|
||||
[dependencies]
|
||||
hyper = { version = "0.14.18", features = ["client"] }
|
||||
lazy_static = "1.4.0"
|
||||
http-body-util = "0.1.0"
|
||||
hyper = { version = "1.2.0", features = ["client", "http1"] }
|
||||
hyper-util = { version = "0.1.3", features = ["client-legacy", "http1","tokio"] }
|
||||
tokio = { version = "1.17.0", features = ["io-util", "rt"] }
|
||||
tracing = "0.1.34"
|
||||
|
||||
[dev-dependencies]
|
||||
hyper = { version = "0.14.18", features = ["server"] }
|
||||
hyper = { version = "1.2.0", features = ["client", "http1", "server"] }
|
||||
futures = "0.3.21"
|
||||
async-trait = "0.1.53"
|
||||
async-tungstenite = { version = "0.17", features = ["tokio-runtime"] }
|
||||
tokio-test = "0.4.2"
|
||||
test-context = "0.1.3"
|
||||
tokiotest-httpserver = "0.2.1"
|
||||
hyper-trust-dns = { version = "0.4.2", features = [
|
||||
"rustls-http2",
|
||||
"dnssec-ring",
|
||||
"dns-over-https-rustls",
|
||||
"rustls-webpki"
|
||||
] }
|
||||
rand = "0.8.5"
|
||||
tungstenite = "0.17"
|
||||
url = "2.2"
|
||||
criterion = "0.3.5"
|
||||
hyper-rustls = "0.27.1"
|
||||
rustls = "0.23.6"
|
||||
|
||||
[features]
|
||||
|
||||
__bench=[]
|
||||
__bench=[]
|
||||
|
155
README.md
155
README.md
@ -1,14 +1,26 @@
|
||||
# This is a fork
|
||||
|
||||
This repo contains a fork of the [original hyper-reverse-proxy
|
||||
codebase][upstream], adding to it a few improvements:
|
||||
|
||||
- Fix to a bug where the `Host` header was getting overwritten on the upstream
|
||||
HTTP request.
|
||||
|
||||
- Upgraded hyper version to 1.x (and fixes related to that upgrade)
|
||||
|
||||
- Logging cleanup
|
||||
|
||||
Plus more as time goes on.
|
||||
|
||||
[upstream]: https://github.com/felipenoris/hyper-reverse-proxy
|
||||
|
||||
# hyper-reverse-proxy
|
||||
|
||||
[![License][license-img]](LICENSE)
|
||||
[![CI][ci-img]][ci-url]
|
||||
[![docs][docs-img]][docs-url]
|
||||
[![version][version-img]][version-url]
|
||||
|
||||
[license-img]: https://img.shields.io/crates/l/hyper-reverse-proxy.svg
|
||||
[ci-img]: https://github.com/felipenoris/hyper-reverse-proxy/workflows/CI/badge.svg
|
||||
[ci-url]: https://github.com/felipenoris/hyper-reverse-proxy/actions/workflows/main.yml
|
||||
[docs-img]: https://docs.rs/hyper-reverse-proxy/badge.svg
|
||||
[docs-url]: https://docs.rs/hyper-reverse-proxy
|
||||
[version-img]: https://img.shields.io/crates/v/hyper-reverse-proxy.svg
|
||||
@ -28,139 +40,16 @@ The implementation is based on Go's [`httputil.ReverseProxy`].
|
||||
|
||||
# Example
|
||||
|
||||
Add these dependencies to your `Cargo.toml` file.
|
||||
Run the example by cloning this repository and running:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
hyper-reverse-proxy = "?"
|
||||
hyper = { version = "?", features = ["full"] }
|
||||
tokio = { version = "?", features = ["full"] }
|
||||
lazy_static = "?"
|
||||
hyper-trust-dns = { version = "?", features = [
|
||||
"rustls-http2",
|
||||
"dnssec-ring",
|
||||
"dns-over-https-rustls",
|
||||
"rustls-webpki",
|
||||
"https-only"
|
||||
] }
|
||||
```shell
|
||||
cargo run --example simple
|
||||
```
|
||||
|
||||
The following example will set up a reverse proxy listening on `127.0.0.1:13900`,
|
||||
and will proxy these calls:
|
||||
The example will set up a reverse proxy listening on `127.0.0.1:8000`, and will proxy these calls:
|
||||
|
||||
* `"/target/first"` will be proxied to `http://127.0.0.1:13901`
|
||||
* `http://service1.localhost:8000` will be proxied to `http://127.0.0.1:13901`
|
||||
|
||||
* `"/target/second"` will be proxied to `http://127.0.0.1:13902`
|
||||
* `http://service2.localhost:8000` will be proxied to `http://127.0.0.1:13902`
|
||||
|
||||
* All other URLs will be handled by `debug_request` function, that will display request information.
|
||||
|
||||
```rust
|
||||
use hyper::server::conn::AddrStream;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Request, Response, Server, StatusCode};
|
||||
use hyper_reverse_proxy::ReverseProxy;
|
||||
use hyper_trust_dns::{RustlsHttpsConnector, TrustDnsResolver};
|
||||
use std::net::IpAddr;
|
||||
use std::{convert::Infallible, net::SocketAddr};
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref PROXY_CLIENT: ReverseProxy<RustlsHttpsConnector> = {
|
||||
ReverseProxy::new(
|
||||
hyper::Client::builder().build::<_, hyper::Body>(TrustDnsResolver::default().into_rustls_webpki_https_connector()),
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
fn debug_request(req: &Request<Body>) -> Result<Response<Body>, Infallible> {
|
||||
let body_str = format!("{:?}", req);
|
||||
Ok(Response::new(Body::from(body_str)))
|
||||
}
|
||||
|
||||
async fn handle(client_ip: IpAddr, req: Request<Body>) -> Result<Response<Body>, Infallible> {
|
||||
if req.uri().path().starts_with("/target/first") {
|
||||
match PROXY_CLIENT.call(client_ip, "http://127.0.0.1:13901", req)
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
Ok(response)
|
||||
},
|
||||
Err(_error) => {
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::empty())
|
||||
.unwrap())},
|
||||
}
|
||||
} else if req.uri().path().starts_with("/target/second") {
|
||||
match PROXY_CLIENT.call(client_ip, "http://127.0.0.1:13902", req)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response),
|
||||
Err(_error) => Ok(Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::empty())
|
||||
.unwrap()),
|
||||
}
|
||||
} else {
|
||||
debug_request(&req)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let bind_addr = "127.0.0.1:8000";
|
||||
let addr: SocketAddr = bind_addr.parse().expect("Could not parse ip:port.");
|
||||
|
||||
let make_svc = make_service_fn(|conn: &AddrStream| {
|
||||
let remote_addr = conn.remote_addr().ip();
|
||||
async move { Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req))) }
|
||||
});
|
||||
|
||||
let server = Server::bind(&addr).serve(make_svc);
|
||||
|
||||
println!("Running server on {:?}", addr);
|
||||
|
||||
if let Err(e) = server.await {
|
||||
eprintln!("server error: {}", e);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### A word about Security
|
||||
|
||||
Handling outgoing requests can be a security nightmare. This crate does not control the client for the outgoing requests, as it needs to be supplied to the proxy call. The following chapters may give you an overview on how you can secure your client using the `hyper-trust-dns` crate.
|
||||
|
||||
> You can see them being used in the example.
|
||||
|
||||
#### HTTPS
|
||||
|
||||
You should use a secure transport in order to know who you are talking to and so you can trust the connection. By default `hyper-trust-dns` enables the feature flag `https-only` which will panic if you supply a transport scheme which isn't `https`. It is a healthy default as it's not only you needing to trust the source but also everyone else seeing the content on unsecure connections.
|
||||
|
||||
> ATTENTION: if you are running on a host with added certificates in your cert store, make sure to audit them in a interval, so neither old certificates nor malicious certificates are considered as valid by your client.
|
||||
|
||||
#### TLS 1.2
|
||||
|
||||
By default `tls 1.2` is disabled in favor of `tls 1.3`, because many parts of `tls 1.2` can be considered as attach friendly. As not yet all services support it `tls 1.2` can be enabled via the `rustls-tls-12` feature.
|
||||
|
||||
> ATTENTION: make sure to audit the services you connect to on an interval
|
||||
|
||||
#### DNSSEC
|
||||
|
||||
As dns queries and entries aren't "trustworthy" by default from a security standpoint. `DNSSEC` adds a new cryptographic layer for verification. To enable it use the `dnssec-ring` feature.
|
||||
|
||||
#### HTTP/2
|
||||
|
||||
By default only rustlss `http1` feature is enabled for dns queries. While `http/3` might be just around the corner. `http/2` support can be enabled using the `rustls-http2` feature.
|
||||
|
||||
#### DoT & DoH
|
||||
|
||||
DoT and DoH provide you with a secure transport between you and your dns.
|
||||
|
||||
By default none of them are enabled. If you would like to enabled them, you can do so using the features `doh` and `dot`.
|
||||
|
||||
Recommendations:
|
||||
- If you need to monitor network activities in relation to accessed ports, use dot with the `dns-over-rustls` feature flag
|
||||
- If you are out in the wild and have no need to monitor based on ports, doh with the `dns-over-https-rustls` feature flag as it will blend in with other `https` traffic
|
||||
|
||||
It is highly recommended to use one of them.
|
||||
|
||||
> Currently only includes dns queries as `esni` or `ech` is still in draft by the `ietf`
|
||||
* All other URLs will display request information.
|
||||
|
@ -10,16 +10,14 @@ use rand::distributions::Alphanumeric;
|
||||
use rand::prelude::*;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::OnceLock;
|
||||
use test_context::AsyncTestContext;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokiotest_httpserver::HttpTestContext;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref PROXY_CLIENT: ReverseProxy<HttpConnector<GaiResolver>> = {
|
||||
ReverseProxy::new(
|
||||
hyper::Client::new(),
|
||||
)
|
||||
};
|
||||
fn proxy_client() -> &'static ReverseProxy<HttpConnector<GaiResolver>> {
|
||||
static PROXY_CLIENT: OnceLock<ReverseProxy<HttpConnector<GaiResolver>>> = OnceLock::new();
|
||||
PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new()))
|
||||
}
|
||||
|
||||
fn create_proxied_response(b: &mut Criterion) {
|
||||
@ -46,7 +44,7 @@ fn generate_string() -> String {
|
||||
}
|
||||
|
||||
fn build_headers() -> HeaderMap {
|
||||
let mut headers_map: HeaderMap = (&*internal_benches::hop_headers())
|
||||
let mut headers_map: HeaderMap = (internal_benches::hop_headers())
|
||||
.iter()
|
||||
.map(|el: &'static HeaderName| (el.clone(), generate_string().parse().unwrap()))
|
||||
.collect();
|
||||
@ -86,7 +84,7 @@ fn proxy_call(b: &mut Criterion) {
|
||||
|
||||
*request.headers_mut().unwrap() = headers_map.clone();
|
||||
|
||||
black_box(&PROXY_CLIENT)
|
||||
black_box(&proxy_client())
|
||||
.call(
|
||||
black_box(client_ip),
|
||||
black_box(forward_url),
|
||||
|
@ -1,67 +1,122 @@
|
||||
use hyper::server::conn::AddrStream;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Request, Response, Server, StatusCode};
|
||||
use std::convert::Infallible;
|
||||
use std::io;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use http_body_util::combinators::UnsyncBoxBody;
|
||||
use http_body_util::{BodyExt, Empty, Full};
|
||||
use hyper::body::{Bytes, Incoming};
|
||||
use hyper::server::conn::http1;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use hyper_reverse_proxy::ReverseProxy;
|
||||
use hyper_trust_dns::{RustlsHttpsConnector, TrustDnsResolver};
|
||||
use std::net::IpAddr;
|
||||
use std::{convert::Infallible, net::SocketAddr};
|
||||
use hyper_rustls::{ConfigBuilderExt, HttpsConnector};
|
||||
use hyper_util::client::legacy::connect::HttpConnector;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref PROXY_CLIENT: ReverseProxy<RustlsHttpsConnector> = {
|
||||
type Connector = HttpsConnector<HttpConnector>;
|
||||
type ResponseBody = UnsyncBoxBody<Bytes, std::io::Error>;
|
||||
|
||||
fn proxy_client() -> &'static ReverseProxy<Connector> {
|
||||
static PROXY_CLIENT: OnceLock<ReverseProxy<Connector>> = OnceLock::new();
|
||||
PROXY_CLIENT.get_or_init(|| {
|
||||
let connector: Connector = Connector::builder()
|
||||
.with_tls_config(
|
||||
rustls::ClientConfig::builder()
|
||||
.with_native_roots()
|
||||
.expect("with_native_roots")
|
||||
.with_no_client_auth(),
|
||||
)
|
||||
.https_or_http()
|
||||
.enable_http1()
|
||||
.build();
|
||||
ReverseProxy::new(
|
||||
hyper::Client::builder().build::<_, hyper::Body>(TrustDnsResolver::default().into_rustls_webpki_https_connector()),
|
||||
hyper_util::client::legacy::Builder::new(TokioExecutor::new())
|
||||
.pool_idle_timeout(Duration::from_secs(3))
|
||||
.pool_timer(TokioTimer::new())
|
||||
.build::<_, Incoming>(connector),
|
||||
)
|
||||
};
|
||||
})
|
||||
}
|
||||
|
||||
fn debug_request(req: &Request<Body>) -> Result<Response<Body>, Infallible> {
|
||||
let body_str = format!("{:?}", req);
|
||||
Ok(Response::new(Body::from(body_str)))
|
||||
}
|
||||
|
||||
async fn handle(client_ip: IpAddr, req: Request<Body>) -> Result<Response<Body>, Infallible> {
|
||||
if req.uri().path().starts_with("/target/first") {
|
||||
match PROXY_CLIENT
|
||||
async fn handle(
|
||||
client_ip: IpAddr,
|
||||
req: Request<Incoming>,
|
||||
) -> Result<Response<ResponseBody>, Infallible> {
|
||||
let host = req.headers().get("host").and_then(|v| v.to_str().ok());
|
||||
if host.is_some_and(|host| host.starts_with("service1.localhost")) {
|
||||
match proxy_client()
|
||||
.call(client_ip, "http://127.0.0.1:13901", req)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response),
|
||||
Err(_error) => Ok(Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::empty())
|
||||
.body(UnsyncBoxBody::new(
|
||||
Empty::<Bytes>::new().map_err(io::Error::other),
|
||||
))
|
||||
.unwrap()),
|
||||
}
|
||||
} else if req.uri().path().starts_with("/target/second") {
|
||||
match PROXY_CLIENT
|
||||
} else if host.is_some_and(|host| host.starts_with("service2.localhost")) {
|
||||
match proxy_client()
|
||||
.call(client_ip, "http://127.0.0.1:13902", req)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response),
|
||||
Err(_error) => Ok(Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::empty())
|
||||
.body(UnsyncBoxBody::new(
|
||||
Empty::<Bytes>::new().map_err(io::Error::other),
|
||||
))
|
||||
.unwrap()),
|
||||
}
|
||||
} else {
|
||||
debug_request(&req)
|
||||
let body_str = format!("{:?}", req);
|
||||
Ok(Response::new(UnsyncBoxBody::new(
|
||||
Full::new(Bytes::from(body_str)).map_err(io::Error::other),
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let bind_addr = "127.0.0.1:8000";
|
||||
let addr: SocketAddr = bind_addr.parse().expect("Could not parse ip:port.");
|
||||
|
||||
let make_svc = make_service_fn(|conn: &AddrStream| {
|
||||
let remote_addr = conn.remote_addr().ip();
|
||||
async move { Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req))) }
|
||||
});
|
||||
// We create a TcpListener and bind it to the address
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
|
||||
let server = Server::bind(&addr).serve(make_svc);
|
||||
println!(
|
||||
"Access service1 on http://service1.localhost:{}",
|
||||
addr.port()
|
||||
);
|
||||
println!(
|
||||
"Access service2 on http://service2.localhost:{}",
|
||||
addr.port()
|
||||
);
|
||||
|
||||
println!("Running server on {:?}", addr);
|
||||
// We start a loop to continuously accept incoming connections
|
||||
loop {
|
||||
let (stream, remote_addr) = listener.accept().await?;
|
||||
let client_ip = remote_addr.ip();
|
||||
|
||||
if let Err(e) = server.await {
|
||||
eprintln!("server error: {}", e);
|
||||
// Use an adapter to access something implementing `tokio::io` traits as if they implement
|
||||
// `hyper::rt` IO traits.
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
// Spawn a tokio task to serve multiple connections concurrently
|
||||
tokio::task::spawn(async move {
|
||||
// Finally, we bind the incoming connection to our `hello` service
|
||||
if let Err(err) = http1::Builder::new()
|
||||
// `service_fn` converts our function in a `Service`
|
||||
.serve_connection(io, service_fn(move |req| handle(client_ip, req)))
|
||||
.await
|
||||
{
|
||||
eprintln!("Error serving connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
278
src/lib.rs
278
src/lib.rs
@ -3,43 +3,78 @@
|
||||
#[macro_use]
|
||||
extern crate tracing;
|
||||
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use hyper::http::header::{InvalidHeaderValue, ToStrError};
|
||||
use hyper::http::uri::InvalidUri;
|
||||
use hyper::upgrade::OnUpgrade;
|
||||
use hyper::{Body, Client, Error, Request, Response, StatusCode};
|
||||
use lazy_static::lazy_static;
|
||||
use std::net::IpAddr;
|
||||
use hyper::{body::Incoming, Error, Request, Response, StatusCode};
|
||||
use hyper_util::client::legacy::{connect::Connect, Client, Error as LegacyError};
|
||||
use hyper_util::rt::tokio::TokioIo;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::OnceLock;
|
||||
use tokio::io::copy_bidirectional;
|
||||
|
||||
lazy_static! {
|
||||
static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
|
||||
static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection");
|
||||
static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade");
|
||||
static ref TRAILER_HEADER: HeaderName = HeaderName::from_static("trailer");
|
||||
static ref TRAILERS_HEADER: HeaderName = HeaderName::from_static("trailers");
|
||||
// A list of the headers, using hypers actual HeaderName comparison
|
||||
static ref HOP_HEADERS: [HeaderName; 9] = [
|
||||
CONNECTION_HEADER.clone(),
|
||||
TE_HEADER.clone(),
|
||||
TRAILER_HEADER.clone(),
|
||||
HeaderName::from_static("keep-alive"),
|
||||
HeaderName::from_static("proxy-connection"),
|
||||
HeaderName::from_static("proxy-authenticate"),
|
||||
HeaderName::from_static("proxy-authorization"),
|
||||
HeaderName::from_static("transfer-encoding"),
|
||||
HeaderName::from_static("upgrade"),
|
||||
];
|
||||
fn te_header() -> &'static HeaderName {
|
||||
static TE_HEADER: OnceLock<HeaderName> = OnceLock::new();
|
||||
TE_HEADER.get_or_init(|| HeaderName::from_static("te"))
|
||||
}
|
||||
|
||||
static ref X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
|
||||
fn connection_header() -> &'static HeaderName {
|
||||
static CONNECTION_HEADER: OnceLock<HeaderName> = OnceLock::new();
|
||||
CONNECTION_HEADER.get_or_init(|| HeaderName::from_static("connection"))
|
||||
}
|
||||
|
||||
fn upgrade_header() -> &'static HeaderName {
|
||||
static UPGRADE_HEADER: OnceLock<HeaderName> = OnceLock::new();
|
||||
UPGRADE_HEADER.get_or_init(|| HeaderName::from_static("upgrade"))
|
||||
}
|
||||
|
||||
fn trailer_header() -> &'static HeaderName {
|
||||
static TRAILER_HEADER: OnceLock<HeaderName> = OnceLock::new();
|
||||
TRAILER_HEADER.get_or_init(|| HeaderName::from_static("trailer"))
|
||||
}
|
||||
|
||||
fn trailers_header() -> &'static HeaderName {
|
||||
static TRAILERS_HEADER: OnceLock<HeaderName> = OnceLock::new();
|
||||
TRAILERS_HEADER.get_or_init(|| HeaderName::from_static("trailers"))
|
||||
}
|
||||
|
||||
fn x_forwarded_for_header() -> &'static HeaderName {
|
||||
static X_FORWARDED_FOR: OnceLock<HeaderName> = OnceLock::new();
|
||||
X_FORWARDED_FOR.get_or_init(|| HeaderName::from_static("x-forwarded-for"))
|
||||
}
|
||||
|
||||
fn hop_headers() -> &'static [HeaderName; 9] {
|
||||
static HOP_HEADERS: OnceLock<[HeaderName; 9]> = OnceLock::new();
|
||||
HOP_HEADERS.get_or_init(|| {
|
||||
[
|
||||
connection_header().clone(),
|
||||
te_header().clone(),
|
||||
trailer_header().clone(),
|
||||
HeaderName::from_static("keep-alive"),
|
||||
HeaderName::from_static("proxy-connection"),
|
||||
HeaderName::from_static("proxy-authenticate"),
|
||||
HeaderName::from_static("proxy-authorization"),
|
||||
HeaderName::from_static("transfer-encoding"),
|
||||
HeaderName::from_static("upgrade"),
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProxyError {
|
||||
InvalidUri(InvalidUri),
|
||||
LegacyHyperError(LegacyError),
|
||||
HyperError(Error),
|
||||
ForwardHeaderError,
|
||||
UpgradeError(String),
|
||||
UpstreamError(String),
|
||||
}
|
||||
|
||||
impl From<LegacyError> for ProxyError {
|
||||
fn from(err: LegacyError) -> ProxyError {
|
||||
ProxyError::LegacyHyperError(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Error> for ProxyError {
|
||||
@ -69,25 +104,25 @@ impl From<InvalidHeaderValue> for ProxyError {
|
||||
fn remove_hop_headers(headers: &mut HeaderMap) {
|
||||
debug!("Removing hop headers");
|
||||
|
||||
for header in &*HOP_HEADERS {
|
||||
for header in hop_headers() {
|
||||
headers.remove(header);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
|
||||
#[allow(clippy::blocks_in_if_conditions)]
|
||||
#[allow(clippy::blocks_in_conditions)]
|
||||
if headers
|
||||
.get(&*CONNECTION_HEADER)
|
||||
.get(connection_header())
|
||||
.map(|value| {
|
||||
value
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.split(',')
|
||||
.any(|e| e.trim() == *UPGRADE_HEADER)
|
||||
.any(|e| e.trim() == *upgrade_header())
|
||||
})
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
|
||||
if let Some(upgrade_value) = headers.get(upgrade_header()) {
|
||||
debug!(
|
||||
"Found upgrade header with value: {}",
|
||||
upgrade_value.to_str().unwrap().to_owned()
|
||||
@ -101,10 +136,10 @@ fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
|
||||
}
|
||||
|
||||
fn remove_connection_headers(headers: &mut HeaderMap) {
|
||||
if headers.get(&*CONNECTION_HEADER).is_some() {
|
||||
if headers.get(connection_header()).is_some() {
|
||||
debug!("Removing connection headers");
|
||||
|
||||
let value = headers.get(&*CONNECTION_HEADER).cloned().unwrap();
|
||||
let value = headers.get(connection_header()).cloned().unwrap();
|
||||
|
||||
for name in value.to_str().unwrap().split(',') {
|
||||
if !name.trim().is_empty() {
|
||||
@ -115,7 +150,7 @@ fn remove_connection_headers(headers: &mut HeaderMap) {
|
||||
}
|
||||
|
||||
fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
|
||||
info!("Creating proxied response");
|
||||
debug!("Creating proxied response");
|
||||
|
||||
remove_hop_headers(response.headers_mut());
|
||||
remove_connection_headers(response.headers_mut());
|
||||
@ -123,12 +158,12 @@ fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
|
||||
response
|
||||
}
|
||||
|
||||
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
|
||||
fn create_forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
|
||||
debug!("Building forward uri");
|
||||
|
||||
let split_url = forward_url.split('?').collect::<Vec<&str>>();
|
||||
|
||||
let mut base_url: &str = split_url.get(0).unwrap_or(&"");
|
||||
let mut base_url: &str = split_url.first().unwrap_or(&"");
|
||||
let forward_url_query: &str = split_url.get(1).unwrap_or(&"");
|
||||
|
||||
let path2 = req.uri().path();
|
||||
@ -203,34 +238,25 @@ fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
|
||||
|
||||
fn create_proxied_request<B>(
|
||||
client_ip: IpAddr,
|
||||
forward_url: &str,
|
||||
mut request: Request<B>,
|
||||
upgrade_type: Option<&String>,
|
||||
) -> Result<Request<B>, ProxyError> {
|
||||
info!("Creating proxied request");
|
||||
debug!("Creating proxied request");
|
||||
|
||||
let contains_te_trailers_value = request
|
||||
.headers()
|
||||
.get(&*TE_HEADER)
|
||||
.get(te_header())
|
||||
.map(|value| {
|
||||
value
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.split(',')
|
||||
.any(|e| e.trim() == *TRAILERS_HEADER)
|
||||
.any(|e| e.trim() == *trailers_header())
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
|
||||
|
||||
debug!("Setting headers of proxied request");
|
||||
|
||||
//request
|
||||
// .headers_mut()
|
||||
// .insert(HOST, HeaderValue::from_str(uri.host().unwrap())?);
|
||||
|
||||
*request.uri_mut() = uri;
|
||||
|
||||
remove_hop_headers(request.headers_mut());
|
||||
remove_connection_headers(request.headers_mut());
|
||||
|
||||
@ -239,7 +265,7 @@ fn create_proxied_request<B>(
|
||||
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(&*TE_HEADER, HeaderValue::from_static("trailers"));
|
||||
.insert(te_header(), HeaderValue::from_static("trailers"));
|
||||
}
|
||||
|
||||
if let Some(value) = upgrade_type {
|
||||
@ -247,21 +273,21 @@ fn create_proxied_request<B>(
|
||||
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(&*UPGRADE_HEADER, value.parse().unwrap());
|
||||
.insert(upgrade_header(), value.parse().unwrap());
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE"));
|
||||
.insert(connection_header(), HeaderValue::from_static("UPGRADE"));
|
||||
}
|
||||
|
||||
// Add forwarding information in the headers
|
||||
match request.headers_mut().entry(&*X_FORWARDED_FOR) {
|
||||
match request.headers_mut().entry(x_forwarded_for_header()) {
|
||||
hyper::header::Entry::Vacant(entry) => {
|
||||
debug!("X-Fowraded-for header was vacant");
|
||||
debug!("X-Forwarded-for header was vacant");
|
||||
entry.insert(client_ip.to_string().parse()?);
|
||||
}
|
||||
|
||||
hyper::header::Entry::Occupied(entry) => {
|
||||
debug!("X-Fowraded-for header was occupied");
|
||||
debug!("X-Forwarded-for header was occupied");
|
||||
let client_ip_str = client_ip.to_string();
|
||||
let mut addr =
|
||||
String::with_capacity(entry.get().as_bytes().len() + 2 + client_ip_str.len());
|
||||
@ -278,13 +304,30 @@ fn create_proxied_request<B>(
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync + 'static>(
|
||||
fn get_upstream_addr(forward_uri: &str) -> Result<SocketAddr, ProxyError> {
|
||||
let forward_uri: hyper::Uri = forward_uri.parse().map_err(|e| {
|
||||
ProxyError::UpstreamError(format!("parsing forward_uri as a Uri: {e}").to_string())
|
||||
})?;
|
||||
let host = forward_uri.host().ok_or(ProxyError::UpstreamError(
|
||||
"forward_uri has no host".to_string(),
|
||||
))?;
|
||||
let port = forward_uri.port_u16().ok_or(ProxyError::UpstreamError(
|
||||
"forward_uri has no port".to_string(),
|
||||
))?;
|
||||
format!("{host}:{port}").parse().map_err(|_| {
|
||||
ProxyError::UpstreamError("forward_uri host must be an IP address".to_string())
|
||||
})
|
||||
}
|
||||
|
||||
type ResponseBody = http_body_util::combinators::UnsyncBoxBody<hyper::body::Bytes, std::io::Error>;
|
||||
|
||||
pub async fn call<'a, T: Connect + Clone + Send + Sync + 'static>(
|
||||
client_ip: IpAddr,
|
||||
forward_uri: &str,
|
||||
mut request: Request<Body>,
|
||||
client: &'a Client<T>,
|
||||
) -> Result<Response<Body>, ProxyError> {
|
||||
info!(
|
||||
request: Request<Incoming>,
|
||||
client: &'a Client<T, Incoming>,
|
||||
) -> Result<Response<ResponseBody>, ProxyError> {
|
||||
debug!(
|
||||
"Received proxy call from {} to {}, client: {}",
|
||||
request.uri().to_string(),
|
||||
forward_uri,
|
||||
@ -292,64 +335,83 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync +
|
||||
);
|
||||
|
||||
let request_upgrade_type = get_upgrade_type(request.headers());
|
||||
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
|
||||
|
||||
let proxied_request = create_proxied_request(
|
||||
client_ip,
|
||||
forward_uri,
|
||||
request,
|
||||
request_upgrade_type.as_ref(),
|
||||
)?;
|
||||
let mut response = client.request(proxied_request).await?;
|
||||
let mut request = create_proxied_request(client_ip, request, request_upgrade_type.as_ref())?;
|
||||
|
||||
if response.status() == StatusCode::SWITCHING_PROTOCOLS {
|
||||
let response_upgrade_type = get_upgrade_type(response.headers());
|
||||
if request_upgrade_type.is_none() {
|
||||
let request_uri: hyper::Uri = create_forward_uri(forward_uri, &request).parse()?;
|
||||
*request.uri_mut() = request_uri.clone();
|
||||
|
||||
if request_upgrade_type == response_upgrade_type {
|
||||
if let Some(request_upgraded) = request_upgraded {
|
||||
let mut response_upgraded = response
|
||||
.extensions_mut()
|
||||
.remove::<OnUpgrade>()
|
||||
.expect("response does not have an upgrade extension")
|
||||
.await?;
|
||||
|
||||
debug!("Responding to a connection upgrade response");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut request_upgraded =
|
||||
request_upgraded.await.expect("failed to upgrade request");
|
||||
|
||||
copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
|
||||
.await
|
||||
.expect("coping between upgraded connections failed");
|
||||
});
|
||||
|
||||
Ok(response)
|
||||
} else {
|
||||
Err(ProxyError::UpgradeError(
|
||||
"request does not have an upgrade extension".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(ProxyError::UpgradeError(format!(
|
||||
"backend tried to switch to protocol {:?} when {:?} was requested",
|
||||
response_upgrade_type, request_upgrade_type
|
||||
)))
|
||||
}
|
||||
} else {
|
||||
let proxied_response = create_proxied_response(response);
|
||||
let response = client.request(request).await?;
|
||||
|
||||
debug!("Responding to call with response");
|
||||
Ok(proxied_response)
|
||||
return Ok(create_proxied_response(
|
||||
response.map(|body| body.map_err(std::io::Error::other).boxed_unsync()),
|
||||
));
|
||||
}
|
||||
|
||||
let upstream_addr = get_upstream_addr(forward_uri)?;
|
||||
let (request_parts, request_body) = request.into_parts();
|
||||
let upstream_request =
|
||||
Request::from_parts(request_parts.clone(), Empty::<hyper::body::Bytes>::new());
|
||||
let mut downstream_request = Request::from_parts(request_parts, request_body);
|
||||
|
||||
let (mut upstream_conn, downstream_response) = {
|
||||
let conn = TokioIo::new(
|
||||
tokio::net::TcpStream::connect(upstream_addr)
|
||||
.await
|
||||
.map_err(|e| ProxyError::UpstreamError(e.to_string()))?,
|
||||
);
|
||||
let (mut sender, conn) = hyper::client::conn::http1::handshake(conn).await?;
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
if let Err(err) = conn.with_upgrades().await {
|
||||
warn!("Upgrading connection failed: {:?}", err);
|
||||
}
|
||||
});
|
||||
|
||||
let response = sender.send_request(upstream_request).await?;
|
||||
|
||||
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
|
||||
return Err(ProxyError::UpgradeError(
|
||||
"Server did not response with Switching Protocols status".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
let (response_parts, response_body) = response.into_parts();
|
||||
let upstream_response = Response::from_parts(response_parts.clone(), response_body);
|
||||
let downstream_response = Response::from_parts(response_parts, Empty::new());
|
||||
|
||||
(
|
||||
TokioIo::new(hyper::upgrade::on(upstream_response).await?),
|
||||
downstream_response,
|
||||
)
|
||||
};
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let mut downstream_conn = match hyper::upgrade::on(&mut downstream_request).await {
|
||||
Ok(upgraded) => TokioIo::new(upgraded),
|
||||
Err(e) => {
|
||||
warn!("Failed to upgrade request: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = copy_bidirectional(&mut downstream_conn, &mut upstream_conn).await {
|
||||
warn!("Bidirectional copy failed: {e}");
|
||||
}
|
||||
});
|
||||
|
||||
Ok(downstream_response.map(|body| body.map_err(std::io::Error::other).boxed_unsync()))
|
||||
}
|
||||
|
||||
pub struct ReverseProxy<T: hyper::client::connect::Connect + Clone + Send + Sync + 'static> {
|
||||
client: Client<T>,
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReverseProxy<T: Connect + Clone + Send + Sync + 'static> {
|
||||
client: Client<T, Incoming>,
|
||||
}
|
||||
|
||||
impl<T: hyper::client::connect::Connect + Clone + Send + Sync + 'static> ReverseProxy<T> {
|
||||
pub fn new(client: Client<T>) -> Self {
|
||||
impl<T: Connect + Clone + Send + Sync + 'static> ReverseProxy<T> {
|
||||
pub fn new(client: Client<T, Incoming>) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
|
||||
@ -357,8 +419,8 @@ impl<T: hyper::client::connect::Connect + Clone + Send + Sync + 'static> Reverse
|
||||
&self,
|
||||
client_ip: IpAddr,
|
||||
forward_uri: &str,
|
||||
request: Request<Body>,
|
||||
) -> Result<Response<Body>, ProxyError> {
|
||||
request: Request<Incoming>,
|
||||
) -> Result<Response<ResponseBody>, ProxyError> {
|
||||
call::<T>(client_ip, forward_uri, request, &self.client).await
|
||||
}
|
||||
}
|
||||
@ -373,8 +435,8 @@ pub mod benches {
|
||||
super::create_proxied_response(response);
|
||||
}
|
||||
|
||||
pub fn forward_uri<B>(forward_url: &str, req: &crate::Request<B>) {
|
||||
super::forward_uri(forward_url, req);
|
||||
pub fn create_forward_uri<B>(forward_url: &str, req: &crate::Request<B>) {
|
||||
super::create_forward_uri(forward_url, req);
|
||||
}
|
||||
|
||||
pub fn create_proxied_request<B>(
|
||||
|
@ -7,6 +7,7 @@ use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
|
||||
use hyper_reverse_proxy::ReverseProxy;
|
||||
use std::convert::Infallible;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::OnceLock;
|
||||
use test_context::test_context;
|
||||
use test_context::AsyncTestContext;
|
||||
use tokio::sync::oneshot::Sender;
|
||||
@ -14,12 +15,9 @@ use tokio::task::JoinHandle;
|
||||
use tokiotest_httpserver::handler::HandlerBuilder;
|
||||
use tokiotest_httpserver::{take_port, HttpTestContext};
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref PROXY_CLIENT: ReverseProxy<HttpConnector<GaiResolver>> = {
|
||||
ReverseProxy::new(
|
||||
hyper::Client::new(),
|
||||
)
|
||||
};
|
||||
fn proxy_client() -> &'static ReverseProxy<HttpConnector<GaiResolver>> {
|
||||
static PROXY_CLIENT: OnceLock<ReverseProxy<HttpConnector<GaiResolver>>> = OnceLock::new();
|
||||
PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new()))
|
||||
}
|
||||
|
||||
struct ProxyTestContext {
|
||||
@ -99,7 +97,7 @@ async fn handle(
|
||||
req: Request<Body>,
|
||||
backend_port: u16,
|
||||
) -> Result<Response<Body>, Infallible> {
|
||||
match PROXY_CLIENT
|
||||
match proxy_client()
|
||||
.call(
|
||||
client_ip,
|
||||
format!("http://127.0.0.1:{}", backend_port).as_str(),
|
||||
|
@ -2,6 +2,7 @@ use std::{
|
||||
convert::Infallible,
|
||||
net::{IpAddr, SocketAddr},
|
||||
process::exit,
|
||||
sync::OnceLock,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
@ -20,12 +21,9 @@ use tokiotest_httpserver::take_port;
|
||||
use tungstenite::Message;
|
||||
use url::Url;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref PROXY_CLIENT: ReverseProxy<HttpConnector<GaiResolver>> = {
|
||||
ReverseProxy::new(
|
||||
hyper::Client::new(),
|
||||
)
|
||||
};
|
||||
fn proxy_client() -> &'static ReverseProxy<HttpConnector<GaiResolver>> {
|
||||
static PROXY_CLIENT: OnceLock<ReverseProxy<HttpConnector<GaiResolver>>> = OnceLock::new();
|
||||
PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new()))
|
||||
}
|
||||
|
||||
struct ProxyTestContext {
|
||||
@ -66,7 +64,7 @@ async fn handle(
|
||||
req: Request<Body>,
|
||||
backend_port: u16,
|
||||
) -> Result<Response<Body>, Infallible> {
|
||||
match PROXY_CLIENT
|
||||
match proxy_client()
|
||||
.call(
|
||||
client_ip,
|
||||
format!("http://127.0.0.1:{}", backend_port).as_str(),
|
||||
|
Loading…
Reference in New Issue
Block a user