[DRAFT] update tokio-rustls
to rustls
0.20.x (#64)
* update to rustls 0.20 Signed-off-by: Eliza Weisman <eliza@buoyant.io> * track simple renamings in rustls Signed-off-by: Eliza Weisman <eliza@buoyant.io> * use reader/writer methods Signed-off-by: Eliza Weisman <eliza@buoyant.io> * fix find and replace Signed-off-by: Eliza Weisman <eliza@buoyant.io> * use rustls-pemfile crate for pem file parsing Signed-off-by: Eliza Weisman <eliza@buoyant.io> * update misc api breakage Signed-off-by: Eliza Weisman <eliza@buoyant.io> * update client example with api changes Signed-off-by: Eliza Weisman <eliza@buoyant.io> * update server example with new APIs Signed-off-by: Eliza Weisman <eliza@buoyant.io> * update test_stream test Signed-off-by: Eliza Weisman <eliza@buoyant.io> * update tests to use new APIs Signed-off-by: Eliza Weisman <eliza@buoyant.io> * rm unused imports Signed-off-by: Eliza Weisman <eliza@buoyant.io> * handle rustls `WouldBlock` on eof Signed-off-by: Eliza Weisman <eliza@buoyant.io> * expect rustls to return wouldblock in tests Signed-off-by: Eliza Weisman <eliza@buoyant.io> * i think this is *actually* the right EOF behavior Signed-off-by: Eliza Weisman <eliza@buoyant.io> * bump version Signed-off-by: Eliza Weisman <eliza@buoyant.io> * okay that seems to fix it Signed-off-by: Eliza Weisman <eliza@buoyant.io> * update to track builder API changes Signed-off-by: Eliza Weisman <eliza@buoyant.io> * actually shutdown read side on close notify Signed-off-by: Eliza Weisman <eliza@buoyant.io> * Further updates to rustls 0.20 (#68) * Adapt to RootCertStore API changes * Handle UnexpectedEof errors * Rename would_block to io_pending * Try to make badssl test failures more verbose * Rebuild AsyncRead impl * Upgrade to current rustls * Revert to using assert!() * Update to rustls 0.20 * Forward rustls features Co-authored-by: Dirkjan Ochtman <dirkjan@ochtman.nl>
This commit is contained in:
parent
db01bce007
commit
8501aafae5
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tokio-rustls"
|
name = "tokio-rustls"
|
||||||
version = "0.22.0"
|
version = "0.23.0"
|
||||||
authors = ["quininer kel <quininer@live.com>"]
|
authors = ["quininer kel <quininer@live.com>"]
|
||||||
license = "MIT/Apache-2.0"
|
license = "MIT/Apache-2.0"
|
||||||
repository = "https://github.com/tokio-rs/tls"
|
repository = "https://github.com/tokio-rs/tls"
|
||||||
@ -13,15 +13,19 @@ edition = "2018"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tokio = "1.0"
|
tokio = "1.0"
|
||||||
rustls = "0.19"
|
rustls = { version = "0.20", default-features = false }
|
||||||
webpki = "0.21"
|
webpki = "0.22"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
early-data = []
|
default = ["logging", "tls12"]
|
||||||
dangerous_configuration = ["rustls/dangerous_configuration"]
|
dangerous_configuration = ["rustls/dangerous_configuration"]
|
||||||
|
early-data = []
|
||||||
|
logging = ["rustls/logging"]
|
||||||
|
tls12 = ["rustls/tls12"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { version = "1.0", features = ["full"] }
|
tokio = { version = "1.0", features = ["full"] }
|
||||||
futures-util = "0.3.1"
|
futures-util = "0.3.1"
|
||||||
lazy_static = "1"
|
lazy_static = "1"
|
||||||
webpki-roots = "0.21"
|
webpki-roots = "0.22"
|
||||||
|
rustls-pemfile = "0.2.1"
|
||||||
|
@ -8,4 +8,5 @@ edition = "2018"
|
|||||||
tokio = { version = "1.0", features = [ "full" ] }
|
tokio = { version = "1.0", features = [ "full" ] }
|
||||||
argh = "0.1"
|
argh = "0.1"
|
||||||
tokio-rustls = { path = "../.." }
|
tokio-rustls = { path = "../.." }
|
||||||
webpki-roots = "0.21"
|
webpki-roots = "0.22"
|
||||||
|
rustls-pemfile = "0.2"
|
@ -1,4 +1,5 @@
|
|||||||
use argh::FromArgs;
|
use argh::FromArgs;
|
||||||
|
use std::convert::TryFrom;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
@ -7,7 +8,8 @@ use std::path::PathBuf;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{copy, split, stdin as tokio_stdin, stdout as tokio_stdout, AsyncWriteExt};
|
use tokio::io::{copy, split, stdin as tokio_stdin, stdout as tokio_stdout, AsyncWriteExt};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio_rustls::{rustls::ClientConfig, webpki::DNSNameRef, TlsConnector};
|
use tokio_rustls::rustls::{self, OwnedTrustAnchor};
|
||||||
|
use tokio_rustls::{webpki, TlsConnector};
|
||||||
|
|
||||||
/// Tokio Rustls client example
|
/// Tokio Rustls client example
|
||||||
#[derive(FromArgs)]
|
#[derive(FromArgs)]
|
||||||
@ -40,25 +42,42 @@ async fn main() -> io::Result<()> {
|
|||||||
let domain = options.domain.unwrap_or(options.host);
|
let domain = options.domain.unwrap_or(options.host);
|
||||||
let content = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
|
let content = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
|
||||||
|
|
||||||
let mut config = ClientConfig::new();
|
let mut root_cert_store = rustls::RootCertStore::empty();
|
||||||
if let Some(cafile) = &options.cafile {
|
if let Some(cafile) = &options.cafile {
|
||||||
let mut pem = BufReader::new(File::open(cafile)?);
|
let mut pem = BufReader::new(File::open(cafile)?);
|
||||||
config
|
let certs = rustls_pemfile::certs(&mut pem)?;
|
||||||
.root_store
|
let trust_anchors = certs.iter().map(|cert| {
|
||||||
.add_pem_file(&mut pem)
|
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
||||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?;
|
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||||
|
ta.subject,
|
||||||
|
ta.spki,
|
||||||
|
ta.name_constraints,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
root_cert_store.add_server_trust_anchors(trust_anchors);
|
||||||
} else {
|
} else {
|
||||||
config
|
root_cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(
|
||||||
.root_store
|
|ta| {
|
||||||
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
|
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||||
|
ta.subject,
|
||||||
|
ta.spki,
|
||||||
|
ta.name_constraints,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let config = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(root_cert_store)
|
||||||
|
.with_no_client_auth(); // i guess this was previously the default?
|
||||||
let connector = TlsConnector::from(Arc::new(config));
|
let connector = TlsConnector::from(Arc::new(config));
|
||||||
|
|
||||||
let stream = TcpStream::connect(&addr).await?;
|
let stream = TcpStream::connect(&addr).await?;
|
||||||
|
|
||||||
let (mut stdin, mut stdout) = (tokio_stdin(), tokio_stdout());
|
let (mut stdin, mut stdout) = (tokio_stdin(), tokio_stdout());
|
||||||
|
|
||||||
let domain = DNSNameRef::try_from_ascii_str(&domain)
|
let domain = rustls::ServerName::try_from(domain.as_str())
|
||||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;
|
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;
|
||||||
|
|
||||||
let mut stream = connector.connect(domain, stream).await?;
|
let mut stream = connector.connect(domain, stream).await?;
|
||||||
|
@ -8,3 +8,4 @@ edition = "2018"
|
|||||||
tokio = { version = "1.0", features = [ "full" ] }
|
tokio = { version = "1.0", features = [ "full" ] }
|
||||||
argh = "0.1"
|
argh = "0.1"
|
||||||
tokio-rustls = { path = "../.." }
|
tokio-rustls = { path = "../.." }
|
||||||
|
rustls-pemfile = "0.2.1"
|
@ -1,4 +1,5 @@
|
|||||||
use argh::FromArgs;
|
use argh::FromArgs;
|
||||||
|
use rustls_pemfile::{certs, rsa_private_keys};
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{self, BufReader};
|
use std::io::{self, BufReader};
|
||||||
use std::net::ToSocketAddrs;
|
use std::net::ToSocketAddrs;
|
||||||
@ -6,8 +7,7 @@ use std::path::{Path, PathBuf};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{copy, sink, split, AsyncWriteExt};
|
use tokio::io::{copy, sink, split, AsyncWriteExt};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tokio_rustls::rustls::internal::pemfile::{certs, rsa_private_keys};
|
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
|
||||||
use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
|
|
||||||
use tokio_rustls::TlsAcceptor;
|
use tokio_rustls::TlsAcceptor;
|
||||||
|
|
||||||
/// Tokio Rustls server example
|
/// Tokio Rustls server example
|
||||||
@ -33,11 +33,13 @@ struct Options {
|
|||||||
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
|
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
|
||||||
certs(&mut BufReader::new(File::open(path)?))
|
certs(&mut BufReader::new(File::open(path)?))
|
||||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
|
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
|
||||||
|
.map(|mut certs| certs.drain(..).map(Certificate).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
|
fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
|
||||||
rsa_private_keys(&mut BufReader::new(File::open(path)?))
|
rsa_private_keys(&mut BufReader::new(File::open(path)?))
|
||||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
|
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
|
||||||
|
.map(|mut keys| keys.drain(..).map(PrivateKey).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -53,9 +55,10 @@ async fn main() -> io::Result<()> {
|
|||||||
let mut keys = load_keys(&options.key)?;
|
let mut keys = load_keys(&options.key)?;
|
||||||
let flag_echo = options.echo_mode;
|
let flag_echo = options.echo_mode;
|
||||||
|
|
||||||
let mut config = ServerConfig::new(NoClientAuth::new());
|
let config = rustls::ServerConfig::builder()
|
||||||
config
|
.with_safe_defaults()
|
||||||
.set_single_cert(certs, keys.remove(0))
|
.with_no_client_auth()
|
||||||
|
.with_single_cert(certs, keys.remove(0))
|
||||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
|
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
|
||||||
let acceptor = TlsAcceptor::from(Arc::new(config));
|
let acceptor = TlsAcceptor::from(Arc::new(config));
|
||||||
|
|
||||||
|
@ -1,36 +1,35 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::common::IoSession;
|
use crate::common::IoSession;
|
||||||
use rustls::Session;
|
|
||||||
|
|
||||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||||
/// protocol.
|
/// protocol.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct TlsStream<IO> {
|
pub struct TlsStream<IO> {
|
||||||
pub(crate) io: IO,
|
pub(crate) io: IO,
|
||||||
pub(crate) session: ClientSession,
|
pub(crate) session: ClientConnection,
|
||||||
pub(crate) state: TlsState,
|
pub(crate) state: TlsState,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<IO> TlsStream<IO> {
|
impl<IO> TlsStream<IO> {
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_ref(&self) -> (&IO, &ClientSession) {
|
pub fn get_ref(&self) -> (&IO, &ClientConnection) {
|
||||||
(&self.io, &self.session)
|
(&self.io, &self.session)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) {
|
pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
|
||||||
(&mut self.io, &mut self.session)
|
(&mut self.io, &mut self.session)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn into_inner(self) -> (IO, ClientSession) {
|
pub fn into_inner(self) -> (IO, ClientConnection) {
|
||||||
(self.io, self.session)
|
(self.io, self.session)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<IO> IoSession for TlsStream<IO> {
|
impl<IO> IoSession for TlsStream<IO> {
|
||||||
type Io = IO;
|
type Io = IO;
|
||||||
type Session = ClientSession;
|
type Session = ClientConnection;
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn skip_handshake(&self) -> bool {
|
fn skip_handshake(&self) -> bool {
|
||||||
@ -68,7 +67,7 @@ where
|
|||||||
|
|
||||||
match stream.as_mut_pin().poll_read(cx, buf) {
|
match stream.as_mut_pin().poll_read(cx, buf) {
|
||||||
Poll::Ready(Ok(())) => {
|
Poll::Ready(Ok(())) => {
|
||||||
if prev == buf.remaining() {
|
if prev == buf.remaining() || stream.eof {
|
||||||
this.state.shutdown_read();
|
this.state.shutdown_read();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use crate::common::{Stream, TlsState};
|
use crate::common::{Stream, TlsState};
|
||||||
use rustls::Session;
|
use rustls::{ConnectionCommon, SideData};
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
|
use std::ops::{Deref, DerefMut};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use std::{io, mem};
|
use std::{io, mem};
|
||||||
@ -15,28 +16,30 @@ pub(crate) trait IoSession {
|
|||||||
fn into_io(self) -> Self::Io;
|
fn into_io(self) -> Self::Io;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) enum MidHandshake<IS> {
|
pub(crate) enum MidHandshake<IS: IoSession> {
|
||||||
Handshaking(IS),
|
Handshaking(IS),
|
||||||
End,
|
End,
|
||||||
|
Error { io: IS::Io, error: io::Error },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<IS> Future for MidHandshake<IS>
|
impl<IS, SD> Future for MidHandshake<IS>
|
||||||
where
|
where
|
||||||
IS: IoSession + Unpin,
|
IS: IoSession + Unpin,
|
||||||
IS::Io: AsyncRead + AsyncWrite + Unpin,
|
IS::Io: AsyncRead + AsyncWrite + Unpin,
|
||||||
IS::Session: Session + Unpin,
|
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin,
|
||||||
|
SD: SideData,
|
||||||
{
|
{
|
||||||
type Output = Result<IS, (io::Error, IS::Io)>;
|
type Output = Result<IS, (io::Error, IS::Io)>;
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
|
|
||||||
let mut stream =
|
let mut stream = match mem::replace(this, MidHandshake::End) {
|
||||||
if let MidHandshake::Handshaking(stream) = mem::replace(this, MidHandshake::End) {
|
MidHandshake::Handshaking(stream) => stream,
|
||||||
stream
|
// Starting the handshake returned an error; fail the future immediately.
|
||||||
} else {
|
MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))),
|
||||||
panic!("unexpected polling after handshake")
|
_ => panic!("unexpected polling after handshake"),
|
||||||
};
|
};
|
||||||
|
|
||||||
if !stream.skip_handshake() {
|
if !stream.skip_handshake() {
|
||||||
let (state, io, session) = stream.get_mut();
|
let (state, io, session) = stream.get_mut();
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
mod handshake;
|
mod handshake;
|
||||||
|
|
||||||
pub(crate) use handshake::{IoSession, MidHandshake};
|
pub(crate) use handshake::{IoSession, MidHandshake};
|
||||||
use rustls::Session;
|
use rustls::{ConnectionCommon, SideData};
|
||||||
use std::io::{self, IoSlice, Read, Write};
|
use std::io::{self, IoSlice, Read, Write};
|
||||||
|
use std::ops::{Deref, DerefMut};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
@ -57,20 +58,26 @@ impl TlsState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Stream<'a, IO, S> {
|
pub struct Stream<'a, IO, C> {
|
||||||
pub io: &'a mut IO,
|
pub io: &'a mut IO,
|
||||||
pub session: &'a mut S,
|
pub session: &'a mut C,
|
||||||
pub eof: bool,
|
pub eof: bool,
|
||||||
|
pub unexpected_eof: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
|
||||||
pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
|
where
|
||||||
|
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
|
||||||
|
SD: SideData,
|
||||||
|
{
|
||||||
|
pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
|
||||||
Stream {
|
Stream {
|
||||||
io,
|
io,
|
||||||
session,
|
session,
|
||||||
// The state so far is only used to detect EOF, so either Stream
|
// The state so far is only used to detect EOF, so either Stream
|
||||||
// or EarlyData state should both be all right.
|
// or EarlyData state should both be all right.
|
||||||
eof: false,
|
eof: false,
|
||||||
|
unexpected_eof: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,7 +221,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
|
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
|
||||||
|
where
|
||||||
|
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
|
||||||
|
SD: SideData,
|
||||||
|
{
|
||||||
fn poll_read(
|
fn poll_read(
|
||||||
mut self: Pin<&mut Self>,
|
mut self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
@ -223,10 +234,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
|
|||||||
let prev = buf.remaining();
|
let prev = buf.remaining();
|
||||||
|
|
||||||
while buf.remaining() != 0 {
|
while buf.remaining() != 0 {
|
||||||
let mut would_block = false;
|
let mut io_pending = false;
|
||||||
|
|
||||||
// read a packet
|
// read a packet
|
||||||
while self.session.wants_read() {
|
while !self.eof && self.session.wants_read() {
|
||||||
match self.read_io(cx) {
|
match self.read_io(cx) {
|
||||||
Poll::Ready(Ok(0)) => {
|
Poll::Ready(Ok(0)) => {
|
||||||
self.eof = true;
|
self.eof = true;
|
||||||
@ -234,30 +245,51 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
|
|||||||
}
|
}
|
||||||
Poll::Ready(Ok(_)) => (),
|
Poll::Ready(Ok(_)) => (),
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
would_block = true;
|
io_pending = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return match self.session.read(buf.initialize_unfilled()) {
|
return match self.session.reader().read(buf.initialize_unfilled()) {
|
||||||
Ok(0) if prev == buf.remaining() && would_block => Poll::Pending,
|
// If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
|
||||||
|
// connection with a `CloseNotify` message and no more data will be forthcoming.
|
||||||
|
Ok(0) => break,
|
||||||
|
|
||||||
|
// Rustls yielded more data: advance the buffer, then see if more data is coming.
|
||||||
Ok(n) => {
|
Ok(n) => {
|
||||||
buf.advance(n);
|
buf.advance(n);
|
||||||
|
|
||||||
if self.eof || would_block {
|
if self.eof || io_pending {
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(ref err)
|
|
||||||
if err.kind() == io::ErrorKind::ConnectionAborted
|
// Rustls doesn't have more data to yield, but it believes the connection is open.
|
||||||
&& prev != buf.remaining() =>
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||||
{
|
if prev == buf.remaining() && io_pending {
|
||||||
break
|
Poll::Pending
|
||||||
|
} else if self.eof || io_pending {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Err(err) if err.kind() == io::ErrorKind::UnexpectedEof => {
|
||||||
|
self.eof = true;
|
||||||
|
self.unexpected_eof = true;
|
||||||
|
if prev == buf.remaining() {
|
||||||
|
Poll::Ready(Err(err))
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This should be unreachable.
|
||||||
Err(err) => Poll::Ready(Err(err)),
|
Err(err) => Poll::Ready(Err(err)),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -266,7 +298,11 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> {
|
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C>
|
||||||
|
where
|
||||||
|
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
|
||||||
|
SD: SideData,
|
||||||
|
{
|
||||||
fn poll_write(
|
fn poll_write(
|
||||||
mut self: Pin<&mut Self>,
|
mut self: Pin<&mut Self>,
|
||||||
cx: &mut Context,
|
cx: &mut Context,
|
||||||
@ -277,7 +313,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
|
|||||||
while pos != buf.len() {
|
while pos != buf.len() {
|
||||||
let mut would_block = false;
|
let mut would_block = false;
|
||||||
|
|
||||||
match self.session.write(&buf[pos..]) {
|
match self.session.writer().write(&buf[pos..]) {
|
||||||
Ok(n) => pos += n,
|
Ok(n) => pos += n,
|
||||||
Err(err) => return Poll::Ready(Err(err)),
|
Err(err) => return Poll::Ready(Err(err)),
|
||||||
};
|
};
|
||||||
@ -304,7 +340,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
self.session.flush()?;
|
self.session.writer().flush()?;
|
||||||
while self.session.wants_write() {
|
while self.session.wants_write() {
|
||||||
ready!(self.write_io(cx))?;
|
ready!(self.write_io(cx))?;
|
||||||
}
|
}
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
use super::Stream;
|
use super::Stream;
|
||||||
use futures_util::future::poll_fn;
|
use futures_util::future::poll_fn;
|
||||||
use futures_util::task::noop_waker_ref;
|
use futures_util::task::noop_waker_ref;
|
||||||
use rustls::internal::pemfile::{certs, rsa_private_keys};
|
use rustls::{ClientConnection, Connection, OwnedTrustAnchor, RootCertStore, ServerConnection};
|
||||||
use rustls::{ClientConfig, ClientSession, NoClientAuth, ServerConfig, ServerSession, Session};
|
use rustls_pemfile::{certs, rsa_private_keys};
|
||||||
use std::io::{self, BufReader, Cursor, Read, Write};
|
use std::io::{self, BufReader, Cursor, Read, Write};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||||
use webpki::DNSNameRef;
|
|
||||||
|
|
||||||
struct Good<'a>(&'a mut dyn Session);
|
struct Good<'a>(&'a mut Connection);
|
||||||
|
|
||||||
impl<'a> AsyncRead for Good<'a> {
|
impl<'a> AsyncRead for Good<'a> {
|
||||||
fn poll_read(
|
fn poll_read(
|
||||||
@ -50,9 +49,10 @@ impl<'a> AsyncWrite for Good<'a> {
|
|||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
self.0.send_close_notify();
|
self.0.send_close_notify();
|
||||||
Poll::Ready(Ok(()))
|
dbg!("sent close notify");
|
||||||
|
self.poll_flush(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,23 +120,28 @@ impl AsyncWrite for Eof {
|
|||||||
async fn stream_good() -> io::Result<()> {
|
async fn stream_good() -> io::Result<()> {
|
||||||
const FILE: &[u8] = include_bytes!("../../README.md");
|
const FILE: &[u8] = include_bytes!("../../README.md");
|
||||||
|
|
||||||
let (mut server, mut client) = make_pair();
|
let (server, mut client) = make_pair();
|
||||||
|
let mut server = Connection::from(server);
|
||||||
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
|
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
|
||||||
io::copy(&mut Cursor::new(FILE), &mut server)?;
|
io::copy(&mut Cursor::new(FILE), &mut server.writer())?;
|
||||||
|
server.send_close_notify();
|
||||||
|
let mut server = Connection::from(server);
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut good = Good(&mut server);
|
let mut good = Good(&mut server);
|
||||||
let mut stream = Stream::new(&mut good, &mut client);
|
let mut stream = Stream::new(&mut good, &mut client);
|
||||||
|
|
||||||
let mut buf = Vec::new();
|
let mut buf = Vec::new();
|
||||||
stream.read_to_end(&mut buf).await?;
|
dbg!(stream.read_to_end(&mut buf).await)?;
|
||||||
assert_eq!(buf, FILE);
|
assert_eq!(buf, FILE);
|
||||||
stream.write_all(b"Hello World!").await?;
|
dbg!(stream.write_all(b"Hello World!").await)?;
|
||||||
stream.flush().await?;
|
stream.session.send_close_notify();
|
||||||
|
dbg!(stream.shutdown().await)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut buf = String::new();
|
let mut buf = String::new();
|
||||||
server.read_to_string(&mut buf)?;
|
dbg!(server.process_new_packets()).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||||
|
dbg!(server.reader().read_to_string(&mut buf))?;
|
||||||
assert_eq!(buf, "Hello World!");
|
assert_eq!(buf, "Hello World!");
|
||||||
|
|
||||||
Ok(()) as io::Result<()>
|
Ok(()) as io::Result<()>
|
||||||
@ -144,9 +149,10 @@ async fn stream_good() -> io::Result<()> {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stream_bad() -> io::Result<()> {
|
async fn stream_bad() -> io::Result<()> {
|
||||||
let (mut server, mut client) = make_pair();
|
let (server, mut client) = make_pair();
|
||||||
|
let mut server = Connection::from(server);
|
||||||
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
|
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
|
||||||
client.set_buffer_limit(1024);
|
client.set_buffer_limit(Some(1024));
|
||||||
|
|
||||||
let mut bad = Pending;
|
let mut bad = Pending;
|
||||||
let mut stream = Stream::new(&mut bad, &mut client);
|
let mut stream = Stream::new(&mut bad, &mut client);
|
||||||
@ -170,7 +176,8 @@ async fn stream_bad() -> io::Result<()> {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stream_handshake() -> io::Result<()> {
|
async fn stream_handshake() -> io::Result<()> {
|
||||||
let (mut server, mut client) = make_pair();
|
let (server, mut client) = make_pair();
|
||||||
|
let mut server = Connection::from(server);
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut good = Good(&mut server);
|
let mut good = Good(&mut server);
|
||||||
@ -208,42 +215,72 @@ async fn stream_handshake_eof() -> io::Result<()> {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stream_eof() -> io::Result<()> {
|
async fn stream_eof() -> io::Result<()> {
|
||||||
let (mut server, mut client) = make_pair();
|
let (server, mut client) = make_pair();
|
||||||
|
let mut server = Connection::from(server);
|
||||||
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
|
poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?;
|
||||||
|
|
||||||
let mut good = Good(&mut server);
|
let mut good = Good(&mut server);
|
||||||
let mut stream = Stream::new(&mut good, &mut client).set_eof(true);
|
let mut stream = Stream::new(&mut good, &mut client).set_eof(true);
|
||||||
|
|
||||||
let mut buf = Vec::new();
|
let mut buf = Vec::new();
|
||||||
stream.read_to_end(&mut buf).await?;
|
let result = stream.read_to_end(&mut buf).await;
|
||||||
assert_eq!(buf.len(), 0);
|
assert_eq!(
|
||||||
|
result.err().map(|e| e.kind()),
|
||||||
|
Some(io::ErrorKind::UnexpectedEof)
|
||||||
|
);
|
||||||
|
|
||||||
Ok(()) as io::Result<()>
|
Ok(()) as io::Result<()>
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_pair() -> (ServerSession, ClientSession) {
|
fn make_pair() -> (ServerConnection, ClientConnection) {
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
|
||||||
const CERT: &str = include_str!("../../tests/end.cert");
|
const CERT: &str = include_str!("../../tests/end.cert");
|
||||||
const CHAIN: &str = include_str!("../../tests/end.chain");
|
const CHAIN: &str = include_str!("../../tests/end.chain");
|
||||||
const RSA: &str = include_str!("../../tests/end.rsa");
|
const RSA: &str = include_str!("../../tests/end.rsa");
|
||||||
|
|
||||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
|
||||||
|
.unwrap()
|
||||||
|
.drain(..)
|
||||||
|
.map(rustls::Certificate)
|
||||||
|
.collect();
|
||||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||||
let mut sconfig = ServerConfig::new(NoClientAuth::new());
|
let mut keys = keys.drain(..).map(rustls::PrivateKey);
|
||||||
sconfig.set_single_cert(cert, keys.pop().unwrap()).unwrap();
|
let sconfig = rustls::ServerConfig::builder()
|
||||||
let server = ServerSession::new(&Arc::new(sconfig));
|
.with_safe_defaults()
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_single_cert(cert, keys.next().unwrap())
|
||||||
|
.unwrap();
|
||||||
|
let server = ServerConnection::new(Arc::new(sconfig)).unwrap();
|
||||||
|
|
||||||
let domain = DNSNameRef::try_from_ascii_str("localhost").unwrap();
|
let domain = rustls::ServerName::try_from("localhost").unwrap();
|
||||||
let mut cconfig = ClientConfig::new();
|
let mut client_root_cert_store = RootCertStore::empty();
|
||||||
let mut chain = BufReader::new(Cursor::new(CHAIN));
|
let mut chain = BufReader::new(Cursor::new(CHAIN));
|
||||||
cconfig.root_store.add_pem_file(&mut chain).unwrap();
|
let certs = certs(&mut chain).unwrap();
|
||||||
let client = ClientSession::new(&Arc::new(cconfig), domain);
|
let trust_anchors = certs
|
||||||
|
.iter()
|
||||||
|
.map(|cert| {
|
||||||
|
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
||||||
|
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||||
|
ta.subject,
|
||||||
|
ta.spki,
|
||||||
|
ta.name_constraints,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter());
|
||||||
|
let cconfig = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(client_root_cert_store)
|
||||||
|
.with_no_client_auth();
|
||||||
|
let client = ClientConnection::new(Arc::new(cconfig), domain).unwrap();
|
||||||
|
|
||||||
(server, client)
|
(server, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn do_handshake(
|
fn do_handshake(
|
||||||
client: &mut ClientSession,
|
client: &mut ClientConnection,
|
||||||
server: &mut ServerSession,
|
server: &mut Connection,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<io::Result<()>> {
|
) -> Poll<io::Result<()>> {
|
||||||
let mut good = Good(server);
|
let mut good = Good(server);
|
||||||
|
@ -14,14 +14,13 @@ mod common;
|
|||||||
pub mod server;
|
pub mod server;
|
||||||
|
|
||||||
use common::{MidHandshake, Stream, TlsState};
|
use common::{MidHandshake, Stream, TlsState};
|
||||||
use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session};
|
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
use webpki::DNSNameRef;
|
|
||||||
|
|
||||||
pub use rustls;
|
pub use rustls;
|
||||||
pub use webpki;
|
pub use webpki;
|
||||||
@ -68,19 +67,29 @@ impl TlsConnector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
|
pub fn connect<IO>(&self, domain: rustls::ServerName, stream: IO) -> Connect<IO>
|
||||||
where
|
where
|
||||||
IO: AsyncRead + AsyncWrite + Unpin,
|
IO: AsyncRead + AsyncWrite + Unpin,
|
||||||
{
|
{
|
||||||
self.connect_with(domain, stream, |_| ())
|
self.connect_with(domain, stream, |_| ())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
|
pub fn connect_with<IO, F>(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect<IO>
|
||||||
where
|
where
|
||||||
IO: AsyncRead + AsyncWrite + Unpin,
|
IO: AsyncRead + AsyncWrite + Unpin,
|
||||||
F: FnOnce(&mut ClientSession),
|
F: FnOnce(&mut ClientConnection),
|
||||||
{
|
{
|
||||||
let mut session = ClientSession::new(&self.inner, domain);
|
let mut session = match ClientConnection::new(self.inner.clone(), domain) {
|
||||||
|
Ok(session) => session,
|
||||||
|
Err(error) => {
|
||||||
|
return Connect(MidHandshake::Error {
|
||||||
|
io: stream,
|
||||||
|
// TODO(eliza): should this really return an `io::Error`?
|
||||||
|
// Probably not...
|
||||||
|
error: io::Error::new(io::ErrorKind::Other, error),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
f(&mut session);
|
f(&mut session);
|
||||||
|
|
||||||
Connect(MidHandshake::Handshaking(client::TlsStream {
|
Connect(MidHandshake::Handshaking(client::TlsStream {
|
||||||
@ -113,9 +122,19 @@ impl TlsAcceptor {
|
|||||||
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
|
pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
|
||||||
where
|
where
|
||||||
IO: AsyncRead + AsyncWrite + Unpin,
|
IO: AsyncRead + AsyncWrite + Unpin,
|
||||||
F: FnOnce(&mut ServerSession),
|
F: FnOnce(&mut ServerConnection),
|
||||||
{
|
{
|
||||||
let mut session = ServerSession::new(&self.inner);
|
let mut session = match ServerConnection::new(self.inner.clone()) {
|
||||||
|
Ok(session) => session,
|
||||||
|
Err(error) => {
|
||||||
|
return Accept(MidHandshake::Error {
|
||||||
|
io: stream,
|
||||||
|
// TODO(eliza): should this really return an `io::Error`?
|
||||||
|
// Probably not...
|
||||||
|
error: io::Error::new(io::ErrorKind::Other, error),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
f(&mut session);
|
f(&mut session);
|
||||||
|
|
||||||
Accept(MidHandshake::Handshaking(server::TlsStream {
|
Accept(MidHandshake::Handshaking(server::TlsStream {
|
||||||
@ -201,7 +220,7 @@ pub enum TlsStream<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T> TlsStream<T> {
|
impl<T> TlsStream<T> {
|
||||||
pub fn get_ref(&self) -> (&T, &dyn Session) {
|
pub fn get_ref(&self) -> (&T, &CommonState) {
|
||||||
use TlsStream::*;
|
use TlsStream::*;
|
||||||
match self {
|
match self {
|
||||||
Client(io) => {
|
Client(io) => {
|
||||||
@ -215,7 +234,7 @@ impl<T> TlsStream<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_mut(&mut self) -> (&mut T, &mut dyn Session) {
|
pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
|
||||||
use TlsStream::*;
|
use TlsStream::*;
|
||||||
match self {
|
match self {
|
||||||
Client(io) => {
|
Client(io) => {
|
||||||
|
@ -1,36 +1,35 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::common::IoSession;
|
use crate::common::IoSession;
|
||||||
use rustls::Session;
|
|
||||||
|
|
||||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||||
/// protocol.
|
/// protocol.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct TlsStream<IO> {
|
pub struct TlsStream<IO> {
|
||||||
pub(crate) io: IO,
|
pub(crate) io: IO,
|
||||||
pub(crate) session: ServerSession,
|
pub(crate) session: ServerConnection,
|
||||||
pub(crate) state: TlsState,
|
pub(crate) state: TlsState,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<IO> TlsStream<IO> {
|
impl<IO> TlsStream<IO> {
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_ref(&self) -> (&IO, &ServerSession) {
|
pub fn get_ref(&self) -> (&IO, &ServerConnection) {
|
||||||
(&self.io, &self.session)
|
(&self.io, &self.session)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_mut(&mut self) -> (&mut IO, &mut ServerSession) {
|
pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
|
||||||
(&mut self.io, &mut self.session)
|
(&mut self.io, &mut self.session)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn into_inner(self) -> (IO, ServerSession) {
|
pub fn into_inner(self) -> (IO, ServerConnection) {
|
||||||
(self.io, self.session)
|
(self.io, self.session)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<IO> IoSession for TlsStream<IO> {
|
impl<IO> IoSession for TlsStream<IO> {
|
||||||
type Io = IO;
|
type Io = IO;
|
||||||
type Session = ServerSession;
|
type Session = ServerConnection;
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn skip_handshake(&self) -> bool {
|
fn skip_handshake(&self) -> bool {
|
||||||
@ -67,7 +66,7 @@ where
|
|||||||
|
|
||||||
match stream.as_mut_pin().poll_read(cx, buf) {
|
match stream.as_mut_pin().poll_read(cx, buf) {
|
||||||
Poll::Ready(Ok(())) => {
|
Poll::Ready(Ok(())) => {
|
||||||
if prev == buf.remaining() {
|
if prev == buf.remaining() || stream.eof {
|
||||||
this.state.shutdown_read();
|
this.state.shutdown_read();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
use rustls::ClientConfig;
|
use std::convert::TryFrom;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::net::ToSocketAddrs;
|
use std::net::ToSocketAddrs;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
use tokio_rustls::{
|
||||||
|
client::TlsStream,
|
||||||
|
rustls::{self, ClientConfig, OwnedTrustAnchor},
|
||||||
|
TlsConnector,
|
||||||
|
};
|
||||||
|
|
||||||
async fn get(
|
async fn get(
|
||||||
config: Arc<ClientConfig>,
|
config: Arc<ClientConfig>,
|
||||||
@ -15,7 +19,7 @@ async fn get(
|
|||||||
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
|
let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain);
|
||||||
|
|
||||||
let addr = (domain, port).to_socket_addrs()?.next().unwrap();
|
let addr = (domain, port).to_socket_addrs()?.next().unwrap();
|
||||||
let domain = webpki::DNSNameRef::try_from_ascii_str(&domain).unwrap();
|
let domain = rustls::ServerName::try_from(domain).unwrap();
|
||||||
let mut buf = Vec::new();
|
let mut buf = Vec::new();
|
||||||
|
|
||||||
let stream = TcpStream::connect(&addr).await?;
|
let stream = TcpStream::connect(&addr).await?;
|
||||||
@ -29,16 +33,31 @@ async fn get(
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_tls12() -> io::Result<()> {
|
async fn test_tls12() -> io::Result<()> {
|
||||||
let mut config = ClientConfig::new();
|
let mut root_store = rustls::RootCertStore::empty();
|
||||||
config
|
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||||
.root_store
|
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||||
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
|
ta.subject,
|
||||||
config.versions = vec![rustls::ProtocolVersion::TLSv1_2];
|
ta.spki,
|
||||||
|
ta.name_constraints,
|
||||||
|
)
|
||||||
|
}));
|
||||||
|
let config = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_default_cipher_suites()
|
||||||
|
.with_safe_default_kx_groups()
|
||||||
|
.with_protocol_versions(&[&rustls::version::TLS12])
|
||||||
|
.unwrap()
|
||||||
|
.with_root_certificates(root_store)
|
||||||
|
.with_no_client_auth();
|
||||||
|
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
let domain = "tls-v1-2.badssl.com";
|
let domain = "tls-v1-2.badssl.com";
|
||||||
|
|
||||||
let (_, output) = get(config.clone(), domain, 1012).await?;
|
let (_, output) = get(config.clone(), domain, 1012).await?;
|
||||||
assert!(output.contains("<title>tls-v1-2.badssl.com</title>"));
|
assert!(
|
||||||
|
output.contains("<title>tls-v1-2.badssl.com</title>"),
|
||||||
|
"failed badssl test, output: {}",
|
||||||
|
output
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -52,15 +71,27 @@ fn test_tls13() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_modern() -> io::Result<()> {
|
async fn test_modern() -> io::Result<()> {
|
||||||
let mut config = ClientConfig::new();
|
let mut root_store = rustls::RootCertStore::empty();
|
||||||
config
|
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||||
.root_store
|
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||||
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
|
ta.subject,
|
||||||
|
ta.spki,
|
||||||
|
ta.name_constraints,
|
||||||
|
)
|
||||||
|
}));
|
||||||
|
let config = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(root_store)
|
||||||
|
.with_no_client_auth();
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
let domain = "mozilla-modern.badssl.com";
|
let domain = "mozilla-modern.badssl.com";
|
||||||
|
|
||||||
let (_, output) = get(config.clone(), domain, 443).await?;
|
let (_, output) = get(config.clone(), domain, 443).await?;
|
||||||
assert!(output.contains("<title>mozilla-modern.badssl.com</title>"));
|
assert!(
|
||||||
|
output.contains("<title>mozilla-modern.badssl.com</title>"),
|
||||||
|
"failed badssl test, output: {}",
|
||||||
|
output
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
#![cfg(feature = "early-data")]
|
#![cfg(feature = "early-data")]
|
||||||
|
|
||||||
use futures_util::{future, future::Future, ready};
|
use futures_util::{future, future::Future, ready};
|
||||||
use rustls::ClientConfig;
|
use rustls::RootCertStore;
|
||||||
|
use std::convert::TryFrom;
|
||||||
use std::io::{self, BufRead, BufReader, Cursor};
|
use std::io::{self, BufRead, BufReader, Cursor};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
@ -12,7 +13,11 @@ use std::time::Duration;
|
|||||||
use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
use tokio_rustls::{
|
||||||
|
client::TlsStream,
|
||||||
|
rustls::{self, ClientConfig, OwnedTrustAnchor},
|
||||||
|
TlsConnector,
|
||||||
|
};
|
||||||
|
|
||||||
struct Read1<T>(T);
|
struct Read1<T>(T);
|
||||||
|
|
||||||
@ -34,7 +39,7 @@ async fn send(
|
|||||||
) -> io::Result<TlsStream<TcpStream>> {
|
) -> io::Result<TlsStream<TcpStream>> {
|
||||||
let connector = TlsConnector::from(config).early_data(true);
|
let connector = TlsConnector::from(config).early_data(true);
|
||||||
let stream = TcpStream::connect(&addr).await?;
|
let stream = TcpStream::connect(&addr).await?;
|
||||||
let domain = webpki::DNSNameRef::try_from_ascii_str("testserver.com").unwrap();
|
let domain = rustls::ServerName::try_from("testserver.com").unwrap();
|
||||||
|
|
||||||
let mut stream = connector.connect(domain, stream).await?;
|
let mut stream = connector.connect(domain, stream).await?;
|
||||||
stream.write_all(data).await?;
|
stream.write_all(data).await?;
|
||||||
@ -81,10 +86,28 @@ async fn test_0rtt() -> io::Result<()> {
|
|||||||
// wait openssl server
|
// wait openssl server
|
||||||
sleep(Duration::from_secs(1)).await;
|
sleep(Duration::from_secs(1)).await;
|
||||||
|
|
||||||
let mut config = ClientConfig::new();
|
|
||||||
let mut chain = BufReader::new(Cursor::new(include_str!("end.chain")));
|
let mut chain = BufReader::new(Cursor::new(include_str!("end.chain")));
|
||||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
let certs = rustls_pemfile::certs(&mut chain).unwrap();
|
||||||
config.versions = vec![rustls::ProtocolVersion::TLSv1_3];
|
let trust_anchors = certs
|
||||||
|
.iter()
|
||||||
|
.map(|cert| {
|
||||||
|
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
||||||
|
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||||
|
ta.subject,
|
||||||
|
ta.spki,
|
||||||
|
ta.name_constraints,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let mut root_store = RootCertStore::empty();
|
||||||
|
root_store.add_server_trust_anchors(trust_anchors.into_iter());
|
||||||
|
let mut config = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_default_cipher_suites()
|
||||||
|
.with_safe_default_kx_groups()
|
||||||
|
.with_protocol_versions(&[&rustls::version::TLS13])
|
||||||
|
.unwrap()
|
||||||
|
.with_root_certificates(root_store)
|
||||||
|
.with_no_client_auth();
|
||||||
config.enable_early_data = true;
|
config.enable_early_data = true;
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
let addr = SocketAddr::from(([127, 0, 0, 1], 12354));
|
let addr = SocketAddr::from(([127, 0, 0, 1], 12354));
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
use futures_util::future::TryFutureExt;
|
use futures_util::future::TryFutureExt;
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use rustls::internal::pemfile::{certs, rsa_private_keys};
|
use rustls::{ClientConfig, OwnedTrustAnchor};
|
||||||
use rustls::{ClientConfig, ServerConfig};
|
use rustls_pemfile::{certs, rsa_private_keys};
|
||||||
|
use std::convert::TryFrom;
|
||||||
use std::io::{BufReader, Cursor};
|
use std::io::{BufReader, Cursor};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::mpsc::channel;
|
use std::sync::mpsc::channel;
|
||||||
@ -13,18 +14,24 @@ use tokio::runtime;
|
|||||||
use tokio_rustls::{TlsAcceptor, TlsConnector};
|
use tokio_rustls::{TlsAcceptor, TlsConnector};
|
||||||
|
|
||||||
const CERT: &str = include_str!("end.cert");
|
const CERT: &str = include_str!("end.cert");
|
||||||
const CHAIN: &str = include_str!("end.chain");
|
const CHAIN: &[u8] = include_bytes!("end.chain");
|
||||||
const RSA: &str = include_str!("end.rsa");
|
const RSA: &str = include_str!("end.rsa");
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref TEST_SERVER: (SocketAddr, &'static str, &'static str) = {
|
static ref TEST_SERVER: (SocketAddr, &'static str, &'static [u8]) = {
|
||||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
|
let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
|
||||||
|
.unwrap()
|
||||||
|
.drain(..)
|
||||||
|
.map(rustls::Certificate)
|
||||||
|
.collect();
|
||||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||||
|
let mut keys = keys.drain(..).map(rustls::PrivateKey);
|
||||||
|
|
||||||
let mut config = ServerConfig::new(rustls::NoClientAuth::new());
|
let config = rustls::ServerConfig::builder()
|
||||||
config
|
.with_safe_defaults()
|
||||||
.set_single_cert(cert, keys.pop().unwrap())
|
.with_no_client_auth()
|
||||||
.expect("invalid key or certificate");
|
.with_single_cert(cert, keys.next().unwrap())
|
||||||
|
.unwrap();
|
||||||
let acceptor = TlsAcceptor::from(Arc::new(config));
|
let acceptor = TlsAcceptor::from(Arc::new(config));
|
||||||
|
|
||||||
let (send, recv) = channel();
|
let (send, recv) = channel();
|
||||||
@ -70,14 +77,14 @@ lazy_static! {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_server() -> &'static (SocketAddr, &'static str, &'static str) {
|
fn start_server() -> &'static (SocketAddr, &'static str, &'static [u8]) {
|
||||||
&*TEST_SERVER
|
&*TEST_SERVER
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> {
|
async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> {
|
||||||
const FILE: &[u8] = include_bytes!("../README.md");
|
const FILE: &[u8] = include_bytes!("../README.md");
|
||||||
|
|
||||||
let domain = webpki::DNSNameRef::try_from_ascii_str(domain).unwrap();
|
let domain = rustls::ServerName::try_from(domain).unwrap();
|
||||||
let config = TlsConnector::from(config);
|
let config = TlsConnector::from(config);
|
||||||
let mut buf = vec![0; FILE.len()];
|
let mut buf = vec![0; FILE.len()];
|
||||||
|
|
||||||
@ -102,12 +109,27 @@ async fn pass() -> io::Result<()> {
|
|||||||
use std::time::*;
|
use std::time::*;
|
||||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
|
|
||||||
let mut config = ClientConfig::new();
|
let chain = certs(&mut std::io::Cursor::new(*chain)).unwrap();
|
||||||
let mut chain = BufReader::new(Cursor::new(chain));
|
let trust_anchors = chain
|
||||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
.iter()
|
||||||
|
.map(|cert| {
|
||||||
|
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
||||||
|
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||||
|
ta.subject,
|
||||||
|
ta.spki,
|
||||||
|
ta.name_constraints,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let mut root_store = rustls::RootCertStore::empty();
|
||||||
|
root_store.add_server_trust_anchors(trust_anchors.into_iter());
|
||||||
|
let config = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(root_store)
|
||||||
|
.with_no_client_auth();
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
|
|
||||||
start_client(*addr, domain, config.clone()).await?;
|
start_client(*addr, domain, config).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -116,9 +138,24 @@ async fn pass() -> io::Result<()> {
|
|||||||
async fn fail() -> io::Result<()> {
|
async fn fail() -> io::Result<()> {
|
||||||
let (addr, domain, chain) = start_server();
|
let (addr, domain, chain) = start_server();
|
||||||
|
|
||||||
let mut config = ClientConfig::new();
|
let chain = certs(&mut std::io::Cursor::new(*chain)).unwrap();
|
||||||
let mut chain = BufReader::new(Cursor::new(chain));
|
let trust_anchors = chain
|
||||||
config.root_store.add_pem_file(&mut chain).unwrap();
|
.iter()
|
||||||
|
.map(|cert| {
|
||||||
|
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
||||||
|
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||||
|
ta.subject,
|
||||||
|
ta.spki,
|
||||||
|
ta.name_constraints,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let mut root_store = rustls::RootCertStore::empty();
|
||||||
|
root_store.add_server_trust_anchors(trust_anchors.into_iter());
|
||||||
|
let config = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(root_store)
|
||||||
|
.with_no_client_auth();
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
|
|
||||||
assert_ne!(domain, &"google.com");
|
assert_ne!(domain, &"google.com");
|
||||||
|
Loading…
Reference in New Issue
Block a user