From 017b1b64d18d80875909e7bf0f09c7e2529f0989 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 4 May 2019 22:44:40 +0800 Subject: [PATCH] start migrate to futures 0.3 (again) --- Cargo.toml | 5 +- src/common/mod.rs | 173 ++++++++++++++++++++++++++++------------------ src/lib.rs | 14 ++-- 3 files changed, 114 insertions(+), 78 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 98d70b8..5a61dcb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,14 +9,15 @@ documentation = "https://docs.rs/tokio-rustls" readme = "README.md" description = "Asynchronous TLS/SSL streams for Tokio using Rustls." categories = ["asynchronous", "cryptography", "network-programming"] +edition = "2018" [badges] travis-ci = { repository = "quininer/tokio-rustls" } appveyor = { repository = "quininer/tokio-rustls" } [dependencies] -futures = "0.1" -tokio-io = "0.1.6" +smallvec = "*" +futures = { package = "futures-preview", version = "0.3.0-alpha.15" } bytes = "0.4" iovec = "0.1" rustls = "0.15" diff --git a/src/common/mod.rs b/src/common/mod.rs index 14d2f71..ed29f09 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,18 +1,23 @@ -mod vecbuf; +// mod vecbuf; +use std::pin::Pin; +use std::task::Poll; +use std::marker::Unpin; use std::io::{ self, Read, Write }; use rustls::Session; use rustls::WriteV; -use tokio_io::{ AsyncRead, AsyncWrite }; +use futures::task::Context; +use futures::io::{ AsyncRead, AsyncWrite, IoVec }; +use smallvec::SmallVec; -pub struct Stream<'a, IO: 'a, S: 'a> { +pub struct Stream<'a, IO, S> { pub io: &'a mut IO, - pub session: &'a mut S + pub session: &'a mut S, } -pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write { - fn write_tls(&mut self) -> io::Result; +pub trait WriteTls { + fn write_tls(&mut self, cx: &mut Context) -> io::Result; } #[derive(Clone, Copy)] @@ -22,36 +27,59 @@ enum Focus { Writable } -impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { Stream { io, session } } - pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { - self.complete_inner_io(Focus::Empty) + pub fn complete_io(&mut self, cx: &mut Context) -> Poll> { + self.complete_inner_io(cx, Focus::Empty) } - fn complete_read_io(&mut self) -> io::Result { - let n = self.session.read_tls(self.io)?; + fn complete_read_io(&mut self, cx: &mut Context) -> Poll> { + struct Reader<'a, 'b, T> { + io: &'a mut T, + cx: &'a mut Context<'b> + } + + impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match Pin::new(&mut self.io).poll_read(self.cx, buf) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) + } + } + } + + let mut reader = Reader { io: self.io, cx }; + + let n = match self.session.read_tls(&mut reader) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, + Err(err) => return Poll::Ready(Err(err)) + }; self.session.process_new_packets() .map_err(|err| { // In case we have an alert to send describing this error, // try a last-gasp write -- but don't predate the primary // error. - let _ = self.write_tls(); + let _ = self.write_tls(cx); io::Error::new(io::ErrorKind::InvalidData, err) })?; - Ok(n) + Poll::Ready(Ok(n)) } - fn complete_write_io(&mut self) -> io::Result { - self.write_tls() + fn complete_write_io(&mut self, cx: &mut Context) -> Poll> { + match self.write_tls(cx) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } } - fn complete_inner_io(&mut self, focus: Focus) -> io::Result<(usize, usize)> { + fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll> { let mut wrlen = 0; let mut rdlen = 0; let mut eof = false; @@ -61,22 +89,22 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { let mut read_would_block = false; while self.session.wants_write() { - match self.complete_write_io() { - Ok(n) => wrlen += n, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + match self.complete_write_io(cx) { + Poll::Ready(Ok(n)) => wrlen += n, + Poll::Pending => { write_would_block = true; break }, - Err(err) => return Err(err) + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } if !eof && self.session.wants_read() { - match self.complete_read_io() { - Ok(0) => eof = true, - Ok(n) => rdlen += n, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => read_would_block = true, - Err(err) => return Err(err) + match self.complete_read_io(cx) { + Poll::Ready(Ok(0)) => eof = true, + Poll::Ready(Ok(n)) => rdlen += n, + Poll::Pending => read_would_block = true, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } @@ -87,7 +115,7 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { }; match (eof, self.session.is_handshaking(), would_block) { - (true, true, _) => return Err(io::ErrorKind::UnexpectedEof.into()), + (true, true, _) => return Poll::Pending, (_, false, true) => { let would_block = match focus { Focus::Empty => rdlen == 0 && wrlen == 0, @@ -96,83 +124,96 @@ impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { }; return if would_block { - Err(io::ErrorKind::WouldBlock.into()) + Poll::Pending } else { - Ok((rdlen, wrlen)) + Poll::Ready(Ok((rdlen, wrlen))) }; }, - (_, false, _) => return Ok((rdlen, wrlen)), - (_, true, true) => return Err(io::ErrorKind::WouldBlock.into()), + (_, false, _) => return Poll::Ready(Ok((rdlen, wrlen))), + (_, true, true) => return Poll::Pending, (..) => () } } } } -impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream<'a, IO, S> { - fn write_tls(&mut self) -> io::Result { - use futures::Async; - use self::vecbuf::VecBuf; +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Stream<'a, IO, S> { + fn write_tls(&mut self, cx: &mut Context) -> io::Result { + struct Writer<'a, 'b, IO> { + io: &'a mut IO, + cx: &'a mut Context<'b> + } - struct V<'a, IO: 'a>(&'a mut IO); - - impl<'a, IO: AsyncWrite> WriteV for V<'a, IO> { + impl<'a, 'b, IO: AsyncWrite + Unpin> WriteV for Writer<'a, 'b, IO> { fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result { - let mut vbytes = VecBuf::new(vbytes); - match self.0.write_buf(&mut vbytes) { - Ok(Async::Ready(n)) => Ok(n), - Ok(Async::NotReady) => Err(io::ErrorKind::WouldBlock.into()), - Err(err) => Err(err) + let vbytes = vbytes + .into_iter() + .try_fold(SmallVec::<[&'_ IoVec; 16]>::new(), |mut sum, next| { + sum.push(IoVec::from_bytes(next)?); + Some(sum) + }) + .unwrap_or_default(); + + match Pin::new(&mut self.io).poll_vectored_write(self.cx, &vbytes) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()) } } } - let mut vecio = V(self.io); + let mut vecio = Writer { io: self.io, cx }; self.session.writev_tls(&mut vecio) } } -impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> { - fn read(&mut self, buf: &mut [u8]) -> io::Result { +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { + fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll> { while self.session.wants_read() { - if let (0, _) = self.complete_inner_io(Focus::Readable)? { - break + match self.complete_inner_io(cx, Focus::Readable) { + Poll::Ready(Ok((0, _))) => break, + Poll::Ready(Ok(_)) => (), + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } - self.session.read(buf) - } -} -impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> { - fn write(&mut self, buf: &[u8]) -> io::Result { + // FIXME rustls always ready ? + Poll::Ready(self.session.read(buf)) + } + + fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll> { let len = self.session.write(buf)?; while self.session.wants_write() { - match self.complete_inner_io(Focus::Writable) { - Ok(_) => (), - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock && len != 0 => break, - Err(err) => return Err(err) + match self.complete_inner_io(cx, Focus::Writable) { + Poll::Ready(Ok(_)) => (), + Poll::Pending if len != 0 => break, + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } if len != 0 || buf.is_empty() { - Ok(len) + Poll::Ready(Ok(len)) } else { // not write zero - self.session.write(buf) - .and_then(|len| if len != 0 { - Ok(len) - } else { - Err(io::ErrorKind::WouldBlock.into()) - }) + match self.session.write(buf) { + Ok(0) => Poll::Pending, + Ok(n) => Poll::Ready(Ok(n)), + Err(err) => Poll::Ready(Err(err)) + } } } - fn flush(&mut self) -> io::Result<()> { + fn poll_flush(&mut self, cx: &mut Context) -> Poll> { self.session.flush()?; while self.session.wants_write() { - self.complete_inner_io(Focus::Writable)?; + match self.complete_inner_io(cx, Focus::Writable) { + Poll::Ready(Ok(_)) => (), + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } } - Ok(()) + Poll::Ready(Ok(())) } } diff --git a/src/lib.rs b/src/lib.rs index 04d7421..9e15ed7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,10 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -pub extern crate rustls; -pub extern crate webpki; - -extern crate bytes; -extern crate futures; -extern crate iovec; -extern crate tokio_io; - -pub mod client; +// pub mod client; mod common; -pub mod server; +// pub mod server; +/* use common::Stream; use futures::{Async, Future, Poll}; use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; @@ -194,3 +187,4 @@ impl Future for Accept { #[cfg(feature = "early-data")] #[cfg(test)] mod test_0rtt; +*/