Switched to hyper, cleaned up manager lifetimes, got domain_sync working

This commit is contained in:
Brian Picciano 2023-05-15 17:42:32 +02:00
parent 26ebda90e8
commit e3c13123db
8 changed files with 328 additions and 521 deletions

190
Cargo.lock generated
View File

@ -107,12 +107,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.21.0" version = "0.21.0"
@ -167,12 +161,6 @@ version = "3.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b1ce199063694f33ffb7dd4e0ee620741495c32833cde5aa08f02a0bf96f0c8" checksum = "9b1ce199063694f33ffb7dd4e0ee620741495c32833cde5aa08f02a0bf96f0c8"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]] [[package]]
name = "bytes" name = "bytes"
version = "1.4.0" version = "1.4.0"
@ -377,11 +365,14 @@ dependencies = [
"gix", "gix",
"handlebars", "handlebars",
"hex", "hex",
"http",
"hyper",
"mime_guess", "mime_guess",
"mockall", "mockall",
"rust-embed", "rust-embed",
"serde", "serde",
"serde_json", "serde_json",
"serde_urlencoded",
"sha2", "sha2",
"signal-hook", "signal-hook",
"signal-hook-tokio", "signal-hook-tokio",
@ -389,7 +380,6 @@ dependencies = [
"thiserror", "thiserror",
"tokio", "tokio",
"trust-dns-client", "trust-dns-client",
"warp",
] ]
[[package]] [[package]]
@ -1156,7 +1146,7 @@ version = "0.31.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f01c2bf7b989c679695ef635fc7d9e80072e08101be4b53193c8e8b649900102" checksum = "f01c2bf7b989c679695ef635fc7d9e80072e08101be4b53193c8e8b649900102"
dependencies = [ dependencies = [
"base64 0.21.0", "base64",
"bstr", "bstr",
"gix-command", "gix-command",
"gix-credentials", "gix-credentials",
@ -1281,31 +1271,6 @@ version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
[[package]]
name = "headers"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584"
dependencies = [
"base64 0.13.1",
"bitflags 1.3.2",
"bytes",
"headers-core",
"http",
"httpdate",
"mime",
"sha1",
]
[[package]]
name = "headers-core"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429"
dependencies = [
"http",
]
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.4.1" version = "0.4.1"
@ -1743,24 +1708,6 @@ dependencies = [
"syn 1.0.109", "syn 1.0.109",
] ]
[[package]]
name = "multer"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2"
dependencies = [
"bytes",
"encoding_rs",
"futures-util",
"http",
"httparse",
"log",
"memchr",
"mime",
"spin 0.9.8",
"version_check",
]
[[package]] [[package]]
name = "nibble_vec" name = "nibble_vec"
version = "0.1.0" version = "0.1.0"
@ -1893,26 +1840,6 @@ dependencies = [
"sha2", "sha2",
] ]
[[package]]
name = "pin-project"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.9" version = "0.2.9"
@ -2150,7 +2077,7 @@ version = "0.11.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13293b639a097af28fc8a90f22add145a9c954e49d77da06263d58cf44d5fb91" checksum = "13293b639a097af28fc8a90f22add145a9c954e49d77da06263d58cf44d5fb91"
dependencies = [ dependencies = [
"base64 0.21.0", "base64",
"bytes", "bytes",
"encoding_rs", "encoding_rs",
"futures-core", "futures-core",
@ -2203,7 +2130,7 @@ dependencies = [
"cc", "cc",
"libc", "libc",
"once_cell", "once_cell",
"spin 0.5.2", "spin",
"untrusted", "untrusted",
"web-sys", "web-sys",
"winapi", "winapi",
@ -2275,7 +2202,7 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b"
dependencies = [ dependencies = [
"base64 0.21.0", "base64",
] ]
[[package]] [[package]]
@ -2293,12 +2220,6 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "scoped-tls"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
[[package]] [[package]]
name = "scopeguard" name = "scopeguard"
version = "1.1.0" version = "1.1.0"
@ -2358,17 +2279,6 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "sha1"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]] [[package]]
name = "sha1_smol" name = "sha1_smol"
version = "1.0.0" version = "1.0.0"
@ -2448,12 +2358,6 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]] [[package]]
name = "static_assertions" name = "static_assertions"
version = "1.1.0" version = "1.1.0"
@ -2622,29 +2526,6 @@ dependencies = [
"webpki", "webpki",
] ]
[[package]]
name = "tokio-stream"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.8" version = "0.7.8"
@ -2672,7 +2553,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"log",
"pin-project-lite", "pin-project-lite",
"tracing-attributes", "tracing-attributes",
"tracing-core", "tracing-core",
@ -2769,25 +2649,6 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "tungstenite"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788"
dependencies = [
"base64 0.13.1",
"byteorder",
"bytes",
"http",
"httparse",
"log",
"rand 0.8.5",
"sha1",
"thiserror",
"url",
"utf-8",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.16.0" version = "1.16.0"
@ -2862,12 +2723,6 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "utf8parse" name = "utf8parse"
version = "0.2.1" version = "0.2.1"
@ -2900,37 +2755,6 @@ dependencies = [
"try-lock", "try-lock",
] ]
[[package]]
name = "warp"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba431ef570df1287f7f8b07e376491ad54f84d26ac473489427231e1718e1f69"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"headers",
"http",
"hyper",
"log",
"mime",
"mime_guess",
"multer",
"percent-encoding",
"pin-project",
"rustls-pemfile",
"scoped-tls",
"serde",
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-stream",
"tokio-tungstenite",
"tokio-util",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.11.0+wasi-snapshot-preview1" version = "0.11.0+wasi-snapshot-preview1"

View File

@ -21,7 +21,6 @@ trust-dns-client = "0.22.0"
mockall = "0.11.4" mockall = "0.11.4"
thiserror = "1.0.40" thiserror = "1.0.40"
tokio = { version = "1.28.1", features = [ "full" ]} tokio = { version = "1.28.1", features = [ "full" ]}
warp = "0.3.5"
signal-hook = "0.3.15" signal-hook = "0.3.15"
futures = "0.3.28" futures = "0.3.28"
signal-hook-tokio = { version = "0.3.1", features = [ "futures-v0_3" ]} signal-hook-tokio = { version = "0.3.1", features = [ "futures-v0_3" ]}
@ -29,3 +28,6 @@ clap = { version = "4.2.7", features = ["derive", "env"] }
handlebars = { version = "4.3.7", features = [ "rust-embed" ]} handlebars = { version = "4.3.7", features = [ "rust-embed" ]}
rust-embed = "6.6.1" rust-embed = "6.6.1"
mime_guess = "2.0.4" mime_guess = "2.0.4"
hyper = { version = "0.14.26", features = [ "server" ]}
http = "0.2.9"
serde_urlencoded = "0.7.1"

View File

@ -4,7 +4,6 @@ use std::sync;
use crate::domain; use crate::domain;
use mockall::automock;
use trust_dns_client::client::{AsyncClient, ClientHandle}; use trust_dns_client::client::{AsyncClient, ClientHandle};
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType}; use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
use trust_dns_client::udp; use trust_dns_client::udp;
@ -33,17 +32,7 @@ pub enum CheckDomainError {
Unexpected(Box<dyn Error>), Unexpected(Box<dyn Error>),
} }
#[automock]
pub trait Checker: std::marker::Send + std::marker::Sync {
fn check_domain(
&self,
domain: &domain::Name,
challenge_token: &str,
) -> Result<(), CheckDomainError>;
}
pub struct DNSChecker { pub struct DNSChecker {
tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
target_cname: Name, target_cname: Name,
// TODO we should use some kind of connection pool here, I suppose // TODO we should use some kind of connection pool here, I suppose
@ -54,7 +43,7 @@ pub fn new(
tokio_runtime: sync::Arc<tokio::runtime::Runtime>, tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
target_cname: domain::Name, target_cname: domain::Name,
resolver_addr: &str, resolver_addr: &str,
) -> Result<impl Checker, NewDNSCheckerError> { ) -> Result<DNSChecker, NewDNSCheckerError> {
let resolver_addr = resolver_addr let resolver_addr = resolver_addr
.parse() .parse()
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?; .map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
@ -68,14 +57,13 @@ pub fn new(
tokio_runtime.spawn(bg); tokio_runtime.spawn(bg);
Ok(DNSChecker { Ok(DNSChecker {
tokio_runtime,
target_cname: target_cname.inner, target_cname: target_cname.inner,
client: tokio::sync::Mutex::new(client), client: tokio::sync::Mutex::new(client),
}) })
} }
impl Checker for DNSChecker { impl DNSChecker {
fn check_domain( pub async fn check_domain(
&self, &self,
domain: &domain::Name, domain: &domain::Name,
challenge_token: &str, challenge_token: &str,
@ -84,13 +72,13 @@ impl Checker for DNSChecker {
// check that the CNAME is installed correctly on the domain // check that the CNAME is installed correctly on the domain
{ {
let response = match self.tokio_runtime.block_on(async { let response = match self
self.client .client
.lock() .lock()
.await .await
.query(domain.clone(), DNSClass::IN, RecordType::CNAME) .query(domain.clone(), DNSClass::IN, RecordType::CNAME)
.await .await
}) { {
Ok(res) => res, Ok(res) => res,
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))), Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
}; };
@ -116,13 +104,13 @@ impl Checker for DNSChecker {
.append_domain(&domain) .append_domain(&domain)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let response = match self.tokio_runtime.block_on(async { let response = match self
self.client .client
.lock() .lock()
.await .await
.query(domain, DNSClass::IN, RecordType::TXT) .query(domain, DNSClass::IN, RecordType::TXT)
.await .await
}) { {
Ok(res) => res, Ok(res) => res,
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))), Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
}; };

View File

@ -9,7 +9,7 @@ use hex::ToHex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
/// Values which the owner of a domain can configure when they install a domain. /// Values which the owner of a domain can configure when they install a domain.
pub struct Config { pub struct Config {
pub origin_descr: Descr, pub origin_descr: Descr,

View File

@ -1,6 +1,8 @@
use crate::domain::{self, checker, config}; use crate::domain::{self, checker, config};
use crate::origin; use crate::origin;
use std::error::Error; use std::error::Error;
use std::future::Future;
use std::pin;
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum GetConfigError { pub enum GetConfigError {
@ -111,31 +113,29 @@ impl From<config::SetError> for SyncWithConfigError {
} }
} }
#[mockall::automock(type Origin=origin::MockOrigin;)] //#[mockall::automock(type Origin=origin::MockOrigin;)]
pub trait Manager: std::marker::Send + std::marker::Sync { pub trait Manager: Send + Sync {
type Origin<'mgr>: origin::Origin + 'mgr
where
Self: 'mgr;
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError>; fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError>;
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError>; fn get_origin(
&self,
domain: &domain::Name,
) -> Result<Box<dyn origin::Origin + '_>, GetOriginError>;
fn sync(&self, domain: &domain::Name) -> Result<(), SyncError>; fn sync(&self, domain: &domain::Name) -> Result<(), SyncError>;
fn sync_with_config( fn sync_with_config(
&self, &self,
domain: &domain::Name, domain: domain::Name,
config: &config::Config, config: config::Config,
) -> Result<(), SyncWithConfigError>; ) -> pin::Pin<Box<dyn Future<Output = Result<(), SyncWithConfigError>> + Send + '_>>;
} }
pub fn new<OriginStore, DomainConfigStore, DomainChecker>( pub fn new<OriginStore, DomainConfigStore>(
origin_store: OriginStore, origin_store: OriginStore,
domain_config_store: DomainConfigStore, domain_config_store: DomainConfigStore,
domain_checker: DomainChecker, domain_checker: checker::DNSChecker,
) -> impl Manager ) -> impl Manager
where where
OriginStore: origin::store::Store, OriginStore: origin::store::Store,
DomainConfigStore: config::Store, DomainConfigStore: config::Store,
DomainChecker: checker::Checker,
{ {
ManagerImpl { ManagerImpl {
origin_store, origin_store,
@ -144,39 +144,36 @@ where
} }
} }
struct ManagerImpl<OriginStore, DomainConfigStore, DomainChecker> struct ManagerImpl<OriginStore, DomainConfigStore>
where where
OriginStore: origin::store::Store, OriginStore: origin::store::Store,
DomainConfigStore: config::Store, DomainConfigStore: config::Store,
DomainChecker: checker::Checker,
{ {
origin_store: OriginStore, origin_store: OriginStore,
domain_config_store: DomainConfigStore, domain_config_store: DomainConfigStore,
domain_checker: DomainChecker, domain_checker: checker::DNSChecker,
} }
impl<OriginStore, DomainConfigStore, DomainChecker> Manager impl<OriginStore, DomainConfigStore> Manager for ManagerImpl<OriginStore, DomainConfigStore>
for ManagerImpl<OriginStore, DomainConfigStore, DomainChecker>
where where
OriginStore: origin::store::Store, OriginStore: origin::store::Store,
DomainConfigStore: config::Store, DomainConfigStore: config::Store,
DomainChecker: checker::Checker,
{ {
type Origin<'mgr> = OriginStore::Origin<'mgr>
where Self: 'mgr;
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError> { fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError> {
Ok(self.domain_config_store.get(domain)?) Ok(self.domain_config_store.get(domain)?)
} }
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError> { fn get_origin(
&self,
domain: &domain::Name,
) -> Result<Box<dyn origin::Origin + '_>, GetOriginError> {
let config = self.domain_config_store.get(domain)?; let config = self.domain_config_store.get(domain)?;
let origin = self let origin = self
.origin_store .origin_store
.get(config.origin_descr) .get(config.origin_descr)
// if there's a config there should be an origin, any error here is unexpected // if there's a config there should be an origin, any error here is unexpected
.map_err(|e| GetOriginError::Unexpected(Box::from(e)))?; .map_err(|e| GetOriginError::Unexpected(Box::from(e)))?;
Ok(origin) Ok(Box::from(origin))
} }
fn sync(&self, domain: &domain::Name) -> Result<(), SyncError> { fn sync(&self, domain: &domain::Name) -> Result<(), SyncError> {
@ -192,20 +189,24 @@ where
fn sync_with_config( fn sync_with_config(
&self, &self,
domain: &domain::Name, domain: domain::Name,
config: &config::Config, config: config::Config,
) -> Result<(), SyncWithConfigError> { ) -> pin::Pin<Box<dyn Future<Output = Result<(), SyncWithConfigError>> + Send + '_>> {
Box::pin(async move {
let config_hash = config let config_hash = config
.hash() .hash()
.map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?; .map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?;
self.domain_checker.check_domain(&domain, &config_hash)?; self.domain_checker
.check_domain(&domain, &config_hash)
.await?;
self.origin_store self.origin_store
.sync(config.origin_descr.clone(), origin::store::Limits {})?; .sync(config.origin_descr.clone(), origin::store::Limits {})?;
self.domain_config_store.set(domain, config)?; self.domain_config_store.set(&domain, &config)?;
Ok(()) Ok(())
})
} }
} }

View File

@ -4,9 +4,11 @@ use signal_hook::consts::signal;
use signal_hook_tokio::Signals; use signal_hook_tokio::Signals;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path; use std::path;
use std::str::FromStr; use std::str::FromStr;
use std::sync;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(version)] #[command(version)]
@ -74,22 +76,42 @@ fn main() {
.expect("domain config store initialized"); .expect("domain config store initialized");
let manager = domiply::domain::manager::new(origin_store, domain_config_store, domain_checker); let manager = domiply::domain::manager::new(origin_store, domain_config_store, domain_checker);
let manager = sync::Arc::new(manager);
let service = domiply::service::new( let service = domiply::service::new(
manager, manager,
config.domain_checker_target_cname, config.domain_checker_target_cname,
config.passphrase, config.passphrase,
) );
.expect("service initialized");
let service = sync::Arc::new(service);
let make_service =
hyper::service::make_service_fn(move |_conn: &hyper::server::conn::AddrStream| {
let service = service.clone();
// Create a `Service` for responding to the request.
let service = hyper::service::service_fn(move |req| {
domiply::service::handle_request(service.clone(), req)
});
// Return the service to hyper.
async move { Ok::<_, Infallible>(service) }
});
tokio_runtime.block_on(async { tokio_runtime.block_on(async {
let (addr, server) = let addr = config.http_listen_addr;
warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async {
println!("Listening on {addr}");
let server = hyper::Server::bind(&addr).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
stop_ch_rx.await.ok(); stop_ch_rx.await.ok();
}); });
println!("Listening on {addr}"); if let Err(e) = graceful.await {
server.await; panic!("server error: {}", e);
}
}); });
println!("Graceful shutdown complete"); println!("Graceful shutdown complete");

View File

@ -1,45 +1,68 @@
use http::status::StatusCode;
use hyper::{Body, Method, Request, Response};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::error::Error; use std::convert::Infallible;
use std::future::Future;
use std::sync; use std::sync;
use warp::Filter;
use crate::domain; use crate::domain;
pub mod http_tpl; pub mod http_tpl;
mod util; mod util;
type Handlebars<'a> = sync::Arc<handlebars::Handlebars<'a>>; type SvcResponse = Result<Response<String>, String>;
struct Renderer<'a, DM> #[derive(Clone)]
where pub struct Service<'svc> {
DM: domain::manager::Manager, domain_manager: sync::Arc<dyn domain::manager::Manager>,
{ target_cname: domain::Name,
domain_manager: sync::Arc<DM>, passphrase: String,
target_cname: sync::Arc<domain::Name>, handlebars: handlebars::Handlebars<'svc>,
passphrase: sync::Arc<String>, }
handlebars: Handlebars<'a>, pub fn new<'svc, 'mgr>(
query_args: HashMap<String, String>, domain_manager: sync::Arc<dyn domain::manager::Manager>,
target_cname: domain::Name,
passphrase: String,
) -> Service<'svc> {
Service {
domain_manager,
target_cname,
passphrase,
handlebars: self::http_tpl::get().expect("Retrieved Handlebars templates"),
}
} }
#[derive(Serialize)] #[derive(Serialize)]
struct BasePresenter<'a, T> { struct BasePresenter<'a, T> {
page_name: &'a str, page_name: &'a str,
query_args: &'a HashMap<String, String>, data: T,
data: &'a T,
} }
impl<'a, DM> Renderer<'a, DM> #[derive(Deserialize)]
where struct DomainGetArgs {
DM: domain::manager::Manager, domain: domain::Name,
{ }
// TODO make this use an io::Write, rather than warp::Reply
fn render<T>(&self, name: &'_ str, value: &'_ T) -> Box<dyn warp::Reply> #[derive(Deserialize)]
struct DomainInitArgs {
domain: domain::Name,
passphrase: String,
}
#[derive(Deserialize)]
struct DomainSyncArgs {
domain: domain::Name,
}
impl<'svc> Service<'svc> {
//// TODO make this use an io::Write, rather than SvcResponse
fn render<T>(&self, status_code: u16, name: &'_ str, value: T) -> SvcResponse
where where
T: Serialize, T: Serialize,
{ {
let rendered = match self.handlebars.render(name, value) { let rendered = match self.handlebars.render(name, &value) {
Ok(res) => res, Ok(res) => res,
Err(handlebars::RenderError { Err(handlebars::RenderError {
template_name: None, template_name: None,
@ -54,155 +77,84 @@ where
.first_or_octet_stream() .first_or_octet_stream()
.to_string(); .to_string();
let reply = warp::reply::html(rendered); match Response::builder()
.status(status_code)
Box::from(warp::reply::with_header( .header("Content-Type", content_type)
reply, .body(rendered)
"Content-Type", {
content_type, Ok(res) => Ok(res),
)) Err(err) => Err(format!("failed to build {}: {}", name, err)),
}
} }
fn render_error_page(&self, status_code: u16, e: &'_ str) -> Box<dyn warp::Reply> { fn render_error_page(&'svc self, status_code: u16, e: &'_ str) -> SvcResponse {
#[derive(Serialize)] #[derive(Serialize)]
struct Response<'a> { struct Response<'a> {
error_msg: &'a str, error_msg: &'a str,
} }
Box::from(warp::reply::with_status(
self.render( self.render(
status_code,
"/base.html", "/base.html",
&BasePresenter { &BasePresenter {
page_name: "/error.html", page_name: "/error.html",
query_args: &HashMap::default(),
data: &Response { error_msg: e }, data: &Response { error_msg: e },
}, },
), )
status_code.try_into().unwrap(),
))
} }
fn render_page<T>(&self, name: &'_ str, data: &'_ T) -> Box<dyn warp::Reply> fn render_page<T>(&self, name: &'_ str, data: T) -> SvcResponse
where where
T: Serialize, T: Serialize,
{ {
self.render( self.render(
200,
"/base.html", "/base.html",
&BasePresenter { BasePresenter {
page_name: name, page_name: name,
query_args: &self.query_args,
data, data,
}, },
) )
} }
}
pub fn new<DM>( async fn with_query_req<'a, F, In, Out>(&self, req: &'a Request<Body>, f: F) -> SvcResponse
manager: DM, where
target_cname: domain::Name, In: Deserialize<'a>,
passphrase: String, F: FnOnce(In) -> Out,
) -> Result< Out: Future<Output = SvcResponse>,
impl warp::Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone + 'static, {
Box<dyn Error>, let query = req.uri().query().unwrap_or("");
> match serde_urlencoded::from_str::<In>(query) {
where Ok(args) => f(args).await,
DM: domain::manager::Manager + 'static, Err(err) => Err(format!("failed to parse query args: {}", err)),
{ }
let manager = sync::Arc::new(manager);
let target_cname = sync::Arc::new(target_cname);
let passphrase = sync::Arc::new(passphrase);
let hbs = sync::Arc::new(self::http_tpl::get()?);
let with_renderer = warp::any()
.and(warp::query::<HashMap<String, String>>())
.map(move |query_args: HashMap<String, String>| Renderer {
domain_manager: manager.clone(),
target_cname: target_cname.clone(),
passphrase: passphrase.clone(),
handlebars: hbs.clone(),
query_args,
});
let static_dir = warp::get()
.and(with_renderer.clone())
.and(warp::path("static"))
.and(warp::path::full())
.map(|renderer: Renderer<'_, DM>, full: warp::path::FullPath| {
renderer.render(full.as_str(), &())
});
let index = warp::get()
.and(with_renderer.clone())
.and(warp::path::end())
.map(|renderer: Renderer<'_, DM>| renderer.render_page("/index.html", &()));
#[derive(Deserialize)]
struct DomainGetNewRequest {
domain: domain::Name,
} }
fn domain_get(&self, args: DomainGetArgs) -> SvcResponse {
#[derive(Serialize)] #[derive(Serialize)]
struct DomainGetNewResponse { struct Response {
domain: domain::Name, domain: domain::Name,
config: Option<domain::config::Config>, config: Option<domain::config::Config>,
} }
let domain = warp::get() match self.domain_manager.get_config(&args.domain) {
.and(with_renderer.clone()) Ok(_config) => self.render_error_page(500, "TODO not yet implemented"),
.and(warp::path!("domain.html")) Err(domain::manager::GetConfigError::NotFound) => self.render_page(
.and(warp::query::<DomainGetNewRequest>())
.and(warp::query::<util::FlatConfig>())
.map(
|renderer: Renderer<'_, DM>,
req: DomainGetNewRequest,
domain_config: util::FlatConfig| {
match renderer.domain_manager.get_config(&req.domain) {
Ok(_config) => renderer.render_error_page(500, "TODO not yet implemented"),
Err(domain::manager::GetConfigError::NotFound) => {
let domain_config = match domain_config.try_into() {
Ok(domain_config) => domain_config,
Err(e) => {
return renderer.render_error_page(
400,
format!("parsing domain configuration: {}", e).as_str(),
)
}
};
renderer.render_page(
"/domain.html", "/domain.html",
&DomainGetNewResponse { &Response {
domain: req.domain, domain: args.domain,
config: domain_config, config: None,
}, },
)
}
Err(domain::manager::GetConfigError::Unexpected(e)) => renderer
.render_error_page(
500,
format!("retrieving configuration: {}", e).as_str(),
), ),
Err(domain::manager::GetConfigError::Unexpected(e)) => {
self.render_error_page(500, format!("retrieving configuration: {}", e).as_str())
}
} }
},
);
#[derive(Deserialize)]
struct DomainInitRequest {
domain: domain::Name,
passphrase: String,
} }
let domain_init = warp::get() fn domain_init(&self, args: DomainInitArgs, domain_config: util::FlatConfig) -> SvcResponse {
.and(with_renderer.clone()) if args.passphrase != self.passphrase.as_str() {
.and(warp::path!("domain_init.html")) return self.render_error_page(401, "Incorrect passphrase");
.and(warp::query::<DomainInitRequest>())
.and(warp::query::<util::FlatConfig>())
.map(
|renderer: Renderer<'_, DM>,
req: DomainInitRequest,
domain_config: util::FlatConfig| {
if req.passphrase != renderer.passphrase.as_str() {
return renderer.render_error_page(401, "Incorrect passphrase");
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -215,83 +167,58 @@ where
let config: domain::config::Config = match domain_config.try_into() { let config: domain::config::Config = match domain_config.try_into() {
Ok(Some(config)) => config, Ok(Some(config)) => config,
Ok(None) => { Ok(None) => return self.render_error_page(400, "domain config is required"),
return renderer.render_error_page(400, "domain config is required")
}
Err(e) => { Err(e) => {
return renderer return self.render_error_page(400, format!("invalid domain config: {e}").as_str())
.render_error_page(400, format!("invalid domain config: {e}").as_str())
} }
}; };
let config_hash = match config.hash() { let config_hash = match config.hash() {
Ok(hash) => hash, Ok(hash) => hash,
Err(e) => { Err(e) => {
return renderer.render_error_page( return self
500, .render_error_page(500, format!("failed to hash domain config: {e}").as_str())
format!("failed to hash domain config: {e}").as_str(),
)
} }
}; };
let target_cname = (*renderer.target_cname).clone(); let target_cname = self.target_cname.clone();
return renderer.render_page( return self.render_page(
"/domain_init.html", "/domain_init.html",
&Response { &Response {
domain: req.domain, domain: args.domain,
flat_config: config.into(), flat_config: config.into(),
target_cname: target_cname, target_cname: target_cname,
challenge_token: config_hash, challenge_token: config_hash,
}, },
); );
},
);
#[derive(Deserialize)]
struct DomainSyncRequest {
domain: domain::Name,
} }
let domain_sync = warp::get() async fn domain_sync(
.and(with_renderer.clone()) &self,
.and(warp::path!("domain_sync.html")) args: DomainSyncArgs,
.and(warp::query::<DomainSyncRequest>()) domain_config: util::FlatConfig,
.and(warp::query::<util::FlatConfig>()) ) -> SvcResponse {
.map(
|renderer: Renderer<'_, DM>,
req: DomainSyncRequest,
domain_config: util::FlatConfig| {
let config: domain::config::Config = match domain_config.try_into() { let config: domain::config::Config = match domain_config.try_into() {
Ok(Some(config)) => config, Ok(Some(config)) => config,
Ok(None) => { Ok(None) => return self.render_error_page(400, "domain config is required"),
return renderer.render_error_page(400, "domain config is required")
}
Err(e) => { Err(e) => {
return renderer return self.render_error_page(400, format!("invalid domain config: {e}").as_str())
.render_error_page(400, format!("invalid domain config: {e}").as_str())
} }
}; };
let sync_result = renderer let sync_result = self
.domain_manager .domain_manager
.sync_with_config(&req.domain, &config); .sync_with_config(args.domain.clone(), config)
.await;
#[derive(Serialize)] #[derive(Serialize)]
struct Response { struct Response {
domain: domain::Name, domain: domain::Name,
flat_config: util::FlatConfig,
error_msg: Option<String>, error_msg: Option<String>,
} }
let mut response = Response { let error_msg = match sync_result {
domain: req.domain,
flat_config: config.into(),
error_msg: None,
};
response.error_msg = match sync_result
{
Ok(_) => None, Ok(_) => None,
Err(domain::manager::SyncWithConfigError::InvalidURL) => Some("Fetching the git repository failed, please double check that you input the correct URL.".to_string()), Err(domain::manager::SyncWithConfigError::InvalidURL) => Some("Fetching the git repository failed, please double check that you input the correct URL.".to_string()),
Err(domain::manager::SyncWithConfigError::InvalidBranchName) => Some("The git repository does not have a branch of the given name, please double check that you input the correct name.".to_string()), Err(domain::manager::SyncWithConfigError::InvalidBranchName) => Some("The git repository does not have a branch of the given name, please double check that you input the correct name.".to_string()),
@ -301,21 +228,64 @@ where
Err(domain::manager::SyncWithConfigError::Unexpected(e)) => Some(format!("An unexpected error occurred: {e}")), Err(domain::manager::SyncWithConfigError::Unexpected(e)) => Some(format!("An unexpected error occurred: {e}")),
}; };
return renderer.render_page( let response = Response {
"/domain_sync.html", domain: args.domain,
&response, error_msg,
) };
},
);
let not_found = warp::any() return self.render_page("/domain_sync.html", response);
.and(with_renderer.clone()) }
.map(|renderer: Renderer<'_, DM>| renderer.render_error_page(404, "Page not found")); }
Ok(static_dir pub async fn handle_request<'svc>(
.or(index) svc: sync::Arc<Service<'svc>>,
.or(domain) req: Request<Body>,
.or(domain_init) ) -> Result<Response<String>, Infallible> {
.or(domain_sync) match handle_request_inner(svc, req).await {
.or(not_found)) Ok(res) => Ok(res),
Err(err) => {
let mut res = Response::new(format!("failed to serve request: {}", err));
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(res)
}
}
}
pub async fn handle_request_inner<'svc>(
svc: sync::Arc<Service<'svc>>,
req: Request<Body>,
) -> SvcResponse {
let method = req.method();
let path = req.uri().path();
if method == &Method::GET && path.starts_with("/static/") {
return svc.render(200, path, ());
}
match (method, path) {
(&Method::GET, "/") | (&Method::GET, "/index.html") => svc.render_page("/index.html", ()),
(&Method::GET, "/domain.html") => {
svc.with_query_req(&req, |args: DomainGetArgs| async { svc.domain_get(args) })
.await
}
(&Method::GET, "/domain_init.html") => {
svc.with_query_req(&req, |args: DomainInitArgs| async {
svc.with_query_req(&req, |config: util::FlatConfig| async {
svc.domain_init(args, config)
})
.await
})
.await
}
(&Method::GET, "/domain_sync.html") => {
svc.with_query_req(&req, |args: DomainSyncArgs| async {
svc.with_query_req(&req, |config: util::FlatConfig| async {
svc.domain_sync(args, config).await
})
.await
})
.await
}
_ => svc.render_error_page(404, "Page not found!"),
}
} }

View File

@ -5,7 +5,7 @@ use handlebars::{Handlebars, TemplateError};
#[prefix = "/"] #[prefix = "/"]
struct Dir; struct Dir;
pub fn get() -> Result<Handlebars<'static>, TemplateError> { pub fn get<'hbs>() -> Result<Handlebars<'hbs>, TemplateError> {
let mut reg = Handlebars::new(); let mut reg = Handlebars::new();
reg.register_embed_templates::<Dir>()?; reg.register_embed_templates::<Dir>()?;
Ok(reg) Ok(reg)