From 18208689299be1ed7d5ace5ed95b33d099165d67 Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 23 Mar 2018 17:31:55 +0800 Subject: [PATCH] fix: futures_impl --- Cargo.toml | 16 +++++++------- src/futures_impl.rs | 52 +++++++++++++++++++++++++-------------------- src/lib.rs | 10 --------- src/tokio_impl.rs | 6 ++++-- tests/test.rs | 47 +++------------------------------------- 5 files changed, 44 insertions(+), 87 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 427015d..89d6b49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,23 +15,23 @@ travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures = { version = "0.2.0-alpha", optional = true } +futures = { version = "0.2.0-beta", optional = true } tokio = { version = "0.1", optional = true } rustls = "0.12" webpki = "0.18.0-alpha" [dev-dependencies] tokio = "0.1" -tokio-io = "0.1" -tokio-core = "0.1" -tokio-file-unix = "0.4" +# tokio-io = "0.1" +# tokio-core = "0.1" +# tokio-file-unix = "0.4" clap = "2.26" webpki-roots = "0.14" [features] -default = [ "tokio" ] -# unstable-futures = [ "futures", "tokio/unstable-futures" ] -# default = [ "unstable-futures", "tokio" ] +# default = [ "tokio" ] +default = [ "unstable-futures", "tokio" ] +unstable-futures = [ "futures", "tokio/unstable-futures" ] [patch.crates-io] -# tokio = { path = "../ref/tokio" } +tokio = { path = "../ref/tokio" } diff --git a/src/futures_impl.rs b/src/futures_impl.rs index 0f5cc80..22c637c 100644 --- a/src/futures_impl.rs +++ b/src/futures_impl.rs @@ -80,19 +80,13 @@ impl Future for MidHandshake let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; - let mut taskio = TaskStream { io: &mut stream.io, task: ctx }; + let (io, session) = stream.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; - match TlsStream::do_io(&mut stream.session, &mut taskio, &mut stream.eof) { - Ok(()) => match (stream.eof, stream.session.is_handshaking()) { - (true, true) => return Err(io::ErrorKind::UnexpectedEof.into()), - (false, true) => continue, - (..) => break - }, - Err(e) => match (e.kind(), stream.session.is_handshaking()) { - (io::ErrorKind::WouldBlock, true) => return Ok(Async::Pending), - (io::ErrorKind::WouldBlock, false) => break, - (..) => return Err(e) - } + match session.complete_io(&mut taskio) { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::Pending), + Err(e) => return Err(e) } } @@ -106,9 +100,16 @@ impl AsyncRead for TlsStream C: Session { fn poll_read(&mut self, ctx: &mut Context, buf: &mut [u8]) -> Poll { - let mut taskio = TaskStream { io: &mut self.io, task: ctx }; - // FIXME TlsStream + TaskStream - async!(from io::Read::read(&mut taskio, buf)) + let (io, session) = self.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; + let mut stream = Stream::new(session, &mut taskio); + + match io::Read::read(&mut stream, buf) { + Ok(n) => Ok(Async::Ready(n)), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(Async::Ready(0)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::Pending), + Err(e) => Err(e) + } } } @@ -118,15 +119,19 @@ impl AsyncWrite for TlsStream C: Session { fn poll_write(&mut self, ctx: &mut Context, buf: &[u8]) -> Poll { - let mut taskio = TaskStream { io: &mut self.io, task: ctx }; - // FIXME TlsStream + TaskStream - async!(from io::Write::write(&mut taskio, buf)) + let (io, session) = self.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; + let mut stream = Stream::new(session, &mut taskio); + + async!(from io::Write::write(&mut stream, buf)) } fn poll_flush(&mut self, ctx: &mut Context) -> Poll<(), Error> { - let mut taskio = TaskStream { io: &mut self.io, task: ctx }; - // FIXME TlsStream + TaskStream - async!(from io::Write::flush(&mut taskio)) + let (io, session) = self.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; + let mut stream = Stream::new(session, &mut taskio); + + async!(from io::Write::flush(&mut stream)) } fn poll_close(&mut self, ctx: &mut Context) -> Poll<(), Error> { @@ -136,8 +141,9 @@ impl AsyncWrite for TlsStream } { - let mut taskio = TaskStream { io: &mut self.io, task: ctx }; - while TlsStream::do_write(&mut self.session, &mut taskio)? {}; + let (io, session) = self.get_mut(); + let mut taskio = TaskStream { io, task: ctx }; + session.complete_io(&mut taskio)?; } self.io.poll_close(ctx) diff --git a/src/lib.rs b/src/lib.rs index 1d4ab7f..293a112 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -176,16 +176,6 @@ impl TlsStream } } -macro_rules! try_ignore { - ( $r:expr ) => { - match $r { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Err(e) => return Err(e) - } - } -} - impl io::Read for TlsStream where S: io::Read + io::Write, C: Session { diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs index fa4fbe8..f5d3c6c 100644 --- a/src/tokio_impl.rs +++ b/src/tokio_impl.rs @@ -35,7 +35,9 @@ impl Future for MidHandshake let stream = self.inner.as_mut().unwrap(); if !stream.session.is_handshaking() { break }; - match stream.session.complete_io(&mut stream.io) { + let (io, session) = stream.get_mut(); + + match session.complete_io(io) { Ok(_) => (), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(Async::NotReady), Err(e) => return Err(e) @@ -62,7 +64,7 @@ impl AsyncWrite for TlsStream self.session.send_close_notify(); self.is_shutdown = true; } - while TlsStream::do_write(&mut self.session, &mut self.io)? {}; + self.session.complete_io(&mut self.io)?; self.io.shutdown() } } diff --git a/tests/test.rs b/tests/test.rs index e231737..92c904a 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -45,8 +45,7 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { assert_eq!(buf, HELLO_WORLD); aio::write_all(stream, HELLO_WORLD) }) - .map(drop) - .map_err(|err| panic!("{:?}", err)); + .then(|_| Ok(())); tokio::spawn(done); Ok(()) @@ -59,46 +58,6 @@ fn start_server(cert: Vec, rsa: PrivateKey) -> SocketAddr { recv.recv().unwrap() } -#[cfg(feature = "unstable-futures")] -fn start_server2(cert: Vec, rsa: PrivateKey) -> SocketAddr { - use futures::{ FutureExt, StreamExt }; - use futures::io::{ AsyncReadExt, AsyncWriteExt }; - - let mut config = ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(cert, rsa); - let config = Arc::new(config); - - let (send, recv) = channel(); - - thread::spawn(move || { - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); - let listener = TcpListener::bind(&addr).unwrap(); - - send.send(listener.local_addr().unwrap()).unwrap(); - - let done = listener.incoming() - .for_each(move |stream| { - let done = config.accept_async(stream) - .and_then(|stream| stream.read_exact(vec![0; HELLO_WORLD.len()])) - .and_then(|(stream, buf)| { - assert_eq!(buf, HELLO_WORLD); - stream.write_all(HELLO_WORLD) - }) - .map(drop) - .map_err(|err| panic!("{:?}", err)); - - tokio::spawn2(done); - Ok(()) - }) - .map(drop) - .map_err(|err| panic!("{:?}", err)); - - tokio::runtime::run2(done); - }); - - recv.recv().unwrap() -} - fn start_client(addr: &SocketAddr, domain: &str, chain: Option>>) -> io::Result<()> { use tokio::prelude::*; use tokio::io as aio; @@ -139,7 +98,7 @@ fn start_client2(addr: &SocketAddr, domain: &str, chain: Option