Use TaskSet to cleanly shut down the http service

This commit is contained in:
Brian Picciano 2023-06-19 20:56:14 +02:00
parent 43f4b98b38
commit f2374cded5
5 changed files with 170 additions and 135 deletions

4
TODO
View File

@ -1 +1,3 @@
- clean up main a lot - make domain_manager implement rusttls cert resolver
- Try to switch from Arc to Box where possible
- maybe build TaskSet into some kind of defer-like replacement

View File

@ -148,7 +148,7 @@ async fn main() {
let domain_manager = sync::Arc::new(domain_manager); let domain_manager = sync::Arc::new(domain_manager);
{ {
let http_service = domiply::service::http::new( let (http_service, http_service_task_set) = domiply::service::http::new(
domain_manager.clone(), domain_manager.clone(),
config.domain_checker_target_a, config.domain_checker_target_a,
config.passphrase, config.passphrase,
@ -162,7 +162,7 @@ async fn main() {
canceller.cancelled().await; canceller.cancelled().await;
sync::Arc::into_inner(http_service).unwrap().stop().await; domiply::service::http::stop(http_service, http_service_task_set).await;
} }
sync::Arc::into_inner(domain_manager) sync::Arc::into_inner(domain_manager)

View File

@ -1,16 +1,15 @@
mod tasks; mod tasks;
mod tpl; mod tpl;
use futures::stream::futures_unordered::FuturesUnordered;
use hyper::{Body, Method, Request, Response}; use hyper::{Body, Method, Request, Response};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use std::convert::Infallible; use std::convert::Infallible;
use std::str::FromStr; use std::str::FromStr;
use std::{future, net, sync}; use std::{future, net, sync};
use crate::{domain, origin, service}; use crate::error::unexpected;
use crate::{domain, origin, service, util};
type SvcResponse = Result<Response<hyper::body::Body>, String>; type SvcResponse = Result<Response<hyper::body::Body>, String>;
@ -20,9 +19,6 @@ pub struct Service {
passphrase: String, passphrase: String,
http_domain: domain::Name, http_domain: domain::Name,
handlebars: handlebars::Handlebars<'static>, handlebars: handlebars::Handlebars<'static>,
canceller: CancellationToken,
wait_group: FuturesUnordered<tokio::task::JoinHandle<()>>,
} }
pub struct HTTPSParams { pub struct HTTPSParams {
@ -37,50 +33,52 @@ pub fn new(
http_listen_addr: net::SocketAddr, http_listen_addr: net::SocketAddr,
http_domain: domain::Name, http_domain: domain::Name,
https_params: Option<HTTPSParams>, https_params: Option<HTTPSParams>,
) -> sync::Arc<Service> { ) -> (sync::Arc<Service>, util::TaskSet<unexpected::Error>) {
let service = sync::Arc::new(Service { let service = sync::Arc::new(Service {
domain_manager: domain_manager.clone(), domain_manager: domain_manager.clone(),
target_a, target_a,
passphrase, passphrase,
http_domain: http_domain.clone(), http_domain: http_domain.clone(),
handlebars: tpl::get(), handlebars: tpl::get(),
canceller: CancellationToken::new(),
wait_group: FuturesUnordered::new(),
}); });
service.wait_group.push(tasks::listen_http( let task_set = util::TaskSet::new();
service.clone(),
service.canceller.clone(), task_set.spawn(|canceller| {
http_listen_addr, tasks::listen_http(
http_domain.clone(), service.clone(),
)); canceller,
http_listen_addr,
http_domain.clone(),
)
});
if let Some(https_params) = https_params { if let Some(https_params) = https_params {
service.wait_group.push(tasks::listen_https( task_set.spawn(|canceller| {
service.clone(), tasks::listen_https(
service.canceller.clone(), service.clone(),
https_params.cert_resolver, canceller,
https_params.listen_addr, https_params.cert_resolver.clone(),
http_domain.clone(), https_params.listen_addr,
)); http_domain.clone(),
)
});
service.wait_group.push(tasks::cert_refresher( task_set.spawn(|canceller| {
domain_manager, tasks::cert_refresher(domain_manager.clone(), canceller, http_domain.clone())
service.canceller.clone(), });
http_domain,
))
} }
return service; return (service, task_set);
} }
impl Service { pub async fn stop(service: sync::Arc<Service>, task_set: util::TaskSet<unexpected::Error>) {
pub async fn stop(self) { task_set
self.canceller.cancel(); .stop()
for f in self.wait_group { .await
f.await.expect("task failed"); .iter()
} .for_each(|e| log::error!("error while shutting down http service: {e}"));
} sync::Arc::into_inner(service).expect("service didn't get cleaned up");
} }
#[derive(Serialize)] #[derive(Serialize)]

View File

@ -1,3 +1,4 @@
use crate::error::unexpected::{self, Mappable};
use crate::{domain, service}; use crate::{domain, service};
use std::{convert, future, net, sync}; use std::{convert, future, net, sync};
@ -5,138 +6,127 @@ use std::{convert, future, net, sync};
use futures::StreamExt; use futures::StreamExt;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
pub fn listen_http( pub async fn listen_http(
service: sync::Arc<service::http::Service>, service: sync::Arc<service::http::Service>,
canceller: CancellationToken, canceller: CancellationToken,
addr: net::SocketAddr, addr: net::SocketAddr,
domain: domain::Name, domain: domain::Name,
) -> tokio::task::JoinHandle<()> { ) -> Result<(), unexpected::Error> {
tokio::spawn(async move { let make_service = hyper::service::make_service_fn(move |_| {
let make_service = hyper::service::make_service_fn(move |_| { let service = service.clone();
// Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| {
let service = service.clone(); let service = service.clone();
async move { service.handle_request(req).await }
// Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| {
let service = service.clone();
async move { service.handle_request(req).await }
});
// Return the service to hyper.
async move { Ok::<_, convert::Infallible>(hyper_service) }
}); });
log::info!("Listening on http://{}:{}", domain.as_str(), addr.port()); // Return the service to hyper.
let server = hyper::Server::bind(&addr).serve(make_service); async move { Ok::<_, convert::Infallible>(hyper_service) }
});
let graceful = server.with_graceful_shutdown(async { log::info!("Listening on http://{}:{}", domain.as_str(), addr.port());
canceller.cancelled().await; let server = hyper::Server::bind(&addr).serve(make_service);
});
if let Err(e) = graceful.await { let graceful = server.with_graceful_shutdown(async {
panic!("server error: {}", e); canceller.cancelled().await;
}; });
})
graceful.await.or_unexpected()
} }
pub fn listen_https( pub async fn listen_https(
service: sync::Arc<service::http::Service>, service: sync::Arc<service::http::Service>,
canceller: CancellationToken, canceller: CancellationToken,
cert_resolver: sync::Arc<dyn rustls::server::ResolvesServerCert>, cert_resolver: sync::Arc<dyn rustls::server::ResolvesServerCert>,
addr: net::SocketAddr, addr: net::SocketAddr,
domain: domain::Name, domain: domain::Name,
) -> tokio::task::JoinHandle<()> { ) -> Result<(), unexpected::Error> {
tokio::spawn(async move { let make_service = hyper::service::make_service_fn(move |_| {
let make_service = hyper::service::make_service_fn(move |_| { let service = service.clone();
// Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| {
let service = service.clone(); let service = service.clone();
async move { service.handle_request(req).await }
// Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| {
let service = service.clone();
async move { service.handle_request(req).await }
});
// Return the service to hyper.
async move { Ok::<_, convert::Infallible>(hyper_service) }
}); });
let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new( // Return the service to hyper.
rustls::server::ServerConfig::builder() async move { Ok::<_, convert::Infallible>(hyper_service) }
.with_safe_defaults() });
.with_no_client_auth()
.with_cert_resolver(cert_resolver),
)
.into();
let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr) let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
.expect("https listen socket creation failed"); rustls::server::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(cert_resolver),
)
.into();
let incoming = let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr)
tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| { .expect("https listen socket creation failed");
if let Err(err) = conn {
log::error!("Error accepting TLS connection: {:?}", err);
future::ready(false)
} else {
future::ready(true)
}
});
let incoming = hyper::server::accept::from_stream(incoming); let incoming = tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| {
if let Err(err) = conn {
log::error!("Error accepting TLS connection: {:?}", err);
future::ready(false)
} else {
future::ready(true)
}
});
log::info!("Listening on https://{}:{}", domain.as_str(), addr.port()); let incoming = hyper::server::accept::from_stream(incoming);
let server = hyper::Server::builder(incoming).serve(make_service); log::info!("Listening on https://{}:{}", domain.as_str(), addr.port());
let graceful = server.with_graceful_shutdown(async { let server = hyper::Server::builder(incoming).serve(make_service);
canceller.cancelled().await;
});
if let Err(e) = graceful.await { let graceful = server.with_graceful_shutdown(async {
panic!("server error: {}", e); canceller.cancelled().await;
}; });
})
graceful.await.or_unexpected()
} }
pub fn cert_refresher( pub async fn cert_refresher(
domain_manager: sync::Arc<dyn domain::manager::Manager>, domain_manager: sync::Arc<dyn domain::manager::Manager>,
canceller: CancellationToken, canceller: CancellationToken,
http_domain: domain::Name, http_domain: domain::Name,
) -> tokio::task::JoinHandle<()> { ) -> Result<(), unexpected::Error> {
tokio::spawn(async move { use tokio::time;
use tokio::time;
let mut interval = time::interval(time::Duration::from_secs(60 * 60)); let mut interval = time::interval(time::Duration::from_secs(60 * 60));
loop { loop {
tokio::select! { tokio::select! {
_ = interval.tick() => (), _ = interval.tick() => (),
_ = canceller.cancelled() => return, _ = canceller.cancelled() => return Ok(()),
} }
_ = domain_manager _ = domain_manager
.sync_cert(http_domain.clone()) .sync_cert(http_domain.clone())
.await
.inspect_err(|err| {
log::error!(
"Error while getting cert for {}: {err}",
http_domain.as_str()
)
});
let domains_iter = domain_manager.all_domains();
if let Err(err) = domains_iter {
log::error!("Got error calling all_domains: {err}");
continue;
}
for domain in domains_iter.unwrap().into_iter() {
let _ = domain_manager
.sync_cert(domain.clone())
.await .await
.inspect_err(|err| { .inspect_err(|err| {
log::error!( log::error!("Error while getting cert for {}: {err}", domain.as_str(),)
"Error while getting cert for {}: {err}",
http_domain.as_str()
)
}); });
let domains_iter = domain_manager.all_domains();
if let Err(err) = domains_iter {
log::error!("Got error calling all_domains: {err}");
continue;
}
for domain in domains_iter.unwrap().into_iter() {
let _ = domain_manager
.sync_cert(domain.clone())
.await
.inspect_err(|err| {
log::error!("Error while getting cert for {}: {err}", domain.as_str(),)
});
}
} }
}) }
} }

View File

@ -1,4 +1,7 @@
use std::{fs, io, path}; use std::{error, fs, io, path};
use futures::stream::futures_unordered::FuturesUnordered;
use tokio_util::sync::CancellationToken;
pub fn open_file(path: &path::Path) -> io::Result<Option<fs::File>> { pub fn open_file(path: &path::Path) -> io::Result<Option<fs::File>> {
match fs::File::open(path) { match fs::File::open(path) {
@ -9,3 +12,45 @@ pub fn open_file(path: &path::Path) -> io::Result<Option<fs::File>> {
}, },
} }
} }
pub struct TaskSet<E>
where
E: error::Error + Send + 'static,
{
canceller: CancellationToken,
wait_group: FuturesUnordered<tokio::task::JoinHandle<Result<(), E>>>,
}
impl<E> TaskSet<E>
where
E: error::Error + Send + 'static,
{
pub fn new() -> TaskSet<E> {
TaskSet {
canceller: CancellationToken::new(),
wait_group: FuturesUnordered::new(),
}
}
pub fn spawn<F, Fut>(&self, mut f: F)
where
Fut: futures::Future<Output = Result<(), E>> + Send + 'static,
F: FnMut(CancellationToken) -> Fut,
{
let canceller = self.canceller.clone();
let handle = tokio::spawn(f(canceller));
self.wait_group.push(handle);
}
pub async fn stop(self) -> Vec<E> {
self.canceller.cancel();
let mut res = Vec::new();
for f in self.wait_group {
if let Err(err) = f.await.expect("task failed") {
res.push(err);
}
}
res
}
}