Add Failable{Connect,Accept}
This commit is contained in:
parent
7f69e889a4
commit
368f32ea9f
@ -1,5 +1,7 @@
|
||||
use super::*;
|
||||
use rustls::Session;
|
||||
use crate::common::IoSession;
|
||||
|
||||
|
||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||
/// protocol.
|
||||
@ -10,11 +12,6 @@ pub struct TlsStream<IO> {
|
||||
pub(crate) state: TlsState,
|
||||
}
|
||||
|
||||
pub(crate) enum MidHandshake<IO> {
|
||||
Handshaking(TlsStream<IO>),
|
||||
End,
|
||||
}
|
||||
|
||||
impl<IO> TlsStream<IO> {
|
||||
#[inline]
|
||||
pub fn get_ref(&self) -> (&IO, &ClientSession) {
|
||||
@ -32,36 +29,23 @@ impl<IO> TlsStream<IO> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> Future for MidHandshake<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = io::Result<TlsStream<IO>>;
|
||||
impl<IO> IoSession for TlsStream<IO> {
|
||||
type Io = IO;
|
||||
type Session = ClientSession;
|
||||
|
||||
#[inline]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if let MidHandshake::Handshaking(stream) = this {
|
||||
if !stream.state.is_early_data() {
|
||||
let eof = !stream.state.readable();
|
||||
let (io, session) = stream.get_mut();
|
||||
let mut stream = Stream::new(io, session).set_eof(eof);
|
||||
|
||||
while stream.session.is_handshaking() {
|
||||
futures::ready!(stream.handshake(cx))?;
|
||||
fn skip_handshake(&self) -> bool {
|
||||
self.state.is_early_data()
|
||||
}
|
||||
|
||||
while stream.session.wants_write() {
|
||||
futures::ready!(stream.write_io(cx))?;
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
|
||||
(&mut self.state, &mut self.io, &mut self.session)
|
||||
}
|
||||
|
||||
match mem::replace(this, MidHandshake::End) {
|
||||
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
|
||||
MidHandshake::End => panic!(),
|
||||
}
|
||||
#[inline]
|
||||
fn into_io(self) -> Self::Io {
|
||||
self.io
|
||||
}
|
||||
}
|
||||
|
||||
@ -119,6 +103,7 @@ where
|
||||
match this.state {
|
||||
#[cfg(feature = "early-data")]
|
||||
TlsState::EarlyData(ref mut pos, ref mut data) => {
|
||||
use futures_core::ready;
|
||||
use std::io::Write;
|
||||
|
||||
// write early data
|
||||
@ -137,13 +122,13 @@ where
|
||||
|
||||
// complete handshake
|
||||
while stream.session.is_handshaking() {
|
||||
futures::ready!(stream.handshake(cx))?;
|
||||
ready!(stream.handshake(cx))?;
|
||||
}
|
||||
|
||||
// write early data (fallback)
|
||||
if !stream.session.is_early_data_accepted() {
|
||||
while *pos < data.len() {
|
||||
let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
|
||||
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
|
||||
*pos += len;
|
||||
}
|
||||
}
|
||||
@ -162,16 +147,18 @@ where
|
||||
.set_eof(!this.state.readable());
|
||||
|
||||
#[cfg(feature = "early-data")] {
|
||||
use futures_core::ready;
|
||||
|
||||
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
|
||||
// complete handshake
|
||||
while stream.session.is_handshaking() {
|
||||
futures::ready!(stream.handshake(cx))?;
|
||||
ready!(stream.handshake(cx))?;
|
||||
}
|
||||
|
||||
// write early data (fallback)
|
||||
if !stream.session.is_early_data_accepted() {
|
||||
while *pos < data.len() {
|
||||
let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
|
||||
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
|
||||
*pos += len;
|
||||
}
|
||||
}
|
||||
|
84
src/common/handshake.rs
Normal file
84
src/common/handshake.rs
Normal file
@ -0,0 +1,84 @@
|
||||
use std::{ io, mem };
|
||||
use std::pin::Pin;
|
||||
use std::future::Future;
|
||||
use std::task::{ Context, Poll };
|
||||
use futures_core::future::FusedFuture;
|
||||
use tokio::io::{ AsyncRead, AsyncWrite };
|
||||
use rustls::Session;
|
||||
use crate::common::{ TlsState, Stream };
|
||||
|
||||
|
||||
pub(crate) trait IoSession {
|
||||
type Io;
|
||||
type Session;
|
||||
|
||||
fn skip_handshake(&self) -> bool;
|
||||
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session);
|
||||
fn into_io(self) -> Self::Io;
|
||||
}
|
||||
|
||||
pub(crate) enum MidHandshake<IS> {
|
||||
Handshaking(IS),
|
||||
End,
|
||||
}
|
||||
|
||||
impl<IS> FusedFuture for MidHandshake<IS>
|
||||
where
|
||||
IS: IoSession + Unpin,
|
||||
IS::Io: AsyncRead + AsyncWrite + Unpin,
|
||||
IS::Session: Session + Unpin
|
||||
{
|
||||
fn is_terminated(&self) -> bool {
|
||||
if let MidHandshake::End = self {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IS> Future for MidHandshake<IS>
|
||||
where
|
||||
IS: IoSession + Unpin,
|
||||
IS::Io: AsyncRead + AsyncWrite + Unpin,
|
||||
IS::Session: Session + Unpin
|
||||
{
|
||||
type Output = Result<IS, (io::Error, IS::Io)>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) {
|
||||
if !stream.skip_handshake() {
|
||||
let (state, io, session) = stream.get_mut();
|
||||
let mut tls_stream = Stream::new(io, session)
|
||||
.set_eof(!state.readable());
|
||||
|
||||
macro_rules! try_poll {
|
||||
( $e:expr ) => {
|
||||
match $e {
|
||||
Poll::Ready(Ok(_)) => (),
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))),
|
||||
Poll::Pending => {
|
||||
*this = MidHandshake::Handshaking(stream);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while tls_stream.session.is_handshaking() {
|
||||
try_poll!(tls_stream.handshake(cx));
|
||||
}
|
||||
|
||||
while tls_stream.session.wants_write() {
|
||||
try_poll!(tls_stream.write_io(cx));
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(stream))
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,10 +1,12 @@
|
||||
mod handshake;
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::task::{ Poll, Context };
|
||||
use std::marker::Unpin;
|
||||
use std::io::{ self, Read, Write };
|
||||
use rustls::Session;
|
||||
use tokio::io::{ AsyncRead, AsyncWrite };
|
||||
use futures_core as futures;
|
||||
pub(crate) use handshake::{ IoSession, MidHandshake };
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -18,6 +20,7 @@ pub enum TlsState {
|
||||
}
|
||||
|
||||
impl TlsState {
|
||||
#[inline]
|
||||
pub fn shutdown_read(&mut self) {
|
||||
match *self {
|
||||
TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
|
||||
@ -25,6 +28,7 @@ impl TlsState {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn shutdown_write(&mut self) {
|
||||
match *self {
|
||||
TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
|
||||
@ -32,6 +36,7 @@ impl TlsState {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn writeable(&self) -> bool {
|
||||
match *self {
|
||||
TlsState::WriteShutdown | TlsState::FullyShutdown => false,
|
||||
@ -39,6 +44,7 @@ impl TlsState {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn readable(&self) -> bool {
|
||||
match self {
|
||||
TlsState::ReadShutdown | TlsState::FullyShutdown => false,
|
||||
@ -46,6 +52,7 @@ impl TlsState {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[cfg(feature = "early-data")]
|
||||
pub fn is_early_data(&self) -> bool {
|
||||
match self {
|
||||
@ -54,6 +61,7 @@ impl TlsState {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[cfg(not(feature = "early-data"))]
|
||||
pub const fn is_early_data(&self) -> bool {
|
||||
false
|
||||
|
85
src/lib.rs
85
src/lib.rs
@ -4,16 +4,16 @@ mod common;
|
||||
pub mod client;
|
||||
pub mod server;
|
||||
|
||||
use std::{ io, mem };
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::future::Future;
|
||||
use std::task::{ Context, Poll };
|
||||
use futures_core as futures;
|
||||
use futures_core::future::FusedFuture;
|
||||
use tokio::io::{ AsyncRead, AsyncWrite };
|
||||
use webpki::DNSNameRef;
|
||||
use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession, Session };
|
||||
use common::{ Stream, TlsState };
|
||||
use common::{ Stream, TlsState, MidHandshake };
|
||||
|
||||
pub use rustls;
|
||||
pub use webpki;
|
||||
@ -75,7 +75,7 @@ impl TlsConnector {
|
||||
let mut session = ClientSession::new(&self.inner, domain);
|
||||
f(&mut session);
|
||||
|
||||
Connect(client::MidHandshake::Handshaking(client::TlsStream {
|
||||
Connect(MidHandshake::Handshaking(client::TlsStream {
|
||||
io: stream,
|
||||
|
||||
#[cfg(not(feature = "early-data"))]
|
||||
@ -110,7 +110,7 @@ impl TlsAcceptor {
|
||||
let mut session = ServerSession::new(&self.inner);
|
||||
f(&mut session);
|
||||
|
||||
Accept(server::MidHandshake::Handshaking(server::TlsStream {
|
||||
Accept(MidHandshake::Handshaking(server::TlsStream {
|
||||
session,
|
||||
io: stream,
|
||||
state: TlsState::Stream,
|
||||
@ -120,30 +120,99 @@ impl TlsAcceptor {
|
||||
|
||||
/// Future returned from `TlsConnector::connect` which will resolve
|
||||
/// once the connection handshake has finished.
|
||||
pub struct Connect<IO>(client::MidHandshake<IO>);
|
||||
pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
|
||||
|
||||
/// Future returned from `TlsAcceptor::accept` which will resolve
|
||||
/// once the accept handshake has finished.
|
||||
pub struct Accept<IO>(server::MidHandshake<IO>);
|
||||
pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
|
||||
|
||||
/// Like [Connect], but returns `IO` on failure.
|
||||
pub struct FailableConnect<IO>(MidHandshake<client::TlsStream<IO>>);
|
||||
|
||||
/// Like [Accept], but returns `IO` on failure.
|
||||
pub struct FailableAccept<IO>(MidHandshake<server::TlsStream<IO>>);
|
||||
|
||||
impl<IO> Connect<IO> {
|
||||
#[inline]
|
||||
pub fn into_failable(self) -> FailableConnect<IO> {
|
||||
FailableConnect(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> Accept<IO> {
|
||||
#[inline]
|
||||
pub fn into_failable(self) -> FailableAccept<IO> {
|
||||
FailableAccept(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
|
||||
type Output = io::Result<client::TlsStream<IO>>;
|
||||
|
||||
#[inline]
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.0).poll(cx)
|
||||
Pin::new(&mut self.0)
|
||||
.poll(cx)
|
||||
.map_err(|(err, _)| err)
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for Connect<IO> {
|
||||
fn is_terminated(&self) -> bool {
|
||||
self.0.is_terminated()
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
|
||||
type Output = io::Result<server::TlsStream<IO>>;
|
||||
|
||||
#[inline]
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.0)
|
||||
.poll(cx)
|
||||
.map_err(|(err, _)| err)
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for Accept<IO> {
|
||||
#[inline]
|
||||
fn is_terminated(&self) -> bool {
|
||||
self.0.is_terminated()
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableConnect<IO> {
|
||||
type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
|
||||
|
||||
#[inline]
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.0).poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for FailableConnect<IO> {
|
||||
#[inline]
|
||||
fn is_terminated(&self) -> bool {
|
||||
self.0.is_terminated()
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableAccept<IO> {
|
||||
type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
|
||||
|
||||
#[inline]
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.0).poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for FailableAccept<IO> {
|
||||
#[inline]
|
||||
fn is_terminated(&self) -> bool {
|
||||
self.0.is_terminated()
|
||||
}
|
||||
}
|
||||
|
||||
/// Unified TLS stream type
|
||||
///
|
||||
/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
|
||||
|
@ -1,5 +1,6 @@
|
||||
use super::*;
|
||||
use rustls::Session;
|
||||
use crate::common::IoSession;
|
||||
|
||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||
/// protocol.
|
||||
@ -10,11 +11,6 @@ pub struct TlsStream<IO> {
|
||||
pub(crate) state: TlsState,
|
||||
}
|
||||
|
||||
pub(crate) enum MidHandshake<IO> {
|
||||
Handshaking(TlsStream<IO>),
|
||||
End,
|
||||
}
|
||||
|
||||
impl<IO> TlsStream<IO> {
|
||||
#[inline]
|
||||
pub fn get_ref(&self) -> (&IO, &ServerSession) {
|
||||
@ -32,34 +28,23 @@ impl<IO> TlsStream<IO> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> Future for MidHandshake<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = io::Result<TlsStream<IO>>;
|
||||
impl<IO> IoSession for TlsStream<IO> {
|
||||
type Io = IO;
|
||||
type Session = ServerSession;
|
||||
|
||||
#[inline]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if let MidHandshake::Handshaking(stream) = this {
|
||||
let eof = !stream.state.readable();
|
||||
let (io, session) = stream.get_mut();
|
||||
let mut stream = Stream::new(io, session).set_eof(eof);
|
||||
|
||||
while stream.session.is_handshaking() {
|
||||
futures::ready!(stream.handshake(cx))?;
|
||||
fn skip_handshake(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
while stream.session.wants_write() {
|
||||
futures::ready!(stream.write_io(cx))?;
|
||||
}
|
||||
#[inline]
|
||||
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
|
||||
(&mut self.state, &mut self.io, &mut self.session)
|
||||
}
|
||||
|
||||
match mem::replace(this, MidHandshake::End) {
|
||||
MidHandshake::Handshaking(stream) => Poll::Ready(Ok(stream)),
|
||||
MidHandshake::End => panic!(),
|
||||
}
|
||||
#[inline]
|
||||
fn into_io(self) -> Self::Io {
|
||||
self.io
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user