From 72de25ebce08a9c9421bfac4d23cf23deb8694cf Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 21 Mar 2018 21:44:36 +0800 Subject: [PATCH] change: impl io::{Read,Write} --- Cargo.toml | 2 +- examples/server.rs | 13 ++-- src/futures_impl.rs | 8 ++- src/lib.rs | 160 +++++++++++++++++++++++++------------------- src/tokio_impl.rs | 13 ++-- 5 files changed, 107 insertions(+), 89 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 765cdde..2d42213 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.5.0" +version = "0.6.0-alpha" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/quininer/tokio-rustls" diff --git a/examples/server.rs b/examples/server.rs index a450393..178ff75 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,7 +1,5 @@ extern crate clap; extern crate rustls; -extern crate futures; -extern crate tokio_io; extern crate tokio; extern crate webpki_roots; extern crate tokio_rustls; @@ -10,12 +8,11 @@ use std::sync::Arc; use std::net::ToSocketAddrs; use std::io::BufReader; use std::fs::File; -use futures::{ Future, Stream }; use rustls::{ Certificate, NoClientAuth, PrivateKey, ServerConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; -use tokio_io::{ io, AsyncRead }; +use tokio::prelude::{ Future, Stream }; +use tokio::io::{ self, AsyncRead }; use tokio::net::TcpListener; -use tokio::executor::current_thread; use clap::{ App, Arg }; use tokio_rustls::ServerConfigExt; @@ -64,7 +61,7 @@ fn main() { }) .map(move |(n, ..)| println!("Echo: {} - {:?}", n, addr)) .map_err(move |err| println!("Error: {:?} - {:?}", err, addr2)); - current_thread::spawn(done); + tokio::spawn(done); Ok(()) } else { @@ -82,10 +79,10 @@ fn main() { .and_then(|(stream, _)| io::flush(stream)) .map(move |_| println!("Accept: {:?}", addr)) .map_err(move |err| println!("Error: {:?} - {:?}", err, addr2)); - current_thread::spawn(done); + tokio::spawn(done); Ok(()) }); - current_thread::run(|_| current_thread::spawn(done.map_err(drop))); + tokio::run(done.map_err(drop)); } diff --git a/src/futures_impl.rs b/src/futures_impl.rs index b8b91bf..e8c1f79 100644 --- a/src/futures_impl.rs +++ b/src/futures_impl.rs @@ -1,7 +1,9 @@ +extern crate futures; + use super::*; -use futures::{ Future, Poll, Async }; -use futures::io::{ Error, AsyncRead, AsyncWrite }; -use futures::task::Context; +use self::futures::{ Future, Poll, Async }; +use self::futures::io::{ Error, AsyncRead, AsyncWrite }; +use self::futures::task::Context; impl Future for ConnectAsync { diff --git a/src/lib.rs b/src/lib.rs index 98ff6c0..b7cd5b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,10 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -extern crate futures; -extern crate tokio; extern crate rustls; extern crate webpki; -mod tokio_impl; -mod futures_impl; +#[cfg(feature = "tokio")] mod tokio_impl; +#[cfg(feature = "futures")] mod futures_impl; use std::io; use std::sync::Arc; @@ -17,14 +15,14 @@ use rustls::{ /// Extension trait for the `Arc` type in the `rustls` crate. -pub trait ClientConfigExt { +pub trait ClientConfigExt: sealed::Sealed { fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) -> ConnectAsync where S: io::Read + io::Write; } /// Extension trait for the `Arc` type in the `rustls` crate. -pub trait ServerConfigExt { +pub trait ServerConfigExt: sealed::Sealed { fn accept_async(&self, stream: S) -> AcceptAsync where S: io::Read + io::Write; @@ -39,6 +37,7 @@ pub struct ConnectAsync(MidHandshake); /// once the accept handshake has finished. pub struct AcceptAsync(MidHandshake); +impl sealed::Sealed for Arc {} impl ClientConfigExt for Arc { fn connect_async(&self, domain: webpki::DNSNameRef, stream: S) @@ -54,11 +53,11 @@ pub fn connect_async_with_session(stream: S, session: ClientSession) -> ConnectAsync where S: io::Read + io::Write { - ConnectAsync(MidHandshake { - inner: Some(TlsStream::new(stream, session)) - }) + ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) }) } +impl sealed::Sealed for Arc {} + impl ServerConfigExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync @@ -73,9 +72,7 @@ pub fn accept_async_with_session(stream: S, session: ServerSession) -> AcceptAsync where S: io::Read + io::Write { - AcceptAsync(MidHandshake { - inner: Some(TlsStream::new(stream, session)) - }) + AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, session)) }) } @@ -104,11 +101,30 @@ impl TlsStream { } } + +macro_rules! try_wouldblock { + ( continue $r:expr ) => { + match $r { + Ok(true) => continue, + Ok(false) => false, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, + Err(e) => return Err(e) + } + }; + ( ignore $r:expr ) => { + match $r { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => return Err(e) + } + }; +} + impl TlsStream where S: io::Read + io::Write, C: Session { #[inline] - pub fn new(io: S, session: C) -> TlsStream { + fn new(io: S, session: C) -> TlsStream { TlsStream { is_shutdown: false, eof: false, @@ -117,45 +133,46 @@ impl TlsStream } } + fn do_read(&mut self) -> io::Result { + if !self.eof && self.session.wants_read() { + if self.session.read_tls(&mut self.io)? == 0 { + self.eof = true; + } + + if let Err(err) = self.session.process_new_packets() { + // flush queued messages before returning an Err in + // order to send alerts instead of abruptly closing + // the socket + if self.session.wants_write() { + // ignore result to avoid masking original error + let _ = self.session.write_tls(&mut self.io); + } + return Err(io::Error::new(io::ErrorKind::InvalidData, err)); + } + + Ok(true) + } else { + Ok(false) + } + } + + fn do_write(&mut self) -> io::Result { + if self.session.wants_write() { + self.session.write_tls(&mut self.io)?; + + Ok(true) + } else { + Ok(false) + } + } + + #[inline] pub fn do_io(&mut self) -> io::Result<()> { loop { - let read_would_block = if !self.eof && self.session.wants_read() { - match self.session.read_tls(&mut self.io) { - Ok(0) => { - self.eof = true; - continue - }, - Ok(_) => { - if let Err(err) = self.session.process_new_packets() { - // flush queued messages before returning an Err in - // order to send alerts instead of abruptly closing - // the socket - if self.session.wants_write() { - // ignore result to avoid masking original error - let _ = self.session.write_tls(&mut self.io); - } - return Err(io::Error::new(io::ErrorKind::Other, err)); - } - continue - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, - Err(e) => return Err(e) - } - } else { - false - }; + let write_would_block = try_wouldblock!(continue self.do_write()); + let read_would_block = try_wouldblock!(continue self.do_read()); - let write_would_block = if self.session.wants_write() { - match self.session.write_tls(&mut self.io) { - Ok(_) => continue, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, - Err(e) => return Err(e) - } - } else { - false - }; - - if read_would_block || write_would_block { + if write_would_block || read_would_block { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } else { return Ok(()); @@ -168,12 +185,14 @@ impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { + try_wouldblock!(ignore self.do_io()); + loop { match self.session.read(buf) { - Ok(0) if !self.eof => self.do_io()?, + Ok(0) if !self.eof => while self.do_read()? {}, Ok(n) => return Ok(n), Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted { - self.do_io()?; + try_wouldblock!(ignore self.do_read()); return if self.eof { Ok(0) } else { Err(e) } } else { return Err(e) @@ -187,38 +206,39 @@ impl io::Write for TlsStream where S: io::Read + io::Write, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - if buf.is_empty() { - return Ok(0); - } + try_wouldblock!(ignore self.do_io()); + + let mut wlen = self.session.write(buf)?; loop { - let output = self.session.write(buf)?; - - while self.session.wants_write() { - match self.session.write_tls(&mut self.io) { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => if output == 0 { + match self.do_write() { + Ok(true) => continue, + Ok(false) if wlen == 0 => (), + Ok(false) => break, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => + if wlen == 0 { // Both rustls buffer and IO buffer are blocking. return Err(io::Error::from(io::ErrorKind::WouldBlock)); } else { - break; + continue }, - Err(e) => return Err(e) - } + Err(e) => return Err(e) } - if output > 0 { - // Already wrote something out. - return Ok(output); - } + assert_eq!(wlen, 0); + wlen = self.session.write(buf)?; } + + Ok(wlen) } fn flush(&mut self) -> io::Result<()> { self.session.flush()?; - while self.session.wants_write() { - self.session.write_tls(&mut self.io)?; - } + while self.do_write()? {}; self.io.flush() } } + +mod sealed { + pub trait Sealed {} +} diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index 117b56b..edbda5b 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -1,7 +1,9 @@ +extern crate tokio; + use super::*; -use tokio::prelude::*; -use tokio::io::{ AsyncRead, AsyncWrite }; -use tokio::prelude::Poll; +use self::tokio::prelude::*; +use self::tokio::io::{ AsyncRead, AsyncWrite }; +use self::tokio::prelude::Poll; impl Future for ConnectAsync { @@ -67,10 +69,7 @@ impl AsyncWrite for TlsStream self.session.send_close_notify(); self.is_shutdown = true; } - while self.session.wants_write() { - self.session.write_tls(&mut self.io)?; - } - self.io.flush()?; + while self.do_write()? {}; self.io.shutdown() } }