make all test work!

This commit is contained in:
quininer 2019-05-21 01:47:50 +08:00
parent b03c327ab6
commit 7949f4377a
4 changed files with 184 additions and 123 deletions

View File

@ -97,7 +97,7 @@ where
// write early data (fallback) // write early data (fallback)
if !stream.session.is_early_data_accepted() { if !stream.session.is_early_data_accepted() {
while *pos < data.len() { while *pos < data.len() {
let len = try_ready!(stream.poll_write(cx, &data[*pos..])); let len = try_ready!(stream.pin().poll_write(cx, &data[*pos..]));
*pos += len; *pos += len;
} }
} }
@ -113,7 +113,7 @@ where
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()); .set_eof(!this.state.readable());
match stream.poll_read(cx, buf) { match stream.pin().poll_read(cx, buf) {
Poll::Ready(Ok(0)) => { Poll::Ready(Ok(0)) => {
this.state.shutdown_read(); this.state.shutdown_read();
Poll::Ready(Ok(0)) Poll::Ready(Ok(0))
@ -167,7 +167,7 @@ where
// write early data (fallback) // write early data (fallback)
if !stream.session.is_early_data_accepted() { if !stream.session.is_early_data_accepted() {
while *pos < data.len() { while *pos < data.len() {
let len = try_ready!(stream.poll_write(cx, &data[*pos..])); let len = try_ready!(stream.pin().poll_write(cx, &data[*pos..]));
*pos += len; *pos += len;
} }
} }
@ -175,17 +175,17 @@ where
// end // end
this.state = TlsState::Stream; this.state = TlsState::Stream;
data.clear(); data.clear();
stream.poll_write(cx, buf) stream.pin().poll_write(cx, buf)
} }
_ => stream.poll_write(cx, buf), _ => stream.pin().poll_write(cx, buf),
} }
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
Stream::new(&mut this.io, &mut this.session) let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()) .set_eof(!this.state.readable());
.poll_flush(cx) stream.pin().poll_flush(cx)
} }
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
@ -197,7 +197,6 @@ where
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()); .set_eof(!this.state.readable());
try_ready!(stream.poll_flush(cx)); stream.pin().poll_close(cx)
Pin::new(&mut this.io).poll_close(cx)
} }
} }

View File

@ -2,8 +2,7 @@ use std::pin::Pin;
use std::task::Poll; use std::task::Poll;
use std::marker::Unpin; use std::marker::Unpin;
use std::io::{ self, Read }; use std::io::{ self, Read };
use rustls::Session; use rustls::{ Session, WriteV };
use rustls::WriteV;
use futures::task::Context; use futures::task::Context;
use futures::io::{ AsyncRead, AsyncWrite, IoSlice }; use futures::io::{ AsyncRead, AsyncWrite, IoSlice };
use smallvec::SmallVec; use smallvec::SmallVec;
@ -42,6 +41,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
self self
} }
pub fn pin(&mut self) -> Pin<&mut Self> {
Pin::new(self)
}
pub fn complete_io(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> { pub fn complete_io(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
self.complete_inner_io(cx, Focus::Empty) self.complete_inner_io(cx, Focus::Empty)
} }
@ -124,7 +127,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
}; };
match (self.eof, self.session.is_handshaking(), would_block) { match (self.eof, self.session.is_handshaking(), would_block) {
(true, true, _) => return Poll::Pending, (true, true, _) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())),
(_, false, true) => { (_, false, true) => {
let would_block = match focus { let would_block = match focus {
Focus::Empty => rdlen == 0 && wrlen == 0, Focus::Empty => rdlen == 0 && wrlen == 0,
@ -172,10 +175,12 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls<IO, S> for Str
} }
} }
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
pub fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
while self.session.wants_read() { let this = self.get_mut();
match self.complete_inner_io(cx, Focus::Readable) {
while this.session.wants_read() {
match this.complete_inner_io(cx, Focus::Readable) {
Poll::Ready(Ok((0, _))) => break, Poll::Ready(Ok((0, _))) => break,
Poll::Ready(Ok(_)) => (), Poll::Ready(Ok(_)) => (),
Poll::Pending => return Poll::Pending, Poll::Pending => return Poll::Pending,
@ -184,13 +189,17 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
} }
// FIXME rustls always ready ? // FIXME rustls always ready ?
Poll::Ready(self.session.read(buf)) Poll::Ready(this.session.read(buf))
} }
}
pub fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> {
let len = self.session.write(buf)?; fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
while self.session.wants_write() { let this = self.get_mut();
match self.complete_inner_io(cx, Focus::Writable) {
let len = this.session.write(buf)?;
while this.session.wants_write() {
match this.complete_inner_io(cx, Focus::Writable) {
Poll::Ready(Ok(_)) => (), Poll::Ready(Ok(_)) => (),
Poll::Pending if len != 0 => break, Poll::Pending if len != 0 => break,
Poll::Pending => return Poll::Pending, Poll::Pending => return Poll::Pending,
@ -202,7 +211,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
Poll::Ready(Ok(len)) Poll::Ready(Ok(len))
} else { } else {
// not write zero // not write zero
match self.session.write(buf) { match this.session.write(buf) {
Ok(0) => Poll::Pending, Ok(0) => Poll::Pending,
Ok(n) => Poll::Ready(Ok(n)), Ok(n) => Poll::Ready(Ok(n)),
Err(err) => Poll::Ready(Err(err)) Err(err) => Poll::Ready(Err(err))
@ -210,18 +219,33 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
} }
} }
pub fn poll_flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.session.flush()?; let this = self.get_mut();
while self.session.wants_write() {
match self.complete_inner_io(cx, Focus::Writable) { this.session.flush()?;
while this.session.wants_write() {
match this.complete_inner_io(cx, Focus::Writable) {
Poll::Ready(Ok(_)) => (), Poll::Ready(Ok(_)) => (),
Poll::Pending => return Poll::Pending, Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)) Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
} }
} }
Pin::new(&mut self.io).poll_flush(cx) Pin::new(&mut this.io).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
while this.session.wants_write() {
match this.complete_inner_io(cx, Focus::Writable) {
Poll::Ready(Ok(_)) => (),
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err))
}
}
Pin::new(&mut this.io).poll_close(cx)
} }
} }
// #[cfg(test)] #[cfg(test)]
// mod test_stream; mod test_stream;

View File

@ -1,4 +1,10 @@
use std::pin::Pin;
use std::task::Poll;
use std::sync::Arc; use std::sync::Arc;
use futures::prelude::*;
use futures::task::{ Context, noop_waker_ref };
use futures::executor;
use futures::io::{ AsyncRead, AsyncWrite };
use std::io::{ self, Read, Write, BufReader, Cursor }; use std::io::{ self, Read, Write, BufReader, Cursor };
use webpki::DNSNameRef; use webpki::DNSNameRef;
use rustls::internal::pemfile::{ certs, rsa_private_keys }; use rustls::internal::pemfile::{ certs, rsa_private_keys };
@ -7,146 +13,172 @@ use rustls::{
ServerSession, ClientSession, ServerSession, ClientSession,
Session, NoClientAuth Session, NoClientAuth
}; };
use futures::{ Async, Poll };
use tokio_io::{ AsyncRead, AsyncWrite };
use super::Stream; use super::Stream;
struct Good<'a>(&'a mut Session); struct Good<'a>(&'a mut Session);
impl<'a> Read for Good<'a> { impl<'a> AsyncRead for Good<'a> {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> { fn poll_read(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &mut [u8]) -> Poll<io::Result<usize>> {
self.0.write_tls(buf.by_ref()) Poll::Ready(self.0.write_tls(buf.by_ref()))
} }
} }
impl<'a> Write for Good<'a> { impl<'a> AsyncWrite for Good<'a> {
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> { fn poll_write(mut self: Pin<&mut Self>, _cx: &mut Context<'_>, mut buf: &[u8]) -> Poll<io::Result<usize>> {
let len = self.0.read_tls(buf.by_ref())?; let len = self.0.read_tls(buf.by_ref())?;
self.0.process_new_packets() self.0.process_new_packets()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
Ok(len) Poll::Ready(Ok(len))
} }
fn flush(&mut self) -> io::Result<()> { fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()) Poll::Ready(Ok(()))
} }
}
impl<'a> AsyncRead for Good<'a> {} fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
impl<'a> AsyncWrite for Good<'a> { Poll::Ready(Ok(()))
fn shutdown(&mut self) -> Poll<(), io::Error> {
Ok(Async::Ready(()))
} }
} }
struct Bad(bool); struct Bad(bool);
impl Read for Bad { impl AsyncRead for Bad {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> { fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _: &mut [u8]) -> Poll<io::Result<usize>> {
Ok(0) Poll::Ready(Ok(0))
} }
} }
impl Write for Bad { impl AsyncWrite for Bad {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
if self.0 { if self.0 {
Err(io::ErrorKind::WouldBlock.into()) Poll::Pending
} else { } else {
Ok(buf.len()) Poll::Ready(Ok(buf.len()))
} }
} }
fn flush(&mut self) -> io::Result<()> { fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()) Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
} }
} }
impl AsyncRead for Bad {}
impl AsyncWrite for Bad {
fn shutdown(&mut self) -> Poll<(), io::Error> {
Ok(Async::Ready(()))
}
}
#[test] #[test]
fn stream_good() -> io::Result<()> { fn stream_good() -> io::Result<()> {
const FILE: &'static [u8] = include_bytes!("../../README.md"); const FILE: &'static [u8] = include_bytes!("../../README.md");
let (mut server, mut client) = make_pair(); let fut = async {
do_handshake(&mut client, &mut server); let (mut server, mut client) = make_pair();
io::copy(&mut Cursor::new(FILE), &mut server)?; future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
io::copy(&mut Cursor::new(FILE), &mut server)?;
{ {
let mut good = Good(&mut server); let mut good = Good(&mut server);
let mut stream = Stream::new(&mut good, &mut client); let mut stream = Stream::new(&mut good, &mut client);
let mut buf = Vec::new(); let mut buf = Vec::new();
stream.read_to_end(&mut buf)?; stream.read_to_end(&mut buf).await?;
assert_eq!(buf, FILE); assert_eq!(buf, FILE);
stream.write_all(b"Hello World!")?; stream.write_all(b"Hello World!").await?;
} }
let mut buf = String::new(); let mut buf = String::new();
server.read_to_string(&mut buf)?; server.read_to_string(&mut buf)?;
assert_eq!(buf, "Hello World!"); assert_eq!(buf, "Hello World!");
Ok(()) Ok(()) as io::Result<()>
};
executor::block_on(fut)
} }
#[test] #[test]
fn stream_bad() -> io::Result<()> { fn stream_bad() -> io::Result<()> {
let (mut server, mut client) = make_pair(); let fut = async {
do_handshake(&mut client, &mut server); let (mut server, mut client) = make_pair();
client.set_buffer_limit(1024); future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
client.set_buffer_limit(1024);
let mut bad = Bad(true); let mut bad = Bad(true);
let mut stream = Stream::new(&mut bad, &mut client); let mut stream = Stream::new(&mut bad, &mut client);
assert_eq!(stream.write(&[0x42; 8])?, 8); assert_eq!(future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x42; 8])).await?, 8);
assert_eq!(stream.write(&[0x42; 8])?, 8); assert_eq!(future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x42; 8])).await?, 8);
let r = stream.write(&[0x00; 1024])?; // fill buffer let r = future::poll_fn(|cx| stream.pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer
assert!(r < 1024); assert!(r < 1024);
assert_eq!(
stream.write(&[0x01]).unwrap_err().kind(),
io::ErrorKind::WouldBlock
);
Ok(()) let mut cx = Context::from_waker(noop_waker_ref());
assert!(stream.pin().poll_write(&mut cx, &[0x01]).is_pending());
Ok(()) as io::Result<()>
};
executor::block_on(fut)
} }
#[test] #[test]
fn stream_handshake() -> io::Result<()> { fn stream_handshake() -> io::Result<()> {
let (mut server, mut client) = make_pair(); let fut = async {
let (mut server, mut client) = make_pair();
{ {
let mut good = Good(&mut server); let mut good = Good(&mut server);
let mut stream = Stream::new(&mut good, &mut client); let mut stream = Stream::new(&mut good, &mut client);
let (r, w) = stream.complete_io()?; let (r, w) = future::poll_fn(|cx| stream.complete_io(cx)).await?;
assert!(r > 0); assert!(r > 0);
assert!(w > 0); assert!(w > 0);
stream.complete_io()?; // finish server handshake future::poll_fn(|cx| stream.complete_io(cx)).await?; // finish server handshake
} }
assert!(!server.is_handshaking()); assert!(!server.is_handshaking());
assert!(!client.is_handshaking()); assert!(!client.is_handshaking());
Ok(()) Ok(()) as io::Result<()>
};
executor::block_on(fut)
} }
#[test] #[test]
fn stream_handshake_eof() -> io::Result<()> { fn stream_handshake_eof() -> io::Result<()> {
let (_, mut client) = make_pair(); let fut = async {
let (_, mut client) = make_pair();
let mut bad = Bad(false); let mut bad = Bad(false);
let mut stream = Stream::new(&mut bad, &mut client); let mut stream = Stream::new(&mut bad, &mut client);
let r = stream.complete_io();
assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof); let mut cx = Context::from_waker(noop_waker_ref());
let r = stream.complete_io(&mut cx);
assert_eq!(r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof)));
Ok(()) Ok(()) as io::Result<()>
};
executor::block_on(fut)
}
#[test]
fn stream_eof() -> io::Result<()> {
let fut = async {
let (mut server, mut client) = make_pair();
future::poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
let mut good = Good(&mut server);
let mut stream = Stream::new(&mut good, &mut client).set_eof(true);
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await?;
assert_eq!(buf.len(), 0);
Ok(()) as io::Result<()>
};
executor::block_on(fut)
} }
fn make_pair() -> (ServerSession, ClientSession) { fn make_pair() -> (ServerSession, ClientSession) {
@ -169,9 +201,17 @@ fn make_pair() -> (ServerSession, ClientSession) {
(server, client) (server, client)
} }
fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) { fn do_handshake(client: &mut ClientSession, server: &mut ServerSession, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut good = Good(server); let mut good = Good(server);
let mut stream = Stream::new(&mut good, client); let mut stream = Stream::new(&mut good, client);
stream.complete_io().unwrap();
stream.complete_io().unwrap(); if stream.session.is_handshaking() {
try_ready!(stream.complete_io(cx));
}
if stream.session.wants_write() {
try_ready!(stream.complete_io(cx));
}
Poll::Ready(Ok(()))
} }

View File

@ -78,7 +78,7 @@ where
.set_eof(!this.state.readable()); .set_eof(!this.state.readable());
match this.state { match this.state {
TlsState::Stream | TlsState::WriteShutdown => match stream.poll_read(cx, buf) { TlsState::Stream | TlsState::WriteShutdown => match stream.pin().poll_read(cx, buf) {
Poll::Ready(Ok(0)) => { Poll::Ready(Ok(0)) => {
this.state.shutdown_read(); this.state.shutdown_read();
Poll::Ready(Ok(0)) Poll::Ready(Ok(0))
@ -108,16 +108,16 @@ where
{ {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
Stream::new(&mut this.io, &mut this.session) let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()) .set_eof(!this.state.readable());
.poll_write(cx, buf) stream.pin().poll_write(cx, buf)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
Stream::new(&mut this.io, &mut this.session) let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()) .set_eof(!this.state.readable());
.poll_flush(cx) stream.pin().poll_flush(cx)
} }
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
@ -127,9 +127,7 @@ where
} }
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
.set_eof(!this.state.readable()); stream.pin().poll_close(cx)
try_ready!(stream.complete_io(cx));
Pin::new(&mut this.io).poll_close(cx)
} }
} }