test: split bad channel

This commit is contained in:
quininer 2019-10-11 01:24:27 +08:00
parent 9a161beb87
commit 10c139df08
3 changed files with 39 additions and 11 deletions

View File

@ -32,6 +32,31 @@ impl<'a> AsyncWrite for Good<'a> {
Poll::Ready(Ok(len)) Poll::Ready(Ok(len))
} }
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.process_new_packets()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
Poll::Ready(Ok(()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.send_close_notify();
Poll::Ready(Ok(()))
}
}
struct Pending;
impl AsyncRead for Pending {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> {
Poll::Pending
}
}
impl AsyncWrite for Pending {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &[u8]) -> Poll<io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
@ -41,21 +66,17 @@ impl<'a> AsyncWrite for Good<'a> {
} }
} }
struct Bad(bool); struct Eof;
impl AsyncRead for Bad { impl AsyncRead for Eof {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(0)) Poll::Ready(Ok(0))
} }
} }
impl AsyncWrite for Bad { impl AsyncWrite for Eof {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
if self.0 { Poll::Ready(Ok(buf.len()))
Poll::Pending
} else {
Poll::Ready(Ok(buf.len()))
}
} }
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
@ -99,7 +120,7 @@ async fn stream_bad() -> io::Result<()> {
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
client.set_buffer_limit(1024); client.set_buffer_limit(1024);
let mut bad = Bad(true); let mut bad = Pending;
let mut stream = Stream::new(&mut bad, &mut client); let mut stream = Stream::new(&mut bad, &mut client);
assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8);
assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8); assert_eq!(poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, 8);
@ -138,7 +159,7 @@ async fn stream_handshake() -> io::Result<()> {
async fn stream_handshake_eof() -> io::Result<()> { async fn stream_handshake_eof() -> io::Result<()> {
let (_, mut client) = make_pair(); let (_, mut client) = make_pair();
let mut bad = Bad(false); let mut bad = Eof;
let mut stream = Stream::new(&mut bad, &mut client); let mut stream = Stream::new(&mut bad, &mut client);
let mut cx = Context::from_waker(noop_waker_ref()); let mut cx = Context::from_waker(noop_waker_ref());

View File

@ -146,6 +146,7 @@ impl TlsConnector {
} }
impl TlsAcceptor { impl TlsAcceptor {
#[inline]
pub fn accept<IO>(&self, stream: IO) -> Accept<IO> pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
where where
IO: AsyncRead + AsyncWrite + Unpin, IO: AsyncRead + AsyncWrite + Unpin,
@ -153,7 +154,6 @@ impl TlsAcceptor {
self.accept_with(stream, |_| ()) self.accept_with(stream, |_| ())
} }
#[inline]
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO> pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
where where
IO: AsyncRead + AsyncWrite + Unpin, IO: AsyncRead + AsyncWrite + Unpin,

View File

@ -96,17 +96,24 @@ async fn test_0rtt() -> io::Result<()> {
let stdout = handle.0.stdout.as_mut().unwrap(); let stdout = handle.0.stdout.as_mut().unwrap();
let mut lines = BufReader::new(stdout).lines(); let mut lines = BufReader::new(stdout).lines();
let mut f1 = false;
let mut f2 = false;
for line in lines.by_ref() { for line in lines.by_ref() {
if line?.contains("hello") { if line?.contains("hello") {
f1 = true;
break break
} }
} }
for line in lines.by_ref() { for line in lines.by_ref() {
if line?.contains("world!") { if line?.contains("world!") {
f2 = true;
break break
} }
} }
assert!(f1 && f2);
Ok(()) Ok(())
} }