Use async dns client, pass tokio runtime into the checker constructor
This commit is contained in:
parent
f9801af166
commit
26ebda90e8
@ -1,12 +1,13 @@
|
|||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
use std::sync;
|
||||||
|
|
||||||
use crate::domain;
|
use crate::domain;
|
||||||
|
|
||||||
use mockall::automock;
|
use mockall::automock;
|
||||||
use trust_dns_client::client::{Client, SyncClient};
|
use trust_dns_client::client::{AsyncClient, ClientHandle};
|
||||||
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
|
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
|
||||||
use trust_dns_client::udp::UdpClientConnection;
|
use trust_dns_client::udp;
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum NewDNSCheckerError {
|
pub enum NewDNSCheckerError {
|
||||||
@ -42,11 +43,15 @@ pub trait Checker: std::marker::Send + std::marker::Sync {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct DNSChecker {
|
pub struct DNSChecker {
|
||||||
|
tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
|
||||||
target_cname: Name,
|
target_cname: Name,
|
||||||
client: SyncClient<UdpClientConnection>,
|
|
||||||
|
// TODO we should use some kind of connection pool here, I suppose
|
||||||
|
client: tokio::sync::Mutex<AsyncClient>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
|
tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
|
||||||
target_cname: domain::Name,
|
target_cname: domain::Name,
|
||||||
resolver_addr: &str,
|
resolver_addr: &str,
|
||||||
) -> Result<impl Checker, NewDNSCheckerError> {
|
) -> Result<impl Checker, NewDNSCheckerError> {
|
||||||
@ -54,14 +59,18 @@ pub fn new(
|
|||||||
.parse()
|
.parse()
|
||||||
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
|
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
|
||||||
|
|
||||||
let conn = UdpClientConnection::new(resolver_addr)
|
let stream = udp::UdpClientStream::<tokio::net::UdpSocket>::new(resolver_addr);
|
||||||
|
|
||||||
|
let (client, bg) = tokio_runtime
|
||||||
|
.block_on(async { AsyncClient::connect(stream).await })
|
||||||
.map_err(|e| NewDNSCheckerError::Unexpected(Box::from(e)))?;
|
.map_err(|e| NewDNSCheckerError::Unexpected(Box::from(e)))?;
|
||||||
|
|
||||||
let client = SyncClient::new(conn);
|
tokio_runtime.spawn(bg);
|
||||||
|
|
||||||
Ok(DNSChecker {
|
Ok(DNSChecker {
|
||||||
|
tokio_runtime,
|
||||||
target_cname: target_cname.inner,
|
target_cname: target_cname.inner,
|
||||||
client,
|
client: tokio::sync::Mutex::new(client),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,10 +84,16 @@ impl Checker for DNSChecker {
|
|||||||
|
|
||||||
// check that the CNAME is installed correctly on the domain
|
// check that the CNAME is installed correctly on the domain
|
||||||
{
|
{
|
||||||
let response = self
|
let response = match self.tokio_runtime.block_on(async {
|
||||||
.client
|
self.client
|
||||||
.query(domain, DNSClass::IN, RecordType::CNAME)
|
.lock()
|
||||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
.await
|
||||||
|
.query(domain.clone(), DNSClass::IN, RecordType::CNAME)
|
||||||
|
.await
|
||||||
|
}) {
|
||||||
|
Ok(res) => res,
|
||||||
|
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
|
||||||
|
};
|
||||||
|
|
||||||
let records = response.answers();
|
let records = response.answers();
|
||||||
|
|
||||||
@ -98,13 +113,19 @@ impl Checker for DNSChecker {
|
|||||||
{
|
{
|
||||||
let domain = Name::from_str("_domiply_challenge")
|
let domain = Name::from_str("_domiply_challenge")
|
||||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?
|
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?
|
||||||
.append_domain(domain)
|
.append_domain(&domain)
|
||||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
||||||
|
|
||||||
let response = self
|
let response = match self.tokio_runtime.block_on(async {
|
||||||
.client
|
self.client
|
||||||
.query(&domain, DNSClass::IN, RecordType::TXT)
|
.lock()
|
||||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
.await
|
||||||
|
.query(domain, DNSClass::IN, RecordType::TXT)
|
||||||
|
.await
|
||||||
|
}) {
|
||||||
|
Ok(res) => res,
|
||||||
|
Err(e) => return Err(CheckDomainError::Unexpected(Box::from(e))),
|
||||||
|
};
|
||||||
|
|
||||||
let records = response.answers();
|
let records = response.answers();
|
||||||
|
|
||||||
|
@ -199,7 +199,7 @@ where
|
|||||||
.hash()
|
.hash()
|
||||||
.map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?;
|
.map_err(|e| SyncWithConfigError::Unexpected(Box::from(e)))?;
|
||||||
|
|
||||||
self.domain_checker.check_domain(domain, &config_hash)?;
|
self.domain_checker.check_domain(&domain, &config_hash)?;
|
||||||
|
|
||||||
self.origin_store
|
self.origin_store
|
||||||
.sync(config.origin_descr.clone(), origin::store::Limits {})?;
|
.sync(config.origin_descr.clone(), origin::store::Limits {})?;
|
||||||
|
55
src/main.rs
55
src/main.rs
@ -31,37 +31,40 @@ struct Cli {
|
|||||||
domain_config_store_dir_path: path::PathBuf,
|
domain_config_store_dir_path: path::PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
fn main() {
|
||||||
async fn main() {
|
|
||||||
let config = Cli::parse();
|
let config = Cli::parse();
|
||||||
|
|
||||||
|
let tokio_runtime = std::sync::Arc::new(
|
||||||
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let (stop_ch_tx, stop_ch_rx) = tokio_runtime.block_on(async { oneshot::channel() });
|
||||||
|
|
||||||
// set up signal handling, stop_ch_rx will be used to signal that the stop signal has been
|
// set up signal handling, stop_ch_rx will be used to signal that the stop signal has been
|
||||||
// received
|
// received
|
||||||
let stop_ch_rx = {
|
tokio_runtime.spawn(async move {
|
||||||
let mut signals = Signals::new(&[signal::SIGTERM, signal::SIGINT, signal::SIGQUIT])
|
let mut signals = Signals::new(&[signal::SIGTERM, signal::SIGINT, signal::SIGQUIT])
|
||||||
.expect("initialized signals");
|
.expect("initialized signals");
|
||||||
|
|
||||||
let (stop_ch_tx, stop_ch_rx) = oneshot::channel();
|
if let Some(_) = signals.next().await {
|
||||||
|
println!("Gracefully shutting down...");
|
||||||
|
let _ = stop_ch_tx.send(());
|
||||||
|
}
|
||||||
|
|
||||||
tokio::spawn(async move {
|
if let Some(_) = signals.next().await {
|
||||||
if let Some(_) = signals.next().await {
|
println!("Forcefully shutting down");
|
||||||
println!("Gracefully shutting down...");
|
std::process::exit(1);
|
||||||
let _ = stop_ch_tx.send(());
|
};
|
||||||
}
|
});
|
||||||
|
|
||||||
if let Some(_) = signals.next().await {
|
|
||||||
println!("Forcefully shutting down");
|
|
||||||
std::process::exit(1);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
stop_ch_rx
|
|
||||||
};
|
|
||||||
|
|
||||||
let origin_store = domiply::origin::store::git::new(config.origin_store_git_dir_path)
|
let origin_store = domiply::origin::store::git::new(config.origin_store_git_dir_path)
|
||||||
.expect("git origin store initialized");
|
.expect("git origin store initialized");
|
||||||
|
|
||||||
let domain_checker = domiply::domain::checker::new(
|
let domain_checker = domiply::domain::checker::new(
|
||||||
|
tokio_runtime.clone(),
|
||||||
config.domain_checker_target_cname.clone(),
|
config.domain_checker_target_cname.clone(),
|
||||||
&config.domain_checker_resolver_addr,
|
&config.domain_checker_resolver_addr,
|
||||||
)
|
)
|
||||||
@ -79,15 +82,15 @@ async fn main() {
|
|||||||
)
|
)
|
||||||
.expect("service initialized");
|
.expect("service initialized");
|
||||||
|
|
||||||
let (addr, server) =
|
tokio_runtime.block_on(async {
|
||||||
warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async {
|
let (addr, server) =
|
||||||
stop_ch_rx.await.ok();
|
warp::serve(service).bind_with_graceful_shutdown(config.http_listen_addr, async {
|
||||||
});
|
stop_ch_rx.await.ok();
|
||||||
|
});
|
||||||
|
|
||||||
println!("Listening on {addr}");
|
println!("Listening on {addr}");
|
||||||
tokio::task::spawn(server)
|
server.await;
|
||||||
.await
|
});
|
||||||
.expect("server shutdown gracefully");
|
|
||||||
|
|
||||||
println!("Graceful shutdown complete");
|
println!("Graceful shutdown complete");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user