From e6ef54641b911cfcf23b77a8c4826ae0f8e9870e Mon Sep 17 00:00:00 2001 From: quininer Date: Fri, 16 Oct 2020 18:26:32 +0800 Subject: [PATCH] Tokio 0.3 (#29) * Remove futures-core * Upgrade Tokio 0.3 * clean code * Fix ci * Fix lint --- tokio-native-tls/src/lib.rs | 2 +- tokio-rustls/Cargo.toml | 10 +- tokio-rustls/examples/client/Cargo.toml | 2 +- tokio-rustls/examples/server/Cargo.toml | 2 +- tokio-rustls/examples/server/src/main.rs | 2 +- tokio-rustls/src/client.rs | 27 ++--- tokio-rustls/src/common/handshake.rs | 16 --- tokio-rustls/src/common/mod.rs | 82 +++++-------- tokio-rustls/src/common/test_stream.rs | 27 +++-- tokio-rustls/src/common/vecbuf.rs | 140 ----------------------- tokio-rustls/src/lib.rs | 44 ++----- tokio-rustls/src/server.rs | 28 ++--- tokio-rustls/tests/early-data.rs | 8 +- tokio-rustls/tests/test.rs | 13 +-- 14 files changed, 100 insertions(+), 303 deletions(-) delete mode 100644 tokio-rustls/src/common/vecbuf.rs diff --git a/tokio-native-tls/src/lib.rs b/tokio-native-tls/src/lib.rs index a81e915..6b64280 100644 --- a/tokio-native-tls/src/lib.rs +++ b/tokio-native-tls/src/lib.rs @@ -5,7 +5,7 @@ rust_2018_idioms, unreachable_pub )] -#![deny(intra_doc_link_resolution_failure)] +#![deny(broken_intra_doc_links)] #![doc(test( no_crate_inject, attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) diff --git a/tokio-rustls/Cargo.toml b/tokio-rustls/Cargo.toml index 94e55e0..7d11f9e 100644 --- a/tokio-rustls/Cargo.toml +++ b/tokio-rustls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.14.1" +version = "0.20.0" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/tokio-rs/tls" @@ -12,20 +12,16 @@ categories = ["asynchronous", "cryptography", "network-programming"] edition = "2018" [dependencies] -tokio = "0.2.0" -futures-core = "0.3.1" +tokio = "0.3" rustls = "0.18" webpki = "0.21" -bytes = { version = "0.5", optional = true } - [features] early-data = [] dangerous_configuration = ["rustls/dangerous_configuration"] -unstable = ["bytes"] [dev-dependencies] -tokio = { version = "0.2.0", features = ["macros", "net", "io-util", "rt-core", "time"] } +tokio = { version = "0.3", features = ["full"] } futures-util = "0.3.1" lazy_static = "1" webpki-roots = "0.20" diff --git a/tokio-rustls/examples/client/Cargo.toml b/tokio-rustls/examples/client/Cargo.toml index b3a48b1..3d59914 100644 --- a/tokio-rustls/examples/client/Cargo.toml +++ b/tokio-rustls/examples/client/Cargo.toml @@ -5,7 +5,7 @@ authors = ["quininer "] edition = "2018" [dependencies] -tokio = { version = "0.2", features = [ "full" ] } +tokio = { version = "0.3", features = [ "full" ] } argh = "0.1" tokio-rustls = { path = "../.." } webpki-roots = "0.20" diff --git a/tokio-rustls/examples/server/Cargo.toml b/tokio-rustls/examples/server/Cargo.toml index f425e31..20b9b56 100644 --- a/tokio-rustls/examples/server/Cargo.toml +++ b/tokio-rustls/examples/server/Cargo.toml @@ -5,6 +5,6 @@ authors = ["quininer "] edition = "2018" [dependencies] -tokio = { version = "0.2", features = [ "full" ] } +tokio = { version = "0.3", features = [ "full" ] } argh = "0.1" tokio-rustls = { path = "../.." } diff --git a/tokio-rustls/examples/server/src/main.rs b/tokio-rustls/examples/server/src/main.rs index af340e9..65fcf39 100644 --- a/tokio-rustls/examples/server/src/main.rs +++ b/tokio-rustls/examples/server/src/main.rs @@ -59,7 +59,7 @@ async fn main() -> io::Result<()> { .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; let acceptor = TlsAcceptor::from(Arc::new(config)); - let mut listener = TcpListener::bind(&addr).await?; + let listener = TcpListener::bind(&addr).await?; loop { let (stream, peer_addr) = listener.accept().await?; diff --git a/tokio-rustls/src/client.rs b/tokio-rustls/src/client.rs index 444659e..9bd20ad 100644 --- a/tokio-rustls/src/client.rs +++ b/tokio-rustls/src/client.rs @@ -52,16 +52,11 @@ impl AsyncRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - #[cfg(feature = "unstable")] - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { match self.state { #[cfg(feature = "early-data")] TlsState::EarlyData(..) => Poll::Pending, @@ -69,21 +64,24 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + let prev = buf.remaining(); match stream.as_mut_pin().poll_read(cx, buf) { - Poll::Ready(Ok(0)) => { - this.state.shutdown_read(); - Poll::Ready(Ok(0)) + Poll::Ready(Ok(())) => { + if prev == buf.remaining() { + this.state.shutdown_read(); + } + + Poll::Ready(Ok(())) } - Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { this.state.shutdown_read(); - Poll::Ready(Ok(0)) + Poll::Ready(Ok(())) } output => output, } } - TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())), } } } @@ -107,7 +105,6 @@ where match this.state { #[cfg(feature = "early-data")] TlsState::EarlyData(ref mut pos, ref mut data) => { - use futures_core::ready; use std::io::Write; // write early data @@ -153,8 +150,6 @@ where #[cfg(feature = "early-data")] { - use futures_core::ready; - if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { // complete handshake while stream.session.is_handshaking() { diff --git a/tokio-rustls/src/common/handshake.rs b/tokio-rustls/src/common/handshake.rs index a00a3e1..39139fa 100644 --- a/tokio-rustls/src/common/handshake.rs +++ b/tokio-rustls/src/common/handshake.rs @@ -1,5 +1,4 @@ use crate::common::{Stream, TlsState}; -use futures_core::future::FusedFuture; use rustls::Session; use std::future::Future; use std::pin::Pin; @@ -21,21 +20,6 @@ pub(crate) enum MidHandshake { End, } -impl FusedFuture for MidHandshake -where - IS: IoSession + Unpin, - IS::Io: AsyncRead + AsyncWrite + Unpin, - IS::Session: Session + Unpin, -{ - fn is_terminated(&self) -> bool { - if let MidHandshake::End = self { - true - } else { - false - } - } -} - impl Future for MidHandshake where IS: IoSession + Unpin, diff --git a/tokio-rustls/src/common/mod.rs b/tokio-rustls/src/common/mod.rs index a71ab7a..2a2d3e1 100644 --- a/tokio-rustls/src/common/mod.rs +++ b/tokio-rustls/src/common/mod.rs @@ -1,15 +1,11 @@ mod handshake; -#[cfg(feature = "unstable")] -mod vecbuf; - -use futures_core as futures; pub(crate) use handshake::{IoSession, MidHandshake}; use rustls::Session; use std::io::{self, Read, Write}; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[derive(Debug)] pub enum TlsState { @@ -40,27 +36,18 @@ impl TlsState { #[inline] pub fn writeable(&self) -> bool { - match *self { - TlsState::WriteShutdown | TlsState::FullyShutdown => false, - _ => true, - } + !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown) } #[inline] pub fn readable(&self) -> bool { - match self { - TlsState::ReadShutdown | TlsState::FullyShutdown => false, - _ => true, - } + !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown) } #[inline] #[cfg(feature = "early-data")] pub fn is_early_data(&self) -> bool { - match self { - TlsState::EarlyData(..) => true, - _ => false, - } + matches!(self, TlsState::EarlyData(..)) } #[inline] @@ -105,8 +92,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { #[inline] fn read(&mut self, buf: &mut [u8]) -> io::Result { - match Pin::new(&mut self.io).poll_read(self.cx, buf) { - Poll::Ready(result) => result, + let mut buf = ReadBuf::new(buf); + match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) { + Poll::Ready(Ok(())) => Ok(buf.filled().len()), + Poll::Ready(Err(err)) => Err(err), Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), } } @@ -133,9 +122,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } pub fn write_io(&mut self, cx: &mut Context) -> Poll> { - #[cfg(feature = "unstable")] - use std::io::IoSlice; - struct Writer<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b>, @@ -150,19 +136,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } - #[cfg(feature = "unstable")] - #[inline] - fn write_vectored(&mut self, bufs: &[IoSlice]) -> io::Result { - use vecbuf::VecBuf; - - let mut vbuf = VecBuf::new(bufs); - - match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) { - Poll::Ready(result) => result, - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), - } - } - fn flush(&mut self) -> io::Result<()> { match Pin::new(&mut self.io).poll_flush(self.cx) { Poll::Ready(result) => result, @@ -232,12 +205,12 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { fn poll_read( mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { - let mut pos = 0; + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let prev = buf.remaining(); - while pos != buf.len() { + while buf.remaining() != 0 { let mut would_block = false; // read a packet @@ -256,22 +229,28 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a } } - return match self.session.read(&mut buf[pos..]) { - Ok(0) if pos == 0 && would_block => Poll::Pending, - Ok(n) if self.eof || would_block => Poll::Ready(Ok(pos + n)), + return match self.session.read(buf.initialize_unfilled()) { + Ok(0) if prev == buf.remaining() && would_block => Poll::Pending, Ok(n) => { - pos += n; - continue; + buf.advance(n); + + if self.eof || would_block { + break; + } else { + continue; + } } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 => { - Poll::Ready(Ok(pos)) + Err(ref err) + if err.kind() == io::ErrorKind::ConnectionAborted + && prev != buf.remaining() => + { + break } Err(err) => Poll::Ready(Err(err)), }; } - Poll::Ready(Ok(pos)) + Poll::Ready(Ok(())) } } @@ -288,7 +267,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' match self.session.write(&buf[pos..]) { Ok(n) => pos += n, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (), Err(err) => return Poll::Ready(Err(err)), }; @@ -316,14 +294,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.session.flush()?; while self.session.wants_write() { - futures::ready!(self.write_io(cx))?; + ready!(self.write_io(cx))?; } Pin::new(&mut self.io).poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { while self.session.wants_write() { - futures::ready!(self.write_io(cx))?; + ready!(self.write_io(cx))?; } Pin::new(&mut self.io).poll_shutdown(cx) } diff --git a/tokio-rustls/src/common/test_stream.rs b/tokio-rustls/src/common/test_stream.rs index b333239..9faf762 100644 --- a/tokio-rustls/src/common/test_stream.rs +++ b/tokio-rustls/src/common/test_stream.rs @@ -1,5 +1,4 @@ use super::Stream; -use futures_core::ready; use futures_util::future::poll_fn; use futures_util::task::noop_waker_ref; use rustls::internal::pemfile::{certs, rsa_private_keys}; @@ -8,7 +7,7 @@ use std::io::{self, BufReader, Cursor, Read, Write}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use webpki::DNSNameRef; struct Good<'a>(&'a mut dyn Session); @@ -17,9 +16,17 @@ impl<'a> AsyncRead for Good<'a> { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - mut buf: &mut [u8], - ) -> Poll> { - Poll::Ready(self.0.write_tls(buf.by_ref())) + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let mut buf2 = buf.initialize_unfilled(); + + Poll::Ready(match self.0.write_tls(buf2.by_ref()) { + Ok(n) => { + buf.advance(n); + Ok(()) + } + Err(err) => Err(err), + }) } } @@ -55,8 +62,8 @@ impl AsyncRead for Pending { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - _: &mut [u8], - ) -> Poll> { + _: &mut ReadBuf<'_>, + ) -> Poll> { Poll::Pending } } @@ -85,9 +92,9 @@ impl AsyncRead for Eof { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - _: &mut [u8], - ) -> Poll> { - Poll::Ready(Ok(0)) + _: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) } } diff --git a/tokio-rustls/src/common/vecbuf.rs b/tokio-rustls/src/common/vecbuf.rs deleted file mode 100644 index c8b809a..0000000 --- a/tokio-rustls/src/common/vecbuf.rs +++ /dev/null @@ -1,140 +0,0 @@ -use bytes::Buf; -use std::cmp::{self, Ordering}; -use std::io::IoSlice; - -pub struct VecBuf<'a> { - pos: usize, - cur: usize, - inner: &'a [IoSlice<'a>], -} - -impl<'a> VecBuf<'a> { - pub fn new(vbytes: &'a [IoSlice<'a>]) -> Self { - VecBuf { - pos: 0, - cur: 0, - inner: vbytes, - } - } -} - -impl<'a> Buf for VecBuf<'a> { - fn remaining(&self) -> usize { - let sum = self - .inner - .iter() - .skip(self.pos) - .map(|bytes| bytes.len()) - .sum::(); - sum - self.cur - } - - fn bytes(&self) -> &[u8] { - &self.inner[self.pos][self.cur..] - } - - fn advance(&mut self, cnt: usize) { - let current = self.inner[self.pos].len(); - match (self.cur + cnt).cmp(¤t) { - Ordering::Equal => { - if self.pos + 1 < self.inner.len() { - self.pos += 1; - self.cur = 0; - } else { - self.cur += cnt; - } - } - Ordering::Greater => { - if self.pos + 1 < self.inner.len() { - self.pos += 1; - } - let remaining = self.cur + cnt - current; - self.advance(remaining); - } - Ordering::Less => self.cur += cnt, - } - } - - #[allow(clippy::needless_range_loop)] - #[inline] - fn bytes_vectored<'c>(&'c self, dst: &mut [IoSlice<'c>]) -> usize { - let len = cmp::min(self.inner.len() - self.pos, dst.len()); - - if len > 0 { - dst[0] = IoSlice::new(self.bytes()); - } - - for i in 1..len { - dst[i] = self.inner[self.pos + i]; - } - - len - } -} - -#[cfg(test)] -mod test_vecbuf { - use super::*; - - #[test] - fn test_fresh_cursor_vec() { - let buf = [IoSlice::new(b"he"), IoSlice::new(b"llo")]; - let mut buf = VecBuf::new(&buf); - - assert_eq!(buf.remaining(), 5); - assert_eq!(buf.bytes(), b"he"); - - buf.advance(1); - - assert_eq!(buf.remaining(), 4); - assert_eq!(buf.bytes(), b"e"); - - buf.advance(1); - - assert_eq!(buf.remaining(), 3); - assert_eq!(buf.bytes(), b"llo"); - - buf.advance(3); - - assert_eq!(buf.remaining(), 0); - assert_eq!(buf.bytes(), b""); - } - - #[test] - fn test_get_u8() { - let buf = [IoSlice::new(b"\x21z"), IoSlice::new(b"omg")]; - let mut buf = VecBuf::new(&buf); - assert_eq!(0x21, buf.get_u8()); - } - - #[test] - fn test_get_u16() { - let buf = [IoSlice::new(b"\x21\x54z"), IoSlice::new(b"omg")]; - let mut buf = VecBuf::new(&buf); - assert_eq!(0x2154, buf.get_u16()); - let buf = [IoSlice::new(b"\x21\x54z"), IoSlice::new(b"omg")]; - let mut buf = VecBuf::new(&buf); - assert_eq!(0x5421, buf.get_u16_le()); - } - - #[test] - #[should_panic] - fn test_get_u16_buffer_underflow() { - let buf = [IoSlice::new(b"\x21")]; - let mut buf = VecBuf::new(&buf); - buf.get_u16(); - } - - #[test] - fn test_bufs_vec() { - let buf = [IoSlice::new(b"he"), IoSlice::new(b"llo")]; - let buf = VecBuf::new(&buf); - - let b1: &[u8] = &mut [0]; - let b2: &[u8] = &mut [0]; - - let mut dst: [IoSlice; 2] = [IoSlice::new(b1), IoSlice::new(b2)]; - - assert_eq!(2, buf.bytes_vectored(&mut dst[..])); - } -} diff --git a/tokio-rustls/src/lib.rs b/tokio-rustls/src/lib.rs index 161a31a..29af923 100644 --- a/tokio-rustls/src/lib.rs +++ b/tokio-rustls/src/lib.rs @@ -1,18 +1,26 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). +macro_rules! ready { + ( $e:expr ) => { + match $e { + std::task::Poll::Ready(t) => t, + std::task::Poll::Pending => return std::task::Poll::Pending, + } + }; +} + pub mod client; mod common; pub mod server; use common::{MidHandshake, Stream, TlsState}; -use futures_core::future::FusedFuture; use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session}; use std::future::Future; use std::io; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use webpki::DNSNameRef; pub use rustls; @@ -155,13 +163,6 @@ impl Future for Connect { } } -impl FusedFuture for Connect { - #[inline] - fn is_terminated(&self) -> bool { - self.0.is_terminated() - } -} - impl Future for Accept { type Output = io::Result>; @@ -171,13 +172,6 @@ impl Future for Accept { } } -impl FusedFuture for Accept { - #[inline] - fn is_terminated(&self) -> bool { - self.0.is_terminated() - } -} - impl Future for FailableConnect { type Output = Result, (io::Error, IO)>; @@ -187,13 +181,6 @@ impl Future for FailableConnect { } } -impl FusedFuture for FailableConnect { - #[inline] - fn is_terminated(&self) -> bool { - self.0.is_terminated() - } -} - impl Future for FailableAccept { type Output = Result, (io::Error, IO)>; @@ -203,13 +190,6 @@ impl Future for FailableAccept { } } -impl FusedFuture for FailableAccept { - #[inline] - fn is_terminated(&self) -> bool { - self.0.is_terminated() - } -} - /// Unified TLS stream type /// /// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use @@ -269,8 +249,8 @@ where fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { match self.get_mut() { TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf), TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf), diff --git a/tokio-rustls/src/server.rs b/tokio-rustls/src/server.rs index d9cc62c..7ea7ce9 100644 --- a/tokio-rustls/src/server.rs +++ b/tokio-rustls/src/server.rs @@ -52,40 +52,36 @@ impl AsyncRead for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - #[cfg(feature = "unstable")] - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { - // TODO - // - // https://doc.rust-lang.org/nightly/std/io/trait.Read.html#method.initializer - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); match &this.state { TlsState::Stream | TlsState::WriteShutdown => { + let prev = buf.remaining(); + match stream.as_mut_pin().poll_read(cx, buf) { - Poll::Ready(Ok(0)) => { - this.state.shutdown_read(); - Poll::Ready(Ok(0)) + Poll::Ready(Ok(())) => { + if prev == buf.remaining() { + this.state.shutdown_read(); + } + + Poll::Ready(Ok(())) } - Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => { this.state.shutdown_read(); - Poll::Ready(Ok(0)) + Poll::Ready(Ok(())) } Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, } } - TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())), #[cfg(feature = "early-data")] s => unreachable!("server TLS can not hit this state: {:?}", s), } diff --git a/tokio-rustls/tests/early-data.rs b/tokio-rustls/tests/early-data.rs index 4ba9338..86915bd 100644 --- a/tokio-rustls/tests/early-data.rs +++ b/tokio-rustls/tests/early-data.rs @@ -9,9 +9,10 @@ use std::process::{Child, Command, Stdio}; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; +use tokio::io::ReadBuf; use tokio::net::TcpStream; use tokio::prelude::*; -use tokio::time::delay_for; +use tokio::time::sleep; use tokio_rustls::{client::TlsStream, TlsConnector}; struct Read1(T); @@ -21,6 +22,7 @@ impl Future for Read1 { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut buf = [0]; + let mut buf = &mut ReadBuf::new(&mut buf); ready!(Pin::new(&mut self.0).poll_read(cx, &mut buf))?; Poll::Pending } @@ -42,7 +44,7 @@ async fn send( // sleep 1s // // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html - let sleep1 = delay_for(Duration::from_secs(1)); + let sleep1 = sleep(Duration::from_secs(1)); let mut stream = match future::select(Read1(stream), sleep1).await { future::Either::Right((_, Read1(stream))) => stream, future::Either::Left((Err(err), _)) => return Err(err), @@ -77,7 +79,7 @@ async fn test_0rtt() -> io::Result<()> { .map(DropKill)?; // wait openssl server - delay_for(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)).await; let mut config = ClientConfig::new(); let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); diff --git a/tokio-rustls/tests/test.rs b/tokio-rustls/tests/test.rs index d0b449d..c92040b 100644 --- a/tokio-rustls/tests/test.rs +++ b/tokio-rustls/tests/test.rs @@ -31,17 +31,16 @@ lazy_static! { let (send, recv) = channel(); thread::spawn(move || { - let mut runtime = runtime::Builder::new() - .basic_scheduler() + let runtime = runtime::Builder::new_current_thread() .enable_io() .build() .unwrap(); - - let handle = runtime.handle().clone(); + let runtime = Arc::new(runtime); + let runtime2 = runtime.clone(); let done = async move { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let mut listener = TcpListener::bind(&addr).await?; + let listener = TcpListener::bind(&addr).await?; send.send(listener.local_addr()?).unwrap(); @@ -59,7 +58,7 @@ lazy_static! { } .unwrap_or_else(|err| eprintln!("server: {:?}", err)); - handle.spawn(fut); + runtime2.spawn(fut); } } .unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err)); @@ -102,7 +101,7 @@ async fn pass() -> io::Result<()> { // TcpStream::bind now returns a future it creates a race // condition until its ready sometimes. use std::time::*; - tokio::time::delay_for(Duration::from_secs(1)).await; + tokio::time::sleep(Duration::from_secs(1)).await; let mut config = ClientConfig::new(); let mut chain = BufReader::new(Cursor::new(chain));