feat: static client (#27)

* feat: static client

* feat: client as input
This commit is contained in:
Christof Weickhardt 2022-05-01 20:36:09 +02:00 committed by GitHub
parent 537484122d
commit 96a398de85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 236 additions and 126 deletions

View File

@ -18,7 +18,7 @@ jobs:
- uses: actions-rs/toolchain@v1 - uses: actions-rs/toolchain@v1
with: with:
profile: minimal profile: minimal
toolchain: stable toolchain: nightly
override: true override: true
- uses: actions-rs/cargo@v1 - uses: actions-rs/cargo@v1
with: with:
@ -71,7 +71,7 @@ jobs:
- uses: actions-rs/toolchain@v1 - uses: actions-rs/toolchain@v1
with: with:
toolchain: stable toolchain: nightly
override: true override: true
- name: Test Dev - name: Test Dev

View File

@ -1,7 +1,11 @@
[package] [package]
name = "hyper-reverse-proxy" name = "hyper-reverse-proxy"
version = "0.5.2-dev" version = "0.5.2-dev"
authors = ["Brendan Zabarauskas <bjzaba@yahoo.com.au>", "Felipe Noronha <felipenoris@gmail.com>", "Jan Kantert <jan-hyper-reverse-proxy@kantert.net>"] authors = [
"Brendan Zabarauskas <bjzaba@yahoo.com.au>",
"Felipe Noronha <felipenoris@gmail.com>",
"Jan Kantert <jan-hyper-reverse-proxy@kantert.net>",
]
license = "Apache-2.0" license = "Apache-2.0"
description = "A simple reverse proxy, to be used with Hyper and Tokio." description = "A simple reverse proxy, to be used with Hyper and Tokio."
homepage = "https://github.com/felipenoris/hyper-reverse-proxy" homepage = "https://github.com/felipenoris/hyper-reverse-proxy"
@ -12,35 +16,25 @@ categories = ["network-programming", "web-programming"]
readme = "README.md" readme = "README.md"
edition = "2018" edition = "2018"
include = [ include = ["Cargo.toml", "LICENSE", "src/**/*"]
"Cargo.toml",
"LICENSE",
"src/**/*"
]
[dependencies] [dependencies]
hyper = { version = "0.14.18", features = ["full"] } hyper = { version = "0.14.18", features = ["client"] }
hyper-trust-dns = { version = "0.4.2", optional = true, default-features = false, features = ["rustls-webpki", "rustls-http1"] }
lazy_static = "1.4.0" lazy_static = "1.4.0"
rand = "0.8.5"
tracing = "0.1.34" tracing = "0.1.34"
[dev-dependencies] [dev-dependencies]
hyper = { version = "0.14.18", features = ["server"] }
tokio = { version = "1.17.0", features = ["full"] } tokio = { version = "1.17.0", features = ["full"] }
futures = "0.3.21" futures = "0.3.21"
async-trait = "0.1.53" async-trait = "0.1.53"
tokio-test = "0.4.2" tokio-test = "0.4.2"
test-context = "0.1.3" test-context = "0.1.3"
tokiotest-httpserver = "0.2.1" tokiotest-httpserver = "0.2.1"
hyper-trust-dns = { version = "0.4.2", features = [
[features] "rustls-http2",
default = ["https"] "dnssec-ring",
"dns-over-https-rustls",
https = ["hyper-trust-dns", "dnssec", "hyper-trust-dns/rustls-webpki", "http2"] "rustls-webpki"
doh = ["hyper-trust-dns/dns-over-https-rustls"] ] }
dot = ["hyper-trust-dns/dns-over-rustls"] rand = "0.8.5"
dnssec = ["hyper-trust-dns/dnssec-ring"]
http2 = ["hyper/http2", "hyper-trust-dns/rustls-http2"]
https-only = ["hyper-trust-dns/https-only"]
tls-1-2 = ["hyper-trust-dns/rustls-tls-12"]
native-cert-store = ["hyper-trust-dns/rustls-native"]

View File

@ -32,9 +32,17 @@ Add these dependencies to your `Cargo.toml` file.
```toml ```toml
[dependencies] [dependencies]
hyper-reverse-proxy = "0.5" hyper-reverse-proxy = "?"
hyper = { version = "0.14", features = ["full"] } hyper = { version = "?", features = ["full"] }
tokio = { version = "1", features = ["full"] } tokio = { version = "?", features = ["full"] }
lazy_static = "?"
hyper-trust-dns = { version = "?", features = [
"rustls-http2",
"dnssec-ring",
"dns-over-https-rustls",
"rustls-webpki",
"https-only"
] }
``` ```
The following example will set up a reverse proxy listening on `127.0.0.1:13900`, The following example will set up a reverse proxy listening on `127.0.0.1:13900`,
@ -46,52 +54,65 @@ and will proxy these calls:
* All other URLs will be handled by `debug_request` function, that will display request information. * All other URLs will be handled by `debug_request` function, that will display request information.
```rust,no_run ```rust
use hyper::server::conn::AddrStream; use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server, StatusCode}; use hyper::{Body, Request, Response, Server, StatusCode};
use hyper::service::{service_fn, make_service_fn}; use hyper_reverse_proxy::ReverseProxy;
use std::{convert::Infallible, net::SocketAddr}; use hyper_trust_dns::{RustlsHttpsConnector, TrustDnsResolver};
use std::net::IpAddr; use std::net::IpAddr;
use std::{convert::Infallible, net::SocketAddr};
fn debug_request(req: Request<Body>) -> Result<Response<Body>, Infallible> { lazy_static::lazy_static! {
static ref PROXY_CLIENT: ReverseProxy<RustlsHttpsConnector> = {
ReverseProxy::new(
hyper::Client::builder().build::<_, hyper::Body>(TrustDnsResolver::default().into_rustls_webpki_https_connector()),
)
};
}
fn debug_request(req: &Request<Body>) -> Result<Response<Body>, Infallible> {
let body_str = format!("{:?}", req); let body_str = format!("{:?}", req);
Ok(Response::new(Body::from(body_str))) Ok(Response::new(Body::from(body_str)))
} }
async fn handle(client_ip: IpAddr, req: Request<Body>) -> Result<Response<Body>, Infallible> { async fn handle(client_ip: IpAddr, req: Request<Body>) -> Result<Response<Body>, Infallible> {
if req.uri().path().starts_with("/target/first") { if req.uri().path().starts_with("/target/first") {
// will forward requests to port 13901 match PROXY_CLIENT.call(client_ip, "http://127.0.0.1:13901", req)
match hyper_reverse_proxy::call(client_ip, "http://127.0.0.1:13901", req).await { .await
Ok(response) => {Ok(response)} {
Err(_error) => {Ok(Response::builder() Ok(response) => {
.status(StatusCode::INTERNAL_SERVER_ERROR) Ok(response)
.body(Body::empty()) },
.unwrap())} Err(_error) => {
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap())},
} }
} else if req.uri().path().starts_with("/target/second") { } else if req.uri().path().starts_with("/target/second") {
// will forward requests to port 13902 match PROXY_CLIENT.call(client_ip, "http://127.0.0.1:13902", req)
match hyper_reverse_proxy::call(client_ip, "http://127.0.0.1:13902", req).await { .await
Ok(response) => {Ok(response)} {
Err(_error) => {Ok(Response::builder() Ok(response) => Ok(response),
.status(StatusCode::INTERNAL_SERVER_ERROR) Err(_error) => Ok(Response::builder()
.body(Body::empty()) .status(StatusCode::INTERNAL_SERVER_ERROR)
.unwrap())} .body(Body::empty())
.unwrap()),
} }
} else { } else {
debug_request(req) debug_request(&req)
} }
} }
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let bind_addr = "127.0.0.1:8000"; let bind_addr = "127.0.0.1:8000";
let addr:SocketAddr = bind_addr.parse().expect("Could not parse ip:port."); let addr: SocketAddr = bind_addr.parse().expect("Could not parse ip:port.");
let make_svc = make_service_fn(|conn: &AddrStream| { let make_svc = make_service_fn(|conn: &AddrStream| {
let remote_addr = conn.remote_addr().ip(); let remote_addr = conn.remote_addr().ip();
async move { async move { Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req))) }
Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req)))
}
}); });
let server = Server::bind(&addr).serve(make_svc); let server = Server::bind(&addr).serve(make_svc);
@ -104,33 +125,41 @@ async fn main() {
} }
``` ```
### Security ### A word about Security
Handling outgoing requests can be a security nightmare. This crate includes some features to reduce some of the risks. Everthing uses `rustls` benieth, a rust implementation for tls, faster and more secure as `openssl` Handling outgoing requests can be a security nightmare. This crate does not control the client for the outgoing requests, as it needs to be supplied to the proxy call. The following chapters may give you an overview on how you can secure your client using the `hyper-trust-dns` crate.
> You can see them being used in the example.
#### HTTPS #### HTTPS
By default the `https` feature is enabled which will allow you to request resources over https. This does not limit to only `https` traffic, if you would like so add the feature `https-only` to your `Cargo.toml` for this crate. You should use a secure transport in order to know who you are talking to and so you can trust the connection. By default `hyper-trust-dns` enables the feature flag `https-only` which will panic if you supply a transport scheme which isn't `https`. It is a healthy default as it's not only you needing to trust the source but also everyone else seeing the content on unsecure connections.
> ATTENTION: if you are running on a host with added certificates in your cert store, make sure to audit them in a interval, so neither old certificates nor malicious certificates are considered as valid by your client.
#### TLS 1.2 #### TLS 1.2
By default `tls 1.2` is disabled in favor of `tls 1.3`. As not yet all services support it `tls 1.2` can be enabled via the `tls-1-2` feature. By default `tls 1.2` is disabled in favor of `tls 1.3`, because many parts of `tls 1.2` can be considered as attach friendly. As not yet all services support it `tls 1.2` can be enabled via the `rustls-tls-12` feature.
> ATTENTION: make sure to audit the services you connect to on an interval
#### DNSSEC #### DNSSEC
By default if you enable `https` (which is enabled by default) `dnssec` is enabled. As dns queries and entries aren't "trustworthy" by default from a security standpoint. `DNSSEC` adds a new cryptographic layer for verification. To enable it use the `dnssec-ring` feature.
#### HTTP/2 #### HTTP/2
While `http/3` might be just around the corner. `http/2` support can be enabled using the `http2` feature. By default only rustlss `http1` feature is enabled for dns queries. While `http/3` might be just around the corner. `http/2` support can be enabled using the `rustls-http2` feature.
#### DoT & DoH #### DoT & DoH
DoT and DoH provide you with a secure transport between you and your dns.
By default none of them are enabled. If you would like to enabled them, you can do so using the features `doh` and `dot`. By default none of them are enabled. If you would like to enabled them, you can do so using the features `doh` and `dot`.
Recommendations: Recommendations:
- If you need to monitor network activities in relation to accessed ports, use `dot` - If you need to monitor network activities in relation to accessed ports, use dot with the `dns-over-rustls` feature flag
- If you are out in the wild and have no need to monitor based on ports, use `doh` as it will blend in with other `https` traffic - If you are out in the wild and have no need to monitor based on ports, doh with the `dns-over-https-rustls` feature flag as it will blend in with other `https` traffic
It is highly recommended to use one of them. It is highly recommended to use one of them.

67
examples/simple.rs Normal file
View File

@ -0,0 +1,67 @@
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server, StatusCode};
use hyper_reverse_proxy::ReverseProxy;
use hyper_trust_dns::{RustlsHttpsConnector, TrustDnsResolver};
use std::net::IpAddr;
use std::{convert::Infallible, net::SocketAddr};
lazy_static::lazy_static! {
static ref PROXY_CLIENT: ReverseProxy<RustlsHttpsConnector> = {
ReverseProxy::new(
hyper::Client::builder().build::<_, hyper::Body>(TrustDnsResolver::default().into_rustls_webpki_https_connector()),
)
};
}
fn debug_request(req: &Request<Body>) -> Result<Response<Body>, Infallible> {
let body_str = format!("{:?}", req);
Ok(Response::new(Body::from(body_str)))
}
async fn handle(client_ip: IpAddr, req: Request<Body>) -> Result<Response<Body>, Infallible> {
if req.uri().path().starts_with("/target/first") {
match PROXY_CLIENT
.call(client_ip, "http://127.0.0.1:13901", req)
.await
{
Ok(response) => Ok(response),
Err(_error) => Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()),
}
} else if req.uri().path().starts_with("/target/second") {
match PROXY_CLIENT
.call(client_ip, "http://127.0.0.1:13902", req)
.await
{
Ok(response) => Ok(response),
Err(_error) => Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()),
}
} else {
debug_request(&req)
}
}
#[tokio::main]
async fn main() {
let bind_addr = "127.0.0.1:8000";
let addr: SocketAddr = bind_addr.parse().expect("Could not parse ip:port.");
let make_svc = make_service_fn(|conn: &AddrStream| {
let remote_addr = conn.remote_addr().ip();
async move { Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req))) }
});
let server = Server::bind(&addr).serve(make_svc);
println!("Running server on {:?}", addr);
if let Err(e) = server.await {
eprintln!("server error: {}", e);
}
}

View File

@ -39,50 +39,63 @@
//! //!
//! ```rust,no_run //! ```rust,no_run
//! use hyper::server::conn::AddrStream; //! use hyper::server::conn::AddrStream;
//! use hyper::service::{make_service_fn, service_fn};
//! use hyper::{Body, Request, Response, Server, StatusCode}; //! use hyper::{Body, Request, Response, Server, StatusCode};
//! use hyper::service::{service_fn, make_service_fn}; //! use hyper_reverse_proxy::ReverseProxy;
//! use std::{convert::Infallible, net::SocketAddr}; //! use hyper_trust_dns::{RustlsHttpsConnector, TrustDnsResolver};
//! use std::net::IpAddr; //! use std::net::IpAddr;
//! use std::{convert::Infallible, net::SocketAddr};
//! //!
//! fn debug_request(req: Request<Body>) -> Result<Response<Body>, Infallible> { //! lazy_static::lazy_static! {
//! static ref PROXY_CLIENT: ReverseProxy<RustlsHttpsConnector> = {
//! ReverseProxy::new(
//! hyper::Client::builder().build::<_, hyper::Body>(TrustDnsResolver::default().into_rustls_webpki_https_connector()),
//! )
//! };
//! }
//!
//! fn debug_request(req: &Request<Body>) -> Result<Response<Body>, Infallible> {
//! let body_str = format!("{:?}", req); //! let body_str = format!("{:?}", req);
//! Ok(Response::new(Body::from(body_str))) //! Ok(Response::new(Body::from(body_str)))
//! } //! }
//! //!
//! async fn handle(client_ip: IpAddr, req: Request<Body>) -> Result<Response<Body>, Infallible> { //! async fn handle(client_ip: IpAddr, req: Request<Body>) -> Result<Response<Body>, Infallible> {
//! if req.uri().path().starts_with("/target/first") { //! if req.uri().path().starts_with("/target/first") {
//! // will forward requests to port 13901 //! match PROXY_CLIENT.call(client_ip, "http://127.0.0.1:13901", req)
//! match hyper_reverse_proxy::call(client_ip, "http://127.0.0.1:13901", req).await { //! .await
//! Ok(response) => {Ok(response)} //! {
//! Err(_error) => {Ok(Response::builder() //! Ok(response) => {
//! .status(StatusCode::INTERNAL_SERVER_ERROR) //! Ok(response)
//! .body(Body::empty()) //! },
//! .unwrap())} //! Err(_error) => {
//! Ok(Response::builder()
//! .status(StatusCode::INTERNAL_SERVER_ERROR)
//! .body(Body::empty())
//! .unwrap())},
//! } //! }
//! } else if req.uri().path().starts_with("/target/second") { //! } else if req.uri().path().starts_with("/target/second") {
//! // will forward requests to port 13902 //! match PROXY_CLIENT.call(client_ip, "http://127.0.0.1:13902", req)
//! match hyper_reverse_proxy::call(client_ip, "http://127.0.0.1:13902", req).await { //! .await
//! Ok(response) => {Ok(response)} //! {
//! Err(_error) => {Ok(Response::builder() //! Ok(response) => Ok(response),
//! .status(StatusCode::INTERNAL_SERVER_ERROR) //! Err(_error) => Ok(Response::builder()
//! .body(Body::empty()) //! .status(StatusCode::INTERNAL_SERVER_ERROR)
//! .unwrap())} //! .body(Body::empty())
//! .unwrap()),
//! } //! }
//! } else { //! } else {
//! debug_request(req) //! debug_request(&req)
//! } //! }
//! } //! }
//! //!
//! #[tokio::main] //! #[tokio::main]
//! async fn main() { //! async fn main() {
//! let bind_addr = "127.0.0.1:8000"; //! let bind_addr = "127.0.0.1:8000";
//! let addr:SocketAddr = bind_addr.parse().expect("Could not parse ip:port."); //! let addr: SocketAddr = bind_addr.parse().expect("Could not parse ip:port.");
//! //!
//! let make_svc = make_service_fn(|conn: &AddrStream| { //! let make_svc = make_service_fn(|conn: &AddrStream| {
//! let remote_addr = conn.remote_addr().ip(); //! let remote_addr = conn.remote_addr().ip();
//! async move { //! async move { Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req))) }
//! Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req)))
//! }
//! }); //! });
//! //!
//! let server = Server::bind(&addr).serve(make_svc); //! let server = Server::bind(&addr).serve(make_svc);
@ -93,6 +106,7 @@
//! eprintln!("server error: {}", e); //! eprintln!("server error: {}", e);
//! } //! }
//! } //! }
//!
//! ``` //! ```
#![cfg_attr(all(not(stable), test), feature(test))] #![cfg_attr(all(not(stable), test), feature(test))]
@ -102,12 +116,6 @@ extern crate tracing;
#[cfg(all(not(stable), test))] #[cfg(all(not(stable), test))]
extern crate test; extern crate test;
#[cfg(feature = "https")]
use hyper_trust_dns::TrustDnsResolver;
#[cfg(not(feature = "https"))]
use hyper::client::{connect::dns::GaiResolver, HttpConnector};
use hyper::header::{HeaderMap, HeaderName, HeaderValue, HOST}; use hyper::header::{HeaderMap, HeaderName, HeaderValue, HOST};
use hyper::http::header::{InvalidHeaderValue, ToStrError}; use hyper::http::header::{InvalidHeaderValue, ToStrError};
use hyper::http::uri::InvalidUri; use hyper::http::uri::InvalidUri;
@ -234,7 +242,7 @@ fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
if base_url.ends_with('/') { if base_url.ends_with('/') {
let mut path1_chars = base_url.chars(); let mut path1_chars = base_url.chars();
path1_chars.next(); path1_chars.next_back();
base_url = path1_chars.as_str(); base_url = path1_chars.as_str();
} }
@ -264,18 +272,10 @@ fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
} else { } else {
debug!("Merging request and forward_url query"); debug!("Merging request and forward_url query");
let request_query_items = req let request_query_items = req.uri().query().unwrap_or("").split('&').map(|el| {
.uri() let parts = el.split('=').collect::<Vec<&str>>();
.query() (parts[0], if parts.len() > 1 { parts[1] } else { "" })
.unwrap_or("") });
.split('&')
.collect::<Vec<&str>>()
.iter()
.map(|el| {
let parts = el.split('=').collect::<Vec<&str>>();
(parts[0], if parts.len() > 1 { parts[1] } else { "" })
})
.collect::<Vec<(&str, &str)>>();
let forward_query_items = forward_url_query let forward_query_items = forward_url_query
.split('&') .split('&')
@ -285,8 +285,9 @@ fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for (key, value) in request_query_items.iter() { for (key, value) in request_query_items {
if !forward_query_items.contains(key) { if !forward_query_items.iter().any(|e| e == &key) {
url.push('&');
url.push_str(key); url.push_str(key);
url.push('='); url.push('=');
url.push_str(value); url.push_str(value);
@ -384,26 +385,11 @@ fn create_proxied_request<B>(
Ok(request) Ok(request)
} }
#[cfg(feature = "https")] pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync + 'static>(
fn build_client() -> Client<hyper_trust_dns::RustlsHttpsConnector, hyper::Body> {
#[cfg(feature = "native-cert-store")]
let https = TrustDnsResolver::default().into_rustls_native_https_connector();
#[cfg(not(feature = "native-cert-store"))]
let https = TrustDnsResolver::default().into_rustls_webpki_https_connector();
Client::builder().build::<_, hyper::Body>(https)
}
#[cfg(not(feature = "https"))]
fn build_client() -> Client<HttpConnector<GaiResolver>, hyper::Body> {
Client::new()
}
pub async fn call(
client_ip: IpAddr, client_ip: IpAddr,
forward_uri: &str, forward_uri: &str,
request: Request<Body>, request: Request<Body>,
client: &'a Client<T>,
) -> Result<Response<Body>, ProxyError> { ) -> Result<Response<Body>, ProxyError> {
info!( info!(
"Received proxy call from {} to {}, client: {}", "Received proxy call from {} to {}, client: {}",
@ -414,7 +400,6 @@ pub async fn call(
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?; let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
let client = build_client();
let response = client.request(proxied_request).await?; let response = client.request(proxied_request).await?;
let proxied_response = create_proxied_response(response); let proxied_response = create_proxied_response(response);
@ -422,10 +407,29 @@ pub async fn call(
Ok(proxied_response) Ok(proxied_response)
} }
pub struct ReverseProxy<T: hyper::client::connect::Connect + Clone + Send + Sync + 'static> {
client: Client<T>,
}
impl<T: hyper::client::connect::Connect + Clone + Send + Sync + 'static> ReverseProxy<T> {
pub fn new(client: Client<T>) -> Self {
Self { client }
}
pub async fn call(
&self,
client_ip: IpAddr,
forward_uri: &str,
request: Request<Body>,
) -> Result<Response<Body>, ProxyError> {
call::<T>(client_ip, forward_uri, request, &self.client).await
}
}
#[cfg(all(not(stable), test))] #[cfg(all(not(stable), test))]
mod tests { mod tests {
use hyper::header::HeaderName; use hyper::header::HeaderName;
use hyper::Uri; use hyper::{Client, Uri};
use hyper::{HeaderMap, Request, Response}; use hyper::{HeaderMap, Request, Response};
use rand::distributions::Alphanumeric; use rand::distributions::Alphanumeric;
use rand::prelude::*; use rand::prelude::*;
@ -480,6 +484,8 @@ mod tests {
let client_ip = std::net::IpAddr::from(Ipv4Addr::from_str("0.0.0.0").unwrap()); let client_ip = std::net::IpAddr::from(Ipv4Addr::from_str("0.0.0.0").unwrap());
let client = Client::new();
b.iter(|| { b.iter(|| {
rt.block_on(async { rt.block_on(async {
let mut request = Request::builder().uri(uri.clone()); let mut request = Request::builder().uri(uri.clone());
@ -490,6 +496,7 @@ mod tests {
client_ip, client_ip,
forward_url, forward_url,
request.body(hyper::Body::from("")).unwrap(), request.body(hyper::Body::from("")).unwrap(),
&client,
) )
.await .await
.unwrap(); .unwrap();

View File

@ -1,6 +1,9 @@
use hyper::client::connect::dns::GaiResolver;
use hyper::client::HttpConnector;
use hyper::server::conn::AddrStream; use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn}; use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri}; use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
use hyper_reverse_proxy::ReverseProxy;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use test_context::test_context; use test_context::test_context;
@ -10,6 +13,14 @@ use tokio::task::JoinHandle;
use tokiotest_httpserver::handler::HandlerBuilder; use tokiotest_httpserver::handler::HandlerBuilder;
use tokiotest_httpserver::{take_port, HttpTestContext}; use tokiotest_httpserver::{take_port, HttpTestContext};
lazy_static::lazy_static! {
static ref PROXY_CLIENT: ReverseProxy<HttpConnector<GaiResolver>> = {
ReverseProxy::new(
hyper::Client::new(),
)
};
}
struct ProxyTestContext { struct ProxyTestContext {
sender: Sender<()>, sender: Sender<()>,
proxy_handler: JoinHandle<Result<(), hyper::Error>>, proxy_handler: JoinHandle<Result<(), hyper::Error>>,
@ -52,12 +63,13 @@ async fn handle(
req: Request<Body>, req: Request<Body>,
backend_port: u16, backend_port: u16,
) -> Result<Response<Body>, Infallible> { ) -> Result<Response<Body>, Infallible> {
match hyper_reverse_proxy::call( match PROXY_CLIENT
client_ip, .call(
format!("http://127.0.0.1:{}", backend_port).as_str(), client_ip,
req, format!("http://127.0.0.1:{}", backend_port).as_str(),
) req,
.await )
.await
{ {
Ok(response) => Ok(response), Ok(response) => Ok(response),
Err(_) => Ok(Response::builder().status(502).body(Body::empty()).unwrap()), Err(_) => Ok(Response::builder().status(502).body(Body::empty()).unwrap()),
@ -65,11 +77,12 @@ async fn handle(
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl AsyncTestContext for ProxyTestContext { impl<'a> AsyncTestContext for ProxyTestContext {
async fn setup() -> ProxyTestContext { async fn setup() -> ProxyTestContext {
let http_back: HttpTestContext = AsyncTestContext::setup().await; let http_back: HttpTestContext = AsyncTestContext::setup().await;
let (sender, receiver) = tokio::sync::oneshot::channel::<()>(); let (sender, receiver) = tokio::sync::oneshot::channel::<()>();
let bp_to_move = http_back.port; let bp_to_move = http_back.port;
let make_svc = make_service_fn(move |conn: &AddrStream| { let make_svc = make_service_fn(move |conn: &AddrStream| {
let remote_addr = conn.remote_addr().ip(); let remote_addr = conn.remote_addr().ip();
let back_port = bp_to_move; let back_port = bp_to_move;