diff --git a/Cargo.toml b/Cargo.toml index 9de53ac..b15bea2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,4 @@ webpki = "0.19" [dev-dependencies] tokio = "0.1.6" lazy_static = "1" +webpki-roots = "0.16" diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..8d6758e --- /dev/null +++ b/src/client.rs @@ -0,0 +1,196 @@ +use super::*; +use std::io::Write; +use rustls::Session; + + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ClientSession, + pub(crate) state: TlsState, + pub(crate) early_data: (usize, Vec) +} + +#[derive(Debug)] +pub(crate) enum TlsState { + EarlyData, + Stream, + Eof, + Shutdown +} + +pub(crate) enum MidHandshake { + Handshaking(TlsStream), + EarlyData(TlsStream), + End +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ClientSession) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ClientSession) { + (self.io, self.session) + } +} + +impl Future for MidHandshake +where IO: AsyncRead + AsyncWrite, +{ + type Item = TlsStream; + type Error = io::Error; + + #[inline] + fn poll(&mut self) -> Poll { + match self { + MidHandshake::Handshaking(stream) => { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(io, session); + + if stream.session.is_handshaking() { + try_nb!(stream.complete_io()); + } + + if stream.session.wants_write() { + try_nb!(stream.complete_io()); + } + }, + _ => () + } + + match mem::replace(self, MidHandshake::End) { + MidHandshake::Handshaking(stream) + | MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)), + MidHandshake::End => panic!() + } + } +} + +impl io::Read for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match self.state { + TlsState::EarlyData => { + let (pos, data) = &mut self.early_data; + + // complete handshake + if stream.session.is_handshaking() { + stream.complete_io()?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = stream.write(&data[*pos..])?; + *pos += len; + } + } + + // end + self.state = TlsState::Stream; + data.clear(); + stream.read(buf) + }, + TlsState::Stream => match stream.read(buf) { + Ok(0) => { + self.state = TlsState::Eof; + Ok(0) + }, + Ok(n) => Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.state = TlsState::Shutdown; + stream.session.send_close_notify(); + Ok(0) + }, + Err(e) => Err(e) + }, + TlsState::Eof | TlsState::Shutdown => Ok(0), + } + } +} + +impl io::Write for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match self.state { + TlsState::EarlyData => { + let (pos, data) = &mut self.early_data; + + // write early data + if let Some(mut early_data) = stream.session.early_data() { + let len = early_data.write(buf)?; + data.extend_from_slice(&buf[..len]); + return Ok(len); + } + + // complete handshake + if stream.session.is_handshaking() { + stream.complete_io()?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = stream.write(&data[*pos..])?; + *pos += len; + } + } + + // end + self.state = TlsState::Stream; + data.clear(); + stream.write(buf) + }, + _ => stream.write(buf) + } + } + + fn flush(&mut self) -> io::Result<()> { + Stream::new(&mut self.io, &mut self.session).flush()?; + self.io.flush() + } +} + +impl AsyncRead for TlsStream +where IO: AsyncRead + AsyncWrite +{ + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } +} + +impl AsyncWrite for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + match self.state { + TlsState::Shutdown => (), + _ => { + self.session.send_close_notify(); + self.state = TlsState::Shutdown; + } + } + + { + let mut stream = Stream::new(&mut self.io, &mut self.session); + try_nb!(stream.complete_io()); + } + self.io.shutdown() + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs index 19fedb1..9010d8d 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -6,18 +6,18 @@ use rustls::WriteV; use tokio_io::{ AsyncRead, AsyncWrite }; -pub struct Stream<'a, S: 'a, IO: 'a> { - pub session: &'a mut S, - pub io: &'a mut IO +pub struct Stream<'a, IO: 'a, S: 'a> { + pub io: &'a mut IO, + pub session: &'a mut S } -pub trait WriteTls<'a, S: Session, IO: AsyncRead + AsyncWrite>: Read + Write { +pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write { fn write_tls(&mut self) -> io::Result; } -impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Stream<'a, S, IO> { - pub fn new(session: &'a mut S, io: &'a mut IO) -> Self { - Stream { session, io } +impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> { + pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { + Stream { io, session } } pub fn complete_io(&mut self) -> io::Result<(usize, usize)> { @@ -66,7 +66,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Stream<'a, S, IO> { } } -impl<'a, S: Session, IO: AsyncRead + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> { +impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream<'a, IO, S> { fn write_tls(&mut self) -> io::Result { use futures::Async; use self::vecbuf::VecBuf; @@ -89,7 +89,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> WriteTls<'a, S, IO> for Stream< } } -impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Read for Stream<'a, S, IO> { +impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> { fn read(&mut self, buf: &mut [u8]) -> io::Result { while self.session.wants_read() { if let (0, 0) = self.complete_io()? { @@ -100,7 +100,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Read for Stream<'a, S, IO> { } } -impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Write for Stream<'a, S, IO> { +impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> { fn write(&mut self, buf: &[u8]) -> io::Result { let len = self.session.write(buf)?; while self.session.wants_write() { diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 66b34b6..a43622c 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -80,7 +80,7 @@ fn stream_good() -> io::Result<()> { { let mut good = Good(&mut server); - let mut stream = Stream::new(&mut client, &mut good); + let mut stream = Stream::new(&mut good, &mut client); let mut buf = Vec::new(); stream.read_to_end(&mut buf)?; @@ -102,7 +102,7 @@ fn stream_bad() -> io::Result<()> { client.set_buffer_limit(1024); let mut bad = Bad(true); - let mut stream = Stream::new(&mut client, &mut bad); + let mut stream = Stream::new(&mut bad, &mut client); assert_eq!(stream.write(&[0x42; 8])?, 8); assert_eq!(stream.write(&[0x42; 8])?, 8); let r = stream.write(&[0x00; 1024])?; // fill buffer @@ -121,7 +121,7 @@ fn stream_handshake() -> io::Result<()> { { let mut good = Good(&mut server); - let mut stream = Stream::new(&mut client, &mut good); + let mut stream = Stream::new(&mut good, &mut client); let (r, w) = stream.complete_io()?; assert!(r > 0); @@ -141,7 +141,7 @@ fn stream_handshake_eof() -> io::Result<()> { let (_, mut client) = make_pair(); let mut bad = Bad(false); - let mut stream = Stream::new(&mut client, &mut bad); + let mut stream = Stream::new(&mut bad, &mut client); let r = stream.complete_io(); assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof); @@ -171,7 +171,7 @@ fn make_pair() -> (ServerSession, ClientSession) { fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) { let mut good = Good(server); - let mut stream = Stream::new(client, &mut good); + let mut stream = Stream::new(&mut good, client); stream.complete_io().unwrap(); stream.complete_io().unwrap(); } diff --git a/src/lib.rs b/src/lib.rs index 0cb0e3c..446e80d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,24 +8,26 @@ extern crate tokio_io; extern crate bytes; extern crate iovec; - mod common; -mod tokio_impl; +pub mod client; +pub mod server; -use std::io; +use std::{ io, mem }; use std::sync::Arc; use webpki::DNSNameRef; use rustls::{ - Session, ClientSession, ServerSession, - ClientConfig, ServerConfig, + ClientSession, ServerSession, + ClientConfig, ServerConfig }; -use tokio_io::{ AsyncRead, AsyncWrite }; +use futures::{Async, Future, Poll}; +use tokio_io::{ AsyncRead, AsyncWrite, try_nb }; use common::Stream; #[derive(Clone)] pub struct TlsConnector { - inner: Arc + inner: Arc, + early_data: bool } #[derive(Clone)] @@ -35,7 +37,7 @@ pub struct TlsAcceptor { impl From> for TlsConnector { fn from(inner: Arc) -> TlsConnector { - TlsConnector { inner } + TlsConnector { inner, early_data: false } } } @@ -46,19 +48,43 @@ impl From> for TlsAcceptor { } impl TlsConnector { + /// Enable 0-RTT. + /// + /// Note that you want to use 0-RTT. + /// You must set `enable_early_data` to `true` in `ClientConfig`. + pub fn early_data(mut self, flag: bool) -> TlsConnector { + self.early_data = flag; + self + } + pub fn connect(&self, domain: DNSNameRef, stream: IO) -> Connect where IO: AsyncRead + AsyncWrite { - Self::connect_with_session(stream, ClientSession::new(&self.inner, domain)) + self.connect_with(domain, stream, |_| ()) } #[inline] - pub fn connect_with_session(stream: IO, session: ClientSession) + pub fn connect_with(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect - where IO: AsyncRead + AsyncWrite + where + IO: AsyncRead + AsyncWrite, + F: FnOnce(&mut ClientSession) { - Connect(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) + let mut session = ClientSession::new(&self.inner, domain); + f(&mut session); + + Connect(if self.early_data { + client::MidHandshake::EarlyData(client::TlsStream { + session, io: stream, + state: client::TlsState::EarlyData, + early_data: (0, Vec::new()) + }) + } else { + client::MidHandshake::Handshaking(client::TlsStream { + session, io: stream, + state: client::TlsState::Stream, + early_data: (0, Vec::new()) + }) }) } } @@ -67,105 +93,53 @@ impl TlsAcceptor { pub fn accept(&self, stream: IO) -> Accept where IO: AsyncRead + AsyncWrite, { - Self::accept_with_session(stream, ServerSession::new(&self.inner)) + self.accept_with(stream, |_| ()) } #[inline] - pub fn accept_with_session(stream: IO, session: ServerSession) -> Accept - where IO: AsyncRead + AsyncWrite + pub fn accept_with(&self, stream: IO, f: F) + -> Accept + where + IO: AsyncRead + AsyncWrite, + F: FnOnce(&mut ServerSession) { - Accept(MidHandshake { - inner: Some(TlsStream { session, io: stream, is_shutdown: false, eof: false }) - }) + let mut session = ServerSession::new(&self.inner); + f(&mut session); + + Accept(server::MidHandshake::Handshaking(server::TlsStream { + session, io: stream, + state: server::TlsState::Stream, + })) } } /// Future returned from `ClientConfigExt::connect_async` which will resolve /// once the connection handshake has finished. -pub struct Connect(MidHandshake); +pub struct Connect(client::MidHandshake); /// Future returned from `ServerConfigExt::accept_async` which will resolve /// once the accept handshake has finished. -pub struct Accept(MidHandshake); +pub struct Accept(server::MidHandshake); -struct MidHandshake { - inner: Option> -} +impl Future for Connect { + type Item = client::TlsStream; + type Error = io::Error; - -/// A wrapper around an underlying raw stream which implements the TLS or SSL -/// protocol. -#[derive(Debug)] -pub struct TlsStream { - is_shutdown: bool, - eof: bool, - io: IO, - session: S -} - -impl TlsStream { - #[inline] - pub fn get_ref(&self) -> (&IO, &S) { - (&self.io, &self.session) - } - - #[inline] - pub fn get_mut(&mut self) -> (&mut IO, &mut S) { - (&mut self.io, &mut self.session) - } - - #[inline] - pub fn into_inner(self) -> (IO, S) { - (self.io, self.session) + fn poll(&mut self) -> Poll { + self.0.poll() } } -impl From<(IO, S)> for TlsStream { - #[inline] - fn from((io, session): (IO, S)) -> TlsStream { - assert!(!session.is_handshaking()); +impl Future for Accept { + type Item = server::TlsStream; + type Error = io::Error; - TlsStream { - is_shutdown: false, - eof: false, - io, session - } + fn poll(&mut self) -> Poll { + self.0.poll() } } -impl io::Read for TlsStream - where IO: AsyncRead + AsyncWrite, S: Session -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - if self.eof { - return Ok(0); - } - - match Stream::new(&mut self.session, &mut self.io).read(buf) { - Ok(0) => { self.eof = true; Ok(0) }, - Ok(n) => Ok(n), - Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { - self.eof = true; - self.is_shutdown = true; - self.session.send_close_notify(); - Ok(0) - }, - Err(e) => Err(e) - } - } -} - -impl io::Write for TlsStream - where IO: AsyncRead + AsyncWrite, S: Session -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - Stream::new(&mut self.session, &mut self.io).write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - Stream::new(&mut self.session, &mut self.io).flush()?; - self.io.flush() - } -} +#[cfg(test)] +mod test_0rtt; diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..42dd18d --- /dev/null +++ b/src/server.rs @@ -0,0 +1,139 @@ +use super::*; +use rustls::Session; + + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ServerSession, + pub(crate) state: TlsState +} + +#[derive(Debug)] +pub(crate) enum TlsState { + Stream, + Eof, + Shutdown +} + +pub(crate) enum MidHandshake { + Handshaking(TlsStream), + End +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ServerSession) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ServerSession) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ServerSession) { + (self.io, self.session) + } +} + +impl Future for MidHandshake +where IO: AsyncRead + AsyncWrite, +{ + type Item = TlsStream; + type Error = io::Error; + + #[inline] + fn poll(&mut self) -> Poll { + match self { + MidHandshake::Handshaking(stream) => { + let (io, session) = stream.get_mut(); + let mut stream = Stream::new(io, session); + + if stream.session.is_handshaking() { + try_nb!(stream.complete_io()); + } + + if stream.session.wants_write() { + try_nb!(stream.complete_io()); + } + }, + _ => () + } + + match mem::replace(self, MidHandshake::End) { + MidHandshake::Handshaking(stream) => Ok(Async::Ready(stream)), + MidHandshake::End => panic!() + } + } +} + +impl io::Read for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + + match self.state { + TlsState::Stream => match stream.read(buf) { + Ok(0) => { + self.state = TlsState::Eof; + Ok(0) + }, + Ok(n) => Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => { + self.state = TlsState::Shutdown; + stream.session.send_close_notify(); + Ok(0) + }, + Err(e) => Err(e) + }, + TlsState::Eof | TlsState::Shutdown => Ok(0) + } + } +} + +impl io::Write for TlsStream +where IO: AsyncRead + AsyncWrite +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut stream = Stream::new(&mut self.io, &mut self.session); + stream.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + Stream::new(&mut self.io, &mut self.session).flush()?; + self.io.flush() + } +} + +impl AsyncRead for TlsStream +where IO: AsyncRead + AsyncWrite +{ + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } +} + +impl AsyncWrite for TlsStream +where IO: AsyncRead + AsyncWrite, +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + match self.state { + TlsState::Shutdown => (), + _ => { + self.session.send_close_notify(); + self.state = TlsState::Shutdown; + } + } + + { + let mut stream = Stream::new(&mut self.io, &mut self.session); + try_nb!(stream.complete_io()); + } + self.io.shutdown() + } +} diff --git a/src/test_0rtt.rs b/src/test_0rtt.rs new file mode 100644 index 0000000..0182406 --- /dev/null +++ b/src/test_0rtt.rs @@ -0,0 +1,51 @@ +extern crate tokio; +extern crate webpki; +extern crate webpki_roots; + +use std::io; +use std::sync::Arc; +use std::net::ToSocketAddrs; +use self::tokio::io as aio; +use self::tokio::prelude::*; +use self::tokio::net::TcpStream; +use rustls::ClientConfig; +use ::{ TlsConnector, client::TlsStream }; + + +fn get(config: Arc, domain: &str, rtt0: bool) + -> io::Result<(TlsStream, String)> +{ + let config = TlsConnector::from(config).early_data(rtt0); + let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); + + let addr = (domain, 443) + .to_socket_addrs()? + .next().unwrap(); + + TcpStream::connect(&addr) + .and_then(move |stream| { + let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); + config.connect(domain, stream) + }) + .and_then(move |stream| aio::write_all(stream, input)) + .and_then(move |(stream, _)| aio::read_to_end(stream, Vec::new())) + .map(|(stream, buf)| (stream, String::from_utf8(buf).unwrap())) + .wait() +} + +#[test] +fn test_0rtt() { + let mut config = ClientConfig::new(); + config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + config.enable_early_data = true; + let config = Arc::new(config); + let domain = "mozilla-modern.badssl.com"; + + let (_, output) = get(config.clone(), domain, false).unwrap(); + assert!(output.contains("mozilla-modern.badssl.com")); + + let (io, output) = get(config.clone(), domain, true).unwrap(); + assert!(output.contains("mozilla-modern.badssl.com")); + + assert_eq!(io.early_data.0, 0); +} diff --git a/src/tokio_impl.rs b/src/tokio_impl.rs deleted file mode 100644 index 0897e93..0000000 --- a/src/tokio_impl.rs +++ /dev/null @@ -1,90 +0,0 @@ -use super::*; -use tokio_io::{ AsyncRead, AsyncWrite }; -use futures::{Async, Future, Poll}; -use common::Stream; - - -macro_rules! try_async { - ( $e:expr ) => { - match $e { - Ok(n) => n, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => - return Ok(Async::NotReady), - Err(e) => return Err(e) - } - } -} - -impl Future for Connect { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - self.0.poll() - } -} - -impl Future for Accept { - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - self.0.poll() - } -} - -impl Future for MidHandshake -where - IO: AsyncRead + AsyncWrite, - S: Session -{ - type Item = TlsStream; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - { - let stream = self.inner.as_mut().unwrap(); - let (io, session) = stream.get_mut(); - let mut stream = Stream::new(session, io); - - if stream.session.is_handshaking() { - try_async!(stream.complete_io()); - } - - if stream.session.wants_write() { - try_async!(stream.complete_io()); - } - } - - Ok(Async::Ready(self.inner.take().unwrap())) - } -} - -impl AsyncRead for TlsStream - where - IO: AsyncRead + AsyncWrite, - S: Session -{ - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } -} - -impl AsyncWrite for TlsStream - where - IO: AsyncRead + AsyncWrite, - S: Session -{ - fn shutdown(&mut self) -> Poll<(), io::Error> { - if !self.is_shutdown { - self.session.send_close_notify(); - self.is_shutdown = true; - } - - { - let mut stream = Stream::new(&mut self.session, &mut self.io); - try_async!(stream.complete_io()); - } - self.io.shutdown() - } -} diff --git a/tests/test.rs b/tests/test.rs index 8833253..f0703f8 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -66,17 +66,14 @@ fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { &*TEST_SERVER } -fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> { +fn start_client(addr: &SocketAddr, domain: &str, config: Arc) -> io::Result<()> { use tokio::prelude::*; use tokio::io as aio; const FILE: &'static [u8] = include_bytes!("../README.md"); let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); - let mut config = ClientConfig::new(); - let mut chain = BufReader::new(Cursor::new(chain)); - config.root_store.add_pem_file(&mut chain).unwrap(); - let config = TlsConnector::from(Arc::new(config)); + let config = TlsConnector::from(config); let done = TcpStream::connect(addr) .and_then(|stream| config.connect(domain, stream)) @@ -95,13 +92,23 @@ fn start_client(addr: &SocketAddr, domain: &str, chain: &str) -> io::Result<()> fn pass() { let (addr, domain, chain) = start_server(); - start_client(addr, domain, chain).unwrap(); + let mut config = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); + let config = Arc::new(config); + + start_client(addr, domain, config.clone()).unwrap(); } #[test] fn fail() { let (addr, domain, chain) = start_server(); + let mut config = ClientConfig::new(); + let mut chain = BufReader::new(Cursor::new(chain)); + config.root_store.add_pem_file(&mut chain).unwrap(); + let config = Arc::new(config); + assert_ne!(domain, &"google.com"); - assert!(start_client(addr, "google.com", chain).is_err()); + assert!(start_client(addr, "google.com", config).is_err()); }