Use async dns client, pass tokio runtime into the checker constructor

This commit is contained in:
Brian Picciano 2023-05-14 11:18:36 +02:00
parent f9801af166
commit 26ebda90e8
3 changed files with 66 additions and 42 deletions

View File

@ -1,12 +1,13 @@
use std::error::Error;
use std::str::FromStr;
use std::sync;
use crate::domain;
use mockall::automock;
use trust_dns_client::client::{Client, SyncClient};
use trust_dns_client::client::{AsyncClient, ClientHandle};
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
use trust_dns_client::udp::UdpClientConnection;
use trust_dns_client::udp;
#[derive(thiserror::Error, Debug)]
pub enum NewDNSCheckerError {
@ -42,11 +43,15 @@ pub trait Checker: std::marker::Send + std::marker::Sync {
}
pub struct DNSChecker {
tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
target_cname: Name,
client: SyncClient<UdpClientConnection>,
// TODO we should use some kind of connection pool here, I suppose
client: tokio::sync::Mutex<AsyncClient>,
}
pub fn new(
tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
target_cname: domain::Name,
resolver_addr: &str,
) -> Result<impl Checker, NewDNSCheckerError> {
@ -54,14 +59,18 @@ pub fn new(
.parse()
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
let conn = UdpClientConnection::new(resolver_addr)
let stream = udp::UdpClientStream::<tokio::net::UdpSocket>::new(resolver_addr);
let (client, bg) = tokio_runtime
.block_on(async { AsyncClient::connect(stream).await })
.map_err(|e| NewDNSCheckerError::Unexpected(Box::from(e)))?;
let client = SyncClient::new(conn);
tokio_runtime.spawn(bg);
Ok(DNSChecker {
tokio_runtime,
target_cname: target_cname.inner,
client,
client: tokio::sync::Mutex::new(client),
})
}
@ -75,10 +84,16 @@ impl Checker for DNSChecker {
// check that the CNAME is installed correctly on the domain
{
let response = self
.client
.query(domain, DNSClass::IN, RecordType::CNAME)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let response = match self.tokio_runtime.block_on(async {
self.client
.lock()
.await
.query(domain.clone(), DNSClass::IN, RecordType::CNAME)
.await
}) {
Ok(res) => res,
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
};
let records = response.answers();
@ -98,13 +113,19 @@ impl Checker for DNSChecker {
{
let domain = Name::from_str("_domiply_challenge")
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?
.append_domain(domain)
.append_domain(&domain)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let response = self
.client
.query(&domain, DNSClass::IN, RecordType::TXT)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let response = match self.tokio_runtime.block_on(async {
self.client
.lock()
.await
.query(domain, DNSClass::IN, RecordType::TXT)
.await
}) {
Ok(res) => res,
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
};
let records = response.answers();

View File

@ -199,7 +199,7 @@ where
.hash()
.map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?;
self.domain_checker.check_domain(domain, &config_hash)?;
self.domain_checker.check_domain(&domain, &config_hash)?;
self.origin_store
.sync(config.origin_descr.clone(), origin::store::Limits {})?;

View File

@ -31,37 +31,40 @@ struct Cli {
domain_config_store_dir_path: path::PathBuf,
}
#[tokio::main]
async fn main() {
fn main() {
let config = Cli::parse();
let tokio_runtime = std::sync::Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap(),
);
let (stop_ch_tx, stop_ch_rx) = tokio_runtime.block_on(async { oneshot::channel() });
// set up signal handling, stop_ch_rx will be used to signal that the stop signal has been
// received
let stop_ch_rx = {
tokio_runtime.spawn(async move {
let mut signals = Signals::new(&[signal::SIGTERM, signal::SIGINT, signal::SIGQUIT])
.expect("initialized signals");
let (stop_ch_tx, stop_ch_rx) = oneshot::channel();
if let Some(_) = signals.next().await {
println!("Gracefully shutting down...");
let _ = stop_ch_tx.send(());
}
tokio::spawn(async move {
if let Some(_) = signals.next().await {
println!("Gracefully shutting down...");
let _ = stop_ch_tx.send(());
}
if let Some(_) = signals.next().await {
println!("Forcefully shutting down");
std::process::exit(1);
}
});
stop_ch_rx
};
if let Some(_) = signals.next().await {
println!("Forcefully shutting down");
std::process::exit(1);
};
});
let origin_store = domiply::origin::store::git::new(config.origin_store_git_dir_path)
.expect("git origin store initialized");
let domain_checker = domiply::domain::checker::new(
tokio_runtime.clone(),
config.domain_checker_target_cname.clone(),
&config.domain_checker_resolver_addr,
)
@ -79,15 +82,15 @@ async fn main() {
)
.expect("service initialized");
let (addr, server) =
warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async {
stop_ch_rx.await.ok();
});
tokio_runtime.block_on(async {
let (addr, server) =
warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async {
stop_ch_rx.await.ok();
});
println!("Listening on {addr}");
tokio::task::spawn(server)
.await
.expect("server shutdown gracefully");
println!("Listening on {addr}");
server.await;
});
println!("Graceful shutdown complete");
}