Fix 0-RTT fallback

This commit is contained in:
quininer 2019-11-07 10:57:14 +08:00
parent ff3d0a4de3
commit ba909ed95e
3 changed files with 27 additions and 14 deletions

View File

@ -167,15 +167,27 @@ where
.set_eof(!this.state.readable()); .set_eof(!this.state.readable());
#[cfg(feature = "early-data")] { #[cfg(feature = "early-data")] {
if let TlsState::EarlyData = this.state {
let (pos, data) = &mut this.early_data;
// complete handshake // complete handshake
while stream.session.is_handshaking() { while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?; futures::ready!(stream.handshake(cx))?;
} }
// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}
this.state = TlsState::Stream; this.state = TlsState::Stream;
let (_, data) = &mut this.early_data; let (_, data) = &mut this.early_data;
*data = Vec::new(); *data = Vec::new();
} }
}
stream.as_mut_pin().poll_flush(cx) stream.as_mut_pin().poll_flush(cx)
} }
@ -186,14 +198,16 @@ where
self.state.shutdown_write(); self.state.shutdown_write();
} }
#[cfg(feature = "early-data")] {
// we skip the handshake
if let TlsState::EarlyData = self.state {
return Pin::new(&mut self.io).poll_shutdown(cx);
}
}
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()); .set_eof(!this.state.readable());
// TODO
//
// should we complete the handshake?
stream.as_mut_pin().poll_shutdown(cx) stream.as_mut_pin().poll_shutdown(cx)
} }
} }

View File

@ -240,7 +240,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
while self.session.wants_write() { while self.session.wants_write() {
futures::ready!(self.write_io(cx))?; futures::ready!(self.write_io(cx))?;
} }
Pin::new(&mut self.io).poll_shutdown(cx) Pin::new(&mut self.io).poll_shutdown(cx)
} }
} }

View File

@ -127,7 +127,7 @@ impl TlsConnector {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
{ {
Connect(if self.early_data { Connect(if self.early_data && session.early_data().is_some() {
client::MidHandshake::EarlyData(client::TlsStream { client::MidHandshake::EarlyData(client::TlsStream {
session, session,
io: stream, io: stream,