From fc90b3f378952b25dcbc2cb299cd16f9abeac43a Mon Sep 17 00:00:00 2001 From: quininer Date: Wed, 20 May 2020 13:09:24 +0800 Subject: [PATCH] tokio-rustls: Add to README and clean code (#15) * tokio-rustls: Add to README and clean code * cargo fmt --- README.md | 1 + tokio-rustls/src/common/handshake.rs | 57 ++++++++++++++-------------- tokio-rustls/src/common/mod.rs | 27 ++++++------- 3 files changed, 41 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 8078a01..70cbed2 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ This crate contains a collection of Tokio based TLS libraries. - [`tokio-native-tls`](tokio-native-tls) +- [`tokio-rustls`](tokio-rustls) ## Getting Help diff --git a/tokio-rustls/src/common/handshake.rs b/tokio-rustls/src/common/handshake.rs index b9b7894..a00a3e1 100644 --- a/tokio-rustls/src/common/handshake.rs +++ b/tokio-rustls/src/common/handshake.rs @@ -47,38 +47,39 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); - if let MidHandshake::Handshaking(mut stream) = mem::replace(this, MidHandshake::End) { - if !stream.skip_handshake() { - let (state, io, session) = stream.get_mut(); - let mut tls_stream = Stream::new(io, session).set_eof(!state.readable()); + let mut stream = + if let MidHandshake::Handshaking(stream) = mem::replace(this, MidHandshake::End) { + stream + } else { + panic!("unexpected polling after handshake") + }; - macro_rules! try_poll { - ( $e:expr ) => { - match $e { - Poll::Ready(Ok(_)) => (), - Poll::Ready(Err(err)) => { - return Poll::Ready(Err((err, stream.into_io()))) - } - Poll::Pending => { - *this = MidHandshake::Handshaking(stream); - return Poll::Pending; - } + if !stream.skip_handshake() { + let (state, io, session) = stream.get_mut(); + let mut tls_stream = Stream::new(io, session).set_eof(!state.readable()); + + macro_rules! try_poll { + ( $e:expr ) => { + match $e { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))), + Poll::Pending => { + *this = MidHandshake::Handshaking(stream); + return Poll::Pending; } - }; - } - - while tls_stream.session.is_handshaking() { - try_poll!(tls_stream.handshake(cx)); - } - - while tls_stream.session.wants_write() { - try_poll!(tls_stream.write_io(cx)); - } + } + }; } - Poll::Ready(Ok(stream)) - } else { - panic!("unexpected polling after handshake") + while tls_stream.session.is_handshaking() { + try_poll!(tls_stream.handshake(cx)); + } + + while tls_stream.session.wants_write() { + try_poll!(tls_stream.write_io(cx)); + } } + + Poll::Ready(Ok(stream)) } } diff --git a/tokio-rustls/src/common/mod.rs b/tokio-rustls/src/common/mod.rs index 53ed976..d93179f 100644 --- a/tokio-rustls/src/common/mod.rs +++ b/tokio-rustls/src/common/mod.rs @@ -96,17 +96,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Pin::new(self) } - pub fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> { - self.session.process_new_packets().map_err(|err| { - // In case we have an alert to send describing this error, - // try a last-gasp write -- but don't predate the primary - // error. - let _ = self.write_io(cx); - - io::Error::new(io::ErrorKind::InvalidData, err) - }) - } - pub fn read_io(&mut self, cx: &mut Context) -> Poll> { struct Reader<'a, 'b, T> { io: &'a mut T, @@ -130,6 +119,15 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Err(err) => return Poll::Ready(Err(err)), }; + self.session.process_new_packets().map_err(|err| { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ = self.write_io(cx); + + io::Error::new(io::ErrorKind::InvalidData, err) + })?; + Poll::Ready(Ok(n)) } @@ -218,10 +216,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { while !self.eof && self.session.wants_read() { match self.read_io(cx) { Poll::Ready(Ok(0)) => self.eof = true, - Poll::Ready(Ok(n)) => { - rdlen += n; - self.process_new_packets(cx)?; - } + Poll::Ready(Ok(n)) => rdlen += n, Poll::Pending => { read_would_block = true; break; @@ -267,7 +262,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a self.eof = true; break; } - Poll::Ready(Ok(_)) => self.process_new_packets(cx)?, + Poll::Ready(Ok(_)) => (), Poll::Pending => { would_block = true; break;