diff --git a/src/lib.rs b/src/lib.rs index 884cf00..1e3d3e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(result_option_inspect)] #![feature(iterator_try_collect)] pub mod domain; diff --git a/src/main.rs b/src/main.rs index 8d724a5..e844a26 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,10 @@ -#![feature(result_option_inspect)] - use clap::Parser; -use futures::stream::futures_unordered::FuturesUnordered; use futures::stream::StreamExt; use signal_hook_tokio::Signals; -use tokio::select; -use tokio::time; -use std::convert::Infallible; use std::net::SocketAddr; use std::str::FromStr; -use std::{future, path, sync}; - -use domiply::domain::manager::Manager; +use std::{path, sync}; #[derive(Parser, Debug)] #[command(version)] @@ -86,7 +78,6 @@ async fn main() { ) .init(); - let mut wait_group = FuturesUnordered::new(); let canceller = tokio_util::sync::CancellationToken::new(); { @@ -156,173 +147,24 @@ async fn main() { let domain_manager = sync::Arc::new(domain_manager); - let http_service = domiply::service::http::new( - domain_manager.clone(), - config.domain_checker_target_a, - config.passphrase, - config.http_domain.clone(), - ); - - let http_service = sync::Arc::new(http_service); - - wait_group.push({ - let http_domain = config.http_domain.clone(); - let canceller = canceller.clone(); - let service = http_service.clone(); - - let make_service = hyper::service::make_service_fn(move |_| { - let service = service.clone(); - - // Create a `Service` for responding to the request. - let service = hyper::service::service_fn(move |req| { - domiply::service::http::handle_request(service.clone(), req) - }); - - // Return the service to hyper. - async move { Ok::<_, Infallible>(service) } - }); - - tokio::spawn(async move { - let addr = config.http_listen_addr; - - log::info!( - "Listening on http://{}:{}", - http_domain.as_str(), - addr.port() - ); - let server = hyper::Server::bind(&addr).serve(make_service); - - let graceful = server.with_graceful_shutdown(async { - canceller.cancelled().await; - }); - - if let Err(e) = graceful.await { - panic!("server error: {}", e); - }; - }) - }); - - if let Some(https_params) = https_params { - // Periodically refresh all domain certs, including the http_domain passed in the Cli opts - wait_group.push({ - let domain_manager = domain_manager.clone(); - let http_domain = config.http_domain.clone(); - let canceller = canceller.clone(); - - tokio::spawn(async move { - let mut interval = time::interval(time::Duration::from_secs(60 * 60)); - - loop { - select! { - _ = interval.tick() => (), - _ = canceller.cancelled() => return, - } - - _ = domain_manager - .sync_cert(http_domain.clone()) - .await - .inspect_err(|err| { - log::error!( - "Error while getting cert for {}: {err}", - http_domain.as_str() - ) - }); - - let domains_iter = domain_manager.all_domains(); - - if let Err(err) = domains_iter { - log::error!("Got error calling all_domains: {err}"); - continue; - } - - for domain in domains_iter.unwrap().into_iter() { - let _ = domain_manager - .sync_cert(domain.clone()) - .await - .inspect_err(|err| { - log::error!( - "Error while getting cert for {}: {err}", - domain.as_str(), - ) - }); - } - } - }) - }); - - // HTTPS server - wait_group.push({ - let http_domain = config.http_domain.clone(); - let canceller = canceller.clone(); - let service = http_service.clone(); - - let make_service = hyper::service::make_service_fn(move |_| { - let service = service.clone(); - - // Create a `Service` for responding to the request. - let service = hyper::service::service_fn(move |req| { - domiply::service::http::handle_request(service.clone(), req) - }); - - // Return the service to hyper. - async move { Ok::<_, Infallible>(service) } - }); - - tokio::spawn(async move { - let cert_resolver = - domiply::domain::acme::resolver::new(https_params.domain_acme_store); - let canceller = canceller.clone(); - - let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new( - rustls::server::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_cert_resolver(cert_resolver), - ) - .into(); - - let addr = https_params.https_listen_addr; - let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr) - .expect("https listen socket creation failed"); - - let incoming = - tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| { - if let Err(err) = conn { - log::error!("Error accepting TLS connection: {:?}", err); - future::ready(false) - } else { - future::ready(true) - } - }); - - let incoming = hyper::server::accept::from_stream(incoming); - - log::info!( - "Listening on https://{}:{}", - http_domain.as_str(), - addr.port() - ); - - let server = hyper::Server::builder(incoming).serve(make_service); - - let graceful = server.with_graceful_shutdown(async { - canceller.cancelled().await; - }); - - if let Err(e) = graceful.await { - panic!("server error: {}", e); - }; - }) - }) + { + let http_service = domiply::service::http::new( + domain_manager.clone(), + config.domain_checker_target_a, + config.passphrase, + config.http_listen_addr.clone(), + config.http_domain.clone(), + https_params.map(|p| domiply::service::http::HTTPSParams { + listen_addr: p.https_listen_addr, + cert_resolver: domiply::domain::acme::resolver::new(p.domain_acme_store), + }), + ); + + canceller.cancelled().await; + + sync::Arc::into_inner(http_service).unwrap().stop().await; } - while wait_group.next().await.is_some() {} - - // TODO this is currently required so that we can be sure domain_manager is no longer used by - // anything else, and the into_inner below works. It would be great if service could accept a - // ref to domain_manager instead, and then maybe this wouldn't be needed? - drop(http_service); - sync::Arc::into_inner(domain_manager) .unwrap() .stop() diff --git a/src/service/http.rs b/src/service/http.rs index 8c2160f..8ea2692 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -1,39 +1,85 @@ +mod tasks; mod tpl; +use futures::stream::futures_unordered::FuturesUnordered; use hyper::{Body, Method, Request, Response}; use serde::{Deserialize, Serialize}; +use tokio_util::sync::CancellationToken; use std::convert::Infallible; -use std::future::Future; -use std::net; use std::str::FromStr; -use std::sync; +use std::{future, net, sync}; use crate::{domain, origin, service}; type SvcResponse = Result, String>; -#[derive(Clone)] pub struct Service { domain_manager: sync::Arc, target_a: net::Ipv4Addr, passphrase: String, http_domain: domain::Name, handlebars: handlebars::Handlebars<'static>, + + canceller: CancellationToken, + wait_group: FuturesUnordered>, +} + +pub struct HTTPSParams { + pub listen_addr: net::SocketAddr, + pub cert_resolver: sync::Arc, } pub fn new( domain_manager: sync::Arc, target_a: net::Ipv4Addr, passphrase: String, + http_listen_addr: net::SocketAddr, http_domain: domain::Name, -) -> Service { - Service { - domain_manager, + https_params: Option, +) -> sync::Arc { + let service = sync::Arc::new(Service { + domain_manager: domain_manager.clone(), target_a, passphrase, - http_domain, + http_domain: http_domain.clone(), handlebars: tpl::get(), + canceller: CancellationToken::new(), + wait_group: FuturesUnordered::new(), + }); + + service.wait_group.push(tasks::listen_http( + service.clone(), + service.canceller.clone(), + http_listen_addr, + http_domain.clone(), + )); + + if let Some(https_params) = https_params { + service.wait_group.push(tasks::listen_https( + service.clone(), + service.canceller.clone(), + https_params.cert_resolver, + https_params.listen_addr, + http_domain.clone(), + )); + + service.wait_group.push(tasks::cert_refresher( + domain_manager, + service.canceller.clone(), + http_domain, + )) + } + + return service; +} + +impl Service { + pub async fn stop(self) { + self.canceller.cancel(); + for f in self.wait_group { + f.await.expect("task failed"); + } } } @@ -162,7 +208,7 @@ impl<'svc> Service { where In: Deserialize<'a>, F: FnOnce(In) -> Out, - Out: Future, + Out: future::Future, { let query = req.uri().query().unwrap_or(""); match serde_urlencoded::from_str::(query) { diff --git a/src/service/http/tasks.rs b/src/service/http/tasks.rs new file mode 100644 index 0000000..29750ce --- /dev/null +++ b/src/service/http/tasks.rs @@ -0,0 +1,140 @@ +use crate::{domain, service}; + +use std::{convert, future, net, sync}; + +use futures::StreamExt; +use tokio_util::sync::CancellationToken; + +pub fn listen_http( + service: sync::Arc, + canceller: CancellationToken, + addr: net::SocketAddr, + domain: domain::Name, +) -> tokio::task::JoinHandle<()> { + let make_service = hyper::service::make_service_fn(move |_| { + let service = service.clone(); + + // Create a `Service` for responding to the request. + let hyper_service = hyper::service::service_fn(move |req| { + service::http::handle_request(service.clone(), req) + }); + + // Return the service to hyper. + async move { Ok::<_, convert::Infallible>(hyper_service) } + }); + + tokio::spawn(async move { + log::info!("Listening on http://{}:{}", domain.as_str(), addr.port()); + let server = hyper::Server::bind(&addr).serve(make_service); + + let graceful = server.with_graceful_shutdown(async { + canceller.cancelled().await; + }); + + if let Err(e) = graceful.await { + panic!("server error: {}", e); + }; + }) +} + +pub fn listen_https( + service: sync::Arc, + canceller: CancellationToken, + cert_resolver: sync::Arc, + addr: net::SocketAddr, + domain: domain::Name, +) -> tokio::task::JoinHandle<()> { + let make_service = hyper::service::make_service_fn(move |_| { + let service = service.clone(); + + // Create a `Service` for responding to the request. + let hyper_service = hyper::service::service_fn(move |req| { + service::http::handle_request(service.clone(), req) + }); + + // Return the service to hyper. + async move { Ok::<_, convert::Infallible>(hyper_service) } + }); + + tokio::spawn(async move { + let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new( + rustls::server::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(cert_resolver), + ) + .into(); + + let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr) + .expect("https listen socket creation failed"); + + let incoming = + tls_listener::TlsListener::new(server_config, addr_incoming).filter(|conn| { + if let Err(err) = conn { + log::error!("Error accepting TLS connection: {:?}", err); + future::ready(false) + } else { + future::ready(true) + } + }); + + let incoming = hyper::server::accept::from_stream(incoming); + + log::info!("Listening on https://{}:{}", domain.as_str(), addr.port()); + + let server = hyper::Server::builder(incoming).serve(make_service); + + let graceful = server.with_graceful_shutdown(async { + canceller.cancelled().await; + }); + + if let Err(e) = graceful.await { + panic!("server error: {}", e); + }; + }) +} + +pub fn cert_refresher( + domain_manager: sync::Arc, + canceller: CancellationToken, + http_domain: domain::Name, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + use tokio::time; + + let mut interval = time::interval(time::Duration::from_secs(60 * 60)); + + loop { + tokio::select! { + _ = interval.tick() => (), + _ = canceller.cancelled() => return, + } + + _ = domain_manager + .sync_cert(http_domain.clone()) + .await + .inspect_err(|err| { + log::error!( + "Error while getting cert for {}: {err}", + http_domain.as_str() + ) + }); + + let domains_iter = domain_manager.all_domains(); + + if let Err(err) = domains_iter { + log::error!("Got error calling all_domains: {err}"); + continue; + } + + for domain in domains_iter.unwrap().into_iter() { + let _ = domain_manager + .sync_cert(domain.clone()) + .await + .inspect_err(|err| { + log::error!("Error while getting cert for {}: {err}", domain.as_str(),) + }); + } + } + }) +}