Compare commits

...

11 Commits

Author SHA1 Message Date
Brian Picciano
f2374cded5 Use TaskSet to cleanly shut down the http service 2023-06-19 20:57:26 +02:00
Brian Picciano
43f4b98b38 Move handle_request onto service as a method 2023-06-19 20:12:15 +02:00
Brian Picciano
506037dcd0 Moved service tasks internally, main crashes on shutdown though 2023-06-18 15:57:51 +02:00
Brian Picciano
7ea97b2617 Get rid of lifetime on Service 2023-06-18 15:10:06 +02:00
Brian Picciano
1f9ae0038f restructure service module 2023-06-18 14:53:25 +02:00
Brian Picciano
dbc912a9d3 add sync_cert method to domain manager 2023-06-18 14:46:52 +02:00
Brian Picciano
6da68dc042 stop domain manager's inner tasks on shutdown 2023-06-18 14:28:46 +02:00
Brian Picciano
6941ceec8e Remove final Boxed types 2023-06-18 13:53:02 +02:00
Brian Picciano
3d3dfb34ed Got rid of Boxed acme types 2023-06-18 13:44:19 +02:00
Brian Picciano
52f87dc625 Get rid of origin::store::BoxedManager and domain::manager::BoxedManager 2023-06-18 13:12:26 +02:00
Brian Picciano
4317d7f282 Simplify git origin a bit 2023-06-17 16:04:26 +02:00
24 changed files with 848 additions and 887 deletions

4
TODO
View File

@ -1 +1,3 @@
- clean up main a lot - make domain_manager implement rusttls cert resolver
- Try to switch from Arc to Box where possible
- maybe build TaskSet into some kind of defer-like replacement

View File

@ -1,4 +1,5 @@
pub mod manager; pub mod manager;
pub mod resolver;
pub mod store; pub mod store;
mod private_key; mod private_key;

View File

@ -7,39 +7,24 @@ const LETS_ENCRYPT_URL: &str = "https://acme-v02.api.letsencrypt.org/directory";
pub type GetHttp01ChallengeKeyError = acme::store::GetHttp01ChallengeKeyError; pub type GetHttp01ChallengeKeyError = acme::store::GetHttp01ChallengeKeyError;
#[mockall::automock( #[mockall::automock]
type SyncDomainFuture=future::Ready<Result<(), unexpected::Error>>; pub trait Manager: Sync + Send {
)] fn sync_domain<'mgr>(
pub trait Manager { &'mgr self,
type SyncDomainFuture<'mgr>: future::Future<Output = Result<(), unexpected::Error>> domain: domain::Name,
+ Send ) -> pin::Pin<Box<dyn future::Future<Output = Result<(), unexpected::Error>> + Send + 'mgr>>;
+ Unpin
+ 'mgr
where
Self: 'mgr;
fn sync_domain(&self, domain: domain::Name) -> Self::SyncDomainFuture<'_>;
fn get_http01_challenge_key(&self, token: &str) -> Result<String, GetHttp01ChallengeKeyError>; fn get_http01_challenge_key(&self, token: &str) -> Result<String, GetHttp01ChallengeKeyError>;
} }
pub trait BoxedManager: Manager + Send + Sync + Clone + 'static {} struct ManagerImpl {
store: sync::Arc<dyn acme::store::Store>,
struct ManagerImpl<Store>
where
Store: acme::store::BoxedStore,
{
store: Store,
account: sync::Arc<acme2::Account>, account: sync::Arc<acme2::Account>,
} }
pub async fn new<Store>( pub async fn new(
store: Store, store: sync::Arc<dyn acme::store::Store>,
contact_email: &str, contact_email: &str,
) -> Result<impl BoxedManager, unexpected::Error> ) -> Result<sync::Arc<dyn Manager>, unexpected::Error> {
where
Store: acme::store::BoxedStore,
{
let dir = acme2::DirectoryBuilder::new(LETS_ENCRYPT_URL.to_string()) let dir = acme2::DirectoryBuilder::new(LETS_ENCRYPT_URL.to_string())
.build() .build()
.await .await
@ -81,16 +66,12 @@ where
Ok(sync::Arc::new(ManagerImpl { store, account })) Ok(sync::Arc::new(ManagerImpl { store, account }))
} }
impl<Store> BoxedManager for sync::Arc<ManagerImpl<Store>> where Store: acme::store::BoxedStore {} impl Manager for ManagerImpl {
fn sync_domain<'mgr>(
impl<Store> Manager for sync::Arc<ManagerImpl<Store>> &'mgr self,
where domain: domain::Name,
Store: acme::store::BoxedStore, ) -> pin::Pin<Box<dyn future::Future<Output = Result<(), unexpected::Error>> + Send + 'mgr>>
{ {
type SyncDomainFuture<'mgr> = pin::Pin<Box<dyn future::Future<Output = Result<(), unexpected::Error>> + Send + 'mgr>>
where Self: 'mgr;
fn sync_domain(&self, domain: domain::Name) -> Self::SyncDomainFuture<'_> {
Box::pin(async move { Box::pin(async move {
// if there's an existing cert, and its expiry (determined by the soonest value of // if there's an existing cert, and its expiry (determined by the soonest value of
// not_after amongst its parts) is later than 30 days from now, then we consider it to be // not_after amongst its parts) is later than 30 days from now, then we consider it to be

View File

@ -0,0 +1,44 @@
use crate::domain::acme::store;
use crate::error::unexpected::Mappable;
use std::sync;
struct CertResolver(sync::Arc<dyn store::Store>);
pub fn new(
store: sync::Arc<dyn store::Store>,
) -> sync::Arc<dyn rustls::server::ResolvesServerCert> {
return sync::Arc::new(CertResolver(store));
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<sync::Arc<rustls::sign::CertifiedKey>> {
let domain = client_hello.server_name()?;
match self.0.get_certificate(domain) {
Err(store::GetCertificateError::NotFound) => {
log::warn!("No cert found for domain {domain}");
Ok(None)
}
Err(store::GetCertificateError::Unexpected(err)) => Err(err),
Ok((key, cert)) => {
match rustls::sign::any_supported_type(&key.into()).or_unexpected() {
Err(err) => Err(err),
Ok(key) => Ok(Some(sync::Arc::new(rustls::sign::CertifiedKey {
cert: cert.into_iter().map(|cert| cert.into()).collect(),
key,
ocsp: None,
sct_list: None,
}))),
}
}
}
.unwrap_or_else(|err| {
log::error!("Unexpected error getting cert for domain {domain}: {err}");
None
})
}
}

View File

@ -38,7 +38,7 @@ pub enum GetCertificateError {
} }
#[mockall::automock] #[mockall::automock]
pub trait Store { pub trait Store: Sync + Send {
fn set_account_key(&self, k: &PrivateKey) -> Result<(), unexpected::Error>; fn set_account_key(&self, k: &PrivateKey) -> Result<(), unexpected::Error>;
fn get_account_key(&self) -> Result<PrivateKey, GetAccountKeyError>; fn get_account_key(&self) -> Result<PrivateKey, GetAccountKeyError>;
@ -60,8 +60,6 @@ pub trait Store {
) -> Result<(PrivateKey, Vec<Certificate>), GetCertificateError>; ) -> Result<(PrivateKey, Vec<Certificate>), GetCertificateError>;
} }
pub trait BoxedStore: Store + Send + Sync + Clone + 'static {}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct StoredPKeyCert { struct StoredPKeyCert {
private_key: PrivateKey, private_key: PrivateKey,
@ -72,12 +70,7 @@ struct FSStore {
dir_path: path::PathBuf, dir_path: path::PathBuf,
} }
#[derive(Clone)] pub fn new(dir_path: &path::Path) -> Result<sync::Arc<dyn Store>, unexpected::Error> {
struct BoxedFSStore(sync::Arc<FSStore>);
pub fn new(
dir_path: &path::Path,
) -> Result<impl BoxedStore + rustls::server::ResolvesServerCert, unexpected::Error> {
vec![ vec![
dir_path, dir_path,
dir_path.join("http01_challenge_keys").as_ref(), dir_path.join("http01_challenge_keys").as_ref(),
@ -89,14 +82,14 @@ pub fn new(
}) })
.try_collect()?; .try_collect()?;
Ok(BoxedFSStore(sync::Arc::new(FSStore { Ok(sync::Arc::new(FSStore {
dir_path: dir_path.into(), dir_path: dir_path.into(),
}))) }))
} }
impl BoxedFSStore { impl FSStore {
fn account_key_path(&self) -> path::PathBuf { fn account_key_path(&self) -> path::PathBuf {
self.0.dir_path.join("account.key") self.dir_path.join("account.key")
} }
fn http01_challenge_key_path(&self, token: &str) -> path::PathBuf { fn http01_challenge_key_path(&self, token: &str) -> path::PathBuf {
@ -106,20 +99,18 @@ impl BoxedFSStore {
.expect("token successfully hashed"); .expect("token successfully hashed");
let n = h.finalize().encode_hex::<String>(); let n = h.finalize().encode_hex::<String>();
self.0.dir_path.join("http01_challenge_keys").join(n) self.dir_path.join("http01_challenge_keys").join(n)
} }
fn certificate_path(&self, domain: &str) -> path::PathBuf { fn certificate_path(&self, domain: &str) -> path::PathBuf {
let mut domain = domain.to_string(); let mut domain = domain.to_string();
domain.push_str(".json"); domain.push_str(".json");
self.0.dir_path.join("certificates").join(domain) self.dir_path.join("certificates").join(domain)
} }
} }
impl BoxedStore for BoxedFSStore {} impl Store for FSStore {
impl Store for BoxedFSStore {
fn set_account_key(&self, k: &PrivateKey) -> Result<(), unexpected::Error> { fn set_account_key(&self, k: &PrivateKey) -> Result<(), unexpected::Error> {
let path = self.account_key_path(); let path = self.account_key_path();
{ {
@ -232,38 +223,6 @@ impl Store for BoxedFSStore {
} }
} }
impl rustls::server::ResolvesServerCert for BoxedFSStore {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<sync::Arc<rustls::sign::CertifiedKey>> {
let domain = client_hello.server_name()?;
match self.get_certificate(domain) {
Err(GetCertificateError::NotFound) => {
log::warn!("No cert found for domain {domain}");
Ok(None)
}
Err(GetCertificateError::Unexpected(err)) => Err(err),
Ok((key, cert)) => {
match rustls::sign::any_supported_type(&key.into()).or_unexpected() {
Err(err) => Err(err),
Ok(key) => Ok(Some(sync::Arc::new(rustls::sign::CertifiedKey {
cert: cert.into_iter().map(|cert| cert.into()).collect(),
key,
ocsp: None,
sct_list: None,
}))),
}
}
}
.unwrap_or_else(|err| {
log::error!("Unexpected error getting cert for domain {domain}: {err}");
None
})
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -38,23 +38,18 @@ pub enum SetError {
Unexpected(#[from] unexpected::Error), Unexpected(#[from] unexpected::Error),
} }
/// Used in the return from all_domains from Store.
pub type AllDomainsResult<T> = Result<T, unexpected::Error>;
#[mockall::automock] #[mockall::automock]
pub trait Store { pub trait Store: Sync + Send {
fn get(&self, domain: &domain::Name) -> Result<Config, GetError>; fn get(&self, domain: &domain::Name) -> Result<Config, GetError>;
fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError>; fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError>;
fn all_domains(&self) -> AllDomainsResult<Vec<AllDomainsResult<domain::Name>>>; fn all_domains(&self) -> Result<Vec<domain::Name>, unexpected::Error>;
} }
pub trait BoxedStore: Store + Send + Sync + Clone {}
struct FSStore { struct FSStore {
dir_path: PathBuf, dir_path: PathBuf,
} }
pub fn new(dir_path: &Path) -> io::Result<impl BoxedStore> { pub fn new(dir_path: &Path) -> io::Result<sync::Arc<dyn Store>> {
fs::create_dir_all(dir_path)?; fs::create_dir_all(dir_path)?;
Ok(sync::Arc::new(FSStore { Ok(sync::Arc::new(FSStore {
dir_path: dir_path.into(), dir_path: dir_path.into(),
@ -71,9 +66,7 @@ impl FSStore {
} }
} }
impl BoxedStore for sync::Arc<FSStore> {} impl Store for FSStore {
impl Store for sync::Arc<FSStore> {
fn get(&self, domain: &domain::Name) -> Result<Config, GetError> { fn get(&self, domain: &domain::Name) -> Result<Config, GetError> {
let path = self.config_file_path(domain); let path = self.config_file_path(domain);
let config_file = fs::File::open(path.as_path()).map_err(|e| match e.kind() { let config_file = fs::File::open(path.as_path()).map_err(|e| match e.kind() {
@ -103,11 +96,11 @@ impl Store for sync::Arc<FSStore> {
Ok(()) Ok(())
} }
fn all_domains(&self) -> AllDomainsResult<Vec<AllDomainsResult<domain::Name>>> { fn all_domains(&self) -> Result<Vec<domain::Name>, unexpected::Error> {
Ok(fs::read_dir(&self.dir_path) fs::read_dir(&self.dir_path)
.or_unexpected()? .or_unexpected()?
.map( .map(
|dir_entry_res: io::Result<fs::DirEntry>| -> AllDomainsResult<domain::Name> { |dir_entry_res: io::Result<fs::DirEntry>| -> Result<domain::Name, unexpected::Error> {
let domain = dir_entry_res.or_unexpected()?.file_name(); let domain = dir_entry_res.or_unexpected()?.file_name();
let domain = domain.to_str().ok_or(unexpected::Error::from( let domain = domain.to_str().ok_or(unexpected::Error::from(
"couldn't convert os string to &str", "couldn't convert os string to &str",
@ -117,7 +110,7 @@ impl Store for sync::Arc<FSStore> {
.map_unexpected_while(|| format!("parsing {domain} as domain name")) .map_unexpected_while(|| format!("parsing {domain} as domain name"))
}, },
) )
.collect()) .try_collect()
} }
} }

View File

@ -1,8 +1,9 @@
use crate::domain::{self, acme, checker, config}; use crate::domain::{self, acme, checker, config};
use crate::error::unexpected::{self, Intoable, Mappable}; use crate::error::unexpected::{self, Mappable};
use crate::origin; use crate::origin;
use std::{future, pin, sync}; use std::{future, pin, sync};
use tokio_util::sync::CancellationToken;
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum GetConfigError { pub enum GetConfigError {
@ -115,112 +116,109 @@ impl From<config::SetError> for SyncWithConfigError {
pub type GetAcmeHttp01ChallengeKeyError = acme::manager::GetHttp01ChallengeKeyError; pub type GetAcmeHttp01ChallengeKeyError = acme::manager::GetHttp01ChallengeKeyError;
pub type AllDomainsResult<T> = config::AllDomainsResult<T>; #[mockall::automock]
pub trait Manager: Sync + Send {
#[mockall::automock(
type Origin=origin::MockOrigin;
type SyncWithConfigFuture=future::Ready<Result<(), SyncWithConfigError>>;
type SyncAllOriginsErrorsIter=Vec<unexpected::Error>;
)]
pub trait Manager {
type Origin<'mgr>: origin::Origin + 'mgr
where
Self: 'mgr;
type SyncWithConfigFuture<'mgr>: future::Future<Output = Result<(), SyncWithConfigError>>
+ Send
+ Unpin
+ 'mgr
where
Self: 'mgr;
type SyncAllOriginsErrorsIter<'mgr>: IntoIterator<Item = unexpected::Error> + 'mgr
where
Self: 'mgr;
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError>; fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError>;
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError>; fn get_origin(
fn sync_with_config(
&self, &self,
domain: &domain::Name,
) -> Result<sync::Arc<dyn origin::Origin>, GetOriginError>;
fn sync_cert<'mgr>(
&'mgr self,
domain: domain::Name,
) -> pin::Pin<Box<dyn future::Future<Output = Result<(), unexpected::Error>> + Send + 'mgr>>;
fn sync_with_config<'mgr>(
&'mgr self,
domain: domain::Name, domain: domain::Name,
config: config::Config, config: config::Config,
) -> Self::SyncWithConfigFuture<'_>; ) -> pin::Pin<Box<dyn future::Future<Output = Result<(), SyncWithConfigError>> + Send + 'mgr>>;
fn sync_all_origins(&self) -> Result<Self::SyncAllOriginsErrorsIter<'_>, unexpected::Error>;
fn get_acme_http01_challenge_key( fn get_acme_http01_challenge_key(
&self, &self,
token: &str, token: &str,
) -> Result<String, GetAcmeHttp01ChallengeKeyError>; ) -> Result<String, GetAcmeHttp01ChallengeKeyError>;
fn all_domains(&self) -> AllDomainsResult<Vec<AllDomainsResult<domain::Name>>>; fn all_domains(&self) -> Result<Vec<domain::Name>, unexpected::Error>;
} }
pub trait BoxedManager: Manager + Send + Sync + Clone {} pub struct ManagerImpl {
origin_store: sync::Arc<dyn origin::store::Store>,
struct ManagerImpl<OriginStore, DomainConfigStore, AcmeManager> domain_config_store: sync::Arc<dyn config::Store>,
where
OriginStore: origin::store::BoxedStore,
DomainConfigStore: config::BoxedStore,
AcmeManager: acme::manager::BoxedManager,
{
origin_store: OriginStore,
domain_config_store: DomainConfigStore,
domain_checker: checker::DNSChecker, domain_checker: checker::DNSChecker,
acme_manager: Option<AcmeManager>, acme_manager: Option<sync::Arc<dyn acme::manager::Manager>>,
canceller: CancellationToken,
origin_sync_handler: tokio::task::JoinHandle<()>,
} }
pub fn new<OriginStore, DomainConfigStore, AcmeManager>( fn sync_origins(origin_store: &dyn origin::store::Store) {
origin_store: OriginStore, match origin_store.all_descrs() {
domain_config_store: DomainConfigStore, Ok(iter) => iter.into_iter(),
Err(err) => {
log::error!("Error fetching origin descriptors: {err}");
return;
}
}
.for_each(|descr| {
if let Err(err) = origin_store.sync(descr.clone(), origin::store::Limits {}) {
log::error!("Failed to sync store for {:?}: {err}", descr);
return;
}
});
}
pub fn new(
origin_store: sync::Arc<dyn origin::store::Store>,
domain_config_store: sync::Arc<dyn config::Store>,
domain_checker: checker::DNSChecker, domain_checker: checker::DNSChecker,
acme_manager: Option<AcmeManager>, acme_manager: Option<sync::Arc<dyn acme::manager::Manager>>,
) -> impl BoxedManager ) -> ManagerImpl {
where let canceller = CancellationToken::new();
OriginStore: origin::store::BoxedStore,
DomainConfigStore: config::BoxedStore, let origin_sync_handler = {
AcmeManager: acme::manager::BoxedManager, let origin_store = origin_store.clone();
{ let canceller = canceller.clone();
sync::Arc::new(ManagerImpl { tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(20 * 60));
loop {
tokio::select! {
_ = interval.tick() => sync_origins(origin_store.as_ref()),
_ = canceller.cancelled() => return,
}
}
})
};
ManagerImpl {
origin_store, origin_store,
domain_config_store, domain_config_store,
domain_checker, domain_checker,
acme_manager, acme_manager,
}) canceller,
origin_sync_handler,
}
} }
impl<OriginStore, DomainConfigStore, AcmeManager> BoxedManager impl ManagerImpl {
for sync::Arc<ManagerImpl<OriginStore, DomainConfigStore, AcmeManager>> pub fn stop(self) -> tokio::task::JoinHandle<()> {
where self.canceller.cancel();
OriginStore: origin::store::BoxedStore, self.origin_sync_handler
DomainConfigStore: config::BoxedStore, }
AcmeManager: acme::manager::BoxedManager,
{
} }
impl<OriginStore, DomainConfigStore, AcmeManager> Manager impl Manager for ManagerImpl {
for sync::Arc<ManagerImpl<OriginStore, DomainConfigStore, AcmeManager>>
where
OriginStore: origin::store::BoxedStore,
DomainConfigStore: config::BoxedStore,
AcmeManager: acme::manager::BoxedManager,
{
type Origin<'mgr> = OriginStore::Origin<'mgr>
where Self: 'mgr;
type SyncWithConfigFuture<'mgr> = pin::Pin<Box<dyn future::Future<Output = Result<(), SyncWithConfigError>> + Send + 'mgr>>
where Self: 'mgr;
type SyncAllOriginsErrorsIter<'mgr> = Box<dyn Iterator<Item = unexpected::Error> + 'mgr>
where Self: 'mgr;
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError> { fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError> {
Ok(self.domain_config_store.get(domain)?) Ok(self.domain_config_store.get(domain)?)
} }
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError> { fn get_origin(
&self,
domain: &domain::Name,
) -> Result<sync::Arc<dyn origin::Origin>, GetOriginError> {
let config = self.domain_config_store.get(domain)?; let config = self.domain_config_store.get(domain)?;
let origin = self let origin = self
.origin_store .origin_store
@ -230,11 +228,26 @@ where
Ok(origin) Ok(origin)
} }
fn sync_with_config( fn sync_cert<'mgr>(
&self, &'mgr self,
domain: domain::Name,
) -> pin::Pin<Box<dyn future::Future<Output = Result<(), unexpected::Error>> + Send + 'mgr>>
{
Box::pin(async move {
if let Some(ref acme_manager) = self.acme_manager {
acme_manager.sync_domain(domain.clone()).await?;
}
Ok(())
})
}
fn sync_with_config<'mgr>(
&'mgr self,
domain: domain::Name, domain: domain::Name,
config: config::Config, config: config::Config,
) -> Self::SyncWithConfigFuture<'_> { ) -> pin::Pin<Box<dyn future::Future<Output = Result<(), SyncWithConfigError>> + Send + 'mgr>>
{
Box::pin(async move { Box::pin(async move {
let config_hash = config let config_hash = config
.hash() .hash()
@ -249,39 +262,12 @@ where
self.domain_config_store.set(&domain, &config)?; self.domain_config_store.set(&domain, &config)?;
if let Some(ref acme_manager) = self.acme_manager { self.sync_cert(domain).await?;
acme_manager.sync_domain(domain.clone()).await?;
}
Ok(()) Ok(())
}) })
} }
fn sync_all_origins(&self) -> Result<Self::SyncAllOriginsErrorsIter<'_>, unexpected::Error> {
let iter = self
.origin_store
.all_descrs()
.or_unexpected_while("fetching all origin descrs")?
.into_iter();
Ok(Box::from(iter.filter_map(|descr| {
if let Err(err) = descr {
return Some(err.into_unexpected());
}
let descr = descr.unwrap();
if let Err(err) = self
.origin_store
.sync(descr.clone(), origin::store::Limits {})
{
return Some(err.into_unexpected_while(format!("syncing store {:?}", descr)));
}
None
})))
}
fn get_acme_http01_challenge_key( fn get_acme_http01_challenge_key(
&self, &self,
token: &str, token: &str,
@ -293,7 +279,7 @@ where
Err(GetAcmeHttp01ChallengeKeyError::NotFound) Err(GetAcmeHttp01ChallengeKeyError::NotFound)
} }
fn all_domains(&self) -> AllDomainsResult<Vec<AllDomainsResult<domain::Name>>> { fn all_domains(&self) -> Result<Vec<domain::Name>, unexpected::Error> {
self.domain_config_store.all_domains() self.domain_config_store.all_domains()
} }
} }

View File

@ -1,3 +1,4 @@
#![feature(result_option_inspect)]
#![feature(iterator_try_collect)] #![feature(iterator_try_collect)]
pub mod domain; pub mod domain;

View File

@ -1,19 +1,10 @@
#![feature(result_option_inspect)]
use clap::Parser; use clap::Parser;
use futures::stream::futures_unordered::FuturesUnordered;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use signal_hook_tokio::Signals; use signal_hook_tokio::Signals;
use tokio::select;
use tokio::time;
use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::str::FromStr; use std::str::FromStr;
use std::{future, path, sync}; use std::{path, sync};
use domiply::domain::acme::manager::Manager as AcmeManager;
use domiply::domain::manager::Manager;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(version)] #[command(version)]
@ -68,14 +59,10 @@ struct Cli {
} }
#[derive(Clone)] #[derive(Clone)]
struct HTTPSParams<DomainAcmeStore, DomainAcmeManager> struct HTTPSParams {
where
DomainAcmeStore: domiply::domain::acme::store::BoxedStore,
DomainAcmeManager: domiply::domain::acme::manager::BoxedManager,
{
https_listen_addr: SocketAddr, https_listen_addr: SocketAddr,
domain_acme_store: DomainAcmeStore, domain_acme_store: sync::Arc<dyn domiply::domain::acme::store::Store>,
domain_acme_manager: DomainAcmeManager, domain_acme_manager: sync::Arc<dyn domiply::domain::acme::manager::Manager>,
} }
#[tokio::main] #[tokio::main]
@ -91,7 +78,6 @@ async fn main() {
) )
.init(); .init();
let mut wait_group = FuturesUnordered::new();
let canceller = tokio_util::sync::CancellationToken::new(); let canceller = tokio_util::sync::CancellationToken::new();
{ {
@ -159,201 +145,31 @@ async fn main() {
https_params.as_ref().map(|p| p.domain_acme_manager.clone()), https_params.as_ref().map(|p| p.domain_acme_manager.clone()),
); );
wait_group.push({ let domain_manager = sync::Arc::new(domain_manager);
let domain_manager = domain_manager.clone();
let canceller = canceller.clone();
tokio::spawn(async move { {
let mut interval = time::interval(time::Duration::from_secs(20 * 60)); let (http_service, http_service_task_set) = domiply::service::http::new(
loop {
select! {
_ = interval.tick() => (),
_ = canceller.cancelled() => return,
}
let errors_iter = domain_manager.sync_all_origins();
if let Err(err) = errors_iter {
log::error!("Got error calling sync_all_origins: {err}");
continue;
}
errors_iter
.unwrap()
.into_iter()
.for_each(|err| log::error!("syncing failed: {err}"));
}
})
});
let service = domiply::service::new(
domain_manager.clone(), domain_manager.clone(),
config.domain_checker_target_a, config.domain_checker_target_a,
config.passphrase, config.passphrase,
config.http_listen_addr.clone(),
config.http_domain.clone(), config.http_domain.clone(),
https_params.map(|p| domiply::service::http::HTTPSParams {
listen_addr: p.https_listen_addr,
cert_resolver: domiply::domain::acme::resolver::new(p.domain_acme_store),
}),
); );
let service = sync::Arc::new(service);
wait_group.push({
let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
let service = service.clone();
let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone();
// Create a `Service` for responding to the request.
let service = hyper::service::service_fn(move |req| {
domiply::service::handle_request(service.clone(), req)
});
// Return the service to hyper.
async move { Ok::<_, Infallible>(service) }
});
tokio::spawn(async move {
let addr = config.http_listen_addr;
log::info!(
"Listening on http://{}:{}",
http_domain.as_str(),
addr.port()
);
let server = hyper::Server::bind(&addr).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
canceller.cancelled().await; canceller.cancelled().await;
});
if let Err(e) = graceful.await { domiply::service::http::stop(http_service, http_service_task_set).await;
panic!("server error: {}", e);
};
})
});
if let Some(https_params) = https_params {
// Periodically refresh all domain certs, including the http_domain passed in the Cli opts
wait_group.push({
let https_params = https_params.clone();
let domain_manager = domain_manager.clone();
let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
tokio::spawn(async move {
let mut interval = time::interval(time::Duration::from_secs(60 * 60));
loop {
select! {
_ = interval.tick() => (),
_ = canceller.cancelled() => return,
} }
_ = https_params sync::Arc::into_inner(domain_manager)
.domain_acme_manager .unwrap()
.sync_domain(http_domain.clone()) .stop()
.await .await
.inspect_err(|err| { .expect("domain manager failed to shutdown cleanly");
log::error!(
"Error while getting cert for {}: {err}",
http_domain.as_str()
)
});
let domains_iter = domain_manager.all_domains();
if let Err(err) = domains_iter {
log::error!("Got error calling all_domains: {err}");
continue;
}
for domain in domains_iter.unwrap().into_iter() {
match domain {
Ok(domain) => {
let _ = https_params
.domain_acme_manager
.sync_domain(domain.clone())
.await
.inspect_err(|err| {
log::error!(
"Error while getting cert for {}: {err}",
domain.as_str(),
)
});
}
Err(err) => log::error!("Error iterating through domains: {err}"),
};
}
}
})
});
// HTTPS server
wait_group.push({
let https_params = https_params;
let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
let service = service.clone();
let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone();
// Create a `Service` for responding to the request.
let service = hyper::service::service_fn(move |req| {
domiply::service::handle_request(service.clone(), req)
});
// Return the service to hyper.
async move { Ok::<_, Infallible>(service) }
});
tokio::spawn(async move {
let canceller = canceller.clone();
let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
rustls::server::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(sync::Arc::from(https_params.domain_acme_store)),
)
.into();
let addr = https_params.https_listen_addr;
let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr)
.expect("https listen socket creation failed");
let incoming =
tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| {
if let Err(err) = conn {
log::error!("Error accepting TLS connection: {:?}", err);
future::ready(false)
} else {
future::ready(true)
}
});
let incoming = hyper::server::accept::from_stream(incoming);
log::info!(
"Listening on https://{}:{}",
http_domain.as_str(),
addr.port()
);
let server = hyper::Server::builder(incoming).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
canceller.cancelled().await;
});
if let Err(e) = graceful.await {
panic!("server error: {}", e);
};
})
})
}
while wait_group.next().await.is_some() {}
log::info!("Graceful shutdown complete"); log::info!("Graceful shutdown complete");
} }

View File

@ -1,5 +1,6 @@
use crate::error::unexpected; use crate::error::unexpected;
use crate::origin; use crate::origin;
use std::sync;
pub mod git; pub mod git;
@ -38,29 +39,13 @@ pub enum AllDescrsError {
Unexpected(#[from] unexpected::Error), Unexpected(#[from] unexpected::Error),
} }
/// Used in the return from all_descrs from Store. #[mockall::automock]
pub type AllDescrsResult<T> = Result<T, AllDescrsError>;
#[mockall::automock(
type Origin=origin::MockOrigin;
type AllDescrsIter=Vec<AllDescrsResult<origin::Descr>>;
)]
/// Describes a storage mechanism for Origins. Each Origin is uniquely identified by its Descr. /// Describes a storage mechanism for Origins. Each Origin is uniquely identified by its Descr.
pub trait Store { pub trait Store: Sync + Send {
type Origin<'store>: origin::Origin + 'store
where
Self: 'store;
type AllDescrsIter<'store>: IntoIterator<Item = AllDescrsResult<origin::Descr>> + 'store
where
Self: 'store;
/// If the origin is of a kind which can be updated, sync will pull down the latest version of /// If the origin is of a kind which can be updated, sync will pull down the latest version of
/// the origin into the storage. /// the origin into the storage.
fn sync(&self, descr: origin::Descr, limits: Limits) -> Result<(), SyncError>; fn sync(&self, descr: origin::Descr, limits: Limits) -> Result<(), SyncError>;
fn get(&self, descr: origin::Descr) -> Result<Self::Origin<'_>, GetError>; fn get(&self, descr: origin::Descr) -> Result<sync::Arc<dyn origin::Origin>, GetError>;
fn all_descrs(&self) -> AllDescrsResult<Self::AllDescrsIter<'_>>; fn all_descrs(&self) -> Result<Vec<origin::Descr>, AllDescrsError>;
} }
pub trait BoxedStore: Store + Send + Sync + Clone {}

View File

@ -4,13 +4,14 @@ use crate::origin::{self, store};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::{collections, fs, io, sync}; use std::{collections, fs, io, sync};
#[derive(Clone)]
struct Origin { struct Origin {
descr: origin::Descr, descr: origin::Descr,
repo: gix::ThreadSafeRepository, repo: sync::Arc<gix::ThreadSafeRepository>,
tree_object_id: gix::ObjectId, tree_object_id: gix::ObjectId,
} }
impl origin::Origin for sync::Arc<Origin> { impl origin::Origin for Origin {
fn descr(&self) -> &origin::Descr { fn descr(&self) -> &origin::Descr {
&self.descr &self.descr
} }
@ -70,7 +71,7 @@ struct Store {
origins: sync::RwLock<collections::HashMap<origin::Descr, sync::Arc<Origin>>>, origins: sync::RwLock<collections::HashMap<origin::Descr, sync::Arc<Origin>>>,
} }
pub fn new(dir_path: PathBuf) -> io::Result<impl super::BoxedStore> { pub fn new(dir_path: PathBuf) -> io::Result<sync::Arc<dyn super::Store>> {
fs::create_dir_all(&dir_path)?; fs::create_dir_all(&dir_path)?;
Ok(sync::Arc::new(Store { Ok(sync::Arc::new(Store {
dir_path, dir_path,
@ -96,7 +97,7 @@ impl Store {
&self, &self,
repo: gix::Repository, repo: gix::Repository,
descr: origin::Descr, descr: origin::Descr,
) -> Result<sync::Arc<Origin>, GetOriginError> { ) -> Result<Origin, GetOriginError> {
let origin::Descr::Git { let origin::Descr::Git {
ref branch_name, .. ref branch_name, ..
} = descr; } = descr;
@ -118,11 +119,11 @@ impl Store {
.map_unexpected_while(|| format!("parsing {commit_object_id} as commit"))? .map_unexpected_while(|| format!("parsing {commit_object_id} as commit"))?
.tree(); .tree();
Ok(sync::Arc::from(Origin { Ok(Origin {
descr, descr,
repo: repo.into(), repo: sync::Arc::new(repo.into()),
tree_object_id, tree_object_id,
})) })
} }
fn sync_inner( fn sync_inner(
@ -207,15 +208,7 @@ impl Store {
} }
} }
impl super::BoxedStore for sync::Arc<Store> {} impl super::Store for Store {
impl super::Store for sync::Arc<Store> {
type Origin<'store> = sync::Arc<Origin>
where Self: 'store;
type AllDescrsIter<'store> = Box<dyn Iterator<Item = store::AllDescrsResult<origin::Descr>> + 'store>
where Self: 'store;
fn sync(&self, descr: origin::Descr, limits: store::Limits) -> Result<(), store::SyncError> { fn sync(&self, descr: origin::Descr, limits: store::Limits) -> Result<(), store::SyncError> {
// attempt to lock this descr for syncing, doing so within a new scope so the mutex // attempt to lock this descr for syncing, doing so within a new scope so the mutex
// isn't actually being held for the whole method duration. // isn't actually being held for the whole method duration.
@ -255,12 +248,12 @@ impl super::Store for sync::Arc<Store> {
})?; })?;
let mut origins = self.origins.write().unwrap(); let mut origins = self.origins.write().unwrap();
(*origins).insert(descr, origin); (*origins).insert(descr, sync::Arc::new(origin));
Ok(()) Ok(())
} }
fn get(&self, descr: origin::Descr) -> Result<Self::Origin<'_>, store::GetError> { fn get(&self, descr: origin::Descr) -> Result<sync::Arc<dyn origin::Origin>, store::GetError> {
{ {
let origins = self.origins.read().unwrap(); let origins = self.origins.read().unwrap();
if let Some(origin) = origins.get(&descr) { if let Some(origin) = origins.get(&descr) {
@ -287,16 +280,18 @@ impl super::Store for sync::Arc<Store> {
GetOriginError::Unexpected(e) => store::GetError::Unexpected(e), GetOriginError::Unexpected(e) => store::GetError::Unexpected(e),
})?; })?;
let origin = sync::Arc::new(origin.clone());
let mut origins = self.origins.write().unwrap(); let mut origins = self.origins.write().unwrap();
(*origins).insert(descr, origin.clone()); (*origins).insert(descr, origin.clone());
Ok(origin) Ok(origin)
} }
fn all_descrs(&self) -> store::AllDescrsResult<Self::AllDescrsIter<'_>> { fn all_descrs(&self) -> Result<Vec<origin::Descr>, store::AllDescrsError> {
Ok(Box::from(
fs::read_dir(&self.dir_path).or_unexpected()?.map( fs::read_dir(&self.dir_path).or_unexpected()?.map(
|dir_entry_res: io::Result<fs::DirEntry>| -> store::AllDescrsResult<origin::Descr> { |dir_entry_res: io::Result<fs::DirEntry>| -> Result<origin::Descr, store::AllDescrsError> {
let descr_id: String = dir_entry_res let descr_id: String = dir_entry_res
.or_unexpected()? .or_unexpected()?
.file_name() .file_name()
@ -322,8 +317,7 @@ impl super::Store for sync::Arc<Store> {
Ok(descr) Ok(descr)
}, },
), ).try_collect()
))
} }
} }

View File

@ -1,421 +1,2 @@
use hyper::{Body, Method, Request, Response}; pub mod http;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::future::Future;
use std::net;
use std::str::FromStr;
use std::sync;
use crate::origin::Origin;
use crate::{domain, origin};
pub mod http_tpl;
mod util; mod util;
type SvcResponse = Result<Response<hyper::body::Body>, String>;
#[derive(Clone)]
pub struct Service<'svc, DomainManager>
where
DomainManager: domain::manager::BoxedManager,
{
domain_manager: DomainManager,
target_a: net::Ipv4Addr,
passphrase: String,
http_domain: domain::Name,
handlebars: handlebars::Handlebars<'svc>,
}
pub fn new<'svc, DomainManager>(
domain_manager: DomainManager,
target_a: net::Ipv4Addr,
passphrase: String,
http_domain: domain::Name,
) -> Service<'svc, DomainManager>
where
DomainManager: domain::manager::BoxedManager,
{
Service {
domain_manager,
target_a,
passphrase,
http_domain,
handlebars: self::http_tpl::get(),
}
}
#[derive(Serialize)]
struct BasePresenter<'a, T> {
page_name: &'a str,
data: T,
}
#[derive(Deserialize)]
struct DomainGetArgs {
domain: domain::Name,
}
#[derive(Deserialize)]
struct DomainInitArgs {
domain: domain::Name,
}
#[derive(Deserialize)]
struct DomainSyncArgs {
domain: domain::Name,
passphrase: String,
}
impl<'svc, DomainManager> Service<'svc, DomainManager>
where
DomainManager: domain::manager::BoxedManager,
{
fn serve_string(&self, status_code: u16, path: &'_ str, body: Vec<u8>) -> SvcResponse {
let content_type = mime_guess::from_path(path)
.first_or_octet_stream()
.to_string();
match Response::builder()
.status(status_code)
.header("Content-Type", content_type)
.body(body.into())
{
Ok(res) => Ok(res),
Err(err) => Err(format!("failed to build {}: {}", path, err)),
}
}
//// TODO make this use an io::Write, rather than SvcResponse
fn render<T>(&self, status_code: u16, name: &'_ str, value: T) -> SvcResponse
where
T: Serialize,
{
let rendered = match self.handlebars.render(name, &value) {
Ok(res) => res,
Err(handlebars::RenderError {
template_name: None,
..
}) => return self.render_error_page(404, "Static asset not found"),
Err(err) => {
return self.render_error_page(500, format!("template error: {err}").as_str())
}
};
self.serve_string(status_code, name, rendered.into())
}
fn render_error_page(&'svc self, status_code: u16, e: &'_ str) -> SvcResponse {
#[derive(Serialize)]
struct Response<'a> {
error_msg: &'a str,
}
self.render(
status_code,
"/base.html",
BasePresenter {
page_name: "/error.html",
data: &Response { error_msg: e },
},
)
}
fn render_page<T>(&self, name: &'_ str, data: T) -> SvcResponse
where
T: Serialize,
{
self.render(
200,
"/base.html",
BasePresenter {
page_name: name,
data,
},
)
}
fn serve_origin(&self, domain: domain::Name, path: &'_ str) -> SvcResponse {
let mut path_owned;
let path = match path.ends_with('/') {
true => {
path_owned = String::from(path);
path_owned.push_str("index.html");
path_owned.as_str()
}
false => path,
};
let origin = match self.domain_manager.get_origin(&domain) {
Ok(o) => o,
Err(domain::manager::GetOriginError::NotFound) => {
return self.render_error_page(404, "Domain not found")
}
Err(domain::manager::GetOriginError::Unexpected(e)) => {
return self.render_error_page(500, format!("failed to fetch origin: {e}").as_str())
}
};
let mut buf = Vec::<u8>::new();
match origin.read_file_into(path, &mut buf) {
Ok(_) => self.serve_string(200, path, buf),
Err(origin::ReadFileIntoError::FileNotFound) => {
self.render_error_page(404, "File not found")
}
Err(origin::ReadFileIntoError::Unexpected(e)) => {
self.render_error_page(500, format!("failed to fetch file {path}: {e}").as_str())
}
}
}
async fn with_query_req<'a, F, In, Out>(&self, req: &'a Request<Body>, f: F) -> SvcResponse
where
In: Deserialize<'a>,
F: FnOnce(In) -> Out,
Out: Future<Output = SvcResponse>,
{
let query = req.uri().query().unwrap_or("");
match serde_urlencoded::from_str::<In>(query) {
Ok(args) => f(args).await,
Err(err) => Err(format!("failed to parse query args: {}", err)),
}
}
fn domain_get(&self, args: DomainGetArgs) -> SvcResponse {
#[derive(Serialize)]
struct Response {
domain: domain::Name,
config: Option<domain::config::Config>,
}
let config = match self.domain_manager.get_config(&args.domain) {
Ok(config) => Some(config),
Err(domain::manager::GetConfigError::NotFound) => None,
Err(domain::manager::GetConfigError::Unexpected(e)) => {
return self
.render_error_page(500, format!("retrieving configuration: {}", e).as_str());
}
};
self.render_page(
"/domain.html",
Response {
domain: args.domain,
config,
},
)
}
fn domain_init(&self, args: DomainInitArgs, domain_config: util::FlatConfig) -> SvcResponse {
#[derive(Serialize)]
struct Response {
domain: domain::Name,
flat_config: util::FlatConfig,
target_a: net::Ipv4Addr,
challenge_token: String,
}
let config: domain::config::Config = match domain_config.try_into() {
Ok(Some(config)) => config,
Ok(None) => return self.render_error_page(400, "domain config is required"),
Err(e) => {
return self.render_error_page(400, format!("invalid domain config: {e}").as_str())
}
};
let config_hash = match config.hash() {
Ok(hash) => hash,
Err(e) => {
return self
.render_error_page(500, format!("failed to hash domain config: {e}").as_str())
}
};
self.render_page(
"/domain_init.html",
Response {
domain: args.domain,
flat_config: config.into(),
target_a: self.target_a,
challenge_token: config_hash,
},
)
}
async fn domain_sync(
&self,
args: DomainSyncArgs,
domain_config: util::FlatConfig,
) -> SvcResponse {
if args.passphrase != self.passphrase.as_str() {
return self.render_error_page(401, "Incorrect passphrase");
}
let config: domain::config::Config = match domain_config.try_into() {
Ok(Some(config)) => config,
Ok(None) => return self.render_error_page(400, "domain config is required"),
Err(e) => {
return self.render_error_page(400, format!("invalid domain config: {e}").as_str())
}
};
let sync_result = self
.domain_manager
.sync_with_config(args.domain.clone(), config)
.await;
#[derive(Serialize)]
struct Response {
domain: domain::Name,
error_msg: Option<String>,
}
let error_msg = match sync_result {
Ok(_) => None,
Err(domain::manager::SyncWithConfigError::InvalidURL) => Some("Fetching the git repository failed, please double check that you input the correct URL.".to_string()),
Err(domain::manager::SyncWithConfigError::InvalidBranchName) => Some("The git repository does not have a branch of the given name, please double check that you input the correct name.".to_string()),
Err(domain::manager::SyncWithConfigError::AlreadyInProgress) => Some("The configuration of your domain is still in progress, please refresh in a few minutes.".to_string()),
Err(domain::manager::SyncWithConfigError::TargetANotSet) => Some("The A record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()),
Err(domain::manager::SyncWithConfigError::ChallengeTokenNotSet) => Some("The TXT record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()),
Err(domain::manager::SyncWithConfigError::Unexpected(e)) => Some(format!("An unexpected error occurred: {e}")),
};
let response = Response {
domain: args.domain,
error_msg,
};
self.render_page("/domain_sync.html", response)
}
pub fn domains(&self) -> SvcResponse {
#[derive(Serialize)]
struct Response {
domains: Vec<String>,
}
let domains = match self.domain_manager.all_domains() {
Ok(domains) => domains,
Err(e) => {
return self.render_error_page(500, format!("failed get all domains: {e}").as_str())
}
};
let domains: Vec<domain::Name> = match domains.into_iter().try_collect() {
Ok(domains) => domains,
Err(e) => {
return self.render_error_page(500, format!("failed get all domains: {e}").as_str())
}
};
let mut domains: Vec<String> = domains
.into_iter()
.map(|domain| domain.as_str().to_string())
.collect();
domains.sort();
self.render_page("/domains.html", Response { domains })
}
}
pub async fn handle_request<DomainManager>(
svc: sync::Arc<Service<'_, DomainManager>>,
req: Request<Body>,
) -> Result<Response<Body>, Infallible>
where
DomainManager: domain::manager::BoxedManager,
{
match handle_request_inner(svc, req).await {
Ok(res) => Ok(res),
Err(err) => panic!("unexpected error {err}"),
}
}
fn strip_port(host: &str) -> &str {
match host.rfind(':') {
None => host,
Some(i) => &host[..i],
}
}
pub async fn handle_request_inner<DomainManager>(
svc: sync::Arc<Service<'_, DomainManager>>,
req: Request<Body>,
) -> SvcResponse
where
DomainManager: domain::manager::BoxedManager,
{
let maybe_host = match (
req.headers()
.get("Host")
.and_then(|v| v.to_str().ok())
.map(strip_port),
req.uri().host().map(strip_port),
) {
(Some(h), _) if h != svc.http_domain.as_str() => Some(h),
(_, Some(h)) if h != svc.http_domain.as_str() => Some(h),
_ => None,
}
.and_then(|h| domain::Name::from_str(h).ok());
let method = req.method();
let path = req.uri().path();
// Serving acme challenges always takes priority. We serve them from the same store no matter
// the domain, presumably they are cryptographically random enough that it doesn't matter.
if method == Method::GET && path.starts_with("/.well-known/acme-challenge/") {
let token = path.trim_start_matches("/.well-known/acme-challenge/");
if let Ok(key) = svc.domain_manager.get_acme_http01_challenge_key(token) {
let body: hyper::Body = key.into();
return match Response::builder().status(200).body(body) {
Ok(res) => Ok(res),
Err(err) => Err(format!(
"failed to write acme http-01 challenge key: {}",
err
)),
};
}
}
// If a managed domain was given then serve that from its origin
if let Some(domain) = maybe_host {
return svc.serve_origin(domain, req.uri().path());
}
// Serve main domiply site
if method == Method::GET && path.starts_with("/static/") {
return svc.render(200, path, ());
}
match (method, path) {
(&Method::GET, "/") | (&Method::GET, "/index.html") => svc.render_page("/index.html", ()),
(&Method::GET, "/domain.html") => {
svc.with_query_req(&req, |args: DomainGetArgs| async { svc.domain_get(args) })
.await
}
(&Method::GET, "/domain_init.html") => {
svc.with_query_req(&req, |args: DomainInitArgs| async {
svc.with_query_req(&req, |config: util::FlatConfig| async {
svc.domain_init(args, config)
})
.await
})
.await
}
(&Method::GET, "/domain_sync.html") => {
svc.with_query_req(&req, |args: DomainSyncArgs| async {
svc.with_query_req(&req, |config: util::FlatConfig| async {
svc.domain_sync(args, config).await
})
.await
})
.await
}
(&Method::GET, "/domains.html") => svc.domains(),
_ => svc.render_error_page(404, "Page not found!"),
}
}

441
src/service/http.rs Normal file
View File

@ -0,0 +1,441 @@
mod tasks;
mod tpl;
use hyper::{Body, Method, Request, Response};
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::str::FromStr;
use std::{future, net, sync};
use crate::error::unexpected;
use crate::{domain, origin, service, util};
type SvcResponse = Result<Response<hyper::body::Body>, String>;
pub struct Service {
domain_manager: sync::Arc<dyn domain::manager::Manager>,
target_a: net::Ipv4Addr,
passphrase: String,
http_domain: domain::Name,
handlebars: handlebars::Handlebars<'static>,
}
pub struct HTTPSParams {
pub listen_addr: net::SocketAddr,
pub cert_resolver: sync::Arc<dyn rustls::server::ResolvesServerCert>,
}
pub fn new(
domain_manager: sync::Arc<dyn domain::manager::Manager>,
target_a: net::Ipv4Addr,
passphrase: String,
http_listen_addr: net::SocketAddr,
http_domain: domain::Name,
https_params: Option<HTTPSParams>,
) -> (sync::Arc<Service>, util::TaskSet<unexpected::Error>) {
let service = sync::Arc::new(Service {
domain_manager: domain_manager.clone(),
target_a,
passphrase,
http_domain: http_domain.clone(),
handlebars: tpl::get(),
});
let task_set = util::TaskSet::new();
task_set.spawn(|canceller| {
tasks::listen_http(
service.clone(),
canceller,
http_listen_addr,
http_domain.clone(),
)
});
if let Some(https_params) = https_params {
task_set.spawn(|canceller| {
tasks::listen_https(
service.clone(),
canceller,
https_params.cert_resolver.clone(),
https_params.listen_addr,
http_domain.clone(),
)
});
task_set.spawn(|canceller| {
tasks::cert_refresher(domain_manager.clone(), canceller, http_domain.clone())
});
}
return (service, task_set);
}
pub async fn stop(service: sync::Arc<Service>, task_set: util::TaskSet<unexpected::Error>) {
task_set
.stop()
.await
.iter()
.for_each(|e| log::error!("error while shutting down http service: {e}"));
sync::Arc::into_inner(service).expect("service didn't get cleaned up");
}
#[derive(Serialize)]
struct BasePresenter<'a, T> {
page_name: &'a str,
data: T,
}
#[derive(Deserialize)]
struct DomainGetArgs {
domain: domain::Name,
}
#[derive(Deserialize)]
struct DomainInitArgs {
domain: domain::Name,
}
#[derive(Deserialize)]
struct DomainSyncArgs {
domain: domain::Name,
passphrase: String,
}
impl<'svc> Service {
fn serve_string(&self, status_code: u16, path: &'_ str, body: Vec<u8>) -> SvcResponse {
let content_type = mime_guess::from_path(path)
.first_or_octet_stream()
.to_string();
match Response::builder()
.status(status_code)
.header("Content-Type", content_type)
.body(body.into())
{
Ok(res) => Ok(res),
Err(err) => Err(format!("failed to build {}: {}", path, err)),
}
}
//// TODO make this use an io::Write, rather than SvcResponse
fn render<T>(&self, status_code: u16, name: &'_ str, value: T) -> SvcResponse
where
T: Serialize,
{
let rendered = match self.handlebars.render(name, &value) {
Ok(res) => res,
Err(handlebars::RenderError {
template_name: None,
..
}) => return self.render_error_page(404, "Static asset not found"),
Err(err) => {
return self.render_error_page(500, format!("template error: {err}").as_str())
}
};
self.serve_string(status_code, name, rendered.into())
}
fn render_error_page(&'svc self, status_code: u16, e: &'_ str) -> SvcResponse {
#[derive(Serialize)]
struct Response<'a> {
error_msg: &'a str,
}
self.render(
status_code,
"/base.html",
BasePresenter {
page_name: "/error.html",
data: &Response { error_msg: e },
},
)
}
fn render_page<T>(&self, name: &'_ str, data: T) -> SvcResponse
where
T: Serialize,
{
self.render(
200,
"/base.html",
BasePresenter {
page_name: name,
data,
},
)
}
fn serve_origin(&self, domain: domain::Name, path: &'_ str) -> SvcResponse {
let mut path_owned;
let path = match path.ends_with('/') {
true => {
path_owned = String::from(path);
path_owned.push_str("index.html");
path_owned.as_str()
}
false => path,
};
let origin = match self.domain_manager.get_origin(&domain) {
Ok(o) => o,
Err(domain::manager::GetOriginError::NotFound) => {
return self.render_error_page(404, "Domain not found")
}
Err(domain::manager::GetOriginError::Unexpected(e)) => {
return self.render_error_page(500, format!("failed to fetch origin: {e}").as_str())
}
};
let mut buf = Vec::<u8>::new();
match origin.read_file_into(path, &mut buf) {
Ok(_) => self.serve_string(200, path, buf),
Err(origin::ReadFileIntoError::FileNotFound) => {
self.render_error_page(404, "File not found")
}
Err(origin::ReadFileIntoError::Unexpected(e)) => {
self.render_error_page(500, format!("failed to fetch file {path}: {e}").as_str())
}
}
}
async fn with_query_req<'a, F, In, Out>(&self, req: &'a Request<Body>, f: F) -> SvcResponse
where
In: Deserialize<'a>,
F: FnOnce(In) -> Out,
Out: future::Future<Output = SvcResponse>,
{
let query = req.uri().query().unwrap_or("");
match serde_urlencoded::from_str::<In>(query) {
Ok(args) => f(args).await,
Err(err) => Err(format!("failed to parse query args: {}", err)),
}
}
fn domain_get(&self, args: DomainGetArgs) -> SvcResponse {
#[derive(Serialize)]
struct Response {
domain: domain::Name,
config: Option<domain::config::Config>,
}
let config = match self.domain_manager.get_config(&args.domain) {
Ok(config) => Some(config),
Err(domain::manager::GetConfigError::NotFound) => None,
Err(domain::manager::GetConfigError::Unexpected(e)) => {
return self
.render_error_page(500, format!("retrieving configuration: {}", e).as_str());
}
};
self.render_page(
"/domain.html",
Response {
domain: args.domain,
config,
},
)
}
fn domain_init(
&self,
args: DomainInitArgs,
domain_config: service::util::FlatConfig,
) -> SvcResponse {
#[derive(Serialize)]
struct Response {
domain: domain::Name,
flat_config: service::util::FlatConfig,
target_a: net::Ipv4Addr,
challenge_token: String,
}
let config: domain::config::Config = match domain_config.try_into() {
Ok(Some(config)) => config,
Ok(None) => return self.render_error_page(400, "domain config is required"),
Err(e) => {
return self.render_error_page(400, format!("invalid domain config: {e}").as_str())
}
};
let config_hash = match config.hash() {
Ok(hash) => hash,
Err(e) => {
return self
.render_error_page(500, format!("failed to hash domain config: {e}").as_str())
}
};
self.render_page(
"/domain_init.html",
Response {
domain: args.domain,
flat_config: config.into(),
target_a: self.target_a,
challenge_token: config_hash,
},
)
}
async fn domain_sync(
&self,
args: DomainSyncArgs,
domain_config: service::util::FlatConfig,
) -> SvcResponse {
if args.passphrase != self.passphrase.as_str() {
return self.render_error_page(401, "Incorrect passphrase");
}
let config: domain::config::Config = match domain_config.try_into() {
Ok(Some(config)) => config,
Ok(None) => return self.render_error_page(400, "domain config is required"),
Err(e) => {
return self.render_error_page(400, format!("invalid domain config: {e}").as_str())
}
};
let sync_result = self
.domain_manager
.sync_with_config(args.domain.clone(), config)
.await;
#[derive(Serialize)]
struct Response {
domain: domain::Name,
error_msg: Option<String>,
}
let error_msg = match sync_result {
Ok(_) => None,
Err(domain::manager::SyncWithConfigError::InvalidURL) => Some("Fetching the git repository failed, please double check that you input the correct URL.".to_string()),
Err(domain::manager::SyncWithConfigError::InvalidBranchName) => Some("The git repository does not have a branch of the given name, please double check that you input the correct name.".to_string()),
Err(domain::manager::SyncWithConfigError::AlreadyInProgress) => Some("The configuration of your domain is still in progress, please refresh in a few minutes.".to_string()),
Err(domain::manager::SyncWithConfigError::TargetANotSet) => Some("The A record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()),
Err(domain::manager::SyncWithConfigError::ChallengeTokenNotSet) => Some("The TXT record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()),
Err(domain::manager::SyncWithConfigError::Unexpected(e)) => Some(format!("An unexpected error occurred: {e}")),
};
let response = Response {
domain: args.domain,
error_msg,
};
self.render_page("/domain_sync.html", response)
}
fn domains(&self) -> SvcResponse {
#[derive(Serialize)]
struct Response {
domains: Vec<String>,
}
let domains = match self.domain_manager.all_domains() {
Ok(domains) => domains,
Err(e) => {
return self.render_error_page(500, format!("failed get all domains: {e}").as_str())
}
};
let mut domains: Vec<String> = domains
.into_iter()
.map(|domain| domain.as_str().to_string())
.collect();
domains.sort();
self.render_page("/domains.html", Response { domains })
}
async fn handle_request_inner(&self, req: Request<Body>) -> SvcResponse {
let maybe_host = match (
req.headers()
.get("Host")
.and_then(|v| v.to_str().ok())
.map(strip_port),
req.uri().host().map(strip_port),
) {
(Some(h), _) if h != self.http_domain.as_str() => Some(h),
(_, Some(h)) if h != self.http_domain.as_str() => Some(h),
_ => None,
}
.and_then(|h| domain::Name::from_str(h).ok());
let method = req.method();
let path = req.uri().path();
// Serving acme challenges always takes priority. We serve them from the same store no matter
// the domain, presumably they are cryptographically random enough that it doesn't matter.
if method == Method::GET && path.starts_with("/.well-known/acme-challenge/") {
let token = path.trim_start_matches("/.well-known/acme-challenge/");
if let Ok(key) = self.domain_manager.get_acme_http01_challenge_key(token) {
let body: hyper::Body = key.into();
return match Response::builder().status(200).body(body) {
Ok(res) => Ok(res),
Err(err) => Err(format!(
"failed to write acme http-01 challenge key: {}",
err
)),
};
}
}
// If a managed domain was given then serve that from its origin
if let Some(domain) = maybe_host {
return self.serve_origin(domain, req.uri().path());
}
// Serve main domiply site
if method == Method::GET && path.starts_with("/static/") {
return self.render(200, path, ());
}
match (method, path) {
(&Method::GET, "/") | (&Method::GET, "/index.html") => {
self.render_page("/index.html", ())
}
(&Method::GET, "/domain.html") => {
self.with_query_req(&req, |args: DomainGetArgs| async { self.domain_get(args) })
.await
}
(&Method::GET, "/domain_init.html") => {
self.with_query_req(&req, |args: DomainInitArgs| async {
self.with_query_req(&req, |config: service::util::FlatConfig| async {
self.domain_init(args, config)
})
.await
})
.await
}
(&Method::GET, "/domain_sync.html") => {
self.with_query_req(&req, |args: DomainSyncArgs| async {
self.with_query_req(&req, |config: service::util::FlatConfig| async {
self.domain_sync(args, config).await
})
.await
})
.await
}
(&Method::GET, "/domains.html") => self.domains(),
_ => self.render_error_page(404, "Page not found!"),
}
}
pub async fn handle_request(&self, req: Request<Body>) -> Result<Response<Body>, Infallible> {
match self.handle_request_inner(req).await {
Ok(res) => Ok(res),
Err(err) => panic!("unexpected error {err}"),
}
}
}
fn strip_port(host: &str) -> &str {
match host.rfind(':') {
None => host,
Some(i) => &host[..i],
}
}

132
src/service/http/tasks.rs Normal file
View File

@ -0,0 +1,132 @@
use crate::error::unexpected::{self, Mappable};
use crate::{domain, service};
use std::{convert, future, net, sync};
use futures::StreamExt;
use tokio_util::sync::CancellationToken;
pub async fn listen_http(
service: sync::Arc<service::http::Service>,
canceller: CancellationToken,
addr: net::SocketAddr,
domain: domain::Name,
) -> Result<(), unexpected::Error> {
let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone();
// Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| {
let service = service.clone();
async move { service.handle_request(req).await }
});
// Return the service to hyper.
async move { Ok::<_, convert::Infallible>(hyper_service) }
});
log::info!("Listening on http://{}:{}", domain.as_str(), addr.port());
let server = hyper::Server::bind(&addr).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
canceller.cancelled().await;
});
graceful.await.or_unexpected()
}
pub async fn listen_https(
service: sync::Arc<service::http::Service>,
canceller: CancellationToken,
cert_resolver: sync::Arc<dyn rustls::server::ResolvesServerCert>,
addr: net::SocketAddr,
domain: domain::Name,
) -> Result<(), unexpected::Error> {
let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone();
// Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| {
let service = service.clone();
async move { service.handle_request(req).await }
});
// Return the service to hyper.
async move { Ok::<_, convert::Infallible>(hyper_service) }
});
let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
rustls::server::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(cert_resolver),
)
.into();
let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr)
.expect("https listen socket creation failed");
let incoming = tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| {
if let Err(err) = conn {
log::error!("Error accepting TLS connection: {:?}", err);
future::ready(false)
} else {
future::ready(true)
}
});
let incoming = hyper::server::accept::from_stream(incoming);
log::info!("Listening on https://{}:{}", domain.as_str(), addr.port());
let server = hyper::Server::builder(incoming).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
canceller.cancelled().await;
});
graceful.await.or_unexpected()
}
pub async fn cert_refresher(
domain_manager: sync::Arc<dyn domain::manager::Manager>,
canceller: CancellationToken,
http_domain: domain::Name,
) -> Result<(), unexpected::Error> {
use tokio::time;
let mut interval = time::interval(time::Duration::from_secs(60 * 60));
loop {
tokio::select! {
_ = interval.tick() => (),
_ = canceller.cancelled() => return Ok(()),
}
_ = domain_manager
.sync_cert(http_domain.clone())
.await
.inspect_err(|err| {
log::error!(
"Error while getting cert for {}: {err}",
http_domain.as_str()
)
});
let domains_iter = domain_manager.all_domains();
if let Err(err) = domains_iter {
log::error!("Got error calling all_domains: {err}");
continue;
}
for domain in domains_iter.unwrap().into_iter() {
let _ = domain_manager
.sync_cert(domain.clone())
.await
.inspect_err(|err| {
log::error!("Error while getting cert for {}: {err}", domain.as_str(),)
});
}
}
}

View File

@ -1,11 +1,11 @@
use handlebars::Handlebars; use handlebars::Handlebars;
#[derive(rust_embed::RustEmbed)] #[derive(rust_embed::RustEmbed)]
#[folder = "src/service/http_tpl"] #[folder = "src/service/http/tpl"]
#[prefix = "/"] #[prefix = "/"]
struct Dir; struct Dir;
pub fn get<'hbs>() -> Handlebars<'hbs> { pub fn get() -> Handlebars<'static> {
let mut reg = Handlebars::new(); let mut reg = Handlebars::new();
reg.register_embed_templates::<Dir>() reg.register_embed_templates::<Dir>()
.expect("registered embedded templates"); .expect("registered embedded templates");

View File

@ -1,4 +1,7 @@
use std::{fs, io, path}; use std::{error, fs, io, path};
use futures::stream::futures_unordered::FuturesUnordered;
use tokio_util::sync::CancellationToken;
pub fn open_file(path: &path::Path) -> io::Result<Option<fs::File>> { pub fn open_file(path: &path::Path) -> io::Result<Option<fs::File>> {
match fs::File::open(path) { match fs::File::open(path) {
@ -9,3 +12,45 @@ pub fn open_file(path: &path::Path) -> io::Result<Option<fs::File>> {
}, },
} }
} }
pub struct TaskSet<E>
where
E: error::Error + Send + 'static,
{
canceller: CancellationToken,
wait_group: FuturesUnordered<tokio::task::JoinHandle<Result<(), E>>>,
}
impl<E> TaskSet<E>
where
E: error::Error + Send + 'static,
{
pub fn new() -> TaskSet<E> {
TaskSet {
canceller: CancellationToken::new(),
wait_group: FuturesUnordered::new(),
}
}
pub fn spawn<F, Fut>(&self, mut f: F)
where
Fut: futures::Future<Output = Result<(), E>> + Send + 'static,
F: FnMut(CancellationToken) -> Fut,
{
let canceller = self.canceller.clone();
let handle = tokio::spawn(f(canceller));
self.wait_group.push(handle);
}
pub async fn stop(self) -> Vec<E> {
self.canceller.cancel();
let mut res = Vec::new();
for f in self.wait_group {
if let Err(err) = f.await.expect("task failed") {
res.push(err);
}
}
res
}
}