Switched to hyper, cleaned up manager lifetimes, got domain_sync working

This commit is contained in:
Brian Picciano 2023-05-15 17:42:32 +02:00
parent 26ebda90e8
commit e3c13123db
8 changed files with 328 additions and 521 deletions

190
Cargo.lock generated
View File

@ -107,12 +107,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.21.0" version = "0.21.0"
@ -167,12 +161,6 @@ version = "3.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b1ce199063694f33ffb7dd4e0ee620741495c32833cde5aa08f02a0bf96f0c8" checksum = "9b1ce199063694f33ffb7dd4e0ee620741495c32833cde5aa08f02a0bf96f0c8"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]] [[package]]
name = "bytes" name = "bytes"
version = "1.4.0" version = "1.4.0"
@ -377,11 +365,14 @@ dependencies = [
"gix", "gix",
"handlebars", "handlebars",
"hex", "hex",
"http",
"hyper",
"mime_guess", "mime_guess",
"mockall", "mockall",
"rust-embed", "rust-embed",
"serde", "serde",
"serde_json", "serde_json",
"serde_urlencoded",
"sha2", "sha2",
"signal-hook", "signal-hook",
"signal-hook-tokio", "signal-hook-tokio",
@ -389,7 +380,6 @@ dependencies = [
"thiserror", "thiserror",
"tokio", "tokio",
"trust-dns-client", "trust-dns-client",
"warp",
] ]
[[package]] [[package]]
@ -1156,7 +1146,7 @@ version = "0.31.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f01c2bf7b989c679695ef635fc7d9e80072e08101be4b53193c8e8b649900102" checksum = "f01c2bf7b989c679695ef635fc7d9e80072e08101be4b53193c8e8b649900102"
dependencies = [ dependencies = [
"base64 0.21.0", "base64",
"bstr", "bstr",
"gix-command", "gix-command",
"gix-credentials", "gix-credentials",
@ -1281,31 +1271,6 @@ version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
[[package]]
name = "headers"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584"
dependencies = [
"base64 0.13.1",
"bitflags 1.3.2",
"bytes",
"headers-core",
"http",
"httpdate",
"mime",
"sha1",
]
[[package]]
name = "headers-core"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429"
dependencies = [
"http",
]
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.4.1" version = "0.4.1"
@ -1743,24 +1708,6 @@ dependencies = [
"syn 1.0.109", "syn 1.0.109",
] ]
[[package]]
name = "multer"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2"
dependencies = [
"bytes",
"encoding_rs",
"futures-util",
"http",
"httparse",
"log",
"memchr",
"mime",
"spin 0.9.8",
"version_check",
]
[[package]] [[package]]
name = "nibble_vec" name = "nibble_vec"
version = "0.1.0" version = "0.1.0"
@ -1893,26 +1840,6 @@ dependencies = [
"sha2", "sha2",
] ]
[[package]]
name = "pin-project"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.9" version = "0.2.9"
@ -2150,7 +2077,7 @@ version = "0.11.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13293b639a097af28fc8a90f22add145a9c954e49d77da06263d58cf44d5fb91" checksum = "13293b639a097af28fc8a90f22add145a9c954e49d77da06263d58cf44d5fb91"
dependencies = [ dependencies = [
"base64 0.21.0", "base64",
"bytes", "bytes",
"encoding_rs", "encoding_rs",
"futures-core", "futures-core",
@ -2203,7 +2130,7 @@ dependencies = [
"cc", "cc",
"libc", "libc",
"once_cell", "once_cell",
"spin 0.5.2", "spin",
"untrusted", "untrusted",
"web-sys", "web-sys",
"winapi", "winapi",
@ -2275,7 +2202,7 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b"
dependencies = [ dependencies = [
"base64 0.21.0", "base64",
] ]
[[package]] [[package]]
@ -2293,12 +2220,6 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "scoped-tls"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
[[package]] [[package]]
name = "scopeguard" name = "scopeguard"
version = "1.1.0" version = "1.1.0"
@ -2358,17 +2279,6 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "sha1"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]] [[package]]
name = "sha1_smol" name = "sha1_smol"
version = "1.0.0" version = "1.0.0"
@ -2448,12 +2358,6 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]] [[package]]
name = "static_assertions" name = "static_assertions"
version = "1.1.0" version = "1.1.0"
@ -2622,29 +2526,6 @@ dependencies = [
"webpki", "webpki",
] ]
[[package]]
name = "tokio-stream"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.8" version = "0.7.8"
@ -2672,7 +2553,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"log",
"pin-project-lite", "pin-project-lite",
"tracing-attributes", "tracing-attributes",
"tracing-core", "tracing-core",
@ -2769,25 +2649,6 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "tungstenite"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788"
dependencies = [
"base64 0.13.1",
"byteorder",
"bytes",
"http",
"httparse",
"log",
"rand 0.8.5",
"sha1",
"thiserror",
"url",
"utf-8",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.16.0" version = "1.16.0"
@ -2862,12 +2723,6 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "utf8parse" name = "utf8parse"
version = "0.2.1" version = "0.2.1"
@ -2900,37 +2755,6 @@ dependencies = [
"try-lock", "try-lock",
] ]
[[package]]
name = "warp"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba431ef570df1287f7f8b07e376491ad54f84d26ac473489427231e1718e1f69"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"headers",
"http",
"hyper",
"log",
"mime",
"mime_guess",
"multer",
"percent-encoding",
"pin-project",
"rustls-pemfile",
"scoped-tls",
"serde",
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-stream",
"tokio-tungstenite",
"tokio-util",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.11.0+wasi-snapshot-preview1" version = "0.11.0+wasi-snapshot-preview1"

View File

@ -21,7 +21,6 @@ trust-dns-client = "0.22.0"
mockall = "0.11.4" mockall = "0.11.4"
thiserror = "1.0.40" thiserror = "1.0.40"
tokio = { version = "1.28.1", features = [ "full" ]} tokio = { version = "1.28.1", features = [ "full" ]}
warp = "0.3.5"
signal-hook = "0.3.15" signal-hook = "0.3.15"
futures = "0.3.28" futures = "0.3.28"
signal-hook-tokio = { version = "0.3.1", features = [ "futures-v0_3" ]} signal-hook-tokio = { version = "0.3.1", features = [ "futures-v0_3" ]}
@ -29,3 +28,6 @@ clap = { version = "4.2.7", features = ["derive", "env"] }
handlebars = { version = "4.3.7", features = [ "rust-embed" ]} handlebars = { version = "4.3.7", features = [ "rust-embed" ]}
rust-embed = "6.6.1" rust-embed = "6.6.1"
mime_guess = "2.0.4" mime_guess = "2.0.4"
hyper = { version = "0.14.26", features = [ "server" ]}
http = "0.2.9"
serde_urlencoded = "0.7.1"

View File

@ -4,7 +4,6 @@ use std::sync;
use crate::domain; use crate::domain;
use mockall::automock;
use trust_dns_client::client::{AsyncClient, ClientHandle}; use trust_dns_client::client::{AsyncClient, ClientHandle};
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType}; use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
use trust_dns_client::udp; use trust_dns_client::udp;
@ -33,17 +32,7 @@ pub enum CheckDomainError {
Unexpected(Box<dyn Error>), Unexpected(Box<dyn Error>),
} }
#[automock]
pub trait Checker: std::marker::Send + std::marker::Sync {
fn check_domain(
&self,
domain: &domain::Name,
challenge_token: &str,
) -> Result<(), CheckDomainError>;
}
pub struct DNSChecker { pub struct DNSChecker {
tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
target_cname: Name, target_cname: Name,
// TODO we should use some kind of connection pool here, I suppose // TODO we should use some kind of connection pool here, I suppose
@ -54,7 +43,7 @@ pub fn new(
tokio_runtime: sync::Arc<tokio::runtime::Runtime>, tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
target_cname: domain::Name, target_cname: domain::Name,
resolver_addr: &str, resolver_addr: &str,
) -> Result<impl Checker, NewDNSCheckerError> { ) -> Result<DNSChecker, NewDNSCheckerError> {
let resolver_addr = resolver_addr let resolver_addr = resolver_addr
.parse() .parse()
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?; .map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
@ -68,14 +57,13 @@ pub fn new(
tokio_runtime.spawn(bg); tokio_runtime.spawn(bg);
Ok(DNSChecker { Ok(DNSChecker {
tokio_runtime,
target_cname: target_cname.inner, target_cname: target_cname.inner,
client: tokio::sync::Mutex::new(client), client: tokio::sync::Mutex::new(client),
}) })
} }
impl Checker for DNSChecker { impl DNSChecker {
fn check_domain( pub async fn check_domain(
&self, &self,
domain: &domain::Name, domain: &domain::Name,
challenge_token: &str, challenge_token: &str,
@ -84,13 +72,13 @@ impl Checker for DNSChecker {
// check that the CNAME is installed correctly on the domain // check that the CNAME is installed correctly on the domain
{ {
let response = match self.tokio_runtime.block_on(async { let response = match self
self.client .client
.lock() .lock()
.await .await
.query(domain.clone(), DNSClass::IN, RecordType::CNAME) .query(domain.clone(), DNSClass::IN, RecordType::CNAME)
.await .await
}) { {
Ok(res) => res, Ok(res) => res,
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))), Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
}; };
@ -116,13 +104,13 @@ impl Checker for DNSChecker {
.append_domain(&domain) .append_domain(&domain)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let response = match self.tokio_runtime.block_on(async { let response = match self
self.client .client
.lock() .lock()
.await .await
.query(domain, DNSClass::IN, RecordType::TXT) .query(domain, DNSClass::IN, RecordType::TXT)
.await .await
}) { {
Ok(res) => res, Ok(res) => res,
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))), Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
}; };

View File

@ -9,7 +9,7 @@ use hex::ToHex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
/// Values which the owner of a domain can configure when they install a domain. /// Values which the owner of a domain can configure when they install a domain.
pub struct Config { pub struct Config {
pub origin_descr: Descr, pub origin_descr: Descr,

View File

@ -1,6 +1,8 @@
use crate::domain::{self, checker, config}; use crate::domain::{self, checker, config};
use crate::origin; use crate::origin;
use std::error::Error; use std::error::Error;
use std::future::Future;
use std::pin;
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum GetConfigError { pub enum GetConfigError {
@ -111,31 +113,29 @@ impl From<config::SetError> for SyncWithConfigError {
} }
} }
#[mockall::automock(type Origin=origin::MockOrigin;)] //#[mockall::automock(type Origin=origin::MockOrigin;)]
pub trait Manager: std::marker::Send + std::marker::Sync { pub trait Manager: Send + Sync {
type Origin<'mgr>: origin::Origin + 'mgr
where
Self: 'mgr;
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError>; fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError>;
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError>; fn get_origin(
&self,
domain: &domain::Name,
) -> Result<Box<dyn origin::Origin + '_>, GetOriginError>;
fn sync(&self, domain: &domain::Name) -> Result<(), SyncError>; fn sync(&self, domain: &domain::Name) -> Result<(), SyncError>;
fn sync_with_config( fn sync_with_config(
&self, &self,
domain: &domain::Name, domain: domain::Name,
config: &config::Config, config: config::Config,
) -> Result<(), SyncWithConfigError>; ) -> pin::Pin<Box<dyn Future<Output = Result<(), SyncWithConfigError>> + Send + '_>>;
} }
pub fn new<OriginStore, DomainConfigStore, DomainChecker>( pub fn new<OriginStore, DomainConfigStore>(
origin_store: OriginStore, origin_store: OriginStore,
domain_config_store: DomainConfigStore, domain_config_store: DomainConfigStore,
domain_checker: DomainChecker, domain_checker: checker::DNSChecker,
) -> impl Manager ) -> impl Manager
where where
OriginStore: origin::store::Store, OriginStore: origin::store::Store,
DomainConfigStore: config::Store, DomainConfigStore: config::Store,
DomainChecker: checker::Checker,
{ {
ManagerImpl { ManagerImpl {
origin_store, origin_store,
@ -144,39 +144,36 @@ where
} }
} }
struct ManagerImpl<OriginStore, DomainConfigStore, DomainChecker> struct ManagerImpl<OriginStore, DomainConfigStore>
where where
OriginStore: origin::store::Store, OriginStore: origin::store::Store,
DomainConfigStore: config::Store, DomainConfigStore: config::Store,
DomainChecker: checker::Checker,
{ {
origin_store: OriginStore, origin_store: OriginStore,
domain_config_store: DomainConfigStore, domain_config_store: DomainConfigStore,
domain_checker: DomainChecker, domain_checker: checker::DNSChecker,
} }
impl<OriginStore, DomainConfigStore, DomainChecker> Manager impl<OriginStore, DomainConfigStore> Manager for ManagerImpl<OriginStore, DomainConfigStore>
for ManagerImpl<OriginStore, DomainConfigStore, DomainChecker>
where where
OriginStore: origin::store::Store, OriginStore: origin::store::Store,
DomainConfigStore: config::Store, DomainConfigStore: config::Store,
DomainChecker: checker::Checker,
{ {
type Origin<'mgr> = OriginStore::Origin<'mgr>
where Self: 'mgr;
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError> { fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError> {
Ok(self.domain_config_store.get(domain)?) Ok(self.domain_config_store.get(domain)?)
} }
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError> { fn get_origin(
&self,
domain: &domain::Name,
) -> Result<Box<dyn origin::Origin + '_>, GetOriginError> {
let config = self.domain_config_store.get(domain)?; let config = self.domain_config_store.get(domain)?;
let origin = self let origin = self
.origin_store .origin_store
.get(config.origin_descr) .get(config.origin_descr)
// if there's a config there should be an origin, any error here is unexpected // if there's a config there should be an origin, any error here is unexpected
.map_err(|e| GetOriginError::Unexpected(Box::from(e)))?; .map_err(|e| GetOriginError::Unexpected(Box::from(e)))?;
Ok(origin) Ok(Box::from(origin))
} }
fn sync(&self, domain: &domain::Name) -> Result<(), SyncError> { fn sync(&self, domain: &domain::Name) -> Result<(), SyncError> {
@ -192,20 +189,24 @@ where
fn sync_with_config( fn sync_with_config(
&self, &self,
domain: &domain::Name, domain: domain::Name,
config: &config::Config, config: config::Config,
) -> Result<(), SyncWithConfigError> { ) -> pin::Pin<Box<dyn Future<Output = Result<(), SyncWithConfigError>> + Send + '_>> {
let config_hash = config Box::pin(async move {
.hash() let config_hash = config
.map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?; .hash()
.map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?;
self.domain_checker.check_domain(&domain, &config_hash)?; self.domain_checker
.check_domain(&domain, &config_hash)
.await?;
self.origin_store self.origin_store
.sync(config.origin_descr.clone(), origin::store::Limits {})?; .sync(config.origin_descr.clone(), origin::store::Limits {})?;
self.domain_config_store.set(domain, config)?; self.domain_config_store.set(&domain, &config)?;
Ok(()) Ok(())
})
} }
} }

View File

@ -4,9 +4,11 @@ use signal_hook::consts::signal;
use signal_hook_tokio::Signals; use signal_hook_tokio::Signals;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path; use std::path;
use std::str::FromStr; use std::str::FromStr;
use std::sync;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(version)] #[command(version)]
@ -74,22 +76,42 @@ fn main() {
.expect("domain config store initialized"); .expect("domain config store initialized");
let manager = domiply::domain::manager::new(origin_store, domain_config_store, domain_checker); let manager = domiply::domain::manager::new(origin_store, domain_config_store, domain_checker);
let manager = sync::Arc::new(manager);
let service = domiply::service::new( let service = domiply::service::new(
manager, manager,
config.domain_checker_target_cname, config.domain_checker_target_cname,
config.passphrase, config.passphrase,
) );
.expect("service initialized");
tokio_runtime.block_on(async { let service = sync::Arc::new(service);
let (addr, server) =
warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async { let make_service =
stop_ch_rx.await.ok(); hyper::service::make_service_fn(move |_conn: &hyper::server::conn::AddrStream| {
let service = service.clone();
// Create a `Service` for responding to the request.
let service = hyper::service::service_fn(move |req| {
domiply::service::handle_request(service.clone(), req)
}); });
// Return the service to hyper.
async move { Ok::<_, Infallible>(service) }
});
tokio_runtime.block_on(async {
let addr = config.http_listen_addr;
println!("Listening on {addr}"); println!("Listening on {addr}");
server.await; let server = hyper::Server::bind(&addr).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
stop_ch_rx.await.ok();
});
if let Err(e) = graceful.await {
panic!("server error: {}", e);
}
}); });
println!("Graceful shutdown complete"); println!("Graceful shutdown complete");

View File

@ -1,45 +1,68 @@
use http::status::StatusCode;
use hyper::{Body, Method, Request, Response};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::error::Error; use std::convert::Infallible;
use std::future::Future;
use std::sync; use std::sync;
use warp::Filter;
use crate::domain; use crate::domain;
pub mod http_tpl; pub mod http_tpl;
mod util; mod util;
type Handlebars<'a> = sync::Arc<handlebars::Handlebars<'a>>; type SvcResponse = Result<Response<String>, String>;
struct Renderer<'a, DM> #[derive(Clone)]
where pub struct Service<'svc> {
DM: domain::manager::Manager, domain_manager: sync::Arc<dyn domain::manager::Manager>,
{ target_cname: domain::Name,
domain_manager: sync::Arc<DM>, passphrase: String,
target_cname: sync::Arc<domain::Name>, handlebars: handlebars::Handlebars<'svc>,
passphrase: sync::Arc<String>, }
handlebars: Handlebars<'a>, pub fn new<'svc, 'mgr>(
query_args: HashMap<String, String>, domain_manager: sync::Arc<dyn domain::manager::Manager>,
target_cname: domain::Name,
passphrase: String,
) -> Service<'svc> {
Service {
domain_manager,
target_cname,
passphrase,
handlebars: self::http_tpl::get().expect("Retrieved Handlebars templates"),
}
} }
#[derive(Serialize)] #[derive(Serialize)]
struct BasePresenter<'a, T> { struct BasePresenter<'a, T> {
page_name: &'a str, page_name: &'a str,
query_args: &'a HashMap<String, String>, data: T,
data: &'a T,
} }
impl<'a, DM> Renderer<'a, DM> #[derive(Deserialize)]
where struct DomainGetArgs {
DM: domain::manager::Manager, domain: domain::Name,
{ }
// TODO make this use an io::Write, rather than warp::Reply
fn render<T>(&self, name: &'_ str, value: &'_ T) -> Box<dyn warp::Reply> #[derive(Deserialize)]
struct DomainInitArgs {
domain: domain::Name,
passphrase: String,
}
#[derive(Deserialize)]
struct DomainSyncArgs {
domain: domain::Name,
}
impl<'svc> Service<'svc> {
//// TODO make this use an io::Write, rather than SvcResponse
fn render<T>(&self, status_code: u16, name: &'_ str, value: T) -> SvcResponse
where where
T: Serialize, T: Serialize,
{ {
let rendered = match self.handlebars.render(name, value) { let rendered = match self.handlebars.render(name, &value) {
Ok(res) => res, Ok(res) => res,
Err(handlebars::RenderError { Err(handlebars::RenderError {
template_name: None, template_name: None,
@ -54,268 +77,215 @@ where
.first_or_octet_stream() .first_or_octet_stream()
.to_string(); .to_string();
let reply = warp::reply::html(rendered); match Response::builder()
.status(status_code)
Box::from(warp::reply::with_header( .header("Content-Type", content_type)
reply, .body(rendered)
"Content-Type", {
content_type, Ok(res) => Ok(res),
)) Err(err) => Err(format!("failed to build {}: {}", name, err)),
}
} }
fn render_error_page(&self, status_code: u16, e: &'_ str) -> Box<dyn warp::Reply> { fn render_error_page(&'svc self, status_code: u16, e: &'_ str) -> SvcResponse {
#[derive(Serialize)] #[derive(Serialize)]
struct Response<'a> { struct Response<'a> {
error_msg: &'a str, error_msg: &'a str,
} }
Box::from(warp::reply::with_status( self.render(
self.render( status_code,
"/base.html", "/base.html",
&BasePresenter { &BasePresenter {
page_name: "/error.html", page_name: "/error.html",
query_args: &HashMap::default(), data: &Response { error_msg: e },
data: &Response { error_msg: e }, },
}, )
),
status_code.try_into().unwrap(),
))
} }
fn render_page<T>(&self, name: &'_ str, data: &'_ T) -> Box<dyn warp::Reply> fn render_page<T>(&self, name: &'_ str, data: T) -> SvcResponse
where where
T: Serialize, T: Serialize,
{ {
self.render( self.render(
200,
"/base.html", "/base.html",
&BasePresenter { BasePresenter {
page_name: name, page_name: name,
query_args: &self.query_args,
data, data,
}, },
) )
} }
async fn with_query_req<'a, F, In, Out>(&self, req: &'a Request<Body>, f: F) -> SvcResponse
where
In: Deserialize<'a>,
F: FnOnce(In) -> Out,
Out: Future<Output = SvcResponse>,
{
let query = req.uri().query().unwrap_or("");
match serde_urlencoded::from_str::<In>(query) {
Ok(args) => f(args).await,
Err(err) => Err(format!("failed to parse query args: {}", err)),
}
}
fn domain_get(&self, args: DomainGetArgs) -> SvcResponse {
#[derive(Serialize)]
struct Response {
domain: domain::Name,
config: Option<domain::config::Config>,
}
match self.domain_manager.get_config(&args.domain) {
Ok(_config) => self.render_error_page(500, "TODO not yet implemented"),
Err(domain::manager::GetConfigError::NotFound) => self.render_page(
"/domain.html",
&Response {
domain: args.domain,
config: None,
},
),
Err(domain::manager::GetConfigError::Unexpected(e)) => {
self.render_error_page(500, format!("retrieving configuration: {}", e).as_str())
}
}
}
fn domain_init(&self, args: DomainInitArgs, domain_config: util::FlatConfig) -> SvcResponse {
if args.passphrase != self.passphrase.as_str() {
return self.render_error_page(401, "Incorrect passphrase");
}
#[derive(Serialize)]
struct Response {
domain: domain::Name,
flat_config: util::FlatConfig,
target_cname: domain::Name,
challenge_token: String,
}
let config: domain::config::Config = match domain_config.try_into() {
Ok(Some(config)) => config,
Ok(None) => return self.render_error_page(400, "domain config is required"),
Err(e) => {
return self.render_error_page(400, format!("invalid domain config: {e}").as_str())
}
};
let config_hash = match config.hash() {
Ok(hash) => hash,
Err(e) => {
return self
.render_error_page(500, format!("failed to hash domain config: {e}").as_str())
}
};
let target_cname = self.target_cname.clone();
return self.render_page(
"/domain_init.html",
&Response {
domain: args.domain,
flat_config: config.into(),
target_cname: target_cname,
challenge_token: config_hash,
},
);
}
async fn domain_sync(
&self,
args: DomainSyncArgs,
domain_config: util::FlatConfig,
) -> SvcResponse {
let config: domain::config::Config = match domain_config.try_into() {
Ok(Some(config)) => config,
Ok(None) => return self.render_error_page(400, "domain config is required"),
Err(e) => {
return self.render_error_page(400, format!("invalid domain config: {e}").as_str())
}
};
let sync_result = self
.domain_manager
.sync_with_config(args.domain.clone(), config)
.await;
#[derive(Serialize)]
struct Response {
domain: domain::Name,
error_msg: Option<String>,
}
let error_msg = match sync_result {
Ok(_) => None,
Err(domain::manager::SyncWithConfigError::InvalidURL) => Some("Fetching the git repository failed, please double check that you input the correct URL.".to_string()),
Err(domain::manager::SyncWithConfigError::InvalidBranchName) => Some("The git repository does not have a branch of the given name, please double check that you input the correct name.".to_string()),
Err(domain::manager::SyncWithConfigError::AlreadyInProgress) => Some("The configuration of your domain is still in progress, please refresh in a few minutes.".to_string()),
Err(domain::manager::SyncWithConfigError::TargetCNAMENotSet) => Some("The CNAME record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()),
Err(domain::manager::SyncWithConfigError::ChallengeTokenNotSet) => Some("The TXT record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()),
Err(domain::manager::SyncWithConfigError::Unexpected(e)) => Some(format!("An unexpected error occurred: {e}")),
};
let response = Response {
domain: args.domain,
error_msg,
};
return self.render_page("/domain_sync.html", response);
}
} }
pub fn new<DM>( pub async fn handle_request<'svc>(
manager: DM, svc: sync::Arc<Service<'svc>>,
target_cname: domain::Name, req: Request<Body>,
passphrase: String, ) -> Result<Response<String>, Infallible> {
) -> Result< match handle_request_inner(svc, req).await {
impl warp::Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone + 'static, Ok(res) => Ok(res),
Box<dyn Error>, Err(err) => {
> let mut res = Response::new(format!("failed to serve request: {}", err));
where *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
DM: domain::manager::Manager + 'static, Ok(res)
{ }
let manager = sync::Arc::new(manager); }
let target_cname = sync::Arc::new(target_cname); }
let passphrase = sync::Arc::new(passphrase);
pub async fn handle_request_inner<'svc>(
let hbs = sync::Arc::new(self::http_tpl::get()?); svc: sync::Arc<Service<'svc>>,
let with_renderer = warp::any() req: Request<Body>,
.and(warp::query::<HashMap<String, String>>()) ) -> SvcResponse {
.map(move |query_args: HashMap<String, String>| Renderer { let method = req.method();
domain_manager: manager.clone(), let path = req.uri().path();
target_cname: target_cname.clone(),
passphrase: passphrase.clone(), if method == &Method::GET && path.starts_with("/static/") {
handlebars: hbs.clone(), return svc.render(200, path, ());
query_args, }
});
match (method, path) {
let static_dir = warp::get() (&Method::GET, "/") | (&Method::GET, "/index.html") => svc.render_page("/index.html", ()),
.and(with_renderer.clone()) (&Method::GET, "/domain.html") => {
.and(warp::path("static")) svc.with_query_req(&req, |args: DomainGetArgs| async { svc.domain_get(args) })
.and(warp::path::full()) .await
.map(|renderer: Renderer<'_, DM>, full: warp::path::FullPath| { }
renderer.render(full.as_str(), &()) (&Method::GET, "/domain_init.html") => {
}); svc.with_query_req(&req, |args: DomainInitArgs| async {
svc.with_query_req(&req, |config: util::FlatConfig| async {
let index = warp::get() svc.domain_init(args, config)
.and(with_renderer.clone()) })
.and(warp::path::end()) .await
.map(|renderer: Renderer<'_, DM>| renderer.render_page("/index.html", &())); })
.await
#[derive(Deserialize)] }
struct DomainGetNewRequest { (&Method::GET, "/domain_sync.html") => {
domain: domain::Name, svc.with_query_req(&req, |args: DomainSyncArgs| async {
svc.with_query_req(&req, |config: util::FlatConfig| async {
svc.domain_sync(args, config).await
})
.await
})
.await
}
_ => svc.render_error_page(404, "Page not found!"),
} }
#[derive(Serialize)]
struct DomainGetNewResponse {
domain: domain::Name,
config: Option<domain::config::Config>,
}
let domain = warp::get()
.and(with_renderer.clone())
.and(warp::path!("domain.html"))
.and(warp::query::<DomainGetNewRequest>())
.and(warp::query::<util::FlatConfig>())
.map(
|renderer: Renderer<'_, DM>,
req: DomainGetNewRequest,
domain_config: util::FlatConfig| {
match renderer.domain_manager.get_config(&req.domain) {
Ok(_config) => renderer.render_error_page(500, "TODO not yet implemented"),
Err(domain::manager::GetConfigError::NotFound) => {
let domain_config = match domain_config.try_into() {
Ok(domain_config) => domain_config,
Err(e) => {
return renderer.render_error_page(
400,
format!("parsing domain configuration: {}", e).as_str(),
)
}
};
renderer.render_page(
"/domain.html",
&DomainGetNewResponse {
domain: req.domain,
config: domain_config,
},
)
}
Err(domain::manager::GetConfigError::Unexpected(e)) => renderer
.render_error_page(
500,
format!("retrieving configuration: {}", e).as_str(),
),
}
},
);
#[derive(Deserialize)]
struct DomainInitRequest {
domain: domain::Name,
passphrase: String,
}
let domain_init = warp::get()
.and(with_renderer.clone())
.and(warp::path!("domain_init.html"))
.and(warp::query::<DomainInitRequest>())
.and(warp::query::<util::FlatConfig>())
.map(
|renderer: Renderer<'_, DM>,
req: DomainInitRequest,
domain_config: util::FlatConfig| {
if req.passphrase != renderer.passphrase.as_str() {
return renderer.render_error_page(401, "Incorrect passphrase");
}
#[derive(Serialize)]
struct Response {
domain: domain::Name,
flat_config: util::FlatConfig,
target_cname: domain::Name,
challenge_token: String,
}
let config: domain::config::Config = match domain_config.try_into() {
Ok(Some(config)) => config,
Ok(None) => {
return renderer.render_error_page(400, "domain config is required")
}
Err(e) => {
return renderer
.render_error_page(400, format!("invalid domain config: {e}").as_str())
}
};
let config_hash = match config.hash() {
Ok(hash) => hash,
Err(e) => {
return renderer.render_error_page(
500,
format!("failed to hash domain config: {e}").as_str(),
)
}
};
let target_cname = (*renderer.target_cname).clone();
return renderer.render_page(
"/domain_init.html",
&Response {
domain: req.domain,
flat_config: config.into(),
target_cname: target_cname,
challenge_token: config_hash,
},
);
},
);
#[derive(Deserialize)]
struct DomainSyncRequest {
domain: domain::Name,
}
let domain_sync = warp::get()
.and(with_renderer.clone())
.and(warp::path!("domain_sync.html"))
.and(warp::query::<DomainSyncRequest>())
.and(warp::query::<util::FlatConfig>())
.map(
|renderer: Renderer<'_, DM>,
req: DomainSyncRequest,
domain_config: util::FlatConfig| {
let config: domain::config::Config = match domain_config.try_into() {
Ok(Some(config)) => config,
Ok(None) => {
return renderer.render_error_page(400, "domain config is required")
}
Err(e) => {
return renderer
.render_error_page(400, format!("invalid domain config: {e}").as_str())
}
};
let sync_result = renderer
.domain_manager
.sync_with_config(&req.domain, &config);
#[derive(Serialize)]
struct Response {
domain: domain::Name,
flat_config: util::FlatConfig,
error_msg: Option<String>,
}
let mut response = Response {
domain: req.domain,
flat_config: config.into(),
error_msg: None,
};
response.error_msg = match sync_result
{
Ok(_) => None,
Err(domain::manager::SyncWithConfigError::InvalidURL) => Some("Fetching the git repository failed, please double check that you input the correct URL.".to_string()),
Err(domain::manager::SyncWithConfigError::InvalidBranchName) => Some("The git repository does not have a branch of the given name, please double check that you input the correct name.".to_string()),
Err(domain::manager::SyncWithConfigError::AlreadyInProgress) => Some("The configuration of your domain is still in progress, please refresh in a few minutes.".to_string()),
Err(domain::manager::SyncWithConfigError::TargetCNAMENotSet) => Some("The CNAME record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()),
Err(domain::manager::SyncWithConfigError::ChallengeTokenNotSet) => Some("The TXT record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()),
Err(domain::manager::SyncWithConfigError::Unexpected(e)) => Some(format!("An unexpected error occurred: {e}")),
};
return renderer.render_page(
"/domain_sync.html",
&response,
)
},
);
let not_found = warp::any()
.and(with_renderer.clone())
.map(|renderer: Renderer<'_, DM>| renderer.render_error_page(404, "Page not found"));
Ok(static_dir
.or(index)
.or(domain)
.or(domain_init)
.or(domain_sync)
.or(not_found))
} }

View File

@ -5,7 +5,7 @@ use handlebars::{Handlebars, TemplateError};
#[prefix = "/"] #[prefix = "/"]
struct Dir; struct Dir;
pub fn get() -> Result<Handlebars<'static>, TemplateError> { pub fn get<'hbs>() -> Result<Handlebars<'hbs>, TemplateError> {
let mut reg = Handlebars::new(); let mut reg = Handlebars::new();
reg.register_embed_templates::<Dir>()?; reg.register_embed_templates::<Dir>()?;
Ok(reg) Ok(reg)