don't throw eof error to keep consistency (#79)

This commit is contained in:
quininer 2021-10-12 16:05:51 +08:00 committed by GitHub
parent 5aae337945
commit 56855b7166
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 69 deletions

View File

@ -119,9 +119,9 @@ where
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read(); this.state.shutdown_read();
Poll::Ready(Ok(())) Poll::Ready(Err(err))
} }
output => output, output => output,
} }

View File

@ -62,7 +62,6 @@ pub struct Stream<'a, IO, C> {
pub io: &'a mut IO, pub io: &'a mut IO,
pub session: &'a mut C, pub session: &'a mut C,
pub eof: bool, pub eof: bool,
pub unexpected_eof: bool,
} }
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C> impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
@ -77,7 +76,6 @@ where
// The state so far is only used to detect EOF, so either Stream // The state so far is only used to detect EOF, so either Stream
// or EarlyData state should both be all right. // or EarlyData state should both be all right.
eof: false, eof: false,
unexpected_eof: false,
} }
} }
@ -238,70 +236,53 @@ where
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>, buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
let prev = buf.remaining(); let mut io_pending = false;
while buf.remaining() != 0 { // read a packet
let mut io_pending = false; while !self.eof && self.session.wants_read() {
match self.read_io(cx) {
// read a packet Poll::Ready(Ok(0)) => {
while !self.eof && self.session.wants_read() { break;
match self.read_io(cx) {
Poll::Ready(Ok(0)) => {
self.eof = true;
break;
}
Poll::Ready(Ok(_)) => (),
Poll::Pending => {
io_pending = true;
break;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
} }
Poll::Ready(Ok(_)) => (),
Poll::Pending => {
io_pending = true;
break;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
} }
return match self.session.reader().read(buf.initialize_unfilled()) {
// If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
// connection with a `CloseNotify` message and no more data will be forthcoming.
Ok(0) => break,
// Rustls yielded more data: advance the buffer, then see if more data is coming.
Ok(n) => {
buf.advance(n);
if self.eof || io_pending {
break;
} else {
continue;
}
}
// Rustls doesn't have more data to yield, but it believes the connection is open.
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
if prev == buf.remaining() && io_pending {
Poll::Pending
} else if self.eof || io_pending {
break;
} else {
continue;
}
}
Err(err) if err.kind() == io::ErrorKind::UnexpectedEof => {
self.eof = true;
self.unexpected_eof = true;
if prev == buf.remaining() {
Poll::Ready(Err(err))
} else {
break;
}
}
// This should be unreachable.
Err(err) => Poll::Ready(Err(err)),
};
} }
Poll::Ready(Ok(())) match self.session.reader().read(buf.initialize_unfilled()) {
// If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
// connection with a `CloseNotify` message and no more data will be forthcoming.
//
// Rustls yielded more data: advance the buffer, then see if more data is coming.
//
// We don't need to modify `self.eof` here, because it is only a temporary mark.
// rustls will only return 0 if is has received `CloseNotify`,
// in which case no additional processing is required.
Ok(n) => {
buf.advance(n);
Poll::Ready(Ok(()))
}
// Rustls doesn't have more data to yield, but it believes the connection is open.
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
if !io_pending {
// If `wants_read()` is satisfied, rustls will not return `WouldBlock`.
// but if it does, we can try again.
//
// If the rustls state is abnormal, it may cause a cyclic wakeup.
// but tokio's cooperative budget will prevent infinite wakeup.
cx.waker().wake_by_ref();
}
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
} }
} }

View File

@ -127,8 +127,10 @@ async fn stream_good() -> io::Result<()> {
let (server, mut client) = make_pair(); let (server, mut client) = make_pair();
let mut server = Connection::from(server); let mut server = Connection::from(server);
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
io::copy(&mut Cursor::new(FILE), &mut server.writer())?; io::copy(&mut Cursor::new(FILE), &mut server.writer())?;
server.send_close_notify(); server.send_close_notify();
let mut server = Connection::from(server); let mut server = Connection::from(server);
{ {
@ -138,8 +140,10 @@ async fn stream_good() -> io::Result<()> {
let mut buf = Vec::new(); let mut buf = Vec::new();
dbg!(stream.read_to_end(&mut buf).await)?; dbg!(stream.read_to_end(&mut buf).await)?;
assert_eq!(buf, FILE); assert_eq!(buf, FILE);
dbg!(stream.write_all(b"Hello World!").await)?; dbg!(stream.write_all(b"Hello World!").await)?;
stream.session.send_close_notify(); stream.session.send_close_notify();
dbg!(stream.shutdown().await)?; dbg!(stream.shutdown().await)?;
} }
@ -241,8 +245,8 @@ async fn stream_eof() -> io::Result<()> {
let mut server = Connection::from(server); let mut server = Connection::from(server);
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
let mut good = Good(&mut server); let mut bad = Expected(Cursor::new(Vec::new()));
let mut stream = Stream::new(&mut good, &mut client).set_eof(true); let mut stream = Stream::new(&mut bad, &mut client);
let mut buf = Vec::new(); let mut buf = Vec::new();
let result = stream.read_to_end(&mut buf).await; let result = stream.read_to_end(&mut buf).await;

View File

@ -77,12 +77,11 @@ where
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => { Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
this.state.shutdown_read(); this.state.shutdown_read();
Poll::Ready(Ok(())) Poll::Ready(Err(err))
} }
Poll::Ready(Err(e)) => Poll::Ready(Err(e)), output => output,
Poll::Pending => Poll::Pending,
} }
} }
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())), TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),