diff --git a/Cargo.toml b/Cargo.toml index 68d12c2..765cdde 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,15 +15,15 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures = "0.2.0-alpha" -tokio = { version = "0.1", features = [ "unstable-futures" ] } +futures = { version = "0.2.0-alpha", optional = true } +tokio = { version = "0.1", optional = true } rustls = "0.12" webpki = "0.18.0-alpha" [dev-dependencies] -tokio = { version = "0.1", features = [ "unstable-futures" ] } +tokio = "0.1" clap = "2.26" webpki-roots = "0.14" -[patch.crates-io] -tokio = { git = "https://github.com/tokio-rs/tokio" } +[features] +default = [ "futures", "tokio" ] diff --git a/src/futures_impl.rs b/src/futures_impl.rs new file mode 100644 index 0000000..b8b91bf --- /dev/null +++ b/src/futures_impl.rs @@ -0,0 +1,80 @@ +use super::*; +use futures::{ Future, Poll, Async }; +use futures::io::{ Error, AsyncRead, AsyncWrite }; +use futures::task::Context; + + +impl Future for ConnectAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self, ctx: &mut Context) -> Poll { + self.0.poll(ctx) + } +} + +impl Future for AcceptAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self, ctx: &mut Context) -> Poll { + self.0.poll(ctx) + } +} + +impl Future for MidHandshake + where S: io::Read + io::Write, C: Session +{ + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self, _: &mut Context) -> Poll { + loop { + let stream = self.inner.as_mut().unwrap(); + if !stream.session.is_handshaking() { break }; + + match stream.do_io() { + Ok(()) => match (stream.eof, stream.session.is_handshaking()) { + (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (false, true) => continue, + (..) => break + }, + Err(e) => match (e.kind(), stream.session.is_handshaking()) { + (io::ErrorKind::WouldBlock, true) => return Ok(Async::Pending), + (io::ErrorKind::WouldBlock, false) => break, + (..) => return Err(e) + } + } + } + + Ok(Async::Ready(self.inner.take().unwrap())) + } +} + +impl AsyncRead for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{ + fn poll_read(&mut self, _: &mut Context, buf: &mut [u8]) -> Poll { + unimplemented!() + } +} + +impl AsyncWrite for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{ + fn poll_write(&mut self, _: &mut Context, buf: &[u8]) -> Poll { + unimplemented!() + } + + fn poll_flush(&mut self, _: &mut Context) -> Poll<(), Error> { + unimplemented!() + } + + fn poll_close(&mut self, _: &mut Context) -> Poll<(), Error> { + unimplemented!() + } +} diff --git a/src/lib.rs b/src/lib.rs index 33a1dc3..98ff6c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,10 +5,11 @@ extern crate tokio; extern crate rustls; extern crate webpki; +mod tokio_impl; +mod futures_impl; + use std::io; use std::sync::Arc; -use futures::{ Future, Poll, Async }; -use futures::task::Context; use rustls::{ Session, ClientSession, ServerSession, ClientConfig, ServerConfig @@ -77,58 +78,11 @@ pub fn accept_async_with_session(stream: S, session: ServerSession) }) } -impl Future for ConnectAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - -impl Future for AcceptAsync { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, ctx: &mut Context) -> Poll { - self.0.poll(ctx) - } -} - struct MidHandshake { inner: Option> } -impl Future for MidHandshake - where S: io::Read + io::Write, C: Session -{ - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self, _: &mut Context) -> Poll { - loop { - let stream = self.inner.as_mut().unwrap(); - if !stream.session.is_handshaking() { break }; - - match stream.do_io() { - Ok(()) => match (stream.eof, stream.session.is_handshaking()) { - (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), - (false, true) => continue, - (..) => break - }, - Err(e) => match (e.kind(), stream.session.is_handshaking()) { - (io::ErrorKind::WouldBlock, true) => return Ok(Async::Pending), - (io::ErrorKind::WouldBlock, false) => break, - (..) => return Err(e) - } - } - } - - Ok(Async::Ready(self.inner.take().unwrap())) - } -} - /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -268,41 +222,3 @@ impl io::Write for TlsStream self.io.flush() } } - - -mod tokio_impl { - use super::*; - use tokio::io::{ AsyncRead, AsyncWrite }; - use tokio::prelude::Poll; - - impl AsyncRead for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session - {} - - impl AsyncWrite for TlsStream - where - S: AsyncRead + AsyncWrite, - C: Session - { - fn shutdown(&mut self) -> Poll<(), io::Error> { - if !self.is_shutdown { - self.session.send_close_notify(); - self.is_shutdown = true; - } - while self.session.wants_write() { - self.session.write_tls(&mut self.io)?; - } - self.io.flush()?; - self.io.shutdown() - } - } -} - -mod futures_impl { - use super::*; - use futures::io::{ AsyncRead, AsyncWrite }; - - // TODO -} diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs new file mode 100644 index 0000000..117b56b --- /dev/null +++ b/src/tokio_impl.rs @@ -0,0 +1,76 @@ +use super::*; +use tokio::prelude::*; +use tokio::io::{ AsyncRead, AsyncWrite }; +use tokio::prelude::Poll; + + +impl Future for ConnectAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + self.0.poll() + } +} + +impl Future for AcceptAsync { + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + self.0.poll() + } +} + +impl Future for MidHandshake + where S: io::Read + io::Write, C: Session +{ + type Item = TlsStream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + let stream = self.inner.as_mut().unwrap(); + if !stream.session.is_handshaking() { break }; + + match stream.do_io() { + Ok(()) => match (stream.eof, stream.session.is_handshaking()) { + (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (false, true) => continue, + (..) => break + }, + Err(e) => match (e.kind(), stream.session.is_handshaking()) { + (io::ErrorKind::WouldBlock, true) => return Ok(Async::NotReady), + (io::ErrorKind::WouldBlock, false) => break, + (..) => return Err(e) + } + } + } + + Ok(Async::Ready(self.inner.take().unwrap())) + } +} + +impl AsyncRead for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{} + +impl AsyncWrite for TlsStream + where + S: AsyncRead + AsyncWrite, + C: Session +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + if !self.is_shutdown { + self.session.send_close_notify(); + self.is_shutdown = true; + } + while self.session.wants_write() { + self.session.write_tls(&mut self.io)?; + } + self.io.flush()?; + self.io.shutdown() + } +} diff --git a/tests/test.rs b/tests/test.rs index 5698ec0..5e37e73 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -9,7 +9,8 @@ use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; -use futures::{ FutureExt, StreamExt }; +use tokio::prelude::*; +// use futures::{ FutureExt, StreamExt }; use tokio::net::{ TcpListener, TcpStream }; use tokio::io as aio; use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; @@ -46,12 +47,12 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { .map(drop) .map_err(drop); - tokio::spawn2(done); + tokio::spawn(done); Ok(()) }) .then(|_| Ok(())); - tokio::runtime::run2(done); + tokio::runtime::run(done); }); recv.recv().unwrap()