diff --git a/Cargo.toml b/Cargo.toml index bf70636..c5f05f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ for nonblocking I/O streams. [dependencies] futures = "0.1" tokio-core = "0.1" +tokio-io = "0.1" rustls = "0.5" tokio-proto = { version = "0.1", optional = true } diff --git a/examples/client.rs b/examples/client.rs index 63db56d..2643aba 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,6 +1,7 @@ extern crate clap; extern crate rustls; extern crate futures; +extern crate tokio_io; extern crate tokio_core; extern crate webpki_roots; extern crate tokio_file_unix; @@ -11,7 +12,7 @@ use std::net::ToSocketAddrs; use std::io::{ BufReader, stdout, stdin }; use std::fs; use futures::Future; -use tokio_core::io::{ self, Io }; +use tokio_io::{ io, AsyncRead }; use tokio_core::net::TcpStream; use tokio_core::reactor::Core; use clap::{ App, Arg }; @@ -71,8 +72,9 @@ fn main() { .and_then(|stream| io::write_all(stream, text.as_bytes())) .and_then(|(stream, _)| { let (r, w) = stream.split(); - io::copy(r, stdout).select(io::copy(stdin, w)) + io::copy(r, stdout) .map(|_| ()) + .select(io::copy(stdin, w).map(|_| ())) .map_err(|(e, _)| e) }); diff --git a/examples/server.rs b/examples/server.rs index 283d016..5d27474 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,6 +1,7 @@ extern crate clap; extern crate rustls; extern crate futures; +extern crate tokio_io; extern crate tokio_core; extern crate webpki_roots; extern crate tokio_rustls; @@ -12,7 +13,7 @@ use std::fs::File; use futures::{ Future, Stream }; use rustls::{ Certificate, PrivateKey, ServerConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; -use tokio_core::io::{ self, Io }; +use tokio_io::{ io, AsyncRead }; use tokio_core::net::TcpListener; use tokio_core::reactor::Core; use clap::{ App, Arg }; @@ -62,7 +63,7 @@ fn main() { let (reader, writer) = stream.split(); io::copy(reader, writer) }) - .map(move |n| println!("Echo: {} - {}", n, addr)) + .map(move |(n, _, _)| println!("Echo: {} - {}", n, addr)) .map_err(move |err| println!("Error: {:?} - {}", err, addr)); handle.spawn(done); diff --git a/src/lib.rs b/src/lib.rs index e6f5856..36b4d09 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,8 @@ //! //! [tokio-tls](https://github.com/tokio-rs/tokio-tls) fork, use [rustls](https://github.com/ctz/rustls). -#[cfg_attr(feature = "tokio-proto", macro_use)] -extern crate futures; +#[cfg_attr(feature = "tokio-proto", macro_use)] extern crate futures; +extern crate tokio_io; extern crate tokio_core; extern crate rustls; @@ -12,7 +12,7 @@ pub mod proto; use std::io; use std::sync::Arc; use futures::{ Future, Poll, Async }; -use tokio_core::io::Io; +use tokio_io::{ AsyncRead, AsyncWrite }; use rustls::{ Session, ClientSession, ServerSession }; use rustls::{ ClientConfig, ServerConfig }; @@ -21,14 +21,14 @@ use rustls::{ ClientConfig, ServerConfig }; pub trait ClientConfigExt { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync - where S: Io; + where S: AsyncRead + AsyncWrite; } /// Extension trait for the `Arc` type in the `rustls` crate. pub trait ServerConfigExt { fn accept_async(&self, stream: S) -> AcceptAsync - where S: Io; + where S: AsyncRead + AsyncWrite; } @@ -44,7 +44,7 @@ pub struct AcceptAsync(MidHandshake); impl ClientConfigExt for Arc { fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync - where S: Io + where S: AsyncRead + AsyncWrite { ConnectAsync(MidHandshake { inner: Some(TlsStream::new(stream, ClientSession::new(self, domain))) @@ -55,7 +55,7 @@ impl ClientConfigExt for Arc { impl ServerConfigExt for Arc { fn accept_async(&self, stream: S) -> AcceptAsync - where S: Io + where S: AsyncRead + AsyncWrite { AcceptAsync(MidHandshake { inner: Some(TlsStream::new(stream, ServerSession::new(self))) @@ -63,7 +63,7 @@ impl ServerConfigExt for Arc { } } -impl Future for ConnectAsync { +impl Future for ConnectAsync { type Item = TlsStream; type Error = io::Error; @@ -72,7 +72,7 @@ impl Future for ConnectAsync { } } -impl Future for AcceptAsync { +impl Future for AcceptAsync { type Item = TlsStream; type Error = io::Error; @@ -87,7 +87,7 @@ struct MidHandshake { } impl Future for MidHandshake - where S: Io, C: Session + where S: AsyncRead + AsyncWrite, C: Session { type Item = TlsStream; type Error = io::Error; @@ -136,7 +136,7 @@ impl TlsStream { } impl TlsStream - where S: Io, C: Session + where S: AsyncRead + AsyncWrite, C: Session { #[inline] pub fn new(io: S, session: C) -> TlsStream { @@ -149,29 +149,32 @@ impl TlsStream pub fn do_io(&mut self) -> io::Result<()> { loop { - let read_would_block = match (!self.eof && self.session.wants_read(), self.io.poll_read()) { - (true, Async::Ready(())) => { - match self.session.read_tls(&mut self.io) { - Ok(0) => self.eof = true, - Ok(_) => self.session.process_new_packets() - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Err(e) => return Err(e) - }; - continue - }, - (true, Async::NotReady) => true, - (false, _) => false, + let read_would_block = if !self.eof && self.session.wants_read() { + match self.session.read_tls(&mut self.io) { + Ok(0) => { + self.eof = true; + continue + }, + Ok(_) => { + self.session.process_new_packets() + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + continue + }, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, + Err(e) => return Err(e) + } + } else { + false }; - let write_would_block = match (self.session.wants_write(), self.io.poll_write()) { - (true, Async::Ready(())) => match self.session.write_tls(&mut self.io) { + let write_would_block = if self.session.wants_write() { + match self.session.write_tls(&mut self.io) { Ok(_) => continue, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, Err(e) => return Err(e) - }, - (true, Async::NotReady) => true, - (false, _) => false + } + } else { + false }; if read_would_block || write_would_block { @@ -184,7 +187,7 @@ impl TlsStream } impl io::Read for TlsStream - where S: Io, C: Session + where S: AsyncRead + AsyncWrite, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { loop { @@ -203,12 +206,12 @@ impl io::Read for TlsStream } impl io::Write for TlsStream - where S: Io, C: Session + where S: AsyncRead + AsyncWrite, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { let output = self.session.write(buf)?; - while self.session.wants_write() && self.io.poll_write().is_ready() { + while self.session.wants_write() { match self.session.write_tls(&mut self.io) { Ok(_) => (), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break, @@ -228,6 +231,19 @@ impl io::Write for TlsStream } } -impl Io for TlsStream where S: Io, C: Session { - // TODO impl poll_{read, write} +impl AsyncRead for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{} + +impl AsyncWrite for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.session.send_close_notify(); + self.io.shutdown() + } }