add: take_io method to LazyConfigAcceptor (#145)

* add: take_io method to LazyConfigAcceptor

The `take_io` method can be used to take back ownership of the client IO stream when an error occurs
during clientHello handshake.

An example of this is when a client tries to connect to an TLS socket expecting it to be plain text
connection. In this case take_io can be used to send a 400 response, "The plain HTTP request was
sent to HTTPS port", back to the client.

* rename test lazy_config_acceptor_take_io
This commit is contained in:
Geoff Jacobsen 2023-06-06 18:15:07 +12:00 committed by GitHub
parent 3fcf85892b
commit fcbae20f8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 0 deletions

View File

@ -203,6 +203,51 @@ where
io: Some(io),
}
}
/// Takes back the client connection. Will return `None` if called more than once or if the
/// connection has been accepted.
///
/// # Example
///
/// ```no_run
/// # fn choose_server_config(
/// # _: rustls::server::ClientHello,
/// # ) -> std::sync::Arc<rustls::ServerConfig> {
/// # unimplemented!();
/// # }
/// # #[allow(unused_variables)]
/// # async fn listen() {
/// use tokio::io::AsyncWriteExt;
/// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap();
/// let (stream, _) = listener.accept().await.unwrap();
///
/// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
/// futures_util::pin_mut!(acceptor);
///
/// match acceptor.as_mut().await {
/// Ok(start) => {
/// let clientHello = start.client_hello();
/// let config = choose_server_config(clientHello);
/// let stream = start.into_stream(config).await.unwrap();
/// // Proceed with handling the ServerConnection...
/// }
/// Err(err) => {
/// if let Some(mut stream) = acceptor.take_io() {
/// stream
/// .write_all(
/// format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err)
/// .as_bytes()
/// )
/// .await
/// .unwrap();
/// }
/// }
/// }
/// # }
/// ```
pub fn take_io(&mut self) -> Option<IO> {
self.io.take()
}
}
impl<IO> Future for LazyConfigAcceptor<IO>

View File

@ -11,6 +11,7 @@ use std::time::Duration;
use std::{io, thread};
use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio::{runtime, time};
use tokio_rustls::{LazyConfigAcceptor, TlsAcceptor, TlsConnector};
@ -215,5 +216,40 @@ async fn lazy_config_acceptor_eof() {
}
}
#[tokio::test]
async fn lazy_config_acceptor_take_io() -> Result<(), rustls::Error> {
let (mut cstream, sstream) = tokio::io::duplex(1200);
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
cstream.write_all(b"hello, world!").await.unwrap();
let mut buf = Vec::new();
cstream.read_to_end(&mut buf).await.unwrap();
tx.send(buf).unwrap();
});
let acceptor = LazyConfigAcceptor::new(rustls::server::Acceptor::default(), sstream);
futures_util::pin_mut!(acceptor);
if (acceptor.as_mut().await).is_ok() {
panic!("Expected Err(err)");
}
let server_msg = b"message from server";
let some_io = acceptor.take_io();
assert!(some_io.is_some(), "Expected Some(io)");
some_io.unwrap().write_all(server_msg).await.unwrap();
assert_eq!(rx.await.unwrap(), server_msg);
assert!(
acceptor.take_io().is_none(),
"Should not be able to take twice"
);
Ok(())
}
// Include `utils` module
include!("utils.rs");