diff --git a/Cargo.lock b/Cargo.lock index 6b80734..633152a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -379,6 +379,7 @@ dependencies = [ "tempdir", "thiserror", "tokio", + "tokio-util", "trust-dns-client", ] diff --git a/Cargo.toml b/Cargo.toml index 5c11751..7ec4574 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,3 +31,4 @@ mime_guess = "2.0.4" hyper = { version = "0.14.26", features = [ "server" ]} http = "0.2.9" serde_urlencoded = "0.7.1" +tokio-util = "0.7.8" diff --git a/src/main.rs b/src/main.rs index 44d7df1..64565e2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,8 @@ use clap::Parser; use futures::stream::StreamExt; use signal_hook::consts::signal; use signal_hook_tokio::Signals; -use tokio::sync::oneshot; +use tokio::select; +use tokio::time; use std::convert::Infallible; use std::net::SocketAddr; @@ -10,6 +11,8 @@ use std::path; use std::str::FromStr; use std::sync; +use domiply::origin::store::Store; + #[derive(Parser, Debug)] #[command(version)] #[command(about = "A domiply to another dimension")] @@ -46,28 +49,70 @@ fn main() { .unwrap(), ); - let (stop_ch_tx, stop_ch_rx) = tokio_runtime.block_on(async { oneshot::channel() }); + let canceller = tokio_runtime.block_on(async { tokio_util::sync::CancellationToken::new() }); - // set up signal handling, stop_ch_rx will be used to signal that the stop signal has been - // received - tokio_runtime.spawn(async move { - let mut signals = Signals::new(&[signal::SIGTERM, signal::SIGINT, signal::SIGQUIT]) - .expect("initialized signals"); + { + let canceller = canceller.clone(); + tokio_runtime.spawn(async move { + let mut signals = Signals::new(&[signal::SIGTERM, signal::SIGINT, signal::SIGQUIT]) + .expect("initialized signals"); - if let Some(_) = signals.next().await { - println!("Gracefully shutting down..."); - let _ = stop_ch_tx.send(()); - } + if let Some(_) = signals.next().await { + println!("Gracefully shutting down..."); + canceller.cancel(); + } - if let Some(_) = signals.next().await { - println!("Forcefully shutting down"); - std::process::exit(1); - }; - }); + if let Some(_) = signals.next().await { + println!("Forcefully shutting down"); + std::process::exit(1); + }; + }); + } let origin_store = domiply::origin::store::git::new(config.origin_store_git_dir_path) .expect("git origin store initialized"); + let origin_syncer_handler = { + let origin_store = origin_store.clone(); + let canceller = canceller.clone(); + + tokio_runtime.spawn(async move { + let mut interval = time::interval(time::Duration::from_secs(20 * 60)); + interval.tick().await; + + loop { + origin_store + .all_descrs() + .expect("got all_descrs iter") + .into_iter() + .for_each(|descr| { + if canceller.is_cancelled() { + return; + } + + if let Err(err) = descr { + println!("failed iterating origins: {err}"); + return; + } + + let descr = descr.unwrap(); + + println!("syncing origin: {descr:?}"); + if let Err(err) = + origin_store.sync(descr.clone(), domiply::origin::store::Limits {}) + { + println!("error syncing origin {descr:?}: {err}"); + } + }); + + select! { + _ = interval.tick() => continue, + _ = canceller.cancelled() => return, + } + } + }) + }; + let domain_checker = domiply::domain::checker::new( tokio_runtime.clone(), config.domain_checker_target_aaaa, @@ -103,20 +148,27 @@ fn main() { async move { Ok::<_, Infallible>(service) } }); - tokio_runtime.block_on(async { - let addr = config.http_listen_addr; + let server_handler = { + let canceller = canceller.clone(); + tokio_runtime.spawn(async move { + let addr = config.http_listen_addr; - println!("Listening on {addr}"); - let server = hyper::Server::bind(&addr).serve(make_service); + println!("Listening on {addr}"); + let server = hyper::Server::bind(&addr).serve(make_service); - let graceful = server.with_graceful_shutdown(async { - stop_ch_rx.await.ok(); - }); + let graceful = server.with_graceful_shutdown(async { + canceller.cancelled().await; + }); - if let Err(e) = graceful.await { - panic!("server error: {}", e); - } - }); + if let Err(e) = graceful.await { + panic!("server error: {}", e); + }; + }) + }; + + tokio_runtime + .block_on(async { futures::try_join!(origin_syncer_handler, server_handler) }) + .unwrap(); println!("Graceful shutdown complete"); } diff --git a/src/origin/store.rs b/src/origin/store.rs index aa09229..6ddc424 100644 --- a/src/origin/store.rs +++ b/src/origin/store.rs @@ -41,17 +41,13 @@ pub enum AllDescrsError { /// Used in the return from all_descrs from Store. pub type AllDescrsResult = Result; -#[mockall::automock( - type Origin=origin::MockOrigin; - type AllDescrsIter=Vec>; - )] /// Describes a storage mechanism for Origins. Each Origin is uniquely identified by its Descr. -pub trait Store: std::marker::Send + std::marker::Sync { +pub trait Store: Send + Sync + Clone { type Origin<'store>: origin::Origin + 'store where Self: 'store; - type AllDescrsIter<'store>: IntoIterator> + 'store + type AllDescrsIter<'store>: IntoIterator> + Send + 'store where Self: 'store; diff --git a/src/origin/store/git.rs b/src/origin/store/git.rs index f9f8d54..75a161c 100644 --- a/src/origin/store/git.rs +++ b/src/origin/store/git.rs @@ -67,11 +67,11 @@ struct Store { pub fn new(dir_path: PathBuf) -> io::Result { fs::create_dir_all(&dir_path)?; - Ok(Store { + Ok(sync::Arc::new(Store { dir_path, sync_guard: sync::Mutex::new(collections::HashMap::new()), origins: sync::RwLock::new(collections::HashMap::new()), - }) + })) } impl Store { @@ -191,11 +191,11 @@ impl Store { } } -impl super::Store for Store { +impl super::Store for sync::Arc { type Origin<'store> = sync::Arc where Self: 'store; - type AllDescrsIter<'store> = Box> + 'store> + type AllDescrsIter<'store> = Box> + Send + 'store> where Self: 'store; fn sync(&self, descr: origin::Descr, limits: store::Limits) -> Result<(), store::SyncError> {