diff --git a/Cargo.toml b/Cargo.toml index 8372830..56b3f45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,6 @@ harness = false http-body-util = "0.1.0" hyper = { version = "1.2.0", features = ["client", "http1"] } hyper-util = { version = "0.1.3", features = ["client-legacy", "http1","tokio"] } -lazy_static = "1.4.0" tokio = { version = "1.17.0", features = ["io-util", "rt"] } tracing = "0.1.34" diff --git a/benches/internal.rs b/benches/internal.rs index fcd25df..4fed2be 100644 --- a/benches/internal.rs +++ b/benches/internal.rs @@ -10,16 +10,14 @@ use rand::distributions::Alphanumeric; use rand::prelude::*; use std::net::Ipv4Addr; use std::str::FromStr; +use std::sync::OnceLock; use test_context::AsyncTestContext; use tokio::runtime::Runtime; use tokiotest_httpserver::HttpTestContext; -lazy_static::lazy_static! { - static ref PROXY_CLIENT: ReverseProxy> = { - ReverseProxy::new( - hyper::Client::new(), - ) - }; +fn proxy_client() -> &'static ReverseProxy> { + static PROXY_CLIENT: OnceLock>> = OnceLock::new(); + PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new())) } fn create_proxied_response(b: &mut Criterion) { @@ -46,7 +44,7 @@ fn generate_string() -> String { } fn build_headers() -> HeaderMap { - let mut headers_map: HeaderMap = (&*internal_benches::hop_headers()) + let mut headers_map: HeaderMap = (internal_benches::hop_headers()) .iter() .map(|el: &'static HeaderName| (el.clone(), generate_string().parse().unwrap())) .collect(); @@ -86,7 +84,7 @@ fn proxy_call(b: &mut Criterion) { *request.headers_mut().unwrap() = headers_map.clone(); - black_box(&PROXY_CLIENT) + black_box(&proxy_client()) .call( black_box(client_ip), black_box(forward_url), diff --git a/examples/simple.rs b/examples/simple.rs index 33fe9ad..6d44eed 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -1,6 +1,7 @@ use std::convert::Infallible; use std::io; use std::net::{IpAddr, SocketAddr}; +use std::sync::OnceLock; use std::time::Duration; use http_body_util::combinators::UnsyncBoxBody; @@ -19,25 +20,26 @@ use hyper_util::client::legacy::connect::HttpConnector; type Connector = HttpsConnector; type ResponseBody = UnsyncBoxBody; -lazy_static::lazy_static! { - static ref PROXY_CLIENT: ReverseProxy = { - let connector: Connector = Connector::builder() - .with_tls_config( - rustls::ClientConfig::builder() - .with_native_roots() - .expect("with_native_roots") - .with_no_client_auth(), - ) - .https_or_http() - .enable_http1() - .build(); - ReverseProxy::new( - hyper_util::client::legacy::Builder::new(TokioExecutor::new()) - .pool_idle_timeout(Duration::from_secs(3)) - .pool_timer(TokioTimer::new()) - .build::<_, Incoming>(connector), - ) - }; +fn proxy_client() -> &'static ReverseProxy { + static PROXY_CLIENT: OnceLock> = OnceLock::new(); + PROXY_CLIENT.get_or_init(|| { + let connector: Connector = Connector::builder() + .with_tls_config( + rustls::ClientConfig::builder() + .with_native_roots() + .expect("with_native_roots") + .with_no_client_auth(), + ) + .https_or_http() + .enable_http1() + .build(); + ReverseProxy::new( + hyper_util::client::legacy::Builder::new(TokioExecutor::new()) + .pool_idle_timeout(Duration::from_secs(3)) + .pool_timer(TokioTimer::new()) + .build::<_, Incoming>(connector), + ) + }) } async fn handle( @@ -46,7 +48,7 @@ async fn handle( ) -> Result, Infallible> { let host = req.headers().get("host").and_then(|v| v.to_str().ok()); if host.is_some_and(|host| host.starts_with("service1.localhost")) { - match PROXY_CLIENT + match proxy_client() .call(client_ip, "http://127.0.0.1:13901", req) .await { @@ -59,7 +61,7 @@ async fn handle( .unwrap()), } } else if host.is_some_and(|host| host.starts_with("service2.localhost")) { - match PROXY_CLIENT + match proxy_client() .call(client_ip, "http://127.0.0.1:13902", req) .await { diff --git a/src/lib.rs b/src/lib.rs index 2859769..19cef39 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,30 +10,55 @@ use hyper::http::uri::InvalidUri; use hyper::{body::Incoming, Error, Request, Response, StatusCode}; use hyper_util::client::legacy::{connect::Connect, Client, Error as LegacyError}; use hyper_util::rt::tokio::TokioIo; -use lazy_static::lazy_static; use std::net::{IpAddr, SocketAddr}; +use std::sync::OnceLock; use tokio::io::copy_bidirectional; -lazy_static! { - static ref TE_HEADER: HeaderName = HeaderName::from_static("te"); - static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection"); - static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade"); - static ref TRAILER_HEADER: HeaderName = HeaderName::from_static("trailer"); - static ref TRAILERS_HEADER: HeaderName = HeaderName::from_static("trailers"); - // A list of the headers, using hypers actual HeaderName comparison - static ref HOP_HEADERS: [HeaderName; 9] = [ - CONNECTION_HEADER.clone(), - TE_HEADER.clone(), - TRAILER_HEADER.clone(), - HeaderName::from_static("keep-alive"), - HeaderName::from_static("proxy-connection"), - HeaderName::from_static("proxy-authenticate"), - HeaderName::from_static("proxy-authorization"), - HeaderName::from_static("transfer-encoding"), - HeaderName::from_static("upgrade"), - ]; +fn te_header() -> &'static HeaderName { + static TE_HEADER: OnceLock = OnceLock::new(); + TE_HEADER.get_or_init(|| HeaderName::from_static("te")) +} - static ref X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); +fn connection_header() -> &'static HeaderName { + static CONNECTION_HEADER: OnceLock = OnceLock::new(); + CONNECTION_HEADER.get_or_init(|| HeaderName::from_static("connection")) +} + +fn upgrade_header() -> &'static HeaderName { + static UPGRADE_HEADER: OnceLock = OnceLock::new(); + UPGRADE_HEADER.get_or_init(|| HeaderName::from_static("upgrade")) +} + +fn trailer_header() -> &'static HeaderName { + static TRAILER_HEADER: OnceLock = OnceLock::new(); + TRAILER_HEADER.get_or_init(|| HeaderName::from_static("trailer")) +} + +fn trailers_header() -> &'static HeaderName { + static TRAILERS_HEADER: OnceLock = OnceLock::new(); + TRAILERS_HEADER.get_or_init(|| HeaderName::from_static("trailers")) +} + +fn x_forwarded_for_header() -> &'static HeaderName { + static X_FORWARDED_FOR: OnceLock = OnceLock::new(); + X_FORWARDED_FOR.get_or_init(|| HeaderName::from_static("x-forwarded-for")) +} + +fn hop_headers() -> &'static [HeaderName; 9] { + static HOP_HEADERS: OnceLock<[HeaderName; 9]> = OnceLock::new(); + HOP_HEADERS.get_or_init(|| { + [ + connection_header().clone(), + te_header().clone(), + trailer_header().clone(), + HeaderName::from_static("keep-alive"), + HeaderName::from_static("proxy-connection"), + HeaderName::from_static("proxy-authenticate"), + HeaderName::from_static("proxy-authorization"), + HeaderName::from_static("transfer-encoding"), + HeaderName::from_static("upgrade"), + ] + }) } #[derive(Debug)] @@ -79,7 +104,7 @@ impl From for ProxyError { fn remove_hop_headers(headers: &mut HeaderMap) { debug!("Removing hop headers"); - for header in &*HOP_HEADERS { + for header in hop_headers() { headers.remove(header); } } @@ -87,17 +112,17 @@ fn remove_hop_headers(headers: &mut HeaderMap) { fn get_upgrade_type(headers: &HeaderMap) -> Option { #[allow(clippy::blocks_in_conditions)] if headers - .get(&*CONNECTION_HEADER) + .get(connection_header()) .map(|value| { value .to_str() .unwrap() .split(',') - .any(|e| e.trim() == *UPGRADE_HEADER) + .any(|e| e.trim() == *upgrade_header()) }) .unwrap_or(false) { - if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) { + if let Some(upgrade_value) = headers.get(upgrade_header()) { debug!( "Found upgrade header with value: {}", upgrade_value.to_str().unwrap().to_owned() @@ -111,10 +136,10 @@ fn get_upgrade_type(headers: &HeaderMap) -> Option { } fn remove_connection_headers(headers: &mut HeaderMap) { - if headers.get(&*CONNECTION_HEADER).is_some() { + if headers.get(connection_header()).is_some() { debug!("Removing connection headers"); - let value = headers.get(&*CONNECTION_HEADER).cloned().unwrap(); + let value = headers.get(connection_header()).cloned().unwrap(); for name in value.to_str().unwrap().split(',') { if !name.trim().is_empty() { @@ -220,13 +245,13 @@ fn create_proxied_request( let contains_te_trailers_value = request .headers() - .get(&*TE_HEADER) + .get(te_header()) .map(|value| { value .to_str() .unwrap() .split(',') - .any(|e| e.trim() == *TRAILERS_HEADER) + .any(|e| e.trim() == *trailers_header()) }) .unwrap_or(false); @@ -240,7 +265,7 @@ fn create_proxied_request( request .headers_mut() - .insert(&*TE_HEADER, HeaderValue::from_static("trailers")); + .insert(te_header(), HeaderValue::from_static("trailers")); } if let Some(value) = upgrade_type { @@ -248,14 +273,14 @@ fn create_proxied_request( request .headers_mut() - .insert(&*UPGRADE_HEADER, value.parse().unwrap()); + .insert(upgrade_header(), value.parse().unwrap()); request .headers_mut() - .insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE")); + .insert(connection_header(), HeaderValue::from_static("UPGRADE")); } // Add forwarding information in the headers - match request.headers_mut().entry(&*X_FORWARDED_FOR) { + match request.headers_mut().entry(x_forwarded_for_header()) { hyper::header::Entry::Vacant(entry) => { debug!("X-Forwarded-for header was vacant"); entry.insert(client_ip.to_string().parse()?); diff --git a/tests/test_http.rs b/tests/test_http.rs index d1bacbf..3df76db 100644 --- a/tests/test_http.rs +++ b/tests/test_http.rs @@ -7,6 +7,7 @@ use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri}; use hyper_reverse_proxy::ReverseProxy; use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; +use std::sync::OnceLock; use test_context::test_context; use test_context::AsyncTestContext; use tokio::sync::oneshot::Sender; @@ -14,12 +15,9 @@ use tokio::task::JoinHandle; use tokiotest_httpserver::handler::HandlerBuilder; use tokiotest_httpserver::{take_port, HttpTestContext}; -lazy_static::lazy_static! { - static ref PROXY_CLIENT: ReverseProxy> = { - ReverseProxy::new( - hyper::Client::new(), - ) - }; +fn proxy_client() -> &'static ReverseProxy> { + static PROXY_CLIENT: OnceLock>> = OnceLock::new(); + PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new())) } struct ProxyTestContext { @@ -99,7 +97,7 @@ async fn handle( req: Request, backend_port: u16, ) -> Result, Infallible> { - match PROXY_CLIENT + match proxy_client() .call( client_ip, format!("http://127.0.0.1:{}", backend_port).as_str(), diff --git a/tests/test_websocket.rs b/tests/test_websocket.rs index afdc597..c403cc9 100644 --- a/tests/test_websocket.rs +++ b/tests/test_websocket.rs @@ -2,6 +2,7 @@ use std::{ convert::Infallible, net::{IpAddr, SocketAddr}, process::exit, + sync::OnceLock, time::Duration, }; @@ -20,12 +21,9 @@ use tokiotest_httpserver::take_port; use tungstenite::Message; use url::Url; -lazy_static::lazy_static! { - static ref PROXY_CLIENT: ReverseProxy> = { - ReverseProxy::new( - hyper::Client::new(), - ) - }; +fn proxy_client() -> &'static ReverseProxy> { + static PROXY_CLIENT: OnceLock>> = OnceLock::new(); + PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new())) } struct ProxyTestContext { @@ -66,7 +64,7 @@ async fn handle( req: Request, backend_port: u16, ) -> Result, Infallible> { - match PROXY_CLIENT + match proxy_client() .call( client_ip, format!("http://127.0.0.1:{}", backend_port).as_str(),