From fe8a0f415217bac092ef15de40b3714b45472530 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Thu, 16 Dec 2021 08:32:46 -0500 Subject: [PATCH] fix: Fix EOF spin loop by removing intermediate buffer in LazyConfigAcceptor (#87) * chore: Remove intermediate buffer in LazyConfigAcceptor * chore: Document WouldBlock behavior * chore: satisfy clippy * chore: Rename Reader -> SyncReadAdapter * chore: add test for EOF --- tokio-rustls/src/common/mod.rs | 42 +++++++++++++++++++--------------- tokio-rustls/src/lib.rs | 29 +++++------------------ tokio-rustls/tests/test.rs | 23 +++++++++++++++++-- 3 files changed, 50 insertions(+), 44 deletions(-) diff --git a/tokio-rustls/src/common/mod.rs b/tokio-rustls/src/common/mod.rs index a90c3fb..6de5b97 100644 --- a/tokio-rustls/src/common/mod.rs +++ b/tokio-rustls/src/common/mod.rs @@ -89,24 +89,7 @@ where } pub fn read_io(&mut self, cx: &mut Context) -> Poll> { - struct Reader<'a, 'b, T> { - io: &'a mut T, - cx: &'a mut Context<'b>, - } - - impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut buf = ReadBuf::new(buf); - match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) { - Poll::Ready(Ok(())) => Ok(buf.filled().len()), - Poll::Ready(Err(err)) => Err(err), - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), - } - } - } - - let mut reader = Reader { io: self.io, cx }; + let mut reader = SyncReadAdapter { io: self.io, cx }; let n = match self.session.read_tls(&mut reader) { Ok(n) => n, @@ -145,7 +128,7 @@ where &mut self, f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll>, ) -> io::Result { - match f(Pin::new(&mut self.io), self.cx) { + match f(Pin::new(self.io), self.cx) { Poll::Ready(result) => result, Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), } @@ -343,5 +326,26 @@ where } } +/// An adapter that implements a [`Read`] interface for [`AsyncRead`] types and an +/// associated [`Context`]. +/// +/// Turns `Poll::Pending` into `WouldBlock`. +pub struct SyncReadAdapter<'a, 'b, T> { + pub io: &'a mut T, + pub cx: &'a mut Context<'b>, +} + +impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut buf = ReadBuf::new(buf); + match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) { + Poll::Ready(Ok(())) => Ok(buf.filled().len()), + Poll::Ready(Err(err)) => Err(err), + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), + } + } +} + #[cfg(test)] mod test_stream; diff --git a/tokio-rustls/src/lib.rs b/tokio-rustls/src/lib.rs index 9c2a8d4..242b090 100644 --- a/tokio-rustls/src/lib.rs +++ b/tokio-rustls/src/lib.rs @@ -190,8 +190,6 @@ impl TlsAcceptor { pub struct LazyConfigAcceptor { acceptor: rustls::server::Acceptor, - buf: Vec, - used: usize, io: Option, } @@ -203,8 +201,6 @@ where pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self { Self { acceptor, - buf: vec![0; 512], - used: 0, io: Some(io), } } @@ -229,25 +225,12 @@ where } }; - let mut buf = ReadBuf::new(&mut this.buf); - buf.advance(this.used); - if buf.remaining() > 0 { - if let Err(err) = ready!(Pin::new(io).poll_read(cx, &mut buf)) { - return Poll::Ready(Err(err)); - } - } - - let read = match this.acceptor.read_tls(&mut buf.filled()) { - Ok(read) => read, - Err(err) => return Poll::Ready(Err(err)), - }; - - let received = buf.filled().len(); - if read < received { - this.buf.copy_within(read.., 0); - this.used = received - read; - } else { - this.used = 0; + let mut reader = common::SyncReadAdapter { io, cx }; + match this.acceptor.read_tls(&mut reader) { + Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(), + Ok(_) => {} + Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, + Err(e) => return Err(e).into(), } match this.acceptor.accept() { diff --git a/tokio-rustls/tests/test.rs b/tokio-rustls/tests/test.rs index 29fa603..78cebf5 100644 --- a/tokio-rustls/tests/test.rs +++ b/tokio-rustls/tests/test.rs @@ -3,14 +3,15 @@ use lazy_static::lazy_static; use rustls::{ClientConfig, OwnedTrustAnchor}; use rustls_pemfile::{certs, rsa_private_keys}; use std::convert::TryFrom; -use std::io::{BufReader, Cursor}; +use std::io::{BufReader, Cursor, ErrorKind}; use std::net::SocketAddr; use std::sync::mpsc::channel; use std::sync::Arc; +use std::time::Duration; use std::{io, thread}; use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -use tokio::runtime; +use tokio::{runtime, time}; use tokio_rustls::{LazyConfigAcceptor, TlsAcceptor, TlsConnector}; const CERT: &str = include_str!("end.cert"); @@ -202,5 +203,23 @@ async fn test_lazy_config_acceptor() -> io::Result<()> { Ok(()) } +// This test is a follow-up from https://github.com/tokio-rs/tls/issues/85 +#[tokio::test] +async fn lazy_config_acceptor_eof() { + let buf = Cursor::new(Vec::new()); + let acceptor = LazyConfigAcceptor::new(rustls::server::Acceptor::new().unwrap(), buf); + + let accept_result = match time::timeout(Duration::from_secs(3), acceptor).await { + Ok(res) => res, + Err(_elapsed) => panic!("timeout"), + }; + + match accept_result { + Ok(_) => panic!("accepted a connection from zero bytes of data"), + Err(e) if e.kind() == ErrorKind::UnexpectedEof => {} + Err(e) => panic!("unexpected error: {:?}", e), + } +} + // Include `utils` module include!("utils.rs");