Clean TlsState

This commit is contained in:
quininer 2019-11-07 22:52:27 +08:00
parent fe113dc6b0
commit 262796af39
4 changed files with 45 additions and 57 deletions

View File

@ -15,7 +15,6 @@ edition = "2018"
github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" }
[dependencies]
smallvec = "0.6"
tokio-io = "=0.2.0-alpha.6"
futures-core-preview = "=0.3.0-alpha.19"
pin-project = "0.4"

View File

@ -8,15 +8,10 @@ pub struct TlsStream<IO> {
pub(crate) io: IO,
pub(crate) session: ClientSession,
pub(crate) state: TlsState,
#[cfg(feature = "early-data")]
pub(crate) early_data: (usize, Vec<u8>),
}
pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>),
#[cfg(feature = "early-data")]
EarlyData(TlsStream<IO>),
End,
}
@ -48,23 +43,23 @@ where
let this = self.get_mut();
if let MidHandshake::Handshaking(stream) = this {
let eof = !stream.state.readable();
let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session).set_eof(eof);
if !stream.state.is_early_data() {
let eof = !stream.state.readable();
let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session).set_eof(eof);
while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?;
}
while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?;
}
while stream.session.wants_write() {
futures::ready!(stream.write_io(cx))?;
while stream.session.wants_write() {
futures::ready!(stream.write_io(cx))?;
}
}
}
match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
#[cfg(feature = "early-data")]
MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
MidHandshake::End => panic!(),
}
}
@ -81,7 +76,7 @@ where
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
match self.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData => Poll::Pending,
TlsState::EarlyData(..) => Poll::Pending,
TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session)
@ -122,11 +117,9 @@ where
match this.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData => {
TlsState::EarlyData(ref mut pos, ref mut data) => {
use std::io::Write;
let (pos, data) = &mut this.early_data;
// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = match early_data.write(buf) {
@ -154,7 +147,6 @@ where
// end
this.state = TlsState::Stream;
*data = Vec::new();
stream.as_mut_pin().poll_write(cx, buf)
}
_ => stream.as_mut_pin().poll_write(cx, buf),
@ -167,9 +159,7 @@ where
.set_eof(!this.state.readable());
#[cfg(feature = "early-data")] {
if let TlsState::EarlyData = this.state {
let (pos, data) = &mut this.early_data;
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
// complete handshake
while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?;
@ -184,8 +174,6 @@ where
}
this.state = TlsState::Stream;
let (_, data) = &mut this.early_data;
*data = Vec::new();
}
}
@ -200,7 +188,7 @@ where
#[cfg(feature = "early-data")] {
// we skip the handshake
if let TlsState::EarlyData = self.state {
if let TlsState::EarlyData(..) = self.state {
return Pin::new(&mut self.io).poll_shutdown(cx);
}
}

View File

@ -19,10 +19,10 @@ use common::Stream;
pub use rustls;
pub use webpki;
#[derive(Debug, Copy, Clone)]
#[derive(Debug)]
enum TlsState {
#[cfg(feature = "early-data")]
EarlyData,
EarlyData(usize, Vec<u8>),
Stream,
ReadShutdown,
WriteShutdown,
@ -51,12 +51,25 @@ impl TlsState {
}
}
fn readable(self) -> bool {
fn readable(&self) -> bool {
match self {
TlsState::ReadShutdown | TlsState::FullyShutdown => false,
_ => true,
}
}
#[cfg(feature = "early-data")]
fn is_early_data(&self) -> bool {
match self {
TlsState::EarlyData(..) => true,
_ => false
}
}
#[cfg(not(feature = "early-data"))]
const fn is_early_data(&self) -> bool {
false
}
}
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
@ -100,6 +113,7 @@ impl TlsConnector {
self
}
#[inline]
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
@ -107,7 +121,6 @@ impl TlsConnector {
self.connect_with(domain, stream, |_| ())
}
#[inline]
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
@ -116,33 +129,21 @@ impl TlsConnector {
let mut session = ClientSession::new(&self.inner, domain);
f(&mut session);
#[cfg(not(feature = "early-data"))]
{
Connect(client::MidHandshake::Handshaking(client::TlsStream {
session,
io: stream,
state: TlsState::Stream,
}))
}
Connect(client::MidHandshake::Handshaking(client::TlsStream {
io: stream,
#[cfg(feature = "early-data")]
{
Connect(if self.early_data && session.early_data().is_some() {
client::MidHandshake::EarlyData(client::TlsStream {
session,
io: stream,
state: TlsState::EarlyData,
early_data: (0, Vec::new()),
})
#[cfg(not(feature = "early-data"))]
state: TlsState::Stream,
#[cfg(feature = "early-data")]
state: if self.early_data && session.early_data().is_some() {
TlsState::EarlyData(0, Vec::new())
} else {
client::MidHandshake::Handshaking(client::TlsStream {
session,
io: stream,
state: TlsState::Stream,
early_data: (0, Vec::new()),
})
})
}
TlsState::Stream
},
session
}))
}
}

View File

@ -76,7 +76,7 @@ where
let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable());
match this.state {
match &this.state {
TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(0)) => {
this.state.shutdown_read();