diff --git a/.dev-config.yml b/.dev-config.yml index 280ee7b..407a94a 100644 --- a/.dev-config.yml +++ b/.dev-config.yml @@ -2,9 +2,8 @@ origin: store_dir_path: /tmp/domani_dev_env/origin domain: store_dir_path: /tmp/domani_dev_env/domain - dns: - target_records: - - type: A - addr: 127.0.0.1 service: passphrase: foobar + dns_records: + - type: A + addr: 127.0.0.1 diff --git a/src/domain/checker.rs b/src/domain/checker.rs index 17acc2e..a2ab465 100644 --- a/src/domain/checker.rs +++ b/src/domain/checker.rs @@ -1,4 +1,5 @@ use std::net; +use std::ops::DerefMut; use std::str::FromStr; use crate::domain; @@ -10,8 +11,8 @@ use trust_dns_client::udp; #[derive(thiserror::Error, Debug)] pub enum CheckDomainError { - #[error("target A not set")] - TargetANotSet, + #[error("no service dns records set")] + ServiceDNSRecordsNotSet, #[error("challenge token not set")] ChallengeTokenNotSet, @@ -20,31 +21,65 @@ pub enum CheckDomainError { Unexpected(#[from] unexpected::Error), } -pub struct DNSChecker { - target_a: net::Ipv4Addr, +pub enum DNSRecord { + A(net::Ipv4Addr), +} +impl DNSRecord { + async fn check_a( + client: &mut AsyncClient, + domain: &trust_dns_client::rr::Name, + addr: &net::Ipv4Addr, + ) -> Result { + let response = client + .query(domain.clone(), DNSClass::IN, RecordType::A) + .await + .or_unexpected_while("querying A record")?; + + let records = response.answers(); + + if records.len() != 1 { + return Ok(false); + } + + // if the single record isn't a A, or it's not the target A, then return + // TargetANAMENotSet + match records[0].data() { + Some(RData::A(remote_a)) if remote_a == addr => Ok(true), + _ => return Ok(false), + } + } + + async fn check( + &self, + client: &mut AsyncClient, + domain: &trust_dns_client::rr::Name, + ) -> Result { + match self { + Self::A(addr) => Self::check_a(client, domain, &addr).await, + } + } +} + +pub struct DNSChecker { // TODO we should use some kind of connection pool here, I suppose client: tokio::sync::Mutex, + service_dns_records: Vec, } impl DNSChecker { - pub async fn new(config: &domain::ConfigDNS) -> Result { - let target_a = match config - .target_records - .get(0) - .expect("at least one target record expected") - { - domain::ConfigDNSTargetRecord::A { addr } => addr.clone(), - }; - + pub async fn new( + config: &domain::ConfigDNS, + service_dns_records: Vec, + ) -> Result { let stream = udp::UdpClientStream::::new(config.resolver_addr); let (client, bg) = AsyncClient::connect(stream).await.or_unexpected()?; tokio::spawn(bg); // TODO there should be a mechanism to clean this up Ok(Self { - target_a, client: tokio::sync::Mutex::new(client), + service_dns_records, }) } @@ -55,30 +90,6 @@ impl DNSChecker { ) -> Result<(), CheckDomainError> { let domain = domain.as_rr(); - // check that the A is installed correctly on the domain - { - let response = self - .client - .lock() - .await - .query(domain.clone(), DNSClass::IN, RecordType::A) - .await - .or_unexpected_while("querying A record")?; - - let records = response.answers(); - - if records.len() != 1 { - return Err(CheckDomainError::TargetANotSet); - } - - // if the single record isn't a A, or it's not the target A, then return - // TargetANAMENotSet - match records[0].data() { - Some(RData::A(remote_a)) if remote_a == &self.target_a => (), - _ => return Err(CheckDomainError::TargetANotSet), - } - } - // check that the TXT record with the challenge token is correctly installed on the domain { let domain = Name::from_str("_domani_challenge") @@ -106,6 +117,16 @@ impl DNSChecker { } } - Ok(()) + // check that one of the possible DNS records is installed on the domain + for record in &self.service_dns_records { + let mut client = self.client.lock().await; + match record.check(client.deref_mut(), domain).await { + Ok(true) => return Ok(()), + Ok(false) => (), + Err(e) => return Err(e.into()), + } + } + + Err(CheckDomainError::ServiceDNSRecordsNotSet) } } diff --git a/src/domain/config.rs b/src/domain/config.rs index a89afa5..65855f9 100644 --- a/src/domain/config.rs +++ b/src/domain/config.rs @@ -2,12 +2,6 @@ use std::{net, path, str::FromStr}; use serde::Deserialize; -#[derive(Deserialize)] -#[serde(tag = "type")] -pub enum ConfigDNSTargetRecord { - A { addr: net::Ipv4Addr }, -} - fn default_resolver_addr() -> net::SocketAddr { net::SocketAddr::from_str("1.1.1.1:53").unwrap() } @@ -16,7 +10,14 @@ fn default_resolver_addr() -> net::SocketAddr { pub struct ConfigDNS { #[serde(default = "default_resolver_addr")] pub resolver_addr: net::SocketAddr, - pub target_records: Vec, +} + +impl Default for ConfigDNS { + fn default() -> Self { + Self { + resolver_addr: default_resolver_addr(), + } + } } #[derive(Deserialize)] @@ -27,6 +28,7 @@ pub struct ConfigACME { #[derive(Deserialize)] pub struct Config { pub store_dir_path: path::PathBuf, + #[serde(default)] pub dns: ConfigDNS, pub acme: Option, } diff --git a/src/domain/manager.rs b/src/domain/manager.rs index 937cd4c..16f369d 100644 --- a/src/domain/manager.rs +++ b/src/domain/manager.rs @@ -89,8 +89,8 @@ pub enum SyncWithConfigError { #[error("already in progress")] AlreadyInProgress, - #[error("target A/AAAA not set")] - TargetANotSet, + #[error("no service dns records set")] + ServiceDNSRecordsNotSet, #[error("challenge token not set")] ChallengeTokenNotSet, @@ -113,7 +113,9 @@ impl From for SyncWithConfigError { impl From for SyncWithConfigError { fn from(e: checker::CheckDomainError) -> SyncWithConfigError { match e { - checker::CheckDomainError::TargetANotSet => SyncWithConfigError::TargetANotSet, + checker::CheckDomainError::ServiceDNSRecordsNotSet => { + SyncWithConfigError::ServiceDNSRecordsNotSet + } checker::CheckDomainError::ChallengeTokenNotSet => { SyncWithConfigError::ChallengeTokenNotSet } diff --git a/src/main.rs b/src/main.rs index bf2690b..a056441 100644 --- a/src/main.rs +++ b/src/main.rs @@ -53,9 +53,15 @@ async fn main() { let origin_store = domani::origin::git::FSStore::new(&config.origin) .expect("git origin store initialization failed"); - let domain_checker = domani::domain::checker::DNSChecker::new(&config.domain.dns) + let domain_checker = { + let dns_records = config.service.dns_records.clone(); + domani::domain::checker::DNSChecker::new( + &config.domain.dns, + dns_records.into_iter().map(|r| r.into()).collect(), + ) .await - .expect("domain checker initialization failed"); + .expect("domain checker initialization failed") + }; let domain_config_store = domani::domain::store::FSStore::new(&config.domain.store_dir_path.join("domains")) @@ -95,7 +101,6 @@ async fn main() { domain_manager.clone(), domain_manager.clone(), config.service, - config.domain.dns.target_records, ); let mut signals = diff --git a/src/service.rs b/src/service.rs index dc06589..8075300 100644 --- a/src/service.rs +++ b/src/service.rs @@ -3,17 +3,32 @@ mod util; use crate::domain; use serde::Deserialize; -use std::str::FromStr; +use std::{net, str::FromStr}; fn default_primary_domain() -> domain::Name { domain::Name::from_str("localhost").unwrap() } +#[derive(Deserialize, Clone)] +#[serde(tag = "type")] +pub enum ConfigDNSRecord { + A { addr: net::Ipv4Addr }, +} + +impl From for domain::checker::DNSRecord { + fn from(r: ConfigDNSRecord) -> Self { + match r { + ConfigDNSRecord::A { addr } => Self::A(addr), + } + } +} + #[derive(Deserialize)] pub struct Config { #[serde(default = "default_primary_domain")] pub primary_domain: domain::Name, pub passphrase: String, + pub dns_records: Vec, #[serde(default)] pub http: self::http::Config, } diff --git a/src/service/http.rs b/src/service/http.rs index 0a545eb..26607b7 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -18,7 +18,6 @@ pub struct Service { cert_resolver: sync::Arc, handlebars: handlebars::Handlebars<'static>, config: service::Config, - dns_target_records: Vec, } pub fn new( @@ -26,15 +25,14 @@ pub fn new( domain_manager: sync::Arc, cert_resolver: sync::Arc, config: service::Config, - dns_target_records: Vec, ) -> sync::Arc { let https_enabled = config.http.https_addr.is_some(); + let service = sync::Arc::new(Service { domain_manager: domain_manager.clone(), cert_resolver, handlebars: tpl::get(), config, - dns_target_records, }); task_stack.push_spawn(|canceller| tasks::listen_http(service.clone(), canceller)); @@ -252,11 +250,12 @@ impl<'svc> Service { }; let target_a = match self - .dns_target_records + .config + .dns_records .get(0) .expect("at least one target record expected") { - domain::ConfigDNSTargetRecord::A { addr } => addr.clone(), + service::ConfigDNSRecord::A { addr } => addr.clone(), }; self.render_page( @@ -303,7 +302,7 @@ impl<'svc> Service { Err(domain::manager::SyncWithConfigError::InvalidURL) => Some("Fetching the git repository failed, please double check that you input the correct URL.".to_string()), Err(domain::manager::SyncWithConfigError::InvalidBranchName) => Some("The git repository does not have a branch of the given name, please double check that you input the correct name.".to_string()), Err(domain::manager::SyncWithConfigError::AlreadyInProgress) => Some("The configuration of your domain is still in progress, please refresh in a few minutes.".to_string()), - Err(domain::manager::SyncWithConfigError::TargetANotSet) => Some("The A record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()), + Err(domain::manager::SyncWithConfigError::ServiceDNSRecordsNotSet) => Some("None of the expected service DNS records were set on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()), Err(domain::manager::SyncWithConfigError::ChallengeTokenNotSet) => Some("The TXT record is not set correctly on the domain. Please double check that you put the correct value on the record. If the value is correct, then most likely the updated records have not yet propagated. In this case you can refresh in a few minutes to try again.".to_string()), Err(domain::manager::SyncWithConfigError::Unexpected(e)) => Some(format!("An unexpected error occurred: {e}")), };