diff --git a/tokio-rustls/src/common/mod.rs b/tokio-rustls/src/common/mod.rs index 06dc39b..478c0ff 100644 --- a/tokio-rustls/src/common/mod.rs +++ b/tokio-rustls/src/common/mod.rs @@ -116,7 +116,7 @@ where Err(err) => return Poll::Ready(Err(err)), }; - self.session.process_new_packets().map_err(|err| { + let stats = self.session.process_new_packets().map_err(|err| { // In case we have an alert to send describing this error, // try a last-gasp write -- but don't predate the primary // error. @@ -125,6 +125,13 @@ where io::Error::new(io::ErrorKind::InvalidData, err) })?; + if stats.peer_has_closed() && self.session.is_handshaking() { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "tls handshake alert", + ))); + } + Poll::Ready(Ok(n)) } diff --git a/tokio-rustls/src/common/test_stream.rs b/tokio-rustls/src/common/test_stream.rs index 9f1359c..8623d14 100644 --- a/tokio-rustls/src/common/test_stream.rs +++ b/tokio-rustls/src/common/test_stream.rs @@ -86,19 +86,23 @@ impl AsyncWrite for Pending { } } -struct Eof; +struct Expected(Cursor>); -impl AsyncRead for Eof { +impl AsyncRead for Expected { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - _: &mut ReadBuf<'_>, + buf: &mut ReadBuf<'_>, ) -> Poll> { + let this = self.get_mut(); + let n = std::io::Read::read(&mut this.0, buf.initialize_unfilled())?; + buf.advance(n); + Poll::Ready(Ok(())) } } -impl AsyncWrite for Eof { +impl AsyncWrite for Expected { fn poll_write( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -200,7 +204,25 @@ async fn stream_handshake() -> io::Result<()> { async fn stream_handshake_eof() -> io::Result<()> { let (_, mut client) = make_pair(); - let mut bad = Eof; + let mut bad = Expected(Cursor::new(Vec::new())); + let mut stream = Stream::new(&mut bad, &mut client); + + let mut cx = Context::from_waker(noop_waker_ref()); + let r = stream.handshake(&mut cx); + assert_eq!( + r.map_err(|err| err.kind()), + Poll::Ready(Err(io::ErrorKind::UnexpectedEof)) + ); + + Ok(()) as io::Result<()> +} + +// see https://github.com/tokio-rs/tls/issues/77 +#[tokio::test] +async fn stream_handshake_regression_issues_77() -> io::Result<()> { + let (_, mut client) = make_pair(); + + let mut bad = Expected(Cursor::new(b"\x15\x03\x01\x00\x02\x02\x00".to_vec())); let mut stream = Stream::new(&mut bad, &mut client); let mut cx = Context::from_waker(noop_waker_ref());