diff --git a/src/lib.rs b/src/lib.rs index f18a9b3..000245c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -203,6 +203,51 @@ where io: Some(io), } } + + /// Takes back the client connection. Will return `None` if called more than once or if the + /// connection has been accepted. + /// + /// # Example + /// + /// ```no_run + /// # fn choose_server_config( + /// # _: rustls::server::ClientHello, + /// # ) -> std::sync::Arc { + /// # unimplemented!(); + /// # } + /// # #[allow(unused_variables)] + /// # async fn listen() { + /// use tokio::io::AsyncWriteExt; + /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap(); + /// let (stream, _) = listener.accept().await.unwrap(); + /// + /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream); + /// futures_util::pin_mut!(acceptor); + /// + /// match acceptor.as_mut().await { + /// Ok(start) => { + /// let clientHello = start.client_hello(); + /// let config = choose_server_config(clientHello); + /// let stream = start.into_stream(config).await.unwrap(); + /// // Proceed with handling the ServerConnection... + /// } + /// Err(err) => { + /// if let Some(mut stream) = acceptor.take_io() { + /// stream + /// .write_all( + /// format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err) + /// .as_bytes() + /// ) + /// .await + /// .unwrap(); + /// } + /// } + /// } + /// # } + /// ``` + pub fn take_io(&mut self) -> Option { + self.io.take() + } } impl Future for LazyConfigAcceptor diff --git a/tests/test.rs b/tests/test.rs index c7ba137..c206890 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -11,6 +11,7 @@ use std::time::Duration; use std::{io, thread}; use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::oneshot; use tokio::{runtime, time}; use tokio_rustls::{LazyConfigAcceptor, TlsAcceptor, TlsConnector}; @@ -215,5 +216,40 @@ async fn lazy_config_acceptor_eof() { } } +#[tokio::test] +async fn lazy_config_acceptor_take_io() -> Result<(), rustls::Error> { + let (mut cstream, sstream) = tokio::io::duplex(1200); + + let (tx, rx) = oneshot::channel(); + + tokio::spawn(async move { + cstream.write_all(b"hello, world!").await.unwrap(); + + let mut buf = Vec::new(); + cstream.read_to_end(&mut buf).await.unwrap(); + tx.send(buf).unwrap(); + }); + + let acceptor = LazyConfigAcceptor::new(rustls::server::Acceptor::default(), sstream); + futures_util::pin_mut!(acceptor); + if (acceptor.as_mut().await).is_ok() { + panic!("Expected Err(err)"); + } + + let server_msg = b"message from server"; + + let some_io = acceptor.take_io(); + assert!(some_io.is_some(), "Expected Some(io)"); + some_io.unwrap().write_all(server_msg).await.unwrap(); + + assert_eq!(rx.await.unwrap(), server_msg); + + assert!( + acceptor.take_io().is_none(), + "Should not be able to take twice" + ); + Ok(()) +} + // Include `utils` module include!("utils.rs");