From 7e4fcca0321e94f23c9eba67809a5847b349078d Mon Sep 17 00:00:00 2001 From: quininer kel Date: Mon, 27 Feb 2017 20:59:35 +0800 Subject: [PATCH] [Improved] MidHandshake/TlsStream - [Improved] README.md - [Improved] MidHandshake poll - [Improved] TlsStream read - [Fixed] TlsStream write, possible of repeat write - [Removed] TlsStream poll_{read, write} --- README.md | 18 ++++++++++++++++++ src/lib.rs | 56 ++++++++++++++++++++++++------------------------------ 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index efd4c48..b521c92 100644 --- a/README.md +++ b/README.md @@ -4,3 +4,21 @@ [![docs.rs](https://docs.rs/tokio-rustls/badge.svg)](https://docs.rs/tokio-rustls/) [tokio-tls](https://github.com/tokio-rs/tokio-tls) fork, use [rustls](https://github.com/ctz/rustls). + +### exmaple + +```rust +// ... + +use rustls::ClientConfig; +use tokio_rustls::ClientConfigExt; + +let mut config = ClientConfig::new(); +config.root_store.add_trust_anchors(&webpki_roots::ROOTS); +let config = Arc::new(config); + +TcpStream::connect(&addr, &handle) + .and_then(|socket| config.connect_async("www.rust-lang.org", socket)) + +// ... +``` diff --git a/src/lib.rs b/src/lib.rs index ae56975..9bd1953 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,12 +98,10 @@ impl Future for MidHandshake if !stream.session.is_handshaking() { break }; match stream.do_io() { - Ok(()) => if stream.eof { - return Err(io::Error::from(io::ErrorKind::UnexpectedEof)) - } else if stream.session.is_handshaking() { - continue - } else { - break + Ok(()) => match (stream.eof, stream.session.is_handshaking()) { + (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (false, true) => continue, + (..) => break }, Err(e) => match (e.kind(), stream.session.is_handshaking()) { (io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady), @@ -189,11 +187,17 @@ impl io::Read for TlsStream where S: Io, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.do_io()?; - if self.eof { - Ok(0) - } else { - self.session.read(buf) + loop { + match self.session.read(buf) { + Ok(0) if !self.eof => self.do_io()?, + Ok(n) => return Ok(n), + Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted { + self.do_io()?; + return if self.eof { Ok(0) } else { Err(e) } + } else { + return Err(e) + } + } } } } @@ -202,11 +206,17 @@ impl io::Write for TlsStream where S: Io, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - let output = self.session.write(buf); + let output = self.session.write(buf)?; + while self.session.wants_write() && self.io.poll_write().is_ready() { - self.session.write_tls(&mut self.io)?; + match self.session.write_tls(&mut self.io) { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break, + Err(e) => return Err(e) + } } - output + + Ok(output) } fn flush(&mut self) -> io::Result<()> { @@ -218,20 +228,4 @@ impl io::Write for TlsStream } } -impl Io for TlsStream where S: Io, C: Session { - fn poll_read(&mut self) -> Async<()> { - if !self.eof && self.session.wants_read() && self.io.poll_read().is_not_ready() { - Async::NotReady - } else { - Async::Ready(()) - } - } - - fn poll_write(&mut self) -> Async<()> { - if self.session.wants_write() && self.io.poll_write().is_not_ready() { - Async::NotReady - } else { - Async::Ready(()) - } - } -} +impl Io for TlsStream where S: Io, C: Session {}