From 86171d34a8c8d9e630b9ca5c2af999389ae400e1 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 1 Oct 2019 14:36:16 +0800 Subject: [PATCH] refactor: more read an write --- src/client.rs | 8 +-- src/common/mod.rs | 133 +++++++++++++++++++++++--------------- src/common/test_stream.rs | 2 +- src/server.rs | 2 +- 4 files changed, 88 insertions(+), 57 deletions(-) diff --git a/src/client.rs b/src/client.rs index c901043..4803410 100644 --- a/src/client.rs +++ b/src/client.rs @@ -52,7 +52,7 @@ where let (io, session) = stream.get_mut(); let mut stream = Stream::new(io, session).set_eof(eof); - if stream.session.is_handshaking() { + while stream.session.is_handshaking() { futures::ready!(stream.handshake(cx))?; } @@ -127,7 +127,7 @@ where // write early data if let Some(mut early_data) = stream.session.early_data() { - let len = match dbg!(early_data.write(buf)) { + let len = match early_data.write(buf) { Ok(n) => n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, @@ -138,7 +138,7 @@ where } // complete handshake - if stream.session.is_handshaking() { + while stream.session.is_handshaking() { futures::ready!(stream.handshake(cx))?; } @@ -166,7 +166,7 @@ where #[cfg(feature = "early-data")] { // complete handshake - if stream.session.is_handshaking() { + while stream.session.is_handshaking() { futures::ready!(stream.handshake(cx))?; } } diff --git a/src/common/mod.rs b/src/common/mod.rs index e9fc783..195d0da 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -33,6 +33,18 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Pin::new(self) } + fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> { + self.session.process_new_packets() + .map_err(|err| { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ = self.write_io(cx); + + io::Error::new(io::ErrorKind::InvalidData, err) + }) + } + fn read_io(&mut self, cx: &mut Context) -> Poll> { struct Reader<'a, 'b, T> { io: &'a mut T, @@ -56,16 +68,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Err(err) => return Poll::Ready(Err(err)) }; - self.session.process_new_packets() - .map_err(|err| { - // In case we have an alert to send describing this error, - // try a last-gasp write -- but don't predate the primary - // error. - let _ = self.write_io(cx); - - io::Error::new(io::ErrorKind::InvalidData, err) - })?; - Poll::Ready(Ok(n)) } @@ -118,29 +120,31 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } - if !self.eof && self.session.wants_read() { + while !self.eof && self.session.wants_read() { match self.read_io(cx) { Poll::Ready(Ok(0)) => self.eof = true, Poll::Ready(Ok(n)) => rdlen += n, - Poll::Pending => read_would_block = true, + Poll::Pending => { + read_would_block = true; + break + }, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } - let would_block = write_would_block || read_would_block; + self.process_new_packets(cx)?; - return match (self.eof, self.session.is_handshaking(), would_block) { - (true, true, _) => { + return match (self.eof, self.session.is_handshaking()) { + (true, true) => { let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); Poll::Ready(Err(err)) }, - (_, false, true) => if rdlen != 0 || wrlen != 0 { + (_, false) => Poll::Ready(Ok((rdlen, wrlen))), + (_, true) if write_would_block || read_would_block => if rdlen != 0 || wrlen != 0 { Poll::Ready(Ok((rdlen, wrlen))) } else { Poll::Pending }, - (_, false, _) => Poll::Ready(Ok((rdlen, wrlen))), - (_, true, true) => Poll::Pending, (..) => continue } } @@ -150,53 +154,80 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { let this = self.get_mut(); + let mut pos = 0; - while this.session.wants_read() { - match this.read_io(cx) { - Poll::Ready(Ok(0)) => break, - Poll::Ready(Ok(_)) => (), - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + while pos != buf.len() { + let mut would_block = false; + + // read a packet + while this.session.wants_read() { + match this.read_io(cx) { + Poll::Ready(Ok(0)) => { + this.eof = true; + break + }, + Poll::Ready(Ok(_)) => (), + Poll::Pending => { + would_block = true; + break + }, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + this.process_new_packets(cx)?; + + return match this.session.read(&mut buf[pos..]) { + Ok(0) if pos == 0 && would_block => Poll::Pending, + Ok(n) if this.eof || would_block => Poll::Ready(Ok(pos + n)), + Ok(n) => { + pos += n; + continue + }, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 => + Poll::Ready(Ok(pos)), + Err(err) => Poll::Ready(Err(err)) } } - match this.session.read(buf) { - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - result => Poll::Ready(result) - } + Poll::Ready(Ok(pos)) } } impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let this = self.get_mut(); + let mut pos = 0; - let len = match this.session.write(buf) { - Ok(n) => n, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => - return Poll::Pending, - Err(err) => return Poll::Ready(Err(err)) - }; - while this.session.wants_write() { - match this.write_io(cx) { - Poll::Ready(Ok(_)) => (), - Poll::Pending if len != 0 => break, - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + while pos != buf.len() { + let mut would_block = false; + + match this.session.write(&buf[pos..]) { + Ok(n) => pos += n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (), + Err(err) => return Poll::Ready(Err(err)) + }; + + while this.session.wants_write() { + match this.write_io(cx) { + Poll::Ready(Ok(0)) | Poll::Pending => { + would_block = true; + break + }, + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + return match (pos, would_block) { + (0, true) => Poll::Pending, + (n, true) => Poll::Ready(Ok(n)), + (_, false) => continue } } - if len != 0 || buf.is_empty() { - Poll::Ready(Ok(len)) - } else { - // not write zero - match this.session.write(buf) { - Ok(0) => Poll::Pending, - Ok(n) => Poll::Ready(Ok(n)), - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - Err(err) => Poll::Ready(Err(err)) - } - } + Poll::Ready(Ok(pos)) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 20cc4eb..8be20f4 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -187,7 +187,7 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut let mut good = Good(server); let mut stream = Stream::new(&mut good, client); - if stream.session.is_handshaking() { + while stream.session.is_handshaking() { ready!(stream.handshake(cx))?; } diff --git a/src/server.rs b/src/server.rs index 92043c9..91c1cb4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -47,7 +47,7 @@ where let (io, session) = stream.get_mut(); let mut stream = Stream::new(io, session).set_eof(eof); - if stream.session.is_handshaking() { + while stream.session.is_handshaking() { futures::ready!(stream.handshake(cx))?; }