Use async dns client, pass tokio runtime into the checker constructor
This commit is contained in:
parent
f9801af166
commit
26ebda90e8
@ -1,12 +1,13 @@
|
||||
use std::error::Error;
|
||||
use std::str::FromStr;
|
||||
use std::sync;
|
||||
|
||||
use crate::domain;
|
||||
|
||||
use mockall::automock;
|
||||
use trust_dns_client::client::{Client, SyncClient};
|
||||
use trust_dns_client::client::{AsyncClient, ClientHandle};
|
||||
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
|
||||
use trust_dns_client::udp::UdpClientConnection;
|
||||
use trust_dns_client::udp;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum NewDNSCheckerError {
|
||||
@ -42,11 +43,15 @@ pub trait Checker: std::marker::Send + std::marker::Sync {
|
||||
}
|
||||
|
||||
pub struct DNSChecker {
|
||||
tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
|
||||
target_cname: Name,
|
||||
client: SyncClient<UdpClientConnection>,
|
||||
|
||||
// TODO we should use some kind of connection pool here, I suppose
|
||||
client: tokio::sync::Mutex<AsyncClient>,
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
|
||||
target_cname: domain::Name,
|
||||
resolver_addr: &str,
|
||||
) -> Result<impl Checker, NewDNSCheckerError> {
|
||||
@ -54,14 +59,18 @@ pub fn new(
|
||||
.parse()
|
||||
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
|
||||
|
||||
let conn = UdpClientConnection::new(resolver_addr)
|
||||
let stream = udp::UdpClientStream::<tokio::net::UdpSocket>::new(resolver_addr);
|
||||
|
||||
let (client, bg) = tokio_runtime
|
||||
.block_on(async { AsyncClient::connect(stream).await })
|
||||
.map_err(|e| NewDNSCheckerError::Unexpected(Box::from(e)))?;
|
||||
|
||||
let client = SyncClient::new(conn);
|
||||
tokio_runtime.spawn(bg);
|
||||
|
||||
Ok(DNSChecker {
|
||||
tokio_runtime,
|
||||
target_cname: target_cname.inner,
|
||||
client,
|
||||
client: tokio::sync::Mutex::new(client),
|
||||
})
|
||||
}
|
||||
|
||||
@ -75,10 +84,16 @@ impl Checker for DNSChecker {
|
||||
|
||||
// check that the CNAME is installed correctly on the domain
|
||||
{
|
||||
let response = self
|
||||
.client
|
||||
.query(domain, DNSClass::IN, RecordType::CNAME)
|
||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
||||
let response = match self.tokio_runtime.block_on(async {
|
||||
self.client
|
||||
.lock()
|
||||
.await
|
||||
.query(domain.clone(), DNSClass::IN, RecordType::CNAME)
|
||||
.await
|
||||
}) {
|
||||
Ok(res) => res,
|
||||
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
|
||||
};
|
||||
|
||||
let records = response.answers();
|
||||
|
||||
@ -98,13 +113,19 @@ impl Checker for DNSChecker {
|
||||
{
|
||||
let domain = Name::from_str("_domiply_challenge")
|
||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?
|
||||
.append_domain(domain)
|
||||
.append_domain(&domain)
|
||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.query(&domain, DNSClass::IN, RecordType::TXT)
|
||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
||||
let response = match self.tokio_runtime.block_on(async {
|
||||
self.client
|
||||
.lock()
|
||||
.await
|
||||
.query(domain, DNSClass::IN, RecordType::TXT)
|
||||
.await
|
||||
}) {
|
||||
Ok(res) => res,
|
||||
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
|
||||
};
|
||||
|
||||
let records = response.answers();
|
||||
|
||||
|
@ -199,7 +199,7 @@ where
|
||||
.hash()
|
||||
.map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?;
|
||||
|
||||
self.domain_checker.check_domain(domain, &config_hash)?;
|
||||
self.domain_checker.check_domain(&domain, &config_hash)?;
|
||||
|
||||
self.origin_store
|
||||
.sync(config.origin_descr.clone(), origin::store::Limits {})?;
|
||||
|
55
src/main.rs
55
src/main.rs
@ -31,37 +31,40 @@ struct Cli {
|
||||
domain_config_store_dir_path: path::PathBuf,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
fn main() {
|
||||
let config = Cli::parse();
|
||||
|
||||
let tokio_runtime = std::sync::Arc::new(
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let (stop_ch_tx, stop_ch_rx) = tokio_runtime.block_on(async { oneshot::channel() });
|
||||
|
||||
// set up signal handling, stop_ch_rx will be used to signal that the stop signal has been
|
||||
// received
|
||||
let stop_ch_rx = {
|
||||
tokio_runtime.spawn(async move {
|
||||
let mut signals = Signals::new(&[signal::SIGTERM, signal::SIGINT, signal::SIGQUIT])
|
||||
.expect("initialized signals");
|
||||
|
||||
let (stop_ch_tx, stop_ch_rx) = oneshot::channel();
|
||||
if let Some(_) = signals.next().await {
|
||||
println!("Gracefully shutting down...");
|
||||
let _ = stop_ch_tx.send(());
|
||||
}
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(_) = signals.next().await {
|
||||
println!("Gracefully shutting down...");
|
||||
let _ = stop_ch_tx.send(());
|
||||
}
|
||||
|
||||
if let Some(_) = signals.next().await {
|
||||
println!("Forcefully shutting down");
|
||||
std::process::exit(1);
|
||||
}
|
||||
});
|
||||
|
||||
stop_ch_rx
|
||||
};
|
||||
if let Some(_) = signals.next().await {
|
||||
println!("Forcefully shutting down");
|
||||
std::process::exit(1);
|
||||
};
|
||||
});
|
||||
|
||||
let origin_store = domiply::origin::store::git::new(config.origin_store_git_dir_path)
|
||||
.expect("git origin store initialized");
|
||||
|
||||
let domain_checker = domiply::domain::checker::new(
|
||||
tokio_runtime.clone(),
|
||||
config.domain_checker_target_cname.clone(),
|
||||
&config.domain_checker_resolver_addr,
|
||||
)
|
||||
@ -79,15 +82,15 @@ async fn main() {
|
||||
)
|
||||
.expect("service initialized");
|
||||
|
||||
let (addr, server) =
|
||||
warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async {
|
||||
stop_ch_rx.await.ok();
|
||||
});
|
||||
tokio_runtime.block_on(async {
|
||||
let (addr, server) =
|
||||
warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async {
|
||||
stop_ch_rx.await.ok();
|
||||
});
|
||||
|
||||
println!("Listening on {addr}");
|
||||
tokio::task::spawn(server)
|
||||
.await
|
||||
.expect("server shutdown gracefully");
|
||||
println!("Listening on {addr}");
|
||||
server.await;
|
||||
});
|
||||
|
||||
println!("Graceful shutdown complete");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user