diff --git a/src/client.rs b/src/client.rs index 4803410..26607b3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -113,6 +113,8 @@ impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) diff --git a/src/lib.rs b/src/lib.rs index f631a09..382e43a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -195,7 +195,3 @@ impl Future for Accept { Pin::new(&mut self.0).poll(cx) } } - -#[cfg(feature = "early-data")] -#[cfg(test)] -mod test_0rtt; diff --git a/src/server.rs b/src/server.rs index 91c1cb4..4dad3f6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -105,6 +105,8 @@ impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) diff --git a/tests/early-data.rs b/tests/early-data.rs new file mode 100644 index 0000000..57645d1 --- /dev/null +++ b/tests/early-data.rs @@ -0,0 +1,118 @@ +#![cfg(feature = "early-data")] + +use std::io::{ self, BufReader, BufRead, Cursor }; +use std::process::{ Command, Child, Stdio }; +use std::net::SocketAddr; +use std::sync::Arc; +use std::marker::Unpin; +use std::pin::{ Pin }; +use std::task::{ Context, Poll }; +use std::time::Duration; +use tokio::prelude::*; +use tokio::net::TcpStream; +use tokio::io::split; +use tokio::timer::delay_for; +use futures_util::{ future, ready }; +use rustls::ClientConfig; +use tokio_rustls::{ TlsConnector, client::TlsStream }; + + +struct Read1(T); + +impl Future for Read1 { + type Output = io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut buf = [0]; + ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?; + Poll::Pending + } +} + +async fn send(config: Arc, addr: SocketAddr, data: &[u8]) + -> io::Result> +{ + let connector = TlsConnector::from(config) + .early_data(true); + let stream = TcpStream::connect(&addr).await?; + let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap(); + + let mut stream = connector.connect(domain, stream).await?; + stream.write_all(data).await?; + stream.flush().await?; + + let (r, mut w) = split(stream); + let fut = Read1(r); + let fut2 = async move { + // sleep 3s + // + // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html + delay_for(Duration::from_secs(3)).await; + w.shutdown().await?; + Ok(w) as io::Result<_> + }; + + let stream = match future::select(fut, fut2.boxed()).await { + future::Either::Left(_) => unreachable!(), + future::Either::Right((Ok(w), Read1(r))) => r.unsplit(w), + future::Either::Right((Err(err), _)) => return Err(err) + }; + + Ok(stream) +} + +struct DropKill(Child); + +impl Drop for DropKill { + fn drop(&mut self) { + self.0.kill().unwrap(); + } +} + +#[tokio::test] +async fn test_0rtt() -> io::Result<()> { + let mut handle = Command::new("openssl") + .arg("s_server") + .arg("-early_data") + .arg("-tls1_3") + .args(&["-cert", "./tests/end.cert"]) + .args(&["-key", "./tests/end.rsa"]) + .args(&["-port", "12354"]) + .stdout(Stdio::piped()) + .spawn() + .map(DropKill)?; + + // wait openssl server + delay_for(Duration::from_secs(3)).await; + + let mut config = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); + config.root_store.add_pem_file(&mut chain).unwrap(); + config.versions = vec![rustls::ProtocolVersion::TLSv1_3]; + config.enable_early_data = true; + let config = Arc::new(config); + let addr = SocketAddr::from(([127, 0, 0, 1], 12354)); + + let io = send(config.clone(), addr, b"hello").await?; + assert!(!io.get_ref().1.is_early_data_accepted()); + + let io = send(config, addr, b"world!").await?; + assert!(io.get_ref().1.is_early_data_accepted()); + + let stdout = handle.0.stdout.as_mut().unwrap(); + let mut lines = BufReader::new(stdout).lines(); + + for line in lines.by_ref() { + if line?.contains("hello") { + break + } + } + + for line in lines.by_ref() { + if line?.contains("world!") { + break + } + } + + Ok(()) +} diff --git a/tests/test.rs b/tests/test.rs index 6ebdee9..30f8e8a 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -7,6 +7,7 @@ use lazy_static::lazy_static; use tokio::prelude::*; use tokio::runtime::current_thread; use tokio::net::{ TcpListener, TcpStream }; +use tokio::io::split; use futures_util::try_future::TryFutureExt; use rustls::{ ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; @@ -42,17 +43,10 @@ lazy_static!{ while let Some(stream) = incoming.next().await { let acceptor = acceptor.clone(); let fut = async move { - let mut stream = acceptor.accept(stream?).await?; + let stream = acceptor.accept(stream?).await?; - // TODO split - // - // let (mut reader, mut write) = stream.split(); - // reader.copy(&mut write).await?; - - let mut buf = vec![0; 8192]; - let n = stream.read(&mut buf).await?; - stream.write(&buf[..n]).await?; - stream.flush().await?; + let (mut reader, mut writer) = split(stream); + reader.copy(&mut writer).await?; Ok(()) as io::Result<()> }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); @@ -67,7 +61,7 @@ lazy_static!{ }); let addr = recv.recv().unwrap(); - (addr, "localhost", CHAIN) + (addr, "testserver.com", CHAIN) }; }