diff --git a/tokio-rustls/Cargo.toml b/tokio-rustls/Cargo.toml index 5c9f205..17a5d9f 100644 --- a/tokio-rustls/Cargo.toml +++ b/tokio-rustls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-rustls" -version = "0.23.2" +version = "0.23.3" authors = ["quininer kel "] license = "MIT/Apache-2.0" repository = "https://github.com/tokio-rs/tls" diff --git a/tokio-rustls/src/common/handshake.rs b/tokio-rustls/src/common/handshake.rs index fcb6dc9..6725277 100644 --- a/tokio-rustls/src/common/handshake.rs +++ b/tokio-rustls/src/common/handshake.rs @@ -62,9 +62,7 @@ where try_poll!(tls_stream.handshake(cx)); } - while tls_stream.session.wants_write() { - try_poll!(tls_stream.write_io(cx)); - } + try_poll!(Pin::new(&mut tls_stream).poll_flush(cx)); } Poll::Ready(Ok(stream)) diff --git a/tokio-rustls/src/common/mod.rs b/tokio-rustls/src/common/mod.rs index 6de5b97..fde34c0 100644 --- a/tokio-rustls/src/common/mod.rs +++ b/tokio-rustls/src/common/mod.rs @@ -166,10 +166,14 @@ where loop { let mut write_would_block = false; let mut read_would_block = false; + let mut need_flush = false; while self.session.wants_write() { match self.write_io(cx) { - Poll::Ready(Ok(n)) => wrlen += n, + Poll::Ready(Ok(n)) => { + wrlen += n; + need_flush = true; + } Poll::Pending => { write_would_block = true; break; @@ -178,6 +182,14 @@ where } } + if need_flush { + match Pin::new(&mut self.io).poll_flush(cx) { + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => write_would_block = true, + } + } + while !self.eof && self.session.wants_read() { match self.read_io(cx) { Poll::Ready(Ok(0)) => self.eof = true, diff --git a/tokio-rustls/src/common/test_stream.rs b/tokio-rustls/src/common/test_stream.rs index 3b3966d..710a324 100644 --- a/tokio-rustls/src/common/test_stream.rs +++ b/tokio-rustls/src/common/test_stream.rs @@ -202,6 +202,30 @@ async fn stream_handshake() -> io::Result<()> { Ok(()) as io::Result<()> } +#[tokio::test] +async fn stream_buffered_handshake() -> io::Result<()> { + use tokio::io::BufWriter; + + let (server, mut client) = make_pair(); + let mut server = Connection::from(server); + + { + let mut good = BufWriter::new(Good(&mut server)); + let mut stream = Stream::new(&mut good, &mut client); + let (r, w) = poll_fn(|cx| stream.handshake(cx)).await?; + + assert!(r > 0); + assert!(w > 0); + + poll_fn(|cx| stream.handshake(cx)).await?; // finish server handshake + } + + assert!(!server.is_handshaking()); + assert!(!client.is_handshaking()); + + Ok(()) as io::Result<()> +} + #[tokio::test] async fn stream_handshake_eof() -> io::Result<()> { let (_, mut client) = make_pair();