Add LazyConfigAcceptor API (#69)
This commit is contained in:
parent
48caaf751f
commit
33506018e7
@ -1,11 +1,9 @@
|
||||
use super::Stream;
|
||||
use futures_util::future::poll_fn;
|
||||
use futures_util::task::noop_waker_ref;
|
||||
use rustls::{ClientConnection, Connection, OwnedTrustAnchor, RootCertStore, ServerConnection};
|
||||
use rustls_pemfile::{certs, rsa_private_keys};
|
||||
use std::io::{self, BufReader, Cursor, Read, Write};
|
||||
use rustls::{ClientConnection, Connection, ServerConnection};
|
||||
use std::io::{self, Cursor, Read, Write};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
|
||||
@ -261,45 +259,11 @@ async fn stream_eof() -> io::Result<()> {
|
||||
fn make_pair() -> (ServerConnection, ClientConnection) {
|
||||
use std::convert::TryFrom;
|
||||
|
||||
const CERT: &str = include_str!("../../tests/end.cert");
|
||||
const CHAIN: &str = include_str!("../../tests/end.chain");
|
||||
const RSA: &str = include_str!("../../tests/end.rsa");
|
||||
|
||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
|
||||
.unwrap()
|
||||
.drain(..)
|
||||
.map(rustls::Certificate)
|
||||
.collect();
|
||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||
let mut keys = keys.drain(..).map(rustls::PrivateKey);
|
||||
let sconfig = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert, keys.next().unwrap())
|
||||
.unwrap();
|
||||
let server = ServerConnection::new(Arc::new(sconfig)).unwrap();
|
||||
let (sconfig, cconfig) = utils::make_configs();
|
||||
let server = ServerConnection::new(sconfig).unwrap();
|
||||
|
||||
let domain = rustls::ServerName::try_from("localhost").unwrap();
|
||||
let mut client_root_cert_store = RootCertStore::empty();
|
||||
let mut chain = BufReader::new(Cursor::new(CHAIN));
|
||||
let certs = certs(&mut chain).unwrap();
|
||||
let trust_anchors = certs
|
||||
.iter()
|
||||
.map(|cert| {
|
||||
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter());
|
||||
let cconfig = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(client_root_cert_store)
|
||||
.with_no_client_auth();
|
||||
let client = ClientConnection::new(Arc::new(cconfig), domain).unwrap();
|
||||
let client = ClientConnection::new(cconfig, domain).unwrap();
|
||||
|
||||
(server, client)
|
||||
}
|
||||
@ -322,3 +286,6 @@ fn do_handshake(
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
// Share `utils` module with integration tests
|
||||
include!("../../tests/utils.rs");
|
||||
|
@ -188,6 +188,124 @@ impl TlsAcceptor {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LazyConfigAcceptor<IO> {
|
||||
acceptor: rustls::server::Acceptor,
|
||||
buf: Vec<u8>,
|
||||
used: usize,
|
||||
io: Option<IO>,
|
||||
}
|
||||
|
||||
impl<IO> LazyConfigAcceptor<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
#[inline]
|
||||
pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
|
||||
Self {
|
||||
acceptor,
|
||||
buf: vec![0; 512],
|
||||
used: 0,
|
||||
io: Some(io),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO> Future for LazyConfigAcceptor<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = Result<StartHandshake<IO>, io::Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
loop {
|
||||
let io = match this.io.as_mut() {
|
||||
Some(io) => io,
|
||||
None => {
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"acceptor cannot be polled after acceptance",
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let mut buf = ReadBuf::new(&mut this.buf);
|
||||
buf.advance(this.used);
|
||||
if buf.remaining() > 0 {
|
||||
if let Err(err) = ready!(Pin::new(io).poll_read(cx, &mut buf)) {
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
}
|
||||
|
||||
let read = match this.acceptor.read_tls(&mut buf.filled()) {
|
||||
Ok(read) => read,
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
};
|
||||
|
||||
let received = buf.filled().len();
|
||||
if read < received {
|
||||
this.buf.copy_within(read.., 0);
|
||||
this.used = received - read;
|
||||
} else {
|
||||
this.used = 0;
|
||||
}
|
||||
|
||||
match this.acceptor.accept() {
|
||||
Ok(Some(accepted)) => {
|
||||
let io = this.io.take().unwrap();
|
||||
return Poll::Ready(Ok(StartHandshake { accepted, io }));
|
||||
}
|
||||
Ok(None) => continue,
|
||||
Err(err) => {
|
||||
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StartHandshake<IO> {
|
||||
accepted: rustls::server::Accepted,
|
||||
io: IO,
|
||||
}
|
||||
|
||||
impl<IO> StartHandshake<IO>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
|
||||
self.accepted.client_hello()
|
||||
}
|
||||
|
||||
pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
|
||||
self.into_stream_with(config, |_| ())
|
||||
}
|
||||
|
||||
pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
|
||||
where
|
||||
F: FnOnce(&mut ServerConnection),
|
||||
{
|
||||
let mut conn = match self.accepted.into_connection(config) {
|
||||
Ok(conn) => conn,
|
||||
Err(error) => {
|
||||
return Accept(MidHandshake::Error {
|
||||
io: self.io,
|
||||
// TODO(eliza): should this really return an `io::Error`?
|
||||
// Probably not...
|
||||
error: io::Error::new(io::ErrorKind::Other, error),
|
||||
});
|
||||
}
|
||||
};
|
||||
f(&mut conn);
|
||||
|
||||
Accept(MidHandshake::Handshaking(server::TlsStream {
|
||||
session: conn,
|
||||
io: self.io,
|
||||
state: TlsState::Stream,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
/// Future returned from `TlsConnector::connect` which will resolve
|
||||
/// once the connection handshake has finished.
|
||||
pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
|
||||
|
@ -11,7 +11,7 @@ use std::{io, thread};
|
||||
use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::runtime;
|
||||
use tokio_rustls::{TlsAcceptor, TlsConnector};
|
||||
use tokio_rustls::{LazyConfigAcceptor, TlsAcceptor, TlsConnector};
|
||||
|
||||
const CERT: &str = include_str!("end.cert");
|
||||
const CHAIN: &[u8] = include_bytes!("end.chain");
|
||||
@ -164,3 +164,43 @@ async fn fail() -> io::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_lazy_config_acceptor() -> io::Result<()> {
|
||||
let (sconfig, cconfig) = utils::make_configs();
|
||||
use std::convert::TryFrom;
|
||||
|
||||
let (cstream, sstream) = tokio::io::duplex(1200);
|
||||
let domain = rustls::ServerName::try_from("localhost").unwrap();
|
||||
tokio::spawn(async move {
|
||||
let connector = crate::TlsConnector::from(cconfig);
|
||||
let mut client = connector.connect(domain, cstream).await.unwrap();
|
||||
client.write_all(b"hello, world!").await.unwrap();
|
||||
|
||||
let mut buf = Vec::new();
|
||||
client.read_to_end(&mut buf).await.unwrap();
|
||||
});
|
||||
|
||||
let acceptor = LazyConfigAcceptor::new(rustls::server::Acceptor::new().unwrap(), sstream);
|
||||
let start = acceptor.await.unwrap();
|
||||
let ch = start.client_hello();
|
||||
|
||||
assert_eq!(ch.server_name(), Some("localhost"));
|
||||
assert_eq!(
|
||||
ch.alpn()
|
||||
.map(|protos| protos.collect::<Vec<_>>())
|
||||
.unwrap_or(Vec::new()),
|
||||
Vec::<&[u8]>::new()
|
||||
);
|
||||
|
||||
let mut stream = start.into_stream(sconfig).await.unwrap();
|
||||
let mut buf = [0; 13];
|
||||
stream.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..], b"hello, world!");
|
||||
|
||||
stream.write_all(b"bye").await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Include `utils` module
|
||||
include!("utils.rs");
|
||||
|
49
tokio-rustls/tests/utils.rs
Normal file
49
tokio-rustls/tests/utils.rs
Normal file
@ -0,0 +1,49 @@
|
||||
mod utils {
|
||||
use std::io::{BufReader, Cursor};
|
||||
use std::sync::Arc;
|
||||
|
||||
use rustls::{ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerConfig};
|
||||
use rustls_pemfile::{certs, rsa_private_keys};
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn make_configs() -> (Arc<ServerConfig>, Arc<ClientConfig>) {
|
||||
const CERT: &str = include_str!("end.cert");
|
||||
const CHAIN: &str = include_str!("end.chain");
|
||||
const RSA: &str = include_str!("end.rsa");
|
||||
|
||||
let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
|
||||
.unwrap()
|
||||
.drain(..)
|
||||
.map(rustls::Certificate)
|
||||
.collect();
|
||||
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
|
||||
let mut keys = keys.drain(..).map(PrivateKey);
|
||||
let sconfig = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert, keys.next().unwrap())
|
||||
.unwrap();
|
||||
|
||||
let mut client_root_cert_store = RootCertStore::empty();
|
||||
let mut chain = BufReader::new(Cursor::new(CHAIN));
|
||||
let certs = certs(&mut chain).unwrap();
|
||||
let trust_anchors = certs
|
||||
.iter()
|
||||
.map(|cert| {
|
||||
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter());
|
||||
let cconfig = ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(client_root_cert_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
(Arc::new(sconfig), Arc::new(cconfig))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user