Add LazyConfigAcceptor API (#69)

This commit is contained in:
Dirkjan Ochtman 2021-10-30 08:10:58 +02:00 committed by GitHub
parent 48caaf751f
commit 33506018e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 216 additions and 42 deletions

View File

@ -1,11 +1,9 @@
use super::Stream; use super::Stream;
use futures_util::future::poll_fn; use futures_util::future::poll_fn;
use futures_util::task::noop_waker_ref; use futures_util::task::noop_waker_ref;
use rustls::{ClientConnection, Connection, OwnedTrustAnchor, RootCertStore, ServerConnection}; use rustls::{ClientConnection, Connection, ServerConnection};
use rustls_pemfile::{certs, rsa_private_keys}; use std::io::{self, Cursor, Read, Write};
use std::io::{self, BufReader, Cursor, Read, Write};
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
@ -261,45 +259,11 @@ async fn stream_eof() -> io::Result<()> {
fn make_pair() -> (ServerConnection, ClientConnection) { fn make_pair() -> (ServerConnection, ClientConnection) {
use std::convert::TryFrom; use std::convert::TryFrom;
const CERT: &str = include_str!("../../tests/end.cert"); let (sconfig, cconfig) = utils::make_configs();
const CHAIN: &str = include_str!("../../tests/end.chain"); let server = ServerConnection::new(sconfig).unwrap();
const RSA: &str = include_str!("../../tests/end.rsa");
let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
.unwrap()
.drain(..)
.map(rustls::Certificate)
.collect();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let mut keys = keys.drain(..).map(rustls::PrivateKey);
let sconfig = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, keys.next().unwrap())
.unwrap();
let server = ServerConnection::new(Arc::new(sconfig)).unwrap();
let domain = rustls::ServerName::try_from("localhost").unwrap(); let domain = rustls::ServerName::try_from("localhost").unwrap();
let mut client_root_cert_store = RootCertStore::empty(); let client = ClientConnection::new(cconfig, domain).unwrap();
let mut chain = BufReader::new(Cursor::new(CHAIN));
let certs = certs(&mut chain).unwrap();
let trust_anchors = certs
.iter()
.map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
.collect::<Vec<_>>();
client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter());
let cconfig = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(client_root_cert_store)
.with_no_client_auth();
let client = ClientConnection::new(Arc::new(cconfig), domain).unwrap();
(server, client) (server, client)
} }
@ -322,3 +286,6 @@ fn do_handshake(
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
// Share `utils` module with integration tests
include!("../../tests/utils.rs");

View File

@ -188,6 +188,124 @@ impl TlsAcceptor {
} }
} }
pub struct LazyConfigAcceptor<IO> {
acceptor: rustls::server::Acceptor,
buf: Vec<u8>,
used: usize,
io: Option<IO>,
}
impl<IO> LazyConfigAcceptor<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
Self {
acceptor,
buf: vec![0; 512],
used: 0,
io: Some(io),
}
}
}
impl<IO> Future for LazyConfigAcceptor<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<StartHandshake<IO>, io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
let io = match this.io.as_mut() {
Some(io) => io,
None => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"acceptor cannot be polled after acceptance",
)))
}
};
let mut buf = ReadBuf::new(&mut this.buf);
buf.advance(this.used);
if buf.remaining() > 0 {
if let Err(err) = ready!(Pin::new(io).poll_read(cx, &mut buf)) {
return Poll::Ready(Err(err));
}
}
let read = match this.acceptor.read_tls(&mut buf.filled()) {
Ok(read) => read,
Err(err) => return Poll::Ready(Err(err)),
};
let received = buf.filled().len();
if read < received {
this.buf.copy_within(read.., 0);
this.used = received - read;
} else {
this.used = 0;
}
match this.acceptor.accept() {
Ok(Some(accepted)) => {
let io = this.io.take().unwrap();
return Poll::Ready(Ok(StartHandshake { accepted, io }));
}
Ok(None) => continue,
Err(err) => {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
}
}
}
}
}
pub struct StartHandshake<IO> {
accepted: rustls::server::Accepted,
io: IO,
}
impl<IO> StartHandshake<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
self.accepted.client_hello()
}
pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
self.into_stream_with(config, |_| ())
}
pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
where
F: FnOnce(&mut ServerConnection),
{
let mut conn = match self.accepted.into_connection(config) {
Ok(conn) => conn,
Err(error) => {
return Accept(MidHandshake::Error {
io: self.io,
// TODO(eliza): should this really return an `io::Error`?
// Probably not...
error: io::Error::new(io::ErrorKind::Other, error),
});
}
};
f(&mut conn);
Accept(MidHandshake::Handshaking(server::TlsStream {
session: conn,
io: self.io,
state: TlsState::Stream,
}))
}
}
/// Future returned from `TlsConnector::connect` which will resolve /// Future returned from `TlsConnector::connect` which will resolve
/// once the connection handshake has finished. /// once the connection handshake has finished.
pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>); pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);

View File

@ -11,7 +11,7 @@ use std::{io, thread};
use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt}; use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::runtime; use tokio::runtime;
use tokio_rustls::{TlsAcceptor, TlsConnector}; use tokio_rustls::{LazyConfigAcceptor, TlsAcceptor, TlsConnector};
const CERT: &str = include_str!("end.cert"); const CERT: &str = include_str!("end.cert");
const CHAIN: &[u8] = include_bytes!("end.chain"); const CHAIN: &[u8] = include_bytes!("end.chain");
@ -164,3 +164,43 @@ async fn fail() -> io::Result<()> {
Ok(()) Ok(())
} }
#[tokio::test]
async fn test_lazy_config_acceptor() -> io::Result<()> {
let (sconfig, cconfig) = utils::make_configs();
use std::convert::TryFrom;
let (cstream, sstream) = tokio::io::duplex(1200);
let domain = rustls::ServerName::try_from("localhost").unwrap();
tokio::spawn(async move {
let connector = crate::TlsConnector::from(cconfig);
let mut client = connector.connect(domain, cstream).await.unwrap();
client.write_all(b"hello, world!").await.unwrap();
let mut buf = Vec::new();
client.read_to_end(&mut buf).await.unwrap();
});
let acceptor = LazyConfigAcceptor::new(rustls::server::Acceptor::new().unwrap(), sstream);
let start = acceptor.await.unwrap();
let ch = start.client_hello();
assert_eq!(ch.server_name(), Some("localhost"));
assert_eq!(
ch.alpn()
.map(|protos| protos.collect::<Vec<_>>())
.unwrap_or(Vec::new()),
Vec::<&[u8]>::new()
);
let mut stream = start.into_stream(sconfig).await.unwrap();
let mut buf = [0; 13];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf[..], b"hello, world!");
stream.write_all(b"bye").await.unwrap();
Ok(())
}
// Include `utils` module
include!("utils.rs");

View File

@ -0,0 +1,49 @@
mod utils {
use std::io::{BufReader, Cursor};
use std::sync::Arc;
use rustls::{ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerConfig};
use rustls_pemfile::{certs, rsa_private_keys};
#[allow(dead_code)]
pub fn make_configs() -> (Arc<ServerConfig>, Arc<ClientConfig>) {
const CERT: &str = include_str!("end.cert");
const CHAIN: &str = include_str!("end.chain");
const RSA: &str = include_str!("end.rsa");
let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
.unwrap()
.drain(..)
.map(rustls::Certificate)
.collect();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let mut keys = keys.drain(..).map(PrivateKey);
let sconfig = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, keys.next().unwrap())
.unwrap();
let mut client_root_cert_store = RootCertStore::empty();
let mut chain = BufReader::new(Cursor::new(CHAIN));
let certs = certs(&mut chain).unwrap();
let trust_anchors = certs
.iter()
.map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
.collect::<Vec<_>>();
client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter());
let cconfig = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(client_root_cert_store)
.with_no_client_auth();
(Arc::new(sconfig), Arc::new(cconfig))
}
}