diff --git a/src/client.rs b/src/client.rs index 26607b3..9843f54 100644 --- a/src/client.rs +++ b/src/client.rs @@ -56,8 +56,8 @@ where futures::ready!(stream.handshake(cx))?; } - if stream.session.wants_write() { - futures::ready!(stream.handshake(cx))?; + while stream.session.wants_write() { + futures::ready!(stream.write_io(cx))?; } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 195d0da..c176ad5 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -33,7 +33,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Pin::new(self) } - fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> { + pub fn process_new_packets(&mut self, cx: &mut Context) -> io::Result<()> { self.session.process_new_packets() .map_err(|err| { // In case we have an alert to send describing this error, @@ -45,7 +45,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { }) } - fn read_io(&mut self, cx: &mut Context) -> Poll> { + pub fn read_io(&mut self, cx: &mut Context) -> Poll> { struct Reader<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> @@ -71,7 +71,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { Poll::Ready(Ok(n)) } - fn write_io(&mut self, cx: &mut Context) -> Poll> { + pub fn write_io(&mut self, cx: &mut Context) -> Poll> { struct Writer<'a, 'b, T> { io: &'a mut T, cx: &'a mut Context<'b> diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 8be20f4..84bcbe6 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -191,8 +191,8 @@ fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut ready!(stream.handshake(cx))?; } - if stream.session.wants_write() { - ready!(stream.handshake(cx))?; + while stream.session.wants_write() { + ready!(stream.write_io(cx))?; } Poll::Ready(Ok(())) diff --git a/src/server.rs b/src/server.rs index 4dad3f6..ac72904 100644 --- a/src/server.rs +++ b/src/server.rs @@ -51,8 +51,8 @@ where futures::ready!(stream.handshake(cx))?; } - if stream.session.wants_write() { - futures::ready!(stream.handshake(cx))?; + while stream.session.wants_write() { + futures::ready!(stream.write_io(cx))?; } } diff --git a/tests/early-data.rs b/tests/early-data.rs index ae0d614..9dd6b5e 100644 --- a/tests/early-data.rs +++ b/tests/early-data.rs @@ -40,11 +40,11 @@ async fn send(config: Arc, addr: SocketAddr, data: &[u8]) stream.write_all(data).await?; stream.flush().await?; - // sleep 3s + // sleep 1s // // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html - let sleep3 = delay_for(Duration::from_secs(3)); - let mut stream = match future::select(Read1(stream), sleep3).await { + let sleep1 = delay_for(Duration::from_secs(1)); + let mut stream = match future::select(Read1(stream), sleep1).await { future::Either::Right((_, Read1(stream))) => stream, future::Either::Left((Err(err), _)) => return Err(err), future::Either::Left((Ok(_), _)) => unreachable!(), @@ -77,7 +77,7 @@ async fn test_0rtt() -> io::Result<()> { .map(DropKill)?; // wait openssl server - delay_for(Duration::from_secs(3)).await; + delay_for(Duration::from_secs(1)).await; let mut config = ClientConfig::new(); let mut chain = BufReader::new(Cursor::new(include_str!("end.chain")));