Add LazyConfigAcceptor API (#69)
This commit is contained in:
parent
48caaf751f
commit
33506018e7
@ -1,11 +1,9 @@
|
|||||||
use super::Stream;
|
use super::Stream;
|
||||||
use futures_util::future::poll_fn;
|
use futures_util::future::poll_fn;
|
||||||
use futures_util::task::noop_waker_ref;
|
use futures_util::task::noop_waker_ref;
|
||||||
use rustls::{ClientConnection, Connection, OwnedTrustAnchor, RootCertStore, ServerConnection};
|
use rustls::{ClientConnection, Connection, ServerConnection};
|
||||||
use rustls_pemfile::{certs, rsa_private_keys};
|
use std::io::{self, Cursor, Read, Write};
|
||||||
use std::io::{self, BufReader, Cursor, Read, Write};
|
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||||
|
|
||||||
@ -261,45 +259,11 @@ async fn stream_eof() -> io::Result<()> {
|
|||||||
fn make_pair() -> (ServerConnection, ClientConnection) {
|
fn make_pair() -> (ServerConnection, ClientConnection) {
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
|
|
||||||
const CERT: &str = include_str!("../../tests/end.cert");
|
let (sconfig, cconfig) = utils::make_configs();
|
||||||
const CHAIN: &str = include_str!("../../tests/end.chain");
|
let server = ServerConnection::new(sconfig).unwrap();
|
||||||
const RSA: &str = include_str!("../../tests/end.rsa");
|
|
||||||
|
|
||||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
|
|
||||||
.unwrap()
|
|
||||||
.drain(..)
|
|
||||||
.map(rustls::Certificate)
|
|
||||||
.collect();
|
|
||||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
|
||||||
let mut keys = keys.drain(..).map(rustls::PrivateKey);
|
|
||||||
let sconfig = rustls::ServerConfig::builder()
|
|
||||||
.with_safe_defaults()
|
|
||||||
.with_no_client_auth()
|
|
||||||
.with_single_cert(cert, keys.next().unwrap())
|
|
||||||
.unwrap();
|
|
||||||
let server = ServerConnection::new(Arc::new(sconfig)).unwrap();
|
|
||||||
|
|
||||||
let domain = rustls::ServerName::try_from("localhost").unwrap();
|
let domain = rustls::ServerName::try_from("localhost").unwrap();
|
||||||
let mut client_root_cert_store = RootCertStore::empty();
|
let client = ClientConnection::new(cconfig, domain).unwrap();
|
||||||
let mut chain = BufReader::new(Cursor::new(CHAIN));
|
|
||||||
let certs = certs(&mut chain).unwrap();
|
|
||||||
let trust_anchors = certs
|
|
||||||
.iter()
|
|
||||||
.map(|cert| {
|
|
||||||
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
|
||||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
|
||||||
ta.subject,
|
|
||||||
ta.spki,
|
|
||||||
ta.name_constraints,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter());
|
|
||||||
let cconfig = rustls::ClientConfig::builder()
|
|
||||||
.with_safe_defaults()
|
|
||||||
.with_root_certificates(client_root_cert_store)
|
|
||||||
.with_no_client_auth();
|
|
||||||
let client = ClientConnection::new(Arc::new(cconfig), domain).unwrap();
|
|
||||||
|
|
||||||
(server, client)
|
(server, client)
|
||||||
}
|
}
|
||||||
@ -322,3 +286,6 @@ fn do_handshake(
|
|||||||
|
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Share `utils` module with integration tests
|
||||||
|
include!("../../tests/utils.rs");
|
||||||
|
@ -188,6 +188,124 @@ impl TlsAcceptor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct LazyConfigAcceptor<IO> {
|
||||||
|
acceptor: rustls::server::Acceptor,
|
||||||
|
buf: Vec<u8>,
|
||||||
|
used: usize,
|
||||||
|
io: Option<IO>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<IO> LazyConfigAcceptor<IO>
|
||||||
|
where
|
||||||
|
IO: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
#[inline]
|
||||||
|
pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
|
||||||
|
Self {
|
||||||
|
acceptor,
|
||||||
|
buf: vec![0; 512],
|
||||||
|
used: 0,
|
||||||
|
io: Some(io),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<IO> Future for LazyConfigAcceptor<IO>
|
||||||
|
where
|
||||||
|
IO: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
type Output = Result<StartHandshake<IO>, io::Error>;
|
||||||
|
|
||||||
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
loop {
|
||||||
|
let io = match this.io.as_mut() {
|
||||||
|
Some(io) => io,
|
||||||
|
None => {
|
||||||
|
return Poll::Ready(Err(io::Error::new(
|
||||||
|
io::ErrorKind::Other,
|
||||||
|
"acceptor cannot be polled after acceptance",
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut buf = ReadBuf::new(&mut this.buf);
|
||||||
|
buf.advance(this.used);
|
||||||
|
if buf.remaining() > 0 {
|
||||||
|
if let Err(err) = ready!(Pin::new(io).poll_read(cx, &mut buf)) {
|
||||||
|
return Poll::Ready(Err(err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let read = match this.acceptor.read_tls(&mut buf.filled()) {
|
||||||
|
Ok(read) => read,
|
||||||
|
Err(err) => return Poll::Ready(Err(err)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let received = buf.filled().len();
|
||||||
|
if read < received {
|
||||||
|
this.buf.copy_within(read.., 0);
|
||||||
|
this.used = received - read;
|
||||||
|
} else {
|
||||||
|
this.used = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
match this.acceptor.accept() {
|
||||||
|
Ok(Some(accepted)) => {
|
||||||
|
let io = this.io.take().unwrap();
|
||||||
|
return Poll::Ready(Ok(StartHandshake { accepted, io }));
|
||||||
|
}
|
||||||
|
Ok(None) => continue,
|
||||||
|
Err(err) => {
|
||||||
|
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct StartHandshake<IO> {
|
||||||
|
accepted: rustls::server::Accepted,
|
||||||
|
io: IO,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<IO> StartHandshake<IO>
|
||||||
|
where
|
||||||
|
IO: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
|
||||||
|
self.accepted.client_hello()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
|
||||||
|
self.into_stream_with(config, |_| ())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
|
||||||
|
where
|
||||||
|
F: FnOnce(&mut ServerConnection),
|
||||||
|
{
|
||||||
|
let mut conn = match self.accepted.into_connection(config) {
|
||||||
|
Ok(conn) => conn,
|
||||||
|
Err(error) => {
|
||||||
|
return Accept(MidHandshake::Error {
|
||||||
|
io: self.io,
|
||||||
|
// TODO(eliza): should this really return an `io::Error`?
|
||||||
|
// Probably not...
|
||||||
|
error: io::Error::new(io::ErrorKind::Other, error),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
f(&mut conn);
|
||||||
|
|
||||||
|
Accept(MidHandshake::Handshaking(server::TlsStream {
|
||||||
|
session: conn,
|
||||||
|
io: self.io,
|
||||||
|
state: TlsState::Stream,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Future returned from `TlsConnector::connect` which will resolve
|
/// Future returned from `TlsConnector::connect` which will resolve
|
||||||
/// once the connection handshake has finished.
|
/// once the connection handshake has finished.
|
||||||
pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
|
pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
|
||||||
|
@ -11,7 +11,7 @@ use std::{io, thread};
|
|||||||
use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio::runtime;
|
use tokio::runtime;
|
||||||
use tokio_rustls::{TlsAcceptor, TlsConnector};
|
use tokio_rustls::{LazyConfigAcceptor, TlsAcceptor, TlsConnector};
|
||||||
|
|
||||||
const CERT: &str = include_str!("end.cert");
|
const CERT: &str = include_str!("end.cert");
|
||||||
const CHAIN: &[u8] = include_bytes!("end.chain");
|
const CHAIN: &[u8] = include_bytes!("end.chain");
|
||||||
@ -164,3 +164,43 @@ async fn fail() -> io::Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_lazy_config_acceptor() -> io::Result<()> {
|
||||||
|
let (sconfig, cconfig) = utils::make_configs();
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
|
||||||
|
let (cstream, sstream) = tokio::io::duplex(1200);
|
||||||
|
let domain = rustls::ServerName::try_from("localhost").unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let connector = crate::TlsConnector::from(cconfig);
|
||||||
|
let mut client = connector.connect(domain, cstream).await.unwrap();
|
||||||
|
client.write_all(b"hello, world!").await.unwrap();
|
||||||
|
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
client.read_to_end(&mut buf).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let acceptor = LazyConfigAcceptor::new(rustls::server::Acceptor::new().unwrap(), sstream);
|
||||||
|
let start = acceptor.await.unwrap();
|
||||||
|
let ch = start.client_hello();
|
||||||
|
|
||||||
|
assert_eq!(ch.server_name(), Some("localhost"));
|
||||||
|
assert_eq!(
|
||||||
|
ch.alpn()
|
||||||
|
.map(|protos| protos.collect::<Vec<_>>())
|
||||||
|
.unwrap_or(Vec::new()),
|
||||||
|
Vec::<&[u8]>::new()
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut stream = start.into_stream(sconfig).await.unwrap();
|
||||||
|
let mut buf = [0; 13];
|
||||||
|
stream.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf[..], b"hello, world!");
|
||||||
|
|
||||||
|
stream.write_all(b"bye").await.unwrap();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include `utils` module
|
||||||
|
include!("utils.rs");
|
||||||
|
49
tokio-rustls/tests/utils.rs
Normal file
49
tokio-rustls/tests/utils.rs
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
mod utils {
|
||||||
|
use std::io::{BufReader, Cursor};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use rustls::{ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerConfig};
|
||||||
|
use rustls_pemfile::{certs, rsa_private_keys};
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn make_configs() -> (Arc<ServerConfig>, Arc<ClientConfig>) {
|
||||||
|
const CERT: &str = include_str!("end.cert");
|
||||||
|
const CHAIN: &str = include_str!("end.chain");
|
||||||
|
const RSA: &str = include_str!("end.rsa");
|
||||||
|
|
||||||
|
let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
|
||||||
|
.unwrap()
|
||||||
|
.drain(..)
|
||||||
|
.map(rustls::Certificate)
|
||||||
|
.collect();
|
||||||
|
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||||
|
let mut keys = keys.drain(..).map(PrivateKey);
|
||||||
|
let sconfig = ServerConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_single_cert(cert, keys.next().unwrap())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut client_root_cert_store = RootCertStore::empty();
|
||||||
|
let mut chain = BufReader::new(Cursor::new(CHAIN));
|
||||||
|
let certs = certs(&mut chain).unwrap();
|
||||||
|
let trust_anchors = certs
|
||||||
|
.iter()
|
||||||
|
.map(|cert| {
|
||||||
|
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
||||||
|
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||||
|
ta.subject,
|
||||||
|
ta.spki,
|
||||||
|
ta.name_constraints,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter());
|
||||||
|
let cconfig = ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(client_root_cert_store)
|
||||||
|
.with_no_client_auth();
|
||||||
|
|
||||||
|
(Arc::new(sconfig), Arc::new(cconfig))
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user