Clean TlsState
This commit is contained in:
parent
fe113dc6b0
commit
262796af39
@ -15,7 +15,6 @@ edition = "2018"
|
||||
github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" }
|
||||
|
||||
[dependencies]
|
||||
smallvec = "0.6"
|
||||
tokio-io = "=0.2.0-alpha.6"
|
||||
futures-core-preview = "=0.3.0-alpha.19"
|
||||
pin-project = "0.4"
|
||||
|
@ -8,15 +8,10 @@ pub struct TlsStream<IO> {
|
||||
pub(crate) io: IO,
|
||||
pub(crate) session: ClientSession,
|
||||
pub(crate) state: TlsState,
|
||||
|
||||
#[cfg(feature = "early-data")]
|
||||
pub(crate) early_data: (usize, Vec<u8>),
|
||||
}
|
||||
|
||||
pub(crate) enum MidHandshake<IO> {
|
||||
Handshaking(TlsStream<IO>),
|
||||
#[cfg(feature = "early-data")]
|
||||
EarlyData(TlsStream<IO>),
|
||||
End,
|
||||
}
|
||||
|
||||
@ -48,6 +43,7 @@ where
|
||||
let this = self.get_mut();
|
||||
|
||||
if let MidHandshake::Handshaking(stream) = this {
|
||||
if !stream.state.is_early_data() {
|
||||
let eof = !stream.state.readable();
|
||||
let (io, session) = stream.get_mut();
|
||||
let mut stream = Stream::new(io, session).set_eof(eof);
|
||||
@ -60,11 +56,10 @@ where
|
||||
futures::ready!(stream.write_io(cx))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match mem::replace(this, MidHandshake::End) {
|
||||
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
|
||||
#[cfg(feature = "early-data")]
|
||||
MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
|
||||
MidHandshake::End => panic!(),
|
||||
}
|
||||
}
|
||||
@ -81,7 +76,7 @@ where
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
||||
match self.state {
|
||||
#[cfg(feature = "early-data")]
|
||||
TlsState::EarlyData => Poll::Pending,
|
||||
TlsState::EarlyData(..) => Poll::Pending,
|
||||
TlsState::Stream | TlsState::WriteShutdown => {
|
||||
let this = self.get_mut();
|
||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
||||
@ -122,11 +117,9 @@ where
|
||||
|
||||
match this.state {
|
||||
#[cfg(feature = "early-data")]
|
||||
TlsState::EarlyData => {
|
||||
TlsState::EarlyData(ref mut pos, ref mut data) => {
|
||||
use std::io::Write;
|
||||
|
||||
let (pos, data) = &mut this.early_data;
|
||||
|
||||
// write early data
|
||||
if let Some(mut early_data) = stream.session.early_data() {
|
||||
let len = match early_data.write(buf) {
|
||||
@ -154,7 +147,6 @@ where
|
||||
|
||||
// end
|
||||
this.state = TlsState::Stream;
|
||||
*data = Vec::new();
|
||||
stream.as_mut_pin().poll_write(cx, buf)
|
||||
}
|
||||
_ => stream.as_mut_pin().poll_write(cx, buf),
|
||||
@ -167,9 +159,7 @@ where
|
||||
.set_eof(!this.state.readable());
|
||||
|
||||
#[cfg(feature = "early-data")] {
|
||||
if let TlsState::EarlyData = this.state {
|
||||
let (pos, data) = &mut this.early_data;
|
||||
|
||||
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
|
||||
// complete handshake
|
||||
while stream.session.is_handshaking() {
|
||||
futures::ready!(stream.handshake(cx))?;
|
||||
@ -184,8 +174,6 @@ where
|
||||
}
|
||||
|
||||
this.state = TlsState::Stream;
|
||||
let (_, data) = &mut this.early_data;
|
||||
*data = Vec::new();
|
||||
}
|
||||
}
|
||||
|
||||
@ -200,7 +188,7 @@ where
|
||||
|
||||
#[cfg(feature = "early-data")] {
|
||||
// we skip the handshake
|
||||
if let TlsState::EarlyData = self.state {
|
||||
if let TlsState::EarlyData(..) = self.state {
|
||||
return Pin::new(&mut self.io).poll_shutdown(cx);
|
||||
}
|
||||
}
|
||||
|
51
src/lib.rs
51
src/lib.rs
@ -19,10 +19,10 @@ use common::Stream;
|
||||
pub use rustls;
|
||||
pub use webpki;
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
#[derive(Debug)]
|
||||
enum TlsState {
|
||||
#[cfg(feature = "early-data")]
|
||||
EarlyData,
|
||||
EarlyData(usize, Vec<u8>),
|
||||
Stream,
|
||||
ReadShutdown,
|
||||
WriteShutdown,
|
||||
@ -51,12 +51,25 @@ impl TlsState {
|
||||
}
|
||||
}
|
||||
|
||||
fn readable(self) -> bool {
|
||||
fn readable(&self) -> bool {
|
||||
match self {
|
||||
TlsState::ReadShutdown | TlsState::FullyShutdown => false,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "early-data")]
|
||||
fn is_early_data(&self) -> bool {
|
||||
match self {
|
||||
TlsState::EarlyData(..) => true,
|
||||
_ => false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "early-data"))]
|
||||
const fn is_early_data(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
|
||||
@ -100,6 +113,7 @@ impl TlsConnector {
|
||||
self
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
@ -107,7 +121,6 @@ impl TlsConnector {
|
||||
self.connect_with(domain, stream, |_| ())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
@ -116,33 +129,21 @@ impl TlsConnector {
|
||||
let mut session = ClientSession::new(&self.inner, domain);
|
||||
f(&mut session);
|
||||
|
||||
#[cfg(not(feature = "early-data"))]
|
||||
{
|
||||
Connect(client::MidHandshake::Handshaking(client::TlsStream {
|
||||
session,
|
||||
io: stream,
|
||||
|
||||
#[cfg(not(feature = "early-data"))]
|
||||
state: TlsState::Stream,
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(feature = "early-data")]
|
||||
{
|
||||
Connect(if self.early_data && session.early_data().is_some() {
|
||||
client::MidHandshake::EarlyData(client::TlsStream {
|
||||
session,
|
||||
io: stream,
|
||||
state: TlsState::EarlyData,
|
||||
early_data: (0, Vec::new()),
|
||||
})
|
||||
state: if self.early_data && session.early_data().is_some() {
|
||||
TlsState::EarlyData(0, Vec::new())
|
||||
} else {
|
||||
client::MidHandshake::Handshaking(client::TlsStream {
|
||||
session,
|
||||
io: stream,
|
||||
state: TlsState::Stream,
|
||||
early_data: (0, Vec::new()),
|
||||
})
|
||||
})
|
||||
}
|
||||
TlsState::Stream
|
||||
},
|
||||
|
||||
session
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,7 +76,7 @@ where
|
||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
||||
.set_eof(!this.state.readable());
|
||||
|
||||
match this.state {
|
||||
match &this.state {
|
||||
TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(0)) => {
|
||||
this.state.shutdown_read();
|
||||
|
Loading…
Reference in New Issue
Block a user