Use GATs to remove dynamic dispatch when using origin store

This commit is contained in:
Brian Picciano 2023-05-11 11:47:38 +02:00
parent d1842943cd
commit 79ba735fd6
2 changed files with 43 additions and 53 deletions

View File

@ -115,10 +115,14 @@ impl From<config::SetError> for SyncWithConfigError {
} }
} }
#[mockall::automock] #[mockall::automock(type Origin=origin::MockOrigin;)]
pub trait Manager { pub trait Manager {
type Origin<'mgr>: origin::Origin + 'mgr
where
Self: 'mgr;
fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError>; fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError>;
fn get_origin(&self, domain: &str) -> Result<Box<dyn origin::Origin>, GetOriginError>; fn get_origin(&self, domain: &str) -> Result<Self::Origin<'_>, GetOriginError>;
fn sync(&self, domain: &str) -> Result<(), SyncError>; fn sync(&self, domain: &str) -> Result<(), SyncError>;
fn sync_with_config( fn sync_with_config(
&self, &self,
@ -133,7 +137,7 @@ pub fn new_manager<OriginStore, DomainConfigStore, DomainChecker>(
domain_checker: DomainChecker, domain_checker: DomainChecker,
) -> impl Manager ) -> impl Manager
where where
OriginStore: for<'a> origin::store::Store<'a>, OriginStore: origin::store::Store,
DomainConfigStore: config::Store, DomainConfigStore: config::Store,
DomainChecker: checker::Checker, DomainChecker: checker::Checker,
{ {
@ -146,7 +150,7 @@ where
struct ManagerImpl<OriginStore, DomainConfigStore, DomainChecker> struct ManagerImpl<OriginStore, DomainConfigStore, DomainChecker>
where where
OriginStore: for<'a> origin::store::Store<'a>, OriginStore: origin::store::Store,
DomainConfigStore: config::Store, DomainConfigStore: config::Store,
DomainChecker: checker::Checker, DomainChecker: checker::Checker,
{ {
@ -158,15 +162,18 @@ where
impl<OriginStore, DomainConfigStore, DomainChecker> Manager impl<OriginStore, DomainConfigStore, DomainChecker> Manager
for ManagerImpl<OriginStore, DomainConfigStore, DomainChecker> for ManagerImpl<OriginStore, DomainConfigStore, DomainChecker>
where where
OriginStore: for<'a> origin::store::Store<'a>, OriginStore: origin::store::Store,
DomainConfigStore: config::Store, DomainConfigStore: config::Store,
DomainChecker: checker::Checker, DomainChecker: checker::Checker,
{ {
type Origin<'mgr> = OriginStore::Origin<'mgr>
where Self: 'mgr;
fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError> { fn get_config(&self, domain: &str) -> Result<config::Config, GetConfigError> {
Ok(self.domain_config_store.get(domain)?) Ok(self.domain_config_store.get(domain)?)
} }
fn get_origin(&self, domain: &str) -> Result<Box<dyn origin::Origin>, GetOriginError> { fn get_origin(&self, domain: &str) -> Result<Self::Origin<'_>, GetOriginError> {
let config = self.domain_config_store.get(domain)?; let config = self.domain_config_store.get(domain)?;
let origin = self let origin = self
.origin_store .origin_store

View File

@ -39,48 +39,26 @@ pub enum AllDescrsError {
/// Used in the return from all_descrs from Store. /// Used in the return from all_descrs from Store.
pub type AllDescrsResult<T> = Result<T, AllDescrsError>; pub type AllDescrsResult<T> = Result<T, AllDescrsError>;
#[mockall::automock(
type Origin=origin::MockOrigin;
type AllDescrsIter=Vec<AllDescrsResult<origin::Descr>>;
)]
/// Describes a storage mechanism for Origins. Each Origin is uniquely identified by its Descr. /// Describes a storage mechanism for Origins. Each Origin is uniquely identified by its Descr.
pub trait Store<'a> { pub trait Store {
type AllDescrsIter: IntoIterator<Item = AllDescrsResult<origin::Descr>>; type Origin<'store>: origin::Origin + 'store
where
Self: 'store;
type AllDescrsIter<'store>: IntoIterator<Item = AllDescrsResult<origin::Descr>> + 'store
where
Self: 'store;
/// If the origin is of a kind which can be updated, sync will pull down the latest version of /// If the origin is of a kind which can be updated, sync will pull down the latest version of
/// the origin into the storage. /// the origin into the storage.
fn sync(&'a self, descr: origin::Descr, limits: Limits) -> Result<(), SyncError>; fn sync(&self, descr: origin::Descr, limits: Limits) -> Result<(), SyncError>;
fn get(&'a self, descr: origin::Descr) -> Result<Box<dyn origin::Origin>, GetError>; fn get(&self, descr: origin::Descr) -> Result<Self::Origin<'_>, GetError>;
fn all_descrs(&'a self) -> AllDescrsResult<Self::AllDescrsIter>; fn all_descrs(&self) -> AllDescrsResult<Self::AllDescrsIter<'_>>;
}
pub struct MockStore<SyncFn, GetFn, AllDescrsFn>
where
SyncFn: Fn(origin::Descr, Limits) -> Result<(), SyncError>,
GetFn: Fn(origin::Descr) -> Result<Box<dyn origin::Origin>, GetError>,
AllDescrsFn: Fn() -> AllDescrsResult<Vec<AllDescrsResult<origin::Descr>>>,
{
pub sync_fn: SyncFn,
pub get_fn: GetFn,
pub all_descrs_fn: AllDescrsFn,
}
impl<'a, SyncFn, GetFn, AllDescrsFn> Store<'a> for MockStore<SyncFn, GetFn, AllDescrsFn>
where
SyncFn: Fn(origin::Descr, Limits) -> Result<(), SyncError>,
GetFn: Fn(origin::Descr) -> Result<Box<dyn origin::Origin>, GetError>,
AllDescrsFn: Fn() -> AllDescrsResult<Vec<AllDescrsResult<origin::Descr>>>,
{
type AllDescrsIter = Vec<AllDescrsResult<origin::Descr>>;
fn sync(&'a self, descr: origin::Descr, limits: Limits) -> Result<(), SyncError> {
(self.sync_fn)(descr, limits)
}
fn get(&'a self, descr: origin::Descr) -> Result<Box<dyn origin::Origin>, GetError> {
(self.get_fn)(descr)
}
fn all_descrs(&'a self) -> AllDescrsResult<Self::AllDescrsIter> {
(self.all_descrs_fn)()
}
} }
pub mod git { pub mod git {
@ -92,7 +70,7 @@ pub mod git {
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::{collections, fs, io, sync}; use std::{collections, fs, io, sync};
struct Origin { pub struct Origin {
descr: origin::Descr, descr: origin::Descr,
repo: gix::ThreadSafeRepository, repo: gix::ThreadSafeRepository,
tree_object_id: gix::ObjectId, tree_object_id: gix::ObjectId,
@ -279,11 +257,15 @@ pub mod git {
} }
} }
impl<'a> super::Store<'a> for Store { impl super::Store for Store {
type AllDescrsIter = Box<dyn Iterator<Item = store::AllDescrsResult<origin::Descr>> + 'a>; type Origin<'store> = sync::Arc<Origin>
where Self: 'store;
type AllDescrsIter<'store> = Box<dyn Iterator<Item = store::AllDescrsResult<origin::Descr>> + 'store>
where Self: 'store;
fn sync( fn sync(
&'a self, &self,
descr: origin::Descr, descr: origin::Descr,
limits: store::Limits, limits: store::Limits,
) -> Result<(), store::SyncError> { ) -> Result<(), store::SyncError> {
@ -330,11 +312,11 @@ pub mod git {
Ok(()) Ok(())
} }
fn get(&'a self, descr: origin::Descr) -> Result<Box<dyn origin::Origin>, store::GetError> { fn get(&self, descr: origin::Descr) -> Result<Self::Origin<'_>, store::GetError> {
{ {
let origins = self.origins.read().unwrap(); let origins = self.origins.read().unwrap();
if let Some(origin) = origins.get(&descr) { if let Some(origin) = origins.get(&descr) {
return Ok(Box::from(origin.clone())); return Ok(origin.clone());
} }
} }
@ -358,10 +340,10 @@ pub mod git {
let mut origins = self.origins.write().unwrap(); let mut origins = self.origins.write().unwrap();
(*origins).insert(descr, origin.clone()); (*origins).insert(descr, origin.clone());
Ok(Box::from(origin)) Ok(origin)
} }
fn all_descrs(&'a self) -> store::AllDescrsResult<Self::AllDescrsIter> { fn all_descrs(&self) -> store::AllDescrsResult<Self::AllDescrsIter<'_>> {
Ok(Box::from( Ok(Box::from(
fs::read_dir(&self.dir_path) fs::read_dir(&self.dir_path)
.map_err(|e| store::AllDescrsError::Unexpected(Box::from(e)))? .map_err(|e| store::AllDescrsError::Unexpected(Box::from(e)))?
@ -395,9 +377,10 @@ pub mod git {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::origin;
use crate::origin::store; use crate::origin::store;
use crate::origin::store::Store; use crate::origin::store::Store;
use crate::origin::{self, Origin};
use std::sync;
use tempdir::TempDir; use tempdir::TempDir;
#[test] #[test]
@ -429,7 +412,7 @@ pub mod git {
assert!(matches!( assert!(matches!(
store.get(other_descr), store.get(other_descr),
Err::<Box<dyn origin::Origin>, store::GetError>(store::GetError::NotFound), Err::<sync::Arc<super::Origin>, store::GetError>(store::GetError::NotFound),
)); ));
let origin = store.get(descr.clone()).expect("origin retrieved"); let origin = store.get(descr.clone()).expect("origin retrieved");