diff --git a/srv/src/post/asset.go b/srv/src/post/asset.go index 3e6ae28..18af8f6 100644 --- a/srv/src/post/asset.go +++ b/srv/src/post/asset.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "sync" ) var ( @@ -84,3 +85,69 @@ func (s *assetStore) Delete(id string) error { _, err := s.db.Exec(`DELETE FROM assets WHERE id = ?`, id) return err } + +//////////////////////////////////////////////////////////////////////////////// + +type cachedAssetStore struct { + inner AssetStore + m sync.Map +} + +// NewCachedAssetStore wraps an AssetStore in an in-memory cache. +func NewCachedAssetStore(assetStore AssetStore) AssetStore { + return &cachedAssetStore{ + inner: assetStore, + } +} + +func (s *cachedAssetStore) Set(id string, from io.Reader) error { + + buf := new(bytes.Buffer) + from = io.TeeReader(from, buf) + + if err := s.inner.Set(id, from); err != nil { + return err + } + + s.m.Store(id, buf.Bytes()) + return nil +} + +func (s *cachedAssetStore) Get(id string, into io.Writer) error { + + if bodyI, ok := s.m.Load(id); ok { + + if err, ok := bodyI.(error); ok { + return err + } + + if _, err := io.Copy(into, bytes.NewReader(bodyI.([]byte))); err != nil { + return fmt.Errorf("writing body to io.Writer: %w", err) + } + + return nil + } + + buf := new(bytes.Buffer) + into = io.MultiWriter(into, buf) + + if err := s.inner.Get(id, into); errors.Is(err, ErrAssetNotFound) { + s.m.Store(id, err) + return err + } else if err != nil { + return err + } + + s.m.Store(id, buf.Bytes()) + return nil +} + +func (s *cachedAssetStore) Delete(id string) error { + + if err := s.inner.Delete(id); err != nil { + return err + } + + s.m.Delete(id) + return nil +} diff --git a/srv/src/post/asset_test.go b/srv/src/post/asset_test.go index 4b88000..d0cff48 100644 --- a/srv/src/post/asset_test.go +++ b/srv/src/post/asset_test.go @@ -12,14 +12,14 @@ type assetTestHarness struct { store AssetStore } -func newAssetTestHarness(t *testing.T) assetTestHarness { +func newAssetTestHarness(t *testing.T) *assetTestHarness { db := NewInMemSQLDB() t.Cleanup(func() { db.Close() }) store := NewAssetStore(db) - return assetTestHarness{ + return &assetTestHarness{ store: store, } } @@ -40,28 +40,41 @@ func (h *assetTestHarness) assertNotFound(t *testing.T, id string) { func TestAssetStore(t *testing.T) { - h := newAssetTestHarness(t) + testAssetStore := func(t *testing.T, h *assetTestHarness) { + t.Helper() - h.assertNotFound(t, "foo") - h.assertNotFound(t, "bar") + h.assertNotFound(t, "foo") + h.assertNotFound(t, "bar") - err := h.store.Set("foo", bytes.NewBufferString("FOO")) - assert.NoError(t, err) + err := h.store.Set("foo", bytes.NewBufferString("FOO")) + assert.NoError(t, err) - h.assertGet(t, "FOO", "foo") - h.assertNotFound(t, "bar") + h.assertGet(t, "FOO", "foo") + h.assertNotFound(t, "bar") - err = h.store.Set("foo", bytes.NewBufferString("FOOFOO")) - assert.NoError(t, err) + err = h.store.Set("foo", bytes.NewBufferString("FOOFOO")) + assert.NoError(t, err) + + h.assertGet(t, "FOOFOO", "foo") + h.assertNotFound(t, "bar") - h.assertGet(t, "FOOFOO", "foo") - h.assertNotFound(t, "bar") + assert.NoError(t, h.store.Delete("foo")) + h.assertNotFound(t, "foo") + h.assertNotFound(t, "bar") + + assert.NoError(t, h.store.Delete("bar")) + h.assertNotFound(t, "foo") + h.assertNotFound(t, "bar") + } - assert.NoError(t, h.store.Delete("foo")) - h.assertNotFound(t, "foo") - h.assertNotFound(t, "bar") + t.Run("sql", func(t *testing.T) { + h := newAssetTestHarness(t) + testAssetStore(t, h) + }) - assert.NoError(t, h.store.Delete("bar")) - h.assertNotFound(t, "foo") - h.assertNotFound(t, "bar") + t.Run("mem", func(t *testing.T) { + h := newAssetTestHarness(t) + h.store = NewCachedAssetStore(h.store) + testAssetStore(t, h) + }) }