diff --git a/tokio-rustls/README.md b/tokio-rustls/README.md index 6faef8f..e870aa2 100644 --- a/tokio-rustls/README.md +++ b/tokio-rustls/README.md @@ -42,7 +42,7 @@ See [examples/server](examples/server/src/main.rs). You can run it with: ```sh cd examples/server -cargo run -- 127.0.0.1 --cert mycert.der --key mykey.der +cargo run -- 127.0.0.1:8000 --cert mycert.der --key mykey.der ``` ### License & Origin diff --git a/tokio-rustls/src/client.rs b/tokio-rustls/src/client.rs index 7292b1a..1244019 100644 --- a/tokio-rustls/src/client.rs +++ b/tokio-rustls/src/client.rs @@ -12,6 +12,9 @@ pub struct TlsStream { pub(crate) io: IO, pub(crate) session: ClientConnection, pub(crate) state: TlsState, + + #[cfg(feature = "early-data")] + pub(crate) early_waker: Option, } impl TlsStream { @@ -82,7 +85,26 @@ where ) -> Poll> { match self.state { #[cfg(feature = "early-data")] - TlsState::EarlyData(..) => Poll::Pending, + TlsState::EarlyData(..) => { + let this = self.get_mut(); + + // In the EarlyData state, we have not really established a Tls connection. + // Before writing data through `AsyncWrite` and completing the tls handshake, + // we ignore read readiness and return to pending. + // + // In order to avoid event loss, + // we need to register a waker and wake it up after tls is connected. + if this + .early_waker + .as_ref() + .filter(|waker| cx.waker().will_wake(waker)) + .is_none() + { + this.early_waker = Some(cx.waker().clone()); + } + + Poll::Pending + } TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); let mut stream = @@ -134,9 +156,6 @@ where if let Some(mut early_data) = stream.session.early_data() { let len = match early_data.write(buf) { Ok(n) => n, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - return Poll::Pending - } Err(err) => return Poll::Ready(Err(err)), }; if len != 0 { @@ -160,6 +179,11 @@ where // end this.state = TlsState::Stream; + + if let Some(waker) = this.early_waker.take() { + waker.wake(); + } + stream.as_mut_pin().poll_write(cx, buf) } _ => stream.as_mut_pin().poll_write(cx, buf), @@ -188,6 +212,10 @@ where } this.state = TlsState::Stream; + + if let Some(waker) = this.early_waker.take() { + waker.wake(); + } } } @@ -195,19 +223,19 @@ where } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[cfg(feature = "early-data")] + { + // complete handshake + if matches!(self.state, TlsState::EarlyData(..)) { + ready!(self.as_mut().poll_flush(cx))?; + } + } + if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); } - #[cfg(feature = "early-data")] - { - // we skip the handshake - if let TlsState::EarlyData(..) = self.state { - return Pin::new(&mut self.io).poll_shutdown(cx); - } - } - let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); diff --git a/tokio-rustls/src/lib.rs b/tokio-rustls/src/lib.rs index fee3cce..eedf463 100644 --- a/tokio-rustls/src/lib.rs +++ b/tokio-rustls/src/lib.rs @@ -109,6 +109,9 @@ impl TlsConnector { TlsState::Stream }, + #[cfg(feature = "early-data")] + early_waker: None, + session, })) } diff --git a/tokio-rustls/tests/early-data.rs b/tokio-rustls/tests/early-data.rs index 80d6d15..b619f6d 100644 --- a/tokio-rustls/tests/early-data.rs +++ b/tokio-rustls/tests/early-data.rs @@ -10,8 +10,9 @@ use std::process::{Child, Command, Stdio}; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf}; +use tokio::io::{split, AsyncRead, AsyncWriteExt, ReadBuf}; use tokio::net::TcpStream; +use tokio::sync::oneshot; use tokio::time::sleep; use tokio_rustls::{ client::TlsStream, @@ -26,9 +27,15 @@ impl Future for Read1 { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut buf = [0]; - let mut buf = &mut ReadBuf::new(&mut buf); + let mut buf = ReadBuf::new(&mut buf); + ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?; - Poll::Pending + + if buf.filled().is_empty() { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } } } @@ -41,24 +48,48 @@ async fn send( let stream = TcpStream::connect(&addr).await?; let domain = rustls::ServerName::try_from("testserver.com").unwrap(); - let mut stream = connector.connect(domain, stream).await?; - stream.write_all(data).await?; - stream.flush().await?; + let stream = connector.connect(domain, stream).await?; + let (mut rd, mut wd) = split(stream); + let (notify, wait) = oneshot::channel(); - // sleep 1s - // - // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html - let sleep1 = sleep(Duration::from_secs(1)); - futures_util::pin_mut!(sleep1); - let mut stream = match future::select(Read1(stream), sleep1).await { - future::Either::Right((_, Read1(stream))) => stream, - future::Either::Left((Err(err), _)) => return Err(err), - future::Either::Left((Ok(_), _)) => unreachable!(), - }; + let j = tokio::spawn(async move { + // read to eof + // + // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html + let mut read_task = Read1(&mut rd); + let mut notify = Some(notify); - stream.shutdown().await?; + // read once, then write + // + // this is a regression test, see https://github.com/tokio-rs/tls/issues/54 + future::poll_fn(|cx| { + let ret = Pin::new(&mut read_task).poll(cx)?; + assert_eq!(ret, Poll::Pending); - Ok(stream) + notify.take().unwrap().send(()).unwrap(); + + Poll::Ready(Ok(())) as Poll> + }) + .await?; + + match read_task.await { + Ok(()) => (), + Err(ref err) if err.kind() == io::ErrorKind::UnexpectedEof => (), + Err(err) => return Err(err.into()), + } + + Ok(rd) as io::Result<_> + }); + + wait.await.unwrap(); + + wd.write_all(data).await?; + wd.flush().await?; + wd.shutdown().await?; + + let rd: tokio::io::ReadHalf<_> = j.await??; + + Ok(rd.unsplit(wd)) } struct DropKill(Child);