Replace lazy_static with OnceLock.

This commit is contained in:
Stefan Sundin 2024-05-19 11:00:33 -07:00
parent 2ec415ecac
commit 695f9639ef
6 changed files with 96 additions and 76 deletions

View File

@ -26,7 +26,6 @@ harness = false
http-body-util = "0.1.0"
hyper = { version = "1.2.0", features = ["client", "http1"] }
hyper-util = { version = "0.1.3", features = ["client-legacy", "http1","tokio"] }
lazy_static = "1.4.0"
tokio = { version = "1.17.0", features = ["io-util", "rt"] }
tracing = "0.1.34"

View File

@ -10,16 +10,14 @@ use rand::distributions::Alphanumeric;
use rand::prelude::*;
use std::net::Ipv4Addr;
use std::str::FromStr;
use std::sync::OnceLock;
use test_context::AsyncTestContext;
use tokio::runtime::Runtime;
use tokiotest_httpserver::HttpTestContext;
lazy_static::lazy_static! {
static ref PROXY_CLIENT: ReverseProxy<HttpConnector<GaiResolver>> = {
ReverseProxy::new(
hyper::Client::new(),
)
};
fn proxy_client() -> &'static ReverseProxy<HttpConnector<GaiResolver>> {
static PROXY_CLIENT: OnceLock<ReverseProxy<HttpConnector<GaiResolver>>> = OnceLock::new();
PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new()))
}
fn create_proxied_response(b: &mut Criterion) {
@ -46,7 +44,7 @@ fn generate_string() -> String {
}
fn build_headers() -> HeaderMap {
let mut headers_map: HeaderMap = (&*internal_benches::hop_headers())
let mut headers_map: HeaderMap = (internal_benches::hop_headers())
.iter()
.map(|el: &'static HeaderName| (el.clone(), generate_string().parse().unwrap()))
.collect();
@ -86,7 +84,7 @@ fn proxy_call(b: &mut Criterion) {
*request.headers_mut().unwrap() = headers_map.clone();
black_box(&PROXY_CLIENT)
black_box(&proxy_client())
.call(
black_box(client_ip),
black_box(forward_url),

View File

@ -1,6 +1,7 @@
use std::convert::Infallible;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::sync::OnceLock;
use std::time::Duration;
use http_body_util::combinators::UnsyncBoxBody;
@ -19,25 +20,26 @@ use hyper_util::client::legacy::connect::HttpConnector;
type Connector = HttpsConnector<HttpConnector>;
type ResponseBody = UnsyncBoxBody<Bytes, std::io::Error>;
lazy_static::lazy_static! {
static ref PROXY_CLIENT: ReverseProxy<Connector> = {
let connector: Connector = Connector::builder()
.with_tls_config(
rustls::ClientConfig::builder()
.with_native_roots()
.expect("with_native_roots")
.with_no_client_auth(),
)
.https_or_http()
.enable_http1()
.build();
ReverseProxy::new(
hyper_util::client::legacy::Builder::new(TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(3))
.pool_timer(TokioTimer::new())
.build::<_, Incoming>(connector),
)
};
fn proxy_client() -> &'static ReverseProxy<Connector> {
static PROXY_CLIENT: OnceLock<ReverseProxy<Connector>> = OnceLock::new();
PROXY_CLIENT.get_or_init(|| {
let connector: Connector = Connector::builder()
.with_tls_config(
rustls::ClientConfig::builder()
.with_native_roots()
.expect("with_native_roots")
.with_no_client_auth(),
)
.https_or_http()
.enable_http1()
.build();
ReverseProxy::new(
hyper_util::client::legacy::Builder::new(TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(3))
.pool_timer(TokioTimer::new())
.build::<_, Incoming>(connector),
)
})
}
async fn handle(
@ -46,7 +48,7 @@ async fn handle(
) -> Result<Response<ResponseBody>, Infallible> {
let host = req.headers().get("host").and_then(|v| v.to_str().ok());
if host.is_some_and(|host| host.starts_with("service1.localhost")) {
match PROXY_CLIENT
match proxy_client()
.call(client_ip, "http://127.0.0.1:13901", req)
.await
{
@ -59,7 +61,7 @@ async fn handle(
.unwrap()),
}
} else if host.is_some_and(|host| host.starts_with("service2.localhost")) {
match PROXY_CLIENT
match proxy_client()
.call(client_ip, "http://127.0.0.1:13902", req)
.await
{

View File

@ -10,30 +10,55 @@ use hyper::http::uri::InvalidUri;
use hyper::{body::Incoming, Error, Request, Response, StatusCode};
use hyper_util::client::legacy::{connect::Connect, Client, Error as LegacyError};
use hyper_util::rt::tokio::TokioIo;
use lazy_static::lazy_static;
use std::net::{IpAddr, SocketAddr};
use std::sync::OnceLock;
use tokio::io::copy_bidirectional;
lazy_static! {
static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection");
static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade");
static ref TRAILER_HEADER: HeaderName = HeaderName::from_static("trailer");
static ref TRAILERS_HEADER: HeaderName = HeaderName::from_static("trailers");
// A list of the headers, using hypers actual HeaderName comparison
static ref HOP_HEADERS: [HeaderName; 9] = [
CONNECTION_HEADER.clone(),
TE_HEADER.clone(),
TRAILER_HEADER.clone(),
HeaderName::from_static("keep-alive"),
HeaderName::from_static("proxy-connection"),
HeaderName::from_static("proxy-authenticate"),
HeaderName::from_static("proxy-authorization"),
HeaderName::from_static("transfer-encoding"),
HeaderName::from_static("upgrade"),
];
fn te_header() -> &'static HeaderName {
static TE_HEADER: OnceLock<HeaderName> = OnceLock::new();
TE_HEADER.get_or_init(|| HeaderName::from_static("te"))
}
static ref X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
fn connection_header() -> &'static HeaderName {
static CONNECTION_HEADER: OnceLock<HeaderName> = OnceLock::new();
CONNECTION_HEADER.get_or_init(|| HeaderName::from_static("connection"))
}
fn upgrade_header() -> &'static HeaderName {
static UPGRADE_HEADER: OnceLock<HeaderName> = OnceLock::new();
UPGRADE_HEADER.get_or_init(|| HeaderName::from_static("upgrade"))
}
fn trailer_header() -> &'static HeaderName {
static TRAILER_HEADER: OnceLock<HeaderName> = OnceLock::new();
TRAILER_HEADER.get_or_init(|| HeaderName::from_static("trailer"))
}
fn trailers_header() -> &'static HeaderName {
static TRAILERS_HEADER: OnceLock<HeaderName> = OnceLock::new();
TRAILERS_HEADER.get_or_init(|| HeaderName::from_static("trailers"))
}
fn x_forwarded_for_header() -> &'static HeaderName {
static X_FORWARDED_FOR: OnceLock<HeaderName> = OnceLock::new();
X_FORWARDED_FOR.get_or_init(|| HeaderName::from_static("x-forwarded-for"))
}
fn hop_headers() -> &'static [HeaderName; 9] {
static HOP_HEADERS: OnceLock<[HeaderName; 9]> = OnceLock::new();
HOP_HEADERS.get_or_init(|| {
[
connection_header().clone(),
te_header().clone(),
trailer_header().clone(),
HeaderName::from_static("keep-alive"),
HeaderName::from_static("proxy-connection"),
HeaderName::from_static("proxy-authenticate"),
HeaderName::from_static("proxy-authorization"),
HeaderName::from_static("transfer-encoding"),
HeaderName::from_static("upgrade"),
]
})
}
#[derive(Debug)]
@ -79,7 +104,7 @@ impl From<InvalidHeaderValue> for ProxyError {
fn remove_hop_headers(headers: &mut HeaderMap) {
debug!("Removing hop headers");
for header in &*HOP_HEADERS {
for header in hop_headers() {
headers.remove(header);
}
}
@ -87,17 +112,17 @@ fn remove_hop_headers(headers: &mut HeaderMap) {
fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
#[allow(clippy::blocks_in_conditions)]
if headers
.get(&*CONNECTION_HEADER)
.get(connection_header())
.map(|value| {
value
.to_str()
.unwrap()
.split(',')
.any(|e| e.trim() == *UPGRADE_HEADER)
.any(|e| e.trim() == *upgrade_header())
})
.unwrap_or(false)
{
if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
if let Some(upgrade_value) = headers.get(upgrade_header()) {
debug!(
"Found upgrade header with value: {}",
upgrade_value.to_str().unwrap().to_owned()
@ -111,10 +136,10 @@ fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
}
fn remove_connection_headers(headers: &mut HeaderMap) {
if headers.get(&*CONNECTION_HEADER).is_some() {
if headers.get(connection_header()).is_some() {
debug!("Removing connection headers");
let value = headers.get(&*CONNECTION_HEADER).cloned().unwrap();
let value = headers.get(connection_header()).cloned().unwrap();
for name in value.to_str().unwrap().split(',') {
if !name.trim().is_empty() {
@ -220,13 +245,13 @@ fn create_proxied_request<B>(
let contains_te_trailers_value = request
.headers()
.get(&*TE_HEADER)
.get(te_header())
.map(|value| {
value
.to_str()
.unwrap()
.split(',')
.any(|e| e.trim() == *TRAILERS_HEADER)
.any(|e| e.trim() == *trailers_header())
})
.unwrap_or(false);
@ -240,7 +265,7 @@ fn create_proxied_request<B>(
request
.headers_mut()
.insert(&*TE_HEADER, HeaderValue::from_static("trailers"));
.insert(te_header(), HeaderValue::from_static("trailers"));
}
if let Some(value) = upgrade_type {
@ -248,14 +273,14 @@ fn create_proxied_request<B>(
request
.headers_mut()
.insert(&*UPGRADE_HEADER, value.parse().unwrap());
.insert(upgrade_header(), value.parse().unwrap());
request
.headers_mut()
.insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE"));
.insert(connection_header(), HeaderValue::from_static("UPGRADE"));
}
// Add forwarding information in the headers
match request.headers_mut().entry(&*X_FORWARDED_FOR) {
match request.headers_mut().entry(x_forwarded_for_header()) {
hyper::header::Entry::Vacant(entry) => {
debug!("X-Forwarded-for header was vacant");
entry.insert(client_ip.to_string().parse()?);

View File

@ -7,6 +7,7 @@ use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri};
use hyper_reverse_proxy::ReverseProxy;
use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr};
use std::sync::OnceLock;
use test_context::test_context;
use test_context::AsyncTestContext;
use tokio::sync::oneshot::Sender;
@ -14,12 +15,9 @@ use tokio::task::JoinHandle;
use tokiotest_httpserver::handler::HandlerBuilder;
use tokiotest_httpserver::{take_port, HttpTestContext};
lazy_static::lazy_static! {
static ref PROXY_CLIENT: ReverseProxy<HttpConnector<GaiResolver>> = {
ReverseProxy::new(
hyper::Client::new(),
)
};
fn proxy_client() -> &'static ReverseProxy<HttpConnector<GaiResolver>> {
static PROXY_CLIENT: OnceLock<ReverseProxy<HttpConnector<GaiResolver>>> = OnceLock::new();
PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new()))
}
struct ProxyTestContext {
@ -99,7 +97,7 @@ async fn handle(
req: Request<Body>,
backend_port: u16,
) -> Result<Response<Body>, Infallible> {
match PROXY_CLIENT
match proxy_client()
.call(
client_ip,
format!("http://127.0.0.1:{}", backend_port).as_str(),

View File

@ -2,6 +2,7 @@ use std::{
convert::Infallible,
net::{IpAddr, SocketAddr},
process::exit,
sync::OnceLock,
time::Duration,
};
@ -20,12 +21,9 @@ use tokiotest_httpserver::take_port;
use tungstenite::Message;
use url::Url;
lazy_static::lazy_static! {
static ref PROXY_CLIENT: ReverseProxy<HttpConnector<GaiResolver>> = {
ReverseProxy::new(
hyper::Client::new(),
)
};
fn proxy_client() -> &'static ReverseProxy<HttpConnector<GaiResolver>> {
static PROXY_CLIENT: OnceLock<ReverseProxy<HttpConnector<GaiResolver>>> = OnceLock::new();
PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new()))
}
struct ProxyTestContext {
@ -66,7 +64,7 @@ async fn handle(
req: Request<Body>,
backend_port: u16,
) -> Result<Response<Body>, Infallible> {
match PROXY_CLIENT
match proxy_client()
.call(
client_ip,
format!("http://127.0.0.1:{}", backend_port).as_str(),