diff --git a/src/client.rs b/src/client.rs index 7807f12..25d5874 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,5 +1,7 @@ use super::*; use rustls::Session; +use crate::common::IoSession; + /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -10,11 +12,6 @@ pub struct TlsStream { pub(crate) state: TlsState, } -pub(crate) enum MidHandshake { - Handshaking(TlsStream), - End, -} - impl TlsStream { #[inline] pub fn get_ref(&self) -> (&IO, &ClientSession) { @@ -32,36 +29,23 @@ impl TlsStream { } } -impl Future for MidHandshake -where - IO: AsyncRead + AsyncWrite + Unpin, -{ - type Output = io::Result>; +impl IoSession for TlsStream { + type Io = IO; + type Session = ClientSession; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); + fn skip_handshake(&self) -> bool { + self.state.is_early_data() + } - if let MidHandshake::Handshaking(stream) = this { - if !stream.state.is_early_data() { - let eof = !stream.state.readable(); - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session).set_eof(eof); + #[inline] + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { + (&mut self.state, &mut self.io, &mut self.session) + } - while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; - } - - while stream.session.wants_write() { - futures::ready!(stream.write_io(cx))?; - } - } - } - - match mem::replace(this, MidHandshake::End) { - MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), - MidHandshake::End => panic!(), - } + #[inline] + fn into_io(self) -> Self::Io { + self.io } } @@ -119,6 +103,7 @@ where match this.state { #[cfg(feature = "early-data")] TlsState::EarlyData(ref mut pos, ref mut data) => { + use futures_core::ready; use std::io::Write; // write early data @@ -137,13 +122,13 @@ where // complete handshake while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; + ready!(stream.handshake(cx))?; } // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; *pos += len; } } @@ -162,16 +147,18 @@ where .set_eof(!this.state.readable()); #[cfg(feature = "early-data")] { + use futures_core::ready; + if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { // complete handshake while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; + ready!(stream.handshake(cx))?; } // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; *pos += len; } } diff --git a/src/common/handshake.rs b/src/common/handshake.rs new file mode 100644 index 0000000..0006b56 --- /dev/null +++ b/src/common/handshake.rs @@ -0,0 +1,84 @@ +use std::{ io, mem }; +use std::pin::Pin; +use std::future::Future; +use std::task::{ Context, Poll }; +use futures_core::future::FusedFuture; +use tokio::io::{ AsyncRead, AsyncWrite }; +use rustls::Session; +use crate::common::{ TlsState, Stream }; + + +pub(crate) trait IoSession { + type Io; + type Session; + + fn skip_handshake(&self) -> bool; + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session); + fn into_io(self) -> Self::Io; +} + +pub(crate) enum MidHandshake { + Handshaking(IS), + End, +} + +impl FusedFuture for MidHandshake +where + IS: IoSession + Unpin, + IS::Io: AsyncRead + AsyncWrite + Unpin, + IS::Session: Session + Unpin +{ + fn is_terminated(&self) -> bool { + if let MidHandshake::End = self { + true + } else { + false + } + } +} + +impl Future for MidHandshake +where + IS: IoSession + Unpin, + IS::Io: AsyncRead + AsyncWrite + Unpin, + IS::Session: Session + Unpin +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) { + if !stream.skip_handshake() { + let (state, io, session) = stream.get_mut(); + let mut tls_stream = Stream::new(io, session) + .set_eof(!state.readable()); + + macro_rules! try_poll { + ( $e:expr ) => { + match $e { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))), + Poll::Pending => { + *this = MidHandshake::Handshaking(stream); + return Poll::Pending; + } + } + } + } + + while tls_stream.session.is_handshaking() { + try_poll!(tls_stream.handshake(cx)); + } + + while tls_stream.session.wants_write() { + try_poll!(tls_stream.write_io(cx)); + } + } + + Poll::Ready(Ok(stream)) + } else { + panic!() + } + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs index a53f548..9f6d9ac 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,10 +1,12 @@ +mod handshake; + use std::pin::Pin; use std::task::{ Poll, Context }; -use std::marker::Unpin; use std::io::{ self, Read, Write }; use rustls::Session; use tokio::io::{ AsyncRead, AsyncWrite }; use futures_core as futures; +pub(crate) use handshake::{ IoSession, MidHandshake }; #[derive(Debug)] @@ -18,6 +20,7 @@ pub enum TlsState { } impl TlsState { + #[inline] pub fn shutdown_read(&mut self) { match *self { TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, @@ -25,6 +28,7 @@ impl TlsState { } } + #[inline] pub fn shutdown_write(&mut self) { match *self { TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, @@ -32,6 +36,7 @@ impl TlsState { } } + #[inline] pub fn writeable(&self) -> bool { match *self { TlsState::WriteShutdown | TlsState::FullyShutdown => false, @@ -39,6 +44,7 @@ impl TlsState { } } + #[inline] pub fn readable(&self) -> bool { match self { TlsState::ReadShutdown | TlsState::FullyShutdown => false, @@ -46,6 +52,7 @@ impl TlsState { } } + #[inline] #[cfg(feature = "early-data")] pub fn is_early_data(&self) -> bool { match self { @@ -54,6 +61,7 @@ impl TlsState { } } + #[inline] #[cfg(not(feature = "early-data"))] pub const fn is_early_data(&self) -> bool { false diff --git a/src/lib.rs b/src/lib.rs index 09c10e2..28d9de1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,16 +4,16 @@ mod common; pub mod client; pub mod server; -use std::{ io, mem }; +use std::io; use std::pin::Pin; use std::sync::Arc; use std::future::Future; use std::task::{ Context, Poll }; -use futures_core as futures; +use futures_core::future::FusedFuture; use tokio::io::{ AsyncRead, AsyncWrite }; use webpki::DNSNameRef; use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session }; -use common::{ Stream, TlsState }; +use common::{ Stream, TlsState, MidHandshake }; pub use rustls; pub use webpki; @@ -75,7 +75,7 @@ impl TlsConnector { let mut session = ClientSession::new(&self.inner, domain); f(&mut session); - Connect(client::MidHandshake::Handshaking(client::TlsStream { + Connect(MidHandshake::Handshaking(client::TlsStream { io: stream, #[cfg(not(feature = "early-data"))] @@ -110,7 +110,7 @@ impl TlsAcceptor { let mut session = ServerSession::new(&self.inner); f(&mut session); - Accept(server::MidHandshake::Handshaking(server::TlsStream { + Accept(MidHandshake::Handshaking(server::TlsStream { session, io: stream, state: TlsState::Stream, @@ -120,30 +120,99 @@ impl TlsAcceptor { /// Future returned from `TlsConnector::connect` which will resolve /// once the connection handshake has finished. -pub struct Connect(client::MidHandshake); +pub struct Connect(MidHandshake>); /// Future returned from `TlsAcceptor::accept` which will resolve /// once the accept handshake has finished. -pub struct Accept(server::MidHandshake); +pub struct Accept(MidHandshake>); + +/// Like [Connect], but returns `IO` on failure. +pub struct FailableConnect(MidHandshake>); + +/// Like [Accept], but returns `IO` on failure. +pub struct FailableAccept(MidHandshake>); + +impl Connect { + #[inline] + pub fn into_failable(self) -> FailableConnect { + FailableConnect(self.0) + } +} + +impl Accept { + #[inline] + pub fn into_failable(self) -> FailableAccept { + FailableAccept(self.0) + } +} impl Future for Connect { type Output = io::Result>; #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.0).poll(cx) + Pin::new(&mut self.0) + .poll(cx) + .map_err(|(err, _)| err) + } +} + +impl FusedFuture for Connect { + fn is_terminated(&self) -> bool { + self.0.is_terminated() } } impl Future for Accept { type Output = io::Result>; + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0) + .poll(cx) + .map_err(|(err, _)| err) + } +} + +impl FusedFuture for Accept { + #[inline] + fn is_terminated(&self) -> bool { + self.0.is_terminated() + } +} + +impl Future for FailableConnect { + type Output = Result, (io::Error, IO)>; + #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(&mut self.0).poll(cx) } } +impl FusedFuture for FailableConnect { + #[inline] + fn is_terminated(&self) -> bool { + self.0.is_terminated() + } +} + +impl Future for FailableAccept { + type Output = Result, (io::Error, IO)>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +impl FusedFuture for FailableAccept { + #[inline] + fn is_terminated(&self) -> bool { + self.0.is_terminated() + } +} + /// Unified TLS stream type /// /// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use diff --git a/src/server.rs b/src/server.rs index 0563341..aa7164e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,6 @@ use super::*; use rustls::Session; +use crate::common::IoSession; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -10,11 +11,6 @@ pub struct TlsStream { pub(crate) state: TlsState, } -pub(crate) enum MidHandshake { - Handshaking(TlsStream), - End, -} - impl TlsStream { #[inline] pub fn get_ref(&self) -> (&IO, &ServerSession) { @@ -32,34 +28,23 @@ impl TlsStream { } } -impl Future for MidHandshake -where - IO: AsyncRead + AsyncWrite + Unpin, -{ - type Output = io::Result>; +impl IoSession for TlsStream { + type Io = IO; + type Session = ServerSession; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); + fn skip_handshake(&self) -> bool { + false + } - if let MidHandshake::Handshaking(stream) = this { - let eof = !stream.state.readable(); - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(io, session).set_eof(eof); + #[inline] + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { + (&mut self.state, &mut self.io, &mut self.session) + } - while stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; - } - - while stream.session.wants_write() { - futures::ready!(stream.write_io(cx))?; - } - } - - match mem::replace(this, MidHandshake::End) { - MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)), - MidHandshake::End => panic!(), - } + #[inline] + fn into_io(self) -> Self::Io { + self.io } }