|
|
|
@ -1,14 +1,15 @@ |
|
|
|
|
use std::net; |
|
|
|
|
use std::ops::DerefMut; |
|
|
|
|
use std::str::FromStr; |
|
|
|
|
|
|
|
|
|
use crate::domain; |
|
|
|
|
use crate::error::unexpected::{self, Mappable}; |
|
|
|
|
use crate::{domain, token}; |
|
|
|
|
|
|
|
|
|
use trust_dns_client::client::{AsyncClient, ClientHandle}; |
|
|
|
|
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType}; |
|
|
|
|
use trust_dns_client::udp; |
|
|
|
|
|
|
|
|
|
use rand::Rng; |
|
|
|
|
|
|
|
|
|
#[derive(thiserror::Error, Debug)] |
|
|
|
|
pub enum CheckDomainError { |
|
|
|
|
#[error("no service dns records set")] |
|
|
|
@ -27,123 +28,50 @@ pub enum DNSRecord { |
|
|
|
|
CNAME(domain::Name), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl DNSRecord { |
|
|
|
|
async fn check_a( |
|
|
|
|
client: &mut AsyncClient, |
|
|
|
|
domain: &trust_dns_client::rr::Name, |
|
|
|
|
addr: &net::Ipv4Addr, |
|
|
|
|
) -> Result<bool, unexpected::Error> { |
|
|
|
|
let response = client |
|
|
|
|
.query(domain.clone(), DNSClass::IN, RecordType::A) |
|
|
|
|
.await |
|
|
|
|
.or_unexpected_while("querying A record")?; |
|
|
|
|
|
|
|
|
|
let records = response.answers(); |
|
|
|
|
|
|
|
|
|
for record in records { |
|
|
|
|
if let Some(RData::A(record_addr)) = record.data() { |
|
|
|
|
if record_addr == addr { |
|
|
|
|
return Ok(true); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return Ok(false); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
async fn check_aaaa( |
|
|
|
|
client: &mut AsyncClient, |
|
|
|
|
domain: &trust_dns_client::rr::Name, |
|
|
|
|
addr: &net::Ipv6Addr, |
|
|
|
|
) -> Result<bool, unexpected::Error> { |
|
|
|
|
let response = client |
|
|
|
|
.query(domain.clone(), DNSClass::IN, RecordType::AAAA) |
|
|
|
|
.await |
|
|
|
|
.or_unexpected_while("querying AAAA record")?; |
|
|
|
|
|
|
|
|
|
let records = response.answers(); |
|
|
|
|
|
|
|
|
|
for record in records { |
|
|
|
|
if let Some(RData::AAAA(record_addr)) = record.data() { |
|
|
|
|
if record_addr == addr { |
|
|
|
|
return Ok(true); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return Ok(false); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
async fn check_cname( |
|
|
|
|
client: &mut AsyncClient, |
|
|
|
|
domain: &trust_dns_client::rr::Name, |
|
|
|
|
cname: &trust_dns_client::rr::Name, |
|
|
|
|
) -> Result<bool, unexpected::Error> { |
|
|
|
|
let response = client |
|
|
|
|
.query(domain.clone(), DNSClass::IN, RecordType::CNAME) |
|
|
|
|
.await |
|
|
|
|
.or_unexpected_while("querying CNAME record")?; |
|
|
|
|
|
|
|
|
|
let records = response.answers(); |
|
|
|
|
|
|
|
|
|
for record in records { |
|
|
|
|
if let Some(RData::CNAME(record_cname)) = record.data() { |
|
|
|
|
if record_cname == cname { |
|
|
|
|
return Ok(true); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return Ok(false); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
async fn check( |
|
|
|
|
&self, |
|
|
|
|
client: &mut AsyncClient, |
|
|
|
|
domain: &trust_dns_client::rr::Name, |
|
|
|
|
) -> Result<bool, unexpected::Error> { |
|
|
|
|
match self { |
|
|
|
|
Self::A(addr) => Self::check_a(client, domain, &addr).await, |
|
|
|
|
Self::AAAA(addr) => Self::check_aaaa(client, domain, &addr).await, |
|
|
|
|
Self::CNAME(name) => Self::check_cname(client, domain, name.as_rr()).await, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub struct DNSChecker { |
|
|
|
|
// TODO we should use some kind of connection pool here, I suppose
|
|
|
|
|
client: tokio::sync::Mutex<AsyncClient>, |
|
|
|
|
service_dns_records: Vec<DNSRecord>, |
|
|
|
|
token_store: Box<dyn token::Store + Send + Sync>, |
|
|
|
|
service_primary_domain: domain::Name, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
impl DNSChecker { |
|
|
|
|
pub async fn new( |
|
|
|
|
pub async fn new<TokenStore>( |
|
|
|
|
token_store: TokenStore, |
|
|
|
|
config: &domain::ConfigDNS, |
|
|
|
|
service_dns_records: Vec<DNSRecord>, |
|
|
|
|
) -> Result<Self, unexpected::Error> { |
|
|
|
|
service_primary_domain: domain::Name, |
|
|
|
|
) -> Result<Self, unexpected::Error> |
|
|
|
|
where |
|
|
|
|
TokenStore: token::Store + Send + Sync + 'static, |
|
|
|
|
{ |
|
|
|
|
let stream = udp::UdpClientStream::<tokio::net::UdpSocket>::new(config.resolver_addr); |
|
|
|
|
let (client, bg) = AsyncClient::connect(stream).await.or_unexpected()?; |
|
|
|
|
tokio::spawn(bg); |
|
|
|
|
// TODO there should be a mechanism to clean this up
|
|
|
|
|
|
|
|
|
|
Ok(Self { |
|
|
|
|
token_store: Box::from(token_store), |
|
|
|
|
client: tokio::sync::Mutex::new(client), |
|
|
|
|
service_dns_records, |
|
|
|
|
service_primary_domain, |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub fn get_challenge_token(&self, domain: &domain::Name) -> unexpected::Result<Option<String>> { |
|
|
|
|
self.token_store.get(domain.as_str()) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
pub async fn check_domain( |
|
|
|
|
&self, |
|
|
|
|
domain: &domain::Name, |
|
|
|
|
challenge_token: &str, |
|
|
|
|
) -> Result<(), CheckDomainError> { |
|
|
|
|
let domain = domain.as_rr(); |
|
|
|
|
let domain_rr = domain.as_rr(); |
|
|
|
|
|
|
|
|
|
// check that the TXT record with the challenge token is correctly installed on the domain
|
|
|
|
|
{ |
|
|
|
|
let domain = Name::from_str("_domani_challenge") |
|
|
|
|
.or_unexpected_while("parsing TXT name")? |
|
|
|
|
.append_domain(domain) |
|
|
|
|
.append_domain(domain_rr) |
|
|
|
|
.or_unexpected_while("appending domain to TXT")?; |
|
|
|
|
|
|
|
|
|
let response = self |
|
|
|
@ -166,16 +94,40 @@ impl DNSChecker { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// check that one of the possible DNS records is installed on the domain
|
|
|
|
|
for record in &self.service_dns_records { |
|
|
|
|
let mut client = self.client.lock().await; |
|
|
|
|
match record.check(client.deref_mut(), domain).await { |
|
|
|
|
Ok(true) => return Ok(()), |
|
|
|
|
Ok(false) => (), |
|
|
|
|
Err(e) => return Err(e.into()), |
|
|
|
|
} |
|
|
|
|
// check that DNS correctly resolves for the domain. This is done by serving an HTTP
|
|
|
|
|
// challenge on the domain, which we then query for here.
|
|
|
|
|
//
|
|
|
|
|
// first store the challenge token, so that the HTTP server can find it via
|
|
|
|
|
// get_challenge_token.
|
|
|
|
|
let token: String = rand::thread_rng() |
|
|
|
|
.sample_iter(rand::distributions::Alphanumeric) |
|
|
|
|
.take(16) |
|
|
|
|
.map(char::from) |
|
|
|
|
.collect(); |
|
|
|
|
|
|
|
|
|
self.token_store |
|
|
|
|
.set(domain.as_str().to_string(), token.clone()) |
|
|
|
|
.or_unexpected_while("storing challenge token")?; |
|
|
|
|
|
|
|
|
|
let body = match reqwest::get(format!( |
|
|
|
|
"http://{}/.well-known/domani-challenge", |
|
|
|
|
self.service_primary_domain.as_str() |
|
|
|
|
)) |
|
|
|
|
.await |
|
|
|
|
{ |
|
|
|
|
Err(_) => return Err(CheckDomainError::ServiceDNSRecordsNotSet), |
|
|
|
|
Ok(res) => res |
|
|
|
|
.error_for_status() |
|
|
|
|
.or(Err(CheckDomainError::ServiceDNSRecordsNotSet))? |
|
|
|
|
.text() |
|
|
|
|
.await |
|
|
|
|
.or(Err(CheckDomainError::ServiceDNSRecordsNotSet))?, |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
if body != token { |
|
|
|
|
return Err(CheckDomainError::ServiceDNSRecordsNotSet); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
Err(CheckDomainError::ServiceDNSRecordsNotSet) |
|
|
|
|
Ok(()) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|