From b03c327ab6ae86ca89da602e7c4e716a2cb54e7c Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 20 May 2019 00:28:27 +0800 Subject: [PATCH] make simple test work --- src/client.rs | 10 ++-- src/common/mod.rs | 2 - src/common/vecbuf.rs | 122 ------------------------------------------- src/lib.rs | 14 ++--- src/server.rs | 10 ++-- tests/test.rs | 84 ++++++++++++++--------------- 6 files changed, 58 insertions(+), 184 deletions(-) delete mode 100644 src/common/vecbuf.rs diff --git a/src/client.rs b/src/client.rs index a2ebdd2..9e89468 100644 --- a/src/client.rs +++ b/src/client.rs @@ -44,7 +44,7 @@ where type Output = io::Result>; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); if let MidHandshake::Handshaking(stream) = this { @@ -79,7 +79,7 @@ where Initializer::nop() } - fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { match self.state { #[cfg(feature = "early-data")] TlsState::EarlyData => { @@ -140,7 +140,7 @@ impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); @@ -181,14 +181,14 @@ where } } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()) .poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); diff --git a/src/common/mod.rs b/src/common/mod.rs index d20d5a8..eacf585 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,5 +1,3 @@ -// mod vecbuf; - use std::pin::Pin; use std::task::Poll; use std::marker::Unpin; diff --git a/src/common/vecbuf.rs b/src/common/vecbuf.rs deleted file mode 100644 index e550505..0000000 --- a/src/common/vecbuf.rs +++ /dev/null @@ -1,122 +0,0 @@ -use std::cmp::{ self, Ordering }; -use bytes::Buf; -use iovec::IoVec; - -pub struct VecBuf<'a, 'b: 'a> { - pos: usize, - cur: usize, - inner: &'a [&'b [u8]] -} - -impl<'a, 'b> VecBuf<'a, 'b> { - pub fn new(vbytes: &'a [&'b [u8]]) -> Self { - VecBuf { pos: 0, cur: 0, inner: vbytes } - } -} - -impl<'a, 'b> Buf for VecBuf<'a, 'b> { - fn remaining(&self) -> usize { - let sum = self.inner - .iter() - .skip(self.pos) - .map(|bytes| bytes.len()) - .sum::(); - sum - self.cur - } - - fn bytes(&self) -> &[u8] { - &self.inner[self.pos][self.cur..] - } - - fn advance(&mut self, cnt: usize) { - let current = self.inner[self.pos].len(); - match (self.cur + cnt).cmp(¤t) { - Ordering::Equal => if self.pos + 1 < self.inner.len() { - self.pos += 1; - self.cur = 0; - } else { - self.cur += cnt; - }, - Ordering::Greater => { - if self.pos + 1 < self.inner.len() { - self.pos += 1; - } - let remaining = self.cur + cnt - current; - self.advance(remaining); - }, - Ordering::Less => self.cur += cnt, - } - } - - #[allow(clippy::needless_range_loop)] - fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize { - let len = cmp::min(self.inner.len() - self.pos, dst.len()); - - if len > 0 { - dst[0] = self.bytes().into(); - } - - for i in 1..len { - dst[i] = self.inner[self.pos + i].into(); - } - - len - } -} - -#[cfg(test)] -mod test_vecbuf { - use super::*; - - #[test] - fn test_fresh_cursor_vec() { - let mut buf = VecBuf::new(&[b"he", b"llo"]); - - assert_eq!(buf.remaining(), 5); - assert_eq!(buf.bytes(), b"he"); - - buf.advance(2); - - assert_eq!(buf.remaining(), 3); - assert_eq!(buf.bytes(), b"llo"); - - buf.advance(3); - - assert_eq!(buf.remaining(), 0); - assert_eq!(buf.bytes(), b""); - } - - #[test] - fn test_get_u8() { - let mut buf = VecBuf::new(&[b"\x21z", b"omg"]); - assert_eq!(0x21, buf.get_u8()); - } - - #[test] - fn test_get_u16() { - let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); - assert_eq!(0x2154, buf.get_u16_be()); - let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]); - assert_eq!(0x5421, buf.get_u16_le()); - } - - #[test] - #[should_panic] - fn test_get_u16_buffer_underflow() { - let mut buf = VecBuf::new(&[b"\x21"]); - buf.get_u16_be(); - } - - #[test] - fn test_bufs_vec() { - let buf = VecBuf::new(&[b"he", b"llo"]); - - let b1: &[u8] = &mut [0]; - let b2: &[u8] = &mut [0]; - - let mut dst: [&IoVec; 2] = - [b1.into(), b2.into()]; - - assert_eq!(2, buf.bytes_vec(&mut dst[..])); - } -} diff --git a/src/lib.rs b/src/lib.rs index d849f33..df3c259 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,7 +109,7 @@ impl TlsConnector { pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { self.connect_with(domain, stream, |_| ()) } @@ -117,7 +117,7 @@ impl TlsConnector { #[inline] pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ClientSession), { let mut session = ClientSession::new(&self.inner, domain); @@ -156,7 +156,7 @@ impl TlsConnector { impl TlsAcceptor { pub fn accept(&self, stream: IO) -> Accept where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, { self.accept_with(stream, |_| ()) } @@ -164,7 +164,7 @@ impl TlsAcceptor { #[inline] pub fn accept_with(&self, stream: IO, f: F) -> Accept where - IO: AsyncRead + AsyncWrite, + IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ServerSession), { let mut session = ServerSession::new(&self.inner); @@ -189,7 +189,8 @@ pub struct Accept(server::MidHandshake); impl Future for Connect { type Output = io::Result>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(&mut self.0).poll(cx) } } @@ -197,7 +198,8 @@ impl Future for Connect { impl Future for Accept { type Output = io::Result>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::new(&mut self.0).poll(cx) } } diff --git a/src/server.rs b/src/server.rs index 9db4867..ba054a9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -39,7 +39,7 @@ where type Output = io::Result>; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); if let MidHandshake::Handshaking(stream) = this { @@ -72,7 +72,7 @@ where Initializer::nop() } - fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()); @@ -106,21 +106,21 @@ impl AsyncWrite for TlsStream where IO: AsyncRead + AsyncWrite + Unpin, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()) .poll_write(cx, buf) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); Stream::new(&mut this.io, &mut this.session) .set_eof(!this.state.readable()) .poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.state.writeable() { self.session.send_close_notify(); self.state.shutdown_write(); diff --git a/tests/test.rs b/tests/test.rs index 533e4e4..a7fd2f2 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,17 +1,15 @@ -#![cfg(not(test))] - -#[macro_use] extern crate lazy_static; -extern crate rustls; -extern crate tokio; -extern crate tokio_rustls; -extern crate webpki; +#![feature(async_await)] use std::{ io, thread }; use std::io::{ BufReader, Cursor }; use std::sync::Arc; use std::sync::mpsc::channel; use std::net::SocketAddr; -use tokio::net::{ TcpListener, TcpStream }; +use lazy_static::lazy_static; +use futures::prelude::*; +use futures::executor; +use futures::task::SpawnExt; +use romio::tcp::{ TcpListener, TcpStream }; use rustls::{ ServerConfig, ClientConfig }; use rustls::internal::pemfile::{ certs, rsa_private_keys }; use tokio_rustls::{ TlsConnector, TlsAcceptor }; @@ -22,9 +20,6 @@ const RSA: &str = include_str!("end.rsa"); lazy_static!{ static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { - use tokio::prelude::*; - use tokio::io as aio; - let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); @@ -36,26 +31,32 @@ lazy_static!{ let (send, recv) = channel(); thread::spawn(move || { - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let listener = TcpListener::bind(&addr).unwrap(); + let done = async { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let mut pool = executor::ThreadPool::new()?; + let mut listener = TcpListener::bind(&addr)?; - send.send(listener.local_addr().unwrap()).unwrap(); + send.send(listener.local_addr()?).unwrap(); - let done = listener.incoming() - .for_each(move |stream| { - let done = config.accept(stream) - .and_then(|stream| { - let (reader, writer) = stream.split(); - aio::copy(reader, writer) - }) - .then(|_| Ok(())); + let mut incoming = listener.incoming(); + while let Some(stream) = incoming.next().await { + let config = config.clone(); + pool.spawn( + async move { + let stream = stream?; + let stream = config.accept(stream).await?; + let (mut reader, mut write) = stream.split(); + reader.copy_into(&mut write).await?; + Ok(()) as io::Result<()> + } + .unwrap_or_else(|err| eprintln!("{:?}", err)) + ).unwrap(); + } - tokio::spawn(done); - Ok(()) - }) - .map_err(|err| panic!("{:?}", err)); + Ok(()) as io::Result<()> + }; - tokio::run(done); + executor::block_on(done).unwrap(); }); let addr = recv.recv().unwrap(); @@ -63,31 +64,26 @@ lazy_static!{ }; } - fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { &*TEST_SERVER } -fn start_client(addr: &SocketAddr, domain: &str, config: Arc) -> io::Result<()> { - use tokio::prelude::*; - use tokio::io as aio; - +async fn start_client(addr: SocketAddr, domain: &str, config: Arc) -> io::Result<()> { const FILE: &'static [u8] = include_bytes!("../README.md"); let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let config = TlsConnector::from(config); + let mut buf = vec![0; FILE.len()]; - let done = TcpStream::connect(addr) - .and_then(|stream| config.connect(domain, stream)) - .and_then(|stream| aio::write_all(stream, FILE)) - .and_then(|(stream, _)| aio::read_exact(stream, vec![0; FILE.len()])) - .and_then(|(stream, buf)| { - assert_eq!(buf, FILE); - aio::shutdown(stream) - }) - .map(drop); + let stream = TcpStream::connect(&addr).await?; + let mut stream = config.connect(domain, stream).await?; + stream.write_all(FILE).await?; + stream.read_exact(&mut buf).await?; - done.wait() + assert_eq!(buf, FILE); + + stream.close().await?; + Ok(()) } #[test] @@ -99,7 +95,7 @@ fn pass() { config.root_store.add_pem_file(&mut chain).unwrap(); let config = Arc::new(config); - start_client(addr, domain, config.clone()).unwrap(); + executor::block_on(start_client(addr.clone(), domain, config.clone())).unwrap(); } #[test] @@ -112,5 +108,5 @@ fn fail() { let config = Arc::new(config); assert_ne!(domain, &"google.com"); - assert!(start_client(addr, "google.com", config).is_err()); + assert!(executor::block_on(start_client(addr.clone(), "google.com", config)).is_err()); }