From 7dd52839b16d84df8bdcb423460eda42cddaac19 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Wed, 21 Jun 2023 13:15:42 +0200 Subject: [PATCH] Use TaskStack to clean up startup/shutdown logic significantly --- TODO | 1 - src/domain/manager.rs | 73 +++++++++++++++++------------------------ src/error/unexpected.rs | 22 +++++++++++++ src/main.rs | 68 ++++++++++++++++---------------------- src/service/http.rs | 22 ++++--------- src/util.rs | 26 ++++++++------- 6 files changed, 100 insertions(+), 112 deletions(-) diff --git a/TODO b/TODO index 4d39ffb..e55da09 100644 --- a/TODO +++ b/TODO @@ -1,3 +1,2 @@ - make domain_manager implement rusttls cert resolver - Try to switch from Arc to Box where possible -- maybe build TaskSet into some kind of defer-like replacement diff --git a/src/domain/manager.rs b/src/domain/manager.rs index 0353de7..024805c 100644 --- a/src/domain/manager.rs +++ b/src/domain/manager.rs @@ -1,6 +1,7 @@ use crate::domain::{self, acme, checker, config}; use crate::error::unexpected::{self, Mappable}; use crate::origin; +use crate::util; use std::{future, pin, sync}; use tokio_util::sync::CancellationToken; @@ -144,70 +145,56 @@ pub trait Manager: Sync + Send { fn all_domains(&self) -> Result, unexpected::Error>; } -pub struct ManagerImpl { +struct ManagerImpl { origin_store: sync::Arc, domain_config_store: sync::Arc, domain_checker: checker::DNSChecker, acme_manager: Option>, - - canceller: CancellationToken, - origin_sync_handler: tokio::task::JoinHandle<()>, } -fn sync_origins(origin_store: &dyn origin::store::Store) { - match origin_store.all_descrs() { - Ok(iter) => iter.into_iter(), - Err(err) => { - log::error!("Error fetching origin descriptors: {err}"); - return; +async fn sync_origins(origin_store: &dyn origin::store::Store, canceller: CancellationToken) { + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(20 * 60)); + loop { + tokio::select! { + _ = interval.tick() => { + + match origin_store.all_descrs() { + Ok(iter) => iter.into_iter(), + Err(err) => { + log::error!("Error fetching origin descriptors: {err}"); + return; + } } + .for_each(|descr| { + if let Err(err) = origin_store.sync(descr.clone(), origin::store::Limits {}) { + log::error!("Failed to sync store for {:?}: {err}", descr); + return; + } + }); + }, + _ = canceller.cancelled() => return, + } } - .for_each(|descr| { - if let Err(err) = origin_store.sync(descr.clone(), origin::store::Limits {}) { - log::error!("Failed to sync store for {:?}: {err}", descr); - return; - } - }); } pub fn new( + task_stack: &mut util::TaskStack, origin_store: sync::Arc, domain_config_store: sync::Arc, domain_checker: checker::DNSChecker, acme_manager: Option>, -) -> ManagerImpl { - let canceller = CancellationToken::new(); - - let origin_sync_handler = { +) -> sync::Arc { + task_stack.spawn(|canceller| { let origin_store = origin_store.clone(); - let canceller = canceller.clone(); - tokio::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(20 * 60)); + async move { Ok(sync_origins(origin_store.as_ref(), canceller).await) } + }); - loop { - tokio::select! { - _ = interval.tick() => sync_origins(origin_store.as_ref()), - _ = canceller.cancelled() => return, - } - } - }) - }; - - ManagerImpl { + sync::Arc::new(ManagerImpl { origin_store, domain_config_store, domain_checker, acme_manager, - canceller, - origin_sync_handler, - } -} - -impl ManagerImpl { - pub fn stop(self) -> tokio::task::JoinHandle<()> { - self.canceller.cancel(); - self.origin_sync_handler - } + }) } impl Manager for ManagerImpl { diff --git a/src/error/unexpected.rs b/src/error/unexpected.rs index ca31ab9..fc5e082 100644 --- a/src/error/unexpected.rs +++ b/src/error/unexpected.rs @@ -103,6 +103,28 @@ impl Mappable for Result { } } +static OPTION_NONE_ERROR: &'static str = "expected Some but got None"; + +impl Mappable for Option { + fn or_unexpected(self) -> Result { + self.ok_or(Error::from(OPTION_NONE_ERROR)).or_unexpected() + } + + fn or_unexpected_while(self, prefix: D) -> Result { + self.ok_or(Error::from(OPTION_NONE_ERROR)) + .or_unexpected_while(prefix) + } + + fn map_unexpected_while(self, f: F) -> Result + where + F: FnOnce() -> D, + D: fmt::Display, + { + self.ok_or(Error::from(OPTION_NONE_ERROR)) + .map_unexpected_while(f) + } +} + pub trait Intoable { fn into_unexpected(self) -> Error; diff --git a/src/main.rs b/src/main.rs index bf3f1ab..c5eaa12 100644 --- a/src/main.rs +++ b/src/main.rs @@ -78,27 +78,6 @@ async fn main() { ) .init(); - let canceller = tokio_util::sync::CancellationToken::new(); - - { - let canceller = canceller.clone(); - - tokio::spawn(async move { - let mut signals = Signals::new(signal_hook::consts::TERM_SIGNALS) - .expect("initializing signals failed"); - - if (signals.next().await).is_some() { - log::info!("Gracefully shutting down..."); - canceller.cancel(); - } - - if (signals.next().await).is_some() { - log::warn!("Forcefully shutting down"); - std::process::exit(1); - }; - }); - } - let origin_store = domiply::origin::store::git::new(config.origin_store_git_dir_path) .expect("git origin store initialization failed"); @@ -138,38 +117,47 @@ async fn main() { None }; + let mut task_stack = domiply::util::TaskStack::new(); + let domain_manager = domiply::domain::manager::new( + &mut task_stack, origin_store, domain_config_store, domain_checker, https_params.as_ref().map(|p| p.domain_acme_manager.clone()), ); - let domain_manager = sync::Arc::new(domain_manager); + let _ = domiply::service::http::new( + &mut task_stack, + domain_manager.clone(), + config.domain_checker_target_a, + config.passphrase, + config.http_listen_addr.clone(), + config.http_domain.clone(), + https_params.map(|p| domiply::service::http::HTTPSParams { + listen_addr: p.https_listen_addr, + cert_resolver: domiply::domain::acme::resolver::new(p.domain_acme_store), + }), + ); - { - let (http_service, http_service_task_set) = domiply::service::http::new( - domain_manager.clone(), - config.domain_checker_target_a, - config.passphrase, - config.http_listen_addr.clone(), - config.http_domain.clone(), - https_params.map(|p| domiply::service::http::HTTPSParams { - listen_addr: p.https_listen_addr, - cert_resolver: domiply::domain::acme::resolver::new(p.domain_acme_store), - }), - ); + let mut signals = + Signals::new(signal_hook::consts::TERM_SIGNALS).expect("initializing signals failed"); - canceller.cancelled().await; - - domiply::service::http::stop(http_service, http_service_task_set).await; + if (signals.next().await).is_some() { + log::info!("Gracefully shutting down..."); } - sync::Arc::into_inner(domain_manager) - .unwrap() + tokio::spawn(async move { + if (signals.next().await).is_some() { + log::warn!("Forcefully shutting down"); + std::process::exit(1); + }; + }); + + task_stack .stop() .await - .expect("domain manager failed to shutdown cleanly"); + .expect("failed to stop all background tasks"); log::info!("Graceful shutdown complete"); } diff --git a/src/service/http.rs b/src/service/http.rs index 6e74e96..fadcce0 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -27,13 +27,14 @@ pub struct HTTPSParams { } pub fn new( + task_stack: &mut util::TaskStack, domain_manager: sync::Arc, target_a: net::Ipv4Addr, passphrase: String, http_listen_addr: net::SocketAddr, http_domain: domain::Name, https_params: Option, -) -> (sync::Arc, util::TaskSet) { +) -> sync::Arc { let service = sync::Arc::new(Service { domain_manager: domain_manager.clone(), target_a, @@ -42,9 +43,7 @@ pub fn new( handlebars: tpl::get(), }); - let task_set = util::TaskSet::new(); - - task_set.spawn(|canceller| { + task_stack.spawn(|canceller| { tasks::listen_http( service.clone(), canceller, @@ -54,7 +53,7 @@ pub fn new( }); if let Some(https_params) = https_params { - task_set.spawn(|canceller| { + task_stack.spawn(|canceller| { tasks::listen_https( service.clone(), canceller, @@ -64,21 +63,12 @@ pub fn new( ) }); - task_set.spawn(|canceller| { + task_stack.spawn(|canceller| { tasks::cert_refresher(domain_manager.clone(), canceller, http_domain.clone()) }); } - return (service, task_set); -} - -pub async fn stop(service: sync::Arc, task_set: util::TaskSet) { - task_set - .stop() - .await - .iter() - .for_each(|e| log::error!("error while shutting down http service: {e}")); - sync::Arc::into_inner(service).expect("service didn't get cleaned up"); + return service; } #[derive(Serialize)] diff --git a/src/util.rs b/src/util.rs index cbd2696..3efb067 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,6 +1,5 @@ use std::{error, fs, io, path}; -use futures::stream::futures_unordered::FuturesUnordered; use tokio_util::sync::CancellationToken; pub fn open_file(path: &path::Path) -> io::Result> { @@ -13,26 +12,26 @@ pub fn open_file(path: &path::Path) -> io::Result> { } } -pub struct TaskSet +pub struct TaskStack where E: error::Error + Send + 'static, { canceller: CancellationToken, - wait_group: FuturesUnordered>>, + wait_group: Vec>>, } -impl TaskSet +impl TaskStack where E: error::Error + Send + 'static, { - pub fn new() -> TaskSet { - TaskSet { + pub fn new() -> TaskStack { + TaskStack { canceller: CancellationToken::new(), - wait_group: FuturesUnordered::new(), + wait_group: Vec::new(), } } - pub fn spawn(&self, mut f: F) + pub fn spawn(&mut self, mut f: F) where Fut: futures::Future> + Send + 'static, F: FnMut(CancellationToken) -> Fut, @@ -42,15 +41,18 @@ where self.wait_group.push(handle); } - pub async fn stop(self) -> Vec { + pub async fn stop(mut self) -> Result<(), E> { self.canceller.cancel(); - let mut res = Vec::new(); + // reverse wait_group in place, so we stop the most recently added first. Since this method + // consumes self this is fine. + self.wait_group.reverse(); + for f in self.wait_group { if let Err(err) = f.await.expect("task failed") { - res.push(err); + return Err(err); } } - res + Ok(()) } }