feat: ws support

This commit is contained in:
chesedo 2022-05-04 13:42:26 +02:00 committed by Felipe Noronha
parent 20dbf00931
commit d4fdbf2a2e
3 changed files with 60 additions and 14 deletions

View File

@ -21,11 +21,11 @@ include = ["Cargo.toml", "LICENSE", "src/**/*"]
[dependencies] [dependencies]
hyper = { version = "0.14.18", features = ["client"] } hyper = { version = "0.14.18", features = ["client"] }
lazy_static = "1.4.0" lazy_static = "1.4.0"
tokio = { version = "1.17.0", features = ["full"] }
tracing = "0.1.34" tracing = "0.1.34"
[dev-dependencies] [dev-dependencies]
hyper = { version = "0.14.18", features = ["server"] } hyper = { version = "0.14.18", features = ["server"] }
tokio = { version = "1.17.0", features = ["full"] }
futures = "0.3.21" futures = "0.3.21"
async-trait = "0.1.53" async-trait = "0.1.53"
async-tungstenite = { version = "0.17", features = ["tokio-runtime"] } async-tungstenite = { version = "0.17", features = ["tokio-runtime"] }

View File

@ -119,9 +119,11 @@ extern crate test;
use hyper::header::{HeaderMap, HeaderName, HeaderValue, HOST}; use hyper::header::{HeaderMap, HeaderName, HeaderValue, HOST};
use hyper::http::header::{InvalidHeaderValue, ToStrError}; use hyper::http::header::{InvalidHeaderValue, ToStrError};
use hyper::http::uri::InvalidUri; use hyper::http::uri::InvalidUri;
use hyper::{Body, Client, Error, Request, Response}; use hyper::upgrade::OnUpgrade;
use hyper::{upgrade, Body, Client, Error, Request, Response, StatusCode};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use std::net::IpAddr; use std::net::IpAddr;
use tokio::io::copy_bidirectional;
lazy_static! { lazy_static! {
static ref TE_HEADER: HeaderName = HeaderName::from_static("te"); static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
@ -397,11 +399,40 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync +
forward_uri, forward_uri,
client_ip client_ip
); );
let mut request = request;
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?; let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
let response = client.request(proxied_request).await?; let proxied_response = client.request(proxied_request).await?;
let proxied_response = create_proxied_response(response);
if proxied_response.status() == StatusCode::SWITCHING_PROTOCOLS {
// if response.status() != proxied_request.st
let mut response = Response::new(Body::empty());
*response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
for (k, v) in proxied_response.headers().into_iter() {
response.headers_mut().append(k, v.clone());
}
let mut response_upgraded = upgrade::on(proxied_response)
.await
.expect("failed to upgrade response");
tokio::spawn(async move {
let mut request_upgraded = request_upgraded
.expect("test")
.await
.expect("failed to upgrade request");
copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await;
});
return Ok(response);
}
let proxied_response = create_proxied_response(proxied_response);
debug!("Responding to call with response"); debug!("Responding to call with response");
Ok(proxied_response) Ok(proxied_response)

View File

@ -38,17 +38,27 @@ struct ProxyTestContext {
#[test_context(ProxyTestContext)] #[test_context(ProxyTestContext)]
#[tokio::test] #[tokio::test]
async fn test_websocket(ctx: &mut ProxyTestContext) { async fn test_websocket(ctx: &mut ProxyTestContext) {
println!("making client connection");
let (mut client, _) = let (mut client, _) =
connect_async(Url::parse(&format!("ws://127.0.0.1:{}", ctx.port)).unwrap()) connect_async(Url::parse(&format!("ws://127.0.0.1:{}", ctx.port)).unwrap())
.await .await
.unwrap(); .unwrap();
println!("made client connection"); client.send(Message::Ping("hello".into())).await.unwrap();
client.send(Message::Ping("ping".into())).await.unwrap();
let msg = client.next().await.unwrap().unwrap(); let msg = client.next().await.unwrap().unwrap();
assert!(matches!(msg, Message::Pong(inner) if inner == "pong".as_bytes())); assert!(
matches!(&msg, Message::Pong(inner) if inner == "hello".as_bytes()),
"did not get pong, but {:?}",
msg
);
let msg = client.next().await.unwrap().unwrap();
assert!(
matches!(&msg, Message::Text(inner) if inner == "All done"),
"did not get text, but {:?}",
msg
);
} }
async fn handle( async fn handle(
@ -84,16 +94,21 @@ impl<'a> AsyncTestContext for ProxyTestContext {
let ws_handler = tokio::spawn(async move { let ws_handler = tokio::spawn(async move {
let ws_server = TcpListener::bind(("127.0.0.1", ws_port)).await.unwrap(); let ws_server = TcpListener::bind(("127.0.0.1", ws_port)).await.unwrap();
while let Ok((stream, addr)) = ws_server.accept().await { if let Ok((stream, _)) = ws_server.accept().await {
println!("incoming connection: {addr}");
let mut websocket = accept_async(stream).await.unwrap(); let mut websocket = accept_async(stream).await.unwrap();
let msg = websocket.next().await.unwrap().unwrap(); let msg = websocket.next().await.unwrap().unwrap();
assert!(matches!(msg, Message::Ping(inner) if inner == "ping".as_bytes())); assert!(
println!("past ping"); matches!(&msg, Message::Ping(inner) if inner == "hello".as_bytes()),
"did not get ping, but: {:?}",
msg
);
// Tungstenite will auto send a Pong as a response to a Ping
websocket.send(Message::Pong("pong".into())).await.unwrap(); websocket
println!("past pong"); .send(Message::Text("All done".to_string()))
.await
.unwrap();
} }
}); });