use a validated domain name type, rather than a raw string
This commit is contained in:
parent
cf3b11862c
commit
7718735215
@ -1,3 +1,72 @@
|
||||
pub mod checker;
|
||||
pub mod config;
|
||||
pub mod manager;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
||||
use trust_dns_client::rr as trust_dns_rr;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Validated representation of a domain name
|
||||
pub struct Name {
|
||||
inner: trust_dns_rr::Name,
|
||||
utf8_str: String,
|
||||
}
|
||||
|
||||
impl Name {
|
||||
fn as_str(&self) -> &str {
|
||||
self.utf8_str.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Name {
|
||||
type Err = <trust_dns_rr::Name as FromStr>::Err;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut n = trust_dns_rr::Name::from_str(s)?;
|
||||
let utf8_str = n.clone().to_utf8();
|
||||
|
||||
n.set_fqdn(true);
|
||||
|
||||
Ok(Name { inner: n, utf8_str })
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Name {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
struct NameVisitor;
|
||||
|
||||
impl<'de> de::Visitor<'de> for NameVisitor {
|
||||
type Value = Name;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(formatter, "a valid domain name")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
match Name::from_str(s) {
|
||||
Ok(n) => Ok(n),
|
||||
Err(e) => Err(E::custom(format!("invalid domain name: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Name {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Name, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
deserializer.deserialize_str(NameVisitor)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
use std::error::Error;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::domain;
|
||||
|
||||
use mockall::automock;
|
||||
use trust_dns_client::client::{Client, SyncClient};
|
||||
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
|
||||
@ -35,7 +37,11 @@ pub enum CheckDomainError {
|
||||
|
||||
#[automock]
|
||||
pub trait Checker {
|
||||
fn check_domain(&self, domain: &str, challenge_token: &str) -> Result<(), CheckDomainError>;
|
||||
fn check_domain(
|
||||
&self,
|
||||
domain: &domain::Name,
|
||||
challenge_token: &str,
|
||||
) -> Result<(), CheckDomainError>;
|
||||
}
|
||||
|
||||
pub struct DNSChecker {
|
||||
@ -43,10 +49,10 @@ pub struct DNSChecker {
|
||||
client: SyncClient<UdpClientConnection>,
|
||||
}
|
||||
|
||||
pub fn new(target_cname: &str, resolver_addr: &str) -> Result<impl Checker, NewDNSCheckerError> {
|
||||
let target_cname =
|
||||
Name::from_str(target_cname).map_err(|_| NewDNSCheckerError::InvalidTargetCNAME)?;
|
||||
|
||||
pub fn new(
|
||||
target_cname: domain::Name,
|
||||
resolver_addr: &str,
|
||||
) -> Result<impl Checker, NewDNSCheckerError> {
|
||||
let resolver_addr = resolver_addr
|
||||
.parse()
|
||||
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
|
||||
@ -57,21 +63,24 @@ pub fn new(target_cname: &str, resolver_addr: &str) -> Result<impl Checker, NewD
|
||||
let client = SyncClient::new(conn);
|
||||
|
||||
Ok(DNSChecker {
|
||||
target_cname,
|
||||
target_cname: target_cname.inner,
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
impl Checker for DNSChecker {
|
||||
fn check_domain(&self, domain: &str, challenge_token: &str) -> Result<(), CheckDomainError> {
|
||||
let mut fqdn = Name::from_str(domain).map_err(|_| CheckDomainError::InvalidDomainName)?;
|
||||
fqdn.set_fqdn(true);
|
||||
fn check_domain(
|
||||
&self,
|
||||
domain: &domain::Name,
|
||||
challenge_token: &str,
|
||||
) -> Result<(), CheckDomainError> {
|
||||
let domain = &domain.inner;
|
||||
|
||||
// check that the CNAME is installed correctly on the domain
|
||||
{
|
||||
let response = self
|
||||
.client
|
||||
.query(&fqdn, DNSClass::IN, RecordType::CNAME)
|
||||
.query(domain, DNSClass::IN, RecordType::CNAME)
|
||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
||||
|
||||
let records = response.answers();
|
||||
@ -90,14 +99,14 @@ impl Checker for DNSChecker {
|
||||
|
||||
// check that the TXT record with the challenge token is correctly installed on the domain
|
||||
{
|
||||
let fqdn = Name::from_str("_gateway")
|
||||
let domain = Name::from_str("_gateway")
|
||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?
|
||||
.append_domain(&fqdn)
|
||||
.append_domain(domain)
|
||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.query(&fqdn, DNSClass::IN, RecordType::TXT)
|
||||
.query(&domain, DNSClass::IN, RecordType::TXT)
|
||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
||||
|
||||
let records = response.answers();
|
||||
|
@ -2,6 +2,7 @@ use std::error::Error;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::{fs, io};
|
||||
|
||||
use crate::domain;
|
||||
use crate::origin::Descr;
|
||||
|
||||
use hex::ToHex;
|
||||
@ -39,8 +40,8 @@ pub enum SetError {
|
||||
|
||||
#[mockall::automock]
|
||||
pub trait Store {
|
||||
fn get(&self, domain: &str) -> Result<Config, GetError>;
|
||||
fn set(&self, domain: &str, config: &Config) -> Result<(), SetError>;
|
||||
fn get(&self, domain: &domain::Name) -> Result<Config, GetError>;
|
||||
fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError>;
|
||||
}
|
||||
|
||||
struct FSStore {
|
||||
@ -55,17 +56,17 @@ pub fn new(dir_path: &Path) -> io::Result<impl Store> {
|
||||
}
|
||||
|
||||
impl FSStore {
|
||||
fn config_dir_path(&self, domain: &str) -> PathBuf {
|
||||
self.dir_path.join(domain)
|
||||
fn config_dir_path(&self, domain: &domain::Name) -> PathBuf {
|
||||
self.dir_path.join(domain.as_str())
|
||||
}
|
||||
|
||||
fn config_file_path(&self, domain: &str) -> PathBuf {
|
||||
fn config_file_path(&self, domain: &domain::Name) -> PathBuf {
|
||||
self.config_dir_path(domain).join("config.json")
|
||||
}
|
||||
}
|
||||
|
||||
impl Store for FSStore {
|
||||
fn get(&self, domain: &str) -> Result<Config, GetError> {
|
||||
fn get(&self, domain: &domain::Name) -> Result<Config, GetError> {
|
||||
let config_file =
|
||||
fs::File::open(self.config_file_path(domain)).map_err(|e| match e.kind() {
|
||||
io::ErrorKind::NotFound => GetError::NotFound,
|
||||
@ -75,7 +76,7 @@ impl Store for FSStore {
|
||||
Ok(serde_json::from_reader(config_file).map_err(|e| GetError::Unexpected(Box::from(e)))?)
|
||||
}
|
||||
|
||||
fn set(&self, domain: &str, config: &Config) -> Result<(), SetError> {
|
||||
fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError> {
|
||||
fs::create_dir_all(self.config_dir_path(domain))
|
||||
.map_err(|e| SetError::Unexpected(Box::from(e)))?;
|
||||
|
||||
@ -92,7 +93,11 @@ impl Store for FSStore {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::domain;
|
||||
use crate::origin::Descr;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use tempdir::TempDir;
|
||||
|
||||
#[test]
|
||||
@ -101,7 +106,7 @@ mod tests {
|
||||
|
||||
let store = new(tmp_dir.path()).expect("store created");
|
||||
|
||||
let domain = "foo";
|
||||
let domain = domain::Name::from_str("foo.com").expect("domain parsed");
|
||||
|
||||
let config = Config {
|
||||
origin_descr: Descr::Git {
|
||||
@ -111,12 +116,12 @@ mod tests {
|
||||
};
|
||||
|
||||
assert!(matches!(
|
||||
store.get(domain),
|
||||
store.get(&domain),
|
||||
Err::<Config, GetError>(GetError::NotFound)
|
||||
));
|
||||
|
||||
store.set(domain, &config).expect("config set");
|
||||
assert_eq!(config, store.get(domain).expect("config retrieved"));
|
||||
store.set(&domain, &config).expect("config set");
|
||||
assert_eq!(config, store.get(&domain).expect("config retrieved"));
|
||||
|
||||
let new_config = Config {
|
||||
origin_descr: Descr::Git {
|
||||
@ -125,7 +130,7 @@ mod tests {
|
||||
},
|
||||
};
|
||||
|
||||
store.set(domain, &new_config).expect("config set");
|
||||
assert_eq!(new_config, store.get(domain).expect("config retrieved"));
|
||||
store.set(&domain, &new_config).expect("config set");
|
||||
assert_eq!(new_config, store.get(&domain).expect("config retrieved"));
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::domain::{checker, config};
|
||||
use crate::domain::{self, checker, config};
|
||||
use crate::origin;
|
||||
use std::error::Error;
|
||||
|
||||
@ -121,12 +121,12 @@ pub trait Manager {
|
||||
where
|
||||
Self: 'mgr;
|
||||
|
||||
fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError>;
|
||||
fn get_origin(&self, domain: &str) -> Result<Self::Origin<'_>, GetOriginError>;
|
||||
fn sync(&self, domain: &str) -> Result<(), SyncError>;
|
||||
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError>;
|
||||
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError>;
|
||||
fn sync(&self, domain: &domain::Name) -> Result<(), SyncError>;
|
||||
fn sync_with_config(
|
||||
&self,
|
||||
domain: &str,
|
||||
domain: &domain::Name,
|
||||
config: &config::Config,
|
||||
) -> Result<(), SyncWithConfigError>;
|
||||
}
|
||||
@ -169,11 +169,11 @@ where
|
||||
type Origin<'mgr> = OriginStore::Origin<'mgr>
|
||||
where Self: 'mgr;
|
||||
|
||||
fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError> {
|
||||
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError> {
|
||||
Ok(self.domain_config_store.get(domain)?)
|
||||
}
|
||||
|
||||
fn get_origin(&self, domain: &str) -> Result<Self::Origin<'_>, GetOriginError> {
|
||||
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError> {
|
||||
let config = self.domain_config_store.get(domain)?;
|
||||
let origin = self
|
||||
.origin_store
|
||||
@ -183,7 +183,7 @@ where
|
||||
Ok(origin)
|
||||
}
|
||||
|
||||
fn sync(&self, domain: &str) -> Result<(), SyncError> {
|
||||
fn sync(&self, domain: &domain::Name) -> Result<(), SyncError> {
|
||||
let config = self.domain_config_store.get(domain)?;
|
||||
self.origin_store
|
||||
.sync(config.origin_descr, origin::store::Limits {})
|
||||
@ -196,7 +196,7 @@ where
|
||||
|
||||
fn sync_with_config(
|
||||
&self,
|
||||
domain: &str,
|
||||
domain: &domain::Name,
|
||||
config: &config::Config,
|
||||
) -> Result<(), SyncWithConfigError> {
|
||||
let config_hash = config
|
||||
|
13
src/main.rs
13
src/main.rs
@ -16,16 +16,16 @@ struct Cli {
|
||||
http_listen_addr: SocketAddr,
|
||||
|
||||
#[arg(long, required = true, env = "GATEWAY_ORIGIN_STORE_GIT_DIR_PATH")]
|
||||
origin_store_git_dir_path: Option<path::PathBuf>,
|
||||
origin_store_git_dir_path: path::PathBuf,
|
||||
|
||||
#[arg(long, required = true, env = "GATEWAY_DOMAIN_CHECKER_TARGET_CNAME")]
|
||||
domain_checker_target_cname: Option<String>,
|
||||
domain_checker_target_cname: gateway::domain::Name,
|
||||
|
||||
#[arg(long, default_value_t = String::from("1.1.1.1:53"), env = "GATEWAY_DOMAIN_CHECKER_RESOLVER_ADDR")]
|
||||
domain_checker_resolver_addr: String,
|
||||
|
||||
#[arg(long, required = true, env = "GATEWAY_DOMAIN_CONFIG_STORE_DIR_PATH")]
|
||||
domain_config_store_dir_path: Option<path::PathBuf>,
|
||||
domain_config_store_dir_path: path::PathBuf,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -55,17 +55,16 @@ async fn main() {
|
||||
stop_ch_rx
|
||||
};
|
||||
|
||||
let origin_store = gateway::origin::store::git::new(config.origin_store_git_dir_path.unwrap())
|
||||
let origin_store = gateway::origin::store::git::new(config.origin_store_git_dir_path)
|
||||
.expect("git origin store initialized");
|
||||
|
||||
let domain_checker = gateway::domain::checker::new(
|
||||
&config.domain_checker_target_cname.unwrap(),
|
||||
config.domain_checker_target_cname,
|
||||
&config.domain_checker_resolver_addr,
|
||||
)
|
||||
.expect("domain checker initialized");
|
||||
|
||||
let domain_config_store =
|
||||
gateway::domain::config::new(&config.domain_config_store_dir_path.unwrap())
|
||||
let domain_config_store = gateway::domain::config::new(&config.domain_config_store_dir_path)
|
||||
.expect("domain config store initialized");
|
||||
|
||||
let manager = gateway::domain::manager::new(origin_store, domain_config_store, domain_checker);
|
||||
|
Loading…
Reference in New Issue
Block a user