diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 84bcbe6..d706e37 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -32,6 +32,31 @@ impl<'a> AsyncWrite for Good<'a> { Poll::Ready(Ok(len)) } + fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.0.process_new_packets() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + Poll::Ready(Ok(())) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.0.send_close_notify(); + Poll::Ready(Ok(())) + } +} + +struct Pending; + +impl AsyncRead for Pending { + fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll> { + Poll::Pending + } +} + +impl AsyncWrite for Pending { + fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll> { + Poll::Pending + } + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } @@ -41,21 +66,17 @@ impl<'a> AsyncWrite for Good<'a> { } } -struct Bad(bool); +struct Eof; -impl AsyncRead for Bad { +impl AsyncRead for Eof { fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll> { Poll::Ready(Ok(0)) } } -impl AsyncWrite for Bad { +impl AsyncWrite for Eof { fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - if self.0 { - Poll::Pending - } else { - Poll::Ready(Ok(buf.len())) - } + Poll::Ready(Ok(buf.len())) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -99,7 +120,7 @@ async fn stream_bad() -> io::Result<()> { poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; client.set_buffer_limit(1024); - let mut bad = Bad(true); + let mut bad = Pending; let mut stream = Stream::new(&mut bad, &mut client); assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); @@ -138,7 +159,7 @@ async fn stream_handshake() -> io::Result<()> { async fn stream_handshake_eof() -> io::Result<()> { let (_, mut client) = make_pair(); - let mut bad = Bad(false); + let mut bad = Eof; let mut stream = Stream::new(&mut bad, &mut client); let mut cx = Context::from_waker(noop_waker_ref()); diff --git a/src/lib.rs b/src/lib.rs index 382e43a..3dea67f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -146,6 +146,7 @@ impl TlsConnector { } impl TlsAcceptor { + #[inline] pub fn accept(&self, stream: IO) -> Accept where IO: AsyncRead + AsyncWrite + Unpin, @@ -153,7 +154,6 @@ impl TlsAcceptor { self.accept_with(stream, |_| ()) } - #[inline] pub fn accept_with(&self, stream: IO, f: F) -> Accept where IO: AsyncRead + AsyncWrite + Unpin, diff --git a/tests/early-data.rs b/tests/early-data.rs index 9dd6b5e..7a43034 100644 --- a/tests/early-data.rs +++ b/tests/early-data.rs @@ -96,17 +96,24 @@ async fn test_0rtt() -> io::Result<()> { let stdout = handle.0.stdout.as_mut().unwrap(); let mut lines = BufReader::new(stdout).lines(); + let mut f1 = false; + let mut f2 = false; + for line in lines.by_ref() { if line?.contains("hello") { + f1 = true; break } } for line in lines.by_ref() { if line?.contains("world!") { + f2 = true; break } } + assert!(f1 && f2); + Ok(()) }