domani/src/main.rs
2023-05-20 14:51:36 +02:00

339 lines
11 KiB
Rust

#![feature(result_option_inspect)]
use clap::Parser;
use futures::stream::futures_unordered::FuturesUnordered;
use futures::stream::StreamExt;
use signal_hook_tokio::Signals;
use tokio::select;
use tokio::time;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::path;
use std::str::FromStr;
use std::sync;
use domiply::domain::acme::manager::Manager as AcmeManager;
use domiply::domain::manager::Manager;
#[derive(Parser, Debug)]
#[command(version)]
#[command(about = "A domiply to another dimension")]
struct Cli {
#[arg(long, required = true, env = "DOMIPLY_HTTP_DOMAIN")]
http_domain: domiply::domain::Name,
#[arg(long, default_value_t = SocketAddr::from_str("[::]:3030").unwrap(), env = "DOMIPLY_HTTP_LISTEN_ADDR")]
http_listen_addr: SocketAddr,
#[arg(
long,
help = "E.g. '[::]:443', if given then SSL certs will automatically be retrieved for all domains using LetsEncrypt",
env = "DOMIPLY_HTTPS_LISTEN_ADDR",
requires = "domain_acme_contact_email",
requires = "domain_acme_store_dir_path"
)]
https_listen_addr: Option<SocketAddr>,
#[arg(long, required = true, env = "DOMIPLY_PASSPHRASE")]
passphrase: String,
#[arg(long, required = true, env = "DOMIPLY_ORIGIN_STORE_GIT_DIR_PATH")]
origin_store_git_dir_path: path::PathBuf,
#[arg(long, required = true, env = "DOMIPLY_DOMAIN_CHECKER_TARGET_A")]
domain_checker_target_a: std::net::Ipv4Addr,
#[arg(long, default_value_t = String::from("1.1.1.1:53"), env = "DOMIPLY_DOMAIN_CHECKER_RESOLVER_ADDR")]
domain_checker_resolver_addr: String,
#[arg(long, required = true, env = "DOMIPLY_DOMAIN_CONFIG_STORE_DIR_PATH")]
domain_config_store_dir_path: path::PathBuf,
#[arg(long, env = "DOMIPLY_DOMAIN_ACME_STORE_DIR_PATH")]
domain_acme_store_dir_path: Option<path::PathBuf>,
#[arg(long, env = "DOMIPLY_DOMAIN_ACME_CONTACT_EMAIL")]
domain_acme_contact_email: Option<String>,
}
#[derive(Clone)]
struct HTTPSParams<DomainAcmeStore, DomainAcmeManager>
where
DomainAcmeStore: domiply::domain::acme::store::BoxedStore,
DomainAcmeManager: domiply::domain::acme::manager::BoxedManager,
{
https_listen_addr: SocketAddr,
domain_acme_store: DomainAcmeStore,
domain_acme_manager: DomainAcmeManager,
}
#[tokio::main]
async fn main() {
let config = Cli::parse();
let mut wait_group = FuturesUnordered::new();
let canceller = tokio_util::sync::CancellationToken::new();
{
let canceller = canceller.clone();
tokio::spawn(async move {
let mut signals =
Signals::new(signal_hook::consts::TERM_SIGNALS).expect("initialized signals");
if (signals.next().await).is_some() {
println!("Gracefully shutting down...");
canceller.cancel();
}
if (signals.next().await).is_some() {
println!("Forcefully shutting down");
std::process::exit(1);
};
});
}
let origin_store = domiply::origin::store::git::new(config.origin_store_git_dir_path)
.expect("git origin store initialized");
let domain_checker = domiply::domain::checker::new(
config.domain_checker_target_a,
&config.domain_checker_resolver_addr,
)
.await
.expect("domain checker initialized");
let domain_config_store = domiply::domain::config::new(&config.domain_config_store_dir_path)
.expect("domain config store initialized");
let https_params = if let Some(https_listen_addr) = config.https_listen_addr {
let domain_acme_store_dir_path = config.domain_acme_store_dir_path.unwrap();
let domain_acme_store = domiply::domain::acme::store::new(&domain_acme_store_dir_path)
.expect("domain acme store initialized");
// if https_listen_addr is set then domain_acme_contact_email is required, see the Cli/clap
// settings.
let domain_acme_contact_email = config.domain_acme_contact_email.unwrap();
let domain_acme_manager = domiply::domain::acme::manager::new(
domain_acme_store.clone(),
&domain_acme_contact_email,
)
.await
.expect("domain acme manager initialized");
Some(HTTPSParams {
https_listen_addr,
domain_acme_store,
domain_acme_manager,
})
} else {
None
};
let domain_manager = domiply::domain::manager::new(
origin_store,
domain_config_store,
domain_checker,
https_params
.as_ref()
.and_then(|p| Some(p.domain_acme_manager.clone())),
);
wait_group.push({
let domain_manager = domain_manager.clone();
let canceller = canceller.clone();
tokio::spawn(async move {
let mut interval = time::interval(time::Duration::from_secs(20 * 60));
loop {
select! {
_ = interval.tick() => (),
_ = canceller.cancelled() => return,
}
let errors_iter = domain_manager.sync_all_origins();
if let Err(err) = errors_iter {
println!("Got error calling sync_all_origins: {err}");
continue;
}
errors_iter
.unwrap()
.into_iter()
.for_each(|(descr, err)| match descr {
None => println!("error while syncing unknown descr: {err}"),
Some(descr) => println!("failed to sync {descr:?}: {err}"),
});
}
})
});
let service = domiply::service::new(
domain_manager.clone(),
config.domain_checker_target_a,
config.passphrase,
config.http_domain.clone(),
);
let service = sync::Arc::new(service);
wait_group.push({
let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
let service = service.clone();
let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone();
// Create a `Service` for responding to the request.
let service = hyper::service::service_fn(move |req| {
domiply::service::handle_request(service.clone(), req)
});
// Return the service to hyper.
async move { Ok::<_, Infallible>(service) }
});
tokio::spawn(async move {
let addr = config.http_listen_addr;
println!(
"Listening on http://{}:{}",
http_domain.as_str(),
addr.port()
);
let server = hyper::Server::bind(&addr).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
canceller.cancelled().await;
});
if let Err(e) = graceful.await {
panic!("server error: {}", e);
};
})
});
if let Some(https_params) = https_params {
// Periodically refresh all domain certs, including the http_domain passed in the Cli opts
wait_group.push({
let https_params = https_params.clone();
let domain_manager = domain_manager.clone();
let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
tokio::spawn(async move {
let mut interval = time::interval(time::Duration::from_secs(60 * 60));
loop {
select! {
_ = interval.tick() => (),
_ = canceller.cancelled() => return,
}
_ = https_params
.domain_acme_manager
.sync_domain(http_domain.clone())
.await
.inspect_err(|err| {
println!(
"Error while getting cert for {}: {err}",
http_domain.as_str()
)
});
let domains_iter = domain_manager.all_domains();
if let Err(err) = domains_iter {
println!("Got error calling all_domains: {err}");
continue;
}
for domain in domains_iter.unwrap().into_iter() {
match domain {
Ok(domain) => {
let _ = https_params
.domain_acme_manager
.sync_domain(domain.clone())
.await
.inspect_err(|err| {
println!(
"Error while getting cert for {}: {err}",
domain.as_str(),
)
});
}
Err(err) => println!("Error iterating through domains: {err}"),
};
}
}
})
});
// HTTPS server
wait_group.push({
let https_params = https_params.clone();
let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
let service = service.clone();
let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone();
// Create a `Service` for responding to the request.
let service = hyper::service::service_fn(move |req| {
domiply::service::handle_request(service.clone(), req)
});
// Return the service to hyper.
async move { Ok::<_, Infallible>(service) }
});
tokio::spawn(async move {
let canceller = canceller.clone();
let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
rustls::server::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()
.unwrap()
.with_no_client_auth()
.with_cert_resolver(sync::Arc::from(https_params.domain_acme_store)),
)
.into();
let addr = https_params.https_listen_addr;
let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr)
.expect("https listen socket created");
let incoming = tls_listener::TlsListener::new(server_config, addr_incoming);
println!(
"Listening on https://{}:{}",
http_domain.as_str(),
addr.port()
);
let server = hyper::Server::builder(incoming).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
canceller.cancelled().await;
});
if let Err(e) = graceful.await {
panic!("server error: {}", e);
};
})
})
}
while wait_group.next().await.is_some() {}
println!("Graceful shutdown complete");
}