Moved service tasks internally, main crashes on shutdown though

This commit is contained in:
Brian Picciano 2023-06-18 15:57:51 +02:00
parent 7ea97b2617
commit 506037dcd0
4 changed files with 211 additions and 182 deletions

View File

@ -1,3 +1,4 @@
#![feature(result_option_inspect)]
#![feature(iterator_try_collect)] #![feature(iterator_try_collect)]
pub mod domain; pub mod domain;

View File

@ -1,18 +1,10 @@
#![feature(result_option_inspect)]
use clap::Parser; use clap::Parser;
use futures::stream::futures_unordered::FuturesUnordered;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use signal_hook_tokio::Signals; use signal_hook_tokio::Signals;
use tokio::select;
use tokio::time;
use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::str::FromStr; use std::str::FromStr;
use std::{future, path, sync}; use std::{path, sync};
use domiply::domain::manager::Manager;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(version)] #[command(version)]
@ -86,7 +78,6 @@ async fn main() {
) )
.init(); .init();
let mut wait_group = FuturesUnordered::new();
let canceller = tokio_util::sync::CancellationToken::new(); let canceller = tokio_util::sync::CancellationToken::new();
{ {
@ -156,173 +147,24 @@ async fn main() {
let domain_manager = sync::Arc::new(domain_manager); let domain_manager = sync::Arc::new(domain_manager);
{
let http_service = domiply::service::http::new( let http_service = domiply::service::http::new(
domain_manager.clone(), domain_manager.clone(),
config.domain_checker_target_a, config.domain_checker_target_a,
config.passphrase, config.passphrase,
config.http_listen_addr.clone(),
config.http_domain.clone(), config.http_domain.clone(),
https_params.map(|p| domiply::service::http::HTTPSParams {
listen_addr: p.https_listen_addr,
cert_resolver: domiply::domain::acme::resolver::new(p.domain_acme_store),
}),
); );
let http_service = sync::Arc::new(http_service);
wait_group.push({
let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
let service = http_service.clone();
let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone();
// Create a `Service` for responding to the request.
let service = hyper::service::service_fn(move |req| {
domiply::service::http::handle_request(service.clone(), req)
});
// Return the service to hyper.
async move { Ok::<_, Infallible>(service) }
});
tokio::spawn(async move {
let addr = config.http_listen_addr;
log::info!(
"Listening on http://{}:{}",
http_domain.as_str(),
addr.port()
);
let server = hyper::Server::bind(&addr).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
canceller.cancelled().await; canceller.cancelled().await;
});
if let Err(e) = graceful.await { sync::Arc::into_inner(http_service).unwrap().stop().await;
panic!("server error: {}", e);
};
})
});
if let Some(https_params) = https_params {
// Periodically refresh all domain certs, including the http_domain passed in the Cli opts
wait_group.push({
let domain_manager = domain_manager.clone();
let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
tokio::spawn(async move {
let mut interval = time::interval(time::Duration::from_secs(60 * 60));
loop {
select! {
_ = interval.tick() => (),
_ = canceller.cancelled() => return,
} }
_ = domain_manager
.sync_cert(http_domain.clone())
.await
.inspect_err(|err| {
log::error!(
"Error while getting cert for {}: {err}",
http_domain.as_str()
)
});
let domains_iter = domain_manager.all_domains();
if let Err(err) = domains_iter {
log::error!("Got error calling all_domains: {err}");
continue;
}
for domain in domains_iter.unwrap().into_iter() {
let _ = domain_manager
.sync_cert(domain.clone())
.await
.inspect_err(|err| {
log::error!(
"Error while getting cert for {}: {err}",
domain.as_str(),
)
});
}
}
})
});
// HTTPS server
wait_group.push({
let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
let service = http_service.clone();
let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone();
// Create a `Service` for responding to the request.
let service = hyper::service::service_fn(move |req| {
domiply::service::http::handle_request(service.clone(), req)
});
// Return the service to hyper.
async move { Ok::<_, Infallible>(service) }
});
tokio::spawn(async move {
let cert_resolver =
domiply::domain::acme::resolver::new(https_params.domain_acme_store);
let canceller = canceller.clone();
let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
rustls::server::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(cert_resolver),
)
.into();
let addr = https_params.https_listen_addr;
let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr)
.expect("https listen socket creation failed");
let incoming =
tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| {
if let Err(err) = conn {
log::error!("Error accepting TLS connection: {:?}", err);
future::ready(false)
} else {
future::ready(true)
}
});
let incoming = hyper::server::accept::from_stream(incoming);
log::info!(
"Listening on https://{}:{}",
http_domain.as_str(),
addr.port()
);
let server = hyper::Server::builder(incoming).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
canceller.cancelled().await;
});
if let Err(e) = graceful.await {
panic!("server error: {}", e);
};
})
})
}
while wait_group.next().await.is_some() {}
// TODO this is currently required so that we can be sure domain_manager is no longer used by
// anything else, and the into_inner below works. It would be great if service could accept a
// ref to domain_manager instead, and then maybe this wouldn't be needed?
drop(http_service);
sync::Arc::into_inner(domain_manager) sync::Arc::into_inner(domain_manager)
.unwrap() .unwrap()
.stop() .stop()

View File

@ -1,39 +1,85 @@
mod tasks;
mod tpl; mod tpl;
use futures::stream::futures_unordered::FuturesUnordered;
use hyper::{Body, Method, Request, Response}; use hyper::{Body, Method, Request, Response};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use std::convert::Infallible; use std::convert::Infallible;
use std::future::Future;
use std::net;
use std::str::FromStr; use std::str::FromStr;
use std::sync; use std::{future, net, sync};
use crate::{domain, origin, service}; use crate::{domain, origin, service};
type SvcResponse = Result<Response<hyper::body::Body>, String>; type SvcResponse = Result<Response<hyper::body::Body>, String>;
#[derive(Clone)]
pub struct Service { pub struct Service {
domain_manager: sync::Arc<dyn domain::manager::Manager>, domain_manager: sync::Arc<dyn domain::manager::Manager>,
target_a: net::Ipv4Addr, target_a: net::Ipv4Addr,
passphrase: String, passphrase: String,
http_domain: domain::Name, http_domain: domain::Name,
handlebars: handlebars::Handlebars<'static>, handlebars: handlebars::Handlebars<'static>,
canceller: CancellationToken,
wait_group: FuturesUnordered<tokio::task::JoinHandle<()>>,
}
pub struct HTTPSParams {
pub listen_addr: net::SocketAddr,
pub cert_resolver: sync::Arc<dyn rustls::server::ResolvesServerCert>,
} }
pub fn new( pub fn new(
domain_manager: sync::Arc<dyn domain::manager::Manager>, domain_manager: sync::Arc<dyn domain::manager::Manager>,
target_a: net::Ipv4Addr, target_a: net::Ipv4Addr,
passphrase: String, passphrase: String,
http_listen_addr: net::SocketAddr,
http_domain: domain::Name, http_domain: domain::Name,
) -> Service { https_params: Option<HTTPSParams>,
Service { ) -> sync::Arc<Service> {
domain_manager, let service = sync::Arc::new(Service {
domain_manager: domain_manager.clone(),
target_a, target_a,
passphrase, passphrase,
http_domain, http_domain: http_domain.clone(),
handlebars: tpl::get(), handlebars: tpl::get(),
canceller: CancellationToken::new(),
wait_group: FuturesUnordered::new(),
});
service.wait_group.push(tasks::listen_http(
service.clone(),
service.canceller.clone(),
http_listen_addr,
http_domain.clone(),
));
if let Some(https_params) = https_params {
service.wait_group.push(tasks::listen_https(
service.clone(),
service.canceller.clone(),
https_params.cert_resolver,
https_params.listen_addr,
http_domain.clone(),
));
service.wait_group.push(tasks::cert_refresher(
domain_manager,
service.canceller.clone(),
http_domain,
))
}
return service;
}
impl Service {
pub async fn stop(self) {
self.canceller.cancel();
for f in self.wait_group {
f.await.expect("task failed");
}
} }
} }
@ -162,7 +208,7 @@ impl<'svc> Service {
where where
In: Deserialize<'a>, In: Deserialize<'a>,
F: FnOnce(In) -> Out, F: FnOnce(In) -> Out,
Out: Future<Output = SvcResponse>, Out: future::Future<Output = SvcResponse>,
{ {
let query = req.uri().query().unwrap_or(""); let query = req.uri().query().unwrap_or("");
match serde_urlencoded::from_str::<In>(query) { match serde_urlencoded::from_str::<In>(query) {

140
src/service/http/tasks.rs Normal file
View File

@ -0,0 +1,140 @@
use crate::{domain, service};
use std::{convert, future, net, sync};
use futures::StreamExt;
use tokio_util::sync::CancellationToken;
pub fn listen_http(
service: sync::Arc<service::http::Service>,
canceller: CancellationToken,
addr: net::SocketAddr,
domain: domain::Name,
) -> tokio::task::JoinHandle<()> {
let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone();
// Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| {
service::http::handle_request(service.clone(), req)
});
// Return the service to hyper.
async move { Ok::<_, convert::Infallible>(hyper_service) }
});
tokio::spawn(async move {
log::info!("Listening on http://{}:{}", domain.as_str(), addr.port());
let server = hyper::Server::bind(&addr).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
canceller.cancelled().await;
});
if let Err(e) = graceful.await {
panic!("server error: {}", e);
};
})
}
pub fn listen_https(
service: sync::Arc<service::http::Service>,
canceller: CancellationToken,
cert_resolver: sync::Arc<dyn rustls::server::ResolvesServerCert>,
addr: net::SocketAddr,
domain: domain::Name,
) -> tokio::task::JoinHandle<()> {
let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone();
// Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| {
service::http::handle_request(service.clone(), req)
});
// Return the service to hyper.
async move { Ok::<_, convert::Infallible>(hyper_service) }
});
tokio::spawn(async move {
let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
rustls::server::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(cert_resolver),
)
.into();
let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr)
.expect("https listen socket creation failed");
let incoming =
tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| {
if let Err(err) = conn {
log::error!("Error accepting TLS connection: {:?}", err);
future::ready(false)
} else {
future::ready(true)
}
});
let incoming = hyper::server::accept::from_stream(incoming);
log::info!("Listening on https://{}:{}", domain.as_str(), addr.port());
let server = hyper::Server::builder(incoming).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
canceller.cancelled().await;
});
if let Err(e) = graceful.await {
panic!("server error: {}", e);
};
})
}
pub fn cert_refresher(
domain_manager: sync::Arc<dyn domain::manager::Manager>,
canceller: CancellationToken,
http_domain: domain::Name,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
use tokio::time;
let mut interval = time::interval(time::Duration::from_secs(60 * 60));
loop {
tokio::select! {
_ = interval.tick() => (),
_ = canceller.cancelled() => return,
}
_ = domain_manager
.sync_cert(http_domain.clone())
.await
.inspect_err(|err| {
log::error!(
"Error while getting cert for {}: {err}",
http_domain.as_str()
)
});
let domains_iter = domain_manager.all_domains();
if let Err(err) = domains_iter {
log::error!("Got error calling all_domains: {err}");
continue;
}
for domain in domains_iter.unwrap().into_iter() {
let _ = domain_manager
.sync_cert(domain.clone())
.await
.inspect_err(|err| {
log::error!("Error while getting cert for {}: {err}", domain.as_str(),)
});
}
}
})
}