diff --git a/src/client.rs b/src/client.rs index ac961b2..c901043 100644 --- a/src/client.rs +++ b/src/client.rs @@ -53,11 +53,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } if stream.session.wants_write() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } } @@ -81,32 +81,7 @@ where fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { match self.state { #[cfg(feature = "early-data")] - TlsState::EarlyData => { - let this = self.get_mut(); - - let mut stream = Stream::new(&mut this.io, &mut this.session) - .set_eof(!this.state.readable()); - let (pos, data) = &mut this.early_data; - - // complete handshake - if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; - *pos += len; - } - } - - // end - this.state = TlsState::Stream; - data.clear(); - - Pin::new(this).poll_read(cx, buf) - } + TlsState::EarlyData => Poll::Pending, TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) @@ -116,7 +91,7 @@ where Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) - } + }, Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { this.state.shutdown_read(); @@ -125,9 +100,8 @@ where this.state.shutdown_write(); } Poll::Ready(Ok(0)) - } - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), - Poll::Pending => Poll::Pending + }, + output => output } } TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), @@ -153,7 +127,7 @@ where // write early data if let Some(mut early_data) = stream.session.early_data() { - let len = match early_data.write(buf) { + let len = match dbg!(early_data.write(buf)) { Ok(n) => n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, @@ -165,7 +139,7 @@ where // complete handshake if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } // write early data (fallback) @@ -189,6 +163,14 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); + + #[cfg(feature = "early-data")] { + // complete handshake + if stream.session.is_handshaking() { + futures::ready!(stream.handshake(cx))?; + } + } + stream.as_mut_pin().poll_flush(cx) } @@ -201,6 +183,11 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); + + // TODO + // + // should we complete the handshake? + stream.as_mut_pin().poll_shutdown(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 1af5ecb..e9fc783 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -13,17 +13,6 @@ pub struct Stream<'a, IO, S> { pub eof: bool } -trait WriteTls { - fn write_tls(&mut self, cx: &mut Context) -> io::Result; -} - -#[derive(Clone, Copy)] -enum Focus { - Empty, - Readable, - Writable -} - impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { Stream { @@ -44,11 +33,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Pin::new(self) } - pub fn complete_io(&mut self, cx: &mut Context) -> Poll> { - self.complete_inner_io(cx, Focus::Empty) - } - - fn complete_read_io(&mut self, cx: &mut Context) -> Poll> { + fn read_io(&mut self, cx: &mut Context) -> Poll> { struct Reader<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -76,7 +61,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { // In case we have an alert to send describing this error, // try a last-gasp write -- but don't predate the primary // error. - let _ = self.write_tls(cx); + let _ = self.write_io(cx); io::Error::new(io::ErrorKind::InvalidData, err) })?; @@ -84,85 +69,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(n)) } - fn complete_write_io(&mut self, cx: &mut Context) -> Poll> { - match self.write_tls(cx) { - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - result => Poll::Ready(result) - } - } - - fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll> { - let mut wrlen = 0; - let mut rdlen = 0; - - loop { - let mut write_would_block = false; - let mut read_would_block = false; - - while self.session.wants_write() { - match self.complete_write_io(cx) { - Poll::Ready(Ok(n)) => wrlen += n, - Poll::Pending => { - write_would_block = true; - break - }, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } - } - - if let Focus::Writable = focus { - if !write_would_block { - return Poll::Ready(Ok((rdlen, wrlen))); - } else { - return Poll::Pending; - } - } - - if !self.eof && self.session.wants_read() { - match self.complete_read_io(cx) { - Poll::Ready(Ok(0)) => self.eof = true, - Poll::Ready(Ok(n)) => rdlen += n, - Poll::Pending => read_would_block = true, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } - } - - let would_block = match focus { - Focus::Empty => write_would_block || read_would_block, - Focus::Readable => read_would_block, - Focus::Writable => write_would_block, - }; - - match (self.eof, self.session.is_handshaking(), would_block) { - (true, true, _) => { - let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); - return Poll::Ready(Err(err)); - }, - (_, false, true) => { - let would_block = match focus { - Focus::Empty => rdlen == 0 && wrlen == 0, - Focus::Readable => rdlen == 0, - Focus::Writable => wrlen == 0 - }; - - return if would_block { - Poll::Pending - } else { - Poll::Ready(Ok((rdlen, wrlen))) - }; - }, - (_, false, _) => return Poll::Ready(Ok((rdlen, wrlen))), - (_, true, true) => return Poll::Pending, - (..) => () - } - } - } -} - -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Stream<'a, IO, S> { - fn write_tls(&mut self, cx: &mut Context) -> io::Result { - // TODO writev - + fn write_io(&mut self, cx: &mut Context) -> Poll> { struct Writer<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -185,7 +92,58 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Str } let mut writer = Writer { io: self.io, cx }; - self.session.write_tls(&mut writer) + + match self.session.write_tls(&mut writer) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } + } + + pub fn handshake(&mut self, cx: &mut Context) -> Poll> { + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + let mut write_would_block = false; + let mut read_would_block = false; + + while self.session.wants_write() { + match self.write_io(cx) { + Poll::Ready(Ok(n)) => wrlen += n, + Poll::Pending => { + write_would_block = true; + break + }, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + if !self.eof && self.session.wants_read() { + match self.read_io(cx) { + Poll::Ready(Ok(0)) => self.eof = true, + Poll::Ready(Ok(n)) => rdlen += n, + Poll::Pending => read_would_block = true, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + let would_block = write_would_block || read_would_block; + + return match (self.eof, self.session.is_handshaking(), would_block) { + (true, true, _) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); + Poll::Ready(Err(err)) + }, + (_, false, true) => if rdlen != 0 || wrlen != 0 { + Poll::Ready(Ok((rdlen, wrlen))) + } else { + Poll::Pending + }, + (_, false, _) => Poll::Ready(Ok((rdlen, wrlen))), + (_, true, true) => Poll::Pending, + (..) => continue + } + } } } @@ -194,8 +152,8 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a let this = self.get_mut(); while this.session.wants_read() { - match this.complete_inner_io(cx, Focus::Readable) { - Poll::Ready(Ok((0, _))) => break, + match this.read_io(cx) { + Poll::Ready(Ok(0)) => break, Poll::Ready(Ok(_)) => (), Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) @@ -220,7 +178,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' Err(err) => return Poll::Ready(Err(err)) }; while this.session.wants_write() { - match this.complete_inner_io(cx, Focus::Writable) { + match this.write_io(cx) { Poll::Ready(Ok(_)) => (), Poll::Pending if len != 0 => break, Poll::Pending => return Poll::Pending, @@ -246,7 +204,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' this.session.flush()?; while this.session.wants_write() { - futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; + futures::ready!(this.write_io(cx))?; } Pin::new(&mut this.io).poll_flush(cx) } @@ -255,7 +213,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' let this = self.get_mut(); while this.session.wants_write() { - futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; + futures::ready!(this.write_io(cx))?; } Pin::new(&mut this.io).poll_shutdown(cx) diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index d109369..20cc4eb 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -83,6 +83,7 @@ async fn stream_good() -> io::Result<()> { stream.read_to_end(&mut buf).await?; assert_eq!(buf, FILE); stream.write_all(b"Hello World!").await?; + stream.flush().await?; } let mut buf = String::new(); @@ -119,12 +120,12 @@ async fn stream_handshake() -> io::Result<()> { { let mut good = Good(&mut server); let mut stream = Stream::new(&mut good, &mut client); - let (r, w) = poll_fn(|cx| stream.complete_io(cx)).await?; + let (r, w) = poll_fn(|cx| stream.handshake(cx)).await?; assert!(r > 0); assert!(w > 0); - poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake + poll_fn(|cx| stream.handshake(cx)).await?; // finish server handshake } assert!(!server.is_handshaking()); @@ -141,7 +142,7 @@ async fn stream_handshake_eof() -> io::Result<()> { let mut stream = Stream::new(&mut bad, &mut client); let mut cx = Context::from_waker(noop_waker_ref()); - let r = stream.complete_io(&mut cx); + let r = stream.handshake(&mut cx); assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); Ok(()) as io::Result<()> @@ -187,11 +188,11 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut let mut stream = Stream::new(&mut good, client); if stream.session.is_handshaking() { - ready!(stream.complete_io(cx))?; + ready!(stream.handshake(cx))?; } if stream.session.wants_write() { - ready!(stream.complete_io(cx))?; + ready!(stream.handshake(cx))?; } Poll::Ready(Ok(())) diff --git a/src/server.rs b/src/server.rs index 6a94347..92043c9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -48,11 +48,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } if stream.session.wants_write() { - futures::ready!(stream.complete_io(cx))?; + futures::ready!(stream.handshake(cx))?; } } diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs index cb3e94b..898deef 100644 --- a/src/test_0rtt.rs +++ b/src/test_0rtt.rs @@ -22,6 +22,7 @@ async fn get(config: Arc, domain: &str, rtt0: bool) let stream = TcpStream::connect(&addr).await?; let mut stream = connector.connect(domain, stream).await?; stream.write_all(input.as_bytes()).await?; + stream.flush().await?; stream.read_to_end(&mut buf).await?; Ok((stream, String::from_utf8(buf).unwrap())) diff --git a/tests/test.rs b/tests/test.rs index 5749efe..74918ca 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -52,18 +52,19 @@ lazy_static!{ let mut buf = vec![0; 8192]; let n = stream.read(&mut buf).await?; stream.write(&buf[..n]).await?; + stream.flush().await?; let _ = stream.read(&mut buf).await?; Ok(()) as io::Result<()> - }; + }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); - handle.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap(); + handle.spawn(fut).unwrap(); } Ok(()) as io::Result<()> - }; + }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); - runtime.block_on(done.unwrap_or_else(|err| eprintln!("{:?}", err))); + runtime.block_on(done); }); let addr = recv.recv().unwrap(); @@ -85,6 +86,7 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc) let stream = TcpStream::connect(&addr).await?; let mut stream = config.connect(domain, stream).await?; stream.write_all(FILE).await?; + stream.flush().await?; stream.read_exact(&mut buf).await?; assert_eq!(buf, FILE);