Implement TransparentConfigAcceptor

The goal of the TransparentConfigAcceptor is to support an SNI-based
reverse-proxy, where the server reads the SNI and then transparently
forwards the entire TLS session, ClientHello included, to a backend
server, without terminating the TLS session itself.

This isn't possible with the current LazyConfigAcceptor, which only
allows you to pick a different ServerConfig depending on the SNI, but
will always terminate the session.

The TransparentConfigAcceptor will buffer all bytes read from the
connection (the ClientHello) internally, and then replay them if the
user decides they want to hijack the connection.

The TransparentConfigAcceptor supports all functionality that the
LazyConfigAcceptor does, but due to the internal buffering of the
ClientHello I did not want to add it to the LazyConfigAcceptor, since
it's possible someone wouldn't want to incur that extra cost.
This commit is contained in:
Brian Picciano 2023-07-22 13:40:19 +02:00
parent b7289d7e7e
commit 18fd688b33
2 changed files with 236 additions and 2 deletions

View File

@ -359,5 +359,147 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
}
}
/// Wraps an AsyncRead and AsyncWrite instance together to produce a single type which implements
/// AsyncRead + AsyncWrite.
pub struct AsyncReadWrite<R, W> {
r: Pin<Box<R>>,
w: Pin<Box<W>>,
}
impl<R, W> AsyncReadWrite<R, W>
where
R: Unpin,
W: Unpin,
{
pub fn new(r: R, w: W) -> Self {
Self {
r: Box::pin(r),
w: Box::pin(w),
}
}
pub fn into_inner(self) -> (R, W) {
(*Pin::into_inner(self.r), *Pin::into_inner(self.w))
}
}
impl<R, W> AsyncRead for AsyncReadWrite<R, W>
where
R: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.r.as_mut().poll_read(cx, buf)
}
}
impl<R, W> AsyncWrite for AsyncReadWrite<R, W>
where
W: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.w.as_mut().poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.w.as_mut().poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.w.as_mut().poll_shutdown(cx)
}
}
/// Wraps an AsyncRead in order to capture all bytes which have been read by it into an internal
/// buffer.
pub struct AsyncReadCapture<R> {
r: Pin<Box<R>>,
buf: Vec<u8>,
}
impl<R> AsyncReadCapture<R>
where
R: AsyncRead + Unpin,
{
/// Initializes an AsyncReadCapture with an empty internal buffer of the given size.
pub fn with_capacity(r: R, cap: usize) -> Self {
Self {
r: Box::pin(r),
buf: Vec::with_capacity(cap),
}
}
pub fn into_inner(self) -> (R, Vec<u8>) {
(*Pin::into_inner(self.r), self.buf)
}
}
impl<R> AsyncRead for AsyncReadCapture<R>
where
R: AsyncRead,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let res = self.r.as_mut().poll_read(cx, buf);
if let Poll::Ready(Ok(())) = res {
self.buf.extend_from_slice(buf.filled());
}
res
}
}
pub struct AsyncReadPrefixed<R> {
r: Pin<Box<R>>,
prefix: Vec<u8>,
}
impl<R> AsyncReadPrefixed<R>
where
R: AsyncRead + Unpin,
{
pub fn new(r: R, prefix: Vec<u8>) -> Self {
Self {
r: Box::pin(r),
prefix,
}
}
}
impl<R> AsyncRead for AsyncReadPrefixed<R>
where
R: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let prefix_len = this.prefix.len();
if prefix_len == 0 {
return this.r.as_mut().poll_read(cx, buf);
}
let n = std::cmp::min(prefix_len, buf.remaining());
let to_write = this.prefix.drain(..n);
buf.put_slice(to_write.as_slice());
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod test_stream;

View File

@ -49,7 +49,7 @@ pub mod client;
mod common;
pub mod server;
use common::{MidHandshake, Stream, TlsState};
use common::{AsyncReadCapture, AsyncReadPrefixed, AsyncReadWrite, MidHandshake, Stream, TlsState};
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
use std::future::Future;
use std::io;
@ -60,7 +60,7 @@ use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
pub use rustls;
@ -333,6 +333,98 @@ where
}
}
type IOWithCapture<IO> = AsyncReadWrite<AsyncReadCapture<ReadHalf<IO>>, WriteHalf<IO>>;
type IOWithPrefix<IO> = AsyncReadWrite<AsyncReadPrefixed<ReadHalf<IO>>, WriteHalf<IO>>;
fn unwrap_io_with_capture<IO>(
io_with_capture: IOWithCapture<IO>,
) -> (ReadHalf<IO>, WriteHalf<IO>, Vec<u8>)
where
IO: AsyncRead + AsyncWrite + Unpin,
{
let (r, w) = io_with_capture.into_inner();
let (r, bytes_read) = r.into_inner();
(r, w, bytes_read)
}
pub struct TransparentConfigAcceptor<IO> {
acceptor: Pin<Box<LazyConfigAcceptor<IOWithCapture<IO>>>>,
}
impl<IO> TransparentConfigAcceptor<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
let (r, w) = tokio::io::split(io);
let r = AsyncReadCapture::with_capacity(r, 1024);
let rw = AsyncReadWrite::new(r, w);
Self {
acceptor: Box::pin(LazyConfigAcceptor::new(acceptor, rw)),
}
}
}
impl<IO> Future for TransparentConfigAcceptor<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
type Output = io::Result<TransparentStartHandshake<IO>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut().acceptor.as_mut().poll(cx) {
Poll::Ready(Ok(h)) => {
let (r, w, bytes_read) = unwrap_io_with_capture(h.io);
Poll::Ready(Ok(TransparentStartHandshake {
accepted: h.accepted,
r,
w,
bytes_read,
}))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
pub struct TransparentStartHandshake<IO> {
accepted: rustls::server::Accepted,
r: ReadHalf<IO>,
w: WriteHalf<IO>,
bytes_read: Vec<u8>,
}
impl<IO> TransparentStartHandshake<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
self.accepted.client_hello()
}
pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
self.into_stream_with(config, |_| ())
}
pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
where
F: FnOnce(&mut ServerConnection),
{
let start_handshake = StartHandshake {
accepted: self.accepted,
io: self.r.unsplit(self.w),
};
start_handshake.into_stream_with(config, f)
}
pub fn into_original_stream(self) -> IOWithPrefix<IO> {
let r = AsyncReadPrefixed::new(self.r, self.bytes_read);
AsyncReadWrite::new(r, self.w)
}
}
/// Future returned from `TlsConnector::connect` which will resolve
/// once the connection handshake has finished.
pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);