perf: remove headers inline

This commit is contained in:
Christof Weickhardt 2022-04-12 14:41:51 +00:00 committed by Felipe Noronha
parent f9db949910
commit bf833a765e
2 changed files with 135 additions and 66 deletions

View File

@ -100,22 +100,25 @@
extern crate test; extern crate test;
use hyper::client::{connect::dns::GaiResolver, HttpConnector}; use hyper::client::{connect::dns::GaiResolver, HttpConnector};
use hyper::header::{HeaderMap, HeaderValue, HeaderName, HOST}; use hyper::header::{HeaderName, HeaderValue, HOST};
use hyper::http::header::{InvalidHeaderValue, ToStrError}; use hyper::http::header::{InvalidHeaderValue, ToStrError};
use hyper::http::uri::InvalidUri; use hyper::http::uri::InvalidUri;
use hyper::{Body, Client, Error, Request, Response}; use hyper::{Body, Client, Error, HeaderMap, Request, Response};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use std::net::IpAddr; use std::net::IpAddr;
lazy_static! { lazy_static! {
static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection");
static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade");
// A list of the headers, using hypers actual HeaderName comparison // A list of the headers, using hypers actual HeaderName comparison
static ref HOP_HEADERS: [HeaderName; 8] = [ static ref HOP_HEADERS: [HeaderName; 8] = [
HeaderName::from_static("connection"), CONNECTION_HEADER.clone(),
TE_HEADER.clone(),
HeaderName::from_static("keep-alive"), HeaderName::from_static("keep-alive"),
HeaderName::from_static("proxy-authenticate"), HeaderName::from_static("proxy-authenticate"),
HeaderName::from_static("proxy-authorization"), HeaderName::from_static("proxy-authorization"),
HeaderName::from_static("te"), HeaderName::from_static("trailer"),
HeaderName::from_static("trailers"),
HeaderName::from_static("transfer-encoding"), HeaderName::from_static("transfer-encoding"),
HeaderName::from_static("upgrade"), HeaderName::from_static("upgrade"),
]; ];
@ -154,32 +157,54 @@ impl From<InvalidHeaderValue> for ProxyError {
} }
} }
fn is_hop_header(name: &str) -> bool { fn remove_hop_headers(headers: &mut HeaderMap) {
HOP_HEADERS.iter().any(|h| h == &name) for header in &*HOP_HEADERS {
headers.remove(header);
}
} }
/// Returns a clone of the headers without the [hop-by-hop headers]. fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
/// if headers
/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html .get(&*CONNECTION_HEADER)
fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> { .map(|value| {
let mut result = HeaderMap::new(); value
for (k, v) in headers.iter() { .to_str()
if !is_hop_header(k.as_str()) { .unwrap()
result.insert(k.clone(), v.clone()); .split(",")
.any(|e| e.to_lowercase() == "upgrade")
})
.unwrap_or(false)
{
if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
return Some(upgrade_value.to_str().unwrap().to_owned());
}
}
None
}
fn remove_connection_headers(headers: &mut HeaderMap) {
if headers.get(&*CONNECTION_HEADER).is_some() {
let value = headers.get(&*CONNECTION_HEADER).map(|e| e.clone()).unwrap();
for name in value.to_str().unwrap().split(",") {
if !name.trim().is_empty() {
headers.remove(name.trim());
}
} }
} }
result
} }
fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> { fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
*response.headers_mut() = remove_hop_headers(response.headers()); remove_hop_headers(response.headers_mut());
remove_connection_headers(response.headers_mut());
response response
} }
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String { fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
if let Some(query) = req.uri().query() { if let Some(query) = req.uri().query() {
let mut forwarding_uri = String::with_capacity(forward_url.len() + req.uri().path().len() + query.len() + 1); let mut forwarding_uri =
String::with_capacity(forward_url.len() + req.uri().path().len() + query.len() + 1);
forwarding_uri.push_str(forward_url); forwarding_uri.push_str(forward_url);
forwarding_uri.push_str(req.uri().path()); forwarding_uri.push_str(req.uri().path());
@ -202,15 +227,43 @@ fn create_proxied_request<B>(
forward_url: &str, forward_url: &str,
mut request: Request<B>, mut request: Request<B>,
) -> Result<Request<B>, ProxyError> { ) -> Result<Request<B>, ProxyError> {
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?; let contains_te_trailers_value = request
.headers()
.get(&*TE_HEADER)
.map(|value| {
value
.to_str()
.unwrap()
.split(",")
.any(|e| e.to_lowercase() == "trailers")
})
.unwrap_or(false);
let upgrade_type = get_upgrade_type(request.headers());
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
request request
.headers_mut() .headers_mut()
.insert(HOST, HeaderValue::from_str(uri.host().unwrap())?); .insert(HOST, HeaderValue::from_str(uri.host().unwrap())?);
*request.uri_mut() = uri; *request.uri_mut() = uri;
*request.headers_mut() = remove_hop_headers(request.headers()); remove_hop_headers(request.headers_mut());
remove_connection_headers(request.headers_mut());
if contains_te_trailers_value {
request
.headers_mut()
.insert(&*TE_HEADER, HeaderValue::from_static("trailers"));
}
if let Some(value) = upgrade_type {
request
.headers_mut()
.insert(&*UPGRADE_HEADER, value.parse().unwrap());
request
.headers_mut()
.insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE"));
}
// Add forwarding information in the headers // Add forwarding information in the headers
match request.headers_mut().entry(&*X_FORWARDED_FOR) { match request.headers_mut().entry(&*X_FORWARDED_FOR) {
@ -257,21 +310,19 @@ pub async fn call(
Ok(proxied_response) Ok(proxied_response)
} }
#[cfg(all(not(stable), test))] #[cfg(all(not(stable), test))]
mod tests { mod tests {
use rand::distributions::Alphanumeric;
use rand::prelude::*;
use hyper::header::HeaderName; use hyper::header::HeaderName;
use hyper::Uri; use hyper::Uri;
use hyper::{HeaderMap, Request, Response}; use hyper::{HeaderMap, Request, Response};
use rand::distributions::Alphanumeric;
use rand::prelude::*;
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use std::str::FromStr; use std::str::FromStr;
use test::Bencher; use test::Bencher;
use test_context::AsyncTestContext; use test_context::AsyncTestContext;
use tokiotest_httpserver::HttpTestContext; use tokiotest_httpserver::HttpTestContext;
fn generate_string() -> String { fn generate_string() -> String {
let take = rand::thread_rng().gen::<u8>().into(); let take = rand::thread_rng().gen::<u8>().into();
rand::thread_rng() rand::thread_rng()
@ -343,7 +394,10 @@ mod tests {
*response.headers_mut().unwrap() = headers_map.clone(); *response.headers_mut().unwrap() = headers_map.clone();
super::create_proxied_response(response.body(()).unwrap(), hyper::header::HeaderValue::from_static("me")); super::create_proxied_response(
response.body(()).unwrap(),
hyper::header::HeaderValue::from_static("me"),
);
}); });
} }
@ -405,7 +459,6 @@ mod tests {
let port = rand::thread_rng().gen::<u8>(); let port = rand::thread_rng().gen::<u8>();
let forward_url = &format!("http://0.0.0.0:{}", port); let forward_url = &format!("http://0.0.0.0:{}", port);
let mut headers_map = build_headers(); let mut headers_map = build_headers();
headers_map.insert( headers_map.insert(
@ -420,12 +473,8 @@ mod tests {
*request.headers_mut().unwrap() = headers_map.clone(); *request.headers_mut().unwrap() = headers_map.clone();
super::create_proxied_request( super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap())
client_ip, .unwrap();
forward_url,
request.body(()).unwrap(),
)
.unwrap();
}); });
} }
@ -444,13 +493,8 @@ mod tests {
*request.headers_mut().unwrap() = headers_map.clone(); *request.headers_mut().unwrap() = headers_map.clone();
super::create_proxied_request( super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap())
client_ip, .unwrap();
forward_url,
request.body(()).unwrap(),
)
.unwrap();
}); });
} }
} }

View File

@ -1,50 +1,69 @@
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr};
use test_context::test_context;
use test_context::AsyncTestContext;
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use hyper::service::{make_service_fn, service_fn};
use hyper::server::conn::AddrStream;
use std::convert::Infallible;
use hyper::{Uri, Client, Request, Body, Server, Response, StatusCode};
use tokiotest_httpserver::{HttpTestContext, take_port};
use test_context::AsyncTestContext;
use test_context::test_context;
use std::net::{IpAddr, SocketAddr};
use tokiotest_httpserver::handler::HandlerBuilder; use tokiotest_httpserver::handler::HandlerBuilder;
use tokiotest_httpserver::{take_port, HttpTestContext};
struct ProxyTestContext { struct ProxyTestContext {
sender: Sender<()>, sender: Sender<()>,
proxy_handler: JoinHandle<Result<(), hyper::Error>>, proxy_handler: JoinHandle<Result<(), hyper::Error>>,
http_back: HttpTestContext, http_back: HttpTestContext,
port: u16 port: u16,
} }
#[test_context(ProxyTestContext)] #[test_context(ProxyTestContext)]
#[tokio::test] #[tokio::test]
async fn test_get_error_500(ctx: &mut ProxyTestContext) { async fn test_get_error_500(ctx: &mut ProxyTestContext) {
let resp = Client::new().get(ctx.uri("/500")).await.unwrap(); let client = Client::new();
let resp = client
.request(
Request::builder()
.header("keep-alive", "treu")
.method("GET")
.uri(ctx.uri("/500"))
.body(Body::from(""))
.unwrap(),
)
.await
.unwrap();
assert_eq!(500, resp.status()); assert_eq!(500, resp.status());
} }
#[test_context(ProxyTestContext)] #[test_context(ProxyTestContext)]
#[tokio::test] #[tokio::test]
async fn test_get(ctx: &mut ProxyTestContext) { async fn test_get(ctx: &mut ProxyTestContext) {
ctx.http_back.add(HandlerBuilder::new("/foo").status_code(StatusCode::OK).build()); ctx.http_back.add(
HandlerBuilder::new("/foo")
.status_code(StatusCode::OK)
.build(),
);
let resp = Client::new().get(ctx.uri("/foo")).await.unwrap(); let resp = Client::new().get(ctx.uri("/foo")).await.unwrap();
assert_eq!(200, resp.status()); assert_eq!(200, resp.status());
} }
async fn handle(client_ip: IpAddr, req: Request<Body>, backend_port: u16) -> Result<Response<Body>, Infallible> { async fn handle(
match hyper_reverse_proxy::call(client_ip, client_ip: IpAddr,
format!("http://127.0.0.1:{}", backend_port).as_str(), req: Request<Body>,
req).await { backend_port: u16,
Ok(response) => {Ok(response)} ) -> Result<Response<Body>, Infallible> {
Err(_) => {Ok(Response::builder() match hyper_reverse_proxy::call(
.status(502) client_ip,
.body(Body::empty()) format!("http://127.0.0.1:{}", backend_port).as_str(),
.unwrap())} req,
)
.await
{
Ok(response) => Ok(response),
Err(_) => Ok(Response::builder().status(502).body(Body::empty()).unwrap()),
} }
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl AsyncTestContext for ProxyTestContext { impl AsyncTestContext for ProxyTestContext {
async fn setup() -> ProxyTestContext { async fn setup() -> ProxyTestContext {
@ -60,13 +79,17 @@ impl AsyncTestContext for ProxyTestContext {
}); });
let port = take_port(); let port = take_port();
let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), port); let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), port);
let server = Server::bind(&addr).serve(make_svc).with_graceful_shutdown(async { receiver.await.ok(); }); let server = Server::bind(&addr)
.serve(make_svc)
.with_graceful_shutdown(async {
receiver.await.ok();
});
let proxy_handler = tokio::spawn(server); let proxy_handler = tokio::spawn(server);
ProxyTestContext { ProxyTestContext {
sender, sender,
proxy_handler, proxy_handler,
http_back, http_back,
port port,
} }
} }
async fn teardown(self) { async fn teardown(self) {
@ -77,6 +100,8 @@ impl AsyncTestContext for ProxyTestContext {
} }
impl ProxyTestContext { impl ProxyTestContext {
pub fn uri(&self, path: &str) -> Uri { pub fn uri(&self, path: &str) -> Uri {
format!("http://{}:{}{}", "localhost", self.port, path).parse::<Uri>().unwrap() format!("http://{}:{}{}", "localhost", self.port, path)
.parse::<Uri>()
.unwrap()
} }
} }