Cleanup various small issues with post package
This commit is contained in:
parent
07806c6942
commit
cfb633b3b5
@ -1,59 +0,0 @@
|
|||||||
package post
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql/driver"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Date represents a calendar date with no timezone information attached.
|
|
||||||
type Date struct {
|
|
||||||
Year int
|
|
||||||
Month time.Month
|
|
||||||
Day int
|
|
||||||
}
|
|
||||||
|
|
||||||
// DateFromTime converts a Time into a Date, truncating all non-date
|
|
||||||
// information.
|
|
||||||
func DateFromTime(t time.Time) Date {
|
|
||||||
t = t.UTC()
|
|
||||||
return Date{
|
|
||||||
Year: t.Year(),
|
|
||||||
Month: t.Month(),
|
|
||||||
Day: t.Day(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToTime converts a Date into a Time. The returned time will be UTC midnight of
|
|
||||||
// the Date.
|
|
||||||
func (d *Date) ToTime() time.Time {
|
|
||||||
return time.Date(d.Year, d.Month, d.Day, 0, 0, 0, 0, time.UTC)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan implements the sql.Scanner interface.
|
|
||||||
func (d *Date) Scan(src interface{}) error {
|
|
||||||
|
|
||||||
if src == nil {
|
|
||||||
*d = Date{}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ts, ok := src.(int64)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("cannot scan value %#v into Date", src)
|
|
||||||
}
|
|
||||||
|
|
||||||
*d = DateFromTime(time.Unix(ts, 0))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value implements the driver.Valuer interface.
|
|
||||||
func (d Date) Value() (driver.Value, error) {
|
|
||||||
|
|
||||||
if d == (Date{}) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return d.ToTime().Unix(), nil
|
|
||||||
}
|
|
@ -1,9 +1,20 @@
|
|||||||
// Package post deals with the storage and rending of blog post.
|
// Package post deals with the storage and rendering of blog posts.
|
||||||
package post
|
package post
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"path"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrPostNotFound is used to indicate a Post could not be found in the
|
||||||
|
// Store.
|
||||||
|
ErrPostNotFound = errors.New("post not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
var titleCleanRegexp = regexp.MustCompile(`[^a-z ]`)
|
var titleCleanRegexp = regexp.MustCompile(`[^a-z ]`)
|
||||||
@ -25,3 +36,342 @@ type Post struct {
|
|||||||
Series string
|
Series string
|
||||||
Body string
|
Body string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StoredPost is a Post which has been stored in a Store, and has been given
|
||||||
|
// some extra fields as a result.
|
||||||
|
type StoredPost struct {
|
||||||
|
Post
|
||||||
|
|
||||||
|
PublishedAt time.Time
|
||||||
|
LastUpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// URL returns the relative URL of the StoredPost.
|
||||||
|
func (p StoredPost) URL() string {
|
||||||
|
return path.Join(
|
||||||
|
fmt.Sprintf(
|
||||||
|
"%d/%0d/%0d",
|
||||||
|
p.PublishedAt.Year(),
|
||||||
|
p.PublishedAt.Month(),
|
||||||
|
p.PublishedAt.Day(),
|
||||||
|
),
|
||||||
|
p.ID+".html",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store is used for storing posts to a persistent storage.
|
||||||
|
type Store interface {
|
||||||
|
|
||||||
|
// Set sets the Post data into the storage, keyed by the Post's ID. It
|
||||||
|
// overwrites a previous Post with the same ID, if there was one.
|
||||||
|
Set(post Post, now time.Time) error
|
||||||
|
|
||||||
|
// Get returns count StoredPosts, sorted time descending, offset by the given page
|
||||||
|
// number. The returned boolean indicates if there are more pages or not.
|
||||||
|
Get(page, count int) ([]StoredPost, bool, error)
|
||||||
|
|
||||||
|
// GetByID will return the StoredPost with the given ID, or ErrPostNotFound.
|
||||||
|
GetByID(id string) (StoredPost, error)
|
||||||
|
|
||||||
|
// GetBySeries returns all StoredPosts with the given series, sorted time
|
||||||
|
// ascending, or empty slice.
|
||||||
|
GetBySeries(series string) ([]StoredPost, error)
|
||||||
|
|
||||||
|
// GetByTag returns all StoredPosts with the given tag, sorted time
|
||||||
|
// ascending, or empty slice.
|
||||||
|
GetByTag(tag string) ([]StoredPost, error)
|
||||||
|
|
||||||
|
// Delete will delete the StoredPost with the given ID.
|
||||||
|
Delete(id string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type store struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStore initializes a new Store using an existing SQLDB.
|
||||||
|
func NewStore(db *SQLDB) Store {
|
||||||
|
return &store{
|
||||||
|
db: db.db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the callback returns an error then the transaction is aborted.
|
||||||
|
func (s *store) withTx(cb func(*sql.Tx) error) error {
|
||||||
|
|
||||||
|
tx, err := s.db.Begin()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cb(tx); err != nil {
|
||||||
|
|
||||||
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"rolling back transaction: %w (original error: %v)",
|
||||||
|
rollbackErr, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("performing transaction: %w (rolled back)", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("committing transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *store) Set(post Post, now time.Time) error {
|
||||||
|
return s.withTx(func(tx *sql.Tx) error {
|
||||||
|
|
||||||
|
nowTS := now.Unix()
|
||||||
|
|
||||||
|
nowSQL := sql.NullInt64{Int64: nowTS, Valid: !now.IsZero()}
|
||||||
|
|
||||||
|
_, err := tx.Exec(
|
||||||
|
`INSERT INTO posts (
|
||||||
|
id, title, description, series, published_at, body
|
||||||
|
)
|
||||||
|
VALUES
|
||||||
|
(?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT (id) DO UPDATE SET
|
||||||
|
title=excluded.title,
|
||||||
|
description=excluded.description,
|
||||||
|
series=excluded.series,
|
||||||
|
last_updated_at=?,
|
||||||
|
body=excluded.body`,
|
||||||
|
post.ID,
|
||||||
|
post.Title,
|
||||||
|
post.Description,
|
||||||
|
&sql.NullString{String: post.Series, Valid: post.Series != ""},
|
||||||
|
nowSQL,
|
||||||
|
post.Body,
|
||||||
|
nowSQL,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("inserting into posts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is a bit of a hack, but it allows us to update the tagset without
|
||||||
|
// doing a diff.
|
||||||
|
_, err = tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, post.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("clearning post tags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tag := range post.Tags {
|
||||||
|
|
||||||
|
_, err = tx.Exec(
|
||||||
|
`INSERT INTO post_tags (post_id, tag) VALUES (?, ?)
|
||||||
|
ON CONFLICT DO NOTHING`,
|
||||||
|
post.ID,
|
||||||
|
tag,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("inserting tag %q: %w", tag, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *store) get(
|
||||||
|
querier interface {
|
||||||
|
Query(string, ...interface{}) (*sql.Rows, error)
|
||||||
|
},
|
||||||
|
limit, offset int,
|
||||||
|
where string, whereArgs ...interface{},
|
||||||
|
) (
|
||||||
|
[]StoredPost, error,
|
||||||
|
) {
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
p.id, p.title, p.description, p.series, pt.tag,
|
||||||
|
p.published_at, p.last_updated_at,
|
||||||
|
p.body
|
||||||
|
FROM posts p
|
||||||
|
LEFT JOIN post_tags pt ON (p.id = pt.post_id)
|
||||||
|
` + where + `
|
||||||
|
ORDER BY p.published_at ASC, p.title ASC`
|
||||||
|
|
||||||
|
if limit > 0 {
|
||||||
|
query += fmt.Sprintf(" LIMIT %d", limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
if offset > 0 {
|
||||||
|
query += fmt.Sprintf(" OFFSET %d", offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := querier.Query(query, whereArgs...)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("selecting: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var posts []StoredPost
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
|
||||||
|
var (
|
||||||
|
post StoredPost
|
||||||
|
series, tag sql.NullString
|
||||||
|
publishedAt, lastUpdatedAt sql.NullInt64
|
||||||
|
)
|
||||||
|
|
||||||
|
err := rows.Scan(
|
||||||
|
&post.ID, &post.Title, &post.Description, &series, &tag,
|
||||||
|
&publishedAt, &lastUpdatedAt,
|
||||||
|
&post.Body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning row: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tag.Valid {
|
||||||
|
|
||||||
|
if l := len(posts); l > 0 && posts[l-1].ID == post.ID {
|
||||||
|
posts[l-1].Tags = append(posts[l-1].Tags, tag.String)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
post.Tags = append(post.Tags, tag.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
post.Series = series.String
|
||||||
|
|
||||||
|
if publishedAt.Valid {
|
||||||
|
post.PublishedAt = time.Unix(publishedAt.Int64, 0).UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastUpdatedAt.Valid {
|
||||||
|
post.LastUpdatedAt = time.Unix(lastUpdatedAt.Int64, 0).UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
posts = append(posts, post)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, fmt.Errorf("closing row iterator: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return posts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
|
||||||
|
|
||||||
|
posts, err := s.get(s.db, count+1, page*count, ``)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, fmt.Errorf("querying posts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var hasMore bool
|
||||||
|
|
||||||
|
if len(posts) > count {
|
||||||
|
hasMore = true
|
||||||
|
posts = posts[:count]
|
||||||
|
}
|
||||||
|
|
||||||
|
return posts, hasMore, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *store) GetByID(id string) (StoredPost, error) {
|
||||||
|
|
||||||
|
posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return StoredPost{}, fmt.Errorf("querying posts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(posts) == 0 {
|
||||||
|
return StoredPost{}, ErrPostNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(posts) > 1 {
|
||||||
|
panic(fmt.Sprintf("got back multiple posts querying id %q: %+v", id, posts))
|
||||||
|
}
|
||||||
|
|
||||||
|
return posts[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *store) GetBySeries(series string) ([]StoredPost, error) {
|
||||||
|
return s.get(s.db, 0, 0, `WHERE p.series=?`, series)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *store) GetByTag(tag string) ([]StoredPost, error) {
|
||||||
|
|
||||||
|
var posts []StoredPost
|
||||||
|
|
||||||
|
err := s.withTx(func(tx *sql.Tx) error {
|
||||||
|
|
||||||
|
rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("querying post_tags by tag: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
placeholders []string
|
||||||
|
whereArgs []interface{}
|
||||||
|
)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
|
||||||
|
var id string
|
||||||
|
|
||||||
|
if err := rows.Scan(&id); err != nil {
|
||||||
|
rows.Close()
|
||||||
|
return fmt.Errorf("scanning id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
whereArgs = append(whereArgs, id)
|
||||||
|
placeholders = append(placeholders, "?")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return fmt.Errorf("closing row iterator: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
where := fmt.Sprintf("WHERE p.id IN (%s)", strings.Join(placeholders, ","))
|
||||||
|
|
||||||
|
if posts, err = s.get(tx, 0, 0, where, whereArgs...); err != nil {
|
||||||
|
return fmt.Errorf("querying for ids %+v: %w", whereArgs, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return posts, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *store) Delete(id string) error {
|
||||||
|
|
||||||
|
tx, err := s.db.Begin()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
|
||||||
|
return fmt.Errorf("deleting from post_tags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
|
||||||
|
return fmt.Errorf("deleting from posts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("committing transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
package post
|
package post
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tilinna/clock"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewID(t *testing.T) {
|
func TestNewID(t *testing.T) {
|
||||||
@ -30,3 +33,224 @@ func TestNewID(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testPost(i int) Post {
|
||||||
|
istr := strconv.Itoa(i)
|
||||||
|
return Post{
|
||||||
|
ID: istr,
|
||||||
|
Title: istr,
|
||||||
|
Description: istr,
|
||||||
|
Body: istr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type storeTestHarness struct {
|
||||||
|
clock *clock.Mock
|
||||||
|
store Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStoreTestHarness(t *testing.T) storeTestHarness {
|
||||||
|
|
||||||
|
clock := clock.NewMock(time.Now().UTC().Truncate(1 * time.Hour))
|
||||||
|
|
||||||
|
db := NewInMemSQLDB()
|
||||||
|
t.Cleanup(func() { db.Close() })
|
||||||
|
|
||||||
|
store := NewStore(db)
|
||||||
|
|
||||||
|
return storeTestHarness{
|
||||||
|
clock: clock,
|
||||||
|
store: store,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *storeTestHarness) testStoredPost(i int) StoredPost {
|
||||||
|
post := testPost(i)
|
||||||
|
return StoredPost{
|
||||||
|
Post: post,
|
||||||
|
PublishedAt: h.clock.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore(t *testing.T) {
|
||||||
|
|
||||||
|
assertPostEqual := func(t *testing.T, exp, got StoredPost) {
|
||||||
|
t.Helper()
|
||||||
|
sort.Strings(exp.Tags)
|
||||||
|
sort.Strings(got.Tags)
|
||||||
|
assert.Equal(t, exp, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertPostsEqual := func(t *testing.T, exp, got []StoredPost) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range exp {
|
||||||
|
assertPostEqual(t, exp[i], got[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("not_found", func(t *testing.T) {
|
||||||
|
h := newStoreTestHarness(t)
|
||||||
|
|
||||||
|
_, err := h.store.GetByID("foo")
|
||||||
|
assert.ErrorIs(t, err, ErrPostNotFound)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("set_get_delete", func(t *testing.T) {
|
||||||
|
h := newStoreTestHarness(t)
|
||||||
|
|
||||||
|
now := h.clock.Now().UTC()
|
||||||
|
|
||||||
|
post := testPost(0)
|
||||||
|
post.Tags = []string{"foo", "bar"}
|
||||||
|
|
||||||
|
assert.NoError(t, h.store.Set(post, now))
|
||||||
|
|
||||||
|
gotPost, err := h.store.GetByID(post.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assertPostEqual(t, StoredPost{
|
||||||
|
Post: post,
|
||||||
|
PublishedAt: now,
|
||||||
|
}, gotPost)
|
||||||
|
|
||||||
|
// we will now try updating the post on a different day, and ensure it
|
||||||
|
// updates properly
|
||||||
|
|
||||||
|
h.clock.Add(24 * time.Hour)
|
||||||
|
newNow := h.clock.Now().UTC()
|
||||||
|
|
||||||
|
post.Title = "something else"
|
||||||
|
post.Series = "whatever"
|
||||||
|
post.Body = "anything"
|
||||||
|
post.Tags = []string{"bar", "baz"}
|
||||||
|
|
||||||
|
assert.NoError(t, h.store.Set(post, newNow))
|
||||||
|
|
||||||
|
gotPost, err = h.store.GetByID(post.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assertPostEqual(t, StoredPost{
|
||||||
|
Post: post,
|
||||||
|
PublishedAt: now,
|
||||||
|
LastUpdatedAt: newNow,
|
||||||
|
}, gotPost)
|
||||||
|
|
||||||
|
// delete the post, it should go away
|
||||||
|
assert.NoError(t, h.store.Delete(post.ID))
|
||||||
|
|
||||||
|
_, err = h.store.GetByID(post.ID)
|
||||||
|
assert.ErrorIs(t, err, ErrPostNotFound)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get", func(t *testing.T) {
|
||||||
|
h := newStoreTestHarness(t)
|
||||||
|
|
||||||
|
now := h.clock.Now().UTC()
|
||||||
|
|
||||||
|
posts := []StoredPost{
|
||||||
|
h.testStoredPost(0),
|
||||||
|
h.testStoredPost(1),
|
||||||
|
h.testStoredPost(2),
|
||||||
|
h.testStoredPost(3),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, post := range posts {
|
||||||
|
assert.NoError(t, h.store.Set(post.Post, now))
|
||||||
|
}
|
||||||
|
|
||||||
|
gotPosts, hasMore, err := h.store.Get(0, 2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, hasMore)
|
||||||
|
assertPostsEqual(t, posts[:2], gotPosts)
|
||||||
|
|
||||||
|
gotPosts, hasMore, err = h.store.Get(1, 2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, hasMore)
|
||||||
|
assertPostsEqual(t, posts[2:4], gotPosts)
|
||||||
|
|
||||||
|
posts = append(posts, h.testStoredPost(4))
|
||||||
|
assert.NoError(t, h.store.Set(posts[4].Post, now))
|
||||||
|
|
||||||
|
gotPosts, hasMore, err = h.store.Get(1, 2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, hasMore)
|
||||||
|
assertPostsEqual(t, posts[2:4], gotPosts)
|
||||||
|
|
||||||
|
gotPosts, hasMore, err = h.store.Get(2, 2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, hasMore)
|
||||||
|
assertPostsEqual(t, posts[4:], gotPosts)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get_by_series", func(t *testing.T) {
|
||||||
|
h := newStoreTestHarness(t)
|
||||||
|
|
||||||
|
now := h.clock.Now().UTC()
|
||||||
|
|
||||||
|
posts := []StoredPost{
|
||||||
|
h.testStoredPost(0),
|
||||||
|
h.testStoredPost(1),
|
||||||
|
h.testStoredPost(2),
|
||||||
|
h.testStoredPost(3),
|
||||||
|
}
|
||||||
|
|
||||||
|
posts[0].Series = "foo"
|
||||||
|
posts[1].Series = "bar"
|
||||||
|
posts[2].Series = "bar"
|
||||||
|
|
||||||
|
for _, post := range posts {
|
||||||
|
assert.NoError(t, h.store.Set(post.Post, now))
|
||||||
|
}
|
||||||
|
|
||||||
|
fooPosts, err := h.store.GetBySeries("foo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assertPostsEqual(t, posts[:1], fooPosts)
|
||||||
|
|
||||||
|
barPosts, err := h.store.GetBySeries("bar")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assertPostsEqual(t, posts[1:3], barPosts)
|
||||||
|
|
||||||
|
bazPosts, err := h.store.GetBySeries("baz")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, bazPosts)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get_by_tag", func(t *testing.T) {
|
||||||
|
|
||||||
|
h := newStoreTestHarness(t)
|
||||||
|
|
||||||
|
now := h.clock.Now().UTC()
|
||||||
|
|
||||||
|
posts := []StoredPost{
|
||||||
|
h.testStoredPost(0),
|
||||||
|
h.testStoredPost(1),
|
||||||
|
h.testStoredPost(2),
|
||||||
|
h.testStoredPost(3),
|
||||||
|
}
|
||||||
|
|
||||||
|
posts[0].Tags = []string{"foo"}
|
||||||
|
posts[1].Tags = []string{"foo", "bar"}
|
||||||
|
posts[2].Tags = []string{"bar"}
|
||||||
|
|
||||||
|
for _, post := range posts {
|
||||||
|
assert.NoError(t, h.store.Set(post.Post, now))
|
||||||
|
}
|
||||||
|
|
||||||
|
fooPosts, err := h.store.GetByTag("foo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assertPostsEqual(t, posts[:2], fooPosts)
|
||||||
|
|
||||||
|
barPosts, err := h.store.GetByTag("bar")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assertPostsEqual(t, posts[1:3], barPosts)
|
||||||
|
|
||||||
|
bazPosts, err := h.store.GetByTag("baz")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, bazPosts)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -7,10 +7,12 @@ import (
|
|||||||
|
|
||||||
"github.com/mediocregopher/blog.mediocregopher.com/srv/cfg"
|
"github.com/mediocregopher/blog.mediocregopher.com/srv/cfg"
|
||||||
migrate "github.com/rubenv/sql-migrate"
|
migrate "github.com/rubenv/sql-migrate"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3" // we need dis
|
||||||
)
|
)
|
||||||
|
|
||||||
var migrations = []*migrate.Migration{
|
var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration{
|
||||||
&migrate.Migration{
|
{
|
||||||
Id: "1",
|
Id: "1",
|
||||||
Up: []string{
|
Up: []string{
|
||||||
`CREATE TABLE posts (
|
`CREATE TABLE posts (
|
||||||
@ -24,18 +26,25 @@ var migrations = []*migrate.Migration{
|
|||||||
|
|
||||||
body TEXT NOT NULL
|
body TEXT NOT NULL
|
||||||
)`,
|
)`,
|
||||||
|
|
||||||
`CREATE TABLE post_tags (
|
`CREATE TABLE post_tags (
|
||||||
post_id TEXT NOT NULL,
|
post_id TEXT NOT NULL,
|
||||||
tag TEXT NOT NULL,
|
tag TEXT NOT NULL,
|
||||||
UNIQUE(post_id, tag)
|
UNIQUE(post_id, tag)
|
||||||
)`,
|
)`,
|
||||||
|
|
||||||
|
`CREATE TABLE assets (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
body BLOB NOT NULL
|
||||||
|
)`,
|
||||||
},
|
},
|
||||||
Down: []string{
|
Down: []string{
|
||||||
|
"DROP TABLE assets",
|
||||||
"DROP TABLE post_tags",
|
"DROP TABLE post_tags",
|
||||||
"DROP TABLE posts",
|
"DROP TABLE posts",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}}
|
||||||
|
|
||||||
// SQLDB is a sqlite3 database which can be used by storage interfaces within
|
// SQLDB is a sqlite3 database which can be used by storage interfaces within
|
||||||
// this package.
|
// this package.
|
||||||
@ -43,7 +52,7 @@ type SQLDB struct {
|
|||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSQLDB initializes and returns a new sqlite3 database for post storage
|
// NewSQLDB initializes and returns a new sqlite3 database for storage
|
||||||
// intefaces. The db will be created within the given data directory.
|
// intefaces. The db will be created within the given data directory.
|
||||||
func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) {
|
func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) {
|
||||||
|
|
||||||
@ -54,8 +63,6 @@ func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) {
|
|||||||
return nil, fmt.Errorf("opening sqlite file at %q: %w", path, err)
|
return nil, fmt.Errorf("opening sqlite file at %q: %w", path, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
migrations := &migrate.MemoryMigrationSource{Migrations: migrations}
|
|
||||||
|
|
||||||
if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil {
|
if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil {
|
||||||
return nil, fmt.Errorf("running migrations: %w", err)
|
return nil, fmt.Errorf("running migrations: %w", err)
|
||||||
}
|
}
|
||||||
@ -63,6 +70,21 @@ func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) {
|
|||||||
return &SQLDB{db}, nil
|
return &SQLDB{db}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewSQLDB is like NewSQLDB, but the database will be initialized in memory.
|
||||||
|
func NewInMemSQLDB() *SQLDB {
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("opening sqlite in memory: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil {
|
||||||
|
panic(fmt.Errorf("running migrations: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SQLDB{db}
|
||||||
|
}
|
||||||
|
|
||||||
// Close cleans up loose resources being held by the db.
|
// Close cleans up loose resources being held by the db.
|
||||||
func (db *SQLDB) Close() error {
|
func (db *SQLDB) Close() error {
|
||||||
return db.db.Close()
|
return db.db.Close()
|
||||||
|
@ -1,362 +0,0 @@
|
|||||||
package post
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"path"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3" // we need dis
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrNotFound is used to indicate a Post could not be found in the
|
|
||||||
// database.
|
|
||||||
ErrNotFound = errors.New("not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
// StoredPost is a Post which has been stored in a Store, and has been given
|
|
||||||
// some extra fields as a result.
|
|
||||||
type StoredPost struct {
|
|
||||||
Post
|
|
||||||
|
|
||||||
PublishedAt time.Time
|
|
||||||
LastUpdatedAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// URL returns the relative URL of the StoredPost.
|
|
||||||
func (p StoredPost) URL() string {
|
|
||||||
return path.Join(
|
|
||||||
fmt.Sprintf(
|
|
||||||
"%d/%0d/%0d",
|
|
||||||
p.PublishedAt.Year(),
|
|
||||||
p.PublishedAt.Month(),
|
|
||||||
p.PublishedAt.Day(),
|
|
||||||
),
|
|
||||||
p.ID+".html",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store is used for storing posts to a persistent storage.
|
|
||||||
type Store interface {
|
|
||||||
|
|
||||||
// Set sets the Post data into the storage, keyed by the Post's ID. It
|
|
||||||
// overwrites a previous Post with the same ID, if there was one.
|
|
||||||
Set(post Post, now time.Time) error
|
|
||||||
|
|
||||||
// Get returns count StoredPosts, sorted time descending, offset by the given page
|
|
||||||
// number. The returned boolean indicates if there are more pages or not.
|
|
||||||
Get(page, count int) ([]StoredPost, bool, error)
|
|
||||||
|
|
||||||
// GetByID will return the StoredPost with the given ID, or ErrNotFound.
|
|
||||||
GetByID(id string) (StoredPost, error)
|
|
||||||
|
|
||||||
// GetBySeries returns all StoredPosts with the given series, sorted time
|
|
||||||
// ascending, or empty slice.
|
|
||||||
GetBySeries(series string) ([]StoredPost, error)
|
|
||||||
|
|
||||||
// GetByTag returns all StoredPosts with the given tag, sorted time
|
|
||||||
// ascending, or empty slice.
|
|
||||||
GetByTag(tag string) ([]StoredPost, error)
|
|
||||||
|
|
||||||
// Delete will delete the StoredPost with the given ID.
|
|
||||||
Delete(id string) error
|
|
||||||
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
type store struct {
|
|
||||||
db *sql.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewStore initializes a new Store using an existing SQLDB.
|
|
||||||
func NewStore(db *SQLDB) Store {
|
|
||||||
return &store{
|
|
||||||
db: db.db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *store) Close() error {
|
|
||||||
return s.db.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the callback returns an error then the transaction is aborted.
|
|
||||||
func (s *store) withTx(cb func(*sql.Tx) error) error {
|
|
||||||
|
|
||||||
tx, err := s.db.Begin()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("starting transaction: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cb(tx); err != nil {
|
|
||||||
|
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
return fmt.Errorf(
|
|
||||||
"rolling back transaction: %w (original error: %v)",
|
|
||||||
rollbackErr, err,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("performing transaction: %w (rolled back)", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return fmt.Errorf("committing transaction: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *store) Set(post Post, now time.Time) error {
|
|
||||||
return s.withTx(func(tx *sql.Tx) error {
|
|
||||||
|
|
||||||
nowTS := now.Unix()
|
|
||||||
|
|
||||||
nowSQL := sql.NullInt64{Int64: nowTS, Valid: !now.IsZero()}
|
|
||||||
|
|
||||||
_, err := tx.Exec(
|
|
||||||
`INSERT INTO posts (
|
|
||||||
id, title, description, series, published_at, body
|
|
||||||
)
|
|
||||||
VALUES
|
|
||||||
(?, ?, ?, ?, ?, ?)
|
|
||||||
ON CONFLICT (id) DO UPDATE SET
|
|
||||||
title=excluded.title,
|
|
||||||
description=excluded.description,
|
|
||||||
series=excluded.series,
|
|
||||||
last_updated_at=?,
|
|
||||||
body=excluded.body`,
|
|
||||||
post.ID,
|
|
||||||
post.Title,
|
|
||||||
post.Description,
|
|
||||||
&sql.NullString{String: post.Series, Valid: post.Series != ""},
|
|
||||||
nowSQL,
|
|
||||||
post.Body,
|
|
||||||
nowSQL,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("inserting into posts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// this is a bit of a hack, but it allows us to update the tagset without
|
|
||||||
// doing a diff.
|
|
||||||
_, err = tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, post.ID)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("clearning post tags: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tag := range post.Tags {
|
|
||||||
|
|
||||||
_, err = tx.Exec(
|
|
||||||
`INSERT INTO post_tags (post_id, tag) VALUES (?, ?)
|
|
||||||
ON CONFLICT DO NOTHING`,
|
|
||||||
post.ID,
|
|
||||||
tag,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("inserting tag %q: %w", tag, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *store) get(
|
|
||||||
querier interface {
|
|
||||||
Query(string, ...interface{}) (*sql.Rows, error)
|
|
||||||
},
|
|
||||||
limit, offset int,
|
|
||||||
where string, whereArgs ...interface{},
|
|
||||||
) (
|
|
||||||
[]StoredPost, error,
|
|
||||||
) {
|
|
||||||
|
|
||||||
query := `SELECT
|
|
||||||
p.id, p.title, p.description, p.series, pt.tag,
|
|
||||||
p.published_at, p.last_updated_at,
|
|
||||||
p.body
|
|
||||||
FROM posts p
|
|
||||||
LEFT JOIN post_tags pt ON (p.id = pt.post_id)
|
|
||||||
` + where + `
|
|
||||||
ORDER BY p.published_at ASC, p.title ASC`
|
|
||||||
|
|
||||||
if limit > 0 {
|
|
||||||
query += fmt.Sprintf(" LIMIT %d", limit)
|
|
||||||
}
|
|
||||||
|
|
||||||
if offset > 0 {
|
|
||||||
query += fmt.Sprintf(" OFFSET %d", offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := querier.Query(query, whereArgs...)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("selecting: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var posts []StoredPost
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
|
|
||||||
var (
|
|
||||||
post StoredPost
|
|
||||||
series, tag sql.NullString
|
|
||||||
publishedAt, lastUpdatedAt sql.NullInt64
|
|
||||||
)
|
|
||||||
|
|
||||||
err := rows.Scan(
|
|
||||||
&post.ID, &post.Title, &post.Description, &series, &tag,
|
|
||||||
&publishedAt, &lastUpdatedAt,
|
|
||||||
&post.Body,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("scanning row: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tag.Valid {
|
|
||||||
|
|
||||||
if l := len(posts); l > 0 && posts[l-1].ID == post.ID {
|
|
||||||
posts[l-1].Tags = append(posts[l-1].Tags, tag.String)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
post.Tags = append(post.Tags, tag.String)
|
|
||||||
}
|
|
||||||
|
|
||||||
post.Series = series.String
|
|
||||||
|
|
||||||
if publishedAt.Valid {
|
|
||||||
post.PublishedAt = time.Unix(publishedAt.Int64, 0).UTC()
|
|
||||||
}
|
|
||||||
|
|
||||||
if lastUpdatedAt.Valid {
|
|
||||||
post.LastUpdatedAt = time.Unix(lastUpdatedAt.Int64, 0).UTC()
|
|
||||||
}
|
|
||||||
|
|
||||||
posts = append(posts, post)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Close(); err != nil {
|
|
||||||
return nil, fmt.Errorf("closing row iterator: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return posts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
|
|
||||||
|
|
||||||
posts, err := s.get(s.db, count+1, page*count, ``)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, fmt.Errorf("querying posts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var hasMore bool
|
|
||||||
|
|
||||||
if len(posts) > count {
|
|
||||||
hasMore = true
|
|
||||||
posts = posts[:count]
|
|
||||||
}
|
|
||||||
|
|
||||||
return posts, hasMore, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *store) GetByID(id string) (StoredPost, error) {
|
|
||||||
|
|
||||||
posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return StoredPost{}, fmt.Errorf("querying posts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(posts) == 0 {
|
|
||||||
return StoredPost{}, ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(posts) > 1 {
|
|
||||||
panic(fmt.Sprintf("got back multiple posts querying id %q: %+v", id, posts))
|
|
||||||
}
|
|
||||||
|
|
||||||
return posts[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *store) GetBySeries(series string) ([]StoredPost, error) {
|
|
||||||
return s.get(s.db, 0, 0, `WHERE p.series=?`, series)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *store) GetByTag(tag string) ([]StoredPost, error) {
|
|
||||||
|
|
||||||
var posts []StoredPost
|
|
||||||
|
|
||||||
err := s.withTx(func(tx *sql.Tx) error {
|
|
||||||
|
|
||||||
rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("querying post_tags by tag: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
placeholders []string
|
|
||||||
whereArgs []interface{}
|
|
||||||
)
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
|
|
||||||
var id string
|
|
||||||
|
|
||||||
if err := rows.Scan(&id); err != nil {
|
|
||||||
rows.Close()
|
|
||||||
return fmt.Errorf("scanning id: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
whereArgs = append(whereArgs, id)
|
|
||||||
placeholders = append(placeholders, "?")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Close(); err != nil {
|
|
||||||
return fmt.Errorf("closing row iterator: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
where := fmt.Sprintf("WHERE p.id IN (%s)", strings.Join(placeholders, ","))
|
|
||||||
|
|
||||||
if posts, err = s.get(tx, 0, 0, where, whereArgs...); err != nil {
|
|
||||||
return fmt.Errorf("querying for ids %+v: %w", whereArgs, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return posts, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *store) Delete(id string) error {
|
|
||||||
|
|
||||||
tx, err := s.db.Begin()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("starting transaction: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
|
|
||||||
return fmt.Errorf("deleting from post_tags: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
|
|
||||||
return fmt.Errorf("deleting from posts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return fmt.Errorf("committing transaction: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,245 +0,0 @@
|
|||||||
package post
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/mediocregopher/blog.mediocregopher.com/srv/cfg"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/tilinna/clock"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testPost(i int) Post {
|
|
||||||
istr := strconv.Itoa(i)
|
|
||||||
return Post{
|
|
||||||
ID: istr,
|
|
||||||
Title: istr,
|
|
||||||
Description: istr,
|
|
||||||
Body: istr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type storeTestHarness struct {
|
|
||||||
clock *clock.Mock
|
|
||||||
store Store
|
|
||||||
}
|
|
||||||
|
|
||||||
func newStoreTestHarness(t *testing.T) storeTestHarness {
|
|
||||||
|
|
||||||
var dataDir cfg.DataDir
|
|
||||||
|
|
||||||
if err := dataDir.Init(); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Cleanup(func() { dataDir.Close() })
|
|
||||||
|
|
||||||
clock := clock.NewMock(time.Now().UTC().Truncate(1 * time.Hour))
|
|
||||||
|
|
||||||
db, err := NewSQLDB(dataDir)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Cleanup(func() { db.Close() })
|
|
||||||
|
|
||||||
store := NewStore(db)
|
|
||||||
|
|
||||||
return storeTestHarness{
|
|
||||||
clock: clock,
|
|
||||||
store: store,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *storeTestHarness) testStoredPost(i int) StoredPost {
|
|
||||||
post := testPost(i)
|
|
||||||
return StoredPost{
|
|
||||||
Post: post,
|
|
||||||
PublishedAt: h.clock.Now(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStore(t *testing.T) {
|
|
||||||
|
|
||||||
assertPostEqual := func(t *testing.T, exp, got StoredPost) {
|
|
||||||
t.Helper()
|
|
||||||
sort.Strings(exp.Tags)
|
|
||||||
sort.Strings(got.Tags)
|
|
||||||
assert.Equal(t, exp, got)
|
|
||||||
}
|
|
||||||
|
|
||||||
assertPostsEqual := func(t *testing.T, exp, got []StoredPost) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range exp {
|
|
||||||
assertPostEqual(t, exp[i], got[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("not_found", func(t *testing.T) {
|
|
||||||
h := newStoreTestHarness(t)
|
|
||||||
|
|
||||||
_, err := h.store.GetByID("foo")
|
|
||||||
assert.ErrorIs(t, err, ErrNotFound)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("set_get_delete", func(t *testing.T) {
|
|
||||||
h := newStoreTestHarness(t)
|
|
||||||
|
|
||||||
now := h.clock.Now().UTC()
|
|
||||||
|
|
||||||
post := testPost(0)
|
|
||||||
post.Tags = []string{"foo", "bar"}
|
|
||||||
|
|
||||||
assert.NoError(t, h.store.Set(post, now))
|
|
||||||
|
|
||||||
gotPost, err := h.store.GetByID(post.ID)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assertPostEqual(t, StoredPost{
|
|
||||||
Post: post,
|
|
||||||
PublishedAt: now,
|
|
||||||
}, gotPost)
|
|
||||||
|
|
||||||
// we will now try updating the post on a different day, and ensure it
|
|
||||||
// updates properly
|
|
||||||
|
|
||||||
h.clock.Add(24 * time.Hour)
|
|
||||||
newNow := h.clock.Now().UTC()
|
|
||||||
|
|
||||||
post.Title = "something else"
|
|
||||||
post.Series = "whatever"
|
|
||||||
post.Body = "anything"
|
|
||||||
post.Tags = []string{"bar", "baz"}
|
|
||||||
|
|
||||||
assert.NoError(t, h.store.Set(post, newNow))
|
|
||||||
|
|
||||||
gotPost, err = h.store.GetByID(post.ID)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assertPostEqual(t, StoredPost{
|
|
||||||
Post: post,
|
|
||||||
PublishedAt: now,
|
|
||||||
LastUpdatedAt: newNow,
|
|
||||||
}, gotPost)
|
|
||||||
|
|
||||||
// delete the post, it should go away
|
|
||||||
assert.NoError(t, h.store.Delete(post.ID))
|
|
||||||
|
|
||||||
_, err = h.store.GetByID(post.ID)
|
|
||||||
assert.ErrorIs(t, err, ErrNotFound)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("get", func(t *testing.T) {
|
|
||||||
h := newStoreTestHarness(t)
|
|
||||||
|
|
||||||
now := h.clock.Now().UTC()
|
|
||||||
|
|
||||||
posts := []StoredPost{
|
|
||||||
h.testStoredPost(0),
|
|
||||||
h.testStoredPost(1),
|
|
||||||
h.testStoredPost(2),
|
|
||||||
h.testStoredPost(3),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, post := range posts {
|
|
||||||
assert.NoError(t, h.store.Set(post.Post, now))
|
|
||||||
}
|
|
||||||
|
|
||||||
gotPosts, hasMore, err := h.store.Get(0, 2)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.True(t, hasMore)
|
|
||||||
assertPostsEqual(t, posts[:2], gotPosts)
|
|
||||||
|
|
||||||
gotPosts, hasMore, err = h.store.Get(1, 2)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.False(t, hasMore)
|
|
||||||
assertPostsEqual(t, posts[2:4], gotPosts)
|
|
||||||
|
|
||||||
posts = append(posts, h.testStoredPost(4))
|
|
||||||
assert.NoError(t, h.store.Set(posts[4].Post, now))
|
|
||||||
|
|
||||||
gotPosts, hasMore, err = h.store.Get(1, 2)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.True(t, hasMore)
|
|
||||||
assertPostsEqual(t, posts[2:4], gotPosts)
|
|
||||||
|
|
||||||
gotPosts, hasMore, err = h.store.Get(2, 2)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.False(t, hasMore)
|
|
||||||
assertPostsEqual(t, posts[4:], gotPosts)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("get_by_series", func(t *testing.T) {
|
|
||||||
h := newStoreTestHarness(t)
|
|
||||||
|
|
||||||
now := h.clock.Now().UTC()
|
|
||||||
|
|
||||||
posts := []StoredPost{
|
|
||||||
h.testStoredPost(0),
|
|
||||||
h.testStoredPost(1),
|
|
||||||
h.testStoredPost(2),
|
|
||||||
h.testStoredPost(3),
|
|
||||||
}
|
|
||||||
|
|
||||||
posts[0].Series = "foo"
|
|
||||||
posts[1].Series = "bar"
|
|
||||||
posts[2].Series = "bar"
|
|
||||||
|
|
||||||
for _, post := range posts {
|
|
||||||
assert.NoError(t, h.store.Set(post.Post, now))
|
|
||||||
}
|
|
||||||
|
|
||||||
fooPosts, err := h.store.GetBySeries("foo")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assertPostsEqual(t, posts[:1], fooPosts)
|
|
||||||
|
|
||||||
barPosts, err := h.store.GetBySeries("bar")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assertPostsEqual(t, posts[1:3], barPosts)
|
|
||||||
|
|
||||||
bazPosts, err := h.store.GetBySeries("baz")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, bazPosts)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("get_by_tag", func(t *testing.T) {
|
|
||||||
|
|
||||||
h := newStoreTestHarness(t)
|
|
||||||
|
|
||||||
now := h.clock.Now().UTC()
|
|
||||||
|
|
||||||
posts := []StoredPost{
|
|
||||||
h.testStoredPost(0),
|
|
||||||
h.testStoredPost(1),
|
|
||||||
h.testStoredPost(2),
|
|
||||||
h.testStoredPost(3),
|
|
||||||
}
|
|
||||||
|
|
||||||
posts[0].Tags = []string{"foo"}
|
|
||||||
posts[1].Tags = []string{"foo", "bar"}
|
|
||||||
posts[2].Tags = []string{"bar"}
|
|
||||||
|
|
||||||
for _, post := range posts {
|
|
||||||
assert.NoError(t, h.store.Set(post.Post, now))
|
|
||||||
}
|
|
||||||
|
|
||||||
fooPosts, err := h.store.GetByTag("foo")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assertPostsEqual(t, posts[:2], fooPosts)
|
|
||||||
|
|
||||||
barPosts, err := h.store.GetByTag("bar")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assertPostsEqual(t, posts[1:3], barPosts)
|
|
||||||
|
|
||||||
bazPosts, err := h.store.GetByTag("baz")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, bazPosts)
|
|
||||||
})
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user