From bcf4f8e3f96983dbb7a61808b0f1fcd04fb678ae Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 19 Mar 2022 13:09:28 +0800 Subject: [PATCH] Rustls buffered handshake eof failed (#98) * rustls/tests: use BufWriter in handshake * tokio-rustls: move test to stream_buffered_handshake * Fix tokio-rustls bufwriter handshake fail #96 * Use need_flush * More flush * tokio-rustls: release 0.23.3 * Fix fmt Co-authored-by: tharvik --- tokio-rustls/Cargo.toml | 2 +- tokio-rustls/src/common/handshake.rs | 4 +--- tokio-rustls/src/common/mod.rs | 14 +++++++++++++- tokio-rustls/src/common/test_stream.rs | 24 ++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/tokio-rustls/Cargo.toml b/tokio-rustls/Cargo.toml index 5c9f205..17a5d9f 100644 --- a/tokio-rustls/Cargo.toml +++ b/tokio-rustls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.23.2" +version = "0.23.3" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/tokio-rs/tls" diff --git a/tokio-rustls/src/common/handshake.rs b/tokio-rustls/src/common/handshake.rs index fcb6dc9..6725277 100644 --- a/tokio-rustls/src/common/handshake.rs +++ b/tokio-rustls/src/common/handshake.rs @@ -62,9 +62,7 @@ where try_poll!(tls_stream.handshake(cx)); } - while tls_stream.session.wants_write() { - try_poll!(tls_stream.write_io(cx)); - } + try_poll!(Pin::new(&mut tls_stream).poll_flush(cx)); } Poll::Ready(Ok(stream)) diff --git a/tokio-rustls/src/common/mod.rs b/tokio-rustls/src/common/mod.rs index 6de5b97..fde34c0 100644 --- a/tokio-rustls/src/common/mod.rs +++ b/tokio-rustls/src/common/mod.rs @@ -166,10 +166,14 @@ where loop { let mut write_would_block = false; let mut read_would_block = false; + let mut need_flush = false; while self.session.wants_write() { match self.write_io(cx) { - Poll::Ready(Ok(n)) => wrlen += n, + Poll::Ready(Ok(n)) => { + wrlen += n; + need_flush = true; + } Poll::Pending => { write_would_block = true; break; @@ -178,6 +182,14 @@ where } } + if need_flush { + match Pin::new(&mut self.io).poll_flush(cx) { + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => write_would_block = true, + } + } + while !self.eof && self.session.wants_read() { match self.read_io(cx) { Poll::Ready(Ok(0)) => self.eof = true, diff --git a/tokio-rustls/src/common/test_stream.rs b/tokio-rustls/src/common/test_stream.rs index 3b3966d..710a324 100644 --- a/tokio-rustls/src/common/test_stream.rs +++ b/tokio-rustls/src/common/test_stream.rs @@ -202,6 +202,30 @@ async fn stream_handshake() -> io::Result<()> { Ok(()) as io::Result<()> } +#[tokio::test] +async fn stream_buffered_handshake() -> io::Result<()> { + use tokio::io::BufWriter; + + let (server, mut client) = make_pair(); + let mut server = Connection::from(server); + + { + let mut good = BufWriter::new(Good(&mut server)); + let mut stream = Stream::new(&mut good, &mut client); + let (r, w) = poll_fn(|cx| stream.handshake(cx)).await?; + + assert!(r > 0); + assert!(w > 0); + + poll_fn(|cx| stream.handshake(cx)).await?; // finish server handshake + } + + assert!(!server.is_handshaking()); + assert!(!client.is_handshaking()); + + Ok(()) as io::Result<()> +} + #[tokio::test] async fn stream_handshake_eof() -> io::Result<()> { let (_, mut client) = make_pair();