From b7925003e27342c31d3ca2a4781ba6fa164d4a65 Mon Sep 17 00:00:00 2001 From: quininer Date: Sat, 1 Jun 2019 22:37:13 +0800 Subject: [PATCH] clean code --- src/client.rs | 22 +++++++++++++--------- src/common/mod.rs | 18 +++++++++++++----- src/common/test_stream.rs | 8 ++++---- src/lib.rs | 3 ++- src/server.rs | 12 ++++++------ 5 files changed, 38 insertions(+), 25 deletions(-) diff --git a/src/client.rs b/src/client.rs index 7bce1b8..11a0331 100644 --- a/src/client.rs +++ b/src/client.rs @@ -75,7 +75,6 @@ where IO: AsyncRead + AsyncWrite + Unpin, { unsafe fn initializer(&self) -> Initializer { - // TODO Initializer::nop() } @@ -97,7 +96,7 @@ where // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = futures::ready!(stream.pin().poll_write(cx, &data[*pos..]))?; + let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; *pos += len; } } @@ -113,7 +112,7 @@ where let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - match stream.pin().poll_read(cx, buf) { + match stream.as_mut_pin().poll_read(cx, buf) { Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) @@ -154,7 +153,12 @@ where // write early data if let Some(mut early_data) = stream.session.early_data() { - let len = early_data.write(buf)?; // TODO check pending + let len = match early_data.write(buf) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => + return Poll::Pending, + Err(err) => return Poll::Ready(Err(err)) + }; data.extend_from_slice(&buf[..len]); return Poll::Ready(Ok(len)); } @@ -167,7 +171,7 @@ where // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = futures::ready!(stream.pin().poll_write(cx, &data[*pos..]))?; + let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; *pos += len; } } @@ -175,9 +179,9 @@ where // end this.state = TlsState::Stream; data.clear(); - stream.pin().poll_write(cx, buf) + stream.as_mut_pin().poll_write(cx, buf) } - _ => stream.pin().poll_write(cx, buf), + _ => stream.as_mut_pin().poll_write(cx, buf), } } @@ -185,7 +189,7 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - stream.pin().poll_flush(cx) + stream.as_mut_pin().poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -197,6 +201,6 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - stream.pin().poll_close(cx) + stream.as_mut_pin().poll_close(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index bff9990..2e648ae 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -14,7 +14,7 @@ pub struct Stream<'a, IO, S> { pub eof: bool } -pub trait WriteTls { +trait WriteTls { fn write_tls(&mut self, cx: &mut Context) -> io::Result; } @@ -41,7 +41,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { self } - pub fn pin(&mut self) -> Pin<&mut Self> { + pub fn as_mut_pin(&mut self) -> Pin<&mut Self> { Pin::new(self) } @@ -191,8 +191,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a } } - // FIXME rustls always ready ? - Poll::Ready(this.session.read(buf)) + match this.session.read(buf) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } } } @@ -200,7 +202,12 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let this = self.get_mut(); - let len = this.session.write(buf)?; + let len = match this.session.write(buf) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => + return Poll::Pending, + Err(err) => return Poll::Ready(Err(err)) + }; while this.session.wants_write() { match this.complete_inner_io(cx, Focus::Writable) { Poll::Ready(Ok(_)) => (), @@ -217,6 +224,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' match this.session.write(buf) { Ok(0) => Poll::Pending, Ok(n) => Poll::Ready(Ok(n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(err) => Poll::Ready(Err(err)) } } diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 1f7c14c..e774778 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -105,13 +105,13 @@ fn stream_bad() -> io::Result<()> { let mut bad = Bad(true); let mut stream = Stream::new(&mut bad, &mut client); - assert_eq!(future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x42; 8])).await?, 8); - assert_eq!(future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x42; 8])).await?, 8); - let r = future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer + assert_eq!(future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + assert_eq!(future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); + let r = future::poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer assert!(r < 1024); let mut cx = Context::from_waker(noop_waker_ref()); - assert!(stream.pin().poll_write(&mut cx, &[0x01]).is_pending()); + assert!(stream.as_mut_pin().poll_write(&mut cx, &[0x01]).is_pending()); Ok(()) as io::Result<()> }; diff --git a/src/lib.rs b/src/lib.rs index a928b87..65c11f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). -#![feature(async_await)] +#![cfg_attr(test, feature(async_await))] + mod common; pub mod client; diff --git a/src/server.rs b/src/server.rs index 2ed7ba9..1e25145 100644 --- a/src/server.rs +++ b/src/server.rs @@ -68,7 +68,6 @@ where IO: AsyncRead + AsyncWrite + Unpin, { unsafe fn initializer(&self) -> Initializer { - // TODO Initializer::nop() } @@ -78,7 +77,7 @@ where .set_eof(!this.state.readable()); match this.state { - TlsState::Stream | TlsState::WriteShutdown => match stream.pin().poll_read(cx, buf) { + TlsState::Stream | TlsState::WriteShutdown => match stream.as_mut_pin().poll_read(cx, buf) { Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) @@ -110,14 +109,14 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - stream.pin().poll_write(cx, buf) + stream.as_mut_pin().poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - stream.pin().poll_flush(cx) + stream.as_mut_pin().poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -127,7 +126,8 @@ where } let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); - stream.pin().poll_close(cx) + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.as_mut_pin().poll_close(cx) } }