implement DraftStore
This commit is contained in:
parent
76ff79f470
commit
7ac2f5ebb3
186
srv/src/post/draft_post.go
Normal file
186
srv/src/post/draft_post.go
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
package post
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DraftStore interface {
|
||||||
|
|
||||||
|
// Set sets the draft Post's data into the storage, keyed by the draft
|
||||||
|
// Post's ID.
|
||||||
|
Set(post Post) error
|
||||||
|
|
||||||
|
// Get returns count draft Posts, sorted id descending, offset by the
|
||||||
|
// given page number. The returned boolean indicates if there are more pages
|
||||||
|
// or not.
|
||||||
|
Get(page, count int) ([]Post, bool, error)
|
||||||
|
|
||||||
|
// GetByID will return the draft Post with the given ID, or ErrPostNotFound.
|
||||||
|
GetByID(id string) (Post, error)
|
||||||
|
|
||||||
|
// Delete will delete the draft Post with the given ID.
|
||||||
|
Delete(id string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type draftStore struct {
|
||||||
|
db *SQLDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDraftStore initializes a new DraftStore using an existing SQLDB.
|
||||||
|
func NewDraftStore(db *SQLDB) DraftStore {
|
||||||
|
return &draftStore{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *draftStore) Set(post Post) error {
|
||||||
|
|
||||||
|
if post.ID == "" {
|
||||||
|
return errors.New("post ID can't be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
tagsJSON, err := json.Marshal(post.Tags)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("json marshaling tags %#v: %w", post.Tags, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.db.db.Exec(
|
||||||
|
`INSERT INTO post_drafts (
|
||||||
|
id, title, description, tags, series, body
|
||||||
|
)
|
||||||
|
VALUES
|
||||||
|
(?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT (id) DO UPDATE SET
|
||||||
|
title=excluded.title,
|
||||||
|
description=excluded.description,
|
||||||
|
tags=excluded.tags,
|
||||||
|
series=excluded.series,
|
||||||
|
body=excluded.body`,
|
||||||
|
post.ID,
|
||||||
|
post.Title,
|
||||||
|
post.Description,
|
||||||
|
&sql.NullString{String: string(tagsJSON), Valid: len(post.Tags) > 0},
|
||||||
|
&sql.NullString{String: post.Series, Valid: post.Series != ""},
|
||||||
|
post.Body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("inserting into post_drafts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *draftStore) get(
|
||||||
|
querier interface {
|
||||||
|
Query(string, ...interface{}) (*sql.Rows, error)
|
||||||
|
},
|
||||||
|
limit, offset int,
|
||||||
|
where string, whereArgs ...interface{},
|
||||||
|
) (
|
||||||
|
[]Post, error,
|
||||||
|
) {
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
p.id, p.title, p.description, p.tags, p.series, p.body
|
||||||
|
FROM post_drafts p
|
||||||
|
ORDER BY p.id ASC`
|
||||||
|
|
||||||
|
if limit > 0 {
|
||||||
|
query += fmt.Sprintf(" LIMIT %d", limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
if offset > 0 {
|
||||||
|
query += fmt.Sprintf(" OFFSET %d", offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := querier.Query(query, whereArgs...)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("selecting: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var posts []Post
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
|
||||||
|
var (
|
||||||
|
post Post
|
||||||
|
tags, series sql.NullString
|
||||||
|
)
|
||||||
|
|
||||||
|
err := rows.Scan(
|
||||||
|
&post.ID, &post.Title, &post.Description, &tags, &series,
|
||||||
|
&post.Body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning row: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
post.Series = series.String
|
||||||
|
|
||||||
|
if tags.String != "" {
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(tags.String), &post.Tags); err != nil {
|
||||||
|
return nil, fmt.Errorf("json parsing %q: %w", tags.String, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
posts = append(posts, post)
|
||||||
|
}
|
||||||
|
|
||||||
|
return posts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *draftStore) Get(page, count int) ([]Post, bool, error) {
|
||||||
|
|
||||||
|
posts, err := s.get(s.db.db, count+1, page*count, ``)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, fmt.Errorf("querying post_drafts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var hasMore bool
|
||||||
|
|
||||||
|
if len(posts) > count {
|
||||||
|
hasMore = true
|
||||||
|
posts = posts[:count]
|
||||||
|
}
|
||||||
|
|
||||||
|
return posts, hasMore, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *draftStore) GetByID(id string) (Post, error) {
|
||||||
|
|
||||||
|
posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return Post{}, fmt.Errorf("querying post_drafts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(posts) == 0 {
|
||||||
|
return Post{}, ErrPostNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(posts) > 1 {
|
||||||
|
panic(fmt.Sprintf("got back multiple draft posts querying id %q: %+v", id, posts))
|
||||||
|
}
|
||||||
|
|
||||||
|
return posts[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *draftStore) Delete(id string) error {
|
||||||
|
|
||||||
|
if _, err := s.db.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil {
|
||||||
|
return fmt.Errorf("deleting from post_drafts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
130
srv/src/post/draft_post_test.go
Normal file
130
srv/src/post/draft_post_test.go
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
package post
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type draftStoreTestHarness struct {
|
||||||
|
store DraftStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDraftStoreTestHarness(t *testing.T) draftStoreTestHarness {
|
||||||
|
|
||||||
|
db := NewInMemSQLDB()
|
||||||
|
t.Cleanup(func() { db.Close() })
|
||||||
|
|
||||||
|
store := NewDraftStore(db)
|
||||||
|
|
||||||
|
return draftStoreTestHarness{
|
||||||
|
store: store,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDraftStore(t *testing.T) {
|
||||||
|
|
||||||
|
assertPostEqual := func(t *testing.T, exp, got Post) {
|
||||||
|
t.Helper()
|
||||||
|
sort.Strings(exp.Tags)
|
||||||
|
sort.Strings(got.Tags)
|
||||||
|
assert.Equal(t, exp, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertPostsEqual := func(t *testing.T, exp, got []Post) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range exp {
|
||||||
|
assertPostEqual(t, exp[i], got[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("not_found", func(t *testing.T) {
|
||||||
|
h := newDraftStoreTestHarness(t)
|
||||||
|
|
||||||
|
_, err := h.store.GetByID("foo")
|
||||||
|
assert.ErrorIs(t, err, ErrPostNotFound)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("set_get_delete", func(t *testing.T) {
|
||||||
|
h := newDraftStoreTestHarness(t)
|
||||||
|
|
||||||
|
post := testPost(0)
|
||||||
|
post.Tags = []string{"foo", "bar"}
|
||||||
|
|
||||||
|
err := h.store.Set(post)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
gotPost, err := h.store.GetByID(post.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assertPostEqual(t, post, gotPost)
|
||||||
|
|
||||||
|
// we will now try updating the post, and ensure it updates properly
|
||||||
|
|
||||||
|
post.Title = "something else"
|
||||||
|
post.Series = "whatever"
|
||||||
|
post.Body = "anything"
|
||||||
|
post.Tags = []string{"bar", "baz"}
|
||||||
|
|
||||||
|
err = h.store.Set(post)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
gotPost, err = h.store.GetByID(post.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assertPostEqual(t, post, gotPost)
|
||||||
|
|
||||||
|
// delete the post, it should go away
|
||||||
|
assert.NoError(t, h.store.Delete(post.ID))
|
||||||
|
|
||||||
|
_, err = h.store.GetByID(post.ID)
|
||||||
|
assert.ErrorIs(t, err, ErrPostNotFound)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get", func(t *testing.T) {
|
||||||
|
h := newDraftStoreTestHarness(t)
|
||||||
|
|
||||||
|
posts := []Post{
|
||||||
|
testPost(0),
|
||||||
|
testPost(1),
|
||||||
|
testPost(2),
|
||||||
|
testPost(3),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, post := range posts {
|
||||||
|
err := h.store.Set(post)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotPosts, hasMore, err := h.store.Get(0, 2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, hasMore)
|
||||||
|
assertPostsEqual(t, posts[:2], gotPosts)
|
||||||
|
|
||||||
|
gotPosts, hasMore, err = h.store.Get(1, 2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, hasMore)
|
||||||
|
assertPostsEqual(t, posts[2:4], gotPosts)
|
||||||
|
|
||||||
|
posts = append(posts, testPost(4))
|
||||||
|
err = h.store.Set(posts[4])
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
gotPosts, hasMore, err = h.store.Get(1, 2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, hasMore)
|
||||||
|
assertPostsEqual(t, posts[2:4], gotPosts)
|
||||||
|
|
||||||
|
gotPosts, hasMore, err = h.store.Get(2, 2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, hasMore)
|
||||||
|
assertPostsEqual(t, posts[4:], gotPosts)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
@ -77,44 +77,16 @@ type Store interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type store struct {
|
type store struct {
|
||||||
db *sql.DB
|
db *SQLDB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStore initializes a new Store using an existing SQLDB.
|
// NewStore initializes a new Store using an existing SQLDB.
|
||||||
func NewStore(db *SQLDB) Store {
|
func NewStore(db *SQLDB) Store {
|
||||||
return &store{
|
return &store{
|
||||||
db: db.db,
|
db: db,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the callback returns an error then the transaction is aborted.
|
|
||||||
func (s *store) withTx(cb func(*sql.Tx) error) error {
|
|
||||||
|
|
||||||
tx, err := s.db.Begin()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("starting transaction: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cb(tx); err != nil {
|
|
||||||
|
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
return fmt.Errorf(
|
|
||||||
"rolling back transaction: %w (original error: %v)",
|
|
||||||
rollbackErr, err,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("performing transaction: %w (rolled back)", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return fmt.Errorf("committing transaction: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *store) Set(post Post, now time.Time) (bool, error) {
|
func (s *store) Set(post Post, now time.Time) (bool, error) {
|
||||||
|
|
||||||
if post.ID == "" {
|
if post.ID == "" {
|
||||||
@ -123,7 +95,7 @@ func (s *store) Set(post Post, now time.Time) (bool, error) {
|
|||||||
|
|
||||||
var first bool
|
var first bool
|
||||||
|
|
||||||
err := s.withTx(func(tx *sql.Tx) error {
|
err := s.db.withTx(func(tx *sql.Tx) error {
|
||||||
|
|
||||||
nowTS := now.Unix()
|
nowTS := now.Unix()
|
||||||
|
|
||||||
@ -270,7 +242,7 @@ func (s *store) get(
|
|||||||
|
|
||||||
func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
|
func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
|
||||||
|
|
||||||
posts, err := s.get(s.db, count+1, page*count, ``)
|
posts, err := s.get(s.db.db, count+1, page*count, ``)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("querying posts: %w", err)
|
return nil, false, fmt.Errorf("querying posts: %w", err)
|
||||||
@ -288,7 +260,7 @@ func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
|
|||||||
|
|
||||||
func (s *store) GetByID(id string) (StoredPost, error) {
|
func (s *store) GetByID(id string) (StoredPost, error) {
|
||||||
|
|
||||||
posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id)
|
posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return StoredPost{}, fmt.Errorf("querying posts: %w", err)
|
return StoredPost{}, fmt.Errorf("querying posts: %w", err)
|
||||||
@ -306,14 +278,14 @@ func (s *store) GetByID(id string) (StoredPost, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *store) GetBySeries(series string) ([]StoredPost, error) {
|
func (s *store) GetBySeries(series string) ([]StoredPost, error) {
|
||||||
return s.get(s.db, 0, 0, `WHERE p.series=?`, series)
|
return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *store) GetByTag(tag string) ([]StoredPost, error) {
|
func (s *store) GetByTag(tag string) ([]StoredPost, error) {
|
||||||
|
|
||||||
var posts []StoredPost
|
var posts []StoredPost
|
||||||
|
|
||||||
err := s.withTx(func(tx *sql.Tx) error {
|
err := s.db.withTx(func(tx *sql.Tx) error {
|
||||||
|
|
||||||
rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
|
rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
|
||||||
|
|
||||||
@ -357,7 +329,7 @@ func (s *store) GetByTag(tag string) ([]StoredPost, error) {
|
|||||||
|
|
||||||
func (s *store) GetTags() ([]string, error) {
|
func (s *store) GetTags() ([]string, error) {
|
||||||
|
|
||||||
rows, err := s.db.Query(`SELECT tag FROM post_tags GROUP BY tag`)
|
rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("querying all tags: %w", err)
|
return nil, fmt.Errorf("querying all tags: %w", err)
|
||||||
}
|
}
|
||||||
@ -381,23 +353,16 @@ func (s *store) GetTags() ([]string, error) {
|
|||||||
|
|
||||||
func (s *store) Delete(id string) error {
|
func (s *store) Delete(id string) error {
|
||||||
|
|
||||||
tx, err := s.db.Begin()
|
return s.db.withTx(func(tx *sql.Tx) error {
|
||||||
|
|
||||||
if err != nil {
|
if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
|
||||||
return fmt.Errorf("starting transaction: %w", err)
|
return fmt.Errorf("deleting from post_tags: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
|
if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
|
||||||
return fmt.Errorf("deleting from post_tags: %w", err)
|
return fmt.Errorf("deleting from posts: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
|
return nil
|
||||||
return fmt.Errorf("deleting from posts: %w", err)
|
})
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return fmt.Errorf("committing transaction: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -38,10 +38,18 @@ var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration
|
|||||||
body BLOB NOT NULL
|
body BLOB NOT NULL
|
||||||
)`,
|
)`,
|
||||||
},
|
},
|
||||||
Down: []string{
|
},
|
||||||
"DROP TABLE assets",
|
{
|
||||||
"DROP TABLE post_tags",
|
Id: "2",
|
||||||
"DROP TABLE posts",
|
Up: []string{
|
||||||
|
`CREATE TABLE post_drafts (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
description TEXT NOT NULL,
|
||||||
|
tags TEXT,
|
||||||
|
series TEXT,
|
||||||
|
body TEXT NOT NULL
|
||||||
|
)`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
@ -89,3 +97,30 @@ func NewInMemSQLDB() *SQLDB {
|
|||||||
func (db *SQLDB) Close() error {
|
func (db *SQLDB) Close() error {
|
||||||
return db.db.Close()
|
return db.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *SQLDB) withTx(cb func(*sql.Tx) error) error {
|
||||||
|
|
||||||
|
tx, err := db.db.Begin()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cb(tx); err != nil {
|
||||||
|
|
||||||
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"rolling back transaction: %w (original error: %v)",
|
||||||
|
rollbackErr, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("performing transaction: %w (rolled back)", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("committing transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user