Clean TlsState

This commit is contained in:
quininer 2019-11-07 22:52:27 +08:00
parent fe113dc6b0
commit 262796af39
4 changed files with 45 additions and 57 deletions

View File

@ -15,7 +15,6 @@ edition = "2018"
github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" } github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" }
[dependencies] [dependencies]
smallvec = "0.6"
tokio-io = "=0.2.0-alpha.6" tokio-io = "=0.2.0-alpha.6"
futures-core-preview = "=0.3.0-alpha.19" futures-core-preview = "=0.3.0-alpha.19"
pin-project = "0.4" pin-project = "0.4"

View File

@ -8,15 +8,10 @@ pub struct TlsStream<IO> {
pub(crate) io: IO, pub(crate) io: IO,
pub(crate) session: ClientSession, pub(crate) session: ClientSession,
pub(crate) state: TlsState, pub(crate) state: TlsState,
#[cfg(feature = "early-data")]
pub(crate) early_data: (usize, Vec<u8>),
} }
pub(crate) enum MidHandshake<IO> { pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>), Handshaking(TlsStream<IO>),
#[cfg(feature = "early-data")]
EarlyData(TlsStream<IO>),
End, End,
} }
@ -48,6 +43,7 @@ where
let this = self.get_mut(); let this = self.get_mut();
if let MidHandshake::Handshaking(stream) = this { if let MidHandshake::Handshaking(stream) = this {
if !stream.state.is_early_data() {
let eof = !stream.state.readable(); let eof = !stream.state.readable();
let (io, session) = stream.get_mut(); let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session).set_eof(eof); let mut stream = Stream::new(io, session).set_eof(eof);
@ -60,11 +56,10 @@ where
futures::ready!(stream.write_io(cx))?; futures::ready!(stream.write_io(cx))?;
} }
} }
}
match mem::replace(this, MidHandshake::End) { match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
#[cfg(feature = "early-data")]
MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
MidHandshake::End => panic!(), MidHandshake::End => panic!(),
} }
} }
@ -81,7 +76,7 @@ where
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
match self.state { match self.state {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
TlsState::EarlyData => Poll::Pending, TlsState::EarlyData(..) => Poll::Pending,
TlsState::Stream | TlsState::WriteShutdown => { TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream = Stream::new(&mut this.io, &mut this.session)
@ -122,11 +117,9 @@ where
match this.state { match this.state {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
TlsState::EarlyData => { TlsState::EarlyData(ref mut pos, ref mut data) => {
use std::io::Write; use std::io::Write;
let (pos, data) = &mut this.early_data;
// write early data // write early data
if let Some(mut early_data) = stream.session.early_data() { if let Some(mut early_data) = stream.session.early_data() {
let len = match early_data.write(buf) { let len = match early_data.write(buf) {
@ -154,7 +147,6 @@ where
// end // end
this.state = TlsState::Stream; this.state = TlsState::Stream;
*data = Vec::new();
stream.as_mut_pin().poll_write(cx, buf) stream.as_mut_pin().poll_write(cx, buf)
} }
_ => stream.as_mut_pin().poll_write(cx, buf), _ => stream.as_mut_pin().poll_write(cx, buf),
@ -167,9 +159,7 @@ where
.set_eof(!this.state.readable()); .set_eof(!this.state.readable());
#[cfg(feature = "early-data")] { #[cfg(feature = "early-data")] {
if let TlsState::EarlyData = this.state { if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
let (pos, data) = &mut this.early_data;
// complete handshake // complete handshake
while stream.session.is_handshaking() { while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?; futures::ready!(stream.handshake(cx))?;
@ -184,8 +174,6 @@ where
} }
this.state = TlsState::Stream; this.state = TlsState::Stream;
let (_, data) = &mut this.early_data;
*data = Vec::new();
} }
} }
@ -200,7 +188,7 @@ where
#[cfg(feature = "early-data")] { #[cfg(feature = "early-data")] {
// we skip the handshake // we skip the handshake
if let TlsState::EarlyData = self.state { if let TlsState::EarlyData(..) = self.state {
return Pin::new(&mut self.io).poll_shutdown(cx); return Pin::new(&mut self.io).poll_shutdown(cx);
} }
} }

View File

@ -19,10 +19,10 @@ use common::Stream;
pub use rustls; pub use rustls;
pub use webpki; pub use webpki;
#[derive(Debug, Copy, Clone)] #[derive(Debug)]
enum TlsState { enum TlsState {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
EarlyData, EarlyData(usize, Vec<u8>),
Stream, Stream,
ReadShutdown, ReadShutdown,
WriteShutdown, WriteShutdown,
@ -51,12 +51,25 @@ impl TlsState {
} }
} }
fn readable(self) -> bool { fn readable(&self) -> bool {
match self { match self {
TlsState::ReadShutdown | TlsState::FullyShutdown => false, TlsState::ReadShutdown | TlsState::FullyShutdown => false,
_ => true, _ => true,
} }
} }
#[cfg(feature = "early-data")]
fn is_early_data(&self) -> bool {
match self {
TlsState::EarlyData(..) => true,
_ => false
}
}
#[cfg(not(feature = "early-data"))]
const fn is_early_data(&self) -> bool {
false
}
} }
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
@ -100,6 +113,7 @@ impl TlsConnector {
self self
} }
#[inline]
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO> pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
where where
IO: AsyncRead + AsyncWrite + Unpin, IO: AsyncRead + AsyncWrite + Unpin,
@ -107,7 +121,6 @@ impl TlsConnector {
self.connect_with(domain, stream, |_| ()) self.connect_with(domain, stream, |_| ())
} }
#[inline]
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO> pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
where where
IO: AsyncRead + AsyncWrite + Unpin, IO: AsyncRead + AsyncWrite + Unpin,
@ -116,33 +129,21 @@ impl TlsConnector {
let mut session = ClientSession::new(&self.inner, domain); let mut session = ClientSession::new(&self.inner, domain);
f(&mut session); f(&mut session);
#[cfg(not(feature = "early-data"))]
{
Connect(client::MidHandshake::Handshaking(client::TlsStream { Connect(client::MidHandshake::Handshaking(client::TlsStream {
session,
io: stream, io: stream,
#[cfg(not(feature = "early-data"))]
state: TlsState::Stream, state: TlsState::Stream,
}))
}
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
{ state: if self.early_data && session.early_data().is_some() {
Connect(if self.early_data && session.early_data().is_some() { TlsState::EarlyData(0, Vec::new())
client::MidHandshake::EarlyData(client::TlsStream {
session,
io: stream,
state: TlsState::EarlyData,
early_data: (0, Vec::new()),
})
} else { } else {
client::MidHandshake::Handshaking(client::TlsStream { TlsState::Stream
session, },
io: stream,
state: TlsState::Stream, session
early_data: (0, Vec::new()), }))
})
})
}
} }
} }

View File

@ -76,7 +76,7 @@ where
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()); .set_eof(!this.state.readable());
match this.state { match &this.state {
TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) { TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(0)) => { Poll::Ready(Ok(0)) => {
this.state.shutdown_read(); this.state.shutdown_read();