package post import ( "sort" "testing" "github.com/stretchr/testify/assert" ) type draftStoreTestHarness struct { store DraftStore } func newDraftStoreTestHarness(t *testing.T) draftStoreTestHarness { db := NewInMemSQLDB() t.Cleanup(func() { db.Close() }) store := NewDraftStore(db) return draftStoreTestHarness{ store: store, } } func TestDraftStore(t *testing.T) { assertPostEqual := func(t *testing.T, exp, got Post) { t.Helper() sort.Strings(exp.Tags) sort.Strings(got.Tags) assert.Equal(t, exp, got) } assertPostsEqual := func(t *testing.T, exp, got []Post) { t.Helper() if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) { return } for i := range exp { assertPostEqual(t, exp[i], got[i]) } } t.Run("not_found", func(t *testing.T) { h := newDraftStoreTestHarness(t) _, err := h.store.GetByID("foo") assert.ErrorIs(t, err, ErrPostNotFound) }) t.Run("set_get_delete", func(t *testing.T) { h := newDraftStoreTestHarness(t) post := testPost(0) post.Tags = []string{"foo", "bar"} err := h.store.Set(post) assert.NoError(t, err) gotPost, err := h.store.GetByID(post.ID) assert.NoError(t, err) assertPostEqual(t, post, gotPost) // we will now try updating the post, and ensure it updates properly post.Title = "something else" post.Series = "whatever" post.Body = "anything" post.Tags = []string{"bar", "baz"} err = h.store.Set(post) assert.NoError(t, err) gotPost, err = h.store.GetByID(post.ID) assert.NoError(t, err) assertPostEqual(t, post, gotPost) // delete the post, it should go away assert.NoError(t, h.store.Delete(post.ID)) _, err = h.store.GetByID(post.ID) assert.ErrorIs(t, err, ErrPostNotFound) }) t.Run("get", func(t *testing.T) { h := newDraftStoreTestHarness(t) posts := []Post{ testPost(0), testPost(1), testPost(2), testPost(3), } for _, post := range posts { err := h.store.Set(post) assert.NoError(t, err) } gotPosts, hasMore, err := h.store.Get(0, 2) assert.NoError(t, err) assert.True(t, hasMore) assertPostsEqual(t, posts[:2], gotPosts) gotPosts, hasMore, err = h.store.Get(1, 2) assert.NoError(t, err) assert.False(t, hasMore) assertPostsEqual(t, posts[2:4], gotPosts) posts = append(posts, testPost(4)) err = h.store.Set(posts[4]) assert.NoError(t, err) gotPosts, hasMore, err = h.store.Get(1, 2) assert.NoError(t, err) assert.True(t, hasMore) assertPostsEqual(t, posts[2:4], gotPosts) gotPosts, hasMore, err = h.store.Get(2, 2) assert.NoError(t, err) assert.False(t, hasMore) assertPostsEqual(t, posts[4:], gotPosts) }) }