From cdd0eacdd8c6b89ed3f2bf8b36aab7ddb502722e Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Thu, 27 Jul 2023 12:50:56 +0200 Subject: [PATCH] Update gemini proxying according to feedback from the tokio_rustls issue --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/service/gemini.rs | 13 ++- src/service/gemini/proxy.rs | 157 ++++++++++++++++++++++++++++++++++++ src/service/http.rs | 1 + 5 files changed, 170 insertions(+), 5 deletions(-) create mode 100644 src/service/gemini/proxy.rs diff --git a/Cargo.lock b/Cargo.lock index e4d1894..b4f3867 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2964,7 +2964,7 @@ dependencies = [ [[package]] name = "tokio-rustls" version = "0.24.1" -source = "git+https://code.betamike.com/micropelago/tokio-rustls.git?branch=transparent-acceptor#18fd688b335430e17e054e15ff7d6ce073db2419" +source = "git+https://code.betamike.com/micropelago/tokio-rustls.git?branch=start-handshake-into-inner#3d462a1d97836cdb0600f0bc69c5e3b3310f6d8c" dependencies = [ "rustls", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 1b48bda..937cf7e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,4 +47,4 @@ reqwest = "0.11.18" hyper-reverse-proxy = "0.5.1" [patch.crates-io] -tokio-rustls = { git = "https://code.betamike.com/micropelago/tokio-rustls.git", branch = "transparent-acceptor" } +tokio-rustls = { git = "https://code.betamike.com/micropelago/tokio-rustls.git", branch = "start-handshake-into-inner" } diff --git a/src/service/gemini.rs b/src/service/gemini.rs index f5f2c26..0f24a74 100644 --- a/src/service/gemini.rs +++ b/src/service/gemini.rs @@ -1,4 +1,5 @@ mod config; +mod proxy; pub use config::*; @@ -60,8 +61,14 @@ impl Service { conn: tokio::net::TcpStream, _tls_config: sync::Arc, ) -> Result<(), HandleConnError> { + let teed_conn = { + let (r, w) = tokio::io::split(conn); + let r = proxy::AsyncTeeRead::with_capacity(r, 1024); + proxy::AsyncReadWrite::new(r, w) + }; + let acceptor = - tokio_rustls::TransparentConfigAcceptor::new(rustls::server::Acceptor::default(), conn); + tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), teed_conn); futures::pin_mut!(acceptor); match acceptor.as_mut().await { @@ -80,8 +87,8 @@ impl Service { // If the domain should be proxied, then proxy it if let Some(proxied_domain) = self.config.gemini.proxied_domains.get(&domain) { - let conn = start.into_original_stream(); - self.proxy_conn(proxied_domain, conn).await?; + let prefixed_conn = proxy::teed_io_to_prefixed(start.into_inner()); + self.proxy_conn(proxied_domain, prefixed_conn).await?; return Ok(()); } diff --git a/src/service/gemini/proxy.rs b/src/service/gemini/proxy.rs new file mode 100644 index 0000000..3a5b05b --- /dev/null +++ b/src/service/gemini/proxy.rs @@ -0,0 +1,157 @@ +use std::{io, pin, task}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf}; + +pub type TeedIO = AsyncReadWrite>, WriteHalf>; +pub type PrefixedIO = AsyncReadWrite>, WriteHalf>; + +pub fn teed_io_to_prefixed(teed_io: TeedIO) -> PrefixedIO { + let (r, w) = teed_io.into_inner(); + let (r, bytes_read) = r.into_inner(); + let r = AsyncPrefixedRead::new(r, bytes_read); + AsyncReadWrite::new(r, w) +} + +/// Wraps an AsyncRead and AsyncWrite instance together to produce a single type which implements +/// AsyncRead + AsyncWrite. +pub struct AsyncReadWrite { + r: pin::Pin>, + w: pin::Pin>, +} + +impl AsyncReadWrite +where + R: Unpin, + W: Unpin, +{ + pub fn new(r: R, w: W) -> Self { + Self { + r: Box::pin(r), + w: Box::pin(w), + } + } + + pub fn into_inner(self) -> (R, W) { + (*pin::Pin::into_inner(self.r), *pin::Pin::into_inner(self.w)) + } +} + +impl AsyncRead for AsyncReadWrite +where + R: AsyncRead + Unpin, +{ + fn poll_read( + mut self: pin::Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> task::Poll> { + self.r.as_mut().poll_read(cx, buf) + } +} + +impl AsyncWrite for AsyncReadWrite +where + W: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: pin::Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> task::Poll> { + self.w.as_mut().poll_write(cx, buf) + } + + fn poll_flush( + mut self: pin::Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.w.as_mut().poll_flush(cx) + } + + fn poll_shutdown( + mut self: pin::Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.w.as_mut().poll_shutdown(cx) + } +} + +/// Wraps an AsyncRead in order to capture all bytes which have been read by it into an internal +/// buffer. +pub struct AsyncTeeRead { + r: pin::Pin>, + buf: Vec, +} + +impl AsyncTeeRead +where + R: Unpin, +{ + /// Initializes an AsyncTeeRead with an empty internal buffer of the given size. + pub fn with_capacity(r: R, cap: usize) -> Self { + Self { + r: Box::pin(r), + buf: Vec::with_capacity(cap), + } + } + + pub fn into_inner(self) -> (R, Vec) { + (*pin::Pin::into_inner(self.r), self.buf) + } +} + +impl AsyncRead for AsyncTeeRead +where + R: AsyncRead, +{ + fn poll_read( + mut self: pin::Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> task::Poll> { + let res = self.r.as_mut().poll_read(cx, buf); + + if let task::Poll::Ready(Ok(())) = res { + self.buf.extend_from_slice(buf.filled()); + } + + res + } +} + +pub struct AsyncPrefixedRead { + r: pin::Pin>, + prefix: Vec, +} + +impl AsyncPrefixedRead { + pub fn new(r: R, prefix: Vec) -> Self { + Self { + r: Box::pin(r), + prefix, + } + } +} + +impl AsyncRead for AsyncPrefixedRead +where + R: AsyncRead, +{ + fn poll_read( + self: pin::Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> task::Poll> { + let this = self.get_mut(); + + let prefix_len = this.prefix.len(); + if prefix_len == 0 { + return this.r.as_mut().poll_read(cx, buf); + } + + let n = std::cmp::min(prefix_len, buf.remaining()); + let to_write = this.prefix.drain(..n); + + buf.put_slice(to_write.as_slice()); + task::Poll::Ready(Ok(())) + } +} diff --git a/src/service/http.rs b/src/service/http.rs index d42df63..b7206a0 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -293,6 +293,7 @@ impl<'svc> Service { } }; + // TODO this is wrong, e.g. something.co.uk let domain_is_zone_apex = args.domain.as_rr().num_labels() == 2; let dns_records_have_cname = self.config.dns_records.iter().any(|r| match r { service::ConfigDNSRecord::CNAME { .. } => true,