From 06cda77772e29d4f19a44a41a91b45ca869496a6 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Fri, 19 May 2023 13:26:27 +0200 Subject: [PATCH] Periodically refresh certs for all domains --- src/domain/config.rs | 21 ++++++++++++++++ src/domain/manager.rs | 8 ++++++ src/main.rs | 57 +++++++++++++++++++++++++++++++++++-------- 3 files changed, 76 insertions(+), 10 deletions(-) diff --git a/src/domain/config.rs b/src/domain/config.rs index b7f231b..8a69ed1 100644 --- a/src/domain/config.rs +++ b/src/domain/config.rs @@ -1,4 +1,5 @@ use std::path::{Path, PathBuf}; +use std::str::FromStr; use std::{fs, io, sync}; use crate::error::{MapUnexpected, ToUnexpected}; @@ -37,10 +38,14 @@ pub enum SetError { Unexpected(#[from] error::Unexpected), } +/// Used in the return from all_domains from Store. +pub type AllDomainsResult = Result; + #[mockall::automock] pub trait Store { fn get(&self, domain: &domain::Name) -> Result; fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError>; + fn all_domains(&self) -> AllDomainsResult>>; } pub trait BoxedStore: Store + Send + Sync + Clone {} @@ -89,6 +94,22 @@ impl Store for sync::Arc { Ok(()) } + + fn all_domains(&self) -> AllDomainsResult>> { + Ok(fs::read_dir(&self.dir_path) + .map_unexpected()? + .map( + |dir_entry_res: io::Result| -> AllDomainsResult { + let domain = dir_entry_res.map_unexpected()?.file_name(); + let domain = domain.to_str().ok_or_else(|| { + error::Unexpected::from("couldn't convert os string to &str") + })?; + + Ok(domain::Name::from_str(domain).map_unexpected()?) + }, + ) + .collect()) + } } #[cfg(test)] diff --git a/src/domain/manager.rs b/src/domain/manager.rs index 8b9515b..26db366 100644 --- a/src/domain/manager.rs +++ b/src/domain/manager.rs @@ -115,6 +115,8 @@ impl From for SyncWithConfigError { pub type GetAcmeHttp01ChallengeKeyError = acme::manager::GetHttp01ChallengeKeyError; +pub type AllDomainsResult = config::AllDomainsResult; + #[mockall::automock( type Origin=origin::MockOrigin; type SyncWithConfigFuture=future::Ready>; @@ -153,6 +155,8 @@ pub trait Manager { &self, token: &str, ) -> Result; + + fn all_domains(&self) -> AllDomainsResult>>; } pub trait BoxedManager: Manager + Send + Sync + Clone {} @@ -283,4 +287,8 @@ where Err(GetAcmeHttp01ChallengeKeyError::NotFound) } + + fn all_domains(&self) -> AllDomainsResult>> { + self.domain_config_store.all_domains() + } } diff --git a/src/main.rs b/src/main.rs index 5f3de3e..9598638 100644 --- a/src/main.rs +++ b/src/main.rs @@ -158,7 +158,7 @@ fn main() { }); let service = domiply::service::new( - manager, + manager.clone(), config.domain_checker_target_a, config.passphrase, config.http_domain.clone(), @@ -181,6 +181,7 @@ fn main() { wait_group.push({ let http_domain = config.http_domain.clone(); + let canceller = canceller.clone(); tokio_runtime.spawn(async move { let addr = config.http_listen_addr; @@ -205,18 +206,54 @@ fn main() { // if there's an acme manager then it means that https is enabled, and we should ensure that // the http domain for domiply itself has a valid certificate. if let Some(domain_acme_manager) = domain_acme_manager { + let manager = manager.clone(); + let canceller = canceller.clone(); let http_domain = config.http_domain.clone(); + // Periodically refresh all domain certs wait_group.push(tokio_runtime.spawn(async move { - _ = domain_acme_manager - .sync_domain(http_domain.clone()) - .await - .inspect_err(|err| { - println!( - "Error while getting cert for {}: {err}", - http_domain.as_str() - ) - }); + let mut interval = time::interval(time::Duration::from_secs(60 * 60)); + + loop { + select! { + _ = interval.tick() => (), + _ = canceller.cancelled() => return, + } + + _ = domain_acme_manager + .sync_domain(http_domain.clone()) + .await + .inspect_err(|err| { + println!( + "Error while getting cert for {}: {err}", + http_domain.as_str() + ) + }); + + let domains_iter = manager.all_domains(); + + if let Err(err) = domains_iter { + println!("Got error calling all_domains: {err}"); + continue; + } + + for domain in domains_iter.unwrap().into_iter() { + match domain { + Ok(domain) => { + let _ = domain_acme_manager + .sync_domain(domain.clone()) + .await + .inspect_err(|err| { + println!( + "Error while getting cert for {}: {err}", + domain.as_str(), + ) + }); + } + Err(err) => println!("Error iterating through domains: {err}"), + }; + } + } })); }