diff --git a/src/common/mod.rs b/src/common/mod.rs index 7db198e..e18aa6e 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -12,9 +12,10 @@ use rustls::WriteV; #[cfg(feature = "tokio-support")] use tokio::io::AsyncWrite; + pub struct Stream<'a, S: 'a, IO: 'a> { - session: &'a mut S, - io: &'a mut IO + pub session: &'a mut S, + pub io: &'a mut IO } pub trait WriteTls<'a, S: Session, IO: Read + Write>: Read + Write { diff --git a/src/lib.rs b/src/lib.rs index 736cad6..ee1524d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -129,9 +129,11 @@ impl TlsStream { } } -impl From<(IO, S)> for TlsStream { +impl From<(IO, S)> for TlsStream { #[inline] fn from((io, session): (IO, S)) -> TlsStream { + assert!(!session.is_handshaking()); + TlsStream { is_shutdown: false, eof: false, diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 00b4722..644e4f0 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -5,6 +5,17 @@ use tokio::prelude::Poll; use common::Stream; +macro_rules! try_async { + ( $e:expr ) => { + match $e { + Ok(n) => n, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => + return Ok(Async::NotReady), + Err(e) => return Err(e) + } + } +} + impl Future for Connect { type Item = TlsStream; type Error = io::Error; @@ -24,7 +35,9 @@ impl Future for Accept { } impl Future for MidHandshake - where IO: io::Read + io::Write, S: Session +where + IO: io::Read + io::Write, + S: Session { type Item = TlsStream; type Error = io::Error; @@ -32,15 +45,15 @@ impl Future for MidHandshake fn poll(&mut self) -> Poll { { let stream = self.inner.as_mut().unwrap(); - if stream.session.is_handshaking() { - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(session, io); + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(session, io); - match stream.complete_io() { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), - Err(e) => return Err(e) - } + if stream.session.is_handshaking() { + try_async!(stream.complete_io()); + } + + if stream.session.wants_write() { + try_async!(stream.complete_io()); } } @@ -69,12 +82,10 @@ impl AsyncWrite for TlsStream self.is_shutdown = true; } - match self.session.complete_io(&mut self.io) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), - Err(e) => return Err(e) + { + let mut stream = Stream::new(&mut self.session, &mut self.io); + try_async!(stream.complete_io()); } - self.io.shutdown() } }