From f7472e89a214b77826fbc57ec5078fe9ef4068ac Mon Sep 17 00:00:00 2001 From: quininer Date: Sun, 19 May 2019 00:48:56 +0800 Subject: [PATCH] make early data test work --- Cargo.toml | 1 + src/client.rs | 3 ++- src/common/mod.rs | 4 ++-- src/lib.rs | 20 ++++++++++---------- src/test_0rtt.rs | 37 ++++++++++++++++--------------------- tests/test.rs | 2 ++ 6 files changed, 33 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8a44728..cb60566 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ webpki = "0.19" early-data = [] [dev-dependencies] +romio = "0.3.0-alpha.8" tokio = "0.1.6" lazy_static = "1" webpki-roots = "0.16" diff --git a/src/client.rs b/src/client.rs index 8527121..a2ebdd2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,5 @@ use super::*; use rustls::Session; -use std::io::Write; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -149,6 +148,8 @@ where match this.state { #[cfg(feature = "early-data")] TlsState::EarlyData => { + use std::io::Write; + let (pos, data) = &mut this.early_data; // write early data diff --git a/src/common/mod.rs b/src/common/mod.rs index 98afcd6..d20d5a8 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -225,5 +225,5 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { } } -#[cfg(test)] -mod test_stream; +// #[cfg(test)] +// mod test_stream; diff --git a/src/lib.rs b/src/lib.rs index f962fc0..d849f33 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). +#![feature(async_await)] + macro_rules! try_ready { ( $e:expr ) => { match $e { @@ -10,19 +12,19 @@ macro_rules! try_ready { } } -pub mod client; mod common; +pub mod client; pub mod server; -use common::Stream; -use std::pin::Pin; -use std::task::{ Poll, Context }; -use std::future::Future; -use futures::io::{ AsyncRead, AsyncWrite, Initializer }; -use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession}; +use std::{ io, mem }; use std::sync::Arc; -use std::{io, mem}; +use std::pin::Pin; +use std::future::Future; +use std::task::{ Poll, Context }; +use futures::io::{ AsyncRead, AsyncWrite, Initializer }; +use rustls::{ ClientConfig, ClientSession, ServerConfig, ServerSession }; use webpki::DNSNameRef; +use common::Stream; #[derive(Debug, Copy, Clone)] enum TlsState { @@ -200,8 +202,6 @@ impl Future for Accept { } } -/* #[cfg(feature = "early-data")] #[cfg(test)] mod test_0rtt; -*/ diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs index 0182406..8c8db6c 100644 --- a/src/test_0rtt.rs +++ b/src/test_0rtt.rs @@ -1,36 +1,31 @@ -extern crate tokio; -extern crate webpki; -extern crate webpki_roots; - use std::io; use std::sync::Arc; use std::net::ToSocketAddrs; -use self::tokio::io as aio; -use self::tokio::prelude::*; -use self::tokio::net::TcpStream; +use futures::executor; +use futures::prelude::*; +use romio::tcp::TcpStream; use rustls::ClientConfig; -use ::{ TlsConnector, client::TlsStream }; +use crate::{ TlsConnector, client::TlsStream }; -fn get(config: Arc, domain: &str, rtt0: bool) +async fn get(config: Arc, domain: &str, rtt0: bool) -> io::Result<(TlsStream, String)> { - let config = TlsConnector::from(config).early_data(rtt0); + let connector = TlsConnector::from(config).early_data(rtt0); let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); let addr = (domain, 443) .to_socket_addrs()? .next().unwrap(); + let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + let mut buf = Vec::new(); - TcpStream::connect(&addr) - .and_then(move |stream| { - let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); - config.connect(domain, stream) - }) - .and_then(move |stream| aio::write_all(stream, input)) - .and_then(move |(stream, _)| aio::read_to_end(stream, Vec::new())) - .map(|(stream, buf)| (stream, String::from_utf8(buf).unwrap())) - .wait() + let stream = TcpStream::connect(&addr).await?; + let mut stream = connector.connect(domain, stream).await?; + stream.write_all(input.as_bytes()).await?; + stream.read_to_end(&mut buf).await?; + + Ok((stream, String::from_utf8(buf).unwrap())) } #[test] @@ -41,10 +36,10 @@ fn test_0rtt() { let config = Arc::new(config); let domain = "mozilla-modern.badssl.com"; - let (_, output) = get(config.clone(), domain, false).unwrap(); + let (_, output) = executor::block_on(get(config.clone(), domain, false)).unwrap(); assert!(output.contains("mozilla-modern.badssl.com")); - let (io, output) = get(config.clone(), domain, true).unwrap(); + let (io, output) = executor::block_on(get(config.clone(), domain, true)).unwrap(); assert!(output.contains("mozilla-modern.badssl.com")); assert_eq!(io.early_data.0, 0); diff --git a/tests/test.rs b/tests/test.rs index f0703f8..533e4e4 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,3 +1,5 @@ +#![cfg(not(test))] + #[macro_use] extern crate lazy_static; extern crate rustls; extern crate tokio;