perf: remove headers inline

pull/28/head
Christof Weickhardt 2 years ago committed by Felipe Noronha
parent f9db949910
commit bf833a765e
  1. 126
      src/lib.rs
  2. 73
      tests/test_http.rs

@ -100,26 +100,29 @@
extern crate test;
use hyper::client::{connect::dns::GaiResolver, HttpConnector};
use hyper::header::{HeaderMap, HeaderValue, HeaderName, HOST};
use hyper::header::{HeaderName, HeaderValue, HOST};
use hyper::http::header::{InvalidHeaderValue, ToStrError};
use hyper::http::uri::InvalidUri;
use hyper::{Body, Client, Error, Request, Response};
use hyper::{Body, Client, Error, HeaderMap, Request, Response};
use lazy_static::lazy_static;
use std::net::IpAddr;
lazy_static! {
static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection");
static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade");
// A list of the headers, using hypers actual HeaderName comparison
static ref HOP_HEADERS: [HeaderName; 8] = [
HeaderName::from_static("connection"),
CONNECTION_HEADER.clone(),
TE_HEADER.clone(),
HeaderName::from_static("keep-alive"),
HeaderName::from_static("proxy-authenticate"),
HeaderName::from_static("proxy-authorization"),
HeaderName::from_static("te"),
HeaderName::from_static("trailers"),
HeaderName::from_static("trailer"),
HeaderName::from_static("transfer-encoding"),
HeaderName::from_static("upgrade"),
];
static ref X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
}
@ -154,32 +157,54 @@ impl From<InvalidHeaderValue> for ProxyError {
}
}
fn is_hop_header(name: &str) -> bool {
HOP_HEADERS.iter().any(|h| h == &name)
fn remove_hop_headers(headers: &mut HeaderMap) {
for header in &*HOP_HEADERS {
headers.remove(header);
}
}
fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
if headers
.get(&*CONNECTION_HEADER)
.map(|value| {
value
.to_str()
.unwrap()
.split(",")
.any(|e| e.to_lowercase() == "upgrade")
})
.unwrap_or(false)
{
if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
return Some(upgrade_value.to_str().unwrap().to_owned());
}
}
None
}
/// Returns a clone of the headers without the [hop-by-hop headers].
///
/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
let mut result = HeaderMap::new();
for (k, v) in headers.iter() {
if !is_hop_header(k.as_str()) {
result.insert(k.clone(), v.clone());
fn remove_connection_headers(headers: &mut HeaderMap) {
if headers.get(&*CONNECTION_HEADER).is_some() {
let value = headers.get(&*CONNECTION_HEADER).map(|e| e.clone()).unwrap();
for name in value.to_str().unwrap().split(",") {
if !name.trim().is_empty() {
headers.remove(name.trim());
}
}
}
result
}
fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
*response.headers_mut() = remove_hop_headers(response.headers());
remove_hop_headers(response.headers_mut());
remove_connection_headers(response.headers_mut());
response
}
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
if let Some(query) = req.uri().query() {
let mut forwarding_uri = String::with_capacity(forward_url.len() + req.uri().path().len() + query.len() + 1);
let mut forwarding_uri =
String::with_capacity(forward_url.len() + req.uri().path().len() + query.len() + 1);
forwarding_uri.push_str(forward_url);
forwarding_uri.push_str(req.uri().path());
@ -202,15 +227,43 @@ fn create_proxied_request<B>(
forward_url: &str,
mut request: Request<B>,
) -> Result<Request<B>, ProxyError> {
let contains_te_trailers_value = request
.headers()
.get(&*TE_HEADER)
.map(|value| {
value
.to_str()
.unwrap()
.split(",")
.any(|e| e.to_lowercase() == "trailers")
})
.unwrap_or(false);
let upgrade_type = get_upgrade_type(request.headers());
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
request
.headers_mut()
.insert(HOST, HeaderValue::from_str(uri.host().unwrap())?);
*request.uri_mut() = uri;
*request.headers_mut() = remove_hop_headers(request.headers());
remove_hop_headers(request.headers_mut());
remove_connection_headers(request.headers_mut());
if contains_te_trailers_value {
request
.headers_mut()
.insert(&*TE_HEADER, HeaderValue::from_static("trailers"));
}
if let Some(value) = upgrade_type {
request
.headers_mut()
.insert(&*UPGRADE_HEADER, value.parse().unwrap());
request
.headers_mut()
.insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE"));
}
// Add forwarding information in the headers
match request.headers_mut().entry(&*X_FORWARDED_FOR) {
@ -257,21 +310,19 @@ pub async fn call(
Ok(proxied_response)
}
#[cfg(all(not(stable), test))]
mod tests {
use rand::distributions::Alphanumeric;
use rand::prelude::*;
use hyper::header::HeaderName;
use hyper::Uri;
use hyper::{HeaderMap, Request, Response};
use rand::distributions::Alphanumeric;
use rand::prelude::*;
use std::net::Ipv4Addr;
use std::str::FromStr;
use test::Bencher;
use test_context::AsyncTestContext;
use tokiotest_httpserver::HttpTestContext;
fn generate_string() -> String {
let take = rand::thread_rng().gen::<u8>().into();
rand::thread_rng()
@ -343,7 +394,10 @@ mod tests {
*response.headers_mut().unwrap() = headers_map.clone();
super::create_proxied_response(response.body(()).unwrap(), hyper::header::HeaderValue::from_static("me"));
super::create_proxied_response(
response.body(()).unwrap(),
hyper::header::HeaderValue::from_static("me"),
);
});
}
@ -405,7 +459,6 @@ mod tests {
let port = rand::thread_rng().gen::<u8>();
let forward_url = &format!("http://0.0.0.0:{}", port);
let mut headers_map = build_headers();
headers_map.insert(
@ -420,12 +473,8 @@ mod tests {
*request.headers_mut().unwrap() = headers_map.clone();
super::create_proxied_request(
client_ip,
forward_url,
request.body(()).unwrap(),
)
.unwrap();
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap())
.unwrap();
});
}
@ -444,13 +493,8 @@ mod tests {
*request.headers_mut().unwrap() = headers_map.clone();
super::create_proxied_request(
client_ip,
forward_url,
request.body(()).unwrap(),
)
.unwrap();
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap())
.unwrap();
});
}
}

@ -1,50 +1,69 @@
use tokio::sync::oneshot::Sender;
use tokio::task::JoinHandle;
use hyper::service::{make_service_fn, service_fn};
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
use std::convert::Infallible;
use hyper::{Uri, Client, Request, Body, Server, Response, StatusCode};
use tokiotest_httpserver::{HttpTestContext, take_port};
use test_context::AsyncTestContext;
use test_context::test_context;
use std::net::{IpAddr, SocketAddr};
use test_context::test_context;
use test_context::AsyncTestContext;
use tokio::sync::oneshot::Sender;
use tokio::task::JoinHandle;
use tokiotest_httpserver::handler::HandlerBuilder;
use tokiotest_httpserver::{take_port, HttpTestContext};
struct ProxyTestContext {
sender: Sender<()>,
proxy_handler: JoinHandle<Result<(), hyper::Error>>,
http_back: HttpTestContext,
port: u16
port: u16,
}
#[test_context(ProxyTestContext)]
#[tokio::test]
async fn test_get_error_500(ctx: &mut ProxyTestContext) {
let resp = Client::new().get(ctx.uri("/500")).await.unwrap();
let client = Client::new();
let resp = client
.request(
Request::builder()
.header("keep-alive", "treu")
.method("GET")
.uri(ctx.uri("/500"))
.body(Body::from(""))
.unwrap(),
)
.await
.unwrap();
assert_eq!(500, resp.status());
}
#[test_context(ProxyTestContext)]
#[tokio::test]
async fn test_get(ctx: &mut ProxyTestContext) {
ctx.http_back.add(HandlerBuilder::new("/foo").status_code(StatusCode::OK).build());
ctx.http_back.add(
HandlerBuilder::new("/foo")
.status_code(StatusCode::OK)
.build(),
);
let resp = Client::new().get(ctx.uri("/foo")).await.unwrap();
assert_eq!(200, resp.status());
}
async fn handle(client_ip: IpAddr, req: Request<Body>, backend_port: u16) -> Result<Response<Body>, Infallible> {
match hyper_reverse_proxy::call(client_ip,
format!("http://127.0.0.1:{}", backend_port).as_str(),
req).await {
Ok(response) => {Ok(response)}
Err(_) => {Ok(Response::builder()
.status(502)
.body(Body::empty())
.unwrap())}
async fn handle(
client_ip: IpAddr,
req: Request<Body>,
backend_port: u16,
) -> Result<Response<Body>, Infallible> {
match hyper_reverse_proxy::call(
client_ip,
format!("http://127.0.0.1:{}", backend_port).as_str(),
req,
)
.await
{
Ok(response) => Ok(response),
Err(_) => Ok(Response::builder().status(502).body(Body::empty()).unwrap()),
}
}
#[async_trait::async_trait]
impl AsyncTestContext for ProxyTestContext {
async fn setup() -> ProxyTestContext {
@ -60,13 +79,17 @@ impl AsyncTestContext for ProxyTestContext {
});
let port = take_port();
let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), port);
let server = Server::bind(&addr).serve(make_svc).with_graceful_shutdown(async { receiver.await.ok(); });
let server = Server::bind(&addr)
.serve(make_svc)
.with_graceful_shutdown(async {
receiver.await.ok();
});
let proxy_handler = tokio::spawn(server);
ProxyTestContext {
sender,
proxy_handler,
http_back,
port
port,
}
}
async fn teardown(self) {
@ -77,6 +100,8 @@ impl AsyncTestContext for ProxyTestContext {
}
impl ProxyTestContext {
pub fn uri(&self, path: &str) -> Uri {
format!("http://{}:{}{}", "localhost", self.port, path).parse::<Uri>().unwrap()
format!("http://{}:{}{}", "localhost", self.port, path)
.parse::<Uri>()
.unwrap()
}
}
}

Loading…
Cancel
Save