Update gemini proxying according to feedback from the tokio_rustls issue

This commit is contained in:
Brian Picciano 2023-07-27 12:50:56 +02:00
parent c1659fab2a
commit cdd0eacdd8
5 changed files with 170 additions and 5 deletions

2
Cargo.lock generated
View File

@ -2964,7 +2964,7 @@ dependencies = [
[[package]] [[package]]
name = "tokio-rustls" name = "tokio-rustls"
version = "0.24.1" version = "0.24.1"
source = "git+https://code.betamike.com/micropelago/tokio-rustls.git?branch=transparent-acceptor#18fd688b335430e17e054e15ff7d6ce073db2419" source = "git+https://code.betamike.com/micropelago/tokio-rustls.git?branch=start-handshake-into-inner#3d462a1d97836cdb0600f0bc69c5e3b3310f6d8c"
dependencies = [ dependencies = [
"rustls", "rustls",
"tokio", "tokio",

View File

@ -47,4 +47,4 @@ reqwest = "0.11.18"
hyper-reverse-proxy = "0.5.1" hyper-reverse-proxy = "0.5.1"
[patch.crates-io] [patch.crates-io]
tokio-rustls = { git = "https://code.betamike.com/micropelago/tokio-rustls.git", branch = "transparent-acceptor" } tokio-rustls = { git = "https://code.betamike.com/micropelago/tokio-rustls.git", branch = "start-handshake-into-inner" }

View File

@ -1,4 +1,5 @@
mod config; mod config;
mod proxy;
pub use config::*; pub use config::*;
@ -60,8 +61,14 @@ impl Service {
conn: tokio::net::TcpStream, conn: tokio::net::TcpStream,
_tls_config: sync::Arc<rustls::ServerConfig>, _tls_config: sync::Arc<rustls::ServerConfig>,
) -> Result<(), HandleConnError> { ) -> Result<(), HandleConnError> {
let teed_conn = {
let (r, w) = tokio::io::split(conn);
let r = proxy::AsyncTeeRead::with_capacity(r, 1024);
proxy::AsyncReadWrite::new(r, w)
};
let acceptor = let acceptor =
tokio_rustls::TransparentConfigAcceptor::new(rustls::server::Acceptor::default(), conn); tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), teed_conn);
futures::pin_mut!(acceptor); futures::pin_mut!(acceptor);
match acceptor.as_mut().await { match acceptor.as_mut().await {
@ -80,8 +87,8 @@ impl Service {
// If the domain should be proxied, then proxy it // If the domain should be proxied, then proxy it
if let Some(proxied_domain) = self.config.gemini.proxied_domains.get(&domain) { if let Some(proxied_domain) = self.config.gemini.proxied_domains.get(&domain) {
let conn = start.into_original_stream(); let prefixed_conn = proxy::teed_io_to_prefixed(start.into_inner());
self.proxy_conn(proxied_domain, conn).await?; self.proxy_conn(proxied_domain, prefixed_conn).await?;
return Ok(()); return Ok(());
} }

157
src/service/gemini/proxy.rs Normal file
View File

@ -0,0 +1,157 @@
use std::{io, pin, task};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
pub type TeedIO<IO> = AsyncReadWrite<AsyncTeeRead<ReadHalf<IO>>, WriteHalf<IO>>;
pub type PrefixedIO<IO> = AsyncReadWrite<AsyncPrefixedRead<ReadHalf<IO>>, WriteHalf<IO>>;
pub fn teed_io_to_prefixed<IO: Unpin>(teed_io: TeedIO<IO>) -> PrefixedIO<IO> {
let (r, w) = teed_io.into_inner();
let (r, bytes_read) = r.into_inner();
let r = AsyncPrefixedRead::new(r, bytes_read);
AsyncReadWrite::new(r, w)
}
/// Wraps an AsyncRead and AsyncWrite instance together to produce a single type which implements
/// AsyncRead + AsyncWrite.
pub struct AsyncReadWrite<R, W> {
r: pin::Pin<Box<R>>,
w: pin::Pin<Box<W>>,
}
impl<R, W> AsyncReadWrite<R, W>
where
R: Unpin,
W: Unpin,
{
pub fn new(r: R, w: W) -> Self {
Self {
r: Box::pin(r),
w: Box::pin(w),
}
}
pub fn into_inner(self) -> (R, W) {
(*pin::Pin::into_inner(self.r), *pin::Pin::into_inner(self.w))
}
}
impl<R, W> AsyncRead for AsyncReadWrite<R, W>
where
R: AsyncRead + Unpin,
{
fn poll_read(
mut self: pin::Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
self.r.as_mut().poll_read(cx, buf)
}
}
impl<R, W> AsyncWrite for AsyncReadWrite<R, W>
where
W: AsyncWrite + Unpin,
{
fn poll_write(
mut self: pin::Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
self.w.as_mut().poll_write(cx, buf)
}
fn poll_flush(
mut self: pin::Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.w.as_mut().poll_flush(cx)
}
fn poll_shutdown(
mut self: pin::Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.w.as_mut().poll_shutdown(cx)
}
}
/// Wraps an AsyncRead in order to capture all bytes which have been read by it into an internal
/// buffer.
pub struct AsyncTeeRead<R> {
r: pin::Pin<Box<R>>,
buf: Vec<u8>,
}
impl<R> AsyncTeeRead<R>
where
R: Unpin,
{
/// Initializes an AsyncTeeRead with an empty internal buffer of the given size.
pub fn with_capacity(r: R, cap: usize) -> Self {
Self {
r: Box::pin(r),
buf: Vec::with_capacity(cap),
}
}
pub fn into_inner(self) -> (R, Vec<u8>) {
(*pin::Pin::into_inner(self.r), self.buf)
}
}
impl<R> AsyncRead for AsyncTeeRead<R>
where
R: AsyncRead,
{
fn poll_read(
mut self: pin::Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
let res = self.r.as_mut().poll_read(cx, buf);
if let task::Poll::Ready(Ok(())) = res {
self.buf.extend_from_slice(buf.filled());
}
res
}
}
pub struct AsyncPrefixedRead<R> {
r: pin::Pin<Box<R>>,
prefix: Vec<u8>,
}
impl<R> AsyncPrefixedRead<R> {
pub fn new(r: R, prefix: Vec<u8>) -> Self {
Self {
r: Box::pin(r),
prefix,
}
}
}
impl<R> AsyncRead for AsyncPrefixedRead<R>
where
R: AsyncRead,
{
fn poll_read(
self: pin::Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.get_mut();
let prefix_len = this.prefix.len();
if prefix_len == 0 {
return this.r.as_mut().poll_read(cx, buf);
}
let n = std::cmp::min(prefix_len, buf.remaining());
let to_write = this.prefix.drain(..n);
buf.put_slice(to_write.as_slice());
task::Poll::Ready(Ok(()))
}
}

View File

@ -293,6 +293,7 @@ impl<'svc> Service {
} }
}; };
// TODO this is wrong, e.g. something.co.uk
let domain_is_zone_apex = args.domain.as_rr().num_labels() == 2; let domain_is_zone_apex = args.domain.as_rr().num_labels() == 2;
let dns_records_have_cname = self.config.dns_records.iter().any(|r| match r { let dns_records_have_cname = self.config.dns_records.iter().any(|r| match r {
service::ConfigDNSRecord::CNAME { .. } => true, service::ConfigDNSRecord::CNAME { .. } => true,