perf: remove headers inline
This commit is contained in:
parent
f9db949910
commit
bf833a765e
120
src/lib.rs
120
src/lib.rs
@ -100,22 +100,25 @@
|
|||||||
extern crate test;
|
extern crate test;
|
||||||
|
|
||||||
use hyper::client::{connect::dns::GaiResolver, HttpConnector};
|
use hyper::client::{connect::dns::GaiResolver, HttpConnector};
|
||||||
use hyper::header::{HeaderMap, HeaderValue, HeaderName, HOST};
|
use hyper::header::{HeaderName, HeaderValue, HOST};
|
||||||
use hyper::http::header::{InvalidHeaderValue, ToStrError};
|
use hyper::http::header::{InvalidHeaderValue, ToStrError};
|
||||||
use hyper::http::uri::InvalidUri;
|
use hyper::http::uri::InvalidUri;
|
||||||
use hyper::{Body, Client, Error, Request, Response};
|
use hyper::{Body, Client, Error, HeaderMap, Request, Response};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use std::net::IpAddr;
|
use std::net::IpAddr;
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
|
static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
|
||||||
|
static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection");
|
||||||
|
static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade");
|
||||||
// A list of the headers, using hypers actual HeaderName comparison
|
// A list of the headers, using hypers actual HeaderName comparison
|
||||||
static ref HOP_HEADERS: [HeaderName; 8] = [
|
static ref HOP_HEADERS: [HeaderName; 8] = [
|
||||||
HeaderName::from_static("connection"),
|
CONNECTION_HEADER.clone(),
|
||||||
|
TE_HEADER.clone(),
|
||||||
HeaderName::from_static("keep-alive"),
|
HeaderName::from_static("keep-alive"),
|
||||||
HeaderName::from_static("proxy-authenticate"),
|
HeaderName::from_static("proxy-authenticate"),
|
||||||
HeaderName::from_static("proxy-authorization"),
|
HeaderName::from_static("proxy-authorization"),
|
||||||
HeaderName::from_static("te"),
|
HeaderName::from_static("trailer"),
|
||||||
HeaderName::from_static("trailers"),
|
|
||||||
HeaderName::from_static("transfer-encoding"),
|
HeaderName::from_static("transfer-encoding"),
|
||||||
HeaderName::from_static("upgrade"),
|
HeaderName::from_static("upgrade"),
|
||||||
];
|
];
|
||||||
@ -154,32 +157,54 @@ impl From<InvalidHeaderValue> for ProxyError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_hop_header(name: &str) -> bool {
|
fn remove_hop_headers(headers: &mut HeaderMap) {
|
||||||
HOP_HEADERS.iter().any(|h| h == &name)
|
for header in &*HOP_HEADERS {
|
||||||
|
headers.remove(header);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a clone of the headers without the [hop-by-hop headers].
|
fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
|
||||||
///
|
if headers
|
||||||
/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
.get(&*CONNECTION_HEADER)
|
||||||
fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
|
.map(|value| {
|
||||||
let mut result = HeaderMap::new();
|
value
|
||||||
for (k, v) in headers.iter() {
|
.to_str()
|
||||||
if !is_hop_header(k.as_str()) {
|
.unwrap()
|
||||||
result.insert(k.clone(), v.clone());
|
.split(",")
|
||||||
|
.any(|e| e.to_lowercase() == "upgrade")
|
||||||
|
})
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
|
||||||
|
return Some(upgrade_value.to_str().unwrap().to_owned());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_connection_headers(headers: &mut HeaderMap) {
|
||||||
|
if headers.get(&*CONNECTION_HEADER).is_some() {
|
||||||
|
let value = headers.get(&*CONNECTION_HEADER).map(|e| e.clone()).unwrap();
|
||||||
|
|
||||||
|
for name in value.to_str().unwrap().split(",") {
|
||||||
|
if !name.trim().is_empty() {
|
||||||
|
headers.remove(name.trim());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
|
fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
|
||||||
*response.headers_mut() = remove_hop_headers(response.headers());
|
remove_hop_headers(response.headers_mut());
|
||||||
|
remove_connection_headers(response.headers_mut());
|
||||||
|
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
|
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
|
||||||
if let Some(query) = req.uri().query() {
|
if let Some(query) = req.uri().query() {
|
||||||
let mut forwarding_uri = String::with_capacity(forward_url.len() + req.uri().path().len() + query.len() + 1);
|
let mut forwarding_uri =
|
||||||
|
String::with_capacity(forward_url.len() + req.uri().path().len() + query.len() + 1);
|
||||||
|
|
||||||
forwarding_uri.push_str(forward_url);
|
forwarding_uri.push_str(forward_url);
|
||||||
forwarding_uri.push_str(req.uri().path());
|
forwarding_uri.push_str(req.uri().path());
|
||||||
@ -202,15 +227,43 @@ fn create_proxied_request<B>(
|
|||||||
forward_url: &str,
|
forward_url: &str,
|
||||||
mut request: Request<B>,
|
mut request: Request<B>,
|
||||||
) -> Result<Request<B>, ProxyError> {
|
) -> Result<Request<B>, ProxyError> {
|
||||||
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
|
let contains_te_trailers_value = request
|
||||||
|
.headers()
|
||||||
|
.get(&*TE_HEADER)
|
||||||
|
.map(|value| {
|
||||||
|
value
|
||||||
|
.to_str()
|
||||||
|
.unwrap()
|
||||||
|
.split(",")
|
||||||
|
.any(|e| e.to_lowercase() == "trailers")
|
||||||
|
})
|
||||||
|
.unwrap_or(false);
|
||||||
|
let upgrade_type = get_upgrade_type(request.headers());
|
||||||
|
|
||||||
|
let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
|
||||||
request
|
request
|
||||||
.headers_mut()
|
.headers_mut()
|
||||||
.insert(HOST, HeaderValue::from_str(uri.host().unwrap())?);
|
.insert(HOST, HeaderValue::from_str(uri.host().unwrap())?);
|
||||||
|
|
||||||
*request.uri_mut() = uri;
|
*request.uri_mut() = uri;
|
||||||
|
|
||||||
*request.headers_mut() = remove_hop_headers(request.headers());
|
remove_hop_headers(request.headers_mut());
|
||||||
|
remove_connection_headers(request.headers_mut());
|
||||||
|
|
||||||
|
if contains_te_trailers_value {
|
||||||
|
request
|
||||||
|
.headers_mut()
|
||||||
|
.insert(&*TE_HEADER, HeaderValue::from_static("trailers"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(value) = upgrade_type {
|
||||||
|
request
|
||||||
|
.headers_mut()
|
||||||
|
.insert(&*UPGRADE_HEADER, value.parse().unwrap());
|
||||||
|
request
|
||||||
|
.headers_mut()
|
||||||
|
.insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE"));
|
||||||
|
}
|
||||||
|
|
||||||
// Add forwarding information in the headers
|
// Add forwarding information in the headers
|
||||||
match request.headers_mut().entry(&*X_FORWARDED_FOR) {
|
match request.headers_mut().entry(&*X_FORWARDED_FOR) {
|
||||||
@ -257,21 +310,19 @@ pub async fn call(
|
|||||||
Ok(proxied_response)
|
Ok(proxied_response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(all(not(stable), test))]
|
#[cfg(all(not(stable), test))]
|
||||||
mod tests {
|
mod tests {
|
||||||
use rand::distributions::Alphanumeric;
|
|
||||||
use rand::prelude::*;
|
|
||||||
use hyper::header::HeaderName;
|
use hyper::header::HeaderName;
|
||||||
use hyper::Uri;
|
use hyper::Uri;
|
||||||
use hyper::{HeaderMap, Request, Response};
|
use hyper::{HeaderMap, Request, Response};
|
||||||
|
use rand::distributions::Alphanumeric;
|
||||||
|
use rand::prelude::*;
|
||||||
use std::net::Ipv4Addr;
|
use std::net::Ipv4Addr;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use test::Bencher;
|
use test::Bencher;
|
||||||
use test_context::AsyncTestContext;
|
use test_context::AsyncTestContext;
|
||||||
use tokiotest_httpserver::HttpTestContext;
|
use tokiotest_httpserver::HttpTestContext;
|
||||||
|
|
||||||
|
|
||||||
fn generate_string() -> String {
|
fn generate_string() -> String {
|
||||||
let take = rand::thread_rng().gen::<u8>().into();
|
let take = rand::thread_rng().gen::<u8>().into();
|
||||||
rand::thread_rng()
|
rand::thread_rng()
|
||||||
@ -343,7 +394,10 @@ mod tests {
|
|||||||
|
|
||||||
*response.headers_mut().unwrap() = headers_map.clone();
|
*response.headers_mut().unwrap() = headers_map.clone();
|
||||||
|
|
||||||
super::create_proxied_response(response.body(()).unwrap(), hyper::header::HeaderValue::from_static("me"));
|
super::create_proxied_response(
|
||||||
|
response.body(()).unwrap(),
|
||||||
|
hyper::header::HeaderValue::from_static("me"),
|
||||||
|
);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -405,7 +459,6 @@ mod tests {
|
|||||||
let port = rand::thread_rng().gen::<u8>();
|
let port = rand::thread_rng().gen::<u8>();
|
||||||
let forward_url = &format!("http://0.0.0.0:{}", port);
|
let forward_url = &format!("http://0.0.0.0:{}", port);
|
||||||
|
|
||||||
|
|
||||||
let mut headers_map = build_headers();
|
let mut headers_map = build_headers();
|
||||||
|
|
||||||
headers_map.insert(
|
headers_map.insert(
|
||||||
@ -420,11 +473,7 @@ mod tests {
|
|||||||
|
|
||||||
*request.headers_mut().unwrap() = headers_map.clone();
|
*request.headers_mut().unwrap() = headers_map.clone();
|
||||||
|
|
||||||
super::create_proxied_request(
|
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap())
|
||||||
client_ip,
|
|
||||||
forward_url,
|
|
||||||
request.body(()).unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -444,13 +493,8 @@ mod tests {
|
|||||||
|
|
||||||
*request.headers_mut().unwrap() = headers_map.clone();
|
*request.headers_mut().unwrap() = headers_map.clone();
|
||||||
|
|
||||||
super::create_proxied_request(
|
super::create_proxied_request(client_ip, forward_url, request.body(()).unwrap())
|
||||||
client_ip,
|
|
||||||
forward_url,
|
|
||||||
request.body(()).unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,50 +1,69 @@
|
|||||||
|
use hyper::server::conn::AddrStream;
|
||||||
|
use hyper::service::{make_service_fn, service_fn};
|
||||||
|
use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
|
||||||
|
use std::convert::Infallible;
|
||||||
|
use std::net::{IpAddr, SocketAddr};
|
||||||
|
use test_context::test_context;
|
||||||
|
use test_context::AsyncTestContext;
|
||||||
use tokio::sync::oneshot::Sender;
|
use tokio::sync::oneshot::Sender;
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
use hyper::service::{make_service_fn, service_fn};
|
|
||||||
use hyper::server::conn::AddrStream;
|
|
||||||
use std::convert::Infallible;
|
|
||||||
use hyper::{Uri, Client, Request, Body, Server, Response, StatusCode};
|
|
||||||
use tokiotest_httpserver::{HttpTestContext, take_port};
|
|
||||||
use test_context::AsyncTestContext;
|
|
||||||
use test_context::test_context;
|
|
||||||
use std::net::{IpAddr, SocketAddr};
|
|
||||||
use tokiotest_httpserver::handler::HandlerBuilder;
|
use tokiotest_httpserver::handler::HandlerBuilder;
|
||||||
|
use tokiotest_httpserver::{take_port, HttpTestContext};
|
||||||
|
|
||||||
struct ProxyTestContext {
|
struct ProxyTestContext {
|
||||||
sender: Sender<()>,
|
sender: Sender<()>,
|
||||||
proxy_handler: JoinHandle<Result<(), hyper::Error>>,
|
proxy_handler: JoinHandle<Result<(), hyper::Error>>,
|
||||||
http_back: HttpTestContext,
|
http_back: HttpTestContext,
|
||||||
port: u16
|
port: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test_context(ProxyTestContext)]
|
#[test_context(ProxyTestContext)]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_error_500(ctx: &mut ProxyTestContext) {
|
async fn test_get_error_500(ctx: &mut ProxyTestContext) {
|
||||||
let resp = Client::new().get(ctx.uri("/500")).await.unwrap();
|
let client = Client::new();
|
||||||
|
let resp = client
|
||||||
|
.request(
|
||||||
|
Request::builder()
|
||||||
|
.header("keep-alive", "treu")
|
||||||
|
.method("GET")
|
||||||
|
.uri(ctx.uri("/500"))
|
||||||
|
.body(Body::from(""))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
assert_eq!(500, resp.status());
|
assert_eq!(500, resp.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test_context(ProxyTestContext)]
|
#[test_context(ProxyTestContext)]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get(ctx: &mut ProxyTestContext) {
|
async fn test_get(ctx: &mut ProxyTestContext) {
|
||||||
ctx.http_back.add(HandlerBuilder::new("/foo").status_code(StatusCode::OK).build());
|
ctx.http_back.add(
|
||||||
|
HandlerBuilder::new("/foo")
|
||||||
|
.status_code(StatusCode::OK)
|
||||||
|
.build(),
|
||||||
|
);
|
||||||
let resp = Client::new().get(ctx.uri("/foo")).await.unwrap();
|
let resp = Client::new().get(ctx.uri("/foo")).await.unwrap();
|
||||||
assert_eq!(200, resp.status());
|
assert_eq!(200, resp.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(client_ip: IpAddr, req: Request<Body>, backend_port: u16) -> Result<Response<Body>, Infallible> {
|
async fn handle(
|
||||||
match hyper_reverse_proxy::call(client_ip,
|
client_ip: IpAddr,
|
||||||
|
req: Request<Body>,
|
||||||
|
backend_port: u16,
|
||||||
|
) -> Result<Response<Body>, Infallible> {
|
||||||
|
match hyper_reverse_proxy::call(
|
||||||
|
client_ip,
|
||||||
format!("http://127.0.0.1:{}", backend_port).as_str(),
|
format!("http://127.0.0.1:{}", backend_port).as_str(),
|
||||||
req).await {
|
req,
|
||||||
Ok(response) => {Ok(response)}
|
)
|
||||||
Err(_) => {Ok(Response::builder()
|
.await
|
||||||
.status(502)
|
{
|
||||||
.body(Body::empty())
|
Ok(response) => Ok(response),
|
||||||
.unwrap())}
|
Err(_) => Ok(Response::builder().status(502).body(Body::empty()).unwrap()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl AsyncTestContext for ProxyTestContext {
|
impl AsyncTestContext for ProxyTestContext {
|
||||||
async fn setup() -> ProxyTestContext {
|
async fn setup() -> ProxyTestContext {
|
||||||
@ -60,13 +79,17 @@ impl AsyncTestContext for ProxyTestContext {
|
|||||||
});
|
});
|
||||||
let port = take_port();
|
let port = take_port();
|
||||||
let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), port);
|
let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), port);
|
||||||
let server = Server::bind(&addr).serve(make_svc).with_graceful_shutdown(async { receiver.await.ok(); });
|
let server = Server::bind(&addr)
|
||||||
|
.serve(make_svc)
|
||||||
|
.with_graceful_shutdown(async {
|
||||||
|
receiver.await.ok();
|
||||||
|
});
|
||||||
let proxy_handler = tokio::spawn(server);
|
let proxy_handler = tokio::spawn(server);
|
||||||
ProxyTestContext {
|
ProxyTestContext {
|
||||||
sender,
|
sender,
|
||||||
proxy_handler,
|
proxy_handler,
|
||||||
http_back,
|
http_back,
|
||||||
port
|
port,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
async fn teardown(self) {
|
async fn teardown(self) {
|
||||||
@ -77,6 +100,8 @@ impl AsyncTestContext for ProxyTestContext {
|
|||||||
}
|
}
|
||||||
impl ProxyTestContext {
|
impl ProxyTestContext {
|
||||||
pub fn uri(&self, path: &str) -> Uri {
|
pub fn uri(&self, path: &str) -> Uri {
|
||||||
format!("http://{}:{}{}", "localhost", self.port, path).parse::<Uri>().unwrap()
|
format!("http://{}:{}{}", "localhost", self.port, path)
|
||||||
|
.parse::<Uri>()
|
||||||
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user