diff --git a/src/client.rs b/src/client.rs index c901043..ac961b2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -53,11 +53,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; + futures::ready!(stream.complete_io(cx))?; } if stream.session.wants_write() { - futures::ready!(stream.handshake(cx))?; + futures::ready!(stream.complete_io(cx))?; } } @@ -81,7 +81,32 @@ where fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { match self.state { #[cfg(feature = "early-data")] - TlsState::EarlyData => Poll::Pending, + TlsState::EarlyData => { + let this = self.get_mut(); + + let mut stream = Stream::new(&mut this.io, &mut this.session) + .set_eof(!this.state.readable()); + let (pos, data) = &mut this.early_data; + + // complete handshake + if stream.session.is_handshaking() { + futures::ready!(stream.complete_io(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = futures::ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + // end + this.state = TlsState::Stream; + data.clear(); + + Pin::new(this).poll_read(cx, buf) + } TlsState::Stream | TlsState::WriteShutdown => { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) @@ -91,7 +116,7 @@ where Poll::Ready(Ok(0)) => { this.state.shutdown_read(); Poll::Ready(Ok(0)) - }, + } Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)), Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => { this.state.shutdown_read(); @@ -100,8 +125,9 @@ where this.state.shutdown_write(); } Poll::Ready(Ok(0)) - }, - output => output + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending } } TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), @@ -127,7 +153,7 @@ where // write early data if let Some(mut early_data) = stream.session.early_data() { - let len = match dbg!(early_data.write(buf)) { + let len = match early_data.write(buf) { Ok(n) => n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, @@ -139,7 +165,7 @@ where // complete handshake if stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; + futures::ready!(stream.complete_io(cx))?; } // write early data (fallback) @@ -163,14 +189,6 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - - #[cfg(feature = "early-data")] { - // complete handshake - if stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; - } - } - stream.as_mut_pin().poll_flush(cx) } @@ -183,11 +201,6 @@ where let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); - - // TODO - // - // should we complete the handshake? - stream.as_mut_pin().poll_shutdown(cx) } } diff --git a/src/common/mod.rs b/src/common/mod.rs index e9fc783..1af5ecb 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -13,6 +13,17 @@ pub struct Stream<'a, IO, S> { pub eof: bool } +trait WriteTls { + fn write_tls(&mut self, cx: &mut Context) -> io::Result; +} + +#[derive(Clone, Copy)] +enum Focus { + Empty, + Readable, + Writable +} + impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { Stream { @@ -33,7 +44,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Pin::new(self) } - fn read_io(&mut self, cx: &mut Context) -> Poll> { + pub fn complete_io(&mut self, cx: &mut Context) -> Poll> { + self.complete_inner_io(cx, Focus::Empty) + } + + fn complete_read_io(&mut self, cx: &mut Context) -> Poll> { struct Reader<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -61,7 +76,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { // In case we have an alert to send describing this error, // try a last-gasp write -- but don't predate the primary // error. - let _ = self.write_io(cx); + let _ = self.write_tls(cx); io::Error::new(io::ErrorKind::InvalidData, err) })?; @@ -69,7 +84,85 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(n)) } - fn write_io(&mut self, cx: &mut Context) -> Poll> { + fn complete_write_io(&mut self, cx: &mut Context) -> Poll> { + match self.write_tls(cx) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result) + } + } + + fn complete_inner_io(&mut self, cx: &mut Context, focus: Focus) -> Poll> { + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + let mut write_would_block = false; + let mut read_would_block = false; + + while self.session.wants_write() { + match self.complete_write_io(cx) { + Poll::Ready(Ok(n)) => wrlen += n, + Poll::Pending => { + write_would_block = true; + break + }, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + if let Focus::Writable = focus { + if !write_would_block { + return Poll::Ready(Ok((rdlen, wrlen))); + } else { + return Poll::Pending; + } + } + + if !self.eof && self.session.wants_read() { + match self.complete_read_io(cx) { + Poll::Ready(Ok(0)) => self.eof = true, + Poll::Ready(Ok(n)) => rdlen += n, + Poll::Pending => read_would_block = true, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) + } + } + + let would_block = match focus { + Focus::Empty => write_would_block || read_would_block, + Focus::Readable => read_would_block, + Focus::Writable => write_would_block, + }; + + match (self.eof, self.session.is_handshaking(), would_block) { + (true, true, _) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); + return Poll::Ready(Err(err)); + }, + (_, false, true) => { + let would_block = match focus { + Focus::Empty => rdlen == 0 && wrlen == 0, + Focus::Readable => rdlen == 0, + Focus::Writable => wrlen == 0 + }; + + return if would_block { + Poll::Pending + } else { + Poll::Ready(Ok((rdlen, wrlen))) + }; + }, + (_, false, _) => return Poll::Ready(Ok((rdlen, wrlen))), + (_, true, true) => return Poll::Pending, + (..) => () + } + } + } +} + +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls for Stream<'a, IO, S> { + fn write_tls(&mut self, cx: &mut Context) -> io::Result { + // TODO writev + struct Writer<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -92,58 +185,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } let mut writer = Writer { io: self.io, cx }; - - match self.session.write_tls(&mut writer) { - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, - result => Poll::Ready(result) - } - } - - pub fn handshake(&mut self, cx: &mut Context) -> Poll> { - let mut wrlen = 0; - let mut rdlen = 0; - - loop { - let mut write_would_block = false; - let mut read_would_block = false; - - while self.session.wants_write() { - match self.write_io(cx) { - Poll::Ready(Ok(n)) => wrlen += n, - Poll::Pending => { - write_would_block = true; - break - }, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } - } - - if !self.eof && self.session.wants_read() { - match self.read_io(cx) { - Poll::Ready(Ok(0)) => self.eof = true, - Poll::Ready(Ok(n)) => rdlen += n, - Poll::Pending => read_would_block = true, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) - } - } - - let would_block = write_would_block || read_would_block; - - return match (self.eof, self.session.is_handshaking(), would_block) { - (true, true, _) => { - let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); - Poll::Ready(Err(err)) - }, - (_, false, true) => if rdlen != 0 || wrlen != 0 { - Poll::Ready(Ok((rdlen, wrlen))) - } else { - Poll::Pending - }, - (_, false, _) => Poll::Ready(Ok((rdlen, wrlen))), - (_, true, true) => Poll::Pending, - (..) => continue - } - } + self.session.write_tls(&mut writer) } } @@ -152,8 +194,8 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a let this = self.get_mut(); while this.session.wants_read() { - match this.read_io(cx) { - Poll::Ready(Ok(0)) => break, + match this.complete_inner_io(cx, Focus::Readable) { + Poll::Ready(Ok((0, _))) => break, Poll::Ready(Ok(_)) => (), Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) @@ -178,7 +220,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' Err(err) => return Poll::Ready(Err(err)) }; while this.session.wants_write() { - match this.write_io(cx) { + match this.complete_inner_io(cx, Focus::Writable) { Poll::Ready(Ok(_)) => (), Poll::Pending if len != 0 => break, Poll::Pending => return Poll::Pending, @@ -204,7 +246,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' this.session.flush()?; while this.session.wants_write() { - futures::ready!(this.write_io(cx))?; + futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; } Pin::new(&mut this.io).poll_flush(cx) } @@ -213,7 +255,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<' let this = self.get_mut(); while this.session.wants_write() { - futures::ready!(this.write_io(cx))?; + futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; } Pin::new(&mut this.io).poll_shutdown(cx) diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 20cc4eb..d109369 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -83,7 +83,6 @@ async fn stream_good() -> io::Result<()> { stream.read_to_end(&mut buf).await?; assert_eq!(buf, FILE); stream.write_all(b"Hello World!").await?; - stream.flush().await?; } let mut buf = String::new(); @@ -120,12 +119,12 @@ async fn stream_handshake() -> io::Result<()> { { let mut good = Good(&mut server); let mut stream = Stream::new(&mut good, &mut client); - let (r, w) = poll_fn(|cx| stream.handshake(cx)).await?; + let (r, w) = poll_fn(|cx| stream.complete_io(cx)).await?; assert!(r > 0); assert!(w > 0); - poll_fn(|cx| stream.handshake(cx)).await?; // finish server handshake + poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake } assert!(!server.is_handshaking()); @@ -142,7 +141,7 @@ async fn stream_handshake_eof() -> io::Result<()> { let mut stream = Stream::new(&mut bad, &mut client); let mut cx = Context::from_waker(noop_waker_ref()); - let r = stream.handshake(&mut cx); + let r = stream.complete_io(&mut cx); assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof))); Ok(()) as io::Result<()> @@ -188,11 +187,11 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut let mut stream = Stream::new(&mut good, client); if stream.session.is_handshaking() { - ready!(stream.handshake(cx))?; + ready!(stream.complete_io(cx))?; } if stream.session.wants_write() { - ready!(stream.handshake(cx))?; + ready!(stream.complete_io(cx))?; } Poll::Ready(Ok(())) diff --git a/src/server.rs b/src/server.rs index 92043c9..6a94347 100644 --- a/src/server.rs +++ b/src/server.rs @@ -48,11 +48,11 @@ where let mut stream = Stream::new(io, session).set_eof(eof); if stream.session.is_handshaking() { - futures::ready!(stream.handshake(cx))?; + futures::ready!(stream.complete_io(cx))?; } if stream.session.wants_write() { - futures::ready!(stream.handshake(cx))?; + futures::ready!(stream.complete_io(cx))?; } } diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs index 898deef..cb3e94b 100644 --- a/src/test_0rtt.rs +++ b/src/test_0rtt.rs @@ -22,7 +22,6 @@ async fn get(config: Arc, domain: &str, rtt0: bool) let stream = TcpStream::connect(&addr).await?; let mut stream = connector.connect(domain, stream).await?; stream.write_all(input.as_bytes()).await?; - stream.flush().await?; stream.read_to_end(&mut buf).await?; Ok((stream, String::from_utf8(buf).unwrap())) diff --git a/tests/test.rs b/tests/test.rs index 74918ca..5749efe 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -52,19 +52,18 @@ lazy_static!{ let mut buf = vec![0; 8192]; let n = stream.read(&mut buf).await?; stream.write(&buf[..n]).await?; - stream.flush().await?; let _ = stream.read(&mut buf).await?; Ok(()) as io::Result<()> - }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); + }; - handle.spawn(fut).unwrap(); + handle.spawn(fut.unwrap_or_else(|err| eprintln!("{:?}", err))).unwrap(); } Ok(()) as io::Result<()> - }.unwrap_or_else(|err| eprintln!("server: {:?}", err)); + }; - runtime.block_on(done); + runtime.block_on(done.unwrap_or_else(|err| eprintln!("{:?}", err))); }); let addr = recv.recv().unwrap(); @@ -86,7 +85,6 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc) let stream = TcpStream::connect(&addr).await?; let mut stream = config.connect(domain, stream).await?; stream.write_all(FILE).await?; - stream.flush().await?; stream.read_exact(&mut buf).await?; assert_eq!(buf, FILE);