use a validated domain name type, rather than a raw string
This commit is contained in:
parent
cf3b11862c
commit
7718735215
@ -1,3 +1,72 @@
|
|||||||
pub mod checker;
|
pub mod checker;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod manager;
|
pub mod manager;
|
||||||
|
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
use trust_dns_client::rr as trust_dns_rr;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Validated representation of a domain name
|
||||||
|
pub struct Name {
|
||||||
|
inner: trust_dns_rr::Name,
|
||||||
|
utf8_str: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Name {
|
||||||
|
fn as_str(&self) -> &str {
|
||||||
|
self.utf8_str.as_str()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for Name {
|
||||||
|
type Err = <trust_dns_rr::Name as FromStr>::Err;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
let mut n = trust_dns_rr::Name::from_str(s)?;
|
||||||
|
let utf8_str = n.clone().to_utf8();
|
||||||
|
|
||||||
|
n.set_fqdn(true);
|
||||||
|
|
||||||
|
Ok(Name { inner: n, utf8_str })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for Name {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
serializer.serialize_str(self.as_str())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct NameVisitor;
|
||||||
|
|
||||||
|
impl<'de> de::Visitor<'de> for NameVisitor {
|
||||||
|
type Value = Name;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
write!(formatter, "a valid domain name")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: de::Error,
|
||||||
|
{
|
||||||
|
match Name::from_str(s) {
|
||||||
|
Ok(n) => Ok(n),
|
||||||
|
Err(e) => Err(E::custom(format!("invalid domain name: {}", e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for Name {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Name, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
deserializer.deserialize_str(NameVisitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use crate::domain;
|
||||||
|
|
||||||
use mockall::automock;
|
use mockall::automock;
|
||||||
use trust_dns_client::client::{Client, SyncClient};
|
use trust_dns_client::client::{Client, SyncClient};
|
||||||
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
|
use trust_dns_client::rr::{DNSClass, Name, RData, RecordType};
|
||||||
@ -35,7 +37,11 @@ pub enum CheckDomainError {
|
|||||||
|
|
||||||
#[automock]
|
#[automock]
|
||||||
pub trait Checker {
|
pub trait Checker {
|
||||||
fn check_domain(&self, domain: &str, challenge_token: &str) -> Result<(), CheckDomainError>;
|
fn check_domain(
|
||||||
|
&self,
|
||||||
|
domain: &domain::Name,
|
||||||
|
challenge_token: &str,
|
||||||
|
) -> Result<(), CheckDomainError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct DNSChecker {
|
pub struct DNSChecker {
|
||||||
@ -43,10 +49,10 @@ pub struct DNSChecker {
|
|||||||
client: SyncClient<UdpClientConnection>,
|
client: SyncClient<UdpClientConnection>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(target_cname: &str, resolver_addr: &str) -> Result<impl Checker, NewDNSCheckerError> {
|
pub fn new(
|
||||||
let target_cname =
|
target_cname: domain::Name,
|
||||||
Name::from_str(target_cname).map_err(|_| NewDNSCheckerError::InvalidTargetCNAME)?;
|
resolver_addr: &str,
|
||||||
|
) -> Result<impl Checker, NewDNSCheckerError> {
|
||||||
let resolver_addr = resolver_addr
|
let resolver_addr = resolver_addr
|
||||||
.parse()
|
.parse()
|
||||||
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
|
.map_err(|_| NewDNSCheckerError::InvalidResolverAddress)?;
|
||||||
@ -57,21 +63,24 @@ pub fn new(target_cname: &str, resolver_addr: &str) -> Result<impl Checker, NewD
|
|||||||
let client = SyncClient::new(conn);
|
let client = SyncClient::new(conn);
|
||||||
|
|
||||||
Ok(DNSChecker {
|
Ok(DNSChecker {
|
||||||
target_cname,
|
target_cname: target_cname.inner,
|
||||||
client,
|
client,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Checker for DNSChecker {
|
impl Checker for DNSChecker {
|
||||||
fn check_domain(&self, domain: &str, challenge_token: &str) -> Result<(), CheckDomainError> {
|
fn check_domain(
|
||||||
let mut fqdn = Name::from_str(domain).map_err(|_| CheckDomainError::InvalidDomainName)?;
|
&self,
|
||||||
fqdn.set_fqdn(true);
|
domain: &domain::Name,
|
||||||
|
challenge_token: &str,
|
||||||
|
) -> Result<(), CheckDomainError> {
|
||||||
|
let domain = &domain.inner;
|
||||||
|
|
||||||
// check that the CNAME is installed correctly on the domain
|
// check that the CNAME is installed correctly on the domain
|
||||||
{
|
{
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.query(&fqdn, DNSClass::IN, RecordType::CNAME)
|
.query(domain, DNSClass::IN, RecordType::CNAME)
|
||||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
||||||
|
|
||||||
let records = response.answers();
|
let records = response.answers();
|
||||||
@ -90,14 +99,14 @@ impl Checker for DNSChecker {
|
|||||||
|
|
||||||
// check that the TXT record with the challenge token is correctly installed on the domain
|
// check that the TXT record with the challenge token is correctly installed on the domain
|
||||||
{
|
{
|
||||||
let fqdn = Name::from_str("_gateway")
|
let domain = Name::from_str("_gateway")
|
||||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?
|
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?
|
||||||
.append_domain(&fqdn)
|
.append_domain(domain)
|
||||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.query(&fqdn, DNSClass::IN, RecordType::TXT)
|
.query(&domain, DNSClass::IN, RecordType::TXT)
|
||||||
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
.map_err(|e| CheckDomainError::Unexpected(Box::from(e)))?;
|
||||||
|
|
||||||
let records = response.answers();
|
let records = response.answers();
|
||||||
|
@ -2,6 +2,7 @@ use std::error::Error;
|
|||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::{fs, io};
|
use std::{fs, io};
|
||||||
|
|
||||||
|
use crate::domain;
|
||||||
use crate::origin::Descr;
|
use crate::origin::Descr;
|
||||||
|
|
||||||
use hex::ToHex;
|
use hex::ToHex;
|
||||||
@ -39,8 +40,8 @@ pub enum SetError {
|
|||||||
|
|
||||||
#[mockall::automock]
|
#[mockall::automock]
|
||||||
pub trait Store {
|
pub trait Store {
|
||||||
fn get(&self, domain: &str) -> Result<Config, GetError>;
|
fn get(&self, domain: &domain::Name) -> Result<Config, GetError>;
|
||||||
fn set(&self, domain: &str, config: &Config) -> Result<(), SetError>;
|
fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FSStore {
|
struct FSStore {
|
||||||
@ -55,17 +56,17 @@ pub fn new(dir_path: &Path) -> io::Result<impl Store> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl FSStore {
|
impl FSStore {
|
||||||
fn config_dir_path(&self, domain: &str) -> PathBuf {
|
fn config_dir_path(&self, domain: &domain::Name) -> PathBuf {
|
||||||
self.dir_path.join(domain)
|
self.dir_path.join(domain.as_str())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn config_file_path(&self, domain: &str) -> PathBuf {
|
fn config_file_path(&self, domain: &domain::Name) -> PathBuf {
|
||||||
self.config_dir_path(domain).join("config.json")
|
self.config_dir_path(domain).join("config.json")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Store for FSStore {
|
impl Store for FSStore {
|
||||||
fn get(&self, domain: &str) -> Result<Config, GetError> {
|
fn get(&self, domain: &domain::Name) -> Result<Config, GetError> {
|
||||||
let config_file =
|
let config_file =
|
||||||
fs::File::open(self.config_file_path(domain)).map_err(|e| match e.kind() {
|
fs::File::open(self.config_file_path(domain)).map_err(|e| match e.kind() {
|
||||||
io::ErrorKind::NotFound => GetError::NotFound,
|
io::ErrorKind::NotFound => GetError::NotFound,
|
||||||
@ -75,7 +76,7 @@ impl Store for FSStore {
|
|||||||
Ok(serde_json::from_reader(config_file).map_err(|e| GetError::Unexpected(Box::from(e)))?)
|
Ok(serde_json::from_reader(config_file).map_err(|e| GetError::Unexpected(Box::from(e)))?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set(&self, domain: &str, config: &Config) -> Result<(), SetError> {
|
fn set(&self, domain: &domain::Name, config: &Config) -> Result<(), SetError> {
|
||||||
fs::create_dir_all(self.config_dir_path(domain))
|
fs::create_dir_all(self.config_dir_path(domain))
|
||||||
.map_err(|e| SetError::Unexpected(Box::from(e)))?;
|
.map_err(|e| SetError::Unexpected(Box::from(e)))?;
|
||||||
|
|
||||||
@ -92,7 +93,11 @@ impl Store for FSStore {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::domain;
|
||||||
use crate::origin::Descr;
|
use crate::origin::Descr;
|
||||||
|
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
use tempdir::TempDir;
|
use tempdir::TempDir;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -101,7 +106,7 @@ mod tests {
|
|||||||
|
|
||||||
let store = new(tmp_dir.path()).expect("store created");
|
let store = new(tmp_dir.path()).expect("store created");
|
||||||
|
|
||||||
let domain = "foo";
|
let domain = domain::Name::from_str("foo.com").expect("domain parsed");
|
||||||
|
|
||||||
let config = Config {
|
let config = Config {
|
||||||
origin_descr: Descr::Git {
|
origin_descr: Descr::Git {
|
||||||
@ -111,12 +116,12 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
store.get(domain),
|
store.get(&domain),
|
||||||
Err::<Config, GetError>(GetError::NotFound)
|
Err::<Config, GetError>(GetError::NotFound)
|
||||||
));
|
));
|
||||||
|
|
||||||
store.set(domain, &config).expect("config set");
|
store.set(&domain, &config).expect("config set");
|
||||||
assert_eq!(config, store.get(domain).expect("config retrieved"));
|
assert_eq!(config, store.get(&domain).expect("config retrieved"));
|
||||||
|
|
||||||
let new_config = Config {
|
let new_config = Config {
|
||||||
origin_descr: Descr::Git {
|
origin_descr: Descr::Git {
|
||||||
@ -125,7 +130,7 @@ mod tests {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
store.set(domain, &new_config).expect("config set");
|
store.set(&domain, &new_config).expect("config set");
|
||||||
assert_eq!(new_config, store.get(domain).expect("config retrieved"));
|
assert_eq!(new_config, store.get(&domain).expect("config retrieved"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::domain::{checker, config};
|
use crate::domain::{self, checker, config};
|
||||||
use crate::origin;
|
use crate::origin;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
|
||||||
@ -121,12 +121,12 @@ pub trait Manager {
|
|||||||
where
|
where
|
||||||
Self: 'mgr;
|
Self: 'mgr;
|
||||||
|
|
||||||
fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError>;
|
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError>;
|
||||||
fn get_origin(&self, domain: &str) -> Result<Self::Origin<'_>, GetOriginError>;
|
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError>;
|
||||||
fn sync(&self, domain: &str) -> Result<(), SyncError>;
|
fn sync(&self, domain: &domain::Name) -> Result<(), SyncError>;
|
||||||
fn sync_with_config(
|
fn sync_with_config(
|
||||||
&self,
|
&self,
|
||||||
domain: &str,
|
domain: &domain::Name,
|
||||||
config: &config::Config,
|
config: &config::Config,
|
||||||
) -> Result<(), SyncWithConfigError>;
|
) -> Result<(), SyncWithConfigError>;
|
||||||
}
|
}
|
||||||
@ -169,11 +169,11 @@ where
|
|||||||
type Origin<'mgr> = OriginStore::Origin<'mgr>
|
type Origin<'mgr> = OriginStore::Origin<'mgr>
|
||||||
where Self: 'mgr;
|
where Self: 'mgr;
|
||||||
|
|
||||||
fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError> {
|
fn get_config(&self, domain: &domain::Name) -> Result<config::Config, GetConfigError> {
|
||||||
Ok(self.domain_config_store.get(domain)?)
|
Ok(self.domain_config_store.get(domain)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_origin(&self, domain: &str) -> Result<Self::Origin<'_>, GetOriginError> {
|
fn get_origin(&self, domain: &domain::Name) -> Result<Self::Origin<'_>, GetOriginError> {
|
||||||
let config = self.domain_config_store.get(domain)?;
|
let config = self.domain_config_store.get(domain)?;
|
||||||
let origin = self
|
let origin = self
|
||||||
.origin_store
|
.origin_store
|
||||||
@ -183,7 +183,7 @@ where
|
|||||||
Ok(origin)
|
Ok(origin)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sync(&self, domain: &str) -> Result<(), SyncError> {
|
fn sync(&self, domain: &domain::Name) -> Result<(), SyncError> {
|
||||||
let config = self.domain_config_store.get(domain)?;
|
let config = self.domain_config_store.get(domain)?;
|
||||||
self.origin_store
|
self.origin_store
|
||||||
.sync(config.origin_descr, origin::store::Limits {})
|
.sync(config.origin_descr, origin::store::Limits {})
|
||||||
@ -196,7 +196,7 @@ where
|
|||||||
|
|
||||||
fn sync_with_config(
|
fn sync_with_config(
|
||||||
&self,
|
&self,
|
||||||
domain: &str,
|
domain: &domain::Name,
|
||||||
config: &config::Config,
|
config: &config::Config,
|
||||||
) -> Result<(), SyncWithConfigError> {
|
) -> Result<(), SyncWithConfigError> {
|
||||||
let config_hash = config
|
let config_hash = config
|
||||||
|
15
src/main.rs
15
src/main.rs
@ -16,16 +16,16 @@ struct Cli {
|
|||||||
http_listen_addr: SocketAddr,
|
http_listen_addr: SocketAddr,
|
||||||
|
|
||||||
#[arg(long, required = true, env = "GATEWAY_ORIGIN_STORE_GIT_DIR_PATH")]
|
#[arg(long, required = true, env = "GATEWAY_ORIGIN_STORE_GIT_DIR_PATH")]
|
||||||
origin_store_git_dir_path: Option<path::PathBuf>,
|
origin_store_git_dir_path: path::PathBuf,
|
||||||
|
|
||||||
#[arg(long, required = true, env = "GATEWAY_DOMAIN_CHECKER_TARGET_CNAME")]
|
#[arg(long, required = true, env = "GATEWAY_DOMAIN_CHECKER_TARGET_CNAME")]
|
||||||
domain_checker_target_cname: Option<String>,
|
domain_checker_target_cname: gateway::domain::Name,
|
||||||
|
|
||||||
#[arg(long, default_value_t = String::from("1.1.1.1:53"), env = "GATEWAY_DOMAIN_CHECKER_RESOLVER_ADDR")]
|
#[arg(long, default_value_t = String::from("1.1.1.1:53"), env = "GATEWAY_DOMAIN_CHECKER_RESOLVER_ADDR")]
|
||||||
domain_checker_resolver_addr: String,
|
domain_checker_resolver_addr: String,
|
||||||
|
|
||||||
#[arg(long, required = true, env = "GATEWAY_DOMAIN_CONFIG_STORE_DIR_PATH")]
|
#[arg(long, required = true, env = "GATEWAY_DOMAIN_CONFIG_STORE_DIR_PATH")]
|
||||||
domain_config_store_dir_path: Option<path::PathBuf>,
|
domain_config_store_dir_path: path::PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -55,18 +55,17 @@ async fn main() {
|
|||||||
stop_ch_rx
|
stop_ch_rx
|
||||||
};
|
};
|
||||||
|
|
||||||
let origin_store = gateway::origin::store::git::new(config.origin_store_git_dir_path.unwrap())
|
let origin_store = gateway::origin::store::git::new(config.origin_store_git_dir_path)
|
||||||
.expect("git origin store initialized");
|
.expect("git origin store initialized");
|
||||||
|
|
||||||
let domain_checker = gateway::domain::checker::new(
|
let domain_checker = gateway::domain::checker::new(
|
||||||
&config.domain_checker_target_cname.unwrap(),
|
config.domain_checker_target_cname,
|
||||||
&config.domain_checker_resolver_addr,
|
&config.domain_checker_resolver_addr,
|
||||||
)
|
)
|
||||||
.expect("domain checker initialized");
|
.expect("domain checker initialized");
|
||||||
|
|
||||||
let domain_config_store =
|
let domain_config_store = gateway::domain::config::new(&config.domain_config_store_dir_path)
|
||||||
gateway::domain::config::new(&config.domain_config_store_dir_path.unwrap())
|
.expect("domain config store initialized");
|
||||||
.expect("domain config store initialized");
|
|
||||||
|
|
||||||
let manager = gateway::domain::manager::new(origin_store, domain_config_store, domain_checker);
|
let manager = gateway::domain::manager::new(origin_store, domain_config_store, domain_checker);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user