Add 0-RTT test

This commit is contained in:
quininer 2019-02-18 20:01:37 +08:00
parent 3e605aafe4
commit 65932f5150
4 changed files with 73 additions and 9 deletions

View File

@ -25,3 +25,4 @@ webpki = "0.19"
[dev-dependencies] [dev-dependencies]
tokio = "0.1.6" tokio = "0.1.6"
lazy_static = "1" lazy_static = "1"
webpki-roots = "0.16"

View File

@ -48,6 +48,10 @@ impl From<Arc<ServerConfig>> for TlsAcceptor {
} }
impl TlsConnector { impl TlsConnector {
/// Enable 0-RTT.
///
/// Note that you want to use 0-RTT.
/// You must set `enable_early_data` to `true` in `ClientConfig`.
pub fn early_data(mut self, flag: bool) -> TlsConnector { pub fn early_data(mut self, flag: bool) -> TlsConnector {
self.early_data = flag; self.early_data = flag;
self self
@ -186,7 +190,6 @@ where IO: AsyncRead + AsyncWrite
// end // end
self.state = TlsState::Stream; self.state = TlsState::Stream;
*pos = 0;
data.clear(); data.clear();
stream.read(buf) stream.read(buf)
}, },
@ -266,7 +269,6 @@ where IO: AsyncRead + AsyncWrite
// end // end
self.state = TlsState::Stream; self.state = TlsState::Stream;
*pos = 0;
data.clear(); data.clear();
stream.write(buf) stream.write(buf)
}, },
@ -293,3 +295,6 @@ where IO: AsyncRead + AsyncWrite
self.io.flush() self.io.flush()
} }
} }
#[cfg(test)]
mod test_0rtt;

51
src/test_0rtt.rs Normal file
View File

@ -0,0 +1,51 @@
extern crate tokio;
extern crate webpki;
extern crate webpki_roots;
use std::io;
use std::sync::Arc;
use std::net::ToSocketAddrs;
use self::tokio::io as aio;
use self::tokio::prelude::*;
use self::tokio::net::TcpStream;
use rustls::{ ClientConfig, ClientSession };
use ::{ TlsConnector, TlsStream };
fn get(config: Arc<ClientConfig>, domain: &str, rtt0: bool)
-> io::Result<(TlsStream<TcpStream, ClientSession>, String)>
{
let config = TlsConnector::from(config).early_data(rtt0);
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
let addr = (domain, 443)
.to_socket_addrs()?
.next().unwrap();
TcpStream::connect(&addr)
.and_then(move |stream| {
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
config.connect(domain, stream)
})
.and_then(move |stream| aio::write_all(stream, input))
.and_then(move |(stream, _)| aio::read_to_end(stream, Vec::new()))
.map(|(stream, buf)| (stream, String::from_utf8(buf).unwrap()))
.wait()
}
#[test]
fn test_0rtt() {
let mut config = ClientConfig::new();
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
config.enable_early_data = true;
let config = Arc::new(config);
let domain = "mozilla-modern.badssl.com";
let (_, output) = get(config.clone(), domain, false).unwrap();
assert!(output.contains("<title>mozilla-modern.badssl.com</title>"));
let (io, output) = get(config.clone(), domain, true).unwrap();
assert!(output.contains("<title>mozilla-modern.badssl.com</title>"));
assert_eq!(io.early_data.0, 0);
}

View File

@ -66,17 +66,14 @@ fn start_server() -> &'static (SocketAddr, &'static str, &'static str) {
&*TEST_SERVER &*TEST_SERVER
} }
fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> { fn start_client(addr: &SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> {
use tokio::prelude::*; use tokio::prelude::*;
use tokio::io as aio; use tokio::io as aio;
const FILE: &'static [u8] = include_bytes!("../README.md"); const FILE: &'static [u8] = include_bytes!("../README.md");
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
let mut config = ClientConfig::new(); let config = TlsConnector::from(config);
let mut chain = BufReader::new(Cursor::new(chain));
config.root_store.add_pem_file(&mut chain).unwrap();
let config = TlsConnector::from(Arc::new(config));
let done = TcpStream::connect(addr) let done = TcpStream::connect(addr)
.and_then(|stream| config.connect(domain, stream)) .and_then(|stream| config.connect(domain, stream))
@ -95,13 +92,23 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()>
fn pass() { fn pass() {
let (addr, domain, chain) = start_server(); let (addr, domain, chain) = start_server();
start_client(addr, domain, chain).unwrap(); let mut config = ClientConfig::new();
let mut chain = BufReader::new(Cursor::new(chain));
config.root_store.add_pem_file(&mut chain).unwrap();
let config = Arc::new(config);
start_client(addr, domain, config.clone()).unwrap();
} }
#[test] #[test]
fn fail() { fn fail() {
let (addr, domain, chain) = start_server(); let (addr, domain, chain) = start_server();
let mut config = ClientConfig::new();
let mut chain = BufReader::new(Cursor::new(chain));
config.root_store.add_pem_file(&mut chain).unwrap();
let config = Arc::new(config);
assert_ne!(domain, &"google.com"); assert_ne!(domain, &"google.com");
assert!(start_client(addr, "google.com", chain).is_err()); assert!(start_client(addr, "google.com", config).is_err());
} }