diff --git a/src/domain/config.rs b/src/domain/config.rs index ceaa33a..2c1b45b 100644 --- a/src/domain/config.rs +++ b/src/domain/config.rs @@ -48,7 +48,7 @@ pub trait Store { fn all_domains(&self) -> AllDomainsResult>>; } -pub trait BoxedStore: Store + Send + Sync + Clone {} +pub trait BoxedStore: Store + Send + Sync + Clone + 'static {} struct FSStore { dir_path: PathBuf, diff --git a/src/domain/manager.rs b/src/domain/manager.rs index 37b37df..ebb66e9 100644 --- a/src/domain/manager.rs +++ b/src/domain/manager.rs @@ -1,8 +1,9 @@ use crate::domain::{self, acme, checker, config}; -use crate::error::unexpected::{self, Intoable, Mappable}; +use crate::error::unexpected::{self, Mappable}; use crate::origin; use std::{future, pin, sync}; +use tokio_util::sync::CancellationToken; #[derive(thiserror::Error, Debug)] pub enum GetConfigError { @@ -117,38 +118,20 @@ pub type GetAcmeHttp01ChallengeKeyError = acme::manager::GetHttp01ChallengeKeyEr pub type AllDomainsResult = config::AllDomainsResult; -#[mockall::automock( - type Origin=origin::MockOrigin; - type SyncWithConfigFuture=future::Ready>; - type SyncAllOriginsErrorsIter=Vec; -)] -pub trait Manager { - type Origin<'mgr>: origin::Origin + 'mgr - where - Self: 'mgr; - - type SyncWithConfigFuture<'mgr>: future::Future> - + Send - + Unpin - + 'mgr - where - Self: 'mgr; - - type SyncAllOriginsErrorsIter<'mgr>: IntoIterator + 'mgr - where - Self: 'mgr; - +#[mockall::automock] +pub trait Manager: Sync + Send { fn get_config(&self, domain: &domain::Name) -> Result; - fn get_origin(&self, domain: &domain::Name) -> Result, GetOriginError>; - - fn sync_with_config( + fn get_origin( &self, + domain: &domain::Name, + ) -> Result, GetOriginError>; + + fn sync_with_config<'mgr>( + &'mgr self, domain: domain::Name, config: config::Config, - ) -> Self::SyncWithConfigFuture<'_>; - - fn sync_all_origins(&self) -> Result, unexpected::Error>; + ) -> pin::Pin> + Send + 'mgr>>; fn get_acme_http01_challenge_key( &self, @@ -158,69 +141,99 @@ pub trait Manager { fn all_domains(&self) -> AllDomainsResult>>; } -pub trait BoxedManager: Manager + Send + Sync + Clone {} - -struct ManagerImpl +struct ManagerImpl where - OriginStore: origin::store::BoxedStore, DomainConfigStore: config::BoxedStore, AcmeManager: acme::manager::BoxedManager, { - origin_store: OriginStore, + origin_store: sync::Arc, domain_config_store: DomainConfigStore, domain_checker: checker::DNSChecker, acme_manager: Option, + + canceller: CancellationToken, + origin_sync_handler: tokio::task::JoinHandle<()>, } -pub fn new( - origin_store: OriginStore, +fn sync_origins(origin_store: &dyn origin::store::Store) { + match origin_store.all_descrs() { + Ok(iter) => iter.into_iter(), + Err(err) => { + log::error!("Error fetching origin descriptors: {err}"); + return; + } + } + .for_each(|descr| { + if let Err(err) = origin_store.sync(descr.clone(), origin::store::Limits {}) { + log::error!("Failed to sync store for {:?}: {err}", descr); + return; + } + }); +} + +pub fn new<'mgr, DomainConfigStore, AcmeManager>( + origin_store: sync::Arc, domain_config_store: DomainConfigStore, domain_checker: checker::DNSChecker, acme_manager: Option, -) -> impl BoxedManager +) -> sync::Arc where - OriginStore: origin::store::BoxedStore, DomainConfigStore: config::BoxedStore, AcmeManager: acme::manager::BoxedManager, { + let canceller = CancellationToken::new(); + + let origin_sync_handler = { + let origin_store = origin_store.clone(); + let canceller = canceller.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(20 * 60)); + + loop { + tokio::select! { + _ = interval.tick() => sync_origins(origin_store.as_ref()), + _ = canceller.cancelled() => return, + } + } + }) + }; + sync::Arc::new(ManagerImpl { origin_store, domain_config_store, domain_checker, acme_manager, + canceller, + origin_sync_handler, }) } -impl BoxedManager - for sync::Arc> +impl ManagerImpl where - OriginStore: origin::store::BoxedStore, DomainConfigStore: config::BoxedStore, AcmeManager: acme::manager::BoxedManager, { + pub async fn stop(self) { + self.canceller.cancel(); + self.origin_sync_handler + .await + .expect("origin_sync_handler errored"); + } } -impl Manager - for sync::Arc> +impl Manager for ManagerImpl where - OriginStore: origin::store::BoxedStore, DomainConfigStore: config::BoxedStore, AcmeManager: acme::manager::BoxedManager, { - type Origin<'mgr> = OriginStore::Origin<'mgr> - where Self: 'mgr; - - type SyncWithConfigFuture<'mgr> = pin::Pin> + Send + 'mgr>> - where Self: 'mgr; - - type SyncAllOriginsErrorsIter<'mgr> = Box + 'mgr> - where Self: 'mgr; - fn get_config(&self, domain: &domain::Name) -> Result { Ok(self.domain_config_store.get(domain)?) } - fn get_origin(&self, domain: &domain::Name) -> Result, GetOriginError> { + fn get_origin( + &self, + domain: &domain::Name, + ) -> Result, GetOriginError> { let config = self.domain_config_store.get(domain)?; let origin = self .origin_store @@ -230,11 +243,12 @@ where Ok(origin) } - fn sync_with_config( - &self, + fn sync_with_config<'mgr>( + &'mgr self, domain: domain::Name, config: config::Config, - ) -> Self::SyncWithConfigFuture<'_> { + ) -> pin::Pin> + Send + 'mgr>> + { Box::pin(async move { let config_hash = config .hash() @@ -257,31 +271,6 @@ where }) } - fn sync_all_origins(&self) -> Result, unexpected::Error> { - let iter = self - .origin_store - .all_descrs() - .or_unexpected_while("fetching all origin descrs")? - .into_iter(); - - Ok(Box::from(iter.filter_map(|descr| { - if let Err(err) = descr { - return Some(err.into_unexpected()); - } - - let descr = descr.unwrap(); - - if let Err(err) = self - .origin_store - .sync(descr.clone(), origin::store::Limits {}) - { - return Some(err.into_unexpected_while(format!("syncing store {:?}", descr))); - } - - None - }))) - } - fn get_acme_http01_challenge_key( &self, token: &str, diff --git a/src/main.rs b/src/main.rs index a216ef4..b745941 100644 --- a/src/main.rs +++ b/src/main.rs @@ -159,34 +159,6 @@ async fn main() { https_params.as_ref().map(|p| p.domain_acme_manager.clone()), ); - wait_group.push({ - let domain_manager = domain_manager.clone(); - let canceller = canceller.clone(); - - tokio::spawn(async move { - let mut interval = time::interval(time::Duration::from_secs(20 * 60)); - - loop { - select! { - _ = interval.tick() => (), - _ = canceller.cancelled() => return, - } - - let errors_iter = domain_manager.sync_all_origins(); - - if let Err(err) = errors_iter { - log::error!("Got error calling sync_all_origins: {err}"); - continue; - } - - errors_iter - .unwrap() - .into_iter() - .for_each(|err| log::error!("syncing failed: {err}")); - } - }) - }); - let service = domiply::service::new( domain_manager.clone(), config.domain_checker_target_a, diff --git a/src/origin/store.rs b/src/origin/store.rs index 66e8f7e..f376656 100644 --- a/src/origin/store.rs +++ b/src/origin/store.rs @@ -1,5 +1,6 @@ use crate::error::unexpected; use crate::origin; +use std::sync; pub mod git; @@ -38,29 +39,13 @@ pub enum AllDescrsError { Unexpected(#[from] unexpected::Error), } -/// Used in the return from all_descrs from Store. -pub type AllDescrsResult = Result; - -#[mockall::automock( - type Origin=origin::MockOrigin; - type AllDescrsIter=Vec>; - )] +#[mockall::automock] /// Describes a storage mechanism for Origins. Each Origin is uniquely identified by its Descr. -pub trait Store { - type Origin<'store>: origin::Origin + 'store - where - Self: 'store; - - type AllDescrsIter<'store>: IntoIterator> + 'store - where - Self: 'store; - +pub trait Store: Sync + Send { /// If the origin is of a kind which can be updated, sync will pull down the latest version of /// the origin into the storage. fn sync(&self, descr: origin::Descr, limits: Limits) -> Result<(), SyncError>; - fn get(&self, descr: origin::Descr) -> Result, GetError>; - fn all_descrs(&self) -> AllDescrsResult>; + fn get(&self, descr: origin::Descr) -> Result, GetError>; + fn all_descrs(&self) -> Result, AllDescrsError>; } - -pub trait BoxedStore: Store + Send + Sync + Clone {} diff --git a/src/origin/store/git.rs b/src/origin/store/git.rs index ff12273..1f1768a 100644 --- a/src/origin/store/git.rs +++ b/src/origin/store/git.rs @@ -68,10 +68,10 @@ struct Store { // more than one origin to be syncing at a time sync_guard: sync::Mutex>, - origins: sync::RwLock>, + origins: sync::RwLock>>, } -pub fn new(dir_path: PathBuf) -> io::Result { +pub fn new(dir_path: PathBuf) -> io::Result> { fs::create_dir_all(&dir_path)?; Ok(sync::Arc::new(Store { dir_path, @@ -208,15 +208,7 @@ impl Store { } } -impl super::BoxedStore for sync::Arc {} - -impl super::Store for sync::Arc { - type Origin<'store> = Origin - where Self: 'store; - - type AllDescrsIter<'store> = Box> + 'store> - where Self: 'store; - +impl super::Store for Store { fn sync(&self, descr: origin::Descr, limits: store::Limits) -> Result<(), store::SyncError> { // attempt to lock this descr for syncing, doing so within a new scope so the mutex // isn't actually being held for the whole method duration. @@ -256,12 +248,12 @@ impl super::Store for sync::Arc { })?; let mut origins = self.origins.write().unwrap(); - (*origins).insert(descr, origin); + (*origins).insert(descr, sync::Arc::new(origin)); Ok(()) } - fn get(&self, descr: origin::Descr) -> Result, store::GetError> { + fn get(&self, descr: origin::Descr) -> Result, store::GetError> { { let origins = self.origins.read().unwrap(); if let Some(origin) = origins.get(&descr) { @@ -288,16 +280,18 @@ impl super::Store for sync::Arc { GetOriginError::Unexpected(e) => store::GetError::Unexpected(e), })?; + let origin = sync::Arc::new(origin.clone()); + let mut origins = self.origins.write().unwrap(); + (*origins).insert(descr, origin.clone()); Ok(origin) } - fn all_descrs(&self) -> store::AllDescrsResult> { - Ok(Box::from( - fs::read_dir(&self.dir_path).or_unexpected()?.map( - |dir_entry_res: io::Result| -> store::AllDescrsResult { + fn all_descrs(&self) -> Result, store::AllDescrsError> { + fs::read_dir(&self.dir_path).or_unexpected()?.map( + |dir_entry_res: io::Result| -> Result { let descr_id: String = dir_entry_res .or_unexpected()? .file_name() @@ -323,8 +317,7 @@ impl super::Store for sync::Arc { Ok(descr) }, - ), - )) + ).try_collect() } } diff --git a/src/service.rs b/src/service.rs index ab5c9ba..a821132 100644 --- a/src/service.rs +++ b/src/service.rs @@ -7,7 +7,6 @@ use std::net; use std::str::FromStr; use std::sync; -use crate::origin::Origin; use crate::{domain, origin}; pub mod http_tpl; @@ -16,26 +15,20 @@ mod util; type SvcResponse = Result, String>; #[derive(Clone)] -pub struct Service<'svc, DomainManager> -where - DomainManager: domain::manager::BoxedManager, -{ - domain_manager: DomainManager, +pub struct Service<'svc> { + domain_manager: sync::Arc, target_a: net::Ipv4Addr, passphrase: String, http_domain: domain::Name, handlebars: handlebars::Handlebars<'svc>, } -pub fn new<'svc, DomainManager>( - domain_manager: DomainManager, +pub fn new<'svc>( + domain_manager: sync::Arc, target_a: net::Ipv4Addr, passphrase: String, http_domain: domain::Name, -) -> Service<'svc, DomainManager> -where - DomainManager: domain::manager::BoxedManager, -{ +) -> Service<'svc> { Service { domain_manager, target_a, @@ -67,10 +60,7 @@ struct DomainSyncArgs { passphrase: String, } -impl<'svc, DomainManager> Service<'svc, DomainManager> -where - DomainManager: domain::manager::BoxedManager, -{ +impl<'svc> Service<'svc> { fn serve_string(&self, status_code: u16, path: &'_ str, body: Vec) -> SvcResponse { let content_type = mime_guess::from_path(path) .first_or_octet_stream() @@ -320,13 +310,10 @@ where } } -pub async fn handle_request( - svc: sync::Arc>, +pub async fn handle_request( + svc: sync::Arc>, req: Request, -) -> Result, Infallible> -where - DomainManager: domain::manager::BoxedManager, -{ +) -> Result, Infallible> { match handle_request_inner(svc, req).await { Ok(res) => Ok(res), Err(err) => panic!("unexpected error {err}"), @@ -340,13 +327,7 @@ fn strip_port(host: &str) -> &str { } } -pub async fn handle_request_inner( - svc: sync::Arc>, - req: Request, -) -> SvcResponse -where - DomainManager: domain::manager::BoxedManager, -{ +pub async fn handle_request_inner(svc: sync::Arc>, req: Request) -> SvcResponse { let maybe_host = match ( req.headers() .get("Host")