extern crate futures; extern crate tokio_core; extern crate rustls; use std::io; use std::sync::Arc; use futures::{ Future, Poll, Async }; use tokio_core::io::Io; use rustls::{ Session, ClientSession, ServerSession }; pub use rustls::{ ClientConfig, ServerConfig }; pub trait TlsConnectorExt { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync where S: Io; } pub trait TlsAcceptorExt { fn accept_async(&self, stream: S) -> AcceptAsync where S: Io; } pub struct ConnectAsync(MidHandshake); pub struct AcceptAsync(MidHandshake); impl TlsConnectorExt for Arc { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync where S: Io { ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, ClientSession::new(self, domain))) }) } } impl TlsAcceptorExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync where S: Io { AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, ServerSession::new(self))) }) } } impl Future for ConnectAsync { type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { self.0.poll() } } impl Future for AcceptAsync { type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { self.0.poll() } } struct MidHandshake { inner: Option> } impl Future for MidHandshake where S: Io, C: Session { type Item = TlsStream; type Error = io::Error; fn poll(&mut self) -> Poll { loop { let stream = self.inner.as_mut().unwrap_or_else(|| unreachable!()); if !stream.session.is_handshaking() { break }; match stream.do_io() { Ok(()) => continue, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), Err(e) => return Err(e) } if !stream.session.is_handshaking() { break }; if stream.eof { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } else { return Ok(Async::NotReady); } } Ok(Async::Ready(self.inner.take().unwrap_or_else(|| unreachable!()))) } } pub struct TlsStream { eof: bool, io: S, session: C } impl TlsStream where S: Io, C: Session { #[inline] pub fn new(io: S, session: C) -> TlsStream { TlsStream { eof: false, io: io, session: session } } pub fn do_io(&mut self) -> io::Result<()> { loop { let read_would_block = match (!self.eof && self.session.wants_read(), self.io.poll_read()) { (true, Async::Ready(())) => { match self.session.read_tls(&mut self.io) { Ok(0) => self.eof = true, Ok(_) => self.session.process_new_packets() .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), Err(e) => return Err(e) }; continue }, (true, Async::NotReady) => true, (false, _) => false, }; let write_would_block = match (self.session.wants_write(), self.io.poll_write()) { (true, Async::Ready(())) => match self.session.write_tls(&mut self.io) { Ok(_) => continue, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, Err(e) => return Err(e) }, (true, Async::NotReady) => true, (false, _) => false }; if read_would_block || write_would_block { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } else { return Ok(()); } } } } impl io::Read for TlsStream where S: Io, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.do_io()?; if self.eof { Ok(0) } else { self.session.read(buf) } } } impl io::Write for TlsStream where S: Io, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { let output = self.session.write(buf); while self.session.wants_write() && self.io.poll_write().is_ready() { self.session.write_tls(&mut self.io)?; } output } fn flush(&mut self) -> io::Result<()> { self.session.flush()?; while self.session.wants_write() && self.io.poll_write().is_ready() { self.session.write_tls(&mut self.io)?; } Ok(()) } } impl Io for TlsStream where S: Io, C: Session {}