make simple test work

This commit is contained in:
quininer 2019-05-20 00:28:27 +08:00
parent f7472e89a2
commit b03c327ab6
6 changed files with 58 additions and 184 deletions

View File

@ -44,7 +44,7 @@ where
type Output = io::Result<TlsStream<IO>>; type Output = io::Result<TlsStream<IO>>;
#[inline] #[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); let this = self.get_mut();
if let MidHandshake::Handshaking(stream) = this { if let MidHandshake::Handshaking(stream) = this {
@ -79,7 +79,7 @@ where
Initializer::nop() Initializer::nop()
} }
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
match self.state { match self.state {
#[cfg(feature = "early-data")] #[cfg(feature = "early-data")]
TlsState::EarlyData => { TlsState::EarlyData => {
@ -140,7 +140,7 @@ impl<IO> AsyncWrite for TlsStream<IO>
where where
IO: AsyncRead + AsyncWrite + Unpin, IO: AsyncRead + AsyncWrite + Unpin,
{ {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()); .set_eof(!this.state.readable());
@ -181,14 +181,14 @@ where
} }
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
Stream::new(&mut this.io, &mut this.session) Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()) .set_eof(!this.state.readable())
.poll_flush(cx) .poll_flush(cx)
} }
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.state.writeable() { if self.state.writeable() {
self.session.send_close_notify(); self.session.send_close_notify();
self.state.shutdown_write(); self.state.shutdown_write();

View File

@ -1,5 +1,3 @@
// mod vecbuf;
use std::pin::Pin; use std::pin::Pin;
use std::task::Poll; use std::task::Poll;
use std::marker::Unpin; use std::marker::Unpin;

View File

@ -1,122 +0,0 @@
use std::cmp::{ self, Ordering };
use bytes::Buf;
use iovec::IoVec;
pub struct VecBuf<'a, 'b: 'a> {
pos: usize,
cur: usize,
inner: &'a [&'b [u8]]
}
impl<'a, 'b> VecBuf<'a, 'b> {
pub fn new(vbytes: &'a [&'b [u8]]) -> Self {
VecBuf { pos: 0, cur: 0, inner: vbytes }
}
}
impl<'a, 'b> Buf for VecBuf<'a, 'b> {
fn remaining(&self) -> usize {
let sum = self.inner
.iter()
.skip(self.pos)
.map(|bytes| bytes.len())
.sum::<usize>();
sum - self.cur
}
fn bytes(&self) -> &[u8] {
&self.inner[self.pos][self.cur..]
}
fn advance(&mut self, cnt: usize) {
let current = self.inner[self.pos].len();
match (self.cur + cnt).cmp(&current) {
Ordering::Equal => if self.pos + 1 < self.inner.len() {
self.pos += 1;
self.cur = 0;
} else {
self.cur += cnt;
},
Ordering::Greater => {
if self.pos + 1 < self.inner.len() {
self.pos += 1;
}
let remaining = self.cur + cnt - current;
self.advance(remaining);
},
Ordering::Less => self.cur += cnt,
}
}
#[allow(clippy::needless_range_loop)]
fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize {
let len = cmp::min(self.inner.len() - self.pos, dst.len());
if len > 0 {
dst[0] = self.bytes().into();
}
for i in 1..len {
dst[i] = self.inner[self.pos + i].into();
}
len
}
}
#[cfg(test)]
mod test_vecbuf {
use super::*;
#[test]
fn test_fresh_cursor_vec() {
let mut buf = VecBuf::new(&[b"he", b"llo"]);
assert_eq!(buf.remaining(), 5);
assert_eq!(buf.bytes(), b"he");
buf.advance(2);
assert_eq!(buf.remaining(), 3);
assert_eq!(buf.bytes(), b"llo");
buf.advance(3);
assert_eq!(buf.remaining(), 0);
assert_eq!(buf.bytes(), b"");
}
#[test]
fn test_get_u8() {
let mut buf = VecBuf::new(&[b"\x21z", b"omg"]);
assert_eq!(0x21, buf.get_u8());
}
#[test]
fn test_get_u16() {
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
assert_eq!(0x2154, buf.get_u16_be());
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
assert_eq!(0x5421, buf.get_u16_le());
}
#[test]
#[should_panic]
fn test_get_u16_buffer_underflow() {
let mut buf = VecBuf::new(&[b"\x21"]);
buf.get_u16_be();
}
#[test]
fn test_bufs_vec() {
let buf = VecBuf::new(&[b"he", b"llo"]);
let b1: &[u8] = &mut [0];
let b2: &[u8] = &mut [0];
let mut dst: [&IoVec; 2] =
[b1.into(), b2.into()];
assert_eq!(2, buf.bytes_vec(&mut dst[..]));
}
}

View File

@ -109,7 +109,7 @@ impl TlsConnector {
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO> pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite + Unpin,
{ {
self.connect_with(domain, stream, |_| ()) self.connect_with(domain, stream, |_| ())
} }
@ -117,7 +117,7 @@ impl TlsConnector {
#[inline] #[inline]
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO> pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(&mut ClientSession), F: FnOnce(&mut ClientSession),
{ {
let mut session = ClientSession::new(&self.inner, domain); let mut session = ClientSession::new(&self.inner, domain);
@ -156,7 +156,7 @@ impl TlsConnector {
impl TlsAcceptor { impl TlsAcceptor {
pub fn accept<IO>(&self, stream: IO) -> Accept<IO> pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite + Unpin,
{ {
self.accept_with(stream, |_| ()) self.accept_with(stream, |_| ())
} }
@ -164,7 +164,7 @@ impl TlsAcceptor {
#[inline] #[inline]
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO> pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
where where
IO: AsyncRead + AsyncWrite, IO: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(&mut ServerSession), F: FnOnce(&mut ServerSession),
{ {
let mut session = ServerSession::new(&self.inner); let mut session = ServerSession::new(&self.inner);
@ -189,7 +189,8 @@ pub struct Accept<IO>(server::MidHandshake<IO>);
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> { impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
type Output = io::Result<client::TlsStream<IO>>; type Output = io::Result<client::TlsStream<IO>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { #[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx) Pin::new(&mut self.0).poll(cx)
} }
} }
@ -197,7 +198,8 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> { impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
type Output = io::Result<server::TlsStream<IO>>; type Output = io::Result<server::TlsStream<IO>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { #[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx) Pin::new(&mut self.0).poll(cx)
} }
} }

View File

@ -39,7 +39,7 @@ where
type Output = io::Result<TlsStream<IO>>; type Output = io::Result<TlsStream<IO>>;
#[inline] #[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); let this = self.get_mut();
if let MidHandshake::Handshaking(stream) = this { if let MidHandshake::Handshaking(stream) = this {
@ -72,7 +72,7 @@ where
Initializer::nop() Initializer::nop()
} }
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = Stream::new(&mut this.io, &mut this.session) let mut stream = Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()); .set_eof(!this.state.readable());
@ -106,21 +106,21 @@ impl<IO> AsyncWrite for TlsStream<IO>
where where
IO: AsyncRead + AsyncWrite + Unpin, IO: AsyncRead + AsyncWrite + Unpin,
{ {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
Stream::new(&mut this.io, &mut this.session) Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()) .set_eof(!this.state.readable())
.poll_write(cx, buf) .poll_write(cx, buf)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
Stream::new(&mut this.io, &mut this.session) Stream::new(&mut this.io, &mut this.session)
.set_eof(!this.state.readable()) .set_eof(!this.state.readable())
.poll_flush(cx) .poll_flush(cx)
} }
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.state.writeable() { if self.state.writeable() {
self.session.send_close_notify(); self.session.send_close_notify();
self.state.shutdown_write(); self.state.shutdown_write();

View File

@ -1,17 +1,15 @@
#![cfg(not(test))] #![feature(async_await)]
#[macro_use] extern crate lazy_static;
extern crate rustls;
extern crate tokio;
extern crate tokio_rustls;
extern crate webpki;
use std::{ io, thread }; use std::{ io, thread };
use std::io::{ BufReader, Cursor }; use std::io::{ BufReader, Cursor };
use std::sync::Arc; use std::sync::Arc;
use std::sync::mpsc::channel; use std::sync::mpsc::channel;
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::net::{ TcpListener, TcpStream }; use lazy_static::lazy_static;
use futures::prelude::*;
use futures::executor;
use futures::task::SpawnExt;
use romio::tcp::{ TcpListener, TcpStream };
use rustls::{ ServerConfig, ClientConfig }; use rustls::{ ServerConfig, ClientConfig };
use rustls::internal::pemfile::{ certs, rsa_private_keys }; use rustls::internal::pemfile::{ certs, rsa_private_keys };
use tokio_rustls::{ TlsConnector, TlsAcceptor }; use tokio_rustls::{ TlsConnector, TlsAcceptor };
@ -22,9 +20,6 @@ const RSA: &str = include_str!("end.rsa");
lazy_static!{ lazy_static!{
static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = {
use tokio::prelude::*;
use tokio::io as aio;
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
@ -36,26 +31,32 @@ lazy_static!{
let (send, recv) = channel(); let (send, recv) = channel();
thread::spawn(move || { thread::spawn(move || {
let done = async {
let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(&addr).unwrap(); let mut pool = executor::ThreadPool::new()?;
let mut listener = TcpListener::bind(&addr)?;
send.send(listener.local_addr().unwrap()).unwrap(); send.send(listener.local_addr()?).unwrap();
let done = listener.incoming() let mut incoming = listener.incoming();
.for_each(move |stream| { while let Some(stream) = incoming.next().await {
let done = config.accept(stream) let config = config.clone();
.and_then(|stream| { pool.spawn(
let (reader, writer) = stream.split(); async move {
aio::copy(reader, writer) let stream = stream?;
}) let stream = config.accept(stream).await?;
.then(|_| Ok(())); let (mut reader, mut write) = stream.split();
reader.copy_into(&mut write).await?;
Ok(()) as io::Result<()>
}
.unwrap_or_else(|err| eprintln!("{:?}", err))
).unwrap();
}
tokio::spawn(done); Ok(()) as io::Result<()>
Ok(()) };
})
.map_err(|err| panic!("{:?}", err));
tokio::run(done); executor::block_on(done).unwrap();
}); });
let addr = recv.recv().unwrap(); let addr = recv.recv().unwrap();
@ -63,31 +64,26 @@ lazy_static!{
}; };
} }
fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { fn start_server() -> &'static (SocketAddr, &'static str, &'static str) {
&*TEST_SERVER &*TEST_SERVER
} }
fn start_client(addr: &SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> { async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> {
use tokio::prelude::*;
use tokio::io as aio;
const FILE: &'static [u8] = include_bytes!("../README.md"); const FILE: &'static [u8] = include_bytes!("../README.md");
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
let config = TlsConnector::from(config); let config = TlsConnector::from(config);
let mut buf = vec![0; FILE.len()];
let stream = TcpStream::connect(&addr).await?;
let mut stream = config.connect(domain, stream).await?;
stream.write_all(FILE).await?;
stream.read_exact(&mut buf).await?;
let done = TcpStream::connect(addr)
.and_then(|stream| config.connect(domain, stream))
.and_then(|stream| aio::write_all(stream, FILE))
.and_then(|(stream, _)| aio::read_exact(stream, vec![0; FILE.len()]))
.and_then(|(stream, buf)| {
assert_eq!(buf, FILE); assert_eq!(buf, FILE);
aio::shutdown(stream)
})
.map(drop);
done.wait() stream.close().await?;
Ok(())
} }
#[test] #[test]
@ -99,7 +95,7 @@ fn pass() {
config.root_store.add_pem_file(&mut chain).unwrap(); config.root_store.add_pem_file(&mut chain).unwrap();
let config = Arc::new(config); let config = Arc::new(config);
start_client(addr, domain, config.clone()).unwrap(); executor::block_on(start_client(addr.clone(), domain, config.clone())).unwrap();
} }
#[test] #[test]
@ -112,5 +108,5 @@ fn fail() {
let config = Arc::new(config); let config = Arc::new(config);
assert_ne!(domain, &"google.com"); assert_ne!(domain, &"google.com");
assert!(start_client(addr, "google.com", config).is_err()); assert!(executor::block_on(start_client(addr.clone(), "google.com", config)).is_err());
} }