Moved service tasks internally, main crashes on shutdown though
This commit is contained in:
parent
7ea97b2617
commit
506037dcd0
@ -1,3 +1,4 @@
|
|||||||
|
#![feature(result_option_inspect)]
|
||||||
#![feature(iterator_try_collect)]
|
#![feature(iterator_try_collect)]
|
||||||
|
|
||||||
pub mod domain;
|
pub mod domain;
|
||||||
|
174
src/main.rs
174
src/main.rs
@ -1,18 +1,10 @@
|
|||||||
#![feature(result_option_inspect)]
|
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use futures::stream::futures_unordered::FuturesUnordered;
|
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use signal_hook_tokio::Signals;
|
use signal_hook_tokio::Signals;
|
||||||
use tokio::select;
|
|
||||||
use tokio::time;
|
|
||||||
|
|
||||||
use std::convert::Infallible;
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::{future, path, sync};
|
use std::{path, sync};
|
||||||
|
|
||||||
use domiply::domain::manager::Manager;
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(version)]
|
#[command(version)]
|
||||||
@ -86,7 +78,6 @@ async fn main() {
|
|||||||
)
|
)
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
let mut wait_group = FuturesUnordered::new();
|
|
||||||
let canceller = tokio_util::sync::CancellationToken::new();
|
let canceller = tokio_util::sync::CancellationToken::new();
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -156,173 +147,24 @@ async fn main() {
|
|||||||
|
|
||||||
let domain_manager = sync::Arc::new(domain_manager);
|
let domain_manager = sync::Arc::new(domain_manager);
|
||||||
|
|
||||||
|
{
|
||||||
let http_service = domiply::service::http::new(
|
let http_service = domiply::service::http::new(
|
||||||
domain_manager.clone(),
|
domain_manager.clone(),
|
||||||
config.domain_checker_target_a,
|
config.domain_checker_target_a,
|
||||||
config.passphrase,
|
config.passphrase,
|
||||||
|
config.http_listen_addr.clone(),
|
||||||
config.http_domain.clone(),
|
config.http_domain.clone(),
|
||||||
|
https_params.map(|p| domiply::service::http::HTTPSParams {
|
||||||
|
listen_addr: p.https_listen_addr,
|
||||||
|
cert_resolver: domiply::domain::acme::resolver::new(p.domain_acme_store),
|
||||||
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
let http_service = sync::Arc::new(http_service);
|
|
||||||
|
|
||||||
wait_group.push({
|
|
||||||
let http_domain = config.http_domain.clone();
|
|
||||||
let canceller = canceller.clone();
|
|
||||||
let service = http_service.clone();
|
|
||||||
|
|
||||||
let make_service = hyper::service::make_service_fn(move |_| {
|
|
||||||
let service = service.clone();
|
|
||||||
|
|
||||||
// Create a `Service` for responding to the request.
|
|
||||||
let service = hyper::service::service_fn(move |req| {
|
|
||||||
domiply::service::http::handle_request(service.clone(), req)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Return the service to hyper.
|
|
||||||
async move { Ok::<_, Infallible>(service) }
|
|
||||||
});
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let addr = config.http_listen_addr;
|
|
||||||
|
|
||||||
log::info!(
|
|
||||||
"Listening on http://{}:{}",
|
|
||||||
http_domain.as_str(),
|
|
||||||
addr.port()
|
|
||||||
);
|
|
||||||
let server = hyper::Server::bind(&addr).serve(make_service);
|
|
||||||
|
|
||||||
let graceful = server.with_graceful_shutdown(async {
|
|
||||||
canceller.cancelled().await;
|
canceller.cancelled().await;
|
||||||
});
|
|
||||||
|
|
||||||
if let Err(e) = graceful.await {
|
sync::Arc::into_inner(http_service).unwrap().stop().await;
|
||||||
panic!("server error: {}", e);
|
|
||||||
};
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(https_params) = https_params {
|
|
||||||
// Periodically refresh all domain certs, including the http_domain passed in the Cli opts
|
|
||||||
wait_group.push({
|
|
||||||
let domain_manager = domain_manager.clone();
|
|
||||||
let http_domain = config.http_domain.clone();
|
|
||||||
let canceller = canceller.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let mut interval = time::interval(time::Duration::from_secs(60 * 60));
|
|
||||||
|
|
||||||
loop {
|
|
||||||
select! {
|
|
||||||
_ = interval.tick() => (),
|
|
||||||
_ = canceller.cancelled() => return,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = domain_manager
|
|
||||||
.sync_cert(http_domain.clone())
|
|
||||||
.await
|
|
||||||
.inspect_err(|err| {
|
|
||||||
log::error!(
|
|
||||||
"Error while getting cert for {}: {err}",
|
|
||||||
http_domain.as_str()
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
let domains_iter = domain_manager.all_domains();
|
|
||||||
|
|
||||||
if let Err(err) = domains_iter {
|
|
||||||
log::error!("Got error calling all_domains: {err}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
for domain in domains_iter.unwrap().into_iter() {
|
|
||||||
let _ = domain_manager
|
|
||||||
.sync_cert(domain.clone())
|
|
||||||
.await
|
|
||||||
.inspect_err(|err| {
|
|
||||||
log::error!(
|
|
||||||
"Error while getting cert for {}: {err}",
|
|
||||||
domain.as_str(),
|
|
||||||
)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
// HTTPS server
|
|
||||||
wait_group.push({
|
|
||||||
let http_domain = config.http_domain.clone();
|
|
||||||
let canceller = canceller.clone();
|
|
||||||
let service = http_service.clone();
|
|
||||||
|
|
||||||
let make_service = hyper::service::make_service_fn(move |_| {
|
|
||||||
let service = service.clone();
|
|
||||||
|
|
||||||
// Create a `Service` for responding to the request.
|
|
||||||
let service = hyper::service::service_fn(move |req| {
|
|
||||||
domiply::service::http::handle_request(service.clone(), req)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Return the service to hyper.
|
|
||||||
async move { Ok::<_, Infallible>(service) }
|
|
||||||
});
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let cert_resolver =
|
|
||||||
domiply::domain::acme::resolver::new(https_params.domain_acme_store);
|
|
||||||
let canceller = canceller.clone();
|
|
||||||
|
|
||||||
let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
|
|
||||||
rustls::server::ServerConfig::builder()
|
|
||||||
.with_safe_defaults()
|
|
||||||
.with_no_client_auth()
|
|
||||||
.with_cert_resolver(cert_resolver),
|
|
||||||
)
|
|
||||||
.into();
|
|
||||||
|
|
||||||
let addr = https_params.https_listen_addr;
|
|
||||||
let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr)
|
|
||||||
.expect("https listen socket creation failed");
|
|
||||||
|
|
||||||
let incoming =
|
|
||||||
tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| {
|
|
||||||
if let Err(err) = conn {
|
|
||||||
log::error!("Error accepting TLS connection: {:?}", err);
|
|
||||||
future::ready(false)
|
|
||||||
} else {
|
|
||||||
future::ready(true)
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let incoming = hyper::server::accept::from_stream(incoming);
|
|
||||||
|
|
||||||
log::info!(
|
|
||||||
"Listening on https://{}:{}",
|
|
||||||
http_domain.as_str(),
|
|
||||||
addr.port()
|
|
||||||
);
|
|
||||||
|
|
||||||
let server = hyper::Server::builder(incoming).serve(make_service);
|
|
||||||
|
|
||||||
let graceful = server.with_graceful_shutdown(async {
|
|
||||||
canceller.cancelled().await;
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Err(e) = graceful.await {
|
|
||||||
panic!("server error: {}", e);
|
|
||||||
};
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
while wait_group.next().await.is_some() {}
|
|
||||||
|
|
||||||
// TODO this is currently required so that we can be sure domain_manager is no longer used by
|
|
||||||
// anything else, and the into_inner below works. It would be great if service could accept a
|
|
||||||
// ref to domain_manager instead, and then maybe this wouldn't be needed?
|
|
||||||
drop(http_service);
|
|
||||||
|
|
||||||
sync::Arc::into_inner(domain_manager)
|
sync::Arc::into_inner(domain_manager)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.stop()
|
.stop()
|
||||||
|
@ -1,39 +1,85 @@
|
|||||||
|
mod tasks;
|
||||||
mod tpl;
|
mod tpl;
|
||||||
|
|
||||||
|
use futures::stream::futures_unordered::FuturesUnordered;
|
||||||
use hyper::{Body, Method, Request, Response};
|
use hyper::{Body, Method, Request, Response};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::future::Future;
|
|
||||||
use std::net;
|
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::sync;
|
use std::{future, net, sync};
|
||||||
|
|
||||||
use crate::{domain, origin, service};
|
use crate::{domain, origin, service};
|
||||||
|
|
||||||
type SvcResponse = Result<Response<hyper::body::Body>, String>;
|
type SvcResponse = Result<Response<hyper::body::Body>, String>;
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Service {
|
pub struct Service {
|
||||||
domain_manager: sync::Arc<dyn domain::manager::Manager>,
|
domain_manager: sync::Arc<dyn domain::manager::Manager>,
|
||||||
target_a: net::Ipv4Addr,
|
target_a: net::Ipv4Addr,
|
||||||
passphrase: String,
|
passphrase: String,
|
||||||
http_domain: domain::Name,
|
http_domain: domain::Name,
|
||||||
handlebars: handlebars::Handlebars<'static>,
|
handlebars: handlebars::Handlebars<'static>,
|
||||||
|
|
||||||
|
canceller: CancellationToken,
|
||||||
|
wait_group: FuturesUnordered<tokio::task::JoinHandle<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct HTTPSParams {
|
||||||
|
pub listen_addr: net::SocketAddr,
|
||||||
|
pub cert_resolver: sync::Arc<dyn rustls::server::ResolvesServerCert>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
domain_manager: sync::Arc<dyn domain::manager::Manager>,
|
domain_manager: sync::Arc<dyn domain::manager::Manager>,
|
||||||
target_a: net::Ipv4Addr,
|
target_a: net::Ipv4Addr,
|
||||||
passphrase: String,
|
passphrase: String,
|
||||||
|
http_listen_addr: net::SocketAddr,
|
||||||
http_domain: domain::Name,
|
http_domain: domain::Name,
|
||||||
) -> Service {
|
https_params: Option<HTTPSParams>,
|
||||||
Service {
|
) -> sync::Arc<Service> {
|
||||||
domain_manager,
|
let service = sync::Arc::new(Service {
|
||||||
|
domain_manager: domain_manager.clone(),
|
||||||
target_a,
|
target_a,
|
||||||
passphrase,
|
passphrase,
|
||||||
http_domain,
|
http_domain: http_domain.clone(),
|
||||||
handlebars: tpl::get(),
|
handlebars: tpl::get(),
|
||||||
|
canceller: CancellationToken::new(),
|
||||||
|
wait_group: FuturesUnordered::new(),
|
||||||
|
});
|
||||||
|
|
||||||
|
service.wait_group.push(tasks::listen_http(
|
||||||
|
service.clone(),
|
||||||
|
service.canceller.clone(),
|
||||||
|
http_listen_addr,
|
||||||
|
http_domain.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
if let Some(https_params) = https_params {
|
||||||
|
service.wait_group.push(tasks::listen_https(
|
||||||
|
service.clone(),
|
||||||
|
service.canceller.clone(),
|
||||||
|
https_params.cert_resolver,
|
||||||
|
https_params.listen_addr,
|
||||||
|
http_domain.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
service.wait_group.push(tasks::cert_refresher(
|
||||||
|
domain_manager,
|
||||||
|
service.canceller.clone(),
|
||||||
|
http_domain,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
return service;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Service {
|
||||||
|
pub async fn stop(self) {
|
||||||
|
self.canceller.cancel();
|
||||||
|
for f in self.wait_group {
|
||||||
|
f.await.expect("task failed");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,7 +208,7 @@ impl<'svc> Service {
|
|||||||
where
|
where
|
||||||
In: Deserialize<'a>,
|
In: Deserialize<'a>,
|
||||||
F: FnOnce(In) -> Out,
|
F: FnOnce(In) -> Out,
|
||||||
Out: Future<Output = SvcResponse>,
|
Out: future::Future<Output = SvcResponse>,
|
||||||
{
|
{
|
||||||
let query = req.uri().query().unwrap_or("");
|
let query = req.uri().query().unwrap_or("");
|
||||||
match serde_urlencoded::from_str::<In>(query) {
|
match serde_urlencoded::from_str::<In>(query) {
|
||||||
|
140
src/service/http/tasks.rs
Normal file
140
src/service/http/tasks.rs
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
use crate::{domain, service};
|
||||||
|
|
||||||
|
use std::{convert, future, net, sync};
|
||||||
|
|
||||||
|
use futures::StreamExt;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
pub fn listen_http(
|
||||||
|
service: sync::Arc<service::http::Service>,
|
||||||
|
canceller: CancellationToken,
|
||||||
|
addr: net::SocketAddr,
|
||||||
|
domain: domain::Name,
|
||||||
|
) -> tokio::task::JoinHandle<()> {
|
||||||
|
let make_service = hyper::service::make_service_fn(move |_| {
|
||||||
|
let service = service.clone();
|
||||||
|
|
||||||
|
// Create a `Service` for responding to the request.
|
||||||
|
let hyper_service = hyper::service::service_fn(move |req| {
|
||||||
|
service::http::handle_request(service.clone(), req)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Return the service to hyper.
|
||||||
|
async move { Ok::<_, convert::Infallible>(hyper_service) }
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
log::info!("Listening on http://{}:{}", domain.as_str(), addr.port());
|
||||||
|
let server = hyper::Server::bind(&addr).serve(make_service);
|
||||||
|
|
||||||
|
let graceful = server.with_graceful_shutdown(async {
|
||||||
|
canceller.cancelled().await;
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Err(e) = graceful.await {
|
||||||
|
panic!("server error: {}", e);
|
||||||
|
};
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn listen_https(
|
||||||
|
service: sync::Arc<service::http::Service>,
|
||||||
|
canceller: CancellationToken,
|
||||||
|
cert_resolver: sync::Arc<dyn rustls::server::ResolvesServerCert>,
|
||||||
|
addr: net::SocketAddr,
|
||||||
|
domain: domain::Name,
|
||||||
|
) -> tokio::task::JoinHandle<()> {
|
||||||
|
let make_service = hyper::service::make_service_fn(move |_| {
|
||||||
|
let service = service.clone();
|
||||||
|
|
||||||
|
// Create a `Service` for responding to the request.
|
||||||
|
let hyper_service = hyper::service::service_fn(move |req| {
|
||||||
|
service::http::handle_request(service.clone(), req)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Return the service to hyper.
|
||||||
|
async move { Ok::<_, convert::Infallible>(hyper_service) }
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
|
||||||
|
rustls::server::ServerConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_cert_resolver(cert_resolver),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr)
|
||||||
|
.expect("https listen socket creation failed");
|
||||||
|
|
||||||
|
let incoming =
|
||||||
|
tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| {
|
||||||
|
if let Err(err) = conn {
|
||||||
|
log::error!("Error accepting TLS connection: {:?}", err);
|
||||||
|
future::ready(false)
|
||||||
|
} else {
|
||||||
|
future::ready(true)
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let incoming = hyper::server::accept::from_stream(incoming);
|
||||||
|
|
||||||
|
log::info!("Listening on https://{}:{}", domain.as_str(), addr.port());
|
||||||
|
|
||||||
|
let server = hyper::Server::builder(incoming).serve(make_service);
|
||||||
|
|
||||||
|
let graceful = server.with_graceful_shutdown(async {
|
||||||
|
canceller.cancelled().await;
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Err(e) = graceful.await {
|
||||||
|
panic!("server error: {}", e);
|
||||||
|
};
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cert_refresher(
|
||||||
|
domain_manager: sync::Arc<dyn domain::manager::Manager>,
|
||||||
|
canceller: CancellationToken,
|
||||||
|
http_domain: domain::Name,
|
||||||
|
) -> tokio::task::JoinHandle<()> {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
use tokio::time;
|
||||||
|
|
||||||
|
let mut interval = time::interval(time::Duration::from_secs(60 * 60));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
_ = interval.tick() => (),
|
||||||
|
_ = canceller.cancelled() => return,
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = domain_manager
|
||||||
|
.sync_cert(http_domain.clone())
|
||||||
|
.await
|
||||||
|
.inspect_err(|err| {
|
||||||
|
log::error!(
|
||||||
|
"Error while getting cert for {}: {err}",
|
||||||
|
http_domain.as_str()
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
let domains_iter = domain_manager.all_domains();
|
||||||
|
|
||||||
|
if let Err(err) = domains_iter {
|
||||||
|
log::error!("Got error calling all_domains: {err}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for domain in domains_iter.unwrap().into_iter() {
|
||||||
|
let _ = domain_manager
|
||||||
|
.sync_cert(domain.clone())
|
||||||
|
.await
|
||||||
|
.inspect_err(|err| {
|
||||||
|
log::error!("Error while getting cert for {}: {err}", domain.as_str(),)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user