diff --git a/src/client.rs b/src/client.rs index 9e89468..9f1f7f6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -97,7 +97,7 @@ where // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = try_ready!(stream.poll_write(cx, &data[*pos..])); + let len = try_ready!(stream.pin().poll_write(cx, &data[*pos..])); *pos += len; } } @@ -113,7 +113,7 @@ where let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - match stream.poll_read(cx, buf) { + match stream.pin().poll_read(cx, buf) { Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) @@ -167,7 +167,7 @@ where // write early data (fallback) if !stream.session.is_early_data_accepted() { while *pos < data.len() { - let len = try_ready!(stream.poll_write(cx, &data[*pos..])); + let len = try_ready!(stream.pin().poll_write(cx, &data[*pos..])); *pos += len; } } @@ -175,17 +175,17 @@ where // end this.state = TlsState::Stream; data.clear(); - stream.poll_write(cx, buf) + stream.pin().poll_write(cx, buf) } - _ => stream.poll_write(cx, buf), + _ => stream.pin().poll_write(cx, buf), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()) - .poll_flush(cx) + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.pin().poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -197,7 +197,6 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - try_ready!(stream.poll_flush(cx)); - Pin::new(&mut this.io).poll_close(cx) + stream.pin().poll_close(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index eacf585..585e6c9 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -2,8 +2,7 @@ use std::pin::Pin; use std::task::Poll; use std::marker::Unpin; use std::io::{ self, Read }; -use rustls::Session; -use rustls::WriteV; +use rustls::{ Session, WriteV }; use futures::task::Context; use futures::io::{ AsyncRead, AsyncWrite, IoSlice }; use smallvec::SmallVec; @@ -42,6 +41,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { self } + pub fn pin(&mut self) -> Pin<&mut Self> { + Pin::new(self) + } + pub fn complete_io(&mut self, cx: &mut Context) -> Poll> { self.complete_inner_io(cx, Focus::Empty) } @@ -124,7 +127,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { }; match (self.eof, self.session.is_handshaking(), would_block) { - (true, true, _) => return Poll::Pending, + (true, true, _) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), (_, false, true) => { let would_block = match focus { Focus::Empty => rdlen == 0 && wrlen == 0, @@ -172,10 +175,12 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Str } } -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { - pub fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll> { - while self.session.wants_read() { - match self.complete_inner_io(cx, Focus::Readable) { +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + let this = self.get_mut(); + + while this.session.wants_read() { + match this.complete_inner_io(cx, Focus::Readable) { Poll::Ready(Ok((0, _))) => break, Poll::Ready(Ok(_)) => (), Poll::Pending => return Poll::Pending, @@ -184,13 +189,17 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } // FIXME rustls always ready ? - Poll::Ready(self.session.read(buf)) + Poll::Ready(this.session.read(buf)) } +} - pub fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll> { - let len = self.session.write(buf)?; - while self.session.wants_write() { - match self.complete_inner_io(cx, Focus::Writable) { +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + let this = self.get_mut(); + + let len = this.session.write(buf)?; + while this.session.wants_write() { + match this.complete_inner_io(cx, Focus::Writable) { Poll::Ready(Ok(_)) => (), Poll::Pending if len != 0 => break, Poll::Pending => return Poll::Pending, @@ -202,7 +211,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(len)) } else { // not write zero - match self.session.write(buf) { + match this.session.write(buf) { Ok(0) => Poll::Pending, Ok(n) => Poll::Ready(Ok(n)), Err(err) => Poll::Ready(Err(err)) @@ -210,18 +219,33 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } - pub fn poll_flush(&mut self, cx: &mut Context) -> Poll> { - self.session.flush()?; - while self.session.wants_write() { - match self.complete_inner_io(cx, Focus::Writable) { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + + this.session.flush()?; + while this.session.wants_write() { + match this.complete_inner_io(cx, Focus::Writable) { Poll::Ready(Ok(_)) => (), Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) } } - Pin::new(&mut self.io).poll_flush(cx) + Pin::new(&mut this.io).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + while this.session.wants_write() { + match this.complete_inner_io(cx, Focus::Writable) { + Poll::Ready(Ok(_)) => (), + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + Pin::new(&mut this.io).poll_close(cx) } } -// #[cfg(test)] -// mod test_stream; +#[cfg(test)] +mod test_stream; diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 744758a..67b9146 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -1,4 +1,10 @@ +use std::pin::Pin; +use std::task::Poll; use std::sync::Arc; +use futures::prelude::*; +use futures::task::{ Context, noop_waker_ref }; +use futures::executor; +use futures::io::{ AsyncRead, AsyncWrite }; use std::io::{ self, Read, Write, BufReader, Cursor }; use webpki::DNSNameRef; use rustls::internal::pemfile::{ certs, rsa_private_keys }; @@ -7,146 +13,172 @@ use rustls::{ ServerSession, ClientSession, Session, NoClientAuth }; -use futures::{ Async, Poll }; -use tokio_io::{ AsyncRead, AsyncWrite }; use super::Stream; struct Good<'a>(&'a mut Session); -impl<'a> Read for Good<'a> { - fn read(&mut self, mut buf: &mut [u8]) -> io::Result { - self.0.write_tls(buf.by_ref()) +impl<'a> AsyncRead for Good<'a> { + fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll> { + Poll::Ready(self.0.write_tls(buf.by_ref())) } } -impl<'a> Write for Good<'a> { - fn write(&mut self, mut buf: &[u8]) -> io::Result { +impl<'a> AsyncWrite for Good<'a> { + fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &[u8]) -> Poll> { let len = self.0.read_tls(buf.by_ref())?; self.0.process_new_packets() .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; - Ok(len) + Poll::Ready(Ok(len)) } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } -} -impl<'a> AsyncRead for Good<'a> {} -impl<'a> AsyncWrite for Good<'a> { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(Async::Ready(())) + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } } struct Bad(bool); -impl Read for Bad { - fn read(&mut self, _: &mut [u8]) -> io::Result { - Ok(0) +impl AsyncRead for Bad { + fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll> { + Poll::Ready(Ok(0)) } } -impl Write for Bad { - fn write(&mut self, buf: &[u8]) -> io::Result { +impl AsyncWrite for Bad { + fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { if self.0 { - Err(io::ErrorKind::WouldBlock.into()) + Poll::Pending } else { - Ok(buf.len()) + Poll::Ready(Ok(buf.len())) } } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } } -impl AsyncRead for Bad {} -impl AsyncWrite for Bad { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(Async::Ready(())) - } -} - - #[test] fn stream_good() -> io::Result<()> { const FILE: &'static [u8] = include_bytes!("../../README.md"); - let (mut server, mut client) = make_pair(); - do_handshake(&mut client, &mut server); - io::copy(&mut Cursor::new(FILE), &mut server)?; + let fut = async { + let (mut server, mut client) = make_pair(); + future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + io::copy(&mut Cursor::new(FILE), &mut server)?; - { - let mut good = Good(&mut server); - let mut stream = Stream::new(&mut good, &mut client); + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client); - let mut buf = Vec::new(); - stream.read_to_end(&mut buf)?; - assert_eq!(buf, FILE); - stream.write_all(b"Hello World!")?; - } + let mut buf = Vec::new(); + stream.read_to_end(&mut buf).await?; + assert_eq!(buf, FILE); + stream.write_all(b"Hello World!").await?; + } - let mut buf = String::new(); - server.read_to_string(&mut buf)?; - assert_eq!(buf, "Hello World!"); + let mut buf = String::new(); + server.read_to_string(&mut buf)?; + assert_eq!(buf, "Hello World!"); - Ok(()) + Ok(()) as io::Result<()> + }; + + executor::block_on(fut) } #[test] fn stream_bad() -> io::Result<()> { - let (mut server, mut client) = make_pair(); - do_handshake(&mut client, &mut server); - client.set_buffer_limit(1024); + let fut = async { + let (mut server, mut client) = make_pair(); + future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + client.set_buffer_limit(1024); - let mut bad = Bad(true); - let mut stream = Stream::new(&mut bad, &mut client); - assert_eq!(stream.write(&[0x42; 8])?, 8); - assert_eq!(stream.write(&[0x42; 8])?, 8); - let r = stream.write(&[0x00; 1024])?; // fill buffer - assert!(r < 1024); - assert_eq!( - stream.write(&[0x01]).unwrap_err().kind(), - io::ErrorKind::WouldBlock - ); + let mut bad = Bad(true); + let mut stream = Stream::new(&mut bad, &mut client); + assert_eq!(future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x42; 8])).await?, 8); + assert_eq!(future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x42; 8])).await?, 8); + let r = future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer + assert!(r < 1024); - Ok(()) + let mut cx = Context::from_waker(noop_waker_ref()); + assert!(stream.pin().poll_write(&mut cx, &[0x01]).is_pending()); + + Ok(()) as io::Result<()> + }; + + executor::block_on(fut) } #[test] fn stream_handshake() -> io::Result<()> { - let (mut server, mut client) = make_pair(); + let fut = async { + let (mut server, mut client) = make_pair(); - { - let mut good = Good(&mut server); - let mut stream = Stream::new(&mut good, &mut client); - let (r, w) = stream.complete_io()?; + { + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client); + let (r, w) = future::poll_fn(|cx| stream.complete_io(cx)).await?; - assert!(r > 0); - assert!(w > 0); + assert!(r > 0); + assert!(w > 0); - stream.complete_io()?; // finish server handshake - } + future::poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake + } - assert!(!server.is_handshaking()); - assert!(!client.is_handshaking()); + assert!(!server.is_handshaking()); + assert!(!client.is_handshaking()); - Ok(()) + Ok(()) as io::Result<()> + }; + + executor::block_on(fut) } #[test] fn stream_handshake_eof() -> io::Result<()> { - let (_, mut client) = make_pair(); + let fut = async { + let (_, mut client) = make_pair(); - let mut bad = Bad(false); - let mut stream = Stream::new(&mut bad, &mut client); - let r = stream.complete_io(); + let mut bad = Bad(false); + let mut stream = Stream::new(&mut bad, &mut client); - assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof); + let mut cx = Context::from_waker(noop_waker_ref()); + let r = stream.complete_io(&mut cx); + assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); - Ok(()) + Ok(()) as io::Result<()> + }; + + executor::block_on(fut) +} + +#[test] +fn stream_eof() -> io::Result<()> { + let fut = async { + let (mut server, mut client) = make_pair(); + future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; + + let mut good = Good(&mut server); + let mut stream = Stream::new(&mut good, &mut client).set_eof(true); + + let mut buf = Vec::new(); + stream.read_to_end(&mut buf).await?; + assert_eq!(buf.len(), 0); + + Ok(()) as io::Result<()> + }; + + executor::block_on(fut) } fn make_pair() -> (ServerSession, ClientSession) { @@ -169,9 +201,17 @@ fn make_pair() -> (ServerSession, ClientSession) { (server, client) } -fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) { +fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut Context<'_>) -> Poll> { let mut good = Good(server); let mut stream = Stream::new(&mut good, client); - stream.complete_io().unwrap(); - stream.complete_io().unwrap(); + + if stream.session.is_handshaking() { + try_ready!(stream.complete_io(cx)); + } + + if stream.session.wants_write() { + try_ready!(stream.complete_io(cx)); + } + + Poll::Ready(Ok(())) } diff --git a/src/server.rs b/src/server.rs index ba054a9..21cc5e6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -78,7 +78,7 @@ where .set_eof(!this.state.readable()); match this.state { - TlsState::Stream | TlsState::WriteShutdown => match stream.poll_read(cx, buf) { + TlsState::Stream | TlsState::WriteShutdown => match stream.pin().poll_read(cx, buf) { Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) @@ -108,16 +108,16 @@ where { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); - Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()) - .poll_write(cx, buf) + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.pin().poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()) - .poll_flush(cx) + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + stream.pin().poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -127,9 +127,7 @@ where } let this = self.get_mut(); - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); - try_ready!(stream.complete_io(cx)); - Pin::new(&mut this.io).poll_close(cx) + let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.pin().poll_close(cx) } }