Fix early-data wakeup loss (#72)
This commit is contained in:
parent
438cb8f9c8
commit
0bf243566d
@ -42,7 +42,7 @@ See [examples/server](examples/server/src/main.rs). You can run it with:
|
|||||||
|
|
||||||
```sh
|
```sh
|
||||||
cd examples/server
|
cd examples/server
|
||||||
cargo run -- 127.0.0.1 --cert mycert.der --key mykey.der
|
cargo run -- 127.0.0.1:8000 --cert mycert.der --key mykey.der
|
||||||
```
|
```
|
||||||
|
|
||||||
### License & Origin
|
### License & Origin
|
||||||
|
@ -12,6 +12,9 @@ pub struct TlsStream<IO> {
|
|||||||
pub(crate) io: IO,
|
pub(crate) io: IO,
|
||||||
pub(crate) session: ClientConnection,
|
pub(crate) session: ClientConnection,
|
||||||
pub(crate) state: TlsState,
|
pub(crate) state: TlsState,
|
||||||
|
|
||||||
|
#[cfg(feature = "early-data")]
|
||||||
|
pub(crate) early_waker: Option<std::task::Waker>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<IO> TlsStream<IO> {
|
impl<IO> TlsStream<IO> {
|
||||||
@ -82,7 +85,26 @@ where
|
|||||||
) -> Poll<io::Result<()>> {
|
) -> Poll<io::Result<()>> {
|
||||||
match self.state {
|
match self.state {
|
||||||
#[cfg(feature = "early-data")]
|
#[cfg(feature = "early-data")]
|
||||||
TlsState::EarlyData(..) => Poll::Pending,
|
TlsState::EarlyData(..) => {
|
||||||
|
let this = self.get_mut();
|
||||||
|
|
||||||
|
// In the EarlyData state, we have not really established a Tls connection.
|
||||||
|
// Before writing data through `AsyncWrite` and completing the tls handshake,
|
||||||
|
// we ignore read readiness and return to pending.
|
||||||
|
//
|
||||||
|
// In order to avoid event loss,
|
||||||
|
// we need to register a waker and wake it up after tls is connected.
|
||||||
|
if this
|
||||||
|
.early_waker
|
||||||
|
.as_ref()
|
||||||
|
.filter(|waker| cx.waker().will_wake(waker))
|
||||||
|
.is_none()
|
||||||
|
{
|
||||||
|
this.early_waker = Some(cx.waker().clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
TlsState::Stream | TlsState::WriteShutdown => {
|
TlsState::Stream | TlsState::WriteShutdown => {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
let mut stream =
|
let mut stream =
|
||||||
@ -134,9 +156,6 @@ where
|
|||||||
if let Some(mut early_data) = stream.session.early_data() {
|
if let Some(mut early_data) = stream.session.early_data() {
|
||||||
let len = match early_data.write(buf) {
|
let len = match early_data.write(buf) {
|
||||||
Ok(n) => n,
|
Ok(n) => n,
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
|
||||||
return Poll::Pending
|
|
||||||
}
|
|
||||||
Err(err) => return Poll::Ready(Err(err)),
|
Err(err) => return Poll::Ready(Err(err)),
|
||||||
};
|
};
|
||||||
if len != 0 {
|
if len != 0 {
|
||||||
@ -160,6 +179,11 @@ where
|
|||||||
|
|
||||||
// end
|
// end
|
||||||
this.state = TlsState::Stream;
|
this.state = TlsState::Stream;
|
||||||
|
|
||||||
|
if let Some(waker) = this.early_waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
|
||||||
stream.as_mut_pin().poll_write(cx, buf)
|
stream.as_mut_pin().poll_write(cx, buf)
|
||||||
}
|
}
|
||||||
_ => stream.as_mut_pin().poll_write(cx, buf),
|
_ => stream.as_mut_pin().poll_write(cx, buf),
|
||||||
@ -188,6 +212,10 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
this.state = TlsState::Stream;
|
this.state = TlsState::Stream;
|
||||||
|
|
||||||
|
if let Some(waker) = this.early_waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,19 +223,19 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
#[cfg(feature = "early-data")]
|
||||||
|
{
|
||||||
|
// complete handshake
|
||||||
|
if matches!(self.state, TlsState::EarlyData(..)) {
|
||||||
|
ready!(self.as_mut().poll_flush(cx))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if self.state.writeable() {
|
if self.state.writeable() {
|
||||||
self.session.send_close_notify();
|
self.session.send_close_notify();
|
||||||
self.state.shutdown_write();
|
self.state.shutdown_write();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "early-data")]
|
|
||||||
{
|
|
||||||
// we skip the handshake
|
|
||||||
if let TlsState::EarlyData(..) = self.state {
|
|
||||||
return Pin::new(&mut self.io).poll_shutdown(cx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
let mut stream =
|
let mut stream =
|
||||||
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
|
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
|
||||||
|
@ -109,6 +109,9 @@ impl TlsConnector {
|
|||||||
TlsState::Stream
|
TlsState::Stream
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[cfg(feature = "early-data")]
|
||||||
|
early_waker: None,
|
||||||
|
|
||||||
session,
|
session,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
@ -10,8 +10,9 @@ use std::process::{Child, Command, Stdio};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf};
|
use tokio::io::{split, AsyncRead, AsyncWriteExt, ReadBuf};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::sync::oneshot;
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
use tokio_rustls::{
|
use tokio_rustls::{
|
||||||
client::TlsStream,
|
client::TlsStream,
|
||||||
@ -26,10 +27,16 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
|
|||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
let mut buf = [0];
|
let mut buf = [0];
|
||||||
let mut buf = &mut ReadBuf::new(&mut buf);
|
let mut buf = ReadBuf::new(&mut buf);
|
||||||
|
|
||||||
ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?;
|
ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?;
|
||||||
|
|
||||||
|
if buf.filled().is_empty() {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
} else {
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send(
|
async fn send(
|
||||||
@ -41,24 +48,48 @@ async fn send(
|
|||||||
let stream = TcpStream::connect(&addr).await?;
|
let stream = TcpStream::connect(&addr).await?;
|
||||||
let domain = rustls::ServerName::try_from("testserver.com").unwrap();
|
let domain = rustls::ServerName::try_from("testserver.com").unwrap();
|
||||||
|
|
||||||
let mut stream = connector.connect(domain, stream).await?;
|
let stream = connector.connect(domain, stream).await?;
|
||||||
stream.write_all(data).await?;
|
let (mut rd, mut wd) = split(stream);
|
||||||
stream.flush().await?;
|
let (notify, wait) = oneshot::channel();
|
||||||
|
|
||||||
// sleep 1s
|
let j = tokio::spawn(async move {
|
||||||
|
// read to eof
|
||||||
//
|
//
|
||||||
// see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html
|
// see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html
|
||||||
let sleep1 = sleep(Duration::from_secs(1));
|
let mut read_task = Read1(&mut rd);
|
||||||
futures_util::pin_mut!(sleep1);
|
let mut notify = Some(notify);
|
||||||
let mut stream = match future::select(Read1(stream), sleep1).await {
|
|
||||||
future::Either::Right((_, Read1(stream))) => stream,
|
|
||||||
future::Either::Left((Err(err), _)) => return Err(err),
|
|
||||||
future::Either::Left((Ok(_), _)) => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
stream.shutdown().await?;
|
// read once, then write
|
||||||
|
//
|
||||||
|
// this is a regression test, see https://github.com/tokio-rs/tls/issues/54
|
||||||
|
future::poll_fn(|cx| {
|
||||||
|
let ret = Pin::new(&mut read_task).poll(cx)?;
|
||||||
|
assert_eq!(ret, Poll::Pending);
|
||||||
|
|
||||||
Ok(stream)
|
notify.take().unwrap().send(()).unwrap();
|
||||||
|
|
||||||
|
Poll::Ready(Ok(())) as Poll<io::Result<_>>
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
match read_task.await {
|
||||||
|
Ok(()) => (),
|
||||||
|
Err(ref err) if err.kind() == io::ErrorKind::UnexpectedEof => (),
|
||||||
|
Err(err) => return Err(err.into()),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(rd) as io::Result<_>
|
||||||
|
});
|
||||||
|
|
||||||
|
wait.await.unwrap();
|
||||||
|
|
||||||
|
wd.write_all(data).await?;
|
||||||
|
wd.flush().await?;
|
||||||
|
wd.shutdown().await?;
|
||||||
|
|
||||||
|
let rd: tokio::io::ReadHalf<_> = j.await??;
|
||||||
|
|
||||||
|
Ok(rd.unsplit(wd))
|
||||||
}
|
}
|
||||||
|
|
||||||
struct DropKill(Child);
|
struct DropKill(Child);
|
||||||
|
Loading…
Reference in New Issue
Block a user