From ba909ed95ea7352cba17ce0493d3aff0253e609f Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 7 Nov 2019 10:57:14 +0800 Subject: [PATCH] Fix 0-RTT fallback --- src/client.rs | 38 ++++++++++++++++++++++++++------------ src/common/mod.rs | 1 - src/lib.rs | 2 +- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/src/client.rs b/src/client.rs index a8447bb..2c6229b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -167,14 +167,26 @@ where .set_eof(!this.state.readable()); #[cfg(feature = "early-data")] { - // complete handshake - while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; - } + if let TlsState::EarlyData = this.state { + let (pos, data) = &mut this.early_data; - this.state = TlsState::Stream; - let (_, data) = &mut this.early_data; - *data = Vec::new(); + // complete handshake + while stream.session.is_handshaking() { + futures::ready!(stream.handshake(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + this.state = TlsState::Stream; + let (_, data) = &mut this.early_data; + *data = Vec::new(); + } } stream.as_mut_pin().poll_flush(cx) @@ -186,14 +198,16 @@ where self.state.shutdown_write(); } + #[cfg(feature = "early-data")] { + // we skip the handshake + if let TlsState::EarlyData = self.state { + return Pin::new(&mut self.io).poll_shutdown(cx); + } + } + let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - - // TODO - // - // should we complete the handshake? - stream.as_mut_pin().poll_shutdown(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index a870131..3083534 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -240,7 +240,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' while self.session.wants_write() { futures::ready!(self.write_io(cx))?; } - Pin::new(&mut self.io).poll_shutdown(cx) } } diff --git a/src/lib.rs b/src/lib.rs index 31545dc..1c09cef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -127,7 +127,7 @@ impl TlsConnector { #[cfg(feature = "early-data")] { - Connect(if self.early_data { + Connect(if self.early_data && session.early_data().is_some() { client::MidHandshake::EarlyData(client::TlsStream { session, io: stream,