use std::net; use std::str::FromStr; use crate::domain; use crate::error::unexpected::{self, Mappable}; use trust_dns_client::client::{AsyncClient, ClientHandle}; use trust_dns_client::rr::{DNSClass, Name, RData, RecordType}; use trust_dns_client::udp; #[derive(thiserror::Error, Debug)] pub enum NewDNSCheckerError { #[error("invalid resolver address")] InvalidResolverAddress, #[error(transparent)] Unexpected(#[from] unexpected::Error), } #[derive(thiserror::Error, Debug)] pub enum CheckDomainError { #[error("target A not set")] TargetANotSet, #[error("challenge token not set")] ChallengeTokenNotSet, #[error(transparent)] Unexpected(#[from] unexpected::Error), } pub struct DNSChecker { target_a: net::Ipv4Addr, // TODO we should use some kind of connection pool here, I suppose client: tokio::sync::Mutex, } pub async fn new( target_a: net::Ipv4Addr, resolver_addr: &str, ) -> Result { let resolver_addr = resolver_addr .parse() .map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?; let stream = udp::UdpClientStream::::new(resolver_addr); let (client, bg) = AsyncClient::connect(stream).await.or_unexpected()?; tokio::spawn(bg); Ok(DNSChecker { target_a, client: tokio::sync::Mutex::new(client), }) } impl DNSChecker { pub async fn check_domain( &self, domain: &domain::Name, challenge_token: &str, ) -> Result<(), CheckDomainError> { let domain = &domain.inner; // check that the A is installed correctly on the domain { let response = self .client .lock() .await .query(domain.clone(), DNSClass::IN, RecordType::A) .await .or_unexpected_while("querying A record")?; let records = response.answers(); if records.len() != 1 { return Err(CheckDomainError::TargetANotSet); } // if the single record isn't a A, or it's not the target A, then return // TargetANAMENotSet match records[0].data() { Some(RData::A(remote_a)) if remote_a == &self.target_a => (), _ => return Err(CheckDomainError::TargetANotSet), } } // check that the TXT record with the challenge token is correctly installed on the domain { let domain = Name::from_str("_domani_challenge") .or_unexpected_while("parsing TXT name")? .append_domain(domain) .or_unexpected_while("appending domain to TXT")?; let response = self .client .lock() .await .query(domain, DNSClass::IN, RecordType::TXT) .await .or_unexpected_while("querying TXT record")?; let records = response.answers(); if !records.iter().any(|record| -> bool { match record.data() { Some(RData::TXT(txt)) => txt.to_string().contains(challenge_token), _ => false, } }) { return Err(CheckDomainError::ChallengeTokenNotSet); } } Ok(()) } }