feat: ws support
This commit is contained in:
parent
20dbf00931
commit
d4fdbf2a2e
@ -21,11 +21,11 @@ include = ["Cargo.toml", "LICENSE", "src/**/*"]
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
hyper = { version = "0.14.18", features = ["client"] }
|
hyper = { version = "0.14.18", features = ["client"] }
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
|
tokio = { version = "1.17.0", features = ["full"] }
|
||||||
tracing = "0.1.34"
|
tracing = "0.1.34"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
hyper = { version = "0.14.18", features = ["server"] }
|
hyper = { version = "0.14.18", features = ["server"] }
|
||||||
tokio = { version = "1.17.0", features = ["full"] }
|
|
||||||
futures = "0.3.21"
|
futures = "0.3.21"
|
||||||
async-trait = "0.1.53"
|
async-trait = "0.1.53"
|
||||||
async-tungstenite = { version = "0.17", features = ["tokio-runtime"] }
|
async-tungstenite = { version = "0.17", features = ["tokio-runtime"] }
|
||||||
|
37
src/lib.rs
37
src/lib.rs
@ -119,9 +119,11 @@ extern crate test;
|
|||||||
use hyper::header::{HeaderMap, HeaderName, HeaderValue, HOST};
|
use hyper::header::{HeaderMap, HeaderName, HeaderValue, HOST};
|
||||||
use hyper::http::header::{InvalidHeaderValue, ToStrError};
|
use hyper::http::header::{InvalidHeaderValue, ToStrError};
|
||||||
use hyper::http::uri::InvalidUri;
|
use hyper::http::uri::InvalidUri;
|
||||||
use hyper::{Body, Client, Error, Request, Response};
|
use hyper::upgrade::OnUpgrade;
|
||||||
|
use hyper::{upgrade, Body, Client, Error, Request, Response, StatusCode};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use std::net::IpAddr;
|
use std::net::IpAddr;
|
||||||
|
use tokio::io::copy_bidirectional;
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
|
static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
|
||||||
@ -397,11 +399,40 @@ pub async fn call<'a, T: hyper::client::connect::Connect + Clone + Send + Sync +
|
|||||||
forward_uri,
|
forward_uri,
|
||||||
client_ip
|
client_ip
|
||||||
);
|
);
|
||||||
|
let mut request = request;
|
||||||
|
|
||||||
|
let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
|
||||||
|
|
||||||
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
|
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
|
||||||
|
|
||||||
let response = client.request(proxied_request).await?;
|
let proxied_response = client.request(proxied_request).await?;
|
||||||
let proxied_response = create_proxied_response(response);
|
|
||||||
|
if proxied_response.status() == StatusCode::SWITCHING_PROTOCOLS {
|
||||||
|
// if response.status() != proxied_request.st
|
||||||
|
|
||||||
|
let mut response = Response::new(Body::empty());
|
||||||
|
*response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
|
||||||
|
|
||||||
|
for (k, v) in proxied_response.headers().into_iter() {
|
||||||
|
response.headers_mut().append(k, v.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut response_upgraded = upgrade::on(proxied_response)
|
||||||
|
.await
|
||||||
|
.expect("failed to upgrade response");
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut request_upgraded = request_upgraded
|
||||||
|
.expect("test")
|
||||||
|
.await
|
||||||
|
.expect("failed to upgrade request");
|
||||||
|
|
||||||
|
copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
let proxied_response = create_proxied_response(proxied_response);
|
||||||
|
|
||||||
debug!("Responding to call with response");
|
debug!("Responding to call with response");
|
||||||
Ok(proxied_response)
|
Ok(proxied_response)
|
||||||
|
@ -38,17 +38,27 @@ struct ProxyTestContext {
|
|||||||
#[test_context(ProxyTestContext)]
|
#[test_context(ProxyTestContext)]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_websocket(ctx: &mut ProxyTestContext) {
|
async fn test_websocket(ctx: &mut ProxyTestContext) {
|
||||||
println!("making client connection");
|
|
||||||
let (mut client, _) =
|
let (mut client, _) =
|
||||||
connect_async(Url::parse(&format!("ws://127.0.0.1:{}", ctx.port)).unwrap())
|
connect_async(Url::parse(&format!("ws://127.0.0.1:{}", ctx.port)).unwrap())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
println!("made client connection");
|
client.send(Message::Ping("hello".into())).await.unwrap();
|
||||||
client.send(Message::Ping("ping".into())).await.unwrap();
|
|
||||||
let msg = client.next().await.unwrap().unwrap();
|
let msg = client.next().await.unwrap().unwrap();
|
||||||
|
|
||||||
assert!(matches!(msg, Message::Pong(inner) if inner == "pong".as_bytes()));
|
assert!(
|
||||||
|
matches!(&msg, Message::Pong(inner) if inner == "hello".as_bytes()),
|
||||||
|
"did not get pong, but {:?}",
|
||||||
|
msg
|
||||||
|
);
|
||||||
|
|
||||||
|
let msg = client.next().await.unwrap().unwrap();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(&msg, Message::Text(inner) if inner == "All done"),
|
||||||
|
"did not get text, but {:?}",
|
||||||
|
msg
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(
|
async fn handle(
|
||||||
@ -84,16 +94,21 @@ impl<'a> AsyncTestContext for ProxyTestContext {
|
|||||||
let ws_handler = tokio::spawn(async move {
|
let ws_handler = tokio::spawn(async move {
|
||||||
let ws_server = TcpListener::bind(("127.0.0.1", ws_port)).await.unwrap();
|
let ws_server = TcpListener::bind(("127.0.0.1", ws_port)).await.unwrap();
|
||||||
|
|
||||||
while let Ok((stream, addr)) = ws_server.accept().await {
|
if let Ok((stream, _)) = ws_server.accept().await {
|
||||||
println!("incoming connection: {addr}");
|
|
||||||
let mut websocket = accept_async(stream).await.unwrap();
|
let mut websocket = accept_async(stream).await.unwrap();
|
||||||
|
|
||||||
let msg = websocket.next().await.unwrap().unwrap();
|
let msg = websocket.next().await.unwrap().unwrap();
|
||||||
assert!(matches!(msg, Message::Ping(inner) if inner == "ping".as_bytes()));
|
assert!(
|
||||||
println!("past ping");
|
matches!(&msg, Message::Ping(inner) if inner == "hello".as_bytes()),
|
||||||
|
"did not get ping, but: {:?}",
|
||||||
|
msg
|
||||||
|
);
|
||||||
|
// Tungstenite will auto send a Pong as a response to a Ping
|
||||||
|
|
||||||
websocket.send(Message::Pong("pong".into())).await.unwrap();
|
websocket
|
||||||
println!("past pong");
|
.send(Message::Text("All done".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user