From 262796af396737266d055b76643201a523182cbc Mon Sep 17 00:00:00 2001 From: quininer Date: Thu, 7 Nov 2019 22:52:27 +0800 Subject: [PATCH] Clean TlsState --- Cargo.toml | 1 - src/client.rs | 40 ++++++++++++---------------------- src/lib.rs | 59 ++++++++++++++++++++++++++------------------------- src/server.rs | 2 +- 4 files changed, 45 insertions(+), 57 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8b44e63..f205f6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ edition = "2018" github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" } [dependencies] -smallvec = "0.6" tokio-io = "=0.2.0-alpha.6" futures-core-preview = "=0.3.0-alpha.19" pin-project = "0.4" diff --git a/src/client.rs b/src/client.rs index 2c6229b..632e4ed 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,15 +8,10 @@ pub struct TlsStream { pub(crate) io: IO, pub(crate) session: ClientSession, pub(crate) state: TlsState, - - #[cfg(feature = "early-data")] - pub(crate) early_data: (usize, Vec), } pub(crate) enum MidHandshake { Handshaking(TlsStream), - #[cfg(feature = "early-data")] - EarlyData(TlsStream), End, } @@ -48,23 +43,23 @@ where let this = self.get_mut(); if let MidHandshake::Handshaking(stream) = this { - let eof = !stream.state.readable(); - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session).set_eof(eof); + if !stream.state.is_early_data() { + let eof = !stream.state.readable(); + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(io, session).set_eof(eof); - while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; - } + while stream.session.is_handshaking() { + futures::ready!(stream.handshake(cx))?; + } - while stream.session.wants_write() { - futures::ready!(stream.write_io(cx))?; + while stream.session.wants_write() { + futures::ready!(stream.write_io(cx))?; + } } } match mem::replace(this, MidHandshake::End) { MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), - #[cfg(feature = "early-data")] - MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)), MidHandshake::End => panic!(), } } @@ -81,7 +76,7 @@ where fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { match self.state { #[cfg(feature = "early-data")] - TlsState::EarlyData => Poll::Pending, + TlsState::EarlyData(..) => Poll::Pending, TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) @@ -122,11 +117,9 @@ where match this.state { #[cfg(feature = "early-data")] - TlsState::EarlyData => { + TlsState::EarlyData(ref mut pos, ref mut data) => { use std::io::Write; - let (pos, data) = &mut this.early_data; - // write early data if let Some(mut early_data) = stream.session.early_data() { let len = match early_data.write(buf) { @@ -154,7 +147,6 @@ where // end this.state = TlsState::Stream; - *data = Vec::new(); stream.as_mut_pin().poll_write(cx, buf) } _ => stream.as_mut_pin().poll_write(cx, buf), @@ -167,9 +159,7 @@ where .set_eof(!this.state.readable()); #[cfg(feature = "early-data")] { - if let TlsState::EarlyData = this.state { - let (pos, data) = &mut this.early_data; - + if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { // complete handshake while stream.session.is_handshaking() { futures::ready!(stream.handshake(cx))?; @@ -184,8 +174,6 @@ where } this.state = TlsState::Stream; - let (_, data) = &mut this.early_data; - *data = Vec::new(); } } @@ -200,7 +188,7 @@ where #[cfg(feature = "early-data")] { // we skip the handshake - if let TlsState::EarlyData = self.state { + if let TlsState::EarlyData(..) = self.state { return Pin::new(&mut self.io).poll_shutdown(cx); } } diff --git a/src/lib.rs b/src/lib.rs index 1c09cef..d5113e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,10 +19,10 @@ use common::Stream; pub use rustls; pub use webpki; -#[derive(Debug, Copy, Clone)] +#[derive(Debug)] enum TlsState { #[cfg(feature = "early-data")] - EarlyData, + EarlyData(usize, Vec), Stream, ReadShutdown, WriteShutdown, @@ -51,12 +51,25 @@ impl TlsState { } } - fn readable(self) -> bool { + fn readable(&self) -> bool { match self { TlsState::ReadShutdown | TlsState::FullyShutdown => false, _ => true, } } + + #[cfg(feature = "early-data")] + fn is_early_data(&self) -> bool { + match self { + TlsState::EarlyData(..) => true, + _ => false + } + } + + #[cfg(not(feature = "early-data"))] + const fn is_early_data(&self) -> bool { + false + } } /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. @@ -100,6 +113,7 @@ impl TlsConnector { self } + #[inline] pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, @@ -107,7 +121,6 @@ impl TlsConnector { self.connect_with(domain, stream, |_| ()) } - #[inline] pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, @@ -116,33 +129,21 @@ impl TlsConnector { let mut session = ClientSession::new(&self.inner, domain); f(&mut session); - #[cfg(not(feature = "early-data"))] - { - Connect(client::MidHandshake::Handshaking(client::TlsStream { - session, - io: stream, - state: TlsState::Stream, - })) - } + Connect(client::MidHandshake::Handshaking(client::TlsStream { + io: stream, - #[cfg(feature = "early-data")] - { - Connect(if self.early_data && session.early_data().is_some() { - client::MidHandshake::EarlyData(client::TlsStream { - session, - io: stream, - state: TlsState::EarlyData, - early_data: (0, Vec::new()), - }) + #[cfg(not(feature = "early-data"))] + state: TlsState::Stream, + + #[cfg(feature = "early-data")] + state: if self.early_data && session.early_data().is_some() { + TlsState::EarlyData(0, Vec::new()) } else { - client::MidHandshake::Handshaking(client::TlsStream { - session, - io: stream, - state: TlsState::Stream, - early_data: (0, Vec::new()), - }) - }) - } + TlsState::Stream + }, + + session + })) } } diff --git a/src/server.rs b/src/server.rs index ac72904..87ce3f8 100644 --- a/src/server.rs +++ b/src/server.rs @@ -76,7 +76,7 @@ where let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - match this.state { + match &this.state { TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) { Poll::Ready(Ok(0)) => { this.state.shutdown_read();