Use async dns client, pass tokio runtime into the checker constructor

This commit is contained in:
Brian Picciano 2023-05-14 11:18:36 +02:00
parent f9801af166
commit 26ebda90e8
3 changed files with 66 additions and 42 deletions

View File

@ -1,12 +1,13 @@
use std::error::Error; use std::error::Error;
use std::str::FromStr; use std::str::FromStr;
use std::sync;
use crate::domain; use crate::domain;
use mockall::automock; use mockall::automock;
use trust_dns_client::client::{Client, SyncClient}; use trust_dns_client::client::{AsyncClient, ClientHandle};
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType}; use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
use trust_dns_client::udp::UdpClientConnection; use trust_dns_client::udp;
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum NewDNSCheckerError { pub enum NewDNSCheckerError {
@ -42,11 +43,15 @@ pub trait Checker: std::marker::Send + std::marker::Sync {
} }
pub struct DNSChecker { pub struct DNSChecker {
tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
target_cname: Name, target_cname: Name,
client: SyncClient<UdpClientConnection>,
// TODO we should use some kind of connection pool here, I suppose
client: tokio::sync::Mutex<AsyncClient>,
} }
pub fn new( pub fn new(
tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
target_cname: domain::Name, target_cname: domain::Name,
resolver_addr: &str, resolver_addr: &str,
) -> Result<impl Checker, NewDNSCheckerError> { ) -> Result<impl Checker, NewDNSCheckerError> {
@ -54,14 +59,18 @@ pub fn new(
.parse() .parse()
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?; .map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
let conn = UdpClientConnection::new(resolver_addr) let stream = udp::UdpClientStream::<tokio::net::UdpSocket>::new(resolver_addr);
let (client, bg) = tokio_runtime
.block_on(async { AsyncClient::connect(stream).await })
.map_err(|e| NewDNSCheckerError::Unexpected(Box::from(e)))?; .map_err(|e| NewDNSCheckerError::Unexpected(Box::from(e)))?;
let client = SyncClient::new(conn); tokio_runtime.spawn(bg);
Ok(DNSChecker { Ok(DNSChecker {
tokio_runtime,
target_cname: target_cname.inner, target_cname: target_cname.inner,
client, client: tokio::sync::Mutex::new(client),
}) })
} }
@ -75,10 +84,16 @@ impl Checker for DNSChecker {
// check that the CNAME is installed correctly on the domain // check that the CNAME is installed correctly on the domain
{ {
let response = self let response = match self.tokio_runtime.block_on(async {
.client self.client
.query(domain, DNSClass::IN, RecordType::CNAME) .lock()
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; .await
.query(domain.clone(), DNSClass::IN, RecordType::CNAME)
.await
}) {
Ok(res) => res,
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
};
let records = response.answers(); let records = response.answers();
@ -98,13 +113,19 @@ impl Checker for DNSChecker {
{ {
let domain = Name::from_str("_domiply_challenge") let domain = Name::from_str("_domiply_challenge")
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))? .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?
.append_domain(domain) .append_domain(&domain)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let response = self let response = match self.tokio_runtime.block_on(async {
.client self.client
.query(&domain, DNSClass::IN, RecordType::TXT) .lock()
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; .await
.query(domain, DNSClass::IN, RecordType::TXT)
.await
}) {
Ok(res) => res,
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
};
let records = response.answers(); let records = response.answers();

View File

@ -199,7 +199,7 @@ where
.hash() .hash()
.map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?; .map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?;
self.domain_checker.check_domain(domain, &config_hash)?; self.domain_checker.check_domain(&domain, &config_hash)?;
self.origin_store self.origin_store
.sync(config.origin_descr.clone(), origin::store::Limits {})?; .sync(config.origin_descr.clone(), origin::store::Limits {})?;

View File

@ -31,37 +31,40 @@ struct Cli {
domain_config_store_dir_path: path::PathBuf, domain_config_store_dir_path: path::PathBuf,
} }
#[tokio::main] fn main() {
async fn main() {
let config = Cli::parse(); let config = Cli::parse();
let tokio_runtime = std::sync::Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap(),
);
let (stop_ch_tx, stop_ch_rx) = tokio_runtime.block_on(async { oneshot::channel() });
// set up signal handling, stop_ch_rx will be used to signal that the stop signal has been // set up signal handling, stop_ch_rx will be used to signal that the stop signal has been
// received // received
let stop_ch_rx = { tokio_runtime.spawn(async move {
let mut signals = Signals::new(&[signal::SIGTERM, signal::SIGINT, signal::SIGQUIT]) let mut signals = Signals::new(&[signal::SIGTERM, signal::SIGINT, signal::SIGQUIT])
.expect("initialized signals"); .expect("initialized signals");
let (stop_ch_tx, stop_ch_rx) = oneshot::channel(); if let Some(_) = signals.next().await {
println!("Gracefully shutting down...");
let _ = stop_ch_tx.send(());
}
tokio::spawn(async move { if let Some(_) = signals.next().await {
if let Some(_) = signals.next().await { println!("Forcefully shutting down");
println!("Gracefully shutting down..."); std::process::exit(1);
let _ = stop_ch_tx.send(()); };
} });
if let Some(_) = signals.next().await {
println!("Forcefully shutting down");
std::process::exit(1);
}
});
stop_ch_rx
};
let origin_store = domiply::origin::store::git::new(config.origin_store_git_dir_path) let origin_store = domiply::origin::store::git::new(config.origin_store_git_dir_path)
.expect("git origin store initialized"); .expect("git origin store initialized");
let domain_checker = domiply::domain::checker::new( let domain_checker = domiply::domain::checker::new(
tokio_runtime.clone(),
config.domain_checker_target_cname.clone(), config.domain_checker_target_cname.clone(),
&config.domain_checker_resolver_addr, &config.domain_checker_resolver_addr,
) )
@ -79,15 +82,15 @@ async fn main() {
) )
.expect("service initialized"); .expect("service initialized");
let (addr, server) = tokio_runtime.block_on(async {
warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async { let (addr, server) =
stop_ch_rx.await.ok(); warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async {
}); stop_ch_rx.await.ok();
});
println!("Listening on {addr}"); println!("Listening on {addr}");
tokio::task::spawn(server) server.await;
.await });
.expect("server shutdown gracefully");
println!("Graceful shutdown complete"); println!("Graceful shutdown complete");
} }