diff --git a/Cargo.lock b/Cargo.lock index 1f9b7e6..fe0c74a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -470,7 +470,9 @@ dependencies = [ "signal-hook-tokio", "tempdir", "thiserror", + "tls-listener", "tokio", + "tokio-rustls 0.24.0", "tokio-util", "trust-dns-client", ] @@ -1500,7 +1502,7 @@ dependencies = [ "hyper", "rustls 0.20.8", "tokio", - "tokio-rustls", + "tokio-rustls 0.23.4", ] [[package]] @@ -2322,7 +2324,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "tokio", - "tokio-rustls", + "tokio-rustls 0.23.4", "tower-service", "trust-dns-resolver", "url", @@ -2757,6 +2759,20 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tls-listener" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81294c017957a1a69794f506723519255879e15a870507faf45dfed288b763dd" +dependencies = [ + "futures-util", + "hyper", + "pin-project-lite", + "thiserror", + "tokio", + "tokio-rustls 0.24.0", +] + [[package]] name = "tokio" version = "1.28.1" @@ -2798,6 +2814,16 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-rustls" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5" +dependencies = [ + "rustls 0.21.1", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.8" diff --git a/Cargo.toml b/Cargo.toml index 6fa46f1..6c5c1c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ clap = { version = "4.2.7", features = ["derive", "env"] } handlebars = { version = "4.3.7", features = [ "rust-embed" ]} rust-embed = "6.6.1" mime_guess = "2.0.4" -hyper = { version = "0.14.26", features = [ "server" ]} +hyper = { version = "0.14.26", features = [ "server", "stream" ]} http = "0.2.9" serde_urlencoded = "0.7.1" tokio-util = "0.7.8" @@ -37,3 +37,5 @@ openssl = "0.10.52" rustls = "0.21.1" pem = "2.0.1" serde_with = "3.0.0" +tls-listener = { version = "0.7.0", features = [ "rustls", "hyper-h1" ]} +tokio-rustls = "0.24.0" diff --git a/TODO b/TODO index e87a43e..fa0ff8d 100644 --- a/TODO +++ b/TODO @@ -1,11 +1,3 @@ -- make acme store implement https://docs.rs/rustls/latest/rustls/server/trait.ResolvesServerCert.html - -- pass that into https://docs.rs/rustls/latest/rustls/struct.ConfigBuilder.html# - -- turn that into a TlsAcceptor (From is implemented here: - https://docs.rs/tokio-rustls/latest/tokio_rustls/struct.TlsAcceptor.html#impl-From%3CArc%3CServerConfig%3E%3E-for-TlsAcceptor) - -- use tls-listener crate to wrap hyper accepter: https://github.com/tmccombs/tls-listener/blob/main/examples/http.rs#L24 - - https://github.com/tmccombs/tls-listener/blob/main/examples/tls_config/mod.rs - - logging +- expect statements (pretend it's "expected", not "expect") +- map_unexpected annotation string diff --git a/src/domain/acme/store.rs b/src/domain/acme/store.rs index e58d197..8589dfe 100644 --- a/src/domain/acme/store.rs +++ b/src/domain/acme/store.rs @@ -72,19 +72,24 @@ struct FSStore { dir_path: path::PathBuf, } -pub fn new(dir_path: &path::Path) -> Result { +#[derive(Clone)] +struct BoxedFSStore(sync::Arc); + +pub fn new( + dir_path: &path::Path, +) -> Result { fs::create_dir_all(dir_path).map_unexpected()?; fs::create_dir_all(dir_path.join("http01_challenge_keys")).map_unexpected()?; fs::create_dir_all(dir_path.join("certificates")).map_unexpected()?; - Ok(sync::Arc::new(FSStore { + Ok(BoxedFSStore(sync::Arc::new(FSStore { dir_path: dir_path.into(), - })) + }))) } -impl FSStore { +impl BoxedFSStore { fn account_key_path(&self) -> path::PathBuf { - self.dir_path.join("account.key") + self.0.dir_path.join("account.key") } fn http01_challenge_key_path(&self, token: &str) -> path::PathBuf { @@ -94,20 +99,20 @@ impl FSStore { .expect("token successfully hashed"); let n = h.finalize().encode_hex::(); - self.dir_path.join("http01_challenge_keys").join(n) + self.0.dir_path.join("http01_challenge_keys").join(n) } fn certificate_path(&self, domain: &str) -> path::PathBuf { - self.dir_path - .join("certificates") - .join(domain) - .with_extension("json") + let mut domain = domain.to_string(); + domain.push_str(".json"); + + self.0.dir_path.join("certificates").join(domain) } } -impl BoxedStore for sync::Arc {} +impl BoxedStore for BoxedFSStore {} -impl Store for sync::Arc { +impl Store for BoxedFSStore { fn set_account_key(&self, k: &PrivateKey) -> Result<(), error::Unexpected> { let mut file = fs::File::create(self.account_key_path()).map_unexpected()?; file.write_all(k.to_string().as_bytes()).map_unexpected()?; @@ -184,6 +189,39 @@ impl Store for sync::Arc { } } +impl rustls::server::ResolvesServerCert for BoxedFSStore { + fn resolve( + &self, + client_hello: rustls::server::ClientHello<'_>, + ) -> Option> { + let domain = if let Some(domain) = client_hello.server_name() { + domain + } else { + return None; + }; + + match self.get_certificate(domain) { + Err(GetCertificateError::NotFound) => Ok(None), + Err(GetCertificateError::Unexpected(err)) => Err(err), + Ok((key, cert)) => { + match rustls::sign::any_supported_type(&key.into()).map_unexpected() { + Err(err) => Err(err), + Ok(key) => Ok(Some(sync::Arc::new(rustls::sign::CertifiedKey { + cert: cert.into_iter().map(|cert| cert.into()).collect(), + key: key, + ocsp: None, + sct_list: None, + }))), + } + } + } + .unwrap_or_else(|err| { + println!("Unexpected error getting cert for domain {domain}: {err}"); + None + }) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/domain/checker.rs b/src/domain/checker.rs index 6423eba..9fdd6c2 100644 --- a/src/domain/checker.rs +++ b/src/domain/checker.rs @@ -1,6 +1,5 @@ use std::net; use std::str::FromStr; -use std::sync; use crate::error::MapUnexpected; use crate::{domain, error}; @@ -37,8 +36,7 @@ pub struct DNSChecker { client: tokio::sync::Mutex, } -pub fn new( - tokio_runtime: sync::Arc, +pub async fn new( target_a: net::Ipv4Addr, resolver_addr: &str, ) -> Result { @@ -48,11 +46,9 @@ pub fn new( let stream = udp::UdpClientStream::::new(resolver_addr); - let (client, bg) = tokio_runtime - .block_on(async { AsyncClient::connect(stream).await }) - .map_unexpected()?; + let (client, bg) = AsyncClient::connect(stream).await.map_unexpected()?; - tokio_runtime.spawn(bg); + tokio::spawn(bg); Ok(DNSChecker { target_a, diff --git a/src/main.rs b/src/main.rs index 9598638..0fc6ba3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -56,23 +56,17 @@ struct Cli { domain_acme_contact_email: Option, } -fn main() { +#[tokio::main] +async fn main() { let config = Cli::parse(); let mut wait_group = FuturesUnordered::new(); - - let tokio_runtime = std::sync::Arc::new( - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(), - ); - - let canceller = tokio_runtime.block_on(async { tokio_util::sync::CancellationToken::new() }); + let canceller = tokio_util::sync::CancellationToken::new(); { let canceller = canceller.clone(); - tokio_runtime.spawn(async move { + + tokio::spawn(async move { let mut signals = Signals::new(signal_hook::consts::TERM_SIGNALS).expect("initialized signals"); @@ -92,16 +86,16 @@ fn main() { .expect("git origin store initialized"); let domain_checker = domiply::domain::checker::new( - tokio_runtime.clone(), config.domain_checker_target_a, &config.domain_checker_resolver_addr, ) + .await .expect("domain checker initialized"); let domain_config_store = domiply::domain::config::new(&config.domain_config_store_dir_path) .expect("domain config store initialized"); - let domain_acme_manager = config.https_listen_addr.and_then(|_addr| { + let (domain_acme_store, domain_acme_manager) = if config.https_listen_addr.is_some() { let domain_acme_store = domiply::domain::acme::store::new(&config.domain_acme_store_dir_path) .expect("domain acme store initialized"); @@ -110,14 +104,17 @@ fn main() { // settings. let domain_acme_contact_email = config.domain_acme_contact_email.unwrap(); - let domain_acme_manager = tokio_runtime.block_on(async { - domiply::domain::acme::manager::new(domain_acme_store, &domain_acme_contact_email) - .await - .expect("domain acme manager initialized") - }); + let domain_acme_manager = domiply::domain::acme::manager::new( + domain_acme_store.clone(), + &domain_acme_contact_email, + ) + .await + .expect("domain acme manager initialized"); - Some(domain_acme_manager) - }); + (Some(domain_acme_store), Some(domain_acme_manager)) + } else { + (None, None) + }; let manager = domiply::domain::manager::new( origin_store, @@ -130,7 +127,7 @@ fn main() { let manager = manager.clone(); let canceller = canceller.clone(); - tokio_runtime.spawn(async move { + tokio::spawn(async move { let mut interval = time::interval(time::Duration::from_secs(20 * 60)); loop { @@ -166,8 +163,12 @@ fn main() { let service = sync::Arc::new(service); - let make_service = - hyper::service::make_service_fn(move |_conn: &hyper::server::conn::AddrStream| { + wait_group.push({ + let http_domain = config.http_domain.clone(); + let canceller = canceller.clone(); + let service = service.clone(); + + let make_service = hyper::service::make_service_fn(move |_| { let service = service.clone(); // Create a `Service` for responding to the request. @@ -179,11 +180,7 @@ fn main() { async move { Ok::<_, Infallible>(service) } }); - wait_group.push({ - let http_domain = config.http_domain.clone(); - let canceller = canceller.clone(); - - tokio_runtime.spawn(async move { + tokio::spawn(async move { let addr = config.http_listen_addr; println!( @@ -203,61 +200,119 @@ fn main() { }) }); - // if there's an acme manager then it means that https is enabled, and we should ensure that - // the http domain for domiply itself has a valid certificate. - if let Some(domain_acme_manager) = domain_acme_manager { - let manager = manager.clone(); - let canceller = canceller.clone(); - let http_domain = config.http_domain.clone(); + // if there's an acme manager then it means that https is enabled + if let (Some(domain_acme_store), Some(domain_acme_manager)) = + (domain_acme_store, domain_acme_manager) + { + // Periodically refresh all domain certs, including the http_domain passed in the Cli opts + wait_group.push({ + let manager = manager.clone(); + let http_domain = config.http_domain.clone(); + let canceller = canceller.clone(); - // Periodically refresh all domain certs - wait_group.push(tokio_runtime.spawn(async move { - let mut interval = time::interval(time::Duration::from_secs(60 * 60)); + tokio::spawn(async move { + let mut interval = time::interval(time::Duration::from_secs(60 * 60)); - loop { - select! { - _ = interval.tick() => (), - _ = canceller.cancelled() => return, + loop { + select! { + _ = interval.tick() => (), + _ = canceller.cancelled() => return, + } + + _ = domain_acme_manager + .sync_domain(http_domain.clone()) + .await + .inspect_err(|err| { + println!( + "Error while getting cert for {}: {err}", + http_domain.as_str() + ) + }); + + let domains_iter = manager.all_domains(); + + if let Err(err) = domains_iter { + println!("Got error calling all_domains: {err}"); + continue; + } + + for domain in domains_iter.unwrap().into_iter() { + match domain { + Ok(domain) => { + let _ = domain_acme_manager + .sync_domain(domain.clone()) + .await + .inspect_err(|err| { + println!( + "Error while getting cert for {}: {err}", + domain.as_str(), + ) + }); + } + Err(err) => println!("Error iterating through domains: {err}"), + }; + } } + }) + }); - _ = domain_acme_manager - .sync_domain(http_domain.clone()) - .await - .inspect_err(|err| { - println!( - "Error while getting cert for {}: {err}", - http_domain.as_str() - ) - }); + // HTTPS server + wait_group.push({ + let http_domain = config.http_domain.clone(); + let canceller = canceller.clone(); + let service = service.clone(); - let domains_iter = manager.all_domains(); + let make_service = hyper::service::make_service_fn(move |_| { + let service = service.clone(); - if let Err(err) = domains_iter { - println!("Got error calling all_domains: {err}"); - continue; - } + // Create a `Service` for responding to the request. + let service = hyper::service::service_fn(move |req| { + domiply::service::handle_request(service.clone(), req) + }); - for domain in domains_iter.unwrap().into_iter() { - match domain { - Ok(domain) => { - let _ = domain_acme_manager - .sync_domain(domain.clone()) - .await - .inspect_err(|err| { - println!( - "Error while getting cert for {}: {err}", - domain.as_str(), - ) - }); - } - Err(err) => println!("Error iterating through domains: {err}"), - }; - } - } - })); + // Return the service to hyper. + async move { Ok::<_, Infallible>(service) } + }); + + tokio::spawn(async move { + let canceller = canceller.clone(); + let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new( + rustls::server::ServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_safe_default_protocol_versions() + .unwrap() + .with_no_client_auth() + .with_cert_resolver(sync::Arc::from(domain_acme_store)), + ) + .into(); + + let addr = config.https_listen_addr.unwrap(); + let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr) + .expect("https listen socket created"); + + let incoming = tls_listener::TlsListener::new(server_config, addr_incoming); + + println!( + "Listening on https://{}:{}", + http_domain.as_str(), + addr.port() + ); + + let server = hyper::Server::builder(incoming).serve(make_service); + + let graceful = server.with_graceful_shutdown(async { + canceller.cancelled().await; + }); + + if let Err(e) = graceful.await { + panic!("server error: {}", e); + }; + }) + }) } - tokio_runtime.block_on(async { while let Some(_) = wait_group.next().await {} }); + while let Some(_) = wait_group.next().await {} println!("Graceful shutdown complete"); }