From 26ebda90e81e3f4338237cdaa8fc0f992c74fad2 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Sun, 14 May 2023 11:18:36 +0200 Subject: [PATCH] Use async dns client, pass tokio runtime into the checker constructor --- src/domain/checker.rs | 51 +++++++++++++++++++++++++++------------ src/domain/manager.rs | 2 +- src/main.rs | 55 +++++++++++++++++++++++-------------------- 3 files changed, 66 insertions(+), 42 deletions(-) diff --git a/src/domain/checker.rs b/src/domain/checker.rs index d89559f..85014de 100644 --- a/src/domain/checker.rs +++ b/src/domain/checker.rs @@ -1,12 +1,13 @@ use std::error::Error; use std::str::FromStr; +use std::sync; use crate::domain; use mockall::automock; -use trust_dns_client::client::{Client, SyncClient}; +use trust_dns_client::client::{AsyncClient, ClientHandle}; use trust_dns_client::rr::{DNSClass, Name, RData, RecordType}; -use trust_dns_client::udp::UdpClientConnection; +use trust_dns_client::udp; #[derive(thiserror::Error, Debug)] pub enum NewDNSCheckerError { @@ -42,11 +43,15 @@ pub trait Checker: std::marker::Send + std::marker::Sync { } pub struct DNSChecker { + tokio_runtime: sync::Arc, target_cname: Name, - client: SyncClient, + + // TODO we should use some kind of connection pool here, I suppose + client: tokio::sync::Mutex, } pub fn new( + tokio_runtime: sync::Arc, target_cname: domain::Name, resolver_addr: &str, ) -> Result { @@ -54,14 +59,18 @@ pub fn new( .parse() .map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?; - let conn = UdpClientConnection::new(resolver_addr) + let stream = udp::UdpClientStream::::new(resolver_addr); + + let (client, bg) = tokio_runtime + .block_on(async { AsyncClient::connect(stream).await }) .map_err(|e| NewDNSCheckerError::Unexpected(Box::from(e)))?; - let client = SyncClient::new(conn); + tokio_runtime.spawn(bg); Ok(DNSChecker { + tokio_runtime, target_cname: target_cname.inner, - client, + client: tokio::sync::Mutex::new(client), }) } @@ -75,10 +84,16 @@ impl Checker for DNSChecker { // check that the CNAME is installed correctly on the domain { - let response = self - .client - .query(domain, DNSClass::IN, RecordType::CNAME) - .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; + let response = match self.tokio_runtime.block_on(async { + self.client + .lock() + .await + .query(domain.clone(), DNSClass::IN, RecordType::CNAME) + .await + }) { + Ok(res) => res, + Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))), + }; let records = response.answers(); @@ -98,13 +113,19 @@ impl Checker for DNSChecker { { let domain = Name::from_str("_domiply_challenge") .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))? - .append_domain(domain) + .append_domain(&domain) .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; - let response = self - .client - .query(&domain, DNSClass::IN, RecordType::TXT) - .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; + let response = match self.tokio_runtime.block_on(async { + self.client + .lock() + .await + .query(domain, DNSClass::IN, RecordType::TXT) + .await + }) { + Ok(res) => res, + Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))), + }; let records = response.answers(); diff --git a/src/domain/manager.rs b/src/domain/manager.rs index f636b95..80ae211 100644 --- a/src/domain/manager.rs +++ b/src/domain/manager.rs @@ -199,7 +199,7 @@ where .hash() .map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?; - self.domain_checker.check_domain(domain, &config_hash)?; + self.domain_checker.check_domain(&domain, &config_hash)?; self.origin_store .sync(config.origin_descr.clone(), origin::store::Limits {})?; diff --git a/src/main.rs b/src/main.rs index e16df92..23bff5d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,37 +31,40 @@ struct Cli { domain_config_store_dir_path: path::PathBuf, } -#[tokio::main] -async fn main() { +fn main() { let config = Cli::parse(); + let tokio_runtime = std::sync::Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(), + ); + + let (stop_ch_tx, stop_ch_rx) = tokio_runtime.block_on(async { oneshot::channel() }); + // set up signal handling, stop_ch_rx will be used to signal that the stop signal has been // received - let stop_ch_rx = { + tokio_runtime.spawn(async move { let mut signals = Signals::new(&[signal::SIGTERM, signal::SIGINT, signal::SIGQUIT]) .expect("initialized signals"); - let (stop_ch_tx, stop_ch_rx) = oneshot::channel(); + if let Some(_) = signals.next().await { + println!("Gracefully shutting down..."); + let _ = stop_ch_tx.send(()); + } - tokio::spawn(async move { - if let Some(_) = signals.next().await { - println!("Gracefully shutting down..."); - let _ = stop_ch_tx.send(()); - } - - if let Some(_) = signals.next().await { - println!("Forcefully shutting down"); - std::process::exit(1); - } - }); - - stop_ch_rx - }; + if let Some(_) = signals.next().await { + println!("Forcefully shutting down"); + std::process::exit(1); + }; + }); let origin_store = domiply::origin::store::git::new(config.origin_store_git_dir_path) .expect("git origin store initialized"); let domain_checker = domiply::domain::checker::new( + tokio_runtime.clone(), config.domain_checker_target_cname.clone(), &config.domain_checker_resolver_addr, ) @@ -79,15 +82,15 @@ async fn main() { ) .expect("service initialized"); - let (addr, server) = - warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async { - stop_ch_rx.await.ok(); - }); + tokio_runtime.block_on(async { + let (addr, server) = + warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async { + stop_ch_rx.await.ok(); + }); - println!("Listening on {addr}"); - tokio::task::spawn(server) - .await - .expect("server shutdown gracefully"); + println!("Listening on {addr}"); + server.await; + }); println!("Graceful shutdown complete"); }