From a86020eedf1e2210632b2752cae6dc0976de382f Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Sun, 16 Jul 2023 15:40:20 +0200 Subject: [PATCH] Have get_file accept and return structs, which will be easier to extend going forward --- src/domain/manager.rs | 29 +++++++++++++---------- src/origin.rs | 29 ++++++++++++++++------- src/origin/git.rs | 50 +++++++++++++++++++++++++++++---------- src/origin/mux.rs | 16 ++++++------- src/service/http.rs | 21 +++++++++++----- src/service/http/tasks.rs | 12 ++++++---- 6 files changed, 106 insertions(+), 51 deletions(-) diff --git a/src/domain/manager.rs b/src/domain/manager.rs index febb374..1c53795 100644 --- a/src/domain/manager.rs +++ b/src/domain/manager.rs @@ -24,6 +24,9 @@ impl From for GetSettingsError { } } +pub type GetFileRequest<'req> = origin::GetFileRequest<'req>; +pub type GetFileResponse = origin::GetFileResponse; + #[derive(thiserror::Error, Debug)] pub enum GetFileError { #[error("domain not found")] @@ -142,11 +145,11 @@ pub type GetAcmeHttp01ChallengeKeyError = acme::manager::GetHttp01ChallengeKeyEr pub trait Manager: Sync + Send + rustls::server::ResolvesServerCert { fn get_settings(&self, domain: &domain::Name) -> Result; - fn get_file<'store>( - &'store self, - domain: &domain::Name, - path: &str, - ) -> Result; + fn get_file<'req>( + &self, + domain: &'req domain::Name, + req: GetFileRequest<'req>, + ) -> Result; fn sync_cert<'mgr>( &'mgr self, @@ -239,14 +242,14 @@ impl Manager for ManagerImpl { Ok(self.domain_store.get(domain)?) } - fn get_file<'store>( - &'store self, - domain: &domain::Name, - path: &str, - ) -> Result { - let config = self.domain_store.get(domain)?; - let f = self.origin_store.get_file(&config.origin_descr, path)?; - Ok(f) + fn get_file<'req>( + &self, + domain: &'req domain::Name, + req: GetFileRequest<'req>, + ) -> Result { + let settings = self.domain_store.get(domain)?; + let res = self.origin_store.get_file(&settings.origin_descr, req)?; + Ok(res) } fn sync_cert<'mgr>( diff --git a/src/origin.rs b/src/origin.rs index f817f18..8ae2ba8 100644 --- a/src/origin.rs +++ b/src/origin.rs @@ -8,7 +8,7 @@ pub use descr::Descr; use crate::error::unexpected; use crate::util; -use std::sync; +use std::{net, sync}; #[derive(thiserror::Error, Clone, Debug, PartialEq)] pub enum SyncError { @@ -31,6 +31,15 @@ pub enum AllDescrsError { Unexpected(#[from] unexpected::Error), } +pub struct GetFileRequest<'a> { + pub path: &'a str, + pub client_ip: &'a net::IpAddr, +} + +pub struct GetFileResponse { + pub body: util::BoxByteStream, +} + #[derive(thiserror::Error, Debug)] pub enum GetFileError { #[error("descr not synced")] @@ -52,7 +61,11 @@ pub trait Store { fn all_descrs(&self) -> Result, AllDescrsError>; - fn get_file(&self, descr: &Descr, path: &str) -> Result; + fn get_file<'req>( + &self, + descr: &'req Descr, + req: GetFileRequest<'req>, + ) -> Result; } pub fn new_mock() -> sync::Arc> { @@ -68,11 +81,11 @@ impl Store for sync::Arc> { self.lock().unwrap().all_descrs() } - fn get_file<'store>( - &'store self, - descr: &Descr, - path: &str, - ) -> Result { - self.lock().unwrap().get_file(descr, path) + fn get_file<'req>( + &self, + descr: &'req Descr, + req: GetFileRequest<'req>, + ) -> Result { + self.lock().unwrap().get_file(descr, req) } } diff --git a/src/origin/git.rs b/src/origin/git.rs index b466e3b..8266790 100644 --- a/src/origin/git.rs +++ b/src/origin/git.rs @@ -1,5 +1,5 @@ use crate::error::unexpected::{self, Intoable, Mappable}; -use crate::{origin, util}; +use crate::origin; use std::path::{Path, PathBuf}; use std::{collections, fs, io, sync}; @@ -297,18 +297,18 @@ impl super::Store for FSStore { ).try_collect() } - fn get_file<'store>( - &'store self, - descr: &origin::Descr, - path: &str, - ) -> Result { + fn get_file<'req>( + &self, + descr: &'req origin::Descr, + req: origin::GetFileRequest<'req>, + ) -> Result { let repo_snapshot = match self.get_repo_snapshot(descr) { Ok(Some(repo_snapshot)) => repo_snapshot, Ok(None) => return Err(origin::GetFileError::DescrNotSynced), Err(e) => return Err(e.into()), }; - let mut clean_path = Path::new(path); + let mut clean_path = Path::new(req.path); clean_path = clean_path.strip_prefix("/").unwrap_or(clean_path); let repo = repo_snapshot.repo.to_thread_local(); @@ -337,7 +337,9 @@ impl super::Store for FSStore { // TODO this is very not ideal, the whole file is first read totally into memory, and then // that is cloned. let data = file_object.data.clone(); - Ok(Box::pin(stream::once(async move { Ok(data) }))) + Ok(origin::GetFileResponse { + body: Box::pin(stream::once(async move { Ok(data) })), + }) } } @@ -345,10 +347,13 @@ impl super::Store for FSStore { mod tests { use crate::origin::{self, Config, Store}; use futures::StreamExt; + use std::{net, str::FromStr}; use tempdir::TempDir; #[tokio::test] async fn basic() { + let client_ip = net::IpAddr::from_str("127.0.0.1").unwrap(); + let tmp_dir = TempDir::new("origin_store_git").unwrap(); let config = Config { store_dir_path: tmp_dir.path().to_path_buf(), @@ -372,20 +377,41 @@ mod tests { store.sync(&descr).expect("second sync should succeed"); // RepoSnapshot doesn't exist - match store.get_file(&other_descr, "DNE") { + match store.get_file( + &other_descr, + origin::GetFileRequest { + path: "DNE", + client_ip: &client_ip, + }, + ) { Err(origin::GetFileError::DescrNotSynced) => (), _ => assert!(false, "descr should have not been found"), }; - let assert_file_dne = |path: &str| match store.get_file(&descr, path) { + let assert_file_dne = |path: &str| match store.get_file( + &descr, + origin::GetFileRequest { + path, + client_ip: &client_ip, + }, + ) { Err(origin::GetFileError::FileNotFound) => (), _ => assert!(false, "file should have not been found"), }; let assert_file_not_empty = |path: &str| { - let f = store.get_file(&descr, path).expect("file not retrieved"); + let origin::GetFileResponse { body } = store + .get_file( + &descr, + origin::GetFileRequest { + path, + client_ip: &client_ip, + }, + ) + .expect("file not retrieved"); + async move { - let body = f.map(|r| r.unwrap()).concat().await; + let body = body.map(|r| r.unwrap()).concat().await; assert!(body.len() > 0); } }; diff --git a/src/origin/mux.rs b/src/origin/mux.rs index 2ae26b0..739b2bd 100644 --- a/src/origin/mux.rs +++ b/src/origin/mux.rs @@ -1,5 +1,5 @@ use crate::error::unexpected::Mappable; -use crate::{origin, util}; +use crate::origin; pub struct Store where @@ -41,14 +41,14 @@ where Ok(res) } - fn get_file<'store>( - &'store self, - descr: &origin::Descr, - path: &str, - ) -> Result { + fn get_file<'req>( + &self, + descr: &'req origin::Descr, + req: origin::GetFileRequest<'req>, + ) -> Result { (self.mapping_fn)(descr) - .or_unexpected_while(format!("mapping {:?} to store", &descr))? - .get_file(descr, path) + .or_unexpected_while(format!("mapping {:?} to store", descr))? + .get_file(descr, req) } } diff --git a/src/service/http.rs b/src/service/http.rs index bc5368c..ad29688 100644 --- a/src/service/http.rs +++ b/src/service/http.rs @@ -9,7 +9,7 @@ use hyper::{Body, Method, Request, Response}; use serde::{Deserialize, Serialize}; use std::str::FromStr; -use std::{future, sync}; +use std::{future, net, sync}; use crate::error::unexpected; use crate::{domain, service, util}; @@ -158,7 +158,12 @@ impl<'svc> Service { ) } - fn serve_origin(&self, domain: domain::Name, path: &str) -> Response { + fn serve_origin( + &self, + domain: &domain::Name, + path: &str, + client_ip: &net::IpAddr, + ) -> Response { let mut path_owned; let path = match path.ends_with('/') { @@ -170,8 +175,12 @@ impl<'svc> Service { false => path, }; - match self.domain_manager.get_file(&domain, path) { - Ok(f) => self.serve(200, path, Body::wrap_stream(f)), + let req = domain::manager::GetFileRequest { path, client_ip }; + + match self.domain_manager.get_file(&domain, req) { + Ok(domain::manager::GetFileResponse { body }) => { + self.serve(200, path, Body::wrap_stream(body)) + } Err(domain::manager::GetFileError::DomainNotFound) => { return self.render_error_page(404, "Domain not found") } @@ -366,7 +375,7 @@ impl<'svc> Service { self.render_page("/domains.html", Response { domains }) } - async fn handle_request(&self, req: Request) -> Response { + async fn handle_request(&self, client_ip: net::IpAddr, req: Request) -> Response { let (req, body) = req.into_parts(); let maybe_host = match ( @@ -415,7 +424,7 @@ impl<'svc> Service { // If a managed domain was given then serve that from its origin if let Some(domain) = maybe_host { - return self.serve_origin(domain, req.uri.path()); + return self.serve_origin(&domain, req.uri.path(), &client_ip); } // Serve main domani site diff --git a/src/service/http/tasks.rs b/src/service/http/tasks.rs index 8598e43..de3bc30 100644 --- a/src/service/http/tasks.rs +++ b/src/service/http/tasks.rs @@ -4,6 +4,8 @@ use crate::service; use std::{convert, future, sync}; use futures::StreamExt; +use hyper::server::conn::AddrStream; +use tokio_rustls::server::TlsStream; use tokio_util::sync::CancellationToken; pub async fn listen_http( @@ -13,13 +15,14 @@ pub async fn listen_http( let addr = service.config.http.http_addr.clone(); let primary_domain = service.config.primary_domain.clone(); - let make_service = hyper::service::make_service_fn(move |_| { + let make_service = hyper::service::make_service_fn(move |conn: &AddrStream| { let service = service.clone(); + let client_ip = conn.remote_addr().ip(); // Create a `Service` for responding to the request. let hyper_service = hyper::service::service_fn(move |req| { let service = service.clone(); - async move { Ok::<_, convert::Infallible>(service.handle_request(req).await) } + async move { Ok::<_, convert::Infallible>(service.handle_request(client_ip, req).await) } }); // Return the service to hyper. @@ -48,13 +51,14 @@ pub async fn listen_https( let addr = service.config.http.https_addr.unwrap().clone(); let primary_domain = service.config.primary_domain.clone(); - let make_service = hyper::service::make_service_fn(move |_| { + let make_service = hyper::service::make_service_fn(move |conn: &TlsStream| { let service = service.clone(); + let client_ip = conn.get_ref().0.remote_addr().ip(); // Create a `Service` for responding to the request. let hyper_service = hyper::service::service_fn(move |req| { let service = service.clone(); - async move { Ok::<_, convert::Infallible>(service.handle_request(req).await) } + async move { Ok::<_, convert::Infallible>(service.handle_request(client_ip, req).await) } }); // Return the service to hyper.