diff --git a/Cargo.toml b/Cargo.toml index 2d42213..47a9ddd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,8 +22,15 @@ webpki = "0.18.0-alpha" [dev-dependencies] tokio = "0.1" +tokio-io = "0.1" +# tokio-core = "0.1" +# tokio-file-unix = "0.4" clap = "2.26" webpki-roots = "0.14" [features] -default = [ "futures", "tokio" ] +unstable-futures = [ "futures", "tokio/unstable-futures" ] +default = [ "unstable-futures", "tokio" ] + +[patch.crates-io] +tokio = { path = "../ref/tokio" } diff --git a/examples/client.rs b/examples/client.rs index 418a6a4..d454b01 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,7 +1,6 @@ extern crate clap; extern crate rustls; -extern crate futures; -extern crate tokio_io; +extern crate tokio; extern crate tokio_core; extern crate webpki; extern crate webpki_roots; @@ -14,16 +13,16 @@ use std::sync::Arc; use std::net::ToSocketAddrs; use std::io::{ BufReader, stdout, stdin }; use std::fs; -use futures::Future; +use tokio::io; +use tokio::prelude::*; use tokio_core::net::TcpStream; use tokio_core::reactor::Core; -use tokio_io::io; use clap::{ App, Arg }; use rustls::ClientConfig; use tokio_rustls::ClientConfigExt; #[cfg(unix)] -use tokio_io::AsyncRead; +use tokio::io::AsyncRead; #[cfg(unix)] use tokio_file_unix::{ StdFile, File }; diff --git a/src/futures_impl.rs b/src/futures_impl.rs index e8c1f79..0f5cc80 100644 --- a/src/futures_impl.rs +++ b/src/futures_impl.rs @@ -6,7 +6,7 @@ use self::futures::io::{ Error, AsyncRead, AsyncWrite }; use self::futures::task::Context; -impl Future for ConnectAsync { +impl Future for ConnectAsync { type Item = TlsStream; type Error = io::Error; @@ -15,7 +15,7 @@ impl Future for ConnectAsync { } } -impl Future for AcceptAsync { +impl Future for AcceptAsync { type Item = TlsStream; type Error = io::Error; @@ -24,20 +24,67 @@ impl Future for AcceptAsync { } } +macro_rules! async { + ( to $r:expr ) => { + match $r { + Ok(Async::Ready(n)) => Ok(n), + Ok(Async::Pending) => Err(io::ErrorKind::WouldBlock.into()), + Err(e) => Err(e) + } + }; + ( from $r:expr ) => { + match $r { + Ok(n) => Ok(Async::Ready(n)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), + Err(e) => Err(e) + } + }; +} + +struct TaskStream<'a, 'b: 'a, S: 'a> { + io: &'a mut S, + task: &'a mut Context<'b> +} + +impl<'a, 'b, S> io::Read for TaskStream<'a, 'b, S> + where S: AsyncRead + AsyncWrite +{ + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + async!(to self.io.poll_read(self.task, buf)) + } +} + +impl<'a, 'b, S> io::Write for TaskStream<'a, 'b, S> + where S: AsyncRead + AsyncWrite +{ + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + async!(to self.io.poll_write(self.task, buf)) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + async!(to self.io.poll_flush(self.task)) + } +} + impl Future for MidHandshake - where S: io::Read + io::Write, C: Session + where S: AsyncRead + AsyncWrite, C: Session { type Item = TlsStream; type Error = io::Error; - fn poll(&mut self, _: &mut Context) -> Poll { + fn poll(&mut self, ctx: &mut Context) -> Poll { loop { let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; - match stream.do_io() { + let mut taskio = TaskStream { io: &mut stream.io, task: ctx }; + + match TlsStream::do_io(&mut stream.session, &mut taskio, &mut stream.eof) { Ok(()) => match (stream.eof, stream.session.is_handshaking()) { - (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + (true, true) => return Err(io::ErrorKind::UnexpectedEof.into()), (false, true) => continue, (..) => break }, @@ -58,8 +105,10 @@ impl AsyncRead for TlsStream S: AsyncRead + AsyncWrite, C: Session { - fn poll_read(&mut self, _: &mut Context, buf: &mut [u8]) -> Poll { - unimplemented!() + fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll { + let mut taskio = TaskStream { io: &mut self.io, task: ctx }; + // FIXME TlsStream + TaskStream + async!(from io::Read::read(&mut taskio, buf)) } } @@ -68,15 +117,29 @@ impl AsyncWrite for TlsStream S: AsyncRead + AsyncWrite, C: Session { - fn poll_write(&mut self, _: &mut Context, buf: &[u8]) -> Poll { - unimplemented!() + fn poll_write(&mut self, ctx: &mut Context, buf: &[u8]) -> Poll { + let mut taskio = TaskStream { io: &mut self.io, task: ctx }; + // FIXME TlsStream + TaskStream + async!(from io::Write::write(&mut taskio, buf)) } - fn poll_flush(&mut self, _: &mut Context) -> Poll<(), Error> { - unimplemented!() + fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> { + let mut taskio = TaskStream { io: &mut self.io, task: ctx }; + // FIXME TlsStream + TaskStream + async!(from io::Write::flush(&mut taskio)) } - fn poll_close(&mut self, _: &mut Context) -> Poll<(), Error> { - unimplemented!() + fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> { + if !self.is_shutdown { + self.session.send_close_notify(); + self.is_shutdown = true; + } + + { + let mut taskio = TaskStream { io: &mut self.io, task: ctx }; + while TlsStream::do_write(&mut self.session, &mut taskio)? {}; + } + + self.io.poll_close(ctx) } } diff --git a/src/lib.rs b/src/lib.rs index b7cd5b0..8167a5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,7 @@ extern crate rustls; extern crate webpki; #[cfg(feature = "tokio")] mod tokio_impl; -#[cfg(feature = "futures")] mod futures_impl; +#[cfg(feature = "unstable-futures")] mod futures_impl; use std::io; use std::sync::Arc; @@ -101,25 +101,6 @@ impl TlsStream { } } - -macro_rules! try_wouldblock { - ( continue $r:expr ) => { - match $r { - Ok(true) => continue, - Ok(false) => false, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, - Err(e) => return Err(e) - } - }; - ( ignore $r:expr ) => { - match $r { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Err(e) => return Err(e) - } - }; -} - impl TlsStream where S: io::Read + io::Write, C: Session { @@ -133,19 +114,19 @@ impl TlsStream } } - fn do_read(&mut self) -> io::Result { - if !self.eof && self.session.wants_read() { - if self.session.read_tls(&mut self.io)? == 0 { - self.eof = true; + fn do_read(session: &mut C, io: &mut S, eof: &mut bool) -> io::Result { + if !*eof && session.wants_read() { + if session.read_tls(io)? == 0 { + *eof = true; } - if let Err(err) = self.session.process_new_packets() { + if let Err(err) = session.process_new_packets() { // flush queued messages before returning an Err in // order to send alerts instead of abruptly closing // the socket - if self.session.wants_write() { + if session.wants_write() { // ignore result to avoid masking original error - let _ = self.session.write_tls(&mut self.io); + let _ = session.write_tls(io); } return Err(io::Error::new(io::ErrorKind::InvalidData, err)); } @@ -156,9 +137,9 @@ impl TlsStream } } - fn do_write(&mut self) -> io::Result { - if self.session.wants_write() { - self.session.write_tls(&mut self.io)?; + fn do_write(session: &mut C, io: &mut S) -> io::Result { + if session.wants_write() { + session.write_tls(io)?; Ok(true) } else { @@ -167,10 +148,21 @@ impl TlsStream } #[inline] - pub fn do_io(&mut self) -> io::Result<()> { + pub fn do_io(session: &mut C, io: &mut S, eof: &mut bool) -> io::Result<()> { + macro_rules! try_wouldblock { + ( $r:expr ) => { + match $r { + Ok(true) => continue, + Ok(false) => false, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, + Err(e) => return Err(e) + } + }; + } + loop { - let write_would_block = try_wouldblock!(continue self.do_write()); - let read_would_block = try_wouldblock!(continue self.do_read()); + let write_would_block = try_wouldblock!(Self::do_write(session, io)); + let read_would_block = try_wouldblock!(Self::do_read(session, io, eof)); if write_would_block || read_would_block { return Err(io::Error::from(io::ErrorKind::WouldBlock)); @@ -181,18 +173,28 @@ impl TlsStream } } +macro_rules! try_ignore { + ( $r:expr ) => { + match $r { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => return Err(e) + } + } +} + impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { fn read(&mut self, buf: &mut [u8]) -> io::Result { - try_wouldblock!(ignore self.do_io()); + try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); loop { match self.session.read(buf) { - Ok(0) if !self.eof => while self.do_read()? {}, + Ok(0) if !self.eof => while Self::do_read(&mut self.session, &mut self.io, &mut self.eof)? {}, Ok(n) => return Ok(n), Err(e) => if e.kind() == io::ErrorKind::ConnectionAborted { - try_wouldblock!(ignore self.do_read()); + try_ignore!(Self::do_read(&mut self.session, &mut self.io, &mut self.eof)); return if self.eof { Ok(0) } else { Err(e) } } else { return Err(e) @@ -206,12 +208,12 @@ impl io::Write for TlsStream where S: io::Read + io::Write, C: Session { fn write(&mut self, buf: &[u8]) -> io::Result { - try_wouldblock!(ignore self.do_io()); + try_ignore!(Self::do_io(&mut self.session, &mut self.io, &mut self.eof)); let mut wlen = self.session.write(buf)?; loop { - match self.do_write() { + match Self::do_write(&mut self.session, &mut self.io) { Ok(true) => continue, Ok(false) if wlen == 0 => (), Ok(false) => break, @@ -234,7 +236,7 @@ impl io::Write for TlsStream fn flush(&mut self) -> io::Result<()> { self.session.flush()?; - while self.do_write()? {}; + while Self::do_write(&mut self.session, &mut self.io)? {}; self.io.flush() } } diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index edbda5b..294f915 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -35,7 +35,7 @@ impl Future for MidHandshake let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; - match stream.do_io() { + match TlsStream::do_io(&mut stream.session, &mut stream.io, &mut stream.eof) { Ok(()) => match (stream.eof, stream.session.is_handshaking()) { (true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), (false, true) => continue, @@ -69,7 +69,7 @@ impl AsyncWrite for TlsStream self.session.send_close_notify(); self.is_shutdown = true; } - while self.do_write()? {}; + while TlsStream::do_write(&mut self.session, &mut self.io)? {}; self.io.shutdown() } } diff --git a/tests/test.rs b/tests/test.rs index 5e37e73..c0e2c8f 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,18 +1,16 @@ extern crate rustls; -extern crate futures; extern crate tokio; extern crate tokio_rustls; extern crate webpki; +#[cfg(feature = "unstable-futures")] extern crate futures; + use std::{ io, thread }; use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; use std::net::{ SocketAddr, IpAddr, Ipv4Addr }; -use tokio::prelude::*; -// use futures::{ FutureExt, StreamExt }; use tokio::net::{ TcpListener, TcpStream }; -use tokio::io as aio; use rustls::{ Certificate, PrivateKey, ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ ClientConfigExt, ServerConfigExt }; @@ -24,6 +22,9 @@ const HELLO_WORLD: &[u8] = b"Hello world!"; fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { + use tokio::prelude::*; + use tokio::io as aio; + let mut config = ServerConfig::new(rustls::NoClientAuth::new()); config.set_single_cert(cert, rsa); let config = Arc::new(config); @@ -45,21 +46,62 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { aio::write_all(stream, HELLO_WORLD) }) .map(drop) - .map_err(drop); + .map_err(|err| panic!("{:?}", err)); tokio::spawn(done); Ok(()) }) - .then(|_| Ok(())); + .map_err(|err| panic!("{:?}", err)); - tokio::runtime::run(done); + tokio::run(done); }); recv.recv().unwrap() } -fn start_client(addr: &SocketAddr, domain: &str, - chain: Option>>) -> io::Result<()> { +fn start_server2(cert: Vec, rsa: PrivateKey) -> SocketAddr { + use futures::{ FutureExt, StreamExt }; + use futures::io::{ AsyncReadExt, AsyncWriteExt }; + + let mut config = ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(cert, rsa); + let config = Arc::new(config); + + let (send, recv) = channel(); + + thread::spawn(move || { + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); + let listener = TcpListener::bind(&addr).unwrap(); + + send.send(listener.local_addr().unwrap()).unwrap(); + + let done = listener.incoming() + .for_each(move |stream| { + let done = config.accept_async(stream) + .and_then(|stream| stream.read_exact(vec![0; HELLO_WORLD.len()])) + .and_then(|(stream, buf)| { + assert_eq!(buf, HELLO_WORLD); + stream.write_all(HELLO_WORLD) + }) + .map(drop) + .map_err(|err| panic!("{:?}", err)); + + tokio::spawn2(done); + Ok(()) + }) + .map(drop) + .map_err(|err| panic!("{:?}", err)); + + tokio::runtime::run2(done); + }); + + recv.recv().unwrap() +} + +fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { + use tokio::prelude::*; + use tokio::io as aio; + let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let mut config = ClientConfig::new(); if let Some(mut chain) = chain { @@ -79,9 +121,41 @@ fn start_client(addr: &SocketAddr, domain: &str, done.wait() } +#[cfg(feature = "unstable-futures")] +fn start_client2(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { + use futures::FutureExt; + use futures::io::{ AsyncReadExt, AsyncWriteExt }; + use futures::executor::block_on; + + let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); + let mut config = ClientConfig::new(); + if let Some(mut chain) = chain { + config.root_store.add_pem_file(&mut chain).unwrap(); + } + let config = Arc::new(config); + + let done = TcpStream::connect(addr) + .and_then(|stream| config.connect_async(domain, stream)) + .and_then(|stream| { + eprintln!("WRITE: {:?}", stream); + stream.write_all(HELLO_WORLD) + }) + .and_then(|(stream, _)| { + eprintln!("READ: {:?}", stream); + stream.read_exact(vec![0; HELLO_WORLD.len()]) + }) + .and_then(|(stream, buf)| { + eprintln!("OK: {:?}", stream); + assert_eq!(buf, HELLO_WORLD); + Ok(()) + }); + + block_on(done) +} + #[test] -fn main() { +fn pass() { let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); let chain = BufReader::new(Cursor::new(CHAIN)); @@ -90,6 +164,17 @@ fn main() { start_client(&addr, "localhost", Some(chain)).unwrap(); } +#[cfg(feature = "unstable-futures")] +#[test] +fn pass2() { + let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); + let chain = BufReader::new(Cursor::new(CHAIN)); + + let addr = start_server2(cert, keys.pop().unwrap()); + start_client2(&addr, "localhost", Some(chain)).unwrap(); +} + #[should_panic] #[test] fn fail() {