diff --git a/TODO b/TODO index 8eef59c..4d39ffb 100644 --- a/TODO +++ b/TODO @@ -1 +1,3 @@ -- clean up main a lot +- make domain_manager implement rusttls cert resolver +- Try to switch from Arc to Box where possible +- maybe build TaskSet into some kind of defer-like replacement diff --git a/src/main.rs b/src/main.rs index e844a26..bf3f1ab 100644 --- a/src/main.rs +++ b/src/main.rs @@ -148,7 +148,7 @@ async fn main() { let domain_manager = sync::Arc::new(domain_manager); { - let http_service = domiply::service::http::new( + let (http_service, http_service_task_set) = domiply::service::http::new( domain_manager.clone(), config.domain_checker_target_a, config.passphrase, @@ -162,7 +162,7 @@ async fn main() { canceller.cancelled().await; - sync::Arc::into_inner(http_service).unwrap().stop().await; + domiply::service::http::stop(http_service, http_service_task_set).await; } sync::Arc::into_inner(domain_manager) diff --git a/src/service/http.rs b/src/service/http.rs index a0a44c3..6e74e96 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -1,16 +1,15 @@ mod tasks; mod tpl; -use futures::stream::futures_unordered::FuturesUnordered; use hyper::{Body, Method, Request, Response}; use serde::{Deserialize, Serialize}; -use tokio_util::sync::CancellationToken; use std::convert::Infallible; use std::str::FromStr; use std::{future, net, sync}; -use crate::{domain, origin, service}; +use crate::error::unexpected; +use crate::{domain, origin, service, util}; type SvcResponse = Result, String>; @@ -20,9 +19,6 @@ pub struct Service { passphrase: String, http_domain: domain::Name, handlebars: handlebars::Handlebars<'static>, - - canceller: CancellationToken, - wait_group: FuturesUnordered>, } pub struct HTTPSParams { @@ -37,50 +33,52 @@ pub fn new( http_listen_addr: net::SocketAddr, http_domain: domain::Name, https_params: Option, -) -> sync::Arc { +) -> (sync::Arc, util::TaskSet) { let service = sync::Arc::new(Service { domain_manager: domain_manager.clone(), target_a, passphrase, http_domain: http_domain.clone(), handlebars: tpl::get(), - canceller: CancellationToken::new(), - wait_group: FuturesUnordered::new(), }); - service.wait_group.push(tasks::listen_http( - service.clone(), - service.canceller.clone(), - http_listen_addr, - http_domain.clone(), - )); + let task_set = util::TaskSet::new(); + + task_set.spawn(|canceller| { + tasks::listen_http( + service.clone(), + canceller, + http_listen_addr, + http_domain.clone(), + ) + }); if let Some(https_params) = https_params { - service.wait_group.push(tasks::listen_https( - service.clone(), - service.canceller.clone(), - https_params.cert_resolver, - https_params.listen_addr, - http_domain.clone(), - )); + task_set.spawn(|canceller| { + tasks::listen_https( + service.clone(), + canceller, + https_params.cert_resolver.clone(), + https_params.listen_addr, + http_domain.clone(), + ) + }); - service.wait_group.push(tasks::cert_refresher( - domain_manager, - service.canceller.clone(), - http_domain, - )) + task_set.spawn(|canceller| { + tasks::cert_refresher(domain_manager.clone(), canceller, http_domain.clone()) + }); } - return service; + return (service, task_set); } -impl Service { - pub async fn stop(self) { - self.canceller.cancel(); - for f in self.wait_group { - f.await.expect("task failed"); - } - } +pub async fn stop(service: sync::Arc, task_set: util::TaskSet) { + task_set + .stop() + .await + .iter() + .for_each(|e| log::error!("error while shutting down http service: {e}")); + sync::Arc::into_inner(service).expect("service didn't get cleaned up"); } #[derive(Serialize)] diff --git a/src/service/http/tasks.rs b/src/service/http/tasks.rs index 9a0f136..532a92c 100644 --- a/src/service/http/tasks.rs +++ b/src/service/http/tasks.rs @@ -1,3 +1,4 @@ +use crate::error::unexpected::{self, Mappable}; use crate::{domain, service}; use std::{convert, future, net, sync}; @@ -5,138 +6,127 @@ use std::{convert, future, net, sync}; use futures::StreamExt; use tokio_util::sync::CancellationToken; -pub fn listen_http( +pub async fn listen_http( service: sync::Arc, canceller: CancellationToken, addr: net::SocketAddr, domain: domain::Name, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let make_service = hyper::service::make_service_fn(move |_| { +) -> Result<(), unexpected::Error> { + let make_service = hyper::service::make_service_fn(move |_| { + let service = service.clone(); + + // Create a `Service` for responding to the request. + let hyper_service = hyper::service::service_fn(move |req| { let service = service.clone(); - - // Create a `Service` for responding to the request. - let hyper_service = hyper::service::service_fn(move |req| { - let service = service.clone(); - async move { service.handle_request(req).await } - }); - - // Return the service to hyper. - async move { Ok::<_, convert::Infallible>(hyper_service) } + async move { service.handle_request(req).await } }); - log::info!("Listening on http://{}:{}", domain.as_str(), addr.port()); - let server = hyper::Server::bind(&addr).serve(make_service); + // Return the service to hyper. + async move { Ok::<_, convert::Infallible>(hyper_service) } + }); - let graceful = server.with_graceful_shutdown(async { - canceller.cancelled().await; - }); + log::info!("Listening on http://{}:{}", domain.as_str(), addr.port()); + let server = hyper::Server::bind(&addr).serve(make_service); - if let Err(e) = graceful.await { - panic!("server error: {}", e); - }; - }) + let graceful = server.with_graceful_shutdown(async { + canceller.cancelled().await; + }); + + graceful.await.or_unexpected() } -pub fn listen_https( +pub async fn listen_https( service: sync::Arc, canceller: CancellationToken, cert_resolver: sync::Arc, addr: net::SocketAddr, domain: domain::Name, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let make_service = hyper::service::make_service_fn(move |_| { +) -> Result<(), unexpected::Error> { + let make_service = hyper::service::make_service_fn(move |_| { + let service = service.clone(); + + // Create a `Service` for responding to the request. + let hyper_service = hyper::service::service_fn(move |req| { let service = service.clone(); - - // Create a `Service` for responding to the request. - let hyper_service = hyper::service::service_fn(move |req| { - let service = service.clone(); - async move { service.handle_request(req).await } - }); - - // Return the service to hyper. - async move { Ok::<_, convert::Infallible>(hyper_service) } + async move { service.handle_request(req).await } }); - let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new( - rustls::server::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_cert_resolver(cert_resolver), - ) - .into(); + // Return the service to hyper. + async move { Ok::<_, convert::Infallible>(hyper_service) } + }); - let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr) - .expect("https listen socket creation failed"); + let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new( + rustls::server::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(cert_resolver), + ) + .into(); - let incoming = - tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| { - if let Err(err) = conn { - log::error!("Error accepting TLS connection: {:?}", err); - future::ready(false) - } else { - future::ready(true) - } - }); + let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr) + .expect("https listen socket creation failed"); - let incoming = hyper::server::accept::from_stream(incoming); + let incoming = tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| { + if let Err(err) = conn { + log::error!("Error accepting TLS connection: {:?}", err); + future::ready(false) + } else { + future::ready(true) + } + }); - log::info!("Listening on https://{}:{}", domain.as_str(), addr.port()); + let incoming = hyper::server::accept::from_stream(incoming); - let server = hyper::Server::builder(incoming).serve(make_service); + log::info!("Listening on https://{}:{}", domain.as_str(), addr.port()); - let graceful = server.with_graceful_shutdown(async { - canceller.cancelled().await; - }); + let server = hyper::Server::builder(incoming).serve(make_service); - if let Err(e) = graceful.await { - panic!("server error: {}", e); - }; - }) + let graceful = server.with_graceful_shutdown(async { + canceller.cancelled().await; + }); + + graceful.await.or_unexpected() } -pub fn cert_refresher( +pub async fn cert_refresher( domain_manager: sync::Arc, canceller: CancellationToken, http_domain: domain::Name, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - use tokio::time; +) -> Result<(), unexpected::Error> { + use tokio::time; - let mut interval = time::interval(time::Duration::from_secs(60 * 60)); + let mut interval = time::interval(time::Duration::from_secs(60 * 60)); - loop { - tokio::select! { - _ = interval.tick() => (), - _ = canceller.cancelled() => return, - } + loop { + tokio::select! { + _ = interval.tick() => (), + _ = canceller.cancelled() => return Ok(()), + } - _ = domain_manager - .sync_cert(http_domain.clone()) + _ = domain_manager + .sync_cert(http_domain.clone()) + .await + .inspect_err(|err| { + log::error!( + "Error while getting cert for {}: {err}", + http_domain.as_str() + ) + }); + + let domains_iter = domain_manager.all_domains(); + + if let Err(err) = domains_iter { + log::error!("Got error calling all_domains: {err}"); + continue; + } + + for domain in domains_iter.unwrap().into_iter() { + let _ = domain_manager + .sync_cert(domain.clone()) .await .inspect_err(|err| { - log::error!( - "Error while getting cert for {}: {err}", - http_domain.as_str() - ) + log::error!("Error while getting cert for {}: {err}", domain.as_str(),) }); - - let domains_iter = domain_manager.all_domains(); - - if let Err(err) = domains_iter { - log::error!("Got error calling all_domains: {err}"); - continue; - } - - for domain in domains_iter.unwrap().into_iter() { - let _ = domain_manager - .sync_cert(domain.clone()) - .await - .inspect_err(|err| { - log::error!("Error while getting cert for {}: {err}", domain.as_str(),) - }); - } } - }) + } } diff --git a/src/util.rs b/src/util.rs index d0eef25..cbd2696 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,4 +1,7 @@ -use std::{fs, io, path}; +use std::{error, fs, io, path}; + +use futures::stream::futures_unordered::FuturesUnordered; +use tokio_util::sync::CancellationToken; pub fn open_file(path: &path::Path) -> io::Result> { match fs::File::open(path) { @@ -9,3 +12,45 @@ pub fn open_file(path: &path::Path) -> io::Result> { }, } } + +pub struct TaskSet +where + E: error::Error + Send + 'static, +{ + canceller: CancellationToken, + wait_group: FuturesUnordered>>, +} + +impl TaskSet +where + E: error::Error + Send + 'static, +{ + pub fn new() -> TaskSet { + TaskSet { + canceller: CancellationToken::new(), + wait_group: FuturesUnordered::new(), + } + } + + pub fn spawn(&self, mut f: F) + where + Fut: futures::Future> + Send + 'static, + F: FnMut(CancellationToken) -> Fut, + { + let canceller = self.canceller.clone(); + let handle = tokio::spawn(f(canceller)); + self.wait_group.push(handle); + } + + pub async fn stop(self) -> Vec { + self.canceller.cancel(); + + let mut res = Vec::new(); + for f in self.wait_group { + if let Err(err) = f.await.expect("task failed") { + res.push(err); + } + } + res + } +}