Implement post.Store

This commit is contained in:
Brian Picciano 2022-05-06 17:22:17 -06:00
parent d8b12cf17a
commit f7d72adfb5
5 changed files with 713 additions and 39 deletions

View File

@ -83,7 +83,8 @@ type store struct {
db *sql.DB db *sql.DB
} }
// NewStore initializes a new store using the given SQL DB instance. // NewStore initializes a new Store using a sqlite3 database at the given file
// path.
func NewStore(dbFile string) (Store, error) { func NewStore(dbFile string) (Store, error) {
db, err := sql.Open("sqlite3", dbFile) db, err := sql.Open("sqlite3", dbFile)

59
srv/src/post/date.go Normal file
View File

@ -0,0 +1,59 @@
package post
import (
"database/sql/driver"
"fmt"
"time"
)
// Date represents a calendar date with no timezone information attached.
type Date struct {
Year int
Month time.Month
Day int
}
// DateFromTime converts a Time into a Date, truncating all non-date
// information.
func DateFromTime(t time.Time) Date {
t = t.UTC()
return Date{
Year: t.Year(),
Month: t.Month(),
Day: t.Day(),
}
}
// ToTime converts a Date into a Time. The returned time will be UTC midnight of
// the Date.
func (d *Date) ToTime() time.Time {
return time.Date(d.Year, d.Month, d.Day, 0, 0, 0, 0, time.UTC)
}
// Scan implements the sql.Scanner interface.
func (d *Date) Scan(src interface{}) error {
if src == nil {
*d = Date{}
return nil
}
ts, ok := src.(int64)
if !ok {
return fmt.Errorf("cannot scan value %#v into Date", src)
}
*d = DateFromTime(time.Unix(ts, 0))
return nil
}
// Value implements the driver.Valuer interface.
func (d Date) Value() (driver.Value, error) {
if d == (Date{}) {
return nil, nil
}
return d.ToTime().Unix(), nil
}

View File

@ -2,30 +2,10 @@
package post package post
import ( import (
"fmt"
"path"
"regexp" "regexp"
"strings" "strings"
"time"
) )
// Date represents a calendar date with no timezone information attached.
type Date struct {
Year int
Month time.Month
Day int
}
// DateFromTime converts a Time into a Date, truncating all non-date
// information.
func DateFromTime(t time.Time) Date {
return Date{
Year: t.Year(),
Month: t.Month(),
Day: t.Day(),
}
}
var titleCleanRegexp = regexp.MustCompile(`[^a-z ]`) var titleCleanRegexp = regexp.MustCompile(`[^a-z ]`)
// NewID generates a (hopefully) unique ID based on the given title. // NewID generates a (hopefully) unique ID based on the given title.
@ -43,22 +23,5 @@ type Post struct {
Description string Description string
Tags []string Tags []string
Series string Series string
Body string
PublishedAt Date
LastUpdatedAt Date
Body string
}
// URL returns the relative URL of the Post.
func (p Post) URL() string {
return path.Join(
fmt.Sprintf(
"%d/%0d/%0d",
p.PublishedAt.Year,
p.PublishedAt.Month,
p.PublishedAt.Day,
),
p.ID+".html",
)
} }

View File

@ -1 +1,404 @@
package post package post
import (
"database/sql"
"errors"
"fmt"
"path"
"strings"
_ "github.com/mattn/go-sqlite3"
migrate "github.com/rubenv/sql-migrate"
"github.com/tilinna/clock"
)
var (
// ErrNotFound is used to indicate a Post could not be found in the
// database.
ErrNotFound = errors.New("not found")
)
// StoredPost is a Post which has been stored in a Store, and has been given
// some extra fields as a result.
type StoredPost struct {
Post
PublishedAt Date
LastUpdatedAt Date
}
// URL returns the relative URL of the StoredPost.
func (p StoredPost) URL() string {
return path.Join(
fmt.Sprintf(
"%d/%0d/%0d",
p.PublishedAt.Year,
p.PublishedAt.Month,
p.PublishedAt.Day,
),
p.ID+".html",
)
}
// Store is used for storing posts to a persistent storage.
type Store interface {
// Set sets the Post data into the storage, keyed by the Post's ID. It
// overwrites a previous Post with the same ID, if there was one.
Set(Post) error
// Get returns count StoredPosts, sorted time descending, offset by the given page
// number. The returned boolean indicates if there are more pages or not.
Get(page, count int) ([]StoredPost, bool, error)
// GetByID will return the StoredPost with the given ID, or ErrNotFound.
GetByID(id string) (StoredPost, error)
// GetBySeries returns all StoredPosts with the given series, sorted time
// ascending, or empty slice.
GetBySeries(series string) ([]StoredPost, error)
// GetByTag returns all StoredPosts with the given tag, sorted time
// ascending, or empty slice.
GetByTag(tag string) ([]StoredPost, error)
// Delete will delete the StoredPost with the given ID.
Delete(id string) error
Close() error
}
var migrations = []*migrate.Migration{
&migrate.Migration{
Id: "1",
Up: []string{
`CREATE TABLE posts (
id TEXT NOT NULL PRIMARY KEY,
title TEXT NOT NULL,
description TEXT NOT NULL,
series TEXT,
published_at INTEGER NOT NULL,
last_updated_at INTEGER,
body TEXT NOT NULL
)`,
`CREATE TABLE post_tags (
post_id TEXT NOT NULL,
tag TEXT NOT NULL,
UNIQUE(post_id, tag)
)`,
},
Down: []string{
"DROP TABLE post_tags",
"DROP TABLE posts",
},
},
}
// Params are parameters used to initialize a new Store. All fields are required
// unless otherwise noted.
type StoreParams struct {
// Path to the file the database will be stored at.
DBFilePath string
Clock clock.Clock
}
type store struct {
params StoreParams
db *sql.DB
}
// NewStore initializes a new Store using a sqlite3 database at the given file
// path.
func NewStore(params StoreParams) (Store, error) {
db, err := sql.Open("sqlite3", params.DBFilePath)
if err != nil {
return nil, fmt.Errorf("opening sqlite file: %w", err)
}
migrations := &migrate.MemoryMigrationSource{Migrations: migrations}
if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil {
return nil, fmt.Errorf("running migrations: %w", err)
}
return &store{
params: params,
db: db,
}, nil
}
func (s *store) Close() error {
return s.db.Close()
}
// if the callback returns an error then the transaction is aborted
func (s *store) withTx(cb func(*sql.Tx) error) error {
tx, err := s.db.Begin()
if err != nil {
return fmt.Errorf("starting transaction: %w", err)
}
if err := cb(tx); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf(
"rolling back transaction: %w (original error: %v)",
rollbackErr, err,
)
}
return fmt.Errorf("performing transaction: %w (rolled back)", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
return nil
}
func (s *store) Set(post Post) error {
return s.withTx(func(tx *sql.Tx) error {
currentDate := DateFromTime(s.params.Clock.Now())
_, err := tx.Exec(
`INSERT INTO posts (
id, title, description, series, published_at, body
)
VALUES
(?, ?, ?, ?, ?, ?)
ON CONFLICT (id) DO UPDATE SET
title=excluded.title,
description=excluded.description,
series=excluded.series,
last_updated_at=?,
body=excluded.body`,
post.ID,
post.Title,
post.Description,
&sql.NullString{String: post.Series, Valid: post.Series != ""},
currentDate,
post.Body,
currentDate,
)
if err != nil {
return fmt.Errorf("inserting into posts: %w", err)
}
// this is a bit of a hack, but it allows us to update the tagset without
// doing a diff.
_, err = tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, post.ID)
if err != nil {
return fmt.Errorf("clearning post tags: %w", err)
}
for _, tag := range post.Tags {
_, err = tx.Exec(
`INSERT INTO post_tags (post_id, tag) VALUES (?, ?)
ON CONFLICT DO NOTHING`,
post.ID,
tag,
)
if err != nil {
return fmt.Errorf("inserting tag %q: %w", tag, err)
}
}
return nil
})
}
func (s *store) get(
querier interface {
Query(string, ...interface{}) (*sql.Rows, error)
},
limit, offset int,
where string, whereArgs ...interface{},
) (
[]StoredPost, error,
) {
query := `SELECT
p.id, p.title, p.description, p.series, pt.tag,
p.published_at, p.last_updated_at,
p.body
FROM posts p
LEFT JOIN post_tags pt ON (p.id = pt.post_id)
` + where + `
ORDER BY p.published_at ASC, p.title ASC`
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
}
if offset > 0 {
query += fmt.Sprintf(" OFFSET %d", offset)
}
rows, err := querier.Query(query, whereArgs...)
if err != nil {
return nil, fmt.Errorf("selecting: %w", err)
}
var posts []StoredPost
for rows.Next() {
var (
post StoredPost
series, tag sql.NullString
)
err := rows.Scan(
&post.ID, &post.Title, &post.Description, &series, &tag,
&post.PublishedAt, &post.LastUpdatedAt,
&post.Body,
)
if err != nil {
return nil, fmt.Errorf("scanning row: %w", err)
}
if tag.Valid {
if l := len(posts); l > 0 && posts[l-1].ID == post.ID {
posts[l-1].Tags = append(posts[l-1].Tags, tag.String)
continue
}
post.Tags = append(post.Tags, tag.String)
}
post.Series = series.String
posts = append(posts, post)
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("closing row iterator: %w", err)
}
return posts, nil
}
func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
posts, err := s.get(s.db, count+1, page*count, ``)
if err != nil {
return nil, false, fmt.Errorf("querying posts: %w", err)
}
var hasMore bool
if len(posts) > count {
hasMore = true
posts = posts[:count]
}
return posts, hasMore, nil
}
func (s *store) GetByID(id string) (StoredPost, error) {
posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id)
if err != nil {
return StoredPost{}, fmt.Errorf("querying posts: %w", err)
}
if len(posts) == 0 {
return StoredPost{}, ErrNotFound
}
if len(posts) > 1 {
panic(fmt.Sprintf("got back multiple posts querying id %q: %+v", id, posts))
}
return posts[0], nil
}
func (s *store) GetBySeries(series string) ([]StoredPost, error) {
return s.get(s.db, 0, 0, `WHERE p.series=?`, series)
}
func (s *store) GetByTag(tag string) ([]StoredPost, error) {
var posts []StoredPost
err := s.withTx(func(tx *sql.Tx) error {
rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
if err != nil {
return fmt.Errorf("querying post_tags by tag: %w", err)
}
var (
placeholders []string
whereArgs []interface{}
)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
rows.Close()
return fmt.Errorf("scanning id: %w", err)
}
whereArgs = append(whereArgs, id)
placeholders = append(placeholders, "?")
}
if err := rows.Close(); err != nil {
return fmt.Errorf("closing row iterator: %w", err)
}
where := fmt.Sprintf("WHERE p.id IN (%s)", strings.Join(placeholders, ","))
if posts, err = s.get(tx, 0, 0, where, whereArgs...); err != nil {
return fmt.Errorf("querying for ids %+v: %w", whereArgs, err)
}
return nil
})
return posts, err
}
func (s *store) Delete(id string) error {
tx, err := s.db.Begin()
if err != nil {
return fmt.Errorf("starting transaction: %w", err)
}
if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
return fmt.Errorf("deleting from post_tags: %w", err)
}
if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
return fmt.Errorf("deleting from posts: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
return nil
}

248
srv/src/post/store_test.go Normal file
View File

@ -0,0 +1,248 @@
package post
import (
"io/ioutil"
"os"
"sort"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/tilinna/clock"
)
func testPost(i int) Post {
istr := strconv.Itoa(i)
return Post{
ID: istr,
Title: istr,
Description: istr,
Body: istr,
}
}
type storeTestHarness struct {
clock *clock.Mock
store Store
}
func newStoreTestHarness(t *testing.T) storeTestHarness {
clock := clock.NewMock(time.Now().Truncate(1 * time.Hour))
tmpFile, err := ioutil.TempFile(os.TempDir(), "mediocre-blog-post-store-test-")
if err != nil {
t.Fatal("Cannot create temporary file", err)
}
tmpFilePath := tmpFile.Name()
tmpFile.Close()
t.Logf("using temporary sqlite file at %q", tmpFilePath)
t.Cleanup(func() {
if err := os.Remove(tmpFilePath); err != nil {
panic(err)
}
})
store, err := NewStore(StoreParams{
DBFilePath: tmpFilePath,
Clock: clock,
})
assert.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
return storeTestHarness{
clock: clock,
store: store,
}
}
func (h *storeTestHarness) testStoredPost(i int) StoredPost {
post := testPost(i)
return StoredPost{
Post: post,
PublishedAt: DateFromTime(h.clock.Now()),
}
}
func TestStore(t *testing.T) {
assertPostEqual := func(t *testing.T, exp, got StoredPost) {
t.Helper()
sort.Strings(exp.Tags)
sort.Strings(got.Tags)
assert.Equal(t, exp, got)
}
assertPostsEqual := func(t *testing.T, exp, got []StoredPost) {
t.Helper()
if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) {
return
}
for i := range exp {
assertPostEqual(t, exp[i], got[i])
}
}
t.Run("not_found", func(t *testing.T) {
h := newStoreTestHarness(t)
_, err := h.store.GetByID("foo")
assert.ErrorIs(t, err, ErrNotFound)
})
t.Run("set_get_delete", func(t *testing.T) {
h := newStoreTestHarness(t)
nowDate := DateFromTime(h.clock.Now())
post := testPost(0)
post.Tags = []string{"foo", "bar"}
assert.NoError(t, h.store.Set(post))
gotPost, err := h.store.GetByID(post.ID)
assert.NoError(t, err)
assertPostEqual(t, StoredPost{
Post: post,
PublishedAt: nowDate,
}, gotPost)
// we will now try updating the post on a different day, and ensure it
// updates properly
h.clock.Add(24 * time.Hour)
newNowDate := DateFromTime(h.clock.Now())
post.Title = "something else"
post.Series = "whatever"
post.Body = "anything"
post.Tags = []string{"bar", "baz"}
assert.NoError(t, h.store.Set(post))
gotPost, err = h.store.GetByID(post.ID)
assert.NoError(t, err)
assertPostEqual(t, StoredPost{
Post: post,
PublishedAt: nowDate,
LastUpdatedAt: newNowDate,
}, gotPost)
// delete the post, it should go away
assert.NoError(t, h.store.Delete(post.ID))
_, err = h.store.GetByID(post.ID)
assert.ErrorIs(t, err, ErrNotFound)
})
t.Run("get", func(t *testing.T) {
h := newStoreTestHarness(t)
posts := []StoredPost{
h.testStoredPost(0),
h.testStoredPost(1),
h.testStoredPost(2),
h.testStoredPost(3),
}
for _, post := range posts {
assert.NoError(t, h.store.Set(post.Post))
}
gotPosts, hasMore, err := h.store.Get(0, 2)
assert.NoError(t, err)
assert.True(t, hasMore)
assertPostsEqual(t, posts[:2], gotPosts)
gotPosts, hasMore, err = h.store.Get(1, 2)
assert.NoError(t, err)
assert.False(t, hasMore)
assertPostsEqual(t, posts[2:4], gotPosts)
posts = append(posts, h.testStoredPost(4))
assert.NoError(t, h.store.Set(posts[4].Post))
gotPosts, hasMore, err = h.store.Get(1, 2)
assert.NoError(t, err)
assert.True(t, hasMore)
assertPostsEqual(t, posts[2:4], gotPosts)
gotPosts, hasMore, err = h.store.Get(2, 2)
assert.NoError(t, err)
assert.False(t, hasMore)
assertPostsEqual(t, posts[4:], gotPosts)
})
t.Run("get_by_series", func(t *testing.T) {
h := newStoreTestHarness(t)
posts := []StoredPost{
h.testStoredPost(0),
h.testStoredPost(1),
h.testStoredPost(2),
h.testStoredPost(3),
}
posts[0].Series = "foo"
posts[1].Series = "bar"
posts[2].Series = "bar"
for _, post := range posts {
assert.NoError(t, h.store.Set(post.Post))
}
fooPosts, err := h.store.GetBySeries("foo")
assert.NoError(t, err)
assertPostsEqual(t, posts[:1], fooPosts)
barPosts, err := h.store.GetBySeries("bar")
assert.NoError(t, err)
assertPostsEqual(t, posts[1:3], barPosts)
bazPosts, err := h.store.GetBySeries("baz")
assert.NoError(t, err)
assert.Empty(t, bazPosts)
})
t.Run("get_by_tag", func(t *testing.T) {
h := newStoreTestHarness(t)
posts := []StoredPost{
h.testStoredPost(0),
h.testStoredPost(1),
h.testStoredPost(2),
h.testStoredPost(3),
}
posts[0].Tags = []string{"foo"}
posts[1].Tags = []string{"foo", "bar"}
posts[2].Tags = []string{"bar"}
for _, post := range posts {
assert.NoError(t, h.store.Set(post.Post))
}
fooPosts, err := h.store.GetByTag("foo")
assert.NoError(t, err)
assertPostsEqual(t, posts[:2], fooPosts)
barPosts, err := h.store.GetByTag("bar")
assert.NoError(t, err)
assertPostsEqual(t, posts[1:3], barPosts)
bazPosts, err := h.store.GetByTag("baz")
assert.NoError(t, err)
assert.Empty(t, bazPosts)
})
}