Fix early-data wakeup loss (#72)

This commit is contained in:
quininer 2021-10-05 16:43:54 +08:00 committed by GitHub
parent 438cb8f9c8
commit 0bf243566d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 93 additions and 31 deletions

View File

@ -42,7 +42,7 @@ See [examples/server](examples/server/src/main.rs). You can run it with:
```sh ```sh
cd examples/server cd examples/server
cargo run -- 127.0.0.1 --cert mycert.der --key mykey.der cargo run -- 127.0.0.1:8000 --cert mycert.der --key mykey.der
``` ```
### License & Origin ### License & Origin

View File

@ -12,6 +12,9 @@ pub struct TlsStream<IO> {
pub(crate) io: IO, pub(crate) io: IO,
pub(crate) session: ClientConnection, pub(crate) session: ClientConnection,
pub(crate) state: TlsState, pub(crate) state: TlsState,
#[cfg(feature = "early-data")]
pub(crate) early_waker: Option<std::task::Waker>,
} }
impl<IO> TlsStream<IO> { impl<IO> TlsStream<IO> {
@ -82,7 +85,26 @@ where
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
match self.state { match self.state {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
TlsState::EarlyData(..) => Poll::Pending, TlsState::EarlyData(..) => {
let this = self.get_mut();
// In the EarlyData state, we have not really established a Tls connection.
// Before writing data through `AsyncWrite` and completing the tls handshake,
// we ignore read readiness and return to pending.
//
// In order to avoid event loss,
// we need to register a waker and wake it up after tls is connected.
if this
.early_waker
.as_ref()
.filter(|waker| cx.waker().will_wake(waker))
.is_none()
{
this.early_waker = Some(cx.waker().clone());
}
Poll::Pending
}
TlsState::Stream | TlsState::WriteShutdown => { TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = let mut stream =
@ -134,9 +156,6 @@ where
if let Some(mut early_data) = stream.session.early_data() { if let Some(mut early_data) = stream.session.early_data() {
let len = match early_data.write(buf) { let len = match early_data.write(buf) {
Ok(n) => n, Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Poll::Pending
}
Err(err) => return Poll::Ready(Err(err)), Err(err) => return Poll::Ready(Err(err)),
}; };
if len != 0 { if len != 0 {
@ -160,6 +179,11 @@ where
// end // end
this.state = TlsState::Stream; this.state = TlsState::Stream;
if let Some(waker) = this.early_waker.take() {
waker.wake();
}
stream.as_mut_pin().poll_write(cx, buf) stream.as_mut_pin().poll_write(cx, buf)
} }
_ => stream.as_mut_pin().poll_write(cx, buf), _ => stream.as_mut_pin().poll_write(cx, buf),
@ -188,6 +212,10 @@ where
} }
this.state = TlsState::Stream; this.state = TlsState::Stream;
if let Some(waker) = this.early_waker.take() {
waker.wake();
}
} }
} }
@ -195,19 +223,19 @@ where
} }
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
#[cfg(feature = "early-data")]
{
// complete handshake
if matches!(self.state, TlsState::EarlyData(..)) {
ready!(self.as_mut().poll_flush(cx))?;
}
}
if self.state.writeable() { if self.state.writeable() {
self.session.send_close_notify(); self.session.send_close_notify();
self.state.shutdown_write(); self.state.shutdown_write();
} }
#[cfg(feature = "early-data")]
{
// we skip the handshake
if let TlsState::EarlyData(..) = self.state {
return Pin::new(&mut self.io).poll_shutdown(cx);
}
}
let this = self.get_mut(); let this = self.get_mut();
let mut stream = let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

View File

@ -109,6 +109,9 @@ impl TlsConnector {
TlsState::Stream TlsState::Stream
}, },
#[cfg(feature = "early-data")]
early_waker: None,
session, session,
})) }))
} }

View File

@ -10,8 +10,9 @@ use std::process::{Child, Command, Stdio};
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf}; use tokio::io::{split, AsyncRead, AsyncWriteExt, ReadBuf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::time::sleep; use tokio::time::sleep;
use tokio_rustls::{ use tokio_rustls::{
client::TlsStream, client::TlsStream,
@ -26,10 +27,16 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut buf = [0]; let mut buf = [0];
let mut buf = &mut ReadBuf::new(&mut buf); let mut buf = ReadBuf::new(&mut buf);
ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?; ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?;
if buf.filled().is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending Poll::Pending
} }
}
} }
async fn send( async fn send(
@ -41,24 +48,48 @@ async fn send(
let stream = TcpStream::connect(&addr).await?; let stream = TcpStream::connect(&addr).await?;
let domain = rustls::ServerName::try_from("testserver.com").unwrap(); let domain = rustls::ServerName::try_from("testserver.com").unwrap();
let mut stream = connector.connect(domain, stream).await?; let stream = connector.connect(domain, stream).await?;
stream.write_all(data).await?; let (mut rd, mut wd) = split(stream);
stream.flush().await?; let (notify, wait) = oneshot::channel();
// sleep 1s let j = tokio::spawn(async move {
// read to eof
// //
// see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html
let sleep1 = sleep(Duration::from_secs(1)); let mut read_task = Read1(&mut rd);
futures_util::pin_mut!(sleep1); let mut notify = Some(notify);
let mut stream = match future::select(Read1(stream), sleep1).await {
future::Either::Right((_, Read1(stream))) => stream,
future::Either::Left((Err(err), _)) => return Err(err),
future::Either::Left((Ok(_), _)) => unreachable!(),
};
stream.shutdown().await?; // read once, then write
//
// this is a regression test, see https://github.com/tokio-rs/tls/issues/54
future::poll_fn(|cx| {
let ret = Pin::new(&mut read_task).poll(cx)?;
assert_eq!(ret, Poll::Pending);
Ok(stream) notify.take().unwrap().send(()).unwrap();
Poll::Ready(Ok(())) as Poll<io::Result<_>>
})
.await?;
match read_task.await {
Ok(()) => (),
Err(ref err) if err.kind() == io::ErrorKind::UnexpectedEof => (),
Err(err) => return Err(err.into()),
}
Ok(rd) as io::Result<_>
});
wait.await.unwrap();
wd.write_all(data).await?;
wd.flush().await?;
wd.shutdown().await?;
let rd: tokio::io::ReadHalf<_> = j.await??;
Ok(rd.unsplit(wd))
} }
struct DropKill(Child); struct DropKill(Child);