Use TaskSet to cleanly shut down the http service
This commit is contained in:
parent
43f4b98b38
commit
f2374cded5
4
TODO
4
TODO
@ -1 +1,3 @@
|
||||
- clean up main a lot
|
||||
- make domain_manager implement rusttls cert resolver
|
||||
- Try to switch from Arc to Box where possible
|
||||
- maybe build TaskSet into some kind of defer-like replacement
|
||||
|
@ -148,7 +148,7 @@ async fn main() {
|
||||
let domain_manager = sync::Arc::new(domain_manager);
|
||||
|
||||
{
|
||||
let http_service = domiply::service::http::new(
|
||||
let (http_service, http_service_task_set) = domiply::service::http::new(
|
||||
domain_manager.clone(),
|
||||
config.domain_checker_target_a,
|
||||
config.passphrase,
|
||||
@ -162,7 +162,7 @@ async fn main() {
|
||||
|
||||
canceller.cancelled().await;
|
||||
|
||||
sync::Arc::into_inner(http_service).unwrap().stop().await;
|
||||
domiply::service::http::stop(http_service, http_service_task_set).await;
|
||||
}
|
||||
|
||||
sync::Arc::into_inner(domain_manager)
|
||||
|
@ -1,16 +1,15 @@
|
||||
mod tasks;
|
||||
mod tpl;
|
||||
|
||||
use futures::stream::futures_unordered::FuturesUnordered;
|
||||
use hyper::{Body, Method, Request, Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::str::FromStr;
|
||||
use std::{future, net, sync};
|
||||
|
||||
use crate::{domain, origin, service};
|
||||
use crate::error::unexpected;
|
||||
use crate::{domain, origin, service, util};
|
||||
|
||||
type SvcResponse = Result<Response<hyper::body::Body>, String>;
|
||||
|
||||
@ -20,9 +19,6 @@ pub struct Service {
|
||||
passphrase: String,
|
||||
http_domain: domain::Name,
|
||||
handlebars: handlebars::Handlebars<'static>,
|
||||
|
||||
canceller: CancellationToken,
|
||||
wait_group: FuturesUnordered<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
pub struct HTTPSParams {
|
||||
@ -37,50 +33,52 @@ pub fn new(
|
||||
http_listen_addr: net::SocketAddr,
|
||||
http_domain: domain::Name,
|
||||
https_params: Option<HTTPSParams>,
|
||||
) -> sync::Arc<Service> {
|
||||
) -> (sync::Arc<Service>, util::TaskSet<unexpected::Error>) {
|
||||
let service = sync::Arc::new(Service {
|
||||
domain_manager: domain_manager.clone(),
|
||||
target_a,
|
||||
passphrase,
|
||||
http_domain: http_domain.clone(),
|
||||
handlebars: tpl::get(),
|
||||
canceller: CancellationToken::new(),
|
||||
wait_group: FuturesUnordered::new(),
|
||||
});
|
||||
|
||||
service.wait_group.push(tasks::listen_http(
|
||||
service.clone(),
|
||||
service.canceller.clone(),
|
||||
http_listen_addr,
|
||||
http_domain.clone(),
|
||||
));
|
||||
let task_set = util::TaskSet::new();
|
||||
|
||||
task_set.spawn(|canceller| {
|
||||
tasks::listen_http(
|
||||
service.clone(),
|
||||
canceller,
|
||||
http_listen_addr,
|
||||
http_domain.clone(),
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(https_params) = https_params {
|
||||
service.wait_group.push(tasks::listen_https(
|
||||
service.clone(),
|
||||
service.canceller.clone(),
|
||||
https_params.cert_resolver,
|
||||
https_params.listen_addr,
|
||||
http_domain.clone(),
|
||||
));
|
||||
task_set.spawn(|canceller| {
|
||||
tasks::listen_https(
|
||||
service.clone(),
|
||||
canceller,
|
||||
https_params.cert_resolver.clone(),
|
||||
https_params.listen_addr,
|
||||
http_domain.clone(),
|
||||
)
|
||||
});
|
||||
|
||||
service.wait_group.push(tasks::cert_refresher(
|
||||
domain_manager,
|
||||
service.canceller.clone(),
|
||||
http_domain,
|
||||
))
|
||||
task_set.spawn(|canceller| {
|
||||
tasks::cert_refresher(domain_manager.clone(), canceller, http_domain.clone())
|
||||
});
|
||||
}
|
||||
|
||||
return service;
|
||||
return (service, task_set);
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub async fn stop(self) {
|
||||
self.canceller.cancel();
|
||||
for f in self.wait_group {
|
||||
f.await.expect("task failed");
|
||||
}
|
||||
}
|
||||
pub async fn stop(service: sync::Arc<Service>, task_set: util::TaskSet<unexpected::Error>) {
|
||||
task_set
|
||||
.stop()
|
||||
.await
|
||||
.iter()
|
||||
.for_each(|e| log::error!("error while shutting down http service: {e}"));
|
||||
sync::Arc::into_inner(service).expect("service didn't get cleaned up");
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
@ -1,3 +1,4 @@
|
||||
use crate::error::unexpected::{self, Mappable};
|
||||
use crate::{domain, service};
|
||||
|
||||
use std::{convert, future, net, sync};
|
||||
@ -5,138 +6,127 @@ use std::{convert, future, net, sync};
|
||||
use futures::StreamExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub fn listen_http(
|
||||
pub async fn listen_http(
|
||||
service: sync::Arc<service::http::Service>,
|
||||
canceller: CancellationToken,
|
||||
addr: net::SocketAddr,
|
||||
domain: domain::Name,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
tokio::spawn(async move {
|
||||
let make_service = hyper::service::make_service_fn(move |_| {
|
||||
) -> Result<(), unexpected::Error> {
|
||||
let make_service = hyper::service::make_service_fn(move |_| {
|
||||
let service = service.clone();
|
||||
|
||||
// Create a `Service` for responding to the request.
|
||||
let hyper_service = hyper::service::service_fn(move |req| {
|
||||
let service = service.clone();
|
||||
|
||||
// Create a `Service` for responding to the request.
|
||||
let hyper_service = hyper::service::service_fn(move |req| {
|
||||
let service = service.clone();
|
||||
async move { service.handle_request(req).await }
|
||||
});
|
||||
|
||||
// Return the service to hyper.
|
||||
async move { Ok::<_, convert::Infallible>(hyper_service) }
|
||||
async move { service.handle_request(req).await }
|
||||
});
|
||||
|
||||
log::info!("Listening on http://{}:{}", domain.as_str(), addr.port());
|
||||
let server = hyper::Server::bind(&addr).serve(make_service);
|
||||
// Return the service to hyper.
|
||||
async move { Ok::<_, convert::Infallible>(hyper_service) }
|
||||
});
|
||||
|
||||
let graceful = server.with_graceful_shutdown(async {
|
||||
canceller.cancelled().await;
|
||||
});
|
||||
log::info!("Listening on http://{}:{}", domain.as_str(), addr.port());
|
||||
let server = hyper::Server::bind(&addr).serve(make_service);
|
||||
|
||||
if let Err(e) = graceful.await {
|
||||
panic!("server error: {}", e);
|
||||
};
|
||||
})
|
||||
let graceful = server.with_graceful_shutdown(async {
|
||||
canceller.cancelled().await;
|
||||
});
|
||||
|
||||
graceful.await.or_unexpected()
|
||||
}
|
||||
|
||||
pub fn listen_https(
|
||||
pub async fn listen_https(
|
||||
service: sync::Arc<service::http::Service>,
|
||||
canceller: CancellationToken,
|
||||
cert_resolver: sync::Arc<dyn rustls::server::ResolvesServerCert>,
|
||||
addr: net::SocketAddr,
|
||||
domain: domain::Name,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
tokio::spawn(async move {
|
||||
let make_service = hyper::service::make_service_fn(move |_| {
|
||||
) -> Result<(), unexpected::Error> {
|
||||
let make_service = hyper::service::make_service_fn(move |_| {
|
||||
let service = service.clone();
|
||||
|
||||
// Create a `Service` for responding to the request.
|
||||
let hyper_service = hyper::service::service_fn(move |req| {
|
||||
let service = service.clone();
|
||||
|
||||
// Create a `Service` for responding to the request.
|
||||
let hyper_service = hyper::service::service_fn(move |req| {
|
||||
let service = service.clone();
|
||||
async move { service.handle_request(req).await }
|
||||
});
|
||||
|
||||
// Return the service to hyper.
|
||||
async move { Ok::<_, convert::Infallible>(hyper_service) }
|
||||
async move { service.handle_request(req).await }
|
||||
});
|
||||
|
||||
let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
|
||||
rustls::server::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(cert_resolver),
|
||||
)
|
||||
.into();
|
||||
// Return the service to hyper.
|
||||
async move { Ok::<_, convert::Infallible>(hyper_service) }
|
||||
});
|
||||
|
||||
let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr)
|
||||
.expect("https listen socket creation failed");
|
||||
let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
|
||||
rustls::server::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(cert_resolver),
|
||||
)
|
||||
.into();
|
||||
|
||||
let incoming =
|
||||
tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| {
|
||||
if let Err(err) = conn {
|
||||
log::error!("Error accepting TLS connection: {:?}", err);
|
||||
future::ready(false)
|
||||
} else {
|
||||
future::ready(true)
|
||||
}
|
||||
});
|
||||
let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr)
|
||||
.expect("https listen socket creation failed");
|
||||
|
||||
let incoming = hyper::server::accept::from_stream(incoming);
|
||||
let incoming = tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| {
|
||||
if let Err(err) = conn {
|
||||
log::error!("Error accepting TLS connection: {:?}", err);
|
||||
future::ready(false)
|
||||
} else {
|
||||
future::ready(true)
|
||||
}
|
||||
});
|
||||
|
||||
log::info!("Listening on https://{}:{}", domain.as_str(), addr.port());
|
||||
let incoming = hyper::server::accept::from_stream(incoming);
|
||||
|
||||
let server = hyper::Server::builder(incoming).serve(make_service);
|
||||
log::info!("Listening on https://{}:{}", domain.as_str(), addr.port());
|
||||
|
||||
let graceful = server.with_graceful_shutdown(async {
|
||||
canceller.cancelled().await;
|
||||
});
|
||||
let server = hyper::Server::builder(incoming).serve(make_service);
|
||||
|
||||
if let Err(e) = graceful.await {
|
||||
panic!("server error: {}", e);
|
||||
};
|
||||
})
|
||||
let graceful = server.with_graceful_shutdown(async {
|
||||
canceller.cancelled().await;
|
||||
});
|
||||
|
||||
graceful.await.or_unexpected()
|
||||
}
|
||||
|
||||
pub fn cert_refresher(
|
||||
pub async fn cert_refresher(
|
||||
domain_manager: sync::Arc<dyn domain::manager::Manager>,
|
||||
canceller: CancellationToken,
|
||||
http_domain: domain::Name,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
tokio::spawn(async move {
|
||||
use tokio::time;
|
||||
) -> Result<(), unexpected::Error> {
|
||||
use tokio::time;
|
||||
|
||||
let mut interval = time::interval(time::Duration::from_secs(60 * 60));
|
||||
let mut interval = time::interval(time::Duration::from_secs(60 * 60));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = interval.tick() => (),
|
||||
_ = canceller.cancelled() => return,
|
||||
}
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = interval.tick() => (),
|
||||
_ = canceller.cancelled() => return Ok(()),
|
||||
}
|
||||
|
||||
_ = domain_manager
|
||||
.sync_cert(http_domain.clone())
|
||||
_ = domain_manager
|
||||
.sync_cert(http_domain.clone())
|
||||
.await
|
||||
.inspect_err(|err| {
|
||||
log::error!(
|
||||
"Error while getting cert for {}: {err}",
|
||||
http_domain.as_str()
|
||||
)
|
||||
});
|
||||
|
||||
let domains_iter = domain_manager.all_domains();
|
||||
|
||||
if let Err(err) = domains_iter {
|
||||
log::error!("Got error calling all_domains: {err}");
|
||||
continue;
|
||||
}
|
||||
|
||||
for domain in domains_iter.unwrap().into_iter() {
|
||||
let _ = domain_manager
|
||||
.sync_cert(domain.clone())
|
||||
.await
|
||||
.inspect_err(|err| {
|
||||
log::error!(
|
||||
"Error while getting cert for {}: {err}",
|
||||
http_domain.as_str()
|
||||
)
|
||||
log::error!("Error while getting cert for {}: {err}", domain.as_str(),)
|
||||
});
|
||||
|
||||
let domains_iter = domain_manager.all_domains();
|
||||
|
||||
if let Err(err) = domains_iter {
|
||||
log::error!("Got error calling all_domains: {err}");
|
||||
continue;
|
||||
}
|
||||
|
||||
for domain in domains_iter.unwrap().into_iter() {
|
||||
let _ = domain_manager
|
||||
.sync_cert(domain.clone())
|
||||
.await
|
||||
.inspect_err(|err| {
|
||||
log::error!("Error while getting cert for {}: {err}", domain.as_str(),)
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
47
src/util.rs
47
src/util.rs
@ -1,4 +1,7 @@
|
||||
use std::{fs, io, path};
|
||||
use std::{error, fs, io, path};
|
||||
|
||||
use futures::stream::futures_unordered::FuturesUnordered;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub fn open_file(path: &path::Path) -> io::Result<Option<fs::File>> {
|
||||
match fs::File::open(path) {
|
||||
@ -9,3 +12,45 @@ pub fn open_file(path: &path::Path) -> io::Result<Option<fs::File>> {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TaskSet<E>
|
||||
where
|
||||
E: error::Error + Send + 'static,
|
||||
{
|
||||
canceller: CancellationToken,
|
||||
wait_group: FuturesUnordered<tokio::task::JoinHandle<Result<(), E>>>,
|
||||
}
|
||||
|
||||
impl<E> TaskSet<E>
|
||||
where
|
||||
E: error::Error + Send + 'static,
|
||||
{
|
||||
pub fn new() -> TaskSet<E> {
|
||||
TaskSet {
|
||||
canceller: CancellationToken::new(),
|
||||
wait_group: FuturesUnordered::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn spawn<F, Fut>(&self, mut f: F)
|
||||
where
|
||||
Fut: futures::Future<Output = Result<(), E>> + Send + 'static,
|
||||
F: FnMut(CancellationToken) -> Fut,
|
||||
{
|
||||
let canceller = self.canceller.clone();
|
||||
let handle = tokio::spawn(f(canceller));
|
||||
self.wait_group.push(handle);
|
||||
}
|
||||
|
||||
pub async fn stop(self) -> Vec<E> {
|
||||
self.canceller.cancel();
|
||||
|
||||
let mut res = Vec::new();
|
||||
for f in self.wait_group {
|
||||
if let Err(err) = f.await.expect("task failed") {
|
||||
res.push(err);
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user