make simple test work
This commit is contained in:
parent
f7472e89a2
commit
b03c327ab6
@ -44,7 +44,7 @@ where
|
||||
type Output = io::Result<TlsStream<IO>>;
|
||||
|
||||
#[inline]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if let MidHandshake::Handshaking(stream) = this {
|
||||
@ -79,7 +79,7 @@ where
|
||||
Initializer::nop()
|
||||
}
|
||||
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
||||
match self.state {
|
||||
#[cfg(feature = "early-data")]
|
||||
TlsState::EarlyData => {
|
||||
@ -140,7 +140,7 @@ impl<IO> AsyncWrite for TlsStream<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
||||
.set_eof(!this.state.readable());
|
||||
@ -181,14 +181,14 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
Stream::new(&mut this.io, &mut this.session)
|
||||
.set_eof(!this.state.readable())
|
||||
.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
if self.state.writeable() {
|
||||
self.session.send_close_notify();
|
||||
self.state.shutdown_write();
|
||||
|
@ -1,5 +1,3 @@
|
||||
// mod vecbuf;
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::task::Poll;
|
||||
use std::marker::Unpin;
|
||||
|
@ -1,122 +0,0 @@
|
||||
use std::cmp::{ self, Ordering };
|
||||
use bytes::Buf;
|
||||
use iovec::IoVec;
|
||||
|
||||
pub struct VecBuf<'a, 'b: 'a> {
|
||||
pos: usize,
|
||||
cur: usize,
|
||||
inner: &'a [&'b [u8]]
|
||||
}
|
||||
|
||||
impl<'a, 'b> VecBuf<'a, 'b> {
|
||||
pub fn new(vbytes: &'a [&'b [u8]]) -> Self {
|
||||
VecBuf { pos: 0, cur: 0, inner: vbytes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> Buf for VecBuf<'a, 'b> {
|
||||
fn remaining(&self) -> usize {
|
||||
let sum = self.inner
|
||||
.iter()
|
||||
.skip(self.pos)
|
||||
.map(|bytes| bytes.len())
|
||||
.sum::<usize>();
|
||||
sum - self.cur
|
||||
}
|
||||
|
||||
fn bytes(&self) -> &[u8] {
|
||||
&self.inner[self.pos][self.cur..]
|
||||
}
|
||||
|
||||
fn advance(&mut self, cnt: usize) {
|
||||
let current = self.inner[self.pos].len();
|
||||
match (self.cur + cnt).cmp(¤t) {
|
||||
Ordering::Equal => if self.pos + 1 < self.inner.len() {
|
||||
self.pos += 1;
|
||||
self.cur = 0;
|
||||
} else {
|
||||
self.cur += cnt;
|
||||
},
|
||||
Ordering::Greater => {
|
||||
if self.pos + 1 < self.inner.len() {
|
||||
self.pos += 1;
|
||||
}
|
||||
let remaining = self.cur + cnt - current;
|
||||
self.advance(remaining);
|
||||
},
|
||||
Ordering::Less => self.cur += cnt,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
fn bytes_vec<'c>(&'c self, dst: &mut [&'c IoVec]) -> usize {
|
||||
let len = cmp::min(self.inner.len() - self.pos, dst.len());
|
||||
|
||||
if len > 0 {
|
||||
dst[0] = self.bytes().into();
|
||||
}
|
||||
|
||||
for i in 1..len {
|
||||
dst[i] = self.inner[self.pos + i].into();
|
||||
}
|
||||
|
||||
len
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_vecbuf {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_fresh_cursor_vec() {
|
||||
let mut buf = VecBuf::new(&[b"he", b"llo"]);
|
||||
|
||||
assert_eq!(buf.remaining(), 5);
|
||||
assert_eq!(buf.bytes(), b"he");
|
||||
|
||||
buf.advance(2);
|
||||
|
||||
assert_eq!(buf.remaining(), 3);
|
||||
assert_eq!(buf.bytes(), b"llo");
|
||||
|
||||
buf.advance(3);
|
||||
|
||||
assert_eq!(buf.remaining(), 0);
|
||||
assert_eq!(buf.bytes(), b"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_u8() {
|
||||
let mut buf = VecBuf::new(&[b"\x21z", b"omg"]);
|
||||
assert_eq!(0x21, buf.get_u8());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_u16() {
|
||||
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
|
||||
assert_eq!(0x2154, buf.get_u16_be());
|
||||
let mut buf = VecBuf::new(&[b"\x21\x54z", b"omg"]);
|
||||
assert_eq!(0x5421, buf.get_u16_le());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_get_u16_buffer_underflow() {
|
||||
let mut buf = VecBuf::new(&[b"\x21"]);
|
||||
buf.get_u16_be();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bufs_vec() {
|
||||
let buf = VecBuf::new(&[b"he", b"llo"]);
|
||||
|
||||
let b1: &[u8] = &mut [0];
|
||||
let b2: &[u8] = &mut [0];
|
||||
|
||||
let mut dst: [&IoVec; 2] =
|
||||
[b1.into(), b2.into()];
|
||||
|
||||
assert_eq!(2, buf.bytes_vec(&mut dst[..]));
|
||||
}
|
||||
}
|
14
src/lib.rs
14
src/lib.rs
@ -109,7 +109,7 @@ impl TlsConnector {
|
||||
|
||||
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
self.connect_with(domain, stream, |_| ())
|
||||
}
|
||||
@ -117,7 +117,7 @@ impl TlsConnector {
|
||||
#[inline]
|
||||
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
F: FnOnce(&mut ClientSession),
|
||||
{
|
||||
let mut session = ClientSession::new(&self.inner, domain);
|
||||
@ -156,7 +156,7 @@ impl TlsConnector {
|
||||
impl TlsAcceptor {
|
||||
pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
self.accept_with(stream, |_| ())
|
||||
}
|
||||
@ -164,7 +164,7 @@ impl TlsAcceptor {
|
||||
#[inline]
|
||||
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite,
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
F: FnOnce(&mut ServerSession),
|
||||
{
|
||||
let mut session = ServerSession::new(&self.inner);
|
||||
@ -189,7 +189,8 @@ pub struct Accept<IO>(server::MidHandshake<IO>);
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
|
||||
type Output = io::Result<client::TlsStream<IO>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||
#[inline]
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.0).poll(cx)
|
||||
}
|
||||
}
|
||||
@ -197,7 +198,8 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
|
||||
type Output = io::Result<server::TlsStream<IO>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||
#[inline]
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.0).poll(cx)
|
||||
}
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ where
|
||||
type Output = io::Result<TlsStream<IO>>;
|
||||
|
||||
#[inline]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if let MidHandshake::Handshaking(stream) = this {
|
||||
@ -72,7 +72,7 @@ where
|
||||
Initializer::nop()
|
||||
}
|
||||
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
let mut stream = Stream::new(&mut this.io, &mut this.session)
|
||||
.set_eof(!this.state.readable());
|
||||
@ -106,21 +106,21 @@ impl<IO> AsyncWrite for TlsStream<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
Stream::new(&mut this.io, &mut this.session)
|
||||
.set_eof(!this.state.readable())
|
||||
.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
Stream::new(&mut this.io, &mut this.session)
|
||||
.set_eof(!this.state.readable())
|
||||
.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
if self.state.writeable() {
|
||||
self.session.send_close_notify();
|
||||
self.state.shutdown_write();
|
||||
|
@ -1,17 +1,15 @@
|
||||
#![cfg(not(test))]
|
||||
|
||||
#[macro_use] extern crate lazy_static;
|
||||
extern crate rustls;
|
||||
extern crate tokio;
|
||||
extern crate tokio_rustls;
|
||||
extern crate webpki;
|
||||
#![feature(async_await)]
|
||||
|
||||
use std::{ io, thread };
|
||||
use std::io::{ BufReader, Cursor };
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc::channel;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::{ TcpListener, TcpStream };
|
||||
use lazy_static::lazy_static;
|
||||
use futures::prelude::*;
|
||||
use futures::executor;
|
||||
use futures::task::SpawnExt;
|
||||
use romio::tcp::{ TcpListener, TcpStream };
|
||||
use rustls::{ ServerConfig, ClientConfig };
|
||||
use rustls::internal::pemfile::{ certs, rsa_private_keys };
|
||||
use tokio_rustls::{ TlsConnector, TlsAcceptor };
|
||||
@ -22,9 +20,6 @@ const RSA: &str = include_str!("end.rsa");
|
||||
|
||||
lazy_static!{
|
||||
static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = {
|
||||
use tokio::prelude::*;
|
||||
use tokio::io as aio;
|
||||
|
||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||
|
||||
@ -36,26 +31,32 @@ lazy_static!{
|
||||
let (send, recv) = channel();
|
||||
|
||||
thread::spawn(move || {
|
||||
let done = async {
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
|
||||
let listener = TcpListener::bind(&addr).unwrap();
|
||||
let mut pool = executor::ThreadPool::new()?;
|
||||
let mut listener = TcpListener::bind(&addr)?;
|
||||
|
||||
send.send(listener.local_addr().unwrap()).unwrap();
|
||||
send.send(listener.local_addr()?).unwrap();
|
||||
|
||||
let done = listener.incoming()
|
||||
.for_each(move |stream| {
|
||||
let done = config.accept(stream)
|
||||
.and_then(|stream| {
|
||||
let (reader, writer) = stream.split();
|
||||
aio::copy(reader, writer)
|
||||
})
|
||||
.then(|_| Ok(()));
|
||||
let mut incoming = listener.incoming();
|
||||
while let Some(stream) = incoming.next().await {
|
||||
let config = config.clone();
|
||||
pool.spawn(
|
||||
async move {
|
||||
let stream = stream?;
|
||||
let stream = config.accept(stream).await?;
|
||||
let (mut reader, mut write) = stream.split();
|
||||
reader.copy_into(&mut write).await?;
|
||||
Ok(()) as io::Result<()>
|
||||
}
|
||||
.unwrap_or_else(|err| eprintln!("{:?}", err))
|
||||
).unwrap();
|
||||
}
|
||||
|
||||
tokio::spawn(done);
|
||||
Ok(())
|
||||
})
|
||||
.map_err(|err| panic!("{:?}", err));
|
||||
Ok(()) as io::Result<()>
|
||||
};
|
||||
|
||||
tokio::run(done);
|
||||
executor::block_on(done).unwrap();
|
||||
});
|
||||
|
||||
let addr = recv.recv().unwrap();
|
||||
@ -63,31 +64,26 @@ lazy_static!{
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
fn start_server() -> &'static (SocketAddr, &'static str, &'static str) {
|
||||
&*TEST_SERVER
|
||||
}
|
||||
|
||||
fn start_client(addr: &SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> {
|
||||
use tokio::prelude::*;
|
||||
use tokio::io as aio;
|
||||
|
||||
async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> {
|
||||
const FILE: &'static [u8] = include_bytes!("../README.md");
|
||||
|
||||
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
|
||||
let config = TlsConnector::from(config);
|
||||
let mut buf = vec![0; FILE.len()];
|
||||
|
||||
let stream = TcpStream::connect(&addr).await?;
|
||||
let mut stream = config.connect(domain, stream).await?;
|
||||
stream.write_all(FILE).await?;
|
||||
stream.read_exact(&mut buf).await?;
|
||||
|
||||
let done = TcpStream::connect(addr)
|
||||
.and_then(|stream| config.connect(domain, stream))
|
||||
.and_then(|stream| aio::write_all(stream, FILE))
|
||||
.and_then(|(stream, _)| aio::read_exact(stream, vec![0; FILE.len()]))
|
||||
.and_then(|(stream, buf)| {
|
||||
assert_eq!(buf, FILE);
|
||||
aio::shutdown(stream)
|
||||
})
|
||||
.map(drop);
|
||||
|
||||
done.wait()
|
||||
stream.close().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -99,7 +95,7 @@ fn pass() {
|
||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
||||
let config = Arc::new(config);
|
||||
|
||||
start_client(addr, domain, config.clone()).unwrap();
|
||||
executor::block_on(start_client(addr.clone(), domain, config.clone())).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -112,5 +108,5 @@ fn fail() {
|
||||
let config = Arc::new(config);
|
||||
|
||||
assert_ne!(domain, &"google.com");
|
||||
assert!(start_client(addr, "google.com", config).is_err());
|
||||
assert!(executor::block_on(start_client(addr.clone(), "google.com", config)).is_err());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user