From d4fdbf2a2e1435d6b6c22696f58e531cea6119fe Mon Sep 17 00:00:00 2001 From: chesedo Date: Wed, 4 May 2022 13:42:26 +0200 Subject: [PATCH] feat: ws support --- Cargo.toml | 2 +- src/lib.rs | 37 ++++++++++++++++++++++++++++++++++--- tests/test_websocket.rs | 35 +++++++++++++++++++++++++---------- 3 files changed, 60 insertions(+), 14 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 76f5953..66b8018 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,11 +21,11 @@ include = ["Cargo.toml", "LICENSE", "src/**/*"] [dependencies] hyper = { version = "0.14.18", features = ["client"] } lazy_static = "1.4.0" +tokio = { version = "1.17.0", features = ["full"] } tracing = "0.1.34" [dev-dependencies] hyper = { version = "0.14.18", features = ["server"] } -tokio = { version = "1.17.0", features = ["full"] } futures = "0.3.21" async-trait = "0.1.53" async-tungstenite = { version = "0.17", features = ["tokio-runtime"] } diff --git a/src/lib.rs b/src/lib.rs index 620ea65..7a96b4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,9 +119,11 @@ extern crate test; use hyper::header::{HeaderMap, HeaderName, HeaderValue, HOST}; use hyper::http::header::{InvalidHeaderValue, ToStrError}; use hyper::http::uri::InvalidUri; -use hyper::{Body, Client, Error, Request, Response}; +use hyper::upgrade::OnUpgrade; +use hyper::{upgrade, Body, Client, Error, Request, Response, StatusCode}; use lazy_static::lazy_static; use std::net::IpAddr; +use tokio::io::copy_bidirectional; lazy_static! { static ref TE_HEADER: HeaderName = HeaderName::from_static("te"); @@ -397,11 +399,40 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync + forward_uri, client_ip ); + let mut request = request; + + let request_upgraded = request.extensions_mut().remove::(); let proxied_request = create_proxied_request(client_ip, forward_uri, request)?; - let response = client.request(proxied_request).await?; - let proxied_response = create_proxied_response(response); + let proxied_response = client.request(proxied_request).await?; + + if proxied_response.status() == StatusCode::SWITCHING_PROTOCOLS { + // if response.status() != proxied_request.st + + let mut response = Response::new(Body::empty()); + *response.status_mut() = StatusCode::SWITCHING_PROTOCOLS; + + for (k, v) in proxied_response.headers().into_iter() { + response.headers_mut().append(k, v.clone()); + } + + let mut response_upgraded = upgrade::on(proxied_response) + .await + .expect("failed to upgrade response"); + + tokio::spawn(async move { + let mut request_upgraded = request_upgraded + .expect("test") + .await + .expect("failed to upgrade request"); + + copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await; + }); + + return Ok(response); + } + let proxied_response = create_proxied_response(proxied_response); debug!("Responding to call with response"); Ok(proxied_response) diff --git a/tests/test_websocket.rs b/tests/test_websocket.rs index fab6d90..afdc597 100644 --- a/tests/test_websocket.rs +++ b/tests/test_websocket.rs @@ -38,17 +38,27 @@ struct ProxyTestContext { #[test_context(ProxyTestContext)] #[tokio::test] async fn test_websocket(ctx: &mut ProxyTestContext) { - println!("making client connection"); let (mut client, _) = connect_async(Url::parse(&format!("ws://127.0.0.1:{}", ctx.port)).unwrap()) .await .unwrap(); - println!("made client connection"); - client.send(Message::Ping("ping".into())).await.unwrap(); + client.send(Message::Ping("hello".into())).await.unwrap(); let msg = client.next().await.unwrap().unwrap(); - assert!(matches!(msg, Message::Pong(inner) if inner == "pong".as_bytes())); + assert!( + matches!(&msg, Message::Pong(inner) if inner == "hello".as_bytes()), + "did not get pong, but {:?}", + msg + ); + + let msg = client.next().await.unwrap().unwrap(); + + assert!( + matches!(&msg, Message::Text(inner) if inner == "All done"), + "did not get text, but {:?}", + msg + ); } async fn handle( @@ -84,16 +94,21 @@ impl<'a> AsyncTestContext for ProxyTestContext { let ws_handler = tokio::spawn(async move { let ws_server = TcpListener::bind(("127.0.0.1", ws_port)).await.unwrap(); - while let Ok((stream, addr)) = ws_server.accept().await { - println!("incoming connection: {addr}"); + if let Ok((stream, _)) = ws_server.accept().await { let mut websocket = accept_async(stream).await.unwrap(); let msg = websocket.next().await.unwrap().unwrap(); - assert!(matches!(msg, Message::Ping(inner) if inner == "ping".as_bytes())); - println!("past ping"); + assert!( + matches!(&msg, Message::Ping(inner) if inner == "hello".as_bytes()), + "did not get ping, but: {:?}", + msg + ); + // Tungstenite will auto send a Pong as a response to a Ping - websocket.send(Message::Pong("pong".into())).await.unwrap(); - println!("past pong"); + websocket + .send(Message::Text("All done".to_string())) + .await + .unwrap(); } });