use a validated domain name type, rather than a raw string

This commit is contained in:
Brian Picciano 2023-05-12 15:19:24 +02:00
parent cf3b11862c
commit 7718735215
5 changed files with 125 additions and 43 deletions

View File

@ -1,3 +1,72 @@
pub mod checker; pub mod checker;
pub mod config; pub mod config;
pub mod manager; pub mod manager;
use std::str::FromStr;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use trust_dns_client::rr as trust_dns_rr;
#[derive(Debug, Clone)]
/// Validated representation of a domain name
pub struct Name {
inner: trust_dns_rr::Name,
utf8_str: String,
}
impl Name {
fn as_str(&self) -> &str {
self.utf8_str.as_str()
}
}
impl FromStr for Name {
type Err = <trust_dns_rr::Name as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut n = trust_dns_rr::Name::from_str(s)?;
let utf8_str = n.clone().to_utf8();
n.set_fqdn(true);
Ok(Name { inner: n, utf8_str })
}
}
impl Serialize for Name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.as_str())
}
}
struct NameVisitor;
impl<'de> de::Visitor<'de> for NameVisitor {
type Value = Name;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a valid domain name")
}
fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
match Name::from_str(s) {
Ok(n) => Ok(n),
Err(e) => Err(E::custom(format!("invalid domain name: {}", e))),
}
}
}
impl<'de> Deserialize<'de> for Name {
fn deserialize<D>(deserializer: D) -> Result<Name, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(NameVisitor)
}
}

View File

@ -1,6 +1,8 @@
use std::error::Error; use std::error::Error;
use std::str::FromStr; use std::str::FromStr;
use crate::domain;
use mockall::automock; use mockall::automock;
use trust_dns_client::client::{Client, SyncClient}; use trust_dns_client::client::{Client, SyncClient};
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType}; use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
@ -35,7 +37,11 @@ pub enum CheckDomainError {
#[automock] #[automock]
pub trait Checker { pub trait Checker {
fn check_domain(&self, domain: &str, challenge_token: &str) -> Result<(), CheckDomainError>; fn check_domain(
&self,
domain: &domain::Name,
challenge_token: &str,
) -> Result<(), CheckDomainError>;
} }
pub struct DNSChecker { pub struct DNSChecker {
@ -43,10 +49,10 @@ pub struct DNSChecker {
client: SyncClient<UdpClientConnection>, client: SyncClient<UdpClientConnection>,
} }
pub fn new(target_cname: &str, resolver_addr: &str) -> Result<impl Checker, NewDNSCheckerError> { pub fn new(
let target_cname = target_cname: domain::Name,
Name::from_str(target_cname).map_err(|_| NewDNSCheckerError::InvalidTargetCNAME)?; resolver_addr: &str,
) -> Result<impl Checker, NewDNSCheckerError> {
let resolver_addr = resolver_addr let resolver_addr = resolver_addr
.parse() .parse()
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?; .map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
@ -57,21 +63,24 @@ pub fn new(target_cname: &str, resolver_addr: &str) -> Result<impl Checker, NewD
let client = SyncClient::new(conn); let client = SyncClient::new(conn);
Ok(DNSChecker { Ok(DNSChecker {
target_cname, target_cname: target_cname.inner,
client, client,
}) })
} }
impl Checker for DNSChecker { impl Checker for DNSChecker {
fn check_domain(&self, domain: &str, challenge_token: &str) -> Result<(), CheckDomainError> { fn check_domain(
let mut fqdn = Name::from_str(domain).map_err(|_| CheckDomainError::InvalidDomainName)?; &self,
fqdn.set_fqdn(true); domain: &domain::Name,
challenge_token: &str,
) -> Result<(), CheckDomainError> {
let domain = &domain.inner;
// check that the CNAME is installed correctly on the domain // check that the CNAME is installed correctly on the domain
{ {
let response = self let response = self
.client .client
.query(&fqdn, DNSClass::IN, RecordType::CNAME) .query(domain, DNSClass::IN, RecordType::CNAME)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let records = response.answers(); let records = response.answers();
@ -90,14 +99,14 @@ impl Checker for DNSChecker {
// check that the TXT record with the challenge token is correctly installed on the domain // check that the TXT record with the challenge token is correctly installed on the domain
{ {
let fqdn = Name::from_str("_gateway") let domain = Name::from_str("_gateway")
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))? .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?
.append_domain(&fqdn) .append_domain(domain)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let response = self let response = self
.client .client
.query(&fqdn, DNSClass::IN, RecordType::TXT) .query(&domain, DNSClass::IN, RecordType::TXT)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?; .map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let records = response.answers(); let records = response.answers();

View File

@ -2,6 +2,7 @@ use std::error::Error;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::{fs, io}; use std::{fs, io};
use crate::domain;
use crate::origin::Descr; use crate::origin::Descr;
use hex::ToHex; use hex::ToHex;
@ -39,8 +40,8 @@ pub enum SetError {
#[mockall::automock] #[mockall::automock]
pub trait Store { pub trait Store {
fn get(&self, domain: &str) -> Result<Config, GetError>; fn get(&self, domain: &domain::Name) -> Result<Config, GetError>;
fn set(&self, domain: &str, config: &Config) -> Result<(), SetError>; fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError>;
} }
struct FSStore { struct FSStore {
@ -55,17 +56,17 @@ pub fn new(dir_path: &Path) -> io::Result<impl Store> {
} }
impl FSStore { impl FSStore {
fn config_dir_path(&self, domain: &str) -> PathBuf { fn config_dir_path(&self, domain: &domain::Name) -> PathBuf {
self.dir_path.join(domain) self.dir_path.join(domain.as_str())
} }
fn config_file_path(&self, domain: &str) -> PathBuf { fn config_file_path(&self, domain: &domain::Name) -> PathBuf {
self.config_dir_path(domain).join("config.json") self.config_dir_path(domain).join("config.json")
} }
} }
impl Store for FSStore { impl Store for FSStore {
fn get(&self, domain: &str) -> Result<Config, GetError> { fn get(&self, domain: &domain::Name) -> Result<Config, GetError> {
let config_file = let config_file =
fs::File::open(self.config_file_path(domain)).map_err(|e| match e.kind() { fs::File::open(self.config_file_path(domain)).map_err(|e| match e.kind() {
io::ErrorKind::NotFound => GetError::NotFound, io::ErrorKind::NotFound => GetError::NotFound,
@ -75,7 +76,7 @@ impl Store for FSStore {
Ok(serde_json::from_reader(config_file).map_err(|e| GetError::Unexpected(Box::from(e)))?) Ok(serde_json::from_reader(config_file).map_err(|e| GetError::Unexpected(Box::from(e)))?)
} }
fn set(&self, domain: &str, config: &Config) -> Result<(), SetError> { fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError> {
fs::create_dir_all(self.config_dir_path(domain)) fs::create_dir_all(self.config_dir_path(domain))
.map_err(|e| SetError::Unexpected(Box::from(e)))?; .map_err(|e| SetError::Unexpected(Box::from(e)))?;
@ -92,7 +93,11 @@ impl Store for FSStore {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::domain;
use crate::origin::Descr; use crate::origin::Descr;
use std::str::FromStr;
use tempdir::TempDir; use tempdir::TempDir;
#[test] #[test]
@ -101,7 +106,7 @@ mod tests {
let store = new(tmp_dir.path()).expect("store created"); let store = new(tmp_dir.path()).expect("store created");
let domain = "foo"; let domain = domain::Name::from_str("foo.com").expect("domain parsed");
let config = Config { let config = Config {
origin_descr: Descr::Git { origin_descr: Descr::Git {
@ -111,12 +116,12 @@ mod tests {
}; };
assert!(matches!( assert!(matches!(
store.get(domain), store.get(&domain),
Err::<Config, GetError>(GetError::NotFound) Err::<Config, GetError>(GetError::NotFound)
)); ));
store.set(domain, &config).expect("config set"); store.set(&domain, &config).expect("config set");
assert_eq!(config, store.get(domain).expect("config retrieved")); assert_eq!(config, store.get(&domain).expect("config retrieved"));
let new_config = Config { let new_config = Config {
origin_descr: Descr::Git { origin_descr: Descr::Git {
@ -125,7 +130,7 @@ mod tests {
}, },
}; };
store.set(domain, &new_config).expect("config set"); store.set(&domain, &new_config).expect("config set");
assert_eq!(new_config, store.get(domain).expect("config retrieved")); assert_eq!(new_config, store.get(&domain).expect("config retrieved"));
} }
} }

View File

@ -1,4 +1,4 @@
use crate::domain::{checker, config}; use crate::domain::{self, checker, config};
use crate::origin; use crate::origin;
use std::error::Error; use std::error::Error;
@ -121,12 +121,12 @@ pub trait Manager {
where where
Self: 'mgr; Self: 'mgr;
fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError>; fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError>;
fn get_origin(&self, domain: &str) -> Result<Self::Origin<'_>, GetOriginError>; fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError>;
fn sync(&self, domain: &str) -> Result<(), SyncError>; fn sync(&self, domain: &domain::Name) -> Result<(), SyncError>;
fn sync_with_config( fn sync_with_config(
&self, &self,
domain: &str, domain: &domain::Name,
config: &config::Config, config: &config::Config,
) -> Result<(), SyncWithConfigError>; ) -> Result<(), SyncWithConfigError>;
} }
@ -169,11 +169,11 @@ where
type Origin<'mgr> = OriginStore::Origin<'mgr> type Origin<'mgr> = OriginStore::Origin<'mgr>
where Self: 'mgr; where Self: 'mgr;
fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError> { fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError> {
Ok(self.domain_config_store.get(domain)?) Ok(self.domain_config_store.get(domain)?)
} }
fn get_origin(&self, domain: &str) -> Result<Self::Origin<'_>, GetOriginError> { fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError> {
let config = self.domain_config_store.get(domain)?; let config = self.domain_config_store.get(domain)?;
let origin = self let origin = self
.origin_store .origin_store
@ -183,7 +183,7 @@ where
Ok(origin) Ok(origin)
} }
fn sync(&self, domain: &str) -> Result<(), SyncError> { fn sync(&self, domain: &domain::Name) -> Result<(), SyncError> {
let config = self.domain_config_store.get(domain)?; let config = self.domain_config_store.get(domain)?;
self.origin_store self.origin_store
.sync(config.origin_descr, origin::store::Limits {}) .sync(config.origin_descr, origin::store::Limits {})
@ -196,7 +196,7 @@ where
fn sync_with_config( fn sync_with_config(
&self, &self,
domain: &str, domain: &domain::Name,
config: &config::Config, config: &config::Config,
) -> Result<(), SyncWithConfigError> { ) -> Result<(), SyncWithConfigError> {
let config_hash = config let config_hash = config

View File

@ -16,16 +16,16 @@ struct Cli {
http_listen_addr: SocketAddr, http_listen_addr: SocketAddr,
#[arg(long, required = true, env = "GATEWAY_ORIGIN_STORE_GIT_DIR_PATH")] #[arg(long, required = true, env = "GATEWAY_ORIGIN_STORE_GIT_DIR_PATH")]
origin_store_git_dir_path: Option<path::PathBuf>, origin_store_git_dir_path: path::PathBuf,
#[arg(long, required = true, env = "GATEWAY_DOMAIN_CHECKER_TARGET_CNAME")] #[arg(long, required = true, env = "GATEWAY_DOMAIN_CHECKER_TARGET_CNAME")]
domain_checker_target_cname: Option<String>, domain_checker_target_cname: gateway::domain::Name,
#[arg(long, default_value_t = String::from("1.1.1.1:53"), env = "GATEWAY_DOMAIN_CHECKER_RESOLVER_ADDR")] #[arg(long, default_value_t = String::from("1.1.1.1:53"), env = "GATEWAY_DOMAIN_CHECKER_RESOLVER_ADDR")]
domain_checker_resolver_addr: String, domain_checker_resolver_addr: String,
#[arg(long, required = true, env = "GATEWAY_DOMAIN_CONFIG_STORE_DIR_PATH")] #[arg(long, required = true, env = "GATEWAY_DOMAIN_CONFIG_STORE_DIR_PATH")]
domain_config_store_dir_path: Option<path::PathBuf>, domain_config_store_dir_path: path::PathBuf,
} }
#[tokio::main] #[tokio::main]
@ -55,18 +55,17 @@ async fn main() {
stop_ch_rx stop_ch_rx
}; };
let origin_store = gateway::origin::store::git::new(config.origin_store_git_dir_path.unwrap()) let origin_store = gateway::origin::store::git::new(config.origin_store_git_dir_path)
.expect("git origin store initialized"); .expect("git origin store initialized");
let domain_checker = gateway::domain::checker::new( let domain_checker = gateway::domain::checker::new(
&config.domain_checker_target_cname.unwrap(), config.domain_checker_target_cname,
&config.domain_checker_resolver_addr, &config.domain_checker_resolver_addr,
) )
.expect("domain checker initialized"); .expect("domain checker initialized");
let domain_config_store = let domain_config_store = gateway::domain::config::new(&config.domain_config_store_dir_path)
gateway::domain::config::new(&config.domain_config_store_dir_path.unwrap()) .expect("domain config store initialized");
.expect("domain config store initialized");
let manager = gateway::domain::manager::new(origin_store, domain_config_store, domain_checker); let manager = gateway::domain::manager::new(origin_store, domain_config_store, domain_checker);