From 7ac2f5ebb32a6098bc0d590130cc4b933db08771 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Thu, 18 Aug 2022 22:25:17 -0600 Subject: [PATCH] implement DraftStore --- srv/src/post/draft_post.go | 186 ++++++++++++++++++++++++++++++++ srv/src/post/draft_post_test.go | 130 ++++++++++++++++++++++ srv/src/post/post.go | 69 +++--------- srv/src/post/sql.go | 43 +++++++- 4 files changed, 372 insertions(+), 56 deletions(-) create mode 100644 srv/src/post/draft_post.go create mode 100644 srv/src/post/draft_post_test.go diff --git a/srv/src/post/draft_post.go b/srv/src/post/draft_post.go new file mode 100644 index 0000000..af52965 --- /dev/null +++ b/srv/src/post/draft_post.go @@ -0,0 +1,186 @@ +package post + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" +) + +type DraftStore interface { + + // Set sets the draft Post's data into the storage, keyed by the draft + // Post's ID. + Set(post Post) error + + // Get returns count draft Posts, sorted id descending, offset by the + // given page number. The returned boolean indicates if there are more pages + // or not. + Get(page, count int) ([]Post, bool, error) + + // GetByID will return the draft Post with the given ID, or ErrPostNotFound. + GetByID(id string) (Post, error) + + // Delete will delete the draft Post with the given ID. + Delete(id string) error +} + +type draftStore struct { + db *SQLDB +} + +// NewDraftStore initializes a new DraftStore using an existing SQLDB. +func NewDraftStore(db *SQLDB) DraftStore { + return &draftStore{ + db: db, + } +} + +func (s *draftStore) Set(post Post) error { + + if post.ID == "" { + return errors.New("post ID can't be empty") + } + + tagsJSON, err := json.Marshal(post.Tags) + if err != nil { + return fmt.Errorf("json marshaling tags %#v: %w", post.Tags, err) + } + + _, err = s.db.db.Exec( + `INSERT INTO post_drafts ( + id, title, description, tags, series, body + ) + VALUES + (?, ?, ?, ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + title=excluded.title, + description=excluded.description, + tags=excluded.tags, + series=excluded.series, + body=excluded.body`, + post.ID, + post.Title, + post.Description, + &sql.NullString{String: string(tagsJSON), Valid: len(post.Tags) > 0}, + &sql.NullString{String: post.Series, Valid: post.Series != ""}, + post.Body, + ) + + if err != nil { + return fmt.Errorf("inserting into post_drafts: %w", err) + } + + return nil +} + +func (s *draftStore) get( + querier interface { + Query(string, ...interface{}) (*sql.Rows, error) + }, + limit, offset int, + where string, whereArgs ...interface{}, +) ( + []Post, error, +) { + + query := ` + SELECT + p.id, p.title, p.description, p.tags, p.series, p.body + FROM post_drafts p + ORDER BY p.id ASC` + + if limit > 0 { + query += fmt.Sprintf(" LIMIT %d", limit) + } + + if offset > 0 { + query += fmt.Sprintf(" OFFSET %d", offset) + } + + rows, err := querier.Query(query, whereArgs...) + + if err != nil { + return nil, fmt.Errorf("selecting: %w", err) + } + + defer rows.Close() + + var posts []Post + + for rows.Next() { + + var ( + post Post + tags, series sql.NullString + ) + + err := rows.Scan( + &post.ID, &post.Title, &post.Description, &tags, &series, + &post.Body, + ) + + if err != nil { + return nil, fmt.Errorf("scanning row: %w", err) + } + + post.Series = series.String + + if tags.String != "" { + + if err := json.Unmarshal([]byte(tags.String), &post.Tags); err != nil { + return nil, fmt.Errorf("json parsing %q: %w", tags.String, err) + } + } + + posts = append(posts, post) + } + + return posts, nil +} + +func (s *draftStore) Get(page, count int) ([]Post, bool, error) { + + posts, err := s.get(s.db.db, count+1, page*count, ``) + + if err != nil { + return nil, false, fmt.Errorf("querying post_drafts: %w", err) + } + + var hasMore bool + + if len(posts) > count { + hasMore = true + posts = posts[:count] + } + + return posts, hasMore, nil +} + +func (s *draftStore) GetByID(id string) (Post, error) { + + posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id) + + if err != nil { + return Post{}, fmt.Errorf("querying post_drafts: %w", err) + } + + if len(posts) == 0 { + return Post{}, ErrPostNotFound + } + + if len(posts) > 1 { + panic(fmt.Sprintf("got back multiple draft posts querying id %q: %+v", id, posts)) + } + + return posts[0], nil +} + +func (s *draftStore) Delete(id string) error { + + if _, err := s.db.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil { + return fmt.Errorf("deleting from post_drafts: %w", err) + } + + return nil +} diff --git a/srv/src/post/draft_post_test.go b/srv/src/post/draft_post_test.go new file mode 100644 index 0000000..f404bb0 --- /dev/null +++ b/srv/src/post/draft_post_test.go @@ -0,0 +1,130 @@ +package post + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" +) + +type draftStoreTestHarness struct { + store DraftStore +} + +func newDraftStoreTestHarness(t *testing.T) draftStoreTestHarness { + + db := NewInMemSQLDB() + t.Cleanup(func() { db.Close() }) + + store := NewDraftStore(db) + + return draftStoreTestHarness{ + store: store, + } +} + +func TestDraftStore(t *testing.T) { + + assertPostEqual := func(t *testing.T, exp, got Post) { + t.Helper() + sort.Strings(exp.Tags) + sort.Strings(got.Tags) + assert.Equal(t, exp, got) + } + + assertPostsEqual := func(t *testing.T, exp, got []Post) { + t.Helper() + + if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) { + return + } + + for i := range exp { + assertPostEqual(t, exp[i], got[i]) + } + } + + t.Run("not_found", func(t *testing.T) { + h := newDraftStoreTestHarness(t) + + _, err := h.store.GetByID("foo") + assert.ErrorIs(t, err, ErrPostNotFound) + }) + + t.Run("set_get_delete", func(t *testing.T) { + h := newDraftStoreTestHarness(t) + + post := testPost(0) + post.Tags = []string{"foo", "bar"} + + err := h.store.Set(post) + assert.NoError(t, err) + + gotPost, err := h.store.GetByID(post.ID) + assert.NoError(t, err) + + assertPostEqual(t, post, gotPost) + + // we will now try updating the post, and ensure it updates properly + + post.Title = "something else" + post.Series = "whatever" + post.Body = "anything" + post.Tags = []string{"bar", "baz"} + + err = h.store.Set(post) + assert.NoError(t, err) + + gotPost, err = h.store.GetByID(post.ID) + assert.NoError(t, err) + + assertPostEqual(t, post, gotPost) + + // delete the post, it should go away + assert.NoError(t, h.store.Delete(post.ID)) + + _, err = h.store.GetByID(post.ID) + assert.ErrorIs(t, err, ErrPostNotFound) + }) + + t.Run("get", func(t *testing.T) { + h := newDraftStoreTestHarness(t) + + posts := []Post{ + testPost(0), + testPost(1), + testPost(2), + testPost(3), + } + + for _, post := range posts { + err := h.store.Set(post) + assert.NoError(t, err) + } + + gotPosts, hasMore, err := h.store.Get(0, 2) + assert.NoError(t, err) + assert.True(t, hasMore) + assertPostsEqual(t, posts[:2], gotPosts) + + gotPosts, hasMore, err = h.store.Get(1, 2) + assert.NoError(t, err) + assert.False(t, hasMore) + assertPostsEqual(t, posts[2:4], gotPosts) + + posts = append(posts, testPost(4)) + err = h.store.Set(posts[4]) + assert.NoError(t, err) + + gotPosts, hasMore, err = h.store.Get(1, 2) + assert.NoError(t, err) + assert.True(t, hasMore) + assertPostsEqual(t, posts[2:4], gotPosts) + + gotPosts, hasMore, err = h.store.Get(2, 2) + assert.NoError(t, err) + assert.False(t, hasMore) + assertPostsEqual(t, posts[4:], gotPosts) + }) + +} diff --git a/srv/src/post/post.go b/srv/src/post/post.go index a39af61..03bce6c 100644 --- a/srv/src/post/post.go +++ b/srv/src/post/post.go @@ -77,44 +77,16 @@ type Store interface { } type store struct { - db *sql.DB + db *SQLDB } // NewStore initializes a new Store using an existing SQLDB. func NewStore(db *SQLDB) Store { return &store{ - db: db.db, + db: db, } } -// if the callback returns an error then the transaction is aborted. -func (s *store) withTx(cb func(*sql.Tx) error) error { - - tx, err := s.db.Begin() - - if err != nil { - return fmt.Errorf("starting transaction: %w", err) - } - - if err := cb(tx); err != nil { - - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf( - "rolling back transaction: %w (original error: %v)", - rollbackErr, err, - ) - } - - return fmt.Errorf("performing transaction: %w (rolled back)", err) - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) - } - - return nil -} - func (s *store) Set(post Post, now time.Time) (bool, error) { if post.ID == "" { @@ -123,7 +95,7 @@ func (s *store) Set(post Post, now time.Time) (bool, error) { var first bool - err := s.withTx(func(tx *sql.Tx) error { + err := s.db.withTx(func(tx *sql.Tx) error { nowTS := now.Unix() @@ -270,7 +242,7 @@ func (s *store) get( func (s *store) Get(page, count int) ([]StoredPost, bool, error) { - posts, err := s.get(s.db, count+1, page*count, ``) + posts, err := s.get(s.db.db, count+1, page*count, ``) if err != nil { return nil, false, fmt.Errorf("querying posts: %w", err) @@ -288,7 +260,7 @@ func (s *store) Get(page, count int) ([]StoredPost, bool, error) { func (s *store) GetByID(id string) (StoredPost, error) { - posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id) + posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id) if err != nil { return StoredPost{}, fmt.Errorf("querying posts: %w", err) @@ -306,14 +278,14 @@ func (s *store) GetByID(id string) (StoredPost, error) { } func (s *store) GetBySeries(series string) ([]StoredPost, error) { - return s.get(s.db, 0, 0, `WHERE p.series=?`, series) + return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series) } func (s *store) GetByTag(tag string) ([]StoredPost, error) { var posts []StoredPost - err := s.withTx(func(tx *sql.Tx) error { + err := s.db.withTx(func(tx *sql.Tx) error { rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag) @@ -357,7 +329,7 @@ func (s *store) GetByTag(tag string) ([]StoredPost, error) { func (s *store) GetTags() ([]string, error) { - rows, err := s.db.Query(`SELECT tag FROM post_tags GROUP BY tag`) + rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`) if err != nil { return nil, fmt.Errorf("querying all tags: %w", err) } @@ -381,23 +353,16 @@ func (s *store) GetTags() ([]string, error) { func (s *store) Delete(id string) error { - tx, err := s.db.Begin() - - if err != nil { - return fmt.Errorf("starting transaction: %w", err) - } - - if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { - return fmt.Errorf("deleting from post_tags: %w", err) - } + return s.db.withTx(func(tx *sql.Tx) error { - if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil { - return fmt.Errorf("deleting from posts: %w", err) - } + if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { + return fmt.Errorf("deleting from post_tags: %w", err) + } - if err := tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) - } + if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil { + return fmt.Errorf("deleting from posts: %w", err) + } - return nil + return nil + }) } diff --git a/srv/src/post/sql.go b/srv/src/post/sql.go index 16cdc95..c768c9a 100644 --- a/srv/src/post/sql.go +++ b/srv/src/post/sql.go @@ -38,10 +38,18 @@ var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration body BLOB NOT NULL )`, }, - Down: []string{ - "DROP TABLE assets", - "DROP TABLE post_tags", - "DROP TABLE posts", + }, + { + Id: "2", + Up: []string{ + `CREATE TABLE post_drafts ( + id TEXT NOT NULL PRIMARY KEY, + title TEXT NOT NULL, + description TEXT NOT NULL, + tags TEXT, + series TEXT, + body TEXT NOT NULL + )`, }, }, }} @@ -89,3 +97,30 @@ func NewInMemSQLDB() *SQLDB { func (db *SQLDB) Close() error { return db.db.Close() } + +func (db *SQLDB) withTx(cb func(*sql.Tx) error) error { + + tx, err := db.db.Begin() + + if err != nil { + return fmt.Errorf("starting transaction: %w", err) + } + + if err := cb(tx); err != nil { + + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf( + "rolling back transaction: %w (original error: %v)", + rollbackErr, err, + ) + } + + return fmt.Errorf("performing transaction: %w (rolled back)", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil +}