feat: ws support
This commit is contained in:
parent
20dbf00931
commit
d4fdbf2a2e
@ -21,11 +21,11 @@ include = ["Cargo.toml", "LICENSE", "src/**/*"]
|
||||
[dependencies]
|
||||
hyper = { version = "0.14.18", features = ["client"] }
|
||||
lazy_static = "1.4.0"
|
||||
tokio = { version = "1.17.0", features = ["full"] }
|
||||
tracing = "0.1.34"
|
||||
|
||||
[dev-dependencies]
|
||||
hyper = { version = "0.14.18", features = ["server"] }
|
||||
tokio = { version = "1.17.0", features = ["full"] }
|
||||
futures = "0.3.21"
|
||||
async-trait = "0.1.53"
|
||||
async-tungstenite = { version = "0.17", features = ["tokio-runtime"] }
|
||||
|
37
src/lib.rs
37
src/lib.rs
@ -119,9 +119,11 @@ extern crate test;
|
||||
use hyper::header::{HeaderMap, HeaderName, HeaderValue, HOST};
|
||||
use hyper::http::header::{InvalidHeaderValue, ToStrError};
|
||||
use hyper::http::uri::InvalidUri;
|
||||
use hyper::{Body, Client, Error, Request, Response};
|
||||
use hyper::upgrade::OnUpgrade;
|
||||
use hyper::{upgrade, Body, Client, Error, Request, Response, StatusCode};
|
||||
use lazy_static::lazy_static;
|
||||
use std::net::IpAddr;
|
||||
use tokio::io::copy_bidirectional;
|
||||
|
||||
lazy_static! {
|
||||
static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
|
||||
@ -397,11 +399,40 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync +
|
||||
forward_uri,
|
||||
client_ip
|
||||
);
|
||||
let mut request = request;
|
||||
|
||||
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
|
||||
|
||||
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
|
||||
|
||||
let response = client.request(proxied_request).await?;
|
||||
let proxied_response = create_proxied_response(response);
|
||||
let proxied_response = client.request(proxied_request).await?;
|
||||
|
||||
if proxied_response.status() == StatusCode::SWITCHING_PROTOCOLS {
|
||||
// if response.status() != proxied_request.st
|
||||
|
||||
let mut response = Response::new(Body::empty());
|
||||
*response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
|
||||
|
||||
for (k, v) in proxied_response.headers().into_iter() {
|
||||
response.headers_mut().append(k, v.clone());
|
||||
}
|
||||
|
||||
let mut response_upgraded = upgrade::on(proxied_response)
|
||||
.await
|
||||
.expect("failed to upgrade response");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut request_upgraded = request_upgraded
|
||||
.expect("test")
|
||||
.await
|
||||
.expect("failed to upgrade request");
|
||||
|
||||
copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await;
|
||||
});
|
||||
|
||||
return Ok(response);
|
||||
}
|
||||
let proxied_response = create_proxied_response(proxied_response);
|
||||
|
||||
debug!("Responding to call with response");
|
||||
Ok(proxied_response)
|
||||
|
@ -38,17 +38,27 @@ struct ProxyTestContext {
|
||||
#[test_context(ProxyTestContext)]
|
||||
#[tokio::test]
|
||||
async fn test_websocket(ctx: &mut ProxyTestContext) {
|
||||
println!("making client connection");
|
||||
let (mut client, _) =
|
||||
connect_async(Url::parse(&format!("ws://127.0.0.1:{}", ctx.port)).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!("made client connection");
|
||||
client.send(Message::Ping("ping".into())).await.unwrap();
|
||||
client.send(Message::Ping("hello".into())).await.unwrap();
|
||||
let msg = client.next().await.unwrap().unwrap();
|
||||
|
||||
assert!(matches!(msg, Message::Pong(inner) if inner == "pong".as_bytes()));
|
||||
assert!(
|
||||
matches!(&msg, Message::Pong(inner) if inner == "hello".as_bytes()),
|
||||
"did not get pong, but {:?}",
|
||||
msg
|
||||
);
|
||||
|
||||
let msg = client.next().await.unwrap().unwrap();
|
||||
|
||||
assert!(
|
||||
matches!(&msg, Message::Text(inner) if inner == "All done"),
|
||||
"did not get text, but {:?}",
|
||||
msg
|
||||
);
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
@ -84,16 +94,21 @@ impl<'a> AsyncTestContext for ProxyTestContext {
|
||||
let ws_handler = tokio::spawn(async move {
|
||||
let ws_server = TcpListener::bind(("127.0.0.1", ws_port)).await.unwrap();
|
||||
|
||||
while let Ok((stream, addr)) = ws_server.accept().await {
|
||||
println!("incoming connection: {addr}");
|
||||
if let Ok((stream, _)) = ws_server.accept().await {
|
||||
let mut websocket = accept_async(stream).await.unwrap();
|
||||
|
||||
let msg = websocket.next().await.unwrap().unwrap();
|
||||
assert!(matches!(msg, Message::Ping(inner) if inner == "ping".as_bytes()));
|
||||
println!("past ping");
|
||||
assert!(
|
||||
matches!(&msg, Message::Ping(inner) if inner == "hello".as_bytes()),
|
||||
"did not get ping, but: {:?}",
|
||||
msg
|
||||
);
|
||||
// Tungstenite will auto send a Pong as a response to a Ping
|
||||
|
||||
websocket.send(Message::Pong("pong".into())).await.unwrap();
|
||||
println!("past pong");
|
||||
websocket
|
||||
.send(Message::Text("All done".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user