[DRAFT] update tokio-rustls to rustls 0.20.x (#64)

* update to rustls 0.20

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* track simple renamings in rustls

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* use reader/writer methods

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* fix find and replace

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* use rustls-pemfile crate for pem file parsing

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* update misc api breakage

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* update client example with api changes

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* update server example with new APIs

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* update test_stream test

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* update tests to use new APIs

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* rm unused imports

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* handle rustls `WouldBlock` on eof

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* expect rustls to return wouldblock in tests

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* i think this is *actually* the right EOF behavior

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* bump version

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* okay that seems to fix it

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* update to track builder API changes

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* actually shutdown read side on close notify

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* Further updates to rustls 0.20 (#68)

* Adapt to RootCertStore API changes

* Handle UnexpectedEof errors

* Rename would_block to io_pending

* Try to make badssl test failures more verbose

* Rebuild AsyncRead impl

* Upgrade to current rustls

* Revert to using assert!()

* Update to rustls 0.20

* Forward rustls features

Co-authored-by: Dirkjan Ochtman <dirkjan@ochtman.nl>
This commit is contained in:
Eliza Weisman 2021-09-28 10:01:37 -07:00 committed by GitHub
parent db01bce007
commit 8501aafae5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 354 additions and 142 deletions

View File

@ -1,6 +1,6 @@
[package] [package]
name = "tokio-rustls" name = "tokio-rustls"
version = "0.22.0" version = "0.23.0"
authors = ["quininer kel <quininer@live.com>"] authors = ["quininer kel <quininer@live.com>"]
license = "MIT/Apache-2.0" license = "MIT/Apache-2.0"
repository = "https://github.com/tokio-rs/tls" repository = "https://github.com/tokio-rs/tls"
@ -13,15 +13,19 @@ edition = "2018"
[dependencies] [dependencies]
tokio = "1.0" tokio = "1.0"
rustls = "0.19" rustls = { version = "0.20", default-features = false }
webpki = "0.21" webpki = "0.22"
[features] [features]
early-data = [] default = ["logging", "tls12"]
dangerous_configuration = ["rustls/dangerous_configuration"] dangerous_configuration = ["rustls/dangerous_configuration"]
early-data = []
logging = ["rustls/logging"]
tls12 = ["rustls/tls12"]
[dev-dependencies] [dev-dependencies]
tokio = { version = "1.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] }
futures-util = "0.3.1" futures-util = "0.3.1"
lazy_static = "1" lazy_static = "1"
webpki-roots = "0.21" webpki-roots = "0.22"
rustls-pemfile = "0.2.1"

View File

@ -8,4 +8,5 @@ edition = "2018"
tokio = { version = "1.0", features = [ "full" ] } tokio = { version = "1.0", features = [ "full" ] }
argh = "0.1" argh = "0.1"
tokio-rustls = { path = "../.." } tokio-rustls = { path = "../.." }
webpki-roots = "0.21" webpki-roots = "0.22"
rustls-pemfile = "0.2"

View File

@ -1,4 +1,5 @@
use argh::FromArgs; use argh::FromArgs;
use std::convert::TryFrom;
use std::fs::File; use std::fs::File;
use std::io; use std::io;
use std::io::BufReader; use std::io::BufReader;
@ -7,7 +8,8 @@ use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{copy, split, stdin as tokio_stdin, stdout as tokio_stdout, AsyncWriteExt}; use tokio::io::{copy, split, stdin as tokio_stdin, stdout as tokio_stdout, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::{rustls::ClientConfig, webpki::DNSNameRef, TlsConnector}; use tokio_rustls::rustls::{self, OwnedTrustAnchor};
use tokio_rustls::{webpki, TlsConnector};
/// Tokio Rustls client example /// Tokio Rustls client example
#[derive(FromArgs)] #[derive(FromArgs)]
@ -40,25 +42,42 @@ async fn main() -> io::Result<()> {
let domain = options.domain.unwrap_or(options.host); let domain = options.domain.unwrap_or(options.host);
let content = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); let content = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
let mut config = ClientConfig::new(); let mut root_cert_store = rustls::RootCertStore::empty();
if let Some(cafile) = &options.cafile { if let Some(cafile) = &options.cafile {
let mut pem = BufReader::new(File::open(cafile)?); let mut pem = BufReader::new(File::open(cafile)?);
config let certs = rustls_pemfile::certs(&mut pem)?;
.root_store let trust_anchors = certs.iter().map(|cert| {
.add_pem_file(&mut pem) let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?; OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
});
root_cert_store.add_server_trust_anchors(trust_anchors);
} else { } else {
config root_cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(
.root_store |ta| {
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
},
));
} }
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_no_client_auth(); // i guess this was previously the default?
let connector = TlsConnector::from(Arc::new(config)); let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(&addr).await?; let stream = TcpStream::connect(&addr).await?;
let (mut stdin, mut stdout) = (tokio_stdin(), tokio_stdout()); let (mut stdin, mut stdout) = (tokio_stdin(), tokio_stdout());
let domain = DNSNameRef::try_from_ascii_str(&domain) let domain = rustls::ServerName::try_from(domain.as_str())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?; .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;
let mut stream = connector.connect(domain, stream).await?; let mut stream = connector.connect(domain, stream).await?;

View File

@ -8,3 +8,4 @@ edition = "2018"
tokio = { version = "1.0", features = [ "full" ] } tokio = { version = "1.0", features = [ "full" ] }
argh = "0.1" argh = "0.1"
tokio-rustls = { path = "../.." } tokio-rustls = { path = "../.." }
rustls-pemfile = "0.2.1"

View File

@ -1,4 +1,5 @@
use argh::FromArgs; use argh::FromArgs;
use rustls_pemfile::{certs, rsa_private_keys};
use std::fs::File; use std::fs::File;
use std::io::{self, BufReader}; use std::io::{self, BufReader};
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
@ -6,8 +7,7 @@ use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{copy, sink, split, AsyncWriteExt}; use tokio::io::{copy, sink, split, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio_rustls::rustls::internal::pemfile::{certs, rsa_private_keys}; use tokio_rustls::rustls::{self, Certificate, PrivateKey};
use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
/// Tokio Rustls server example /// Tokio Rustls server example
@ -33,11 +33,13 @@ struct Options {
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> { fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
certs(&mut BufReader::new(File::open(path)?)) certs(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert")) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
.map(|mut certs| certs.drain(..).map(Certificate).collect())
} }
fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> { fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
rsa_private_keys(&mut BufReader::new(File::open(path)?)) rsa_private_keys(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key")) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
.map(|mut keys| keys.drain(..).map(PrivateKey).collect())
} }
#[tokio::main] #[tokio::main]
@ -53,9 +55,10 @@ async fn main() -> io::Result<()> {
let mut keys = load_keys(&options.key)?; let mut keys = load_keys(&options.key)?;
let flag_echo = options.echo_mode; let flag_echo = options.echo_mode;
let mut config = ServerConfig::new(NoClientAuth::new()); let config = rustls::ServerConfig::builder()
config .with_safe_defaults()
.set_single_cert(certs, keys.remove(0)) .with_no_client_auth()
.with_single_cert(certs, keys.remove(0))
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
let acceptor = TlsAcceptor::from(Arc::new(config)); let acceptor = TlsAcceptor::from(Arc::new(config));

View File

@ -1,36 +1,35 @@
use super::*; use super::*;
use crate::common::IoSession; use crate::common::IoSession;
use rustls::Session;
/// A wrapper around an underlying raw stream which implements the TLS or SSL /// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol. /// protocol.
#[derive(Debug)] #[derive(Debug)]
pub struct TlsStream<IO> { pub struct TlsStream<IO> {
pub(crate) io: IO, pub(crate) io: IO,
pub(crate) session: ClientSession, pub(crate) session: ClientConnection,
pub(crate) state: TlsState, pub(crate) state: TlsState,
} }
impl<IO> TlsStream<IO> { impl<IO> TlsStream<IO> {
#[inline] #[inline]
pub fn get_ref(&self) -> (&IO, &ClientSession) { pub fn get_ref(&self) -> (&IO, &ClientConnection) {
(&self.io, &self.session) (&self.io, &self.session)
} }
#[inline] #[inline]
pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) { pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
(&mut self.io, &mut self.session) (&mut self.io, &mut self.session)
} }
#[inline] #[inline]
pub fn into_inner(self) -> (IO, ClientSession) { pub fn into_inner(self) -> (IO, ClientConnection) {
(self.io, self.session) (self.io, self.session)
} }
} }
impl<IO> IoSession for TlsStream<IO> { impl<IO> IoSession for TlsStream<IO> {
type Io = IO; type Io = IO;
type Session = ClientSession; type Session = ClientConnection;
#[inline] #[inline]
fn skip_handshake(&self) -> bool { fn skip_handshake(&self) -> bool {
@ -68,7 +67,7 @@ where
match stream.as_mut_pin().poll_read(cx, buf) { match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(())) => { Poll::Ready(Ok(())) => {
if prev == buf.remaining() { if prev == buf.remaining() || stream.eof {
this.state.shutdown_read(); this.state.shutdown_read();
} }

View File

@ -1,6 +1,7 @@
use crate::common::{Stream, TlsState}; use crate::common::{Stream, TlsState};
use rustls::Session; use rustls::{ConnectionCommon, SideData};
use std::future::Future; use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{io, mem}; use std::{io, mem};
@ -15,28 +16,30 @@ pub(crate) trait IoSession {
fn into_io(self) -> Self::Io; fn into_io(self) -> Self::Io;
} }
pub(crate) enum MidHandshake<IS> { pub(crate) enum MidHandshake<IS: IoSession> {
Handshaking(IS), Handshaking(IS),
End, End,
Error { io: IS::Io, error: io::Error },
} }
impl<IS> Future for MidHandshake<IS> impl<IS, SD> Future for MidHandshake<IS>
where where
IS: IoSession + Unpin, IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin, IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: Session + Unpin, IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin,
SD: SideData,
{ {
type Output = Result<IS, (io::Error, IS::Io)>; type Output = Result<IS, (io::Error, IS::Io)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut(); let this = self.get_mut();
let mut stream = let mut stream = match mem::replace(this, MidHandshake::End) {
if let MidHandshake::Handshaking(stream) = mem::replace(this, MidHandshake::End) { MidHandshake::Handshaking(stream) => stream,
stream // Starting the handshake returned an error; fail the future immediately.
} else { MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))),
panic!("unexpected polling after handshake") _ => panic!("unexpected polling after handshake"),
}; };
if !stream.skip_handshake() { if !stream.skip_handshake() {
let (state, io, session) = stream.get_mut(); let (state, io, session) = stream.get_mut();

View File

@ -1,8 +1,9 @@
mod handshake; mod handshake;
pub(crate) use handshake::{IoSession, MidHandshake}; pub(crate) use handshake::{IoSession, MidHandshake};
use rustls::Session; use rustls::{ConnectionCommon, SideData};
use std::io::{self, IoSlice, Read, Write}; use std::io::{self, IoSlice, Read, Write};
use std::ops::{Deref, DerefMut};
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
@ -57,20 +58,26 @@ impl TlsState {
} }
} }
pub struct Stream<'a, IO, S> { pub struct Stream<'a, IO, C> {
pub io: &'a mut IO, pub io: &'a mut IO,
pub session: &'a mut S, pub session: &'a mut C,
pub eof: bool, pub eof: bool,
pub unexpected_eof: bool,
} }
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { where
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
SD: SideData,
{
pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
Stream { Stream {
io, io,
session, session,
// The state so far is only used to detect EOF, so either Stream // The state so far is only used to detect EOF, so either Stream
// or EarlyData state should both be all right. // or EarlyData state should both be all right.
eof: false, eof: false,
unexpected_eof: false,
} }
} }
@ -214,7 +221,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
} }
} }
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
where
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
SD: SideData,
{
fn poll_read( fn poll_read(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
@ -223,10 +234,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
let prev = buf.remaining(); let prev = buf.remaining();
while buf.remaining() != 0 { while buf.remaining() != 0 {
let mut would_block = false; let mut io_pending = false;
// read a packet // read a packet
while self.session.wants_read() { while !self.eof && self.session.wants_read() {
match self.read_io(cx) { match self.read_io(cx) {
Poll::Ready(Ok(0)) => { Poll::Ready(Ok(0)) => {
self.eof = true; self.eof = true;
@ -234,30 +245,51 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
} }
Poll::Ready(Ok(_)) => (), Poll::Ready(Ok(_)) => (),
Poll::Pending => { Poll::Pending => {
would_block = true; io_pending = true;
break; break;
} }
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
} }
} }
return match self.session.read(buf.initialize_unfilled()) { return match self.session.reader().read(buf.initialize_unfilled()) {
Ok(0) if prev == buf.remaining() && would_block => Poll::Pending, // If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
// connection with a `CloseNotify` message and no more data will be forthcoming.
Ok(0) => break,
// Rustls yielded more data: advance the buffer, then see if more data is coming.
Ok(n) => { Ok(n) => {
buf.advance(n); buf.advance(n);
if self.eof || would_block { if self.eof || io_pending {
break; break;
} else { } else {
continue; continue;
} }
} }
Err(ref err)
if err.kind() == io::ErrorKind::ConnectionAborted // Rustls doesn't have more data to yield, but it believes the connection is open.
&& prev != buf.remaining() => Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
{ if prev == buf.remaining() && io_pending {
break Poll::Pending
} else if self.eof || io_pending {
break;
} else {
continue;
}
} }
Err(err) if err.kind() == io::ErrorKind::UnexpectedEof => {
self.eof = true;
self.unexpected_eof = true;
if prev == buf.remaining() {
Poll::Ready(Err(err))
} else {
break;
}
}
// This should be unreachable.
Err(err) => Poll::Ready(Err(err)), Err(err) => Poll::Ready(Err(err)),
}; };
} }
@ -266,7 +298,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
} }
} }
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C>
where
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
SD: SideData,
{
fn poll_write( fn poll_write(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut Context, cx: &mut Context,
@ -277,7 +313,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
while pos != buf.len() { while pos != buf.len() {
let mut would_block = false; let mut would_block = false;
match self.session.write(&buf[pos..]) { match self.session.writer().write(&buf[pos..]) {
Ok(n) => pos += n, Ok(n) => pos += n,
Err(err) => return Poll::Ready(Err(err)), Err(err) => return Poll::Ready(Err(err)),
}; };
@ -304,7 +340,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
} }
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.session.flush()?; self.session.writer().flush()?;
while self.session.wants_write() { while self.session.wants_write() {
ready!(self.write_io(cx))?; ready!(self.write_io(cx))?;
} }

View File

@ -1,16 +1,15 @@
use super::Stream; use super::Stream;
use futures_util::future::poll_fn; use futures_util::future::poll_fn;
use futures_util::task::noop_waker_ref; use futures_util::task::noop_waker_ref;
use rustls::internal::pemfile::{certs, rsa_private_keys}; use rustls::{ClientConnection, Connection, OwnedTrustAnchor, RootCertStore, ServerConnection};
use rustls::{ClientConfig, ClientSession, NoClientAuth, ServerConfig, ServerSession, Session}; use rustls_pemfile::{certs, rsa_private_keys};
use std::io::{self, BufReader, Cursor, Read, Write}; use std::io::{self, BufReader, Cursor, Read, Write};
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use webpki::DNSNameRef;
struct Good<'a>(&'a mut dyn Session); struct Good<'a>(&'a mut Connection);
impl<'a> AsyncRead for Good<'a> { impl<'a> AsyncRead for Good<'a> {
fn poll_read( fn poll_read(
@ -50,9 +49,10 @@ impl<'a> AsyncWrite for Good<'a> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.send_close_notify(); self.0.send_close_notify();
Poll::Ready(Ok(())) dbg!("sent close notify");
self.poll_flush(cx)
} }
} }
@ -120,23 +120,28 @@ impl AsyncWrite for Eof {
async fn stream_good() -> io::Result<()> { async fn stream_good() -> io::Result<()> {
const FILE: &[u8] = include_bytes!("../../README.md"); const FILE: &[u8] = include_bytes!("../../README.md");
let (mut server, mut client) = make_pair(); let (server, mut client) = make_pair();
let mut server = Connection::from(server);
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
io::copy(&mut Cursor::new(FILE), &mut server)?; io::copy(&mut Cursor::new(FILE), &mut server.writer())?;
server.send_close_notify();
let mut server = Connection::from(server);
{ {
let mut good = Good(&mut server); let mut good = Good(&mut server);
let mut stream = Stream::new(&mut good, &mut client); let mut stream = Stream::new(&mut good, &mut client);
let mut buf = Vec::new(); let mut buf = Vec::new();
stream.read_to_end(&mut buf).await?; dbg!(stream.read_to_end(&mut buf).await)?;
assert_eq!(buf, FILE); assert_eq!(buf, FILE);
stream.write_all(b"Hello World!").await?; dbg!(stream.write_all(b"Hello World!").await)?;
stream.flush().await?; stream.session.send_close_notify();
dbg!(stream.shutdown().await)?;
} }
let mut buf = String::new(); let mut buf = String::new();
server.read_to_string(&mut buf)?; dbg!(server.process_new_packets()).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
dbg!(server.reader().read_to_string(&mut buf))?;
assert_eq!(buf, "Hello World!"); assert_eq!(buf, "Hello World!");
Ok(()) as io::Result<()> Ok(()) as io::Result<()>
@ -144,9 +149,10 @@ async fn stream_good() -> io::Result<()> {
#[tokio::test] #[tokio::test]
async fn stream_bad() -> io::Result<()> { async fn stream_bad() -> io::Result<()> {
let (mut server, mut client) = make_pair(); let (server, mut client) = make_pair();
let mut server = Connection::from(server);
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
client.set_buffer_limit(1024); client.set_buffer_limit(Some(1024));
let mut bad = Pending; let mut bad = Pending;
let mut stream = Stream::new(&mut bad, &mut client); let mut stream = Stream::new(&mut bad, &mut client);
@ -170,7 +176,8 @@ async fn stream_bad() -> io::Result<()> {
#[tokio::test] #[tokio::test]
async fn stream_handshake() -> io::Result<()> { async fn stream_handshake() -> io::Result<()> {
let (mut server, mut client) = make_pair(); let (server, mut client) = make_pair();
let mut server = Connection::from(server);
{ {
let mut good = Good(&mut server); let mut good = Good(&mut server);
@ -208,42 +215,72 @@ async fn stream_handshake_eof() -> io::Result<()> {
#[tokio::test] #[tokio::test]
async fn stream_eof() -> io::Result<()> { async fn stream_eof() -> io::Result<()> {
let (mut server, mut client) = make_pair(); let (server, mut client) = make_pair();
let mut server = Connection::from(server);
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
let mut good = Good(&mut server); let mut good = Good(&mut server);
let mut stream = Stream::new(&mut good, &mut client).set_eof(true); let mut stream = Stream::new(&mut good, &mut client).set_eof(true);
let mut buf = Vec::new(); let mut buf = Vec::new();
stream.read_to_end(&mut buf).await?; let result = stream.read_to_end(&mut buf).await;
assert_eq!(buf.len(), 0); assert_eq!(
result.err().map(|e| e.kind()),
Some(io::ErrorKind::UnexpectedEof)
);
Ok(()) as io::Result<()> Ok(()) as io::Result<()>
} }
fn make_pair() -> (ServerSession, ClientSession) { fn make_pair() -> (ServerConnection, ClientConnection) {
use std::convert::TryFrom;
const CERT: &str = include_str!("../../tests/end.cert"); const CERT: &str = include_str!("../../tests/end.cert");
const CHAIN: &str = include_str!("../../tests/end.chain"); const CHAIN: &str = include_str!("../../tests/end.chain");
const RSA: &str = include_str!("../../tests/end.rsa"); const RSA: &str = include_str!("../../tests/end.rsa");
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
.unwrap()
.drain(..)
.map(rustls::Certificate)
.collect();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let mut sconfig = ServerConfig::new(NoClientAuth::new()); let mut keys = keys.drain(..).map(rustls::PrivateKey);
sconfig.set_single_cert(cert, keys.pop().unwrap()).unwrap(); let sconfig = rustls::ServerConfig::builder()
let server = ServerSession::new(&Arc::new(sconfig)); .with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, keys.next().unwrap())
.unwrap();
let server = ServerConnection::new(Arc::new(sconfig)).unwrap();
let domain = DNSNameRef::try_from_ascii_str("localhost").unwrap(); let domain = rustls::ServerName::try_from("localhost").unwrap();
let mut cconfig = ClientConfig::new(); let mut client_root_cert_store = RootCertStore::empty();
let mut chain = BufReader::new(Cursor::new(CHAIN)); let mut chain = BufReader::new(Cursor::new(CHAIN));
cconfig.root_store.add_pem_file(&mut chain).unwrap(); let certs = certs(&mut chain).unwrap();
let client = ClientSession::new(&Arc::new(cconfig), domain); let trust_anchors = certs
.iter()
.map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
.collect::<Vec<_>>();
client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter());
let cconfig = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(client_root_cert_store)
.with_no_client_auth();
let client = ClientConnection::new(Arc::new(cconfig), domain).unwrap();
(server, client) (server, client)
} }
fn do_handshake( fn do_handshake(
client: &mut ClientSession, client: &mut ClientConnection,
server: &mut ServerSession, server: &mut Connection,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
let mut good = Good(server); let mut good = Good(server);

View File

@ -14,14 +14,13 @@ mod common;
pub mod server; pub mod server;
use common::{MidHandshake, Stream, TlsState}; use common::{MidHandshake, Stream, TlsState};
use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session}; use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
use std::future::Future; use std::future::Future;
use std::io; use std::io;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use webpki::DNSNameRef;
pub use rustls; pub use rustls;
pub use webpki; pub use webpki;
@ -68,19 +67,29 @@ impl TlsConnector {
} }
#[inline] #[inline]
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO> pub fn connect<IO>(&self, domain: rustls::ServerName, stream: IO) -> Connect<IO>
where where
IO: AsyncRead + AsyncWrite + Unpin, IO: AsyncRead + AsyncWrite + Unpin,
{ {
self.connect_with(domain, stream, |_| ()) self.connect_with(domain, stream, |_| ())
} }
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO> pub fn connect_with<IO, F>(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect<IO>
where where
IO: AsyncRead + AsyncWrite + Unpin, IO: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(&mut ClientSession), F: FnOnce(&mut ClientConnection),
{ {
let mut session = ClientSession::new(&self.inner, domain); let mut session = match ClientConnection::new(self.inner.clone(), domain) {
Ok(session) => session,
Err(error) => {
return Connect(MidHandshake::Error {
io: stream,
// TODO(eliza): should this really return an `io::Error`?
// Probably not...
error: io::Error::new(io::ErrorKind::Other, error),
});
}
};
f(&mut session); f(&mut session);
Connect(MidHandshake::Handshaking(client::TlsStream { Connect(MidHandshake::Handshaking(client::TlsStream {
@ -113,9 +122,19 @@ impl TlsAcceptor {
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO> pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
where where
IO: AsyncRead + AsyncWrite + Unpin, IO: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(&mut ServerSession), F: FnOnce(&mut ServerConnection),
{ {
let mut session = ServerSession::new(&self.inner); let mut session = match ServerConnection::new(self.inner.clone()) {
Ok(session) => session,
Err(error) => {
return Accept(MidHandshake::Error {
io: stream,
// TODO(eliza): should this really return an `io::Error`?
// Probably not...
error: io::Error::new(io::ErrorKind::Other, error),
});
}
};
f(&mut session); f(&mut session);
Accept(MidHandshake::Handshaking(server::TlsStream { Accept(MidHandshake::Handshaking(server::TlsStream {
@ -201,7 +220,7 @@ pub enum TlsStream<T> {
} }
impl<T> TlsStream<T> { impl<T> TlsStream<T> {
pub fn get_ref(&self) -> (&T, &dyn Session) { pub fn get_ref(&self) -> (&T, &CommonState) {
use TlsStream::*; use TlsStream::*;
match self { match self {
Client(io) => { Client(io) => {
@ -215,7 +234,7 @@ impl<T> TlsStream<T> {
} }
} }
pub fn get_mut(&mut self) -> (&mut T, &mut dyn Session) { pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
use TlsStream::*; use TlsStream::*;
match self { match self {
Client(io) => { Client(io) => {

View File

@ -1,36 +1,35 @@
use super::*; use super::*;
use crate::common::IoSession; use crate::common::IoSession;
use rustls::Session;
/// A wrapper around an underlying raw stream which implements the TLS or SSL /// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol. /// protocol.
#[derive(Debug)] #[derive(Debug)]
pub struct TlsStream<IO> { pub struct TlsStream<IO> {
pub(crate) io: IO, pub(crate) io: IO,
pub(crate) session: ServerSession, pub(crate) session: ServerConnection,
pub(crate) state: TlsState, pub(crate) state: TlsState,
} }
impl<IO> TlsStream<IO> { impl<IO> TlsStream<IO> {
#[inline] #[inline]
pub fn get_ref(&self) -> (&IO, &ServerSession) { pub fn get_ref(&self) -> (&IO, &ServerConnection) {
(&self.io, &self.session) (&self.io, &self.session)
} }
#[inline] #[inline]
pub fn get_mut(&mut self) -> (&mut IO, &mut ServerSession) { pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
(&mut self.io, &mut self.session) (&mut self.io, &mut self.session)
} }
#[inline] #[inline]
pub fn into_inner(self) -> (IO, ServerSession) { pub fn into_inner(self) -> (IO, ServerConnection) {
(self.io, self.session) (self.io, self.session)
} }
} }
impl<IO> IoSession for TlsStream<IO> { impl<IO> IoSession for TlsStream<IO> {
type Io = IO; type Io = IO;
type Session = ServerSession; type Session = ServerConnection;
#[inline] #[inline]
fn skip_handshake(&self) -> bool { fn skip_handshake(&self) -> bool {
@ -67,7 +66,7 @@ where
match stream.as_mut_pin().poll_read(cx, buf) { match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(())) => { Poll::Ready(Ok(())) => {
if prev == buf.remaining() { if prev == buf.remaining() || stream.eof {
this.state.shutdown_read(); this.state.shutdown_read();
} }

View File

@ -1,10 +1,14 @@
use rustls::ClientConfig; use std::convert::TryFrom;
use std::io; use std::io;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, TlsConnector}; use tokio_rustls::{
client::TlsStream,
rustls::{self, ClientConfig, OwnedTrustAnchor},
TlsConnector,
};
async fn get( async fn get(
config: Arc<ClientConfig>, config: Arc<ClientConfig>,
@ -15,7 +19,7 @@ async fn get(
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
let addr = (domain, port).to_socket_addrs()?.next().unwrap(); let addr = (domain, port).to_socket_addrs()?.next().unwrap();
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap(); let domain = rustls::ServerName::try_from(domain).unwrap();
let mut buf = Vec::new(); let mut buf = Vec::new();
let stream = TcpStream::connect(&addr).await?; let stream = TcpStream::connect(&addr).await?;
@ -29,16 +33,31 @@ async fn get(
#[tokio::test] #[tokio::test]
async fn test_tls12() -> io::Result<()> { async fn test_tls12() -> io::Result<()> {
let mut config = ClientConfig::new(); let mut root_store = rustls::RootCertStore::empty();
config root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
.root_store OwnedTrustAnchor::from_subject_spki_name_constraints(
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); ta.subject,
config.versions = vec![rustls::ProtocolVersion::TLSv1_2]; ta.spki,
ta.name_constraints,
)
}));
let config = rustls::ClientConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&rustls::version::TLS12])
.unwrap()
.with_root_certificates(root_store)
.with_no_client_auth();
let config = Arc::new(config); let config = Arc::new(config);
let domain = "tls-v1-2.badssl.com"; let domain = "tls-v1-2.badssl.com";
let (_, output) = get(config.clone(), domain, 1012).await?; let (_, output) = get(config.clone(), domain, 1012).await?;
assert!(output.contains("<title>tls-v1-2.badssl.com</title>")); assert!(
output.contains("<title>tls-v1-2.badssl.com</title>"),
"failed badssl test, output: {}",
output
);
Ok(()) Ok(())
} }
@ -52,15 +71,27 @@ fn test_tls13() {
#[tokio::test] #[tokio::test]
async fn test_modern() -> io::Result<()> { async fn test_modern() -> io::Result<()> {
let mut config = ClientConfig::new(); let mut root_store = rustls::RootCertStore::empty();
config root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
.root_store OwnedTrustAnchor::from_subject_spki_name_constraints(
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let config = Arc::new(config); let config = Arc::new(config);
let domain = "mozilla-modern.badssl.com"; let domain = "mozilla-modern.badssl.com";
let (_, output) = get(config.clone(), domain, 443).await?; let (_, output) = get(config.clone(), domain, 443).await?;
assert!(output.contains("<title>mozilla-modern.badssl.com</title>")); assert!(
output.contains("<title>mozilla-modern.badssl.com</title>"),
"failed badssl test, output: {}",
output
);
Ok(()) Ok(())
} }

View File

@ -1,7 +1,8 @@
#![cfg(feature = "early-data")] #![cfg(feature = "early-data")]
use futures_util::{future, future::Future, ready}; use futures_util::{future, future::Future, ready};
use rustls::ClientConfig; use rustls::RootCertStore;
use std::convert::TryFrom;
use std::io::{self, BufRead, BufReader, Cursor}; use std::io::{self, BufRead, BufReader, Cursor};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin; use std::pin::Pin;
@ -12,7 +13,11 @@ use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf}; use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::sleep; use tokio::time::sleep;
use tokio_rustls::{client::TlsStream, TlsConnector}; use tokio_rustls::{
client::TlsStream,
rustls::{self, ClientConfig, OwnedTrustAnchor},
TlsConnector,
};
struct Read1<T>(T); struct Read1<T>(T);
@ -34,7 +39,7 @@ async fn send(
) -> io::Result<TlsStream<TcpStream>> { ) -> io::Result<TlsStream<TcpStream>> {
let connector = TlsConnector::from(config).early_data(true); let connector = TlsConnector::from(config).early_data(true);
let stream = TcpStream::connect(&addr).await?; let stream = TcpStream::connect(&addr).await?;
let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap(); let domain = rustls::ServerName::try_from("testserver.com").unwrap();
let mut stream = connector.connect(domain, stream).await?; let mut stream = connector.connect(domain, stream).await?;
stream.write_all(data).await?; stream.write_all(data).await?;
@ -81,10 +86,28 @@ async fn test_0rtt() -> io::Result<()> {
// wait openssl server // wait openssl server
sleep(Duration::from_secs(1)).await; sleep(Duration::from_secs(1)).await;
let mut config = ClientConfig::new();
let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); let mut chain = BufReader::new(Cursor::new(include_str!("end.chain")));
config.root_store.add_pem_file(&mut chain).unwrap(); let certs = rustls_pemfile::certs(&mut chain).unwrap();
config.versions = vec![rustls::ProtocolVersion::TLSv1_3]; let trust_anchors = certs
.iter()
.map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
.collect::<Vec<_>>();
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(trust_anchors.into_iter());
let mut config = rustls::ClientConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap()
.with_root_certificates(root_store)
.with_no_client_auth();
config.enable_early_data = true; config.enable_early_data = true;
let config = Arc::new(config); let config = Arc::new(config);
let addr = SocketAddr::from(([127, 0, 0, 1], 12354)); let addr = SocketAddr::from(([127, 0, 0, 1], 12354));

View File

@ -1,7 +1,8 @@
use futures_util::future::TryFutureExt; use futures_util::future::TryFutureExt;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use rustls::internal::pemfile::{certs, rsa_private_keys}; use rustls::{ClientConfig, OwnedTrustAnchor};
use rustls::{ClientConfig, ServerConfig}; use rustls_pemfile::{certs, rsa_private_keys};
use std::convert::TryFrom;
use std::io::{BufReader, Cursor}; use std::io::{BufReader, Cursor};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::mpsc::channel; use std::sync::mpsc::channel;
@ -13,18 +14,24 @@ use tokio::runtime;
use tokio_rustls::{TlsAcceptor, TlsConnector}; use tokio_rustls::{TlsAcceptor, TlsConnector};
const CERT: &str = include_str!("end.cert"); const CERT: &str = include_str!("end.cert");
const CHAIN: &str = include_str!("end.chain"); const CHAIN: &[u8] = include_bytes!("end.chain");
const RSA: &str = include_str!("end.rsa"); const RSA: &str = include_str!("end.rsa");
lazy_static! { lazy_static! {
static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = { static ref TEST_SERVER: (SocketAddr, &'static str, &'static [u8]) = {
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
.unwrap()
.drain(..)
.map(rustls::Certificate)
.collect();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let mut keys = keys.drain(..).map(rustls::PrivateKey);
let mut config = ServerConfig::new(rustls::NoClientAuth::new()); let config = rustls::ServerConfig::builder()
config .with_safe_defaults()
.set_single_cert(cert, keys.pop().unwrap()) .with_no_client_auth()
.expect("invalid key or certificate"); .with_single_cert(cert, keys.next().unwrap())
.unwrap();
let acceptor = TlsAcceptor::from(Arc::new(config)); let acceptor = TlsAcceptor::from(Arc::new(config));
let (send, recv) = channel(); let (send, recv) = channel();
@ -70,14 +77,14 @@ lazy_static! {
}; };
} }
fn start_server() -> &'static (SocketAddr, &'static str, &'static str) { fn start_server() -> &'static (SocketAddr, &'static str, &'static [u8]) {
&*TEST_SERVER &*TEST_SERVER
} }
async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> { async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> {
const FILE: &[u8] = include_bytes!("../README.md"); const FILE: &[u8] = include_bytes!("../README.md");
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap(); let domain = rustls::ServerName::try_from(domain).unwrap();
let config = TlsConnector::from(config); let config = TlsConnector::from(config);
let mut buf = vec![0; FILE.len()]; let mut buf = vec![0; FILE.len()];
@ -102,12 +109,27 @@ async fn pass() -> io::Result<()> {
use std::time::*; use std::time::*;
tokio::time::sleep(Duration::from_secs(1)).await; tokio::time::sleep(Duration::from_secs(1)).await;
let mut config = ClientConfig::new(); let chain = certs(&mut std::io::Cursor::new(*chain)).unwrap();
let mut chain = BufReader::new(Cursor::new(chain)); let trust_anchors = chain
config.root_store.add_pem_file(&mut chain).unwrap(); .iter()
.map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
.collect::<Vec<_>>();
let mut root_store = rustls::RootCertStore::empty();
root_store.add_server_trust_anchors(trust_anchors.into_iter());
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let config = Arc::new(config); let config = Arc::new(config);
start_client(*addr, domain, config.clone()).await?; start_client(*addr, domain, config).await?;
Ok(()) Ok(())
} }
@ -116,9 +138,24 @@ async fn pass() -> io::Result<()> {
async fn fail() -> io::Result<()> { async fn fail() -> io::Result<()> {
let (addr, domain, chain) = start_server(); let (addr, domain, chain) = start_server();
let mut config = ClientConfig::new(); let chain = certs(&mut std::io::Cursor::new(*chain)).unwrap();
let mut chain = BufReader::new(Cursor::new(chain)); let trust_anchors = chain
config.root_store.add_pem_file(&mut chain).unwrap(); .iter()
.map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
.collect::<Vec<_>>();
let mut root_store = rustls::RootCertStore::empty();
root_store.add_server_trust_anchors(trust_anchors.into_iter());
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let config = Arc::new(config); let config = Arc::new(config);
assert_ne!(domain, &"google.com"); assert_ne!(domain, &"google.com");