got HTTPS fully working

This commit is contained in:
Brian Picciano 2023-05-20 14:28:02 +02:00
parent 4f98a9a244
commit e29de0d29c
6 changed files with 215 additions and 106 deletions

30
Cargo.lock generated
View File

@ -470,7 +470,9 @@ dependencies = [
"signal-hook-tokio", "signal-hook-tokio",
"tempdir", "tempdir",
"thiserror", "thiserror",
"tls-listener",
"tokio", "tokio",
"tokio-rustls 0.24.0",
"tokio-util", "tokio-util",
"trust-dns-client", "trust-dns-client",
] ]
@ -1500,7 +1502,7 @@ dependencies = [
"hyper", "hyper",
"rustls 0.20.8", "rustls 0.20.8",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls 0.23.4",
] ]
[[package]] [[package]]
@ -2322,7 +2324,7 @@ dependencies = [
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls 0.23.4",
"tower-service", "tower-service",
"trust-dns-resolver", "trust-dns-resolver",
"url", "url",
@ -2757,6 +2759,20 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tls-listener"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81294c017957a1a69794f506723519255879e15a870507faf45dfed288b763dd"
dependencies = [
"futures-util",
"hyper",
"pin-project-lite",
"thiserror",
"tokio",
"tokio-rustls 0.24.0",
]
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.28.1" version = "1.28.1"
@ -2798,6 +2814,16 @@ dependencies = [
"webpki", "webpki",
] ]
[[package]]
name = "tokio-rustls"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5"
dependencies = [
"rustls 0.21.1",
"tokio",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.8" version = "0.7.8"

View File

@ -28,7 +28,7 @@ clap = { version = "4.2.7", features = ["derive", "env"] }
handlebars = { version = "4.3.7", features = [ "rust-embed" ]} handlebars = { version = "4.3.7", features = [ "rust-embed" ]}
rust-embed = "6.6.1" rust-embed = "6.6.1"
mime_guess = "2.0.4" mime_guess = "2.0.4"
hyper = { version = "0.14.26", features = [ "server" ]} hyper = { version = "0.14.26", features = [ "server", "stream" ]}
http = "0.2.9" http = "0.2.9"
serde_urlencoded = "0.7.1" serde_urlencoded = "0.7.1"
tokio-util = "0.7.8" tokio-util = "0.7.8"
@ -37,3 +37,5 @@ openssl = "0.10.52"
rustls = "0.21.1" rustls = "0.21.1"
pem = "2.0.1" pem = "2.0.1"
serde_with = "3.0.0" serde_with = "3.0.0"
tls-listener = { version = "0.7.0", features = [ "rustls", "hyper-h1" ]}
tokio-rustls = "0.24.0"

12
TODO
View File

@ -1,11 +1,3 @@
- make acme store implement https://docs.rs/rustls/latest/rustls/server/trait.ResolvesServerCert.html
- pass that into https://docs.rs/rustls/latest/rustls/struct.ConfigBuilder.html#
- turn that into a TlsAcceptor (From is implemented here:
https://docs.rs/tokio-rustls/latest/tokio_rustls/struct.TlsAcceptor.html#impl-From%3CArc%3CServerConfig%3E%3E-for-TlsAcceptor)
- use tls-listener crate to wrap hyper accepter: https://github.com/tmccombs/tls-listener/blob/main/examples/http.rs#L24
- https://github.com/tmccombs/tls-listener/blob/main/examples/tls_config/mod.rs
- logging - logging
- expect statements (pretend it's "expected", not "expect")
- map_unexpected annotation string

View File

@ -72,19 +72,24 @@ struct FSStore {
dir_path: path::PathBuf, dir_path: path::PathBuf,
} }
pub fn new(dir_path: &path::Path) -> Result<impl BoxedStore, error::Unexpected> { #[derive(Clone)]
struct BoxedFSStore(sync::Arc<FSStore>);
pub fn new(
dir_path: &path::Path,
) -> Result<impl BoxedStore + rustls::server::ResolvesServerCert, error::Unexpected> {
fs::create_dir_all(dir_path).map_unexpected()?; fs::create_dir_all(dir_path).map_unexpected()?;
fs::create_dir_all(dir_path.join("http01_challenge_keys")).map_unexpected()?; fs::create_dir_all(dir_path.join("http01_challenge_keys")).map_unexpected()?;
fs::create_dir_all(dir_path.join("certificates")).map_unexpected()?; fs::create_dir_all(dir_path.join("certificates")).map_unexpected()?;
Ok(sync::Arc::new(FSStore { Ok(BoxedFSStore(sync::Arc::new(FSStore {
dir_path: dir_path.into(), dir_path: dir_path.into(),
})) })))
} }
impl FSStore { impl BoxedFSStore {
fn account_key_path(&self) -> path::PathBuf { fn account_key_path(&self) -> path::PathBuf {
self.dir_path.join("account.key") self.0.dir_path.join("account.key")
} }
fn http01_challenge_key_path(&self, token: &str) -> path::PathBuf { fn http01_challenge_key_path(&self, token: &str) -> path::PathBuf {
@ -94,20 +99,20 @@ impl FSStore {
.expect("token successfully hashed"); .expect("token successfully hashed");
let n = h.finalize().encode_hex::<String>(); let n = h.finalize().encode_hex::<String>();
self.dir_path.join("http01_challenge_keys").join(n) self.0.dir_path.join("http01_challenge_keys").join(n)
} }
fn certificate_path(&self, domain: &str) -> path::PathBuf { fn certificate_path(&self, domain: &str) -> path::PathBuf {
self.dir_path let mut domain = domain.to_string();
.join("certificates") domain.push_str(".json");
.join(domain)
.with_extension("json") self.0.dir_path.join("certificates").join(domain)
} }
} }
impl BoxedStore for sync::Arc<FSStore> {} impl BoxedStore for BoxedFSStore {}
impl Store for sync::Arc<FSStore> { impl Store for BoxedFSStore {
fn set_account_key(&self, k: &PrivateKey) -> Result<(), error::Unexpected> { fn set_account_key(&self, k: &PrivateKey) -> Result<(), error::Unexpected> {
let mut file = fs::File::create(self.account_key_path()).map_unexpected()?; let mut file = fs::File::create(self.account_key_path()).map_unexpected()?;
file.write_all(k.to_string().as_bytes()).map_unexpected()?; file.write_all(k.to_string().as_bytes()).map_unexpected()?;
@ -184,6 +189,39 @@ impl Store for sync::Arc<FSStore> {
} }
} }
impl rustls::server::ResolvesServerCert for BoxedFSStore {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<sync::Arc<rustls::sign::CertifiedKey>> {
let domain = if let Some(domain) = client_hello.server_name() {
domain
} else {
return None;
};
match self.get_certificate(domain) {
Err(GetCertificateError::NotFound) => Ok(None),
Err(GetCertificateError::Unexpected(err)) => Err(err),
Ok((key, cert)) => {
match rustls::sign::any_supported_type(&key.into()).map_unexpected() {
Err(err) => Err(err),
Ok(key) => Ok(Some(sync::Arc::new(rustls::sign::CertifiedKey {
cert: cert.into_iter().map(|cert| cert.into()).collect(),
key: key,
ocsp: None,
sct_list: None,
}))),
}
}
}
.unwrap_or_else(|err| {
println!("Unexpected error getting cert for domain {domain}: {err}");
None
})
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -1,6 +1,5 @@
use std::net; use std::net;
use std::str::FromStr; use std::str::FromStr;
use std::sync;
use crate::error::MapUnexpected; use crate::error::MapUnexpected;
use crate::{domain, error}; use crate::{domain, error};
@ -37,8 +36,7 @@ pub struct DNSChecker {
client: tokio::sync::Mutex<AsyncClient>, client: tokio::sync::Mutex<AsyncClient>,
} }
pub fn new( pub async fn new(
tokio_runtime: sync::Arc<tokio::runtime::Runtime>,
target_a: net::Ipv4Addr, target_a: net::Ipv4Addr,
resolver_addr: &str, resolver_addr: &str,
) -> Result<DNSChecker, NewDNSCheckerError> { ) -> Result<DNSChecker, NewDNSCheckerError> {
@ -48,11 +46,9 @@ pub fn new(
let stream = udp::UdpClientStream::<tokio::net::UdpSocket>::new(resolver_addr); let stream = udp::UdpClientStream::<tokio::net::UdpSocket>::new(resolver_addr);
let (client, bg) = tokio_runtime let (client, bg) = AsyncClient::connect(stream).await.map_unexpected()?;
.block_on(async { AsyncClient::connect(stream).await })
.map_unexpected()?;
tokio_runtime.spawn(bg); tokio::spawn(bg);
Ok(DNSChecker { Ok(DNSChecker {
target_a, target_a,

View File

@ -56,23 +56,17 @@ struct Cli {
domain_acme_contact_email: Option<String>, domain_acme_contact_email: Option<String>,
} }
fn main() { #[tokio::main]
async fn main() {
let config = Cli::parse(); let config = Cli::parse();
let mut wait_group = FuturesUnordered::new(); let mut wait_group = FuturesUnordered::new();
let canceller = tokio_util::sync::CancellationToken::new();
let tokio_runtime = std::sync::Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap(),
);
let canceller = tokio_runtime.block_on(async { tokio_util::sync::CancellationToken::new() });
{ {
let canceller = canceller.clone(); let canceller = canceller.clone();
tokio_runtime.spawn(async move {
tokio::spawn(async move {
let mut signals = let mut signals =
Signals::new(signal_hook::consts::TERM_SIGNALS).expect("initialized signals"); Signals::new(signal_hook::consts::TERM_SIGNALS).expect("initialized signals");
@ -92,16 +86,16 @@ fn main() {
.expect("git origin store initialized"); .expect("git origin store initialized");
let domain_checker = domiply::domain::checker::new( let domain_checker = domiply::domain::checker::new(
tokio_runtime.clone(),
config.domain_checker_target_a, config.domain_checker_target_a,
&config.domain_checker_resolver_addr, &config.domain_checker_resolver_addr,
) )
.await
.expect("domain checker initialized"); .expect("domain checker initialized");
let domain_config_store = domiply::domain::config::new(&config.domain_config_store_dir_path) let domain_config_store = domiply::domain::config::new(&config.domain_config_store_dir_path)
.expect("domain config store initialized"); .expect("domain config store initialized");
let domain_acme_manager = config.https_listen_addr.and_then(|_addr| { let (domain_acme_store, domain_acme_manager) = if config.https_listen_addr.is_some() {
let domain_acme_store = let domain_acme_store =
domiply::domain::acme::store::new(&config.domain_acme_store_dir_path) domiply::domain::acme::store::new(&config.domain_acme_store_dir_path)
.expect("domain acme store initialized"); .expect("domain acme store initialized");
@ -110,14 +104,17 @@ fn main() {
// settings. // settings.
let domain_acme_contact_email = config.domain_acme_contact_email.unwrap(); let domain_acme_contact_email = config.domain_acme_contact_email.unwrap();
let domain_acme_manager = tokio_runtime.block_on(async { let domain_acme_manager = domiply::domain::acme::manager::new(
domiply::domain::acme::manager::new(domain_acme_store, &domain_acme_contact_email) domain_acme_store.clone(),
.await &domain_acme_contact_email,
.expect("domain acme manager initialized") )
}); .await
.expect("domain acme manager initialized");
Some(domain_acme_manager) (Some(domain_acme_store), Some(domain_acme_manager))
}); } else {
(None, None)
};
let manager = domiply::domain::manager::new( let manager = domiply::domain::manager::new(
origin_store, origin_store,
@ -130,7 +127,7 @@ fn main() {
let manager = manager.clone(); let manager = manager.clone();
let canceller = canceller.clone(); let canceller = canceller.clone();
tokio_runtime.spawn(async move { tokio::spawn(async move {
let mut interval = time::interval(time::Duration::from_secs(20 * 60)); let mut interval = time::interval(time::Duration::from_secs(20 * 60));
loop { loop {
@ -166,8 +163,12 @@ fn main() {
let service = sync::Arc::new(service); let service = sync::Arc::new(service);
let make_service = wait_group.push({
hyper::service::make_service_fn(move |_conn: &hyper::server::conn::AddrStream| { let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
let service = service.clone();
let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone(); let service = service.clone();
// Create a `Service` for responding to the request. // Create a `Service` for responding to the request.
@ -179,11 +180,7 @@ fn main() {
async move { Ok::<_, Infallible>(service) } async move { Ok::<_, Infallible>(service) }
}); });
wait_group.push({ tokio::spawn(async move {
let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
tokio_runtime.spawn(async move {
let addr = config.http_listen_addr; let addr = config.http_listen_addr;
println!( println!(
@ -203,61 +200,119 @@ fn main() {
}) })
}); });
// if there's an acme manager then it means that https is enabled, and we should ensure that // if there's an acme manager then it means that https is enabled
// the http domain for domiply itself has a valid certificate. if let (Some(domain_acme_store), Some(domain_acme_manager)) =
if let Some(domain_acme_manager) = domain_acme_manager { (domain_acme_store, domain_acme_manager)
let manager = manager.clone(); {
let canceller = canceller.clone(); // Periodically refresh all domain certs, including the http_domain passed in the Cli opts
let http_domain = config.http_domain.clone(); wait_group.push({
let manager = manager.clone();
let http_domain = config.http_domain.clone();
let canceller = canceller.clone();
// Periodically refresh all domain certs tokio::spawn(async move {
wait_group.push(tokio_runtime.spawn(async move { let mut interval = time::interval(time::Duration::from_secs(60 * 60));
let mut interval = time::interval(time::Duration::from_secs(60 * 60));
loop { loop {
select! { select! {
_ = interval.tick() => (), _ = interval.tick() => (),
_ = canceller.cancelled() => return, _ = canceller.cancelled() => return,
}
_ = domain_acme_manager
.sync_domain(http_domain.clone())
.await
.inspect_err(|err| {
println!(
"Error while getting cert for {}: {err}",
http_domain.as_str()
)
});
let domains_iter = manager.all_domains();
if let Err(err) = domains_iter {
println!("Got error calling all_domains: {err}");
continue;
}
for domain in domains_iter.unwrap().into_iter() {
match domain {
Ok(domain) => {
let _ = domain_acme_manager
.sync_domain(domain.clone())
.await
.inspect_err(|err| {
println!(
"Error while getting cert for {}: {err}",
domain.as_str(),
)
});
}
Err(err) => println!("Error iterating through domains: {err}"),
};
}
} }
})
});
_ = domain_acme_manager // HTTPS server
.sync_domain(http_domain.clone()) wait_group.push({
.await let http_domain = config.http_domain.clone();
.inspect_err(|err| { let canceller = canceller.clone();
println!( let service = service.clone();
"Error while getting cert for {}: {err}",
http_domain.as_str()
)
});
let domains_iter = manager.all_domains(); let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone();
if let Err(err) = domains_iter { // Create a `Service` for responding to the request.
println!("Got error calling all_domains: {err}"); let service = hyper::service::service_fn(move |req| {
continue; domiply::service::handle_request(service.clone(), req)
} });
for domain in domains_iter.unwrap().into_iter() { // Return the service to hyper.
match domain { async move { Ok::<_, Infallible>(service) }
Ok(domain) => { });
let _ = domain_acme_manager
.sync_domain(domain.clone()) tokio::spawn(async move {
.await let canceller = canceller.clone();
.inspect_err(|err| { let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
println!( rustls::server::ServerConfig::builder()
"Error while getting cert for {}: {err}", .with_safe_default_cipher_suites()
domain.as_str(), .with_safe_default_kx_groups()
) .with_safe_default_protocol_versions()
}); .unwrap()
} .with_no_client_auth()
Err(err) => println!("Error iterating through domains: {err}"), .with_cert_resolver(sync::Arc::from(domain_acme_store)),
}; )
} .into();
}
})); let addr = config.https_listen_addr.unwrap();
let addr_incoming = hyper::server::conn::AddrIncoming::bind(&addr)
.expect("https listen socket created");
let incoming = tls_listener::TlsListener::new(server_config, addr_incoming);
println!(
"Listening on https://{}:{}",
http_domain.as_str(),
addr.port()
);
let server = hyper::Server::builder(incoming).serve(make_service);
let graceful = server.with_graceful_shutdown(async {
canceller.cancelled().await;
});
if let Err(e) = graceful.await {
panic!("server error: {}", e);
};
})
})
} }
tokio_runtime.block_on(async { while let Some(_) = wait_group.next().await {} }); while let Some(_) = wait_group.next().await {}
println!("Graceful shutdown complete"); println!("Graceful shutdown complete");
} }