Clean TlsState
This commit is contained in:
parent
fe113dc6b0
commit
262796af39
@ -15,7 +15,6 @@ edition = "2018"
|
|||||||
github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" }
|
github-actions = { repository = "quininer/tokio-rustls", workflow = "ci" }
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
smallvec = "0.6"
|
|
||||||
tokio-io = "=0.2.0-alpha.6"
|
tokio-io = "=0.2.0-alpha.6"
|
||||||
futures-core-preview = "=0.3.0-alpha.19"
|
futures-core-preview = "=0.3.0-alpha.19"
|
||||||
pin-project = "0.4"
|
pin-project = "0.4"
|
||||||
|
@ -8,15 +8,10 @@ pub struct TlsStream<IO> {
|
|||||||
pub(crate) io: IO,
|
pub(crate) io: IO,
|
||||||
pub(crate) session: ClientSession,
|
pub(crate) session: ClientSession,
|
||||||
pub(crate) state: TlsState,
|
pub(crate) state: TlsState,
|
||||||
|
|
||||||
#[cfg(feature = "early-data")]
|
|
||||||
pub(crate) early_data: (usize, Vec<u8>),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) enum MidHandshake<IO> {
|
pub(crate) enum MidHandshake<IO> {
|
||||||
Handshaking(TlsStream<IO>),
|
Handshaking(TlsStream<IO>),
|
||||||
#[cfg(feature = "early-data")]
|
|
||||||
EarlyData(TlsStream<IO>),
|
|
||||||
End,
|
End,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,6 +43,7 @@ where
|
|||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
|
|
||||||
if let MidHandshake::Handshaking(stream) = this {
|
if let MidHandshake::Handshaking(stream) = this {
|
||||||
|
if !stream.state.is_early_data() {
|
||||||
let eof = !stream.state.readable();
|
let eof = !stream.state.readable();
|
||||||
let (io, session) = stream.get_mut();
|
let (io, session) = stream.get_mut();
|
||||||
let mut stream = Stream::new(io, session).set_eof(eof);
|
let mut stream = Stream::new(io, session).set_eof(eof);
|
||||||
@ -60,11 +56,10 @@ where
|
|||||||
futures::ready!(stream.write_io(cx))?;
|
futures::ready!(stream.write_io(cx))?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
match mem::replace(this, MidHandshake::End) {
|
match mem::replace(this, MidHandshake::End) {
|
||||||
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
|
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
|
||||||
#[cfg(feature = "early-data")]
|
|
||||||
MidHandshake::EarlyData(stream) => Poll::Ready(Ok(stream)),
|
|
||||||
MidHandshake::End => panic!(),
|
MidHandshake::End => panic!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -81,7 +76,7 @@ where
|
|||||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
||||||
match self.state {
|
match self.state {
|
||||||
#[cfg(feature = "early-data")]
|
#[cfg(feature = "early-data")]
|
||||||
TlsState::EarlyData => Poll::Pending,
|
TlsState::EarlyData(..) => Poll::Pending,
|
||||||
TlsState::Stream | TlsState::WriteShutdown => {
|
TlsState::Stream | TlsState::WriteShutdown => {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
||||||
@ -122,11 +117,9 @@ where
|
|||||||
|
|
||||||
match this.state {
|
match this.state {
|
||||||
#[cfg(feature = "early-data")]
|
#[cfg(feature = "early-data")]
|
||||||
TlsState::EarlyData => {
|
TlsState::EarlyData(ref mut pos, ref mut data) => {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
let (pos, data) = &mut this.early_data;
|
|
||||||
|
|
||||||
// write early data
|
// write early data
|
||||||
if let Some(mut early_data) = stream.session.early_data() {
|
if let Some(mut early_data) = stream.session.early_data() {
|
||||||
let len = match early_data.write(buf) {
|
let len = match early_data.write(buf) {
|
||||||
@ -154,7 +147,6 @@ where
|
|||||||
|
|
||||||
// end
|
// end
|
||||||
this.state = TlsState::Stream;
|
this.state = TlsState::Stream;
|
||||||
*data = Vec::new();
|
|
||||||
stream.as_mut_pin().poll_write(cx, buf)
|
stream.as_mut_pin().poll_write(cx, buf)
|
||||||
}
|
}
|
||||||
_ => stream.as_mut_pin().poll_write(cx, buf),
|
_ => stream.as_mut_pin().poll_write(cx, buf),
|
||||||
@ -167,9 +159,7 @@ where
|
|||||||
.set_eof(!this.state.readable());
|
.set_eof(!this.state.readable());
|
||||||
|
|
||||||
#[cfg(feature = "early-data")] {
|
#[cfg(feature = "early-data")] {
|
||||||
if let TlsState::EarlyData = this.state {
|
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
|
||||||
let (pos, data) = &mut this.early_data;
|
|
||||||
|
|
||||||
// complete handshake
|
// complete handshake
|
||||||
while stream.session.is_handshaking() {
|
while stream.session.is_handshaking() {
|
||||||
futures::ready!(stream.handshake(cx))?;
|
futures::ready!(stream.handshake(cx))?;
|
||||||
@ -184,8 +174,6 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
this.state = TlsState::Stream;
|
this.state = TlsState::Stream;
|
||||||
let (_, data) = &mut this.early_data;
|
|
||||||
*data = Vec::new();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -200,7 +188,7 @@ where
|
|||||||
|
|
||||||
#[cfg(feature = "early-data")] {
|
#[cfg(feature = "early-data")] {
|
||||||
// we skip the handshake
|
// we skip the handshake
|
||||||
if let TlsState::EarlyData = self.state {
|
if let TlsState::EarlyData(..) = self.state {
|
||||||
return Pin::new(&mut self.io).poll_shutdown(cx);
|
return Pin::new(&mut self.io).poll_shutdown(cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
51
src/lib.rs
51
src/lib.rs
@ -19,10 +19,10 @@ use common::Stream;
|
|||||||
pub use rustls;
|
pub use rustls;
|
||||||
pub use webpki;
|
pub use webpki;
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug)]
|
||||||
enum TlsState {
|
enum TlsState {
|
||||||
#[cfg(feature = "early-data")]
|
#[cfg(feature = "early-data")]
|
||||||
EarlyData,
|
EarlyData(usize, Vec<u8>),
|
||||||
Stream,
|
Stream,
|
||||||
ReadShutdown,
|
ReadShutdown,
|
||||||
WriteShutdown,
|
WriteShutdown,
|
||||||
@ -51,12 +51,25 @@ impl TlsState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readable(self) -> bool {
|
fn readable(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
TlsState::ReadShutdown | TlsState::FullyShutdown => false,
|
TlsState::ReadShutdown | TlsState::FullyShutdown => false,
|
||||||
_ => true,
|
_ => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "early-data")]
|
||||||
|
fn is_early_data(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
TlsState::EarlyData(..) => true,
|
||||||
|
_ => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "early-data"))]
|
||||||
|
const fn is_early_data(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
|
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
|
||||||
@ -100,6 +113,7 @@ impl TlsConnector {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
|
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
|
||||||
where
|
where
|
||||||
IO: AsyncRead + AsyncWrite + Unpin,
|
IO: AsyncRead + AsyncWrite + Unpin,
|
||||||
@ -107,7 +121,6 @@ impl TlsConnector {
|
|||||||
self.connect_with(domain, stream, |_| ())
|
self.connect_with(domain, stream, |_| ())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
|
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
|
||||||
where
|
where
|
||||||
IO: AsyncRead + AsyncWrite + Unpin,
|
IO: AsyncRead + AsyncWrite + Unpin,
|
||||||
@ -116,33 +129,21 @@ impl TlsConnector {
|
|||||||
let mut session = ClientSession::new(&self.inner, domain);
|
let mut session = ClientSession::new(&self.inner, domain);
|
||||||
f(&mut session);
|
f(&mut session);
|
||||||
|
|
||||||
#[cfg(not(feature = "early-data"))]
|
|
||||||
{
|
|
||||||
Connect(client::MidHandshake::Handshaking(client::TlsStream {
|
Connect(client::MidHandshake::Handshaking(client::TlsStream {
|
||||||
session,
|
|
||||||
io: stream,
|
io: stream,
|
||||||
|
|
||||||
|
#[cfg(not(feature = "early-data"))]
|
||||||
state: TlsState::Stream,
|
state: TlsState::Stream,
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "early-data")]
|
#[cfg(feature = "early-data")]
|
||||||
{
|
state: if self.early_data && session.early_data().is_some() {
|
||||||
Connect(if self.early_data && session.early_data().is_some() {
|
TlsState::EarlyData(0, Vec::new())
|
||||||
client::MidHandshake::EarlyData(client::TlsStream {
|
|
||||||
session,
|
|
||||||
io: stream,
|
|
||||||
state: TlsState::EarlyData,
|
|
||||||
early_data: (0, Vec::new()),
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
client::MidHandshake::Handshaking(client::TlsStream {
|
TlsState::Stream
|
||||||
session,
|
},
|
||||||
io: stream,
|
|
||||||
state: TlsState::Stream,
|
session
|
||||||
early_data: (0, Vec::new()),
|
}))
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ where
|
|||||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
||||||
.set_eof(!this.state.readable());
|
.set_eof(!this.state.readable());
|
||||||
|
|
||||||
match this.state {
|
match &this.state {
|
||||||
TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) {
|
TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) {
|
||||||
Poll::Ready(Ok(0)) => {
|
Poll::Ready(Ok(0)) => {
|
||||||
this.state.shutdown_read();
|
this.state.shutdown_read();
|
||||||
|
Loading…
Reference in New Issue
Block a user