Replace lazy_static with OnceLock.

This commit is contained in:
Stefan Sundin 2024-05-19 11:00:33 -07:00
parent 2ec415ecac
commit 695f9639ef
6 changed files with 96 additions and 76 deletions

View File

@ -26,7 +26,6 @@ harness = false
http-body-util = "0.1.0" http-body-util = "0.1.0"
hyper = { version = "1.2.0", features = ["client", "http1"] } hyper = { version = "1.2.0", features = ["client", "http1"] }
hyper-util = { version = "0.1.3", features = ["client-legacy", "http1","tokio"] } hyper-util = { version = "0.1.3", features = ["client-legacy", "http1","tokio"] }
lazy_static = "1.4.0"
tokio = { version = "1.17.0", features = ["io-util", "rt"] } tokio = { version = "1.17.0", features = ["io-util", "rt"] }
tracing = "0.1.34" tracing = "0.1.34"

View File

@ -10,16 +10,14 @@ use rand::distributions::Alphanumeric;
use rand::prelude::*; use rand::prelude::*;
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use std::str::FromStr; use std::str::FromStr;
use std::sync::OnceLock;
use test_context::AsyncTestContext; use test_context::AsyncTestContext;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use tokiotest_httpserver::HttpTestContext; use tokiotest_httpserver::HttpTestContext;
lazy_static::lazy_static! { fn proxy_client() -> &'static ReverseProxy<HttpConnector<GaiResolver>> {
static ref PROXY_CLIENT: ReverseProxy<HttpConnector<GaiResolver>> = { static PROXY_CLIENT: OnceLock<ReverseProxy<HttpConnector<GaiResolver>>> = OnceLock::new();
ReverseProxy::new( PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new()))
hyper::Client::new(),
)
};
} }
fn create_proxied_response(b: &mut Criterion) { fn create_proxied_response(b: &mut Criterion) {
@ -46,7 +44,7 @@ fn generate_string() -> String {
} }
fn build_headers() -> HeaderMap { fn build_headers() -> HeaderMap {
let mut headers_map: HeaderMap = (&*internal_benches::hop_headers()) let mut headers_map: HeaderMap = (internal_benches::hop_headers())
.iter() .iter()
.map(|el: &'static HeaderName| (el.clone(), generate_string().parse().unwrap())) .map(|el: &'static HeaderName| (el.clone(), generate_string().parse().unwrap()))
.collect(); .collect();
@ -86,7 +84,7 @@ fn proxy_call(b: &mut Criterion) {
*request.headers_mut().unwrap() = headers_map.clone(); *request.headers_mut().unwrap() = headers_map.clone();
black_box(&PROXY_CLIENT) black_box(&proxy_client())
.call( .call(
black_box(client_ip), black_box(client_ip),
black_box(forward_url), black_box(forward_url),

View File

@ -1,6 +1,7 @@
use std::convert::Infallible; use std::convert::Infallible;
use std::io; use std::io;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::OnceLock;
use std::time::Duration; use std::time::Duration;
use http_body_util::combinators::UnsyncBoxBody; use http_body_util::combinators::UnsyncBoxBody;
@ -19,25 +20,26 @@ use hyper_util::client::legacy::connect::HttpConnector;
type Connector = HttpsConnector<HttpConnector>; type Connector = HttpsConnector<HttpConnector>;
type ResponseBody = UnsyncBoxBody<Bytes, std::io::Error>; type ResponseBody = UnsyncBoxBody<Bytes, std::io::Error>;
lazy_static::lazy_static! { fn proxy_client() -> &'static ReverseProxy<Connector> {
static ref PROXY_CLIENT: ReverseProxy<Connector> = { static PROXY_CLIENT: OnceLock<ReverseProxy<Connector>> = OnceLock::new();
let connector: Connector = Connector::builder() PROXY_CLIENT.get_or_init(|| {
.with_tls_config( let connector: Connector = Connector::builder()
rustls::ClientConfig::builder() .with_tls_config(
.with_native_roots() rustls::ClientConfig::builder()
.expect("with_native_roots") .with_native_roots()
.with_no_client_auth(), .expect("with_native_roots")
) .with_no_client_auth(),
.https_or_http() )
.enable_http1() .https_or_http()
.build(); .enable_http1()
ReverseProxy::new( .build();
hyper_util::client::legacy::Builder::new(TokioExecutor::new()) ReverseProxy::new(
.pool_idle_timeout(Duration::from_secs(3)) hyper_util::client::legacy::Builder::new(TokioExecutor::new())
.pool_timer(TokioTimer::new()) .pool_idle_timeout(Duration::from_secs(3))
.build::<_, Incoming>(connector), .pool_timer(TokioTimer::new())
) .build::<_, Incoming>(connector),
}; )
})
} }
async fn handle( async fn handle(
@ -46,7 +48,7 @@ async fn handle(
) -> Result<Response<ResponseBody>, Infallible> { ) -> Result<Response<ResponseBody>, Infallible> {
let host = req.headers().get("host").and_then(|v| v.to_str().ok()); let host = req.headers().get("host").and_then(|v| v.to_str().ok());
if host.is_some_and(|host| host.starts_with("service1.localhost")) { if host.is_some_and(|host| host.starts_with("service1.localhost")) {
match PROXY_CLIENT match proxy_client()
.call(client_ip, "http://127.0.0.1:13901", req) .call(client_ip, "http://127.0.0.1:13901", req)
.await .await
{ {
@ -59,7 +61,7 @@ async fn handle(
.unwrap()), .unwrap()),
} }
} else if host.is_some_and(|host| host.starts_with("service2.localhost")) { } else if host.is_some_and(|host| host.starts_with("service2.localhost")) {
match PROXY_CLIENT match proxy_client()
.call(client_ip, "http://127.0.0.1:13902", req) .call(client_ip, "http://127.0.0.1:13902", req)
.await .await
{ {

View File

@ -10,30 +10,55 @@ use hyper::http::uri::InvalidUri;
use hyper::{body::Incoming, Error, Request, Response, StatusCode}; use hyper::{body::Incoming, Error, Request, Response, StatusCode};
use hyper_util::client::legacy::{connect::Connect, Client, Error as LegacyError}; use hyper_util::client::legacy::{connect::Connect, Client, Error as LegacyError};
use hyper_util::rt::tokio::TokioIo; use hyper_util::rt::tokio::TokioIo;
use lazy_static::lazy_static;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::OnceLock;
use tokio::io::copy_bidirectional; use tokio::io::copy_bidirectional;
lazy_static! { fn te_header() -> &'static HeaderName {
static ref TE_HEADER: HeaderName = HeaderName::from_static("te"); static TE_HEADER: OnceLock<HeaderName> = OnceLock::new();
static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection"); TE_HEADER.get_or_init(|| HeaderName::from_static("te"))
static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade"); }
static ref TRAILER_HEADER: HeaderName = HeaderName::from_static("trailer");
static ref TRAILERS_HEADER: HeaderName = HeaderName::from_static("trailers");
// A list of the headers, using hypers actual HeaderName comparison
static ref HOP_HEADERS: [HeaderName; 9] = [
CONNECTION_HEADER.clone(),
TE_HEADER.clone(),
TRAILER_HEADER.clone(),
HeaderName::from_static("keep-alive"),
HeaderName::from_static("proxy-connection"),
HeaderName::from_static("proxy-authenticate"),
HeaderName::from_static("proxy-authorization"),
HeaderName::from_static("transfer-encoding"),
HeaderName::from_static("upgrade"),
];
static ref X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); fn connection_header() -> &'static HeaderName {
static CONNECTION_HEADER: OnceLock<HeaderName> = OnceLock::new();
CONNECTION_HEADER.get_or_init(|| HeaderName::from_static("connection"))
}
fn upgrade_header() -> &'static HeaderName {
static UPGRADE_HEADER: OnceLock<HeaderName> = OnceLock::new();
UPGRADE_HEADER.get_or_init(|| HeaderName::from_static("upgrade"))
}
fn trailer_header() -> &'static HeaderName {
static TRAILER_HEADER: OnceLock<HeaderName> = OnceLock::new();
TRAILER_HEADER.get_or_init(|| HeaderName::from_static("trailer"))
}
fn trailers_header() -> &'static HeaderName {
static TRAILERS_HEADER: OnceLock<HeaderName> = OnceLock::new();
TRAILERS_HEADER.get_or_init(|| HeaderName::from_static("trailers"))
}
fn x_forwarded_for_header() -> &'static HeaderName {
static X_FORWARDED_FOR: OnceLock<HeaderName> = OnceLock::new();
X_FORWARDED_FOR.get_or_init(|| HeaderName::from_static("x-forwarded-for"))
}
fn hop_headers() -> &'static [HeaderName; 9] {
static HOP_HEADERS: OnceLock<[HeaderName; 9]> = OnceLock::new();
HOP_HEADERS.get_or_init(|| {
[
connection_header().clone(),
te_header().clone(),
trailer_header().clone(),
HeaderName::from_static("keep-alive"),
HeaderName::from_static("proxy-connection"),
HeaderName::from_static("proxy-authenticate"),
HeaderName::from_static("proxy-authorization"),
HeaderName::from_static("transfer-encoding"),
HeaderName::from_static("upgrade"),
]
})
} }
#[derive(Debug)] #[derive(Debug)]
@ -79,7 +104,7 @@ impl From<InvalidHeaderValue> for ProxyError {
fn remove_hop_headers(headers: &mut HeaderMap) { fn remove_hop_headers(headers: &mut HeaderMap) {
debug!("Removing hop headers"); debug!("Removing hop headers");
for header in &*HOP_HEADERS { for header in hop_headers() {
headers.remove(header); headers.remove(header);
} }
} }
@ -87,17 +112,17 @@ fn remove_hop_headers(headers: &mut HeaderMap) {
fn get_upgrade_type(headers: &HeaderMap) -> Option<String> { fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
#[allow(clippy::blocks_in_conditions)] #[allow(clippy::blocks_in_conditions)]
if headers if headers
.get(&*CONNECTION_HEADER) .get(connection_header())
.map(|value| { .map(|value| {
value value
.to_str() .to_str()
.unwrap() .unwrap()
.split(',') .split(',')
.any(|e| e.trim() == *UPGRADE_HEADER) .any(|e| e.trim() == *upgrade_header())
}) })
.unwrap_or(false) .unwrap_or(false)
{ {
if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) { if let Some(upgrade_value) = headers.get(upgrade_header()) {
debug!( debug!(
"Found upgrade header with value: {}", "Found upgrade header with value: {}",
upgrade_value.to_str().unwrap().to_owned() upgrade_value.to_str().unwrap().to_owned()
@ -111,10 +136,10 @@ fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
} }
fn remove_connection_headers(headers: &mut HeaderMap) { fn remove_connection_headers(headers: &mut HeaderMap) {
if headers.get(&*CONNECTION_HEADER).is_some() { if headers.get(connection_header()).is_some() {
debug!("Removing connection headers"); debug!("Removing connection headers");
let value = headers.get(&*CONNECTION_HEADER).cloned().unwrap(); let value = headers.get(connection_header()).cloned().unwrap();
for name in value.to_str().unwrap().split(',') { for name in value.to_str().unwrap().split(',') {
if !name.trim().is_empty() { if !name.trim().is_empty() {
@ -220,13 +245,13 @@ fn create_proxied_request<B>(
let contains_te_trailers_value = request let contains_te_trailers_value = request
.headers() .headers()
.get(&*TE_HEADER) .get(te_header())
.map(|value| { .map(|value| {
value value
.to_str() .to_str()
.unwrap() .unwrap()
.split(',') .split(',')
.any(|e| e.trim() == *TRAILERS_HEADER) .any(|e| e.trim() == *trailers_header())
}) })
.unwrap_or(false); .unwrap_or(false);
@ -240,7 +265,7 @@ fn create_proxied_request<B>(
request request
.headers_mut() .headers_mut()
.insert(&*TE_HEADER, HeaderValue::from_static("trailers")); .insert(te_header(), HeaderValue::from_static("trailers"));
} }
if let Some(value) = upgrade_type { if let Some(value) = upgrade_type {
@ -248,14 +273,14 @@ fn create_proxied_request<B>(
request request
.headers_mut() .headers_mut()
.insert(&*UPGRADE_HEADER, value.parse().unwrap()); .insert(upgrade_header(), value.parse().unwrap());
request request
.headers_mut() .headers_mut()
.insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE")); .insert(connection_header(), HeaderValue::from_static("UPGRADE"));
} }
// Add forwarding information in the headers // Add forwarding information in the headers
match request.headers_mut().entry(&*X_FORWARDED_FOR) { match request.headers_mut().entry(x_forwarded_for_header()) {
hyper::header::Entry::Vacant(entry) => { hyper::header::Entry::Vacant(entry) => {
debug!("X-Forwarded-for header was vacant"); debug!("X-Forwarded-for header was vacant");
entry.insert(client_ip.to_string().parse()?); entry.insert(client_ip.to_string().parse()?);

View File

@ -7,6 +7,7 @@ use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
use hyper_reverse_proxy::ReverseProxy; use hyper_reverse_proxy::ReverseProxy;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::OnceLock;
use test_context::test_context; use test_context::test_context;
use test_context::AsyncTestContext; use test_context::AsyncTestContext;
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
@ -14,12 +15,9 @@ use tokio::task::JoinHandle;
use tokiotest_httpserver::handler::HandlerBuilder; use tokiotest_httpserver::handler::HandlerBuilder;
use tokiotest_httpserver::{take_port, HttpTestContext}; use tokiotest_httpserver::{take_port, HttpTestContext};
lazy_static::lazy_static! { fn proxy_client() -> &'static ReverseProxy<HttpConnector<GaiResolver>> {
static ref PROXY_CLIENT: ReverseProxy<HttpConnector<GaiResolver>> = { static PROXY_CLIENT: OnceLock<ReverseProxy<HttpConnector<GaiResolver>>> = OnceLock::new();
ReverseProxy::new( PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new()))
hyper::Client::new(),
)
};
} }
struct ProxyTestContext { struct ProxyTestContext {
@ -99,7 +97,7 @@ async fn handle(
req: Request<Body>, req: Request<Body>,
backend_port: u16, backend_port: u16,
) -> Result<Response<Body>, Infallible> { ) -> Result<Response<Body>, Infallible> {
match PROXY_CLIENT match proxy_client()
.call( .call(
client_ip, client_ip,
format!("http://127.0.0.1:{}", backend_port).as_str(), format!("http://127.0.0.1:{}", backend_port).as_str(),

View File

@ -2,6 +2,7 @@ use std::{
convert::Infallible, convert::Infallible,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
process::exit, process::exit,
sync::OnceLock,
time::Duration, time::Duration,
}; };
@ -20,12 +21,9 @@ use tokiotest_httpserver::take_port;
use tungstenite::Message; use tungstenite::Message;
use url::Url; use url::Url;
lazy_static::lazy_static! { fn proxy_client() -> &'static ReverseProxy<HttpConnector<GaiResolver>> {
static ref PROXY_CLIENT: ReverseProxy<HttpConnector<GaiResolver>> = { static PROXY_CLIENT: OnceLock<ReverseProxy<HttpConnector<GaiResolver>>> = OnceLock::new();
ReverseProxy::new( PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new()))
hyper::Client::new(),
)
};
} }
struct ProxyTestContext { struct ProxyTestContext {
@ -66,7 +64,7 @@ async fn handle(
req: Request<Body>, req: Request<Body>,
backend_port: u16, backend_port: u16,
) -> Result<Response<Body>, Infallible> { ) -> Result<Response<Body>, Infallible> {
match PROXY_CLIENT match proxy_client()
.call( .call(
client_ip, client_ip,
format!("http://127.0.0.1:{}", backend_port).as_str(), format!("http://127.0.0.1:{}", backend_port).as_str(),