diff --git a/src/client.rs b/src/client.rs index 9d57268..616c151 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,6 @@ use super::*; -use std::io::Write; use rustls::Session; - +use std::io::Write; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -12,21 +11,14 @@ pub struct TlsStream { pub(crate) state: TlsState, #[cfg(feature = "early-data")] - pub(crate) early_data: (usize, Vec) -} - -#[derive(Debug)] -pub(crate) enum TlsState { - #[cfg(feature = "early-data")] EarlyData, - Stream, - Eof, - Shutdown + pub(crate) early_data: (usize, Vec), } pub(crate) enum MidHandshake { Handshaking(TlsStream), - #[cfg(feature = "early-data")] EarlyData(TlsStream), - End + #[cfg(feature = "early-data")] + EarlyData(TlsStream), + End, } impl TlsStream { @@ -47,7 +39,8 @@ impl TlsStream { } impl Future for MidHandshake -where IO: AsyncRead + AsyncWrite, +where + IO: AsyncRead + AsyncWrite, { type Item = TlsStream; type Error = io::Error; @@ -71,13 +64,14 @@ where IO: AsyncRead + AsyncWrite, MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), #[cfg(feature = "early-data")] MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), - MidHandshake::End => panic!() + MidHandshake::End => panic!(), } } } impl io::Read for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self.state { @@ -106,31 +100,35 @@ where IO: AsyncRead + AsyncWrite } self.read(buf) - }, - TlsState::Stream => { + } + TlsState::Stream | TlsState::WriteShutdown => { let mut stream = Stream::new(&mut self.io, &mut self.session); match stream.read(buf) { Ok(0) => { - self.state = TlsState::Eof; + self.state.shutdown_read(); Ok(0) - }, + } Ok(n) => Ok(n), Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state = TlsState::Shutdown; - stream.session.send_close_notify(); + self.state.shutdown_read(); + if self.state.writeable() { + stream.session.send_close_notify(); + self.state.shutdown_write(); + } Ok(0) - }, - Err(e) => Err(e) + } + Err(e) => Err(e), } - }, - TlsState::Eof | TlsState::Shutdown => Ok(0), + } + TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0), } } } impl io::Write for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { fn write(&mut self, buf: &[u8]) -> io::Result { let mut stream = Stream::new(&mut self.io, &mut self.session); @@ -164,8 +162,8 @@ where IO: AsyncRead + AsyncWrite self.state = TlsState::Stream; data.clear(); stream.write(buf) - }, - _ => stream.write(buf) + } + _ => stream.write(buf), } } @@ -176,7 +174,8 @@ where IO: AsyncRead + AsyncWrite } impl AsyncRead for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { false @@ -184,15 +183,13 @@ where IO: AsyncRead + AsyncWrite } impl AsyncWrite for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { fn shutdown(&mut self) -> Poll<(), io::Error> { - match self.state { - TlsState::Shutdown => (), - _ => { - self.session.send_close_notify(); - self.state = TlsState::Shutdown; - } + if self.state.writeable() { + self.session.send_close_notify(); + self.state.shutdown_write(); } let mut stream = Stream::new(&mut self.io, &mut self.session); diff --git a/src/lib.rs b/src/lib.rs index 6a77fbb..04d7421 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,39 +3,68 @@ pub extern crate rustls; pub extern crate webpki; -extern crate futures; -extern crate tokio_io; extern crate bytes; +extern crate futures; extern crate iovec; +extern crate tokio_io; -mod common; pub mod client; +mod common; pub mod server; -use std::{ io, mem }; -use std::sync::Arc; -use webpki::DNSNameRef; -use rustls::{ - ClientSession, ServerSession, - ClientConfig, ServerConfig -}; -use futures::{Async, Future, Poll}; -use tokio_io::{ AsyncRead, AsyncWrite, try_nb }; use common::Stream; +use futures::{Async, Future, Poll}; +use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; +use std::sync::Arc; +use std::{io, mem}; +use tokio_io::{try_nb, AsyncRead, AsyncWrite}; +use webpki::DNSNameRef; +#[derive(Debug, Copy, Clone)] +pub enum TlsState { + #[cfg(feature = "early-data")] + EarlyData, + Stream, + ReadShutdown, + WriteShutdown, + FullyShutdown, +} + +impl TlsState { + pub(crate) fn shutdown_read(&mut self) { + match *self { + TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + _ => *self = TlsState::ReadShutdown, + } + } + + pub(crate) fn shutdown_write(&mut self) { + match *self { + TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + _ => *self = TlsState::WriteShutdown, + } + } + + pub(crate) fn writeable(&self) -> bool { + match *self { + TlsState::WriteShutdown | TlsState::FullyShutdown => false, + _ => true, + } + } +} /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. #[derive(Clone)] pub struct TlsConnector { inner: Arc, #[cfg(feature = "early-data")] - early_data: bool + early_data: bool, } /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. #[derive(Clone)] pub struct TlsAcceptor { - inner: Arc + inner: Arc, } impl From> for TlsConnector { @@ -43,7 +72,7 @@ impl From> for TlsConnector { TlsConnector { inner, #[cfg(feature = "early-data")] - early_data: false + early_data: false, } } } @@ -66,40 +95,45 @@ impl TlsConnector { } pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect - where IO: AsyncRead + AsyncWrite + where + IO: AsyncRead + AsyncWrite, { self.connect_with(domain, stream, |_| ()) } #[inline] - pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) - -> Connect + pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect where IO: AsyncRead + AsyncWrite, - F: FnOnce(&mut ClientSession) + F: FnOnce(&mut ClientSession), { let mut session = ClientSession::new(&self.inner, domain); f(&mut session); - #[cfg(not(feature = "early-data"))] { + #[cfg(not(feature = "early-data"))] + { Connect(client::MidHandshake::Handshaking(client::TlsStream { - session, io: stream, - state: client::TlsState::Stream, + session, + io: stream, + state: TlsState::Stream, })) } - #[cfg(feature = "early-data")] { + #[cfg(feature = "early-data")] + { Connect(if self.early_data { client::MidHandshake::EarlyData(client::TlsStream { - session, io: stream, - state: client::TlsState::EarlyData, - early_data: (0, Vec::new()) + session, + io: stream, + state: TlsState::EarlyData, + early_data: (0, Vec::new()), }) } else { client::MidHandshake::Handshaking(client::TlsStream { - session, io: stream, - state: client::TlsState::Stream, - early_data: (0, Vec::new()) + session, + io: stream, + state: TlsState::Stream, + early_data: (0, Vec::new()), }) }) } @@ -108,29 +142,29 @@ impl TlsConnector { impl TlsAcceptor { pub fn accept(&self, stream: IO) -> Accept - where IO: AsyncRead + AsyncWrite, + where + IO: AsyncRead + AsyncWrite, { self.accept_with(stream, |_| ()) } #[inline] - pub fn accept_with(&self, stream: IO, f: F) - -> Accept + pub fn accept_with(&self, stream: IO, f: F) -> Accept where IO: AsyncRead + AsyncWrite, - F: FnOnce(&mut ServerSession) + F: FnOnce(&mut ServerSession), { let mut session = ServerSession::new(&self.inner); f(&mut session); Accept(server::MidHandshake::Handshaking(server::TlsStream { - session, io: stream, - state: server::TlsState::Stream, + session, + io: stream, + state: TlsState::Stream, })) } } - /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. pub struct Connect(client::MidHandshake); @@ -139,7 +173,6 @@ pub struct Connect(client::MidHandshake); /// once the accept handshake has finished. pub struct Accept(server::MidHandshake); - impl Future for Connect { type Item = client::TlsStream; type Error = io::Error; diff --git a/src/server.rs b/src/server.rs index 67d47d3..1568414 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,26 +1,18 @@ use super::*; use rustls::Session; - /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. #[derive(Debug)] pub struct TlsStream { pub(crate) io: IO, pub(crate) session: ServerSession, - pub(crate) state: TlsState -} - -#[derive(Debug)] -pub(crate) enum TlsState { - Stream, - Eof, - Shutdown + pub(crate) state: TlsState, } pub(crate) enum MidHandshake { Handshaking(TlsStream), - End + End, } impl TlsStream { @@ -41,7 +33,8 @@ impl TlsStream { } impl Future for MidHandshake -where IO: AsyncRead + AsyncWrite, +where + IO: AsyncRead + AsyncWrite, { type Item = TlsStream; type Error = io::Error; @@ -63,38 +56,45 @@ where IO: AsyncRead + AsyncWrite, match mem::replace(self, MidHandshake::End) { MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), - MidHandshake::End => panic!() + MidHandshake::End => panic!(), } } } impl io::Read for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { fn read(&mut self, buf: &mut [u8]) -> io::Result { let mut stream = Stream::new(&mut self.io, &mut self.session); match self.state { - TlsState::Stream => match stream.read(buf) { + TlsState::Stream | TlsState::WriteShutdown => match stream.read(buf) { Ok(0) => { - self.state = TlsState::Eof; + self.state.shutdown_read(); Ok(0) - }, + } Ok(n) => Ok(n), Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.state = TlsState::Shutdown; - stream.session.send_close_notify(); + self.state.shutdown_read(); + if self.state.writeable() { + stream.session.send_close_notify(); + self.state.shutdown_write(); + } Ok(0) - }, - Err(e) => Err(e) + } + Err(e) => Err(e), }, - TlsState::Eof | TlsState::Shutdown => Ok(0) + TlsState::ReadShutdown | TlsState::FullyShutdown => Ok(0), + #[cfg(feature = "early-data")] + s => unreachable!("server TLS can not hit this state: {:?}", s), } } } impl io::Write for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { fn write(&mut self, buf: &[u8]) -> io::Result { let mut stream = Stream::new(&mut self.io, &mut self.session); @@ -108,7 +108,8 @@ where IO: AsyncRead + AsyncWrite } impl AsyncRead for TlsStream -where IO: AsyncRead + AsyncWrite +where + IO: AsyncRead + AsyncWrite, { unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { false @@ -116,15 +117,13 @@ where IO: AsyncRead + AsyncWrite } impl AsyncWrite for TlsStream -where IO: AsyncRead + AsyncWrite, +where + IO: AsyncRead + AsyncWrite, { fn shutdown(&mut self) -> Poll<(), io::Error> { - match self.state { - TlsState::Shutdown => (), - _ => { - self.session.send_close_notify(); - self.state = TlsState::Shutdown; - } + if self.state.writeable() { + self.session.send_close_notify(); + self.state.shutdown_write(); } let mut stream = Stream::new(&mut self.io, &mut self.session);