From a3c823c7b2d9d2ca05c0a9a4feed48a0754b3ca3 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Sat, 8 Jul 2023 16:04:33 +0200 Subject: [PATCH] Simplify the http service a bunch, better error handling --- src/service/http.rs | 54 +++++++++++++++++++++------------------ src/service/http/tasks.rs | 4 +-- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/src/service/http.rs b/src/service/http.rs index 8be9fe3..28eb871 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -4,15 +4,12 @@ mod tpl; use hyper::{Body, Method, Request, Response}; use serde::{Deserialize, Serialize}; -use std::convert::Infallible; use std::str::FromStr; use std::{future, net, sync}; use crate::error::unexpected; use crate::{domain, service, util}; -type SvcResponse = Result, String>; - pub struct Service { domain_manager: sync::Arc, target_a: net::Ipv4Addr, @@ -94,7 +91,7 @@ struct DomainSyncArgs { } impl<'svc> Service { - fn serve(&self, status_code: u16, path: &'_ str, body: Body) -> SvcResponse { + fn serve(&self, status_code: u16, path: &'_ str, body: Body) -> Response { let content_type = mime_guess::from_path(path) .first_or_octet_stream() .to_string(); @@ -104,12 +101,22 @@ impl<'svc> Service { .header("Content-Type", content_type) .body(body) { - Ok(res) => Ok(res), - Err(err) => Err(format!("failed to build {}: {}", path, err)), + Ok(res) => res, + Err(err) => { + if status_code == 500 { + panic!("failed to build {}: {}", path, err); + } + + self.serve( + 500, + "error.txt", + format!("failed to build {}: {}", path, err).into(), + ) + } } } - fn render(&self, status_code: u16, name: &'_ str, value: T) -> SvcResponse + fn render(&self, status_code: u16, name: &'_ str, value: T) -> Response where T: Serialize, { @@ -127,7 +134,7 @@ impl<'svc> Service { self.serve(status_code, name, rendered.into()) } - fn render_error_page(&'svc self, status_code: u16, e: &'_ str) -> SvcResponse { + fn render_error_page(&self, status_code: u16, e: &'_ str) -> Response { #[derive(Serialize)] struct Response<'a> { error_msg: &'a str, @@ -143,7 +150,7 @@ impl<'svc> Service { ) } - fn render_page(&self, name: &'_ str, data: T) -> SvcResponse + fn render_page(&self, name: &'_ str, data: T) -> Response where T: Serialize, { @@ -157,7 +164,7 @@ impl<'svc> Service { ) } - fn serve_origin(&self, domain: domain::Name, path: &'_ str) -> SvcResponse { + fn serve_origin(&self, domain: domain::Name, path: &'_ str) -> Response { let mut path_owned; let path = match path.ends_with('/') { @@ -183,20 +190,24 @@ impl<'svc> Service { } } - async fn with_query_req<'a, F, In, Out>(&self, req: &'a Request, f: F) -> SvcResponse + async fn with_query_req<'a, F, In, Out>(&self, req: &'a Request, f: F) -> Response where In: Deserialize<'a>, F: FnOnce(In) -> Out, - Out: future::Future, + Out: future::Future>, { let query = req.uri().query().unwrap_or(""); match serde_urlencoded::from_str::(query) { Ok(args) => f(args).await, - Err(err) => Err(format!("failed to parse query args: {}", err)), + Err(err) => self.serve( + 400, + "error.txt", + format!("failed to parse query args: {}", err).into(), + ), } } - fn domain_get(&self, args: DomainGetArgs) -> SvcResponse { + fn domain_get(&self, args: DomainGetArgs) -> Response { #[derive(Serialize)] struct Response { domain: domain::Name, @@ -225,7 +236,7 @@ impl<'svc> Service { &self, args: DomainInitArgs, domain_config: service::util::FlatConfig, - ) -> SvcResponse { + ) -> Response { #[derive(Serialize)] struct Response { domain: domain::Name, @@ -265,7 +276,7 @@ impl<'svc> Service { &self, args: DomainSyncArgs, domain_config: service::util::FlatConfig, - ) -> SvcResponse { + ) -> Response { if args.passphrase != self.passphrase.as_str() { return self.render_error_page(401, "Incorrect passphrase"); } @@ -307,7 +318,7 @@ impl<'svc> Service { self.render_page("/domain_sync.html", response) } - fn domains(&self) -> SvcResponse { + fn domains(&self) -> Response { #[derive(Serialize)] struct Response { domains: Vec, @@ -330,7 +341,7 @@ impl<'svc> Service { self.render_page("/domains.html", Response { domains }) } - async fn handle_request_inner(&self, req: Request) -> SvcResponse { + async fn handle_request(&self, req: Request) -> Response { let maybe_host = match ( req.headers() .get("Host") @@ -398,13 +409,6 @@ impl<'svc> Service { _ => self.render_error_page(404, "Page not found!"), } } - - pub async fn handle_request(&self, req: Request) -> Result, Infallible> { - match self.handle_request_inner(req).await { - Ok(res) => Ok(res), - Err(err) => panic!("unexpected error {err}"), - } - } } fn strip_port(host: &str) -> &str { diff --git a/src/service/http/tasks.rs b/src/service/http/tasks.rs index 532a92c..cc5e247 100644 --- a/src/service/http/tasks.rs +++ b/src/service/http/tasks.rs @@ -18,7 +18,7 @@ pub async fn listen_http( // Create a `Service` for responding to the request. let hyper_service = hyper::service::service_fn(move |req| { let service = service.clone(); - async move { service.handle_request(req).await } + async move { Ok::<_, convert::Infallible>(service.handle_request(req).await) } }); // Return the service to hyper. @@ -48,7 +48,7 @@ pub async fn listen_https( // Create a `Service` for responding to the request. let hyper_service = hyper::service::service_fn(move |req| { let service = service.clone(); - async move { service.handle_request(req).await } + async move { Ok::<_, convert::Infallible>(service.handle_request(req).await) } }); // Return the service to hyper.