Got proxy origin working, more or less

This commit is contained in:
Brian Picciano 2023-07-16 17:10:13 +02:00
parent 9beeffcdcf
commit a917f32f04
4 changed files with 94 additions and 38 deletions

View File

@ -32,6 +32,9 @@ pub enum GetFileError {
#[error("file not found")]
FileNotFound,
#[error("origin is of kind proxy")]
OriginIsProxy { url: String },
#[error(transparent)]
Unexpected(#[from] unexpected::Error),
}
@ -245,6 +248,11 @@ impl Manager for ManagerImpl {
path: &str,
) -> Result<util::BoxByteStream, GetFileError> {
let config = self.domain_store.get(domain)?;
if let origin::Descr::Proxy { url } = config.origin_descr {
return Err(GetFileError::OriginIsProxy { url });
}
let f = self.origin_store.get_file(&config.origin_descr, path)?;
Ok(f)
}

View File

@ -1,5 +1,5 @@
use crate::error::unexpected;
use std::net;
use crate::error::unexpected::{self, Mappable};
use std::{net, str::FromStr};
// proxy is a special case because it is so tied to the underlying protocol that a request is
// being served on, it can't be abstracted out into a simple "get_file" operation like other
@ -8,8 +8,37 @@ use std::net;
pub async fn serve_http_request(
client_ip: net::IpAddr,
proxy_url: &str,
req: hyper::Request<hyper::Body>,
mut req: hyper::Request<hyper::Body>,
) -> unexpected::Result<hyper::Response<hyper::Body>> {
let parsed_proxy_url =
http::Uri::from_str(proxy_url).or_unexpected_while("parsing proxy url")?;
let scheme = parsed_proxy_url
.scheme()
.or_unexpected_while("expected a scheme of http in the proxy url")?;
if scheme != "http" {
return Err(unexpected::Error::from("proxy url scheme should be 'http"));
}
// figure out what the host header should be, based on the host[:port] of the proxy_url
let host = {
let authority = parsed_proxy_url
.authority()
.or_unexpected_while("getting host from proxy url, there is no host")?;
let host_and_port;
let mut host = authority.host();
if let Some(port) = authority.port() {
host_and_port = format!("{host}:{port}");
host = host_and_port.as_str();
};
http::header::HeaderValue::from_str(host).or_unexpected()?
};
req.headers_mut().insert("host", host);
match hyper_reverse_proxy::call(client_ip, proxy_url, req).await {
Ok(res) => Ok(res),
// ProxyError doesn't actually implement Error :facepalm: so we have to format the error

View File

@ -9,10 +9,10 @@ use hyper::{Body, Method, Request, Response};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::{future, sync};
use std::{future, net, sync};
use crate::error::unexpected;
use crate::{domain, service, util};
use crate::{domain, origin, service, util};
pub struct Service {
domain_manager: sync::Arc<dyn domain::manager::Manager>,
@ -158,8 +158,14 @@ impl<'svc> Service {
)
}
fn serve_origin(&self, domain: domain::Name, path: &str) -> Response<Body> {
async fn serve_origin(
&self,
client_ip: net::IpAddr,
domain: domain::Name,
req: Request<Body>,
) -> Response<Body> {
let mut path_owned;
let path = req.uri().path();
let path = match path.ends_with('/') {
true => {
@ -178,6 +184,13 @@ impl<'svc> Service {
Err(domain::manager::GetFileError::FileNotFound) => {
self.render_error_page(404, "File not found")
}
Err(domain::manager::GetFileError::OriginIsProxy { url }) => {
origin::proxy::serve_http_request(client_ip, &url, req)
.await
.unwrap_or_else(|e| {
self.internal_error(format!("proxying {domain} to {url}: {e}").as_str())
})
}
Err(domain::manager::GetFileError::Unexpected(e)) => {
self.internal_error(format!("failed to fetch file {path}: {e}").as_str())
}
@ -366,15 +379,13 @@ impl<'svc> Service {
self.render_page("/domains.html", Response { domains })
}
async fn handle_request(&self, req: Request<Body>) -> Response<Body> {
let (req, body) = req.into_parts();
async fn handle_request(&self, client_ip: net::IpAddr, req: Request<Body>) -> Response<Body> {
let maybe_host = match (
req.headers
req.headers()
.get("Host")
.and_then(|v| v.to_str().ok())
.map(strip_port),
req.uri.host().map(strip_port),
req.uri().host().map(strip_port),
) {
(Some(h), _) if h != self.config.primary_domain.as_str() => Some(h),
(_, Some(h)) if h != self.config.primary_domain.as_str() => Some(h),
@ -382,12 +393,13 @@ impl<'svc> Service {
}
.and_then(|h| domain::Name::from_str(h).ok());
let path = req.uri.path();
{
let path = req.uri().path();
// Serving acme challenges always takes priority. We serve them from the same store no
// matter the domain, presumably they are cryptographically random enough that it doesn't
// matter.
if req.method == Method::GET && path.starts_with("/.well-known/acme-challenge/") {
if req.method() == Method::GET && path.starts_with("/.well-known/acme-challenge/") {
let token = path.trim_start_matches("/.well-known/acme-challenge/");
if let Ok(key) = self.domain_manager.get_acme_http01_challenge_key(token) {
@ -396,7 +408,7 @@ impl<'svc> Service {
}
// Serving domani challenges similarly takes priority.
if req.method == Method::GET && path == "/.well-known/domani-challenge" {
if req.method() == Method::GET && path == "/.well-known/domani-challenge" {
if let Some(ref domain) = maybe_host {
match self
.domain_manager
@ -412,13 +424,16 @@ impl<'svc> Service {
}
}
}
}
// If a managed domain was given then serve that from its origin
if let Some(domain) = maybe_host {
return self.serve_origin(domain, req.uri.path());
return self.serve_origin(client_ip, domain, req).await;
}
// Serve main domani site
let (req, body) = req.into_parts();
let path = req.uri.path();
if req.method == Method::GET && path.starts_with("/static/") {
return self.render(200, path, ());

View File

@ -4,6 +4,8 @@ use crate::service;
use std::{convert, future, sync};
use futures::StreamExt;
use hyper::server::conn::AddrStream;
use tokio_rustls::server::TlsStream;
use tokio_util::sync::CancellationToken;
pub async fn listen_http(
@ -13,13 +15,14 @@ pub async fn listen_http(
let addr = service.config.http.http_addr.clone();
let primary_domain = service.config.primary_domain.clone();
let make_service = hyper::service::make_service_fn(move |_| {
let make_service = hyper::service::make_service_fn(move |conn: &AddrStream| {
let service = service.clone();
let client_ip = conn.remote_addr().ip();
// Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| {
let service = service.clone();
async move { Ok::<_, convert::Infallible>(service.handle_request(req).await) }
async move { Ok::<_, convert::Infallible>(service.handle_request(client_ip, req).await) }
});
// Return the service to hyper.
@ -48,13 +51,14 @@ pub async fn listen_https(
let addr = service.config.http.https_addr.unwrap().clone();
let primary_domain = service.config.primary_domain.clone();
let make_service = hyper::service::make_service_fn(move |_| {
let make_service = hyper::service::make_service_fn(move |conn: &TlsStream<AddrStream>| {
let service = service.clone();
let client_ip = conn.get_ref().0.remote_addr().ip();
// Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| {
let service = service.clone();
async move { Ok::<_, convert::Infallible>(service.handle_request(req).await) }
async move { Ok::<_, convert::Infallible>(service.handle_request(client_ip, req).await) }
});
// Return the service to hyper.