Add Failable{Connect,Accept}

This commit is contained in:
quininer 2019-12-08 16:41:47 +08:00
parent 7f69e889a4
commit 368f32ea9f
5 changed files with 206 additions and 73 deletions

View File

@ -1,5 +1,7 @@
use super::*; use super::*;
use rustls::Session; use rustls::Session;
use crate::common::IoSession;
/// A wrapper around an underlying raw stream which implements the TLS or SSL /// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol. /// protocol.
@ -10,11 +12,6 @@ pub struct TlsStream<IO> {
pub(crate) state: TlsState, pub(crate) state: TlsState,
} }
pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>),
End,
}
impl<IO> TlsStream<IO> { impl<IO> TlsStream<IO> {
#[inline] #[inline]
pub fn get_ref(&self) -> (&IO, &ClientSession) { pub fn get_ref(&self) -> (&IO, &ClientSession) {
@ -32,36 +29,23 @@ impl<IO> TlsStream<IO> {
} }
} }
impl<IO> Future for MidHandshake<IO> impl<IO> IoSession for TlsStream<IO> {
where type Io = IO;
IO: AsyncRead + AsyncWrite + Unpin, type Session = ClientSession;
{
type Output = io::Result<TlsStream<IO>>;
#[inline] #[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn skip_handshake(&self) -> bool {
let this = self.get_mut(); self.state.is_early_data()
}
if let MidHandshake::Handshaking(stream) = this { #[inline]
if !stream.state.is_early_data() { fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
let eof = !stream.state.readable(); (&mut self.state, &mut self.io, &mut self.session)
let (io, session) = stream.get_mut(); }
let mut stream = Stream::new(io, session).set_eof(eof);
while stream.session.is_handshaking() { #[inline]
futures::ready!(stream.handshake(cx))?; fn into_io(self) -> Self::Io {
} self.io
while stream.session.wants_write() {
futures::ready!(stream.write_io(cx))?;
}
}
}
match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
MidHandshake::End => panic!(),
}
} }
} }
@ -119,6 +103,7 @@ where
match this.state { match this.state {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
TlsState::EarlyData(ref mut pos, ref mut data) => { TlsState::EarlyData(ref mut pos, ref mut data) => {
use futures_core::ready;
use std::io::Write; use std::io::Write;
// write early data // write early data
@ -137,13 +122,13 @@ where
// complete handshake // complete handshake
while stream.session.is_handshaking() { while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?; ready!(stream.handshake(cx))?;
} }
// write early data (fallback) // write early data (fallback)
if !stream.session.is_early_data_accepted() { if !stream.session.is_early_data_accepted() {
while *pos < data.len() { while *pos < data.len() {
let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len; *pos += len;
} }
} }
@ -162,16 +147,18 @@ where
.set_eof(!this.state.readable()); .set_eof(!this.state.readable());
#[cfg(feature = "early-data")] { #[cfg(feature = "early-data")] {
use futures_core::ready;
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
// complete handshake // complete handshake
while stream.session.is_handshaking() { while stream.session.is_handshaking() {
futures::ready!(stream.handshake(cx))?; ready!(stream.handshake(cx))?;
} }
// write early data (fallback) // write early data (fallback)
if !stream.session.is_early_data_accepted() { if !stream.session.is_early_data_accepted() {
while *pos < data.len() { while *pos < data.len() {
let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len; *pos += len;
} }
} }

84
src/common/handshake.rs Normal file
View File

@ -0,0 +1,84 @@
use std::{ io, mem };
use std::pin::Pin;
use std::future::Future;
use std::task::{ Context, Poll };
use futures_core::future::FusedFuture;
use tokio::io::{ AsyncRead, AsyncWrite };
use rustls::Session;
use crate::common::{ TlsState, Stream };
pub(crate) trait IoSession {
type Io;
type Session;
fn skip_handshake(&self) -> bool;
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session);
fn into_io(self) -> Self::Io;
}
pub(crate) enum MidHandshake<IS> {
Handshaking(IS),
End,
}
impl<IS> FusedFuture for MidHandshake<IS>
where
IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: Session + Unpin
{
fn is_terminated(&self) -> bool {
if let MidHandshake::End = self {
true
} else {
false
}
}
}
impl<IS> Future for MidHandshake<IS>
where
IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: Session + Unpin
{
type Output = Result<IS, (io::Error, IS::Io)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) {
if !stream.skip_handshake() {
let (state, io, session) = stream.get_mut();
let mut tls_stream = Stream::new(io, session)
.set_eof(!state.readable());
macro_rules! try_poll {
( $e:expr ) => {
match $e {
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))),
Poll::Pending => {
*this = MidHandshake::Handshaking(stream);
return Poll::Pending;
}
}
}
}
while tls_stream.session.is_handshaking() {
try_poll!(tls_stream.handshake(cx));
}
while tls_stream.session.wants_write() {
try_poll!(tls_stream.write_io(cx));
}
}
Poll::Ready(Ok(stream))
} else {
panic!()
}
}
}

View File

@ -1,10 +1,12 @@
mod handshake;
use std::pin::Pin; use std::pin::Pin;
use std::task::{ Poll, Context }; use std::task::{ Poll, Context };
use std::marker::Unpin;
use std::io::{ self, Read, Write }; use std::io::{ self, Read, Write };
use rustls::Session; use rustls::Session;
use tokio::io::{ AsyncRead, AsyncWrite }; use tokio::io::{ AsyncRead, AsyncWrite };
use futures_core as futures; use futures_core as futures;
pub(crate) use handshake::{ IoSession, MidHandshake };
#[derive(Debug)] #[derive(Debug)]
@ -18,6 +20,7 @@ pub enum TlsState {
} }
impl TlsState { impl TlsState {
#[inline]
pub fn shutdown_read(&mut self) { pub fn shutdown_read(&mut self) {
match *self { match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
@ -25,6 +28,7 @@ impl TlsState {
} }
} }
#[inline]
pub fn shutdown_write(&mut self) { pub fn shutdown_write(&mut self) {
match *self { match *self {
TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
@ -32,6 +36,7 @@ impl TlsState {
} }
} }
#[inline]
pub fn writeable(&self) -> bool { pub fn writeable(&self) -> bool {
match *self { match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => false, TlsState::WriteShutdown | TlsState::FullyShutdown => false,
@ -39,6 +44,7 @@ impl TlsState {
} }
} }
#[inline]
pub fn readable(&self) -> bool { pub fn readable(&self) -> bool {
match self { match self {
TlsState::ReadShutdown | TlsState::FullyShutdown => false, TlsState::ReadShutdown | TlsState::FullyShutdown => false,
@ -46,6 +52,7 @@ impl TlsState {
} }
} }
#[inline]
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
pub fn is_early_data(&self) -> bool { pub fn is_early_data(&self) -> bool {
match self { match self {
@ -54,6 +61,7 @@ impl TlsState {
} }
} }
#[inline]
#[cfg(not(feature = "early-data"))] #[cfg(not(feature = "early-data"))]
pub const fn is_early_data(&self) -> bool { pub const fn is_early_data(&self) -> bool {
false false

View File

@ -4,16 +4,16 @@ mod common;
pub mod client; pub mod client;
pub mod server; pub mod server;
use std::{ io, mem }; use std::io;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::future::Future; use std::future::Future;
use std::task::{ Context, Poll }; use std::task::{ Context, Poll };
use futures_core as futures; use futures_core::future::FusedFuture;
use tokio::io::{ AsyncRead, AsyncWrite }; use tokio::io::{ AsyncRead, AsyncWrite };
use webpki::DNSNameRef; use webpki::DNSNameRef;
use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session }; use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session };
use common::{ Stream, TlsState }; use common::{ Stream, TlsState, MidHandshake };
pub use rustls; pub use rustls;
pub use webpki; pub use webpki;
@ -75,7 +75,7 @@ impl TlsConnector {
let mut session = ClientSession::new(&self.inner, domain); let mut session = ClientSession::new(&self.inner, domain);
f(&mut session); f(&mut session);
Connect(client::MidHandshake::Handshaking(client::TlsStream { Connect(MidHandshake::Handshaking(client::TlsStream {
io: stream, io: stream,
#[cfg(not(feature = "early-data"))] #[cfg(not(feature = "early-data"))]
@ -110,7 +110,7 @@ impl TlsAcceptor {
let mut session = ServerSession::new(&self.inner); let mut session = ServerSession::new(&self.inner);
f(&mut session); f(&mut session);
Accept(server::MidHandshake::Handshaking(server::TlsStream { Accept(MidHandshake::Handshaking(server::TlsStream {
session, session,
io: stream, io: stream,
state: TlsState::Stream, state: TlsState::Stream,
@ -120,30 +120,99 @@ impl TlsAcceptor {
/// Future returned from `TlsConnector::connect` which will resolve /// Future returned from `TlsConnector::connect` which will resolve
/// once the connection handshake has finished. /// once the connection handshake has finished.
pub struct Connect<IO>(client::MidHandshake<IO>); pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
/// Future returned from `TlsAcceptor::accept` which will resolve /// Future returned from `TlsAcceptor::accept` which will resolve
/// once the accept handshake has finished. /// once the accept handshake has finished.
pub struct Accept<IO>(server::MidHandshake<IO>); pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
/// Like [Connect], but returns `IO` on failure.
pub struct FailableConnect<IO>(MidHandshake<client::TlsStream<IO>>);
/// Like [Accept], but returns `IO` on failure.
pub struct FailableAccept<IO>(MidHandshake<server::TlsStream<IO>>);
impl<IO> Connect<IO> {
#[inline]
pub fn into_failable(self) -> FailableConnect<IO> {
FailableConnect(self.0)
}
}
impl<IO> Accept<IO> {
#[inline]
pub fn into_failable(self) -> FailableAccept<IO> {
FailableAccept(self.0)
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> { impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
type Output = io::Result<client::TlsStream<IO>>; type Output = io::Result<client::TlsStream<IO>>;
#[inline] #[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx) Pin::new(&mut self.0)
.poll(cx)
.map_err(|(err, _)| err)
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for Connect<IO> {
fn is_terminated(&self) -> bool {
self.0.is_terminated()
} }
} }
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> { impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
type Output = io::Result<server::TlsStream<IO>>; type Output = io::Result<server::TlsStream<IO>>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0)
.poll(cx)
.map_err(|(err, _)| err)
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for Accept<IO> {
#[inline]
fn is_terminated(&self) -> bool {
self.0.is_terminated()
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableConnect<IO> {
type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
#[inline] #[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx) Pin::new(&mut self.0).poll(cx)
} }
} }
impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for FailableConnect<IO> {
#[inline]
fn is_terminated(&self) -> bool {
self.0.is_terminated()
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableAccept<IO> {
type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx)
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for FailableAccept<IO> {
#[inline]
fn is_terminated(&self) -> bool {
self.0.is_terminated()
}
}
/// Unified TLS stream type /// Unified TLS stream type
/// ///
/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use /// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use

View File

@ -1,5 +1,6 @@
use super::*; use super::*;
use rustls::Session; use rustls::Session;
use crate::common::IoSession;
/// A wrapper around an underlying raw stream which implements the TLS or SSL /// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol. /// protocol.
@ -10,11 +11,6 @@ pub struct TlsStream<IO> {
pub(crate) state: TlsState, pub(crate) state: TlsState,
} }
pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>),
End,
}
impl<IO> TlsStream<IO> { impl<IO> TlsStream<IO> {
#[inline] #[inline]
pub fn get_ref(&self) -> (&IO, &ServerSession) { pub fn get_ref(&self) -> (&IO, &ServerSession) {
@ -32,34 +28,23 @@ impl<IO> TlsStream<IO> {
} }
} }
impl<IO> Future for MidHandshake<IO> impl<IO> IoSession for TlsStream<IO> {
where type Io = IO;
IO: AsyncRead + AsyncWrite + Unpin, type Session = ServerSession;
{
type Output = io::Result<TlsStream<IO>>;
#[inline] #[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn skip_handshake(&self) -> bool {
let this = self.get_mut(); false
}
if let MidHandshake::Handshaking(stream) = this { #[inline]
let eof = !stream.state.readable(); fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
let (io, session) = stream.get_mut(); (&mut self.state, &mut self.io, &mut self.session)
let mut stream = Stream::new(io, session).set_eof(eof); }
while stream.session.is_handshaking() { #[inline]
futures::ready!(stream.handshake(cx))?; fn into_io(self) -> Self::Io {
} self.io
while stream.session.wants_write() {
futures::ready!(stream.write_io(cx))?;
}
}
match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
MidHandshake::End => panic!(),
}
} }
} }