test: split bad channel

This commit is contained in:
quininer 2019-10-11 01:24:27 +08:00
parent 9a161beb87
commit 10c139df08
3 changed files with 39 additions and 11 deletions

View File

@ -32,6 +32,31 @@ impl<'a> AsyncWrite for Good<'a> {
Poll::Ready(Ok(len))
}
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.process_new_packets()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
Poll::Ready(Ok(()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.send_close_notify();
Poll::Ready(Ok(()))
}
}
struct Pending;
impl AsyncRead for Pending {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> {
Poll::Pending
}
}
impl AsyncWrite for Pending {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll<io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
@ -41,22 +66,18 @@ impl<'a> AsyncWrite for Good<'a> {
}
}
struct Bad(bool);
struct Eof;
impl AsyncRead for Bad {
impl AsyncRead for Eof {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(0))
}
}
impl AsyncWrite for Bad {
impl AsyncWrite for Eof {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
if self.0 {
Poll::Pending
} else {
Poll::Ready(Ok(buf.len()))
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
@ -99,7 +120,7 @@ async fn stream_bad() -> io::Result<()> {
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
client.set_buffer_limit(1024);
let mut bad = Bad(true);
let mut bad = Pending;
let mut stream = Stream::new(&mut bad, &mut client);
assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8);
assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8);
@ -138,7 +159,7 @@ async fn stream_handshake() -> io::Result<()> {
async fn stream_handshake_eof() -> io::Result<()> {
let (_, mut client) = make_pair();
let mut bad = Bad(false);
let mut bad = Eof;
let mut stream = Stream::new(&mut bad, &mut client);
let mut cx = Context::from_waker(noop_waker_ref());

View File

@ -146,6 +146,7 @@ impl TlsConnector {
}
impl TlsAcceptor {
#[inline]
pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
@ -153,7 +154,6 @@ impl TlsAcceptor {
self.accept_with(stream, |_| ())
}
#[inline]
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,

View File

@ -96,17 +96,24 @@ async fn test_0rtt() -> io::Result<()> {
let stdout = handle.0.stdout.as_mut().unwrap();
let mut lines = BufReader::new(stdout).lines();
let mut f1 = false;
let mut f2 = false;
for line in lines.by_ref() {
if line?.contains("hello") {
f1 = true;
break
}
}
for line in lines.by_ref() {
if line?.contains("world!") {
f2 = true;
break
}
}
assert!(f1 && f2);
Ok(())
}