diff --git a/src/domain.rs b/src/domain.rs index e19a35f..15848e4 100644 --- a/src/domain.rs +++ b/src/domain.rs @@ -1,3 +1,72 @@ pub mod checker; pub mod config; pub mod manager; + +use std::str::FromStr; + +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use trust_dns_client::rr as trust_dns_rr; + +#[derive(Debug, Clone)] +/// Validated representation of a domain name +pub struct Name { + inner: trust_dns_rr::Name, + utf8_str: String, +} + +impl Name { + fn as_str(&self) -> &str { + self.utf8_str.as_str() + } +} + +impl FromStr for Name { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + let mut n = trust_dns_rr::Name::from_str(s)?; + let utf8_str = n.clone().to_utf8(); + + n.set_fqdn(true); + + Ok(Name { inner: n, utf8_str }) + } +} + +impl Serialize for Name { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +struct NameVisitor; + +impl<'de> de::Visitor<'de> for NameVisitor { + type Value = Name; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a valid domain name") + } + + fn visit_str(self, s: &str) -> Result + where + E: de::Error, + { + match Name::from_str(s) { + Ok(n) => Ok(n), + Err(e) => Err(E::custom(format!("invalid domain name: {}", e))), + } + } +} + +impl<'de> Deserialize<'de> for Name { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_str(NameVisitor) + } +} diff --git a/src/domain/checker.rs b/src/domain/checker.rs index 05518d2..c0f0014 100644 --- a/src/domain/checker.rs +++ b/src/domain/checker.rs @@ -1,6 +1,8 @@ use std::error::Error; use std::str::FromStr; +use crate::domain; + use mockall::automock; use trust_dns_client::client::{Client, SyncClient}; use trust_dns_client::rr::{DNSClass, Name, RData, RecordType}; @@ -35,7 +37,11 @@ pub enum CheckDomainError { #[automock] pub trait Checker { - fn check_domain(&self, domain: &str, challenge_token: &str) -> Result<(), CheckDomainError>; + fn check_domain( + &self, + domain: &domain::Name, + challenge_token: &str, + ) -> Result<(), CheckDomainError>; } pub struct DNSChecker { @@ -43,10 +49,10 @@ pub struct DNSChecker { client: SyncClient, } -pub fn new(target_cname: &str, resolver_addr: &str) -> Result { - let target_cname = - Name::from_str(target_cname).map_err(|_| NewDNSCheckerError::InvalidTargetCNAME)?; - +pub fn new( + target_cname: domain::Name, + resolver_addr: &str, +) -> Result { let resolver_addr = resolver_addr .parse() .map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?; @@ -57,21 +63,24 @@ pub fn new(target_cname: &str, resolver_addr: &str) -> Result Result<(), CheckDomainError> { - let mut fqdn = Name::from_str(domain).map_err(|_| CheckDomainError::InvalidDomainName)?; - fqdn.set_fqdn(true); + fn check_domain( + &self, + domain: &domain::Name, + challenge_token: &str, + ) -> Result<(), CheckDomainError> { + let domain = &domain.inner; // check that the CNAME is installed correctly on the domain { let response = self .client - .query(&fqdn, DNSClass::IN, RecordType::CNAME) + .query(domain, DNSClass::IN, RecordType::CNAME) .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; let records = response.answers(); @@ -90,14 +99,14 @@ impl Checker for DNSChecker { // check that the TXT record with the challenge token is correctly installed on the domain { - let fqdn = Name::from_str("_gateway") + let domain = Name::from_str("_gateway") .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))? - .append_domain(&fqdn) + .append_domain(domain) .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; let response = self .client - .query(&fqdn, DNSClass::IN, RecordType::TXT) + .query(&domain, DNSClass::IN, RecordType::TXT) .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; let records = response.answers(); diff --git a/src/domain/config.rs b/src/domain/config.rs index a66106c..e89fe5b 100644 --- a/src/domain/config.rs +++ b/src/domain/config.rs @@ -2,6 +2,7 @@ use std::error::Error; use std::path::{Path, PathBuf}; use std::{fs, io}; +use crate::domain; use crate::origin::Descr; use hex::ToHex; @@ -39,8 +40,8 @@ pub enum SetError { #[mockall::automock] pub trait Store { - fn get(&self, domain: &str) -> Result; - fn set(&self, domain: &str, config: &Config) -> Result<(), SetError>; + fn get(&self, domain: &domain::Name) -> Result; + fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError>; } struct FSStore { @@ -55,17 +56,17 @@ pub fn new(dir_path: &Path) -> io::Result { } impl FSStore { - fn config_dir_path(&self, domain: &str) -> PathBuf { - self.dir_path.join(domain) + fn config_dir_path(&self, domain: &domain::Name) -> PathBuf { + self.dir_path.join(domain.as_str()) } - fn config_file_path(&self, domain: &str) -> PathBuf { + fn config_file_path(&self, domain: &domain::Name) -> PathBuf { self.config_dir_path(domain).join("config.json") } } impl Store for FSStore { - fn get(&self, domain: &str) -> Result { + fn get(&self, domain: &domain::Name) -> Result { let config_file = fs::File::open(self.config_file_path(domain)).map_err(|e| match e.kind() { io::ErrorKind::NotFound => GetError::NotFound, @@ -75,7 +76,7 @@ impl Store for FSStore { Ok(serde_json::from_reader(config_file).map_err(|e| GetError::Unexpected(Box::from(e)))?) } - fn set(&self, domain: &str, config: &Config) -> Result<(), SetError> { + fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError> { fs::create_dir_all(self.config_dir_path(domain)) .map_err(|e| SetError::Unexpected(Box::from(e)))?; @@ -92,7 +93,11 @@ impl Store for FSStore { #[cfg(test)] mod tests { use super::*; + use crate::domain; use crate::origin::Descr; + + use std::str::FromStr; + use tempdir::TempDir; #[test] @@ -101,7 +106,7 @@ mod tests { let store = new(tmp_dir.path()).expect("store created"); - let domain = "foo"; + let domain = domain::Name::from_str("foo.com").expect("domain parsed"); let config = Config { origin_descr: Descr::Git { @@ -111,12 +116,12 @@ mod tests { }; assert!(matches!( - store.get(domain), + store.get(&domain), Err::(GetError::NotFound) )); - store.set(domain, &config).expect("config set"); - assert_eq!(config, store.get(domain).expect("config retrieved")); + store.set(&domain, &config).expect("config set"); + assert_eq!(config, store.get(&domain).expect("config retrieved")); let new_config = Config { origin_descr: Descr::Git { @@ -125,7 +130,7 @@ mod tests { }, }; - store.set(domain, &new_config).expect("config set"); - assert_eq!(new_config, store.get(domain).expect("config retrieved")); + store.set(&domain, &new_config).expect("config set"); + assert_eq!(new_config, store.get(&domain).expect("config retrieved")); } } diff --git a/src/domain/manager.rs b/src/domain/manager.rs index 2e297f6..3690b97 100644 --- a/src/domain/manager.rs +++ b/src/domain/manager.rs @@ -1,4 +1,4 @@ -use crate::domain::{checker, config}; +use crate::domain::{self, checker, config}; use crate::origin; use std::error::Error; @@ -121,12 +121,12 @@ pub trait Manager { where Self: 'mgr; - fn get_config(&self, domain: &str) -> Result; - fn get_origin(&self, domain: &str) -> Result, GetOriginError>; - fn sync(&self, domain: &str) -> Result<(), SyncError>; + fn get_config(&self, domain: &domain::Name) -> Result; + fn get_origin(&self, domain: &domain::Name) -> Result, GetOriginError>; + fn sync(&self, domain: &domain::Name) -> Result<(), SyncError>; fn sync_with_config( &self, - domain: &str, + domain: &domain::Name, config: &config::Config, ) -> Result<(), SyncWithConfigError>; } @@ -169,11 +169,11 @@ where type Origin<'mgr> = OriginStore::Origin<'mgr> where Self: 'mgr; - fn get_config(&self, domain: &str) -> Result { + fn get_config(&self, domain: &domain::Name) -> Result { Ok(self.domain_config_store.get(domain)?) } - fn get_origin(&self, domain: &str) -> Result, GetOriginError> { + fn get_origin(&self, domain: &domain::Name) -> Result, GetOriginError> { let config = self.domain_config_store.get(domain)?; let origin = self .origin_store @@ -183,7 +183,7 @@ where Ok(origin) } - fn sync(&self, domain: &str) -> Result<(), SyncError> { + fn sync(&self, domain: &domain::Name) -> Result<(), SyncError> { let config = self.domain_config_store.get(domain)?; self.origin_store .sync(config.origin_descr, origin::store::Limits {}) @@ -196,7 +196,7 @@ where fn sync_with_config( &self, - domain: &str, + domain: &domain::Name, config: &config::Config, ) -> Result<(), SyncWithConfigError> { let config_hash = config diff --git a/src/main.rs b/src/main.rs index 2b74e13..8bdec32 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,16 +16,16 @@ struct Cli { http_listen_addr: SocketAddr, #[arg(long, required = true, env = "GATEWAY_ORIGIN_STORE_GIT_DIR_PATH")] - origin_store_git_dir_path: Option, + origin_store_git_dir_path: path::PathBuf, #[arg(long, required = true, env = "GATEWAY_DOMAIN_CHECKER_TARGET_CNAME")] - domain_checker_target_cname: Option, + domain_checker_target_cname: gateway::domain::Name, #[arg(long, default_value_t = String::from("1.1.1.1:53"), env = "GATEWAY_DOMAIN_CHECKER_RESOLVER_ADDR")] domain_checker_resolver_addr: String, #[arg(long, required = true, env = "GATEWAY_DOMAIN_CONFIG_STORE_DIR_PATH")] - domain_config_store_dir_path: Option, + domain_config_store_dir_path: path::PathBuf, } #[tokio::main] @@ -55,18 +55,17 @@ async fn main() { stop_ch_rx }; - let origin_store = gateway::origin::store::git::new(config.origin_store_git_dir_path.unwrap()) + let origin_store = gateway::origin::store::git::new(config.origin_store_git_dir_path) .expect("git origin store initialized"); let domain_checker = gateway::domain::checker::new( - &config.domain_checker_target_cname.unwrap(), + config.domain_checker_target_cname, &config.domain_checker_resolver_addr, ) .expect("domain checker initialized"); - let domain_config_store = - gateway::domain::config::new(&config.domain_config_store_dir_path.unwrap()) - .expect("domain config store initialized"); + let domain_config_store = gateway::domain::config::new(&config.domain_config_store_dir_path) + .expect("domain config store initialized"); let manager = gateway::domain::manager::new(origin_store, domain_config_store, domain_checker);