//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). extern crate rustls; extern crate webpki; #[cfg(feature = "tokio")] mod tokio_impl; #[cfg(feature = "unstable-futures")] mod futures_impl; use std::io; use std::sync::Arc; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig }; /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ClientConfigExt: sealed::Sealed { fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) -> ConnectAsync where S: io::Read + io::Write; } /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ServerConfigExt: sealed::Sealed { fn accept_async(&self, stream: S) -> AcceptAsync where S: io::Read + io::Write; } /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. pub struct ConnectAsync(MidHandshake); /// Future returned from `ServerConfigExt::accept_async` which will resolve /// once the accept handshake has finished. pub struct AcceptAsync(MidHandshake); impl sealed::Sealed for Arc {} impl ClientConfigExt for Arc { fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) -> ConnectAsync where S: io::Read + io::Write { connect_async_with_session(stream, ClientSession::new(self, domain)) } } #[inline] pub fn connect_async_with_session(stream: S, session: ClientSession) -> ConnectAsync where S: io::Read + io::Write { ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) }) } impl sealed::Sealed for Arc {} impl ServerConfigExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync where S: io::Read + io::Write { accept_async_with_session(stream, ServerSession::new(self)) } } #[inline] pub fn accept_async_with_session(stream: S, session: ServerSession) -> AcceptAsync where S: io::Read + io::Write { AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) }) } struct MidHandshake { inner: Option> } /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. #[derive(Debug)] pub struct TlsStream { is_shutdown: bool, eof: bool, io: S, session: C } impl TlsStream { pub fn get_ref(&self) -> (&S, &C) { (&self.io, &self.session) } pub fn get_mut(&mut self) -> (&mut S, &mut C) { (&mut self.io, &mut self.session) } } impl TlsStream where S: io::Read + io::Write, C: Session { #[inline] fn new(io: S, session: C) -> TlsStream { TlsStream { is_shutdown: false, eof: false, io: io, session: session } } fn do_read(session: &mut C, io: &mut S, eof: &mut bool) -> io::Result { if !*eof && session.wants_read() { if session.read_tls(io)? == 0 { *eof = true; } if let Err(err) = session.process_new_packets() { // flush queued messages before returning an Err in // order to send alerts instead of abruptly closing // the socket if session.wants_write() { // ignore result to avoid masking original error let _ = session.write_tls(io); } return Err(io::Error::new(io::ErrorKind::InvalidData, err)); } Ok(true) } else { Ok(false) } } fn do_write(session: &mut C, io: &mut S) -> io::Result { if session.wants_write() { session.write_tls(io)?; Ok(true) } else { Ok(false) } } #[inline] pub fn do_io(session: &mut C, io: &mut S, eof: &mut bool) -> io::Result<()> { macro_rules! try_wouldblock { ( $r:expr ) => { match $r { Ok(true) => continue, Ok(false) => false, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, Err(e) => return Err(e) } }; } loop { let write_would_block = try_wouldblock!(Self::do_write(session, io)); let read_would_block = try_wouldblock!(Self::do_read(session, io, eof)); if write_would_block || read_would_block { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } else { return Ok(()); } } } } macro_rules! try_ignore { ( $r:expr ) => { match $r { Ok(_) => (), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), Err(e) => return Err(e) } } } impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); loop { match self.session.read(buf) { Ok(0) if !self.eof => while Self::do_read(&mut self.session, &mut self.io, &mut self.eof)? {}, Ok(n) => return Ok(n), Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted { try_ignore!(Self::do_read(&mut self.session, &mut self.io, &mut self.eof)); return if self.eof { Ok(0) } else { Err(e) } } else { return Err(e) } } } } } impl io::Write for TlsStream where S: io::Read + io::Write, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); let mut wlen = self.session.write(buf)?; loop { match Self::do_write(&mut self.session, &mut self.io) { Ok(true) => continue, Ok(false) if wlen == 0 => (), Ok(false) => break, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => if wlen == 0 { // Both rustls buffer and IO buffer are blocking. return Err(io::Error::from(io::ErrorKind::WouldBlock)); } else { continue }, Err(e) => return Err(e) } assert_eq!(wlen, 0); wlen = self.session.write(buf)?; } Ok(wlen) } fn flush(&mut self) -> io::Result<()> { self.session.flush()?; while Self::do_write(&mut self.session, &mut self.io)? {}; self.io.flush() } } mod sealed { pub trait Sealed {} }