diff --git a/Cargo.toml b/Cargo.toml index 47a9ddd..427015d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,14 +23,15 @@ webpki = "0.18.0-alpha" [dev-dependencies] tokio = "0.1" tokio-io = "0.1" -# tokio-core = "0.1" -# tokio-file-unix = "0.4" +tokio-core = "0.1" +tokio-file-unix = "0.4" clap = "2.26" webpki-roots = "0.14" [features] -unstable-futures = [ "futures", "tokio/unstable-futures" ] -default = [ "unstable-futures", "tokio" ] +default = [ "tokio" ] +# unstable-futures = [ "futures", "tokio/unstable-futures" ] +# default = [ "unstable-futures", "tokio" ] [patch.crates-io] -tokio = { path = "../ref/tokio" } +# tokio = { path = "../ref/tokio" } diff --git a/src/lib.rs b/src/lib.rs index 8167a5c..1d4ab7f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,8 @@ use std::io; use std::sync::Arc; use rustls::{ Session, ClientSession, ServerSession, - ClientConfig, ServerConfig + ClientConfig, ServerConfig, + Stream }; @@ -92,10 +93,12 @@ pub struct TlsStream { } impl TlsStream { + #[inline] pub fn get_ref(&self) -> (&S, &C) { (&self.io, &self.session) } + #[inline] pub fn get_mut(&mut self) -> (&mut S, &mut C) { (&mut self.io, &mut self.session) } @@ -187,19 +190,13 @@ impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { - try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); + let (io, session) = self.get_mut(); + let mut stream = Stream::new(session, io); - loop { - match self.session.read(buf) { - Ok(0) if !self.eof => while Self::do_read(&mut self.session, &mut self.io, &mut self.eof)? {}, - Ok(n) => return Ok(n), - Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted { - try_ignore!(Self::do_read(&mut self.session, &mut self.io, &mut self.eof)); - return if self.eof { Ok(0) } else { Err(e) } - } else { - return Err(e) - } - } + match stream.read(buf) { + Ok(n) => Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(0), + Err(e) => Err(e) } } } @@ -208,35 +205,19 @@ impl io::Write for TlsStream where S: io::Read + io::Write, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); + let (io, session) = self.get_mut(); + let mut stream = Stream::new(session, io); - let mut wlen = self.session.write(buf)?; - - loop { - match Self::do_write(&mut self.session, &mut self.io) { - Ok(true) => continue, - Ok(false) if wlen == 0 => (), - Ok(false) => break, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => - if wlen == 0 { - // Both rustls buffer and IO buffer are blocking. - return Err(io::Error::from(io::ErrorKind::WouldBlock)); - } else { - continue - }, - Err(e) => return Err(e) - } - - assert_eq!(wlen, 0); - wlen = self.session.write(buf)?; - } - - Ok(wlen) + stream.write(buf) } fn flush(&mut self) -> io::Result<()> { - self.session.flush()?; - while Self::do_write(&mut self.session, &mut self.io)? {}; + { + let (io, session) = self.get_mut(); + let mut stream = Stream::new(session, io); + stream.flush()?; + } + self.io.flush() } } diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 294f915..fa4fbe8 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -35,17 +35,10 @@ impl Future for MidHandshake let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; - match TlsStream::do_io(&mut stream.session, &mut stream.io, &mut stream.eof) { - Ok(()) => match (stream.eof, stream.session.is_handshaking()) { - (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), - (false, true) => continue, - (..) => break - }, - Err(e) => match (e.kind(), stream.session.is_handshaking()) { - (io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady), - (io::ErrorKind::WouldBlock, false) => break, - (..) => return Err(e) - } + match stream.session.complete_io(&mut stream.io) { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), + Err(e) => return Err(e) } } diff --git a/tests/test.rs b/tests/test.rs index c0e2c8f..e231737 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -59,6 +59,7 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { recv.recv().unwrap() } +#[cfg(feature = "unstable-futures")] fn start_server2(cert: Vec, rsa: PrivateKey) -> SocketAddr { use futures::{ FutureExt, StreamExt }; use futures::io::{ AsyncReadExt, AsyncWriteExt }; @@ -136,16 +137,9 @@ fn start_client2(addr: &SocketAddr, domain: &str, chain: Option