//! Async TLS streams //! //! [tokio-tls](https://github.com/tokio-rs/tokio-tls) fork, use [rustls](https://github.com/ctz/rustls). #[cfg_attr(feature = "tokio-proto", macro_use)] extern crate futures; extern crate tokio_core; extern crate rustls; pub mod proto; use std::io; use std::sync::Arc; use futures::{ Future, Poll, Async }; use tokio_core::io::Io; use rustls::{ Session, ClientSession, ServerSession }; use rustls::{ ClientConfig, ServerConfig }; /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ClientConfigExt { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync where S: Io; } /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ServerConfigExt { fn accept_async(&self, stream: S) -> AcceptAsync where S: Io; } /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. pub struct ConnectAsync(MidHandshake); /// Future returned from `ServerConfigExt::accept_async` which will resolve /// once the accept handshake has finished. pub struct AcceptAsync(MidHandshake); impl ClientConfigExt for Arc { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync where S: Io { ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, ClientSession::new(self, domain))) }) } } impl ServerConfigExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync where S: Io { AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, ServerSession::new(self))) }) } } impl Future for ConnectAsync { type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { self.0.poll() } } impl Future for AcceptAsync { type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { self.0.poll() } } struct MidHandshake { inner: Option> } impl Future for MidHandshake where S: Io, C: Session { type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { loop { let stream = self.inner.as_mut().unwrap_or_else(|| unreachable!()); if !stream.session.is_handshaking() { break }; match stream.do_io() { Ok(()) => match (stream.eof, stream.session.is_handshaking()) { (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), (false, true) => continue, (..) => break }, Err(e) => match (e.kind(), stream.session.is_handshaking()) { (io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady), (io::ErrorKind::WouldBlock, false) => break, (..) => return Err(e) } } } Ok(Async::Ready(self.inner.take().unwrap_or_else(|| unreachable!()))) } } /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. #[derive(Debug)] pub struct TlsStream { eof: bool, io: S, session: C } impl TlsStream { pub fn get_ref(&self) -> (&S, &C) { (&self.io, &self.session) } pub fn get_mut(&mut self) -> (&mut S, &mut C) { (&mut self.io, &mut self.session) } } impl TlsStream where S: Io, C: Session { #[inline] pub fn new(io: S, session: C) -> TlsStream { TlsStream { eof: false, io: io, session: session } } pub fn do_io(&mut self) -> io::Result<()> { loop { let read_would_block = match (!self.eof && self.session.wants_read(), self.io.poll_read()) { (true, Async::Ready(())) => { match self.session.read_tls(&mut self.io) { Ok(0) => self.eof = true, Ok(_) => self.session.process_new_packets() .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), Err(e) => return Err(e) }; continue }, (true, Async::NotReady) => true, (false, _) => false, }; let write_would_block = match (self.session.wants_write(), self.io.poll_write()) { (true, Async::Ready(())) => match self.session.write_tls(&mut self.io) { Ok(_) => continue, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, Err(e) => return Err(e) }, (true, Async::NotReady) => true, (false, _) => false }; if read_would_block || write_would_block { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } else { return Ok(()); } } } } impl io::Read for TlsStream where S: Io, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { loop { match self.session.read(buf) { Ok(0) if !self.eof => self.do_io()?, Ok(n) => return Ok(n), Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted { self.do_io()?; return if self.eof { Ok(0) } else { Err(e) } } else { return Err(e) } } } } } impl io::Write for TlsStream where S: Io, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { let output = self.session.write(buf)?; while self.session.wants_write() && self.io.poll_write().is_ready() { match self.session.write_tls(&mut self.io) { Ok(_) => (), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break, Err(e) => return Err(e) } } Ok(output) } fn flush(&mut self) -> io::Result<()> { self.session.flush()?; while self.session.wants_write() { self.session.write_tls(&mut self.io)?; } Ok(()) } } impl Io for TlsStream where S: Io, C: Session {}