use hyper::client::connect::dns::GaiResolver; use hyper::client::HttpConnector; use hyper::header::{CONNECTION, UPGRADE}; use hyper::server::conn::AddrStream; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri}; use hyper_reverse_proxy::ReverseProxy; use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; use std::sync::OnceLock; use test_context::test_context; use test_context::AsyncTestContext; use tokio::sync::oneshot::Sender; use tokio::task::JoinHandle; use tokiotest_httpserver::handler::HandlerBuilder; use tokiotest_httpserver::{take_port, HttpTestContext}; fn proxy_client() -> &'static ReverseProxy> { static PROXY_CLIENT: OnceLock>> = OnceLock::new(); PROXY_CLIENT.get_or_init(|| ReverseProxy::new(hyper::Client::new())) } struct ProxyTestContext { sender: Sender<()>, proxy_handler: JoinHandle>, http_back: HttpTestContext, port: u16, } #[test_context(ProxyTestContext)] #[tokio::test] async fn test_get_error_500(ctx: &mut ProxyTestContext) { let client = Client::new(); let resp = client .request( Request::builder() .header("keep-alive", "treu") .method("GET") .uri(ctx.uri("/500")) .body(Body::from("")) .unwrap(), ) .await .unwrap(); assert_eq!(500, resp.status()); } #[test_context(ProxyTestContext)] #[tokio::test] async fn test_upgrade_mismatch(ctx: &mut ProxyTestContext) { ctx.http_back.add( HandlerBuilder::new("/ws") .status_code(StatusCode::SWITCHING_PROTOCOLS) .build(), ); let resp = Client::new() .request( Request::builder() .header(CONNECTION, "Upgrade") .header(UPGRADE, "websocket") .method("GET") .uri(ctx.uri("/ws")) .body(Body::from("")) .unwrap(), ) .await .unwrap(); assert_eq!(resp.status(), 502); } #[test_context(ProxyTestContext)] #[tokio::test] async fn test_upgrade_unrequested(ctx: &mut ProxyTestContext) { ctx.http_back.add( HandlerBuilder::new("/wrong_switch") .status_code(StatusCode::SWITCHING_PROTOCOLS) .build(), ); let resp = Client::new().get(ctx.uri("/wrong_switch")).await.unwrap(); assert_eq!(resp.status(), 502); } #[test_context(ProxyTestContext)] #[tokio::test] async fn test_get(ctx: &mut ProxyTestContext) { ctx.http_back.add( HandlerBuilder::new("/foo") .status_code(StatusCode::OK) .build(), ); let resp = Client::new().get(ctx.uri("/foo")).await.unwrap(); assert_eq!(200, resp.status()); } async fn handle( client_ip: IpAddr, req: Request, backend_port: u16, ) -> Result, Infallible> { match proxy_client() .call( client_ip, format!("http://127.0.0.1:{}", backend_port).as_str(), req, ) .await { Ok(response) => Ok(response), Err(_) => Ok(Response::builder().status(502).body(Body::empty()).unwrap()), } } #[async_trait::async_trait] impl<'a> AsyncTestContext for ProxyTestContext { async fn setup() -> ProxyTestContext { let http_back: HttpTestContext = AsyncTestContext::setup().await; let (sender, receiver) = tokio::sync::oneshot::channel::<()>(); let bp_to_move = http_back.port; let make_svc = make_service_fn(move |conn: &AddrStream| { let remote_addr = conn.remote_addr().ip(); let back_port = bp_to_move; async move { Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req, back_port))) } }); let port = take_port(); let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), port); let server = Server::bind(&addr) .serve(make_svc) .with_graceful_shutdown(async { receiver.await.ok(); }); let proxy_handler = tokio::spawn(server); ProxyTestContext { sender, proxy_handler, http_back, port, } } async fn teardown(self) { let _ = AsyncTestContext::teardown(self.http_back); let _ = self.sender.send(()).unwrap(); let _ = tokio::join!(self.proxy_handler); } } impl ProxyTestContext { pub fn uri(&self, path: &str) -> Uri { format!("http://{}:{}{}", "localhost", self.port, path) .parse::() .unwrap() } }