perf: remove headers inline
This commit is contained in:
parent
f9db949910
commit
bf833a765e
126
src/lib.rs
126
src/lib.rs
@ -100,26 +100,29 @@
|
||||
extern crate test;
|
||||
|
||||
use hyper::client::{connect::dns::GaiResolver, HttpConnector};
|
||||
use hyper::header::{HeaderMap, HeaderValue, HeaderName, HOST};
|
||||
use hyper::header::{HeaderName, HeaderValue, HOST};
|
||||
use hyper::http::header::{InvalidHeaderValue, ToStrError};
|
||||
use hyper::http::uri::InvalidUri;
|
||||
use hyper::{Body, Client, Error, Request, Response};
|
||||
use hyper::{Body, Client, Error, HeaderMap, Request, Response};
|
||||
use lazy_static::lazy_static;
|
||||
use std::net::IpAddr;
|
||||
|
||||
lazy_static! {
|
||||
static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
|
||||
static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection");
|
||||
static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade");
|
||||
// A list of the headers, using hypers actual HeaderName comparison
|
||||
static ref HOP_HEADERS: [HeaderName; 8] = [
|
||||
HeaderName::from_static("connection"),
|
||||
CONNECTION_HEADER.clone(),
|
||||
TE_HEADER.clone(),
|
||||
HeaderName::from_static("keep-alive"),
|
||||
HeaderName::from_static("proxy-authenticate"),
|
||||
HeaderName::from_static("proxy-authorization"),
|
||||
HeaderName::from_static("te"),
|
||||
HeaderName::from_static("trailers"),
|
||||
HeaderName::from_static("trailer"),
|
||||
HeaderName::from_static("transfer-encoding"),
|
||||
HeaderName::from_static("upgrade"),
|
||||
];
|
||||
|
||||
|
||||
static ref X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
|
||||
}
|
||||
|
||||
@ -154,32 +157,54 @@ impl From<InvalidHeaderValue> for ProxyError {
|
||||
}
|
||||
}
|
||||
|
||||
fn is_hop_header(name: &str) -> bool {
|
||||
HOP_HEADERS.iter().any(|h| h == &name)
|
||||
fn remove_hop_headers(headers: &mut HeaderMap) {
|
||||
for header in &*HOP_HEADERS {
|
||||
headers.remove(header);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a clone of the headers without the [hop-by-hop headers].
|
||||
///
|
||||
/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||
fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
|
||||
let mut result = HeaderMap::new();
|
||||
for (k, v) in headers.iter() {
|
||||
if !is_hop_header(k.as_str()) {
|
||||
result.insert(k.clone(), v.clone());
|
||||
fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
|
||||
if headers
|
||||
.get(&*CONNECTION_HEADER)
|
||||
.map(|value| {
|
||||
value
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.split(",")
|
||||
.any(|e| e.to_lowercase() == "upgrade")
|
||||
})
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
|
||||
return Some(upgrade_value.to_str().unwrap().to_owned());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn remove_connection_headers(headers: &mut HeaderMap) {
|
||||
if headers.get(&*CONNECTION_HEADER).is_some() {
|
||||
let value = headers.get(&*CONNECTION_HEADER).map(|e| e.clone()).unwrap();
|
||||
|
||||
for name in value.to_str().unwrap().split(",") {
|
||||
if !name.trim().is_empty() {
|
||||
headers.remove(name.trim());
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
|
||||
*response.headers_mut() = remove_hop_headers(response.headers());
|
||||
remove_hop_headers(response.headers_mut());
|
||||
remove_connection_headers(response.headers_mut());
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
|
||||
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
|
||||
if let Some(query) = req.uri().query() {
|
||||
let mut forwarding_uri = String::with_capacity(forward_url.len() + req.uri().path().len() + query.len() + 1);
|
||||
let mut forwarding_uri =
|
||||
String::with_capacity(forward_url.len() + req.uri().path().len() + query.len() + 1);
|
||||
|
||||
forwarding_uri.push_str(forward_url);
|
||||
forwarding_uri.push_str(req.uri().path());
|
||||
@ -202,15 +227,43 @@ fn create_proxied_request<B>(
|
||||
forward_url: &str,
|
||||
mut request: Request<B>,
|
||||
) -> Result<Request<B>, ProxyError> {
|
||||
let contains_te_trailers_value = request
|
||||
.headers()
|
||||
.get(&*TE_HEADER)
|
||||
.map(|value| {
|
||||
value
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.split(",")
|
||||
.any(|e| e.to_lowercase() == "trailers")
|
||||
})
|
||||
.unwrap_or(false);
|
||||
let upgrade_type = get_upgrade_type(request.headers());
|
||||
|
||||
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
|
||||
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(HOST, HeaderValue::from_str(uri.host().unwrap())?);
|
||||
|
||||
*request.uri_mut() = uri;
|
||||
|
||||
*request.headers_mut() = remove_hop_headers(request.headers());
|
||||
remove_hop_headers(request.headers_mut());
|
||||
remove_connection_headers(request.headers_mut());
|
||||
|
||||
if contains_te_trailers_value {
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(&*TE_HEADER, HeaderValue::from_static("trailers"));
|
||||
}
|
||||
|
||||
if let Some(value) = upgrade_type {
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(&*UPGRADE_HEADER, value.parse().unwrap());
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE"));
|
||||
}
|
||||
|
||||
// Add forwarding information in the headers
|
||||
match request.headers_mut().entry(&*X_FORWARDED_FOR) {
|
||||
@ -257,21 +310,19 @@ pub async fn call(
|
||||
Ok(proxied_response)
|
||||
}
|
||||
|
||||
|
||||
#[cfg(all(not(stable), test))]
|
||||
mod tests {
|
||||
use rand::distributions::Alphanumeric;
|
||||
use rand::prelude::*;
|
||||
use hyper::header::HeaderName;
|
||||
use hyper::Uri;
|
||||
use hyper::{HeaderMap, Request, Response};
|
||||
use rand::distributions::Alphanumeric;
|
||||
use rand::prelude::*;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::str::FromStr;
|
||||
use test::Bencher;
|
||||
use test_context::AsyncTestContext;
|
||||
use tokiotest_httpserver::HttpTestContext;
|
||||
|
||||
|
||||
fn generate_string() -> String {
|
||||
let take = rand::thread_rng().gen::<u8>().into();
|
||||
rand::thread_rng()
|
||||
@ -343,7 +394,10 @@ mod tests {
|
||||
|
||||
*response.headers_mut().unwrap() = headers_map.clone();
|
||||
|
||||
super::create_proxied_response(response.body(()).unwrap(), hyper::header::HeaderValue::from_static("me"));
|
||||
super::create_proxied_response(
|
||||
response.body(()).unwrap(),
|
||||
hyper::header::HeaderValue::from_static("me"),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
@ -405,7 +459,6 @@ mod tests {
|
||||
let port = rand::thread_rng().gen::<u8>();
|
||||
let forward_url = &format!("http://0.0.0.0:{}", port);
|
||||
|
||||
|
||||
let mut headers_map = build_headers();
|
||||
|
||||
headers_map.insert(
|
||||
@ -420,12 +473,8 @@ mod tests {
|
||||
|
||||
*request.headers_mut().unwrap() = headers_map.clone();
|
||||
|
||||
super::create_proxied_request(
|
||||
client_ip,
|
||||
forward_url,
|
||||
request.body(()).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap())
|
||||
.unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
@ -444,13 +493,8 @@ mod tests {
|
||||
|
||||
*request.headers_mut().unwrap() = headers_map.clone();
|
||||
|
||||
super::create_proxied_request(
|
||||
client_ip,
|
||||
forward_url,
|
||||
request.body(()).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap())
|
||||
.unwrap();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,50 +1,69 @@
|
||||
use hyper::server::conn::AddrStream;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
|
||||
use std::convert::Infallible;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use test_context::test_context;
|
||||
use test_context::AsyncTestContext;
|
||||
use tokio::sync::oneshot::Sender;
|
||||
use tokio::task::JoinHandle;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::server::conn::AddrStream;
|
||||
use std::convert::Infallible;
|
||||
use hyper::{Uri, Client, Request, Body, Server, Response, StatusCode};
|
||||
use tokiotest_httpserver::{HttpTestContext, take_port};
|
||||
use test_context::AsyncTestContext;
|
||||
use test_context::test_context;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use tokiotest_httpserver::handler::HandlerBuilder;
|
||||
use tokiotest_httpserver::{take_port, HttpTestContext};
|
||||
|
||||
struct ProxyTestContext {
|
||||
sender: Sender<()>,
|
||||
proxy_handler: JoinHandle<Result<(), hyper::Error>>,
|
||||
http_back: HttpTestContext,
|
||||
port: u16
|
||||
port: u16,
|
||||
}
|
||||
|
||||
#[test_context(ProxyTestContext)]
|
||||
#[tokio::test]
|
||||
async fn test_get_error_500(ctx: &mut ProxyTestContext) {
|
||||
let resp = Client::new().get(ctx.uri("/500")).await.unwrap();
|
||||
let client = Client::new();
|
||||
let resp = client
|
||||
.request(
|
||||
Request::builder()
|
||||
.header("keep-alive", "treu")
|
||||
.method("GET")
|
||||
.uri(ctx.uri("/500"))
|
||||
.body(Body::from(""))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(500, resp.status());
|
||||
}
|
||||
|
||||
#[test_context(ProxyTestContext)]
|
||||
#[tokio::test]
|
||||
async fn test_get(ctx: &mut ProxyTestContext) {
|
||||
ctx.http_back.add(HandlerBuilder::new("/foo").status_code(StatusCode::OK).build());
|
||||
ctx.http_back.add(
|
||||
HandlerBuilder::new("/foo")
|
||||
.status_code(StatusCode::OK)
|
||||
.build(),
|
||||
);
|
||||
let resp = Client::new().get(ctx.uri("/foo")).await.unwrap();
|
||||
assert_eq!(200, resp.status());
|
||||
}
|
||||
|
||||
async fn handle(client_ip: IpAddr, req: Request<Body>, backend_port: u16) -> Result<Response<Body>, Infallible> {
|
||||
match hyper_reverse_proxy::call(client_ip,
|
||||
format!("http://127.0.0.1:{}", backend_port).as_str(),
|
||||
req).await {
|
||||
Ok(response) => {Ok(response)}
|
||||
Err(_) => {Ok(Response::builder()
|
||||
.status(502)
|
||||
.body(Body::empty())
|
||||
.unwrap())}
|
||||
async fn handle(
|
||||
client_ip: IpAddr,
|
||||
req: Request<Body>,
|
||||
backend_port: u16,
|
||||
) -> Result<Response<Body>, Infallible> {
|
||||
match hyper_reverse_proxy::call(
|
||||
client_ip,
|
||||
format!("http://127.0.0.1:{}", backend_port).as_str(),
|
||||
req,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response),
|
||||
Err(_) => Ok(Response::builder().status(502).body(Body::empty()).unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AsyncTestContext for ProxyTestContext {
|
||||
async fn setup() -> ProxyTestContext {
|
||||
@ -60,13 +79,17 @@ impl AsyncTestContext for ProxyTestContext {
|
||||
});
|
||||
let port = take_port();
|
||||
let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), port);
|
||||
let server = Server::bind(&addr).serve(make_svc).with_graceful_shutdown(async { receiver.await.ok(); });
|
||||
let server = Server::bind(&addr)
|
||||
.serve(make_svc)
|
||||
.with_graceful_shutdown(async {
|
||||
receiver.await.ok();
|
||||
});
|
||||
let proxy_handler = tokio::spawn(server);
|
||||
ProxyTestContext {
|
||||
sender,
|
||||
proxy_handler,
|
||||
http_back,
|
||||
port
|
||||
port,
|
||||
}
|
||||
}
|
||||
async fn teardown(self) {
|
||||
@ -77,6 +100,8 @@ impl AsyncTestContext for ProxyTestContext {
|
||||
}
|
||||
impl ProxyTestContext {
|
||||
pub fn uri(&self, path: &str) -> Uri {
|
||||
format!("http://{}:{}{}", "localhost", self.port, path).parse::<Uri>().unwrap()
|
||||
format!("http://{}:{}{}", "localhost", self.port, path)
|
||||
.parse::<Uri>()
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user