From e3c13123dbfe24ce0280480f5560756639c5a283 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Mon, 15 May 2023 17:42:32 +0200 Subject: [PATCH] Switched to hyper, cleaned up manager lifetimes, got domain_sync working --- Cargo.lock | 190 +-------------- Cargo.toml | 4 +- src/domain/checker.rs | 46 ++-- src/domain/config.rs | 2 +- src/domain/manager.rs | 71 +++--- src/main.rs | 36 ++- src/service.rs | 498 +++++++++++++++++++--------------------- src/service/http_tpl.rs | 2 +- 8 files changed, 328 insertions(+), 521 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a6f9419..6b80734 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -107,12 +107,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" -[[package]] -name = "base64" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" - [[package]] name = "base64" version = "0.21.0" @@ -167,12 +161,6 @@ version = "3.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b1ce199063694f33ffb7dd4e0ee620741495c32833cde5aa08f02a0bf96f0c8" -[[package]] -name = "byteorder" -version = "1.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" - [[package]] name = "bytes" version = "1.4.0" @@ -377,11 +365,14 @@ dependencies = [ "gix", "handlebars", "hex", + "http", + "hyper", "mime_guess", "mockall", "rust-embed", "serde", "serde_json", + "serde_urlencoded", "sha2", "signal-hook", "signal-hook-tokio", @@ -389,7 +380,6 @@ dependencies = [ "thiserror", "tokio", "trust-dns-client", - "warp", ] [[package]] @@ -1156,7 +1146,7 @@ version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f01c2bf7b989c679695ef635fc7d9e80072e08101be4b53193c8e8b649900102" dependencies = [ - "base64 0.21.0", + "base64", "bstr", "gix-command", "gix-credentials", @@ -1281,31 +1271,6 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" -[[package]] -name = "headers" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584" -dependencies = [ - "base64 0.13.1", - "bitflags 1.3.2", - "bytes", - "headers-core", - "http", - "httpdate", - "mime", - "sha1", -] - -[[package]] -name = "headers-core" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" -dependencies = [ - "http", -] - [[package]] name = "heck" version = "0.4.1" @@ -1743,24 +1708,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "multer" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2" -dependencies = [ - "bytes", - "encoding_rs", - "futures-util", - "http", - "httparse", - "log", - "memchr", - "mime", - "spin 0.9.8", - "version_check", -] - [[package]] name = "nibble_vec" version = "0.1.0" @@ -1893,26 +1840,6 @@ dependencies = [ "sha2", ] -[[package]] -name = "pin-project" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "pin-project-lite" version = "0.2.9" @@ -2150,7 +2077,7 @@ version = "0.11.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13293b639a097af28fc8a90f22add145a9c954e49d77da06263d58cf44d5fb91" dependencies = [ - "base64 0.21.0", + "base64", "bytes", "encoding_rs", "futures-core", @@ -2203,7 +2130,7 @@ dependencies = [ "cc", "libc", "once_cell", - "spin 0.5.2", + "spin", "untrusted", "web-sys", "winapi", @@ -2275,7 +2202,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64 0.21.0", + "base64", ] [[package]] @@ -2293,12 +2220,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "scoped-tls" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" - [[package]] name = "scopeguard" version = "1.1.0" @@ -2358,17 +2279,6 @@ dependencies = [ "serde", ] -[[package]] -name = "sha1" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "sha1_smol" version = "1.0.0" @@ -2448,12 +2358,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - [[package]] name = "static_assertions" version = "1.1.0" @@ -2622,29 +2526,6 @@ dependencies = [ "webpki", ] -[[package]] -name = "tokio-stream" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - -[[package]] -name = "tokio-tungstenite" -version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite", -] - [[package]] name = "tokio-util" version = "0.7.8" @@ -2672,7 +2553,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if", - "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -2769,25 +2649,6 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" -[[package]] -name = "tungstenite" -version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788" -dependencies = [ - "base64 0.13.1", - "byteorder", - "bytes", - "http", - "httparse", - "log", - "rand 0.8.5", - "sha1", - "thiserror", - "url", - "utf-8", -] - [[package]] name = "typenum" version = "1.16.0" @@ -2862,12 +2723,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "utf-8" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" - [[package]] name = "utf8parse" version = "0.2.1" @@ -2900,37 +2755,6 @@ dependencies = [ "try-lock", ] -[[package]] -name = "warp" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba431ef570df1287f7f8b07e376491ad54f84d26ac473489427231e1718e1f69" -dependencies = [ - "bytes", - "futures-channel", - "futures-util", - "headers", - "http", - "hyper", - "log", - "mime", - "mime_guess", - "multer", - "percent-encoding", - "pin-project", - "rustls-pemfile", - "scoped-tls", - "serde", - "serde_json", - "serde_urlencoded", - "tokio", - "tokio-stream", - "tokio-tungstenite", - "tokio-util", - "tower-service", - "tracing", -] - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index 3fef27d..5c11751 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,6 @@ trust-dns-client = "0.22.0" mockall = "0.11.4" thiserror = "1.0.40" tokio = { version = "1.28.1", features = [ "full" ]} -warp = "0.3.5" signal-hook = "0.3.15" futures = "0.3.28" signal-hook-tokio = { version = "0.3.1", features = [ "futures-v0_3" ]} @@ -29,3 +28,6 @@ clap = { version = "4.2.7", features = ["derive", "env"] } handlebars = { version = "4.3.7", features = [ "rust-embed" ]} rust-embed = "6.6.1" mime_guess = "2.0.4" +hyper = { version = "0.14.26", features = [ "server" ]} +http = "0.2.9" +serde_urlencoded = "0.7.1" diff --git a/src/domain/checker.rs b/src/domain/checker.rs index 85014de..730e57d 100644 --- a/src/domain/checker.rs +++ b/src/domain/checker.rs @@ -4,7 +4,6 @@ use std::sync; use crate::domain; -use mockall::automock; use trust_dns_client::client::{AsyncClient, ClientHandle}; use trust_dns_client::rr::{DNSClass, Name, RData, RecordType}; use trust_dns_client::udp; @@ -33,17 +32,7 @@ pub enum CheckDomainError { Unexpected(Box), } -#[automock] -pub trait Checker: std::marker::Send + std::marker::Sync { - fn check_domain( - &self, - domain: &domain::Name, - challenge_token: &str, - ) -> Result<(), CheckDomainError>; -} - pub struct DNSChecker { - tokio_runtime: sync::Arc, target_cname: Name, // TODO we should use some kind of connection pool here, I suppose @@ -54,7 +43,7 @@ pub fn new( tokio_runtime: sync::Arc, target_cname: domain::Name, resolver_addr: &str, -) -> Result { +) -> Result { let resolver_addr = resolver_addr .parse() .map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?; @@ -68,14 +57,13 @@ pub fn new( tokio_runtime.spawn(bg); Ok(DNSChecker { - tokio_runtime, target_cname: target_cname.inner, client: tokio::sync::Mutex::new(client), }) } -impl Checker for DNSChecker { - fn check_domain( +impl DNSChecker { + pub async fn check_domain( &self, domain: &domain::Name, challenge_token: &str, @@ -84,13 +72,13 @@ impl Checker for DNSChecker { // check that the CNAME is installed correctly on the domain { - let response = match self.tokio_runtime.block_on(async { - self.client - .lock() - .await - .query(domain.clone(), DNSClass::IN, RecordType::CNAME) - .await - }) { + let response = match self + .client + .lock() + .await + .query(domain.clone(), DNSClass::IN, RecordType::CNAME) + .await + { Ok(res) => res, Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))), }; @@ -116,13 +104,13 @@ impl Checker for DNSChecker { .append_domain(&domain) .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; - let response = match self.tokio_runtime.block_on(async { - self.client - .lock() - .await - .query(domain, DNSClass::IN, RecordType::TXT) - .await - }) { + let response = match self + .client + .lock() + .await + .query(domain, DNSClass::IN, RecordType::TXT) + .await + { Ok(res) => res, Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))), }; diff --git a/src/domain/config.rs b/src/domain/config.rs index fe76070..445722c 100644 --- a/src/domain/config.rs +++ b/src/domain/config.rs @@ -9,7 +9,7 @@ use hex::ToHex; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] /// Values which the owner of a domain can configure when they install a domain. pub struct Config { pub origin_descr: Descr, diff --git a/src/domain/manager.rs b/src/domain/manager.rs index 80ae211..d0c9a80 100644 --- a/src/domain/manager.rs +++ b/src/domain/manager.rs @@ -1,6 +1,8 @@ use crate::domain::{self, checker, config}; use crate::origin; use std::error::Error; +use std::future::Future; +use std::pin; #[derive(thiserror::Error, Debug)] pub enum GetConfigError { @@ -111,31 +113,29 @@ impl From for SyncWithConfigError { } } -#[mockall::automock(type Origin=origin::MockOrigin;)] -pub trait Manager: std::marker::Send + std::marker::Sync { - type Origin<'mgr>: origin::Origin + 'mgr - where - Self: 'mgr; - +//#[mockall::automock(type Origin=origin::MockOrigin;)] +pub trait Manager: Send + Sync { fn get_config(&self, domain: &domain::Name) -> Result; - fn get_origin(&self, domain: &domain::Name) -> Result, GetOriginError>; + fn get_origin( + &self, + domain: &domain::Name, + ) -> Result, GetOriginError>; fn sync(&self, domain: &domain::Name) -> Result<(), SyncError>; fn sync_with_config( &self, - domain: &domain::Name, - config: &config::Config, - ) -> Result<(), SyncWithConfigError>; + domain: domain::Name, + config: config::Config, + ) -> pin::Pin> + Send + '_>>; } -pub fn new( +pub fn new( origin_store: OriginStore, domain_config_store: DomainConfigStore, - domain_checker: DomainChecker, + domain_checker: checker::DNSChecker, ) -> impl Manager where OriginStore: origin::store::Store, DomainConfigStore: config::Store, - DomainChecker: checker::Checker, { ManagerImpl { origin_store, @@ -144,39 +144,36 @@ where } } -struct ManagerImpl +struct ManagerImpl where OriginStore: origin::store::Store, DomainConfigStore: config::Store, - DomainChecker: checker::Checker, { origin_store: OriginStore, domain_config_store: DomainConfigStore, - domain_checker: DomainChecker, + domain_checker: checker::DNSChecker, } -impl Manager - for ManagerImpl +impl Manager for ManagerImpl where OriginStore: origin::store::Store, DomainConfigStore: config::Store, - DomainChecker: checker::Checker, { - type Origin<'mgr> = OriginStore::Origin<'mgr> - where Self: 'mgr; - fn get_config(&self, domain: &domain::Name) -> Result { Ok(self.domain_config_store.get(domain)?) } - fn get_origin(&self, domain: &domain::Name) -> Result, GetOriginError> { + fn get_origin( + &self, + domain: &domain::Name, + ) -> Result, GetOriginError> { let config = self.domain_config_store.get(domain)?; let origin = self .origin_store .get(config.origin_descr) // if there's a config there should be an origin, any error here is unexpected .map_err(|e| GetOriginError::Unexpected(Box::from(e)))?; - Ok(origin) + Ok(Box::from(origin)) } fn sync(&self, domain: &domain::Name) -> Result<(), SyncError> { @@ -192,20 +189,24 @@ where fn sync_with_config( &self, - domain: &domain::Name, - config: &config::Config, - ) -> Result<(), SyncWithConfigError> { - let config_hash = config - .hash() - .map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?; + domain: domain::Name, + config: config::Config, + ) -> pin::Pin> + Send + '_>> { + Box::pin(async move { + let config_hash = config + .hash() + .map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?; - self.domain_checker.check_domain(&domain, &config_hash)?; + self.domain_checker + .check_domain(&domain, &config_hash) + .await?; - self.origin_store - .sync(config.origin_descr.clone(), origin::store::Limits {})?; + self.origin_store + .sync(config.origin_descr.clone(), origin::store::Limits {})?; - self.domain_config_store.set(domain, config)?; + self.domain_config_store.set(&domain, &config)?; - Ok(()) + Ok(()) + }) } } diff --git a/src/main.rs b/src/main.rs index 23bff5d..6a5f5f7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,9 +4,11 @@ use signal_hook::consts::signal; use signal_hook_tokio::Signals; use tokio::sync::oneshot; +use std::convert::Infallible; use std::net::SocketAddr; use std::path; use std::str::FromStr; +use std::sync; #[derive(Parser, Debug)] #[command(version)] @@ -74,22 +76,42 @@ fn main() { .expect("domain config store initialized"); let manager = domiply::domain::manager::new(origin_store, domain_config_store, domain_checker); + let manager = sync::Arc::new(manager); let service = domiply::service::new( manager, config.domain_checker_target_cname, config.passphrase, - ) - .expect("service initialized"); + ); - tokio_runtime.block_on(async { - let (addr, server) = - warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async { - stop_ch_rx.await.ok(); + let service = sync::Arc::new(service); + + let make_service = + hyper::service::make_service_fn(move |_conn: &hyper::server::conn::AddrStream| { + let service = service.clone(); + + // Create a `Service` for responding to the request. + let service = hyper::service::service_fn(move |req| { + domiply::service::handle_request(service.clone(), req) }); + // Return the service to hyper. + async move { Ok::<_, Infallible>(service) } + }); + + tokio_runtime.block_on(async { + let addr = config.http_listen_addr; + println!("Listening on {addr}"); - server.await; + let server = hyper::Server::bind(&addr).serve(make_service); + + let graceful = server.with_graceful_shutdown(async { + stop_ch_rx.await.ok(); + }); + + if let Err(e) = graceful.await { + panic!("server error: {}", e); + } }); println!("Graceful shutdown complete"); diff --git a/src/service.rs b/src/service.rs index f5c8ed1..27a5068 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,45 +1,68 @@ +use http::status::StatusCode; +use hyper::{Body, Method, Request, Response}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::error::Error; + +use std::convert::Infallible; +use std::future::Future; use std::sync; -use warp::Filter; use crate::domain; pub mod http_tpl; mod util; -type Handlebars<'a> = sync::Arc>; +type SvcResponse = Result, String>; -struct Renderer<'a, DM> -where - DM: domain::manager::Manager, -{ - domain_manager: sync::Arc, - target_cname: sync::Arc, - passphrase: sync::Arc, +#[derive(Clone)] +pub struct Service<'svc> { + domain_manager: sync::Arc, + target_cname: domain::Name, + passphrase: String, + handlebars: handlebars::Handlebars<'svc>, +} - handlebars: Handlebars<'a>, - query_args: HashMap, +pub fn new<'svc, 'mgr>( + domain_manager: sync::Arc, + target_cname: domain::Name, + passphrase: String, +) -> Service<'svc> { + Service { + domain_manager, + target_cname, + passphrase, + handlebars: self::http_tpl::get().expect("Retrieved Handlebars templates"), + } } #[derive(Serialize)] struct BasePresenter<'a, T> { page_name: &'a str, - query_args: &'a HashMap, - data: &'a T, + data: T, } -impl<'a, DM> Renderer<'a, DM> -where - DM: domain::manager::Manager, -{ - // TODO make this use an io::Write, rather than warp::Reply - fn render(&self, name: &'_ str, value: &'_ T) -> Box +#[derive(Deserialize)] +struct DomainGetArgs { + domain: domain::Name, +} + +#[derive(Deserialize)] +struct DomainInitArgs { + domain: domain::Name, + passphrase: String, +} + +#[derive(Deserialize)] +struct DomainSyncArgs { + domain: domain::Name, +} + +impl<'svc> Service<'svc> { + //// TODO make this use an io::Write, rather than SvcResponse + fn render(&self, status_code: u16, name: &'_ str, value: T) -> SvcResponse where T: Serialize, { - let rendered = match self.handlebars.render(name, value) { + let rendered = match self.handlebars.render(name, &value) { Ok(res) => res, Err(handlebars::RenderError { template_name: None, @@ -54,268 +77,215 @@ where .first_or_octet_stream() .to_string(); - let reply = warp::reply::html(rendered); - - Box::from(warp::reply::with_header( - reply, - "Content-Type", - content_type, - )) + match Response::builder() + .status(status_code) + .header("Content-Type", content_type) + .body(rendered) + { + Ok(res) => Ok(res), + Err(err) => Err(format!("failed to build {}: {}", name, err)), + } } - fn render_error_page(&self, status_code: u16, e: &'_ str) -> Box { + fn render_error_page(&'svc self, status_code: u16, e: &'_ str) -> SvcResponse { #[derive(Serialize)] struct Response<'a> { error_msg: &'a str, } - Box::from(warp::reply::with_status( - self.render( - "/base.html", - &BasePresenter { - page_name: "/error.html", - query_args: &HashMap::default(), - data: &Response { error_msg: e }, - }, - ), - status_code.try_into().unwrap(), - )) + self.render( + status_code, + "/base.html", + &BasePresenter { + page_name: "/error.html", + data: &Response { error_msg: e }, + }, + ) } - fn render_page(&self, name: &'_ str, data: &'_ T) -> Box + fn render_page(&self, name: &'_ str, data: T) -> SvcResponse where T: Serialize, { self.render( + 200, "/base.html", - &BasePresenter { + BasePresenter { page_name: name, - query_args: &self.query_args, data, }, ) } + + async fn with_query_req<'a, F, In, Out>(&self, req: &'a Request, f: F) -> SvcResponse + where + In: Deserialize<'a>, + F: FnOnce(In) -> Out, + Out: Future, + { + let query = req.uri().query().unwrap_or(""); + match serde_urlencoded::from_str::(query) { + Ok(args) => f(args).await, + Err(err) => Err(format!("failed to parse query args: {}", err)), + } + } + + fn domain_get(&self, args: DomainGetArgs) -> SvcResponse { + #[derive(Serialize)] + struct Response { + domain: domain::Name, + config: Option, + } + + match self.domain_manager.get_config(&args.domain) { + Ok(_config) => self.render_error_page(500, "TODO not yet implemented"), + Err(domain::manager::GetConfigError::NotFound) => self.render_page( + "/domain.html", + &Response { + domain: args.domain, + config: None, + }, + ), + Err(domain::manager::GetConfigError::Unexpected(e)) => { + self.render_error_page(500, format!("retrieving configuration: {}", e).as_str()) + } + } + } + + fn domain_init(&self, args: DomainInitArgs, domain_config: util::FlatConfig) -> SvcResponse { + if args.passphrase != self.passphrase.as_str() { + return self.render_error_page(401, "Incorrect passphrase"); + } + + #[derive(Serialize)] + struct Response { + domain: domain::Name, + flat_config: util::FlatConfig, + target_cname: domain::Name, + challenge_token: String, + } + + let config: domain::config::Config = match domain_config.try_into() { + Ok(Some(config)) => config, + Ok(None) => return self.render_error_page(400, "domain config is required"), + Err(e) => { + return self.render_error_page(400, format!("invalid domain config: {e}").as_str()) + } + }; + + let config_hash = match config.hash() { + Ok(hash) => hash, + Err(e) => { + return self + .render_error_page(500, format!("failed to hash domain config: {e}").as_str()) + } + }; + + let target_cname = self.target_cname.clone(); + + return self.render_page( + "/domain_init.html", + &Response { + domain: args.domain, + flat_config: config.into(), + target_cname: target_cname, + challenge_token: config_hash, + }, + ); + } + + async fn domain_sync( + &self, + args: DomainSyncArgs, + domain_config: util::FlatConfig, + ) -> SvcResponse { + let config: domain::config::Config = match domain_config.try_into() { + Ok(Some(config)) => config, + Ok(None) => return self.render_error_page(400, "domain config is required"), + Err(e) => { + return self.render_error_page(400, format!("invalid domain config: {e}").as_str()) + } + }; + + let sync_result = self + .domain_manager + .sync_with_config(args.domain.clone(), config) + .await; + + #[derive(Serialize)] + struct Response { + domain: domain::Name, + error_msg: Option, + } + + let error_msg = match sync_result { + Ok(_) => None, + Err(domain::manager::SyncWithConfigError::InvalidURL) => Some("Fetching the git repository failed, please double check that you input the correct URL.".to_string()), + Err(domain::manager::SyncWithConfigError::InvalidBranchName) => Some("The git repository does not have a branch of the given name, please double check that you input the correct name.".to_string()), + Err(domain::manager::SyncWithConfigError::AlreadyInProgress) => Some("The configuration of your domain is still in progress, please refresh in a few minutes.".to_string()), + Err(domain::manager::SyncWithConfigError::TargetCNAMENotSet) => Some("The CNAME record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()), + Err(domain::manager::SyncWithConfigError::ChallengeTokenNotSet) => Some("The TXT record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()), + Err(domain::manager::SyncWithConfigError::Unexpected(e)) => Some(format!("An unexpected error occurred: {e}")), + }; + + let response = Response { + domain: args.domain, + error_msg, + }; + + return self.render_page("/domain_sync.html", response); + } } -pub fn new( - manager: DM, - target_cname: domain::Name, - passphrase: String, -) -> Result< - impl warp::Filter + Clone + 'static, - Box, -> -where - DM: domain::manager::Manager + 'static, -{ - let manager = sync::Arc::new(manager); - let target_cname = sync::Arc::new(target_cname); - let passphrase = sync::Arc::new(passphrase); - - let hbs = sync::Arc::new(self::http_tpl::get()?); - let with_renderer = warp::any() - .and(warp::query::>()) - .map(move |query_args: HashMap| Renderer { - domain_manager: manager.clone(), - target_cname: target_cname.clone(), - passphrase: passphrase.clone(), - handlebars: hbs.clone(), - query_args, - }); - - let static_dir = warp::get() - .and(with_renderer.clone()) - .and(warp::path("static")) - .and(warp::path::full()) - .map(|renderer: Renderer<'_, DM>, full: warp::path::FullPath| { - renderer.render(full.as_str(), &()) - }); - - let index = warp::get() - .and(with_renderer.clone()) - .and(warp::path::end()) - .map(|renderer: Renderer<'_, DM>| renderer.render_page("/index.html", &())); - - #[derive(Deserialize)] - struct DomainGetNewRequest { - domain: domain::Name, +pub async fn handle_request<'svc>( + svc: sync::Arc>, + req: Request, +) -> Result, Infallible> { + match handle_request_inner(svc, req).await { + Ok(res) => Ok(res), + Err(err) => { + let mut res = Response::new(format!("failed to serve request: {}", err)); + *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + Ok(res) + } + } +} + +pub async fn handle_request_inner<'svc>( + svc: sync::Arc>, + req: Request, +) -> SvcResponse { + let method = req.method(); + let path = req.uri().path(); + + if method == &Method::GET && path.starts_with("/static/") { + return svc.render(200, path, ()); + } + + match (method, path) { + (&Method::GET, "/") | (&Method::GET, "/index.html") => svc.render_page("/index.html", ()), + (&Method::GET, "/domain.html") => { + svc.with_query_req(&req, |args: DomainGetArgs| async { svc.domain_get(args) }) + .await + } + (&Method::GET, "/domain_init.html") => { + svc.with_query_req(&req, |args: DomainInitArgs| async { + svc.with_query_req(&req, |config: util::FlatConfig| async { + svc.domain_init(args, config) + }) + .await + }) + .await + } + (&Method::GET, "/domain_sync.html") => { + svc.with_query_req(&req, |args: DomainSyncArgs| async { + svc.with_query_req(&req, |config: util::FlatConfig| async { + svc.domain_sync(args, config).await + }) + .await + }) + .await + } + _ => svc.render_error_page(404, "Page not found!"), } - - #[derive(Serialize)] - struct DomainGetNewResponse { - domain: domain::Name, - config: Option, - } - - let domain = warp::get() - .and(with_renderer.clone()) - .and(warp::path!("domain.html")) - .and(warp::query::()) - .and(warp::query::()) - .map( - |renderer: Renderer<'_, DM>, - req: DomainGetNewRequest, - domain_config: util::FlatConfig| { - match renderer.domain_manager.get_config(&req.domain) { - Ok(_config) => renderer.render_error_page(500, "TODO not yet implemented"), - Err(domain::manager::GetConfigError::NotFound) => { - let domain_config = match domain_config.try_into() { - Ok(domain_config) => domain_config, - Err(e) => { - return renderer.render_error_page( - 400, - format!("parsing domain configuration: {}", e).as_str(), - ) - } - }; - - renderer.render_page( - "/domain.html", - &DomainGetNewResponse { - domain: req.domain, - config: domain_config, - }, - ) - } - Err(domain::manager::GetConfigError::Unexpected(e)) => renderer - .render_error_page( - 500, - format!("retrieving configuration: {}", e).as_str(), - ), - } - }, - ); - - #[derive(Deserialize)] - struct DomainInitRequest { - domain: domain::Name, - passphrase: String, - } - - let domain_init = warp::get() - .and(with_renderer.clone()) - .and(warp::path!("domain_init.html")) - .and(warp::query::()) - .and(warp::query::()) - .map( - |renderer: Renderer<'_, DM>, - req: DomainInitRequest, - domain_config: util::FlatConfig| { - if req.passphrase != renderer.passphrase.as_str() { - return renderer.render_error_page(401, "Incorrect passphrase"); - } - - #[derive(Serialize)] - struct Response { - domain: domain::Name, - flat_config: util::FlatConfig, - target_cname: domain::Name, - challenge_token: String, - } - - let config: domain::config::Config = match domain_config.try_into() { - Ok(Some(config)) => config, - Ok(None) => { - return renderer.render_error_page(400, "domain config is required") - } - Err(e) => { - return renderer - .render_error_page(400, format!("invalid domain config: {e}").as_str()) - } - }; - - let config_hash = match config.hash() { - Ok(hash) => hash, - Err(e) => { - return renderer.render_error_page( - 500, - format!("failed to hash domain config: {e}").as_str(), - ) - } - }; - - let target_cname = (*renderer.target_cname).clone(); - - return renderer.render_page( - "/domain_init.html", - &Response { - domain: req.domain, - flat_config: config.into(), - target_cname: target_cname, - challenge_token: config_hash, - }, - ); - }, - ); - - #[derive(Deserialize)] - struct DomainSyncRequest { - domain: domain::Name, - } - - let domain_sync = warp::get() - .and(with_renderer.clone()) - .and(warp::path!("domain_sync.html")) - .and(warp::query::()) - .and(warp::query::()) - .map( - |renderer: Renderer<'_, DM>, - req: DomainSyncRequest, - domain_config: util::FlatConfig| { - let config: domain::config::Config = match domain_config.try_into() { - Ok(Some(config)) => config, - Ok(None) => { - return renderer.render_error_page(400, "domain config is required") - } - Err(e) => { - return renderer - .render_error_page(400, format!("invalid domain config: {e}").as_str()) - } - }; - - let sync_result = renderer - .domain_manager - .sync_with_config(&req.domain, &config); - - #[derive(Serialize)] - struct Response { - domain: domain::Name, - flat_config: util::FlatConfig, - error_msg: Option, - } - - let mut response = Response { - domain: req.domain, - flat_config: config.into(), - error_msg: None, - }; - - response.error_msg = match sync_result - { - Ok(_) => None, - Err(domain::manager::SyncWithConfigError::InvalidURL) => Some("Fetching the git repository failed, please double check that you input the correct URL.".to_string()), - Err(domain::manager::SyncWithConfigError::InvalidBranchName) => Some("The git repository does not have a branch of the given name, please double check that you input the correct name.".to_string()), - Err(domain::manager::SyncWithConfigError::AlreadyInProgress) => Some("The configuration of your domain is still in progress, please refresh in a few minutes.".to_string()), - Err(domain::manager::SyncWithConfigError::TargetCNAMENotSet) => Some("The CNAME record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()), - Err(domain::manager::SyncWithConfigError::ChallengeTokenNotSet) => Some("The TXT record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()), - Err(domain::manager::SyncWithConfigError::Unexpected(e)) => Some(format!("An unexpected error occurred: {e}")), - }; - - return renderer.render_page( - "/domain_sync.html", - &response, - ) - }, - ); - - let not_found = warp::any() - .and(with_renderer.clone()) - .map(|renderer: Renderer<'_, DM>| renderer.render_error_page(404, "Page not found")); - - Ok(static_dir - .or(index) - .or(domain) - .or(domain_init) - .or(domain_sync) - .or(not_found)) } diff --git a/src/service/http_tpl.rs b/src/service/http_tpl.rs index 8aac6e4..d403d8e 100644 --- a/src/service/http_tpl.rs +++ b/src/service/http_tpl.rs @@ -5,7 +5,7 @@ use handlebars::{Handlebars, TemplateError}; #[prefix = "/"] struct Dir; -pub fn get() -> Result, TemplateError> { +pub fn get<'hbs>() -> Result, TemplateError> { let mut reg = Handlebars::new(); reg.register_embed_templates::()?; Ok(reg)