diff --git a/tokio-rustls/src/client.rs b/tokio-rustls/src/client.rs index 1244019..f8d8d07 100644 --- a/tokio-rustls/src/client.rs +++ b/tokio-rustls/src/client.rs @@ -119,9 +119,9 @@ where Poll::Ready(Ok(())) } - Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { + Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { this.state.shutdown_read(); - Poll::Ready(Ok(())) + Poll::Ready(Err(err)) } output => output, } diff --git a/tokio-rustls/src/common/mod.rs b/tokio-rustls/src/common/mod.rs index 478c0ff..a90c3fb 100644 --- a/tokio-rustls/src/common/mod.rs +++ b/tokio-rustls/src/common/mod.rs @@ -62,7 +62,6 @@ pub struct Stream<'a, IO, C> { pub io: &'a mut IO, pub session: &'a mut C, pub eof: bool, - pub unexpected_eof: bool, } impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C> @@ -77,7 +76,6 @@ where // The state so far is only used to detect EOF, so either Stream // or EarlyData state should both be all right. eof: false, - unexpected_eof: false, } } @@ -238,70 +236,53 @@ where cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let prev = buf.remaining(); + let mut io_pending = false; - while buf.remaining() != 0 { - let mut io_pending = false; - - // read a packet - while !self.eof && self.session.wants_read() { - match self.read_io(cx) { - Poll::Ready(Ok(0)) => { - self.eof = true; - break; - } - Poll::Ready(Ok(_)) => (), - Poll::Pending => { - io_pending = true; - break; - } - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + // read a packet + while !self.eof && self.session.wants_read() { + match self.read_io(cx) { + Poll::Ready(Ok(0)) => { + break; } + Poll::Ready(Ok(_)) => (), + Poll::Pending => { + io_pending = true; + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), } - - return match self.session.reader().read(buf.initialize_unfilled()) { - // If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the - // connection with a `CloseNotify` message and no more data will be forthcoming. - Ok(0) => break, - - // Rustls yielded more data: advance the buffer, then see if more data is coming. - Ok(n) => { - buf.advance(n); - - if self.eof || io_pending { - break; - } else { - continue; - } - } - - // Rustls doesn't have more data to yield, but it believes the connection is open. - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - if prev == buf.remaining() && io_pending { - Poll::Pending - } else if self.eof || io_pending { - break; - } else { - continue; - } - } - - Err(err) if err.kind() == io::ErrorKind::UnexpectedEof => { - self.eof = true; - self.unexpected_eof = true; - if prev == buf.remaining() { - Poll::Ready(Err(err)) - } else { - break; - } - } - - // This should be unreachable. - Err(err) => Poll::Ready(Err(err)), - }; } - Poll::Ready(Ok(())) + match self.session.reader().read(buf.initialize_unfilled()) { + // If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the + // connection with a `CloseNotify` message and no more data will be forthcoming. + // + // Rustls yielded more data: advance the buffer, then see if more data is coming. + // + // We don't need to modify `self.eof` here, because it is only a temporary mark. + // rustls will only return 0 if is has received `CloseNotify`, + // in which case no additional processing is required. + Ok(n) => { + buf.advance(n); + Poll::Ready(Ok(())) + } + + // Rustls doesn't have more data to yield, but it believes the connection is open. + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + if !io_pending { + // If `wants_read()` is satisfied, rustls will not return `WouldBlock`. + // but if it does, we can try again. + // + // If the rustls state is abnormal, it may cause a cyclic wakeup. + // but tokio's cooperative budget will prevent infinite wakeup. + cx.waker().wake_by_ref(); + } + + Poll::Pending + } + + Err(err) => Poll::Ready(Err(err)), + } } } diff --git a/tokio-rustls/src/common/test_stream.rs b/tokio-rustls/src/common/test_stream.rs index 8623d14..89a5686 100644 --- a/tokio-rustls/src/common/test_stream.rs +++ b/tokio-rustls/src/common/test_stream.rs @@ -127,8 +127,10 @@ async fn stream_good() -> io::Result<()> { let (server, mut client) = make_pair(); let mut server = Connection::from(server); poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + io::copy(&mut Cursor::new(FILE), &mut server.writer())?; server.send_close_notify(); + let mut server = Connection::from(server); { @@ -138,8 +140,10 @@ async fn stream_good() -> io::Result<()> { let mut buf = Vec::new(); dbg!(stream.read_to_end(&mut buf).await)?; assert_eq!(buf, FILE); + dbg!(stream.write_all(b"Hello World!").await)?; stream.session.send_close_notify(); + dbg!(stream.shutdown().await)?; } @@ -241,8 +245,8 @@ async fn stream_eof() -> io::Result<()> { let mut server = Connection::from(server); poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; - let mut good = Good(&mut server); - let mut stream = Stream::new(&mut good, &mut client).set_eof(true); + let mut bad = Expected(Cursor::new(Vec::new())); + let mut stream = Stream::new(&mut bad, &mut client); let mut buf = Vec::new(); let result = stream.read_to_end(&mut buf).await; diff --git a/tokio-rustls/src/server.rs b/tokio-rustls/src/server.rs index 4b8ec49..f39f80f 100644 --- a/tokio-rustls/src/server.rs +++ b/tokio-rustls/src/server.rs @@ -77,12 +77,11 @@ where Poll::Ready(Ok(())) } - Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => { + Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => { this.state.shutdown_read(); - Poll::Ready(Ok(())) + Poll::Ready(Err(err)) } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, + output => output, } } TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),