use a validated domain name type, rather than a raw string

This commit is contained in:
Brian Picciano 2023-05-12 15:19:24 +02:00
parent cf3b11862c
commit 7718735215
5 changed files with 125 additions and 43 deletions

View File

@ -1,3 +1,72 @@
pub mod checker;
pub mod config;
pub mod manager;
use std::str::FromStr;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use trust_dns_client::rr as trust_dns_rr;
#[derive(Debug, Clone)]
/// Validated representation of a domain name
pub struct Name {
inner: trust_dns_rr::Name,
utf8_str: String,
}
impl Name {
fn as_str(&self) -> &str {
self.utf8_str.as_str()
}
}
impl FromStr for Name {
type Err = <trust_dns_rr::Name as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut n = trust_dns_rr::Name::from_str(s)?;
let utf8_str = n.clone().to_utf8();
n.set_fqdn(true);
Ok(Name { inner: n, utf8_str })
}
}
impl Serialize for Name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.as_str())
}
}
struct NameVisitor;
impl<'de> de::Visitor<'de> for NameVisitor {
type Value = Name;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a valid domain name")
}
fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
match Name::from_str(s) {
Ok(n) => Ok(n),
Err(e) => Err(E::custom(format!("invalid domain name: {}", e))),
}
}
}
impl<'de> Deserialize<'de> for Name {
fn deserialize<D>(deserializer: D) -> Result<Name, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(NameVisitor)
}
}

View File

@ -1,6 +1,8 @@
use std::error::Error;
use std::str::FromStr;
use crate::domain;
use mockall::automock;
use trust_dns_client::client::{Client, SyncClient};
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
@ -35,7 +37,11 @@ pub enum CheckDomainError {
#[automock]
pub trait Checker {
fn check_domain(&self, domain: &str, challenge_token: &str) -> Result<(), CheckDomainError>;
fn check_domain(
&self,
domain: &domain::Name,
challenge_token: &str,
) -> Result<(), CheckDomainError>;
}
pub struct DNSChecker {
@ -43,10 +49,10 @@ pub struct DNSChecker {
client: SyncClient<UdpClientConnection>,
}
pub fn new(target_cname: &str, resolver_addr: &str) -> Result<impl Checker, NewDNSCheckerError> {
let target_cname =
Name::from_str(target_cname).map_err(|_| NewDNSCheckerError::InvalidTargetCNAME)?;
pub fn new(
target_cname: domain::Name,
resolver_addr: &str,
) -> Result<impl Checker, NewDNSCheckerError> {
let resolver_addr = resolver_addr
.parse()
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
@ -57,21 +63,24 @@ pub fn new(target_cname: &str, resolver_addr: &str) -> Result<impl Checker, NewD
let client = SyncClient::new(conn);
Ok(DNSChecker {
target_cname,
target_cname: target_cname.inner,
client,
})
}
impl Checker for DNSChecker {
fn check_domain(&self, domain: &str, challenge_token: &str) -> Result<(), CheckDomainError> {
let mut fqdn = Name::from_str(domain).map_err(|_| CheckDomainError::InvalidDomainName)?;
fqdn.set_fqdn(true);
fn check_domain(
&self,
domain: &domain::Name,
challenge_token: &str,
) -> Result<(), CheckDomainError> {
let domain = &domain.inner;
// check that the CNAME is installed correctly on the domain
{
let response = self
.client
.query(&fqdn, DNSClass::IN, RecordType::CNAME)
.query(domain, DNSClass::IN, RecordType::CNAME)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let records = response.answers();
@ -90,14 +99,14 @@ impl Checker for DNSChecker {
// check that the TXT record with the challenge token is correctly installed on the domain
{
let fqdn = Name::from_str("_gateway")
let domain = Name::from_str("_gateway")
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?
.append_domain(&fqdn)
.append_domain(domain)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let response = self
.client
.query(&fqdn, DNSClass::IN, RecordType::TXT)
.query(&domain, DNSClass::IN, RecordType::TXT)
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
let records = response.answers();

View File

@ -2,6 +2,7 @@ use std::error::Error;
use std::path::{Path, PathBuf};
use std::{fs, io};
use crate::domain;
use crate::origin::Descr;
use hex::ToHex;
@ -39,8 +40,8 @@ pub enum SetError {
#[mockall::automock]
pub trait Store {
fn get(&self, domain: &str) -> Result<Config, GetError>;
fn set(&self, domain: &str, config: &Config) -> Result<(), SetError>;
fn get(&self, domain: &domain::Name) -> Result<Config, GetError>;
fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError>;
}
struct FSStore {
@ -55,17 +56,17 @@ pub fn new(dir_path: &Path) -> io::Result<impl Store> {
}
impl FSStore {
fn config_dir_path(&self, domain: &str) -> PathBuf {
self.dir_path.join(domain)
fn config_dir_path(&self, domain: &domain::Name) -> PathBuf {
self.dir_path.join(domain.as_str())
}
fn config_file_path(&self, domain: &str) -> PathBuf {
fn config_file_path(&self, domain: &domain::Name) -> PathBuf {
self.config_dir_path(domain).join("config.json")
}
}
impl Store for FSStore {
fn get(&self, domain: &str) -> Result<Config, GetError> {
fn get(&self, domain: &domain::Name) -> Result<Config, GetError> {
let config_file =
fs::File::open(self.config_file_path(domain)).map_err(|e| match e.kind() {
io::ErrorKind::NotFound => GetError::NotFound,
@ -75,7 +76,7 @@ impl Store for FSStore {
Ok(serde_json::from_reader(config_file).map_err(|e| GetError::Unexpected(Box::from(e)))?)
}
fn set(&self, domain: &str, config: &Config) -> Result<(), SetError> {
fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError> {
fs::create_dir_all(self.config_dir_path(domain))
.map_err(|e| SetError::Unexpected(Box::from(e)))?;
@ -92,7 +93,11 @@ impl Store for FSStore {
#[cfg(test)]
mod tests {
use super::*;
use crate::domain;
use crate::origin::Descr;
use std::str::FromStr;
use tempdir::TempDir;
#[test]
@ -101,7 +106,7 @@ mod tests {
let store = new(tmp_dir.path()).expect("store created");
let domain = "foo";
let domain = domain::Name::from_str("foo.com").expect("domain parsed");
let config = Config {
origin_descr: Descr::Git {
@ -111,12 +116,12 @@ mod tests {
};
assert!(matches!(
store.get(domain),
store.get(&domain),
Err::<Config, GetError>(GetError::NotFound)
));
store.set(domain, &config).expect("config set");
assert_eq!(config, store.get(domain).expect("config retrieved"));
store.set(&domain, &config).expect("config set");
assert_eq!(config, store.get(&domain).expect("config retrieved"));
let new_config = Config {
origin_descr: Descr::Git {
@ -125,7 +130,7 @@ mod tests {
},
};
store.set(domain, &new_config).expect("config set");
assert_eq!(new_config, store.get(domain).expect("config retrieved"));
store.set(&domain, &new_config).expect("config set");
assert_eq!(new_config, store.get(&domain).expect("config retrieved"));
}
}

View File

@ -1,4 +1,4 @@
use crate::domain::{checker, config};
use crate::domain::{self, checker, config};
use crate::origin;
use std::error::Error;
@ -121,12 +121,12 @@ pub trait Manager {
where
Self: 'mgr;
fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError>;
fn get_origin(&self, domain: &str) -> Result<Self::Origin<'_>, GetOriginError>;
fn sync(&self, domain: &str) -> Result<(), SyncError>;
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError>;
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError>;
fn sync(&self, domain: &domain::Name) -> Result<(), SyncError>;
fn sync_with_config(
&self,
domain: &str,
domain: &domain::Name,
config: &config::Config,
) -> Result<(), SyncWithConfigError>;
}
@ -169,11 +169,11 @@ where
type Origin<'mgr> = OriginStore::Origin<'mgr>
where Self: 'mgr;
fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError> {
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError> {
Ok(self.domain_config_store.get(domain)?)
}
fn get_origin(&self, domain: &str) -> Result<Self::Origin<'_>, GetOriginError> {
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError> {
let config = self.domain_config_store.get(domain)?;
let origin = self
.origin_store
@ -183,7 +183,7 @@ where
Ok(origin)
}
fn sync(&self, domain: &str) -> Result<(), SyncError> {
fn sync(&self, domain: &domain::Name) -> Result<(), SyncError> {
let config = self.domain_config_store.get(domain)?;
self.origin_store
.sync(config.origin_descr, origin::store::Limits {})
@ -196,7 +196,7 @@ where
fn sync_with_config(
&self,
domain: &str,
domain: &domain::Name,
config: &config::Config,
) -> Result<(), SyncWithConfigError> {
let config_hash = config

View File

@ -16,16 +16,16 @@ struct Cli {
http_listen_addr: SocketAddr,
#[arg(long, required = true, env = "GATEWAY_ORIGIN_STORE_GIT_DIR_PATH")]
origin_store_git_dir_path: Option<path::PathBuf>,
origin_store_git_dir_path: path::PathBuf,
#[arg(long, required = true, env = "GATEWAY_DOMAIN_CHECKER_TARGET_CNAME")]
domain_checker_target_cname: Option<String>,
domain_checker_target_cname: gateway::domain::Name,
#[arg(long, default_value_t = String::from("1.1.1.1:53"), env = "GATEWAY_DOMAIN_CHECKER_RESOLVER_ADDR")]
domain_checker_resolver_addr: String,
#[arg(long, required = true, env = "GATEWAY_DOMAIN_CONFIG_STORE_DIR_PATH")]
domain_config_store_dir_path: Option<path::PathBuf>,
domain_config_store_dir_path: path::PathBuf,
}
#[tokio::main]
@ -55,18 +55,17 @@ async fn main() {
stop_ch_rx
};
let origin_store = gateway::origin::store::git::new(config.origin_store_git_dir_path.unwrap())
let origin_store = gateway::origin::store::git::new(config.origin_store_git_dir_path)
.expect("git origin store initialized");
let domain_checker = gateway::domain::checker::new(
&config.domain_checker_target_cname.unwrap(),
config.domain_checker_target_cname,
&config.domain_checker_resolver_addr,
)
.expect("domain checker initialized");
let domain_config_store =
gateway::domain::config::new(&config.domain_config_store_dir_path.unwrap())
.expect("domain config store initialized");
let domain_config_store = gateway::domain::config::new(&config.domain_config_store_dir_path)
.expect("domain config store initialized");
let manager = gateway::domain::manager::new(origin_store, domain_config_store, domain_checker);