From 65932f5150158aa1816b4e5915a34cce637637cf Mon Sep 17 00:00:00 2001 From: quininer Date: Mon, 18 Feb 2019 20:01:37 +0800 Subject: [PATCH] Add 0-RTT test --- Cargo.toml | 1 + src/lib.rs | 9 +++++++-- src/test_0rtt.rs | 51 ++++++++++++++++++++++++++++++++++++++++++++++++ tests/test.rs | 21 +++++++++++++------- 4 files changed, 73 insertions(+), 9 deletions(-) create mode 100644 src/test_0rtt.rs diff --git a/Cargo.toml b/Cargo.toml index 9de53ac..b15bea2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,4 @@ webpki = "0.19" [dev-dependencies] tokio = "0.1.6" lazy_static = "1" +webpki-roots = "0.16" diff --git a/src/lib.rs b/src/lib.rs index 378c693..dd34452 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,10 @@ impl From> for TlsAcceptor { } impl TlsConnector { + /// Enable 0-RTT. + /// + /// Note that you want to use 0-RTT. + /// You must set `enable_early_data` to `true` in `ClientConfig`. pub fn early_data(mut self, flag: bool) -> TlsConnector { self.early_data = flag; self @@ -186,7 +190,6 @@ where IO: AsyncRead + AsyncWrite // end self.state = TlsState::Stream; - *pos = 0; data.clear(); stream.read(buf) }, @@ -266,7 +269,6 @@ where IO: AsyncRead + AsyncWrite // end self.state = TlsState::Stream; - *pos = 0; data.clear(); stream.write(buf) }, @@ -293,3 +295,6 @@ where IO: AsyncRead + AsyncWrite self.io.flush() } } + +#[cfg(test)] +mod test_0rtt; diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs new file mode 100644 index 0000000..56c9d7b --- /dev/null +++ b/src/test_0rtt.rs @@ -0,0 +1,51 @@ +extern crate tokio; +extern crate webpki; +extern crate webpki_roots; + +use std::io; +use std::sync::Arc; +use std::net::ToSocketAddrs; +use self::tokio::io as aio; +use self::tokio::prelude::*; +use self::tokio::net::TcpStream; +use rustls::{ ClientConfig, ClientSession }; +use ::{ TlsConnector, TlsStream }; + + +fn get(config: Arc, domain: &str, rtt0: bool) + -> io::Result<(TlsStream, String)> +{ + let config = TlsConnector::from(config).early_data(rtt0); + let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); + + let addr = (domain, 443) + .to_socket_addrs()? + .next().unwrap(); + + TcpStream::connect(&addr) + .and_then(move |stream| { + let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + config.connect(domain, stream) + }) + .and_then(move |stream| aio::write_all(stream, input)) + .and_then(move |(stream, _)| aio::read_to_end(stream, Vec::new())) + .map(|(stream, buf)| (stream, String::from_utf8(buf).unwrap())) + .wait() +} + +#[test] +fn test_0rtt() { + let mut config = ClientConfig::new(); + config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + config.enable_early_data = true; + let config = Arc::new(config); + let domain = "mozilla-modern.badssl.com"; + + let (_, output) = get(config.clone(), domain, false).unwrap(); + assert!(output.contains("mozilla-modern.badssl.com")); + + let (io, output) = get(config.clone(), domain, true).unwrap(); + assert!(output.contains("mozilla-modern.badssl.com")); + + assert_eq!(io.early_data.0, 0); +} diff --git a/tests/test.rs b/tests/test.rs index 8833253..f0703f8 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -66,17 +66,14 @@ fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { &*TEST_SERVER } -fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> { +fn start_client(addr: &SocketAddr, domain: &str, config: Arc) -> io::Result<()> { use tokio::prelude::*; use tokio::io as aio; const FILE: &'static [u8] = include_bytes!("../README.md"); let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); - let mut config = ClientConfig::new(); - let mut chain = BufReader::new(Cursor::new(chain)); - config.root_store.add_pem_file(&mut chain).unwrap(); - let config = TlsConnector::from(Arc::new(config)); + let config = TlsConnector::from(config); let done = TcpStream::connect(addr) .and_then(|stream| config.connect(domain, stream)) @@ -95,13 +92,23 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> fn pass() { let (addr, domain, chain) = start_server(); - start_client(addr, domain, chain).unwrap(); + let mut config = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); + let config = Arc::new(config); + + start_client(addr, domain, config.clone()).unwrap(); } #[test] fn fail() { let (addr, domain, chain) = start_server(); + let mut config = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); + let config = Arc::new(config); + assert_ne!(domain, &"google.com"); - assert!(start_client(addr, "google.com", chain).is_err()); + assert!(start_client(addr, "google.com", config).is_err()); }