Move handle_request onto service as a method

This commit is contained in:
Brian Picciano 2023-06-19 20:09:25 +02:00
parent 506037dcd0
commit 43f4b98b38
2 changed files with 104 additions and 103 deletions

View File

@ -328,7 +328,7 @@ impl<'svc> Service {
self.render_page("/domain_sync.html", response) self.render_page("/domain_sync.html", response)
} }
pub fn domains(&self) -> SvcResponse { fn domains(&self) -> SvcResponse {
#[derive(Serialize)] #[derive(Serialize)]
struct Response { struct Response {
domains: Vec<String>, domains: Vec<String>,
@ -350,26 +350,8 @@ impl<'svc> Service {
self.render_page("/domains.html", Response { domains }) self.render_page("/domains.html", Response { domains })
} }
}
pub async fn handle_request( async fn handle_request_inner(&self, req: Request<Body>) -> SvcResponse {
svc: sync::Arc<Service>,
req: Request<Body>,
) -> Result<Response<Body>, Infallible> {
match handle_request_inner(svc, req).await {
Ok(res) => Ok(res),
Err(err) => panic!("unexpected error {err}"),
}
}
fn strip_port(host: &str) -> &str {
match host.rfind(':') {
None => host,
Some(i) => &host[..i],
}
}
pub async fn handle_request_inner(svc: sync::Arc<Service>, req: Request<Body>) -> SvcResponse {
let maybe_host = match ( let maybe_host = match (
req.headers() req.headers()
.get("Host") .get("Host")
@ -377,8 +359,8 @@ pub async fn handle_request_inner(svc: sync::Arc<Service>, req: Request<Body>) -
.map(strip_port), .map(strip_port),
req.uri().host().map(strip_port), req.uri().host().map(strip_port),
) { ) {
(Some(h), _) if h != svc.http_domain.as_str() => Some(h), (Some(h), _) if h != self.http_domain.as_str() => Some(h),
(_, Some(h)) if h != svc.http_domain.as_str() => Some(h), (_, Some(h)) if h != self.http_domain.as_str() => Some(h),
_ => None, _ => None,
} }
.and_then(|h| domain::Name::from_str(h).ok()); .and_then(|h| domain::Name::from_str(h).ok());
@ -391,7 +373,7 @@ pub async fn handle_request_inner(svc: sync::Arc<Service>, req: Request<Body>) -
if method == Method::GET && path.starts_with("/.well-known/acme-challenge/") { if method == Method::GET && path.starts_with("/.well-known/acme-challenge/") {
let token = path.trim_start_matches("/.well-known/acme-challenge/"); let token = path.trim_start_matches("/.well-known/acme-challenge/");
if let Ok(key) = svc.domain_manager.get_acme_http01_challenge_key(token) { if let Ok(key) = self.domain_manager.get_acme_http01_challenge_key(token) {
let body: hyper::Body = key.into(); let body: hyper::Body = key.into();
return match Response::builder().status(200).body(body) { return match Response::builder().status(200).body(body) {
Ok(res) => Ok(res), Ok(res) => Ok(res),
@ -405,40 +387,57 @@ pub async fn handle_request_inner(svc: sync::Arc<Service>, req: Request<Body>) -
// If a managed domain was given then serve that from its origin // If a managed domain was given then serve that from its origin
if let Some(domain) = maybe_host { if let Some(domain) = maybe_host {
return svc.serve_origin(domain, req.uri().path()); return self.serve_origin(domain, req.uri().path());
} }
// Serve main domiply site // Serve main domiply site
if method == Method::GET && path.starts_with("/static/") { if method == Method::GET && path.starts_with("/static/") {
return svc.render(200, path, ()); return self.render(200, path, ());
} }
match (method, path) { match (method, path) {
(&Method::GET, "/") | (&Method::GET, "/index.html") => svc.render_page("/index.html", ()), (&Method::GET, "/") | (&Method::GET, "/index.html") => {
self.render_page("/index.html", ())
}
(&Method::GET, "/domain.html") => { (&Method::GET, "/domain.html") => {
svc.with_query_req(&req, |args: DomainGetArgs| async { svc.domain_get(args) }) self.with_query_req(&req, |args: DomainGetArgs| async { self.domain_get(args) })
.await .await
} }
(&Method::GET, "/domain_init.html") => { (&Method::GET, "/domain_init.html") => {
svc.with_query_req(&req, |args: DomainInitArgs| async { self.with_query_req(&req, |args: DomainInitArgs| async {
svc.with_query_req(&req, |config: service::util::FlatConfig| async { self.with_query_req(&req, |config: service::util::FlatConfig| async {
svc.domain_init(args, config) self.domain_init(args, config)
}) })
.await .await
}) })
.await .await
} }
(&Method::GET, "/domain_sync.html") => { (&Method::GET, "/domain_sync.html") => {
svc.with_query_req(&req, |args: DomainSyncArgs| async { self.with_query_req(&req, |args: DomainSyncArgs| async {
svc.with_query_req(&req, |config: service::util::FlatConfig| async { self.with_query_req(&req, |config: service::util::FlatConfig| async {
svc.domain_sync(args, config).await self.domain_sync(args, config).await
}) })
.await .await
}) })
.await .await
} }
(&Method::GET, "/domains.html") => svc.domains(), (&Method::GET, "/domains.html") => self.domains(),
_ => svc.render_error_page(404, "Page not found!"), _ => self.render_error_page(404, "Page not found!"),
}
}
pub async fn handle_request(&self, req: Request<Body>) -> Result<Response<Body>, Infallible> {
match self.handle_request_inner(req).await {
Ok(res) => Ok(res),
Err(err) => panic!("unexpected error {err}"),
}
}
}
fn strip_port(host: &str) -> &str {
match host.rfind(':') {
None => host,
Some(i) => &host[..i],
} }
} }

View File

@ -11,19 +11,20 @@ pub fn listen_http(
addr: net::SocketAddr, addr: net::SocketAddr,
domain: domain::Name, domain: domain::Name,
) -> tokio::task::JoinHandle<()> { ) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let make_service = hyper::service::make_service_fn(move |_| { let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone(); let service = service.clone();
// Create a `Service` for responding to the request. // Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| { let hyper_service = hyper::service::service_fn(move |req| {
service::http::handle_request(service.clone(), req) let service = service.clone();
async move { service.handle_request(req).await }
}); });
// Return the service to hyper. // Return the service to hyper.
async move { Ok::<_, convert::Infallible>(hyper_service) } async move { Ok::<_, convert::Infallible>(hyper_service) }
}); });
tokio::spawn(async move {
log::info!("Listening on http://{}:{}", domain.as_str(), addr.port()); log::info!("Listening on http://{}:{}", domain.as_str(), addr.port());
let server = hyper::Server::bind(&addr).serve(make_service); let server = hyper::Server::bind(&addr).serve(make_service);
@ -44,19 +45,20 @@ pub fn listen_https(
addr: net::SocketAddr, addr: net::SocketAddr,
domain: domain::Name, domain: domain::Name,
) -> tokio::task::JoinHandle<()> { ) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let make_service = hyper::service::make_service_fn(move |_| { let make_service = hyper::service::make_service_fn(move |_| {
let service = service.clone(); let service = service.clone();
// Create a `Service` for responding to the request. // Create a `Service` for responding to the request.
let hyper_service = hyper::service::service_fn(move |req| { let hyper_service = hyper::service::service_fn(move |req| {
service::http::handle_request(service.clone(), req) let service = service.clone();
async move { service.handle_request(req).await }
}); });
// Return the service to hyper. // Return the service to hyper.
async move { Ok::<_, convert::Infallible>(hyper_service) } async move { Ok::<_, convert::Infallible>(hyper_service) }
}); });
tokio::spawn(async move {
let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new( let server_config: tokio_rustls::TlsAcceptor = sync::Arc::new(
rustls::server::ServerConfig::builder() rustls::server::ServerConfig::builder()
.with_safe_defaults() .with_safe_defaults()