feat: ws support

pull/32/head
chesedo 2 years ago committed by Felipe Noronha
parent 20dbf00931
commit d4fdbf2a2e
  1. 2
      Cargo.toml
  2. 37
      src/lib.rs
  3. 35
      tests/test_websocket.rs

@ -21,11 +21,11 @@ include = ["Cargo.toml", "LICENSE", "src/**/*"]
[dependencies]
hyper = { version = "0.14.18", features = ["client"] }
lazy_static = "1.4.0"
tokio = { version = "1.17.0", features = ["full"] }
tracing = "0.1.34"
[dev-dependencies]
hyper = { version = "0.14.18", features = ["server"] }
tokio = { version = "1.17.0", features = ["full"] }
futures = "0.3.21"
async-trait = "0.1.53"
async-tungstenite = { version = "0.17", features = ["tokio-runtime"] }

@ -119,9 +119,11 @@ extern crate test;
use hyper::header::{HeaderMap, HeaderName, HeaderValue, HOST};
use hyper::http::header::{InvalidHeaderValue, ToStrError};
use hyper::http::uri::InvalidUri;
use hyper::{Body, Client, Error, Request, Response};
use hyper::upgrade::OnUpgrade;
use hyper::{upgrade, Body, Client, Error, Request, Response, StatusCode};
use lazy_static::lazy_static;
use std::net::IpAddr;
use tokio::io::copy_bidirectional;
lazy_static! {
static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
@ -397,11 +399,40 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync +
forward_uri,
client_ip
);
let mut request = request;
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
let response = client.request(proxied_request).await?;
let proxied_response = create_proxied_response(response);
let proxied_response = client.request(proxied_request).await?;
if proxied_response.status() == StatusCode::SWITCHING_PROTOCOLS {
// if response.status() != proxied_request.st
let mut response = Response::new(Body::empty());
*response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
for (k, v) in proxied_response.headers().into_iter() {
response.headers_mut().append(k, v.clone());
}
let mut response_upgraded = upgrade::on(proxied_response)
.await
.expect("failed to upgrade response");
tokio::spawn(async move {
let mut request_upgraded = request_upgraded
.expect("test")
.await
.expect("failed to upgrade request");
copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await;
});
return Ok(response);
}
let proxied_response = create_proxied_response(proxied_response);
debug!("Responding to call with response");
Ok(proxied_response)

@ -38,17 +38,27 @@ struct ProxyTestContext {
#[test_context(ProxyTestContext)]
#[tokio::test]
async fn test_websocket(ctx: &mut ProxyTestContext) {
println!("making client connection");
let (mut client, _) =
connect_async(Url::parse(&format!("ws://127.0.0.1:{}", ctx.port)).unwrap())
.await
.unwrap();
println!("made client connection");
client.send(Message::Ping("ping".into())).await.unwrap();
client.send(Message::Ping("hello".into())).await.unwrap();
let msg = client.next().await.unwrap().unwrap();
assert!(matches!(msg, Message::Pong(inner) if inner == "pong".as_bytes()));
assert!(
matches!(&msg, Message::Pong(inner) if inner == "hello".as_bytes()),
"did not get pong, but {:?}",
msg
);
let msg = client.next().await.unwrap().unwrap();
assert!(
matches!(&msg, Message::Text(inner) if inner == "All done"),
"did not get text, but {:?}",
msg
);
}
async fn handle(
@ -84,16 +94,21 @@ impl<'a> AsyncTestContext for ProxyTestContext {
let ws_handler = tokio::spawn(async move {
let ws_server = TcpListener::bind(("127.0.0.1", ws_port)).await.unwrap();
while let Ok((stream, addr)) = ws_server.accept().await {
println!("incoming connection: {addr}");
if let Ok((stream, _)) = ws_server.accept().await {
let mut websocket = accept_async(stream).await.unwrap();
let msg = websocket.next().await.unwrap().unwrap();
assert!(matches!(msg, Message::Ping(inner) if inner == "ping".as_bytes()));
println!("past ping");
assert!(
matches!(&msg, Message::Ping(inner) if inner == "hello".as_bytes()),
"did not get ping, but: {:?}",
msg
);
// Tungstenite will auto send a Pong as a response to a Ping
websocket.send(Message::Pong("pong".into())).await.unwrap();
println!("past pong");
websocket
.send(Message::Text("All done".to_string()))
.await
.unwrap();
}
});

Loading…
Cancel
Save