//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). extern crate rustls; extern crate webpki; #[cfg(feature = "tokio")] mod tokio_impl; #[cfg(feature = "unstable-futures")] mod futures_impl; use std::io; use std::sync::Arc; use webpki::DNSNameRef; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig, Stream }; /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ClientConfigExt: sealed::Sealed { fn connect_async(&self, domain: DNSNameRef, stream: S) -> ConnectAsync where S: io::Read + io::Write; } /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ServerConfigExt: sealed::Sealed { fn accept_async(&self, stream: S) -> AcceptAsync where S: io::Read + io::Write; } /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. pub struct ConnectAsync(MidHandshake); /// Future returned from `ServerConfigExt::accept_async` which will resolve /// once the accept handshake has finished. pub struct AcceptAsync(MidHandshake); impl sealed::Sealed for Arc {} impl ClientConfigExt for Arc { fn connect_async(&self, domain: DNSNameRef, stream: S) -> ConnectAsync where S: io::Read + io::Write { connect_async_with_session(stream, ClientSession::new(self, domain)) } } #[inline] pub fn connect_async_with_session(stream: S, session: ClientSession) -> ConnectAsync where S: io::Read + io::Write { ConnectAsync(MidHandshake { inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) }) } impl sealed::Sealed for Arc {} impl ServerConfigExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync where S: io::Read + io::Write { accept_async_with_session(stream, ServerSession::new(self)) } } #[inline] pub fn accept_async_with_session(stream: S, session: ServerSession) -> AcceptAsync where S: io::Read + io::Write { AcceptAsync(MidHandshake { inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) }) } struct MidHandshake { inner: Option> } /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. #[derive(Debug)] pub struct TlsStream { is_shutdown: bool, eof: bool, io: S, session: C } impl TlsStream { #[inline] pub fn get_ref(&self) -> (&S, &C) { (&self.io, &self.session) } #[inline] pub fn get_mut(&mut self) -> (&mut S, &mut C) { (&mut self.io, &mut self.session) } } impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { if self.eof { return Ok(0); } // TODO nll let result = { let (io, session) = self.get_mut(); let mut stream = Stream::new(session, io); stream.read(buf) }; match result { Ok(0) => { self.eof = true; Ok(0) }, Ok(n) => Ok(n), Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { self.eof = true; self.is_shutdown = true; self.session.send_close_notify(); Ok(0) }, Err(e) => Err(e) } } } impl io::Write for TlsStream where S: io::Read + io::Write, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { let (io, session) = self.get_mut(); let mut stream = Stream::new(session, io); stream.write(buf) } fn flush(&mut self) -> io::Result<()> { { let (io, session) = self.get_mut(); let mut stream = Stream::new(session, io); stream.flush()?; } self.io.flush() } } mod sealed { pub trait Sealed {} }