From 18fd688b335430e17e054e15ff7d6ce073db2419 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Sat, 22 Jul 2023 13:40:19 +0200 Subject: [PATCH] Implement TransparentConfigAcceptor The goal of the TransparentConfigAcceptor is to support an SNI-based reverse-proxy, where the server reads the SNI and then transparently forwards the entire TLS session, ClientHello included, to a backend server, without terminating the TLS session itself. This isn't possible with the current LazyConfigAcceptor, which only allows you to pick a different ServerConfig depending on the SNI, but will always terminate the session. The TransparentConfigAcceptor will buffer all bytes read from the connection (the ClientHello) internally, and then replay them if the user decides they want to hijack the connection. The TransparentConfigAcceptor supports all functionality that the LazyConfigAcceptor does, but due to the internal buffering of the ClientHello I did not want to add it to the LazyConfigAcceptor, since it's possible someone wouldn't want to incur that extra cost. --- src/common/mod.rs | 142 ++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 96 ++++++++++++++++++++++++++++++- 2 files changed, 236 insertions(+), 2 deletions(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index fde34c0..1cc976a 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -359,5 +359,147 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> { } } +/// Wraps an AsyncRead and AsyncWrite instance together to produce a single type which implements +/// AsyncRead + AsyncWrite. +pub struct AsyncReadWrite { + r: Pin>, + w: Pin>, +} + +impl AsyncReadWrite +where + R: Unpin, + W: Unpin, +{ + pub fn new(r: R, w: W) -> Self { + Self { + r: Box::pin(r), + w: Box::pin(w), + } + } + + pub fn into_inner(self) -> (R, W) { + (*Pin::into_inner(self.r), *Pin::into_inner(self.w)) + } +} + +impl AsyncRead for AsyncReadWrite +where + R: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.r.as_mut().poll_read(cx, buf) + } +} + +impl AsyncWrite for AsyncReadWrite +where + W: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.w.as_mut().poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.w.as_mut().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.w.as_mut().poll_shutdown(cx) + } +} + +/// Wraps an AsyncRead in order to capture all bytes which have been read by it into an internal +/// buffer. +pub struct AsyncReadCapture { + r: Pin>, + buf: Vec, +} + +impl AsyncReadCapture +where + R: AsyncRead + Unpin, +{ + /// Initializes an AsyncReadCapture with an empty internal buffer of the given size. + pub fn with_capacity(r: R, cap: usize) -> Self { + Self { + r: Box::pin(r), + buf: Vec::with_capacity(cap), + } + } + + pub fn into_inner(self) -> (R, Vec) { + (*Pin::into_inner(self.r), self.buf) + } +} + +impl AsyncRead for AsyncReadCapture +where + R: AsyncRead, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let res = self.r.as_mut().poll_read(cx, buf); + + if let Poll::Ready(Ok(())) = res { + self.buf.extend_from_slice(buf.filled()); + } + + res + } +} + +pub struct AsyncReadPrefixed { + r: Pin>, + prefix: Vec, +} + +impl AsyncReadPrefixed +where + R: AsyncRead + Unpin, +{ + pub fn new(r: R, prefix: Vec) -> Self { + Self { + r: Box::pin(r), + prefix, + } + } +} + +impl AsyncRead for AsyncReadPrefixed +where + R: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + + let prefix_len = this.prefix.len(); + if prefix_len == 0 { + return this.r.as_mut().poll_read(cx, buf); + } + + let n = std::cmp::min(prefix_len, buf.remaining()); + let to_write = this.prefix.drain(..n); + + buf.put_slice(to_write.as_slice()); + Poll::Ready(Ok(())) + } +} + #[cfg(test)] mod test_stream; diff --git a/src/lib.rs b/src/lib.rs index 000245c..8574cda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,7 @@ pub mod client; mod common; pub mod server; -use common::{MidHandshake, Stream, TlsState}; +use common::{AsyncReadCapture, AsyncReadPrefixed, AsyncReadWrite, MidHandshake, Stream, TlsState}; use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; use std::future::Future; use std::io; @@ -60,7 +60,7 @@ use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf}; pub use rustls; @@ -333,6 +333,98 @@ where } } +type IOWithCapture = AsyncReadWrite>, WriteHalf>; +type IOWithPrefix = AsyncReadWrite>, WriteHalf>; + +fn unwrap_io_with_capture( + io_with_capture: IOWithCapture, +) -> (ReadHalf, WriteHalf, Vec) +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + let (r, w) = io_with_capture.into_inner(); + let (r, bytes_read) = r.into_inner(); + (r, w, bytes_read) +} + +pub struct TransparentConfigAcceptor { + acceptor: Pin>>>, +} + +impl TransparentConfigAcceptor +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self { + let (r, w) = tokio::io::split(io); + let r = AsyncReadCapture::with_capacity(r, 1024); + let rw = AsyncReadWrite::new(r, w); + Self { + acceptor: Box::pin(LazyConfigAcceptor::new(acceptor, rw)), + } + } +} + +impl Future for TransparentConfigAcceptor +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + type Output = io::Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.get_mut().acceptor.as_mut().poll(cx) { + Poll::Ready(Ok(h)) => { + let (r, w, bytes_read) = unwrap_io_with_capture(h.io); + Poll::Ready(Ok(TransparentStartHandshake { + accepted: h.accepted, + r, + w, + bytes_read, + })) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} + +pub struct TransparentStartHandshake { + accepted: rustls::server::Accepted, + r: ReadHalf, + w: WriteHalf, + bytes_read: Vec, +} + +impl TransparentStartHandshake +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + pub fn client_hello(&self) -> rustls::server::ClientHello<'_> { + self.accepted.client_hello() + } + + pub fn into_stream(self, config: Arc) -> Accept { + self.into_stream_with(config, |_| ()) + } + + pub fn into_stream_with(self, config: Arc, f: F) -> Accept + where + F: FnOnce(&mut ServerConnection), + { + let start_handshake = StartHandshake { + accepted: self.accepted, + io: self.r.unsplit(self.w), + }; + + start_handshake.into_stream_with(config, f) + } + + pub fn into_original_stream(self) -> IOWithPrefix { + let r = AsyncReadPrefixed::new(self.r, self.bytes_read); + AsyncReadWrite::new(r, self.w) + } +} + /// Future returned from `TlsConnector::connect` which will resolve /// once the connection handshake has finished. pub struct Connect(MidHandshake>);