Compare commits

...

6 Commits

Author SHA1 Message Date
Brian Picciano
f365b09757 implemented draft publishing and removed New Posts link/capability 2022-08-19 21:47:38 -06:00
Brian Picciano
33a81f73e1 load published posts test data in reverse order 2022-08-19 21:15:39 -06:00
Brian Picciano
c3135306b3 drafts functionality added, needs a publish button still 2022-08-18 23:07:09 -06:00
Brian Picciano
dfa9bcb9e2 WIP 2022-08-18 22:34:38 -06:00
Brian Picciano
7ac2f5ebb3 implement DraftStore 2022-08-18 22:25:17 -06:00
Brian Picciano
76ff79f470 add admin page, and spruce up posts and assets 2022-08-18 21:11:42 -06:00
15 changed files with 709 additions and 99 deletions

View File

@ -93,7 +93,7 @@ func main() {
logger.Info(ctx, "set post")
now = now.Add(1 * time.Hour)
now = now.Add(-1 * time.Hour)
}
}

View File

@ -121,11 +121,13 @@ func main() {
postStore := post.NewStore(postSQLDB)
postAssetStore := post.NewAssetStore(postSQLDB)
postDraftStore := post.NewDraftStore(postSQLDB)
httpParams.Logger = logger.WithNamespace("http")
httpParams.PowManager = powMgr
httpParams.PostStore = postStore
httpParams.PostAssetStore = postAssetStore
httpParams.PostDraftStore = postDraftStore
httpParams.MailingList = ml
httpParams.GlobalRoom = chatGlobalRoom
httpParams.UserIDCalculator = chatUserIDCalc

View File

@ -37,6 +37,7 @@ type Params struct {
PostStore post.Store
PostAssetStore post.AssetStore
PostDraftStore post.DraftStore
MailingList mailinglist.MailingList
@ -202,7 +203,7 @@ func (a *api) blogHandler() http.Handler {
apiutil.MethodMux(map[string]http.Handler{
"GET": a.renderPostHandler(),
"POST": a.postPostHandler(),
"DELETE": a.deletePostHandler(),
"DELETE": a.deletePostHandler(false),
"PREVIEW": a.previewPostHandler(),
}),
))
@ -215,8 +216,23 @@ func (a *api) blogHandler() http.Handler {
}),
))
mux.Handle("/drafts/", http.StripPrefix("/drafts",
// everything to do with drafts is protected
authMiddleware(a.auther)(
apiutil.MethodMux(map[string]http.Handler{
"GET": a.renderDraftPostHandler(),
"POST": a.postDraftPostHandler(),
"DELETE": a.deletePostHandler(true),
"PREVIEW": a.previewPostHandler(),
}),
),
))
mux.Handle("/static/", http.FileServer(http.FS(staticFS)))
mux.Handle("/follow", a.renderDumbTplHandler("follow.html"))
mux.Handle("/admin", a.renderDumbTplHandler("admin.html"))
mux.Handle("/mailinglist/unsubscribe", a.renderDumbTplHandler("unsubscribe.html"))
mux.Handle("/mailinglist/finalize", a.renderDumbTplHandler("finalize.html"))
mux.Handle("/feed.xml", a.renderFeedHandler())

130
srv/src/http/drafts.go Normal file
View File

@ -0,0 +1,130 @@
package http
import (
"errors"
"fmt"
"net/http"
"path/filepath"
"strings"
"github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil"
"github.com/mediocregopher/blog.mediocregopher.com/srv/post"
)
func (a *api) renderDraftPostHandler() http.Handler {
tpl := a.mustParseBasedTpl("post.html")
renderDraftPostsIndexHandler := a.renderDraftPostsIndexHandler()
renderDraftEditPostHandler := a.renderEditPostHandler(true)
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
id := strings.TrimSuffix(filepath.Base(r.URL.Path), ".html")
if id == "/" {
renderDraftPostsIndexHandler.ServeHTTP(rw, r)
return
}
if _, ok := r.URL.Query()["edit"]; ok {
renderDraftEditPostHandler.ServeHTTP(rw, r)
return
}
p, err := a.params.PostDraftStore.GetByID(id)
if errors.Is(err, post.ErrPostNotFound) {
http.Error(rw, "Post not found", 404)
return
} else if err != nil {
apiutil.InternalServerError(
rw, r, fmt.Errorf("fetching post with id %q: %w", id, err),
)
return
}
tplPayload, err := a.postToPostTplPayload(post.StoredPost{Post: p})
if err != nil {
apiutil.InternalServerError(
rw, r, fmt.Errorf(
"generating template payload for post with id %q: %w",
id, err,
),
)
return
}
executeTemplate(rw, r, tpl, tplPayload)
})
}
func (a *api) renderDraftPostsIndexHandler() http.Handler {
renderEditPostHandler := a.renderEditPostHandler(true)
tpl := a.mustParseBasedTpl("draft-posts.html")
const pageCount = 20
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if _, ok := r.URL.Query()["edit"]; ok {
renderEditPostHandler.ServeHTTP(rw, r)
return
}
page, err := apiutil.StrToInt(r.FormValue("p"), 0)
if err != nil {
apiutil.BadRequest(
rw, r, fmt.Errorf("invalid page number: %w", err),
)
return
}
posts, hasMore, err := a.params.PostDraftStore.Get(page, pageCount)
if err != nil {
apiutil.InternalServerError(
rw, r, fmt.Errorf("fetching page %d of posts: %w", page, err),
)
return
}
tplPayload := struct {
Posts []post.Post
PrevPage, NextPage int
}{
Posts: posts,
PrevPage: -1,
NextPage: -1,
}
if page > 0 {
tplPayload.PrevPage = page - 1
}
if hasMore {
tplPayload.NextPage = page + 1
}
executeTemplate(rw, r, tpl, tplPayload)
})
}
func (a *api) postDraftPostHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
p, err := postFromPostReq(r)
if err != nil {
apiutil.BadRequest(rw, r, err)
return
}
if err := a.params.PostDraftStore.Set(p); err != nil {
apiutil.InternalServerError(
rw, r, fmt.Errorf("storing post with id %q: %w", p.ID, err),
)
return
}
a.executeRedirectTpl(rw, r, a.draftURL(p.ID, false)+"?edit")
})
}

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"errors"
"fmt"
"html/template"
@ -16,9 +17,10 @@ import (
"github.com/gomarkdown/markdown/parser"
"github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil"
"github.com/mediocregopher/blog.mediocregopher.com/srv/post"
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
)
func (a *api) parsePostBody(storedPost post.StoredPost) (*txttpl.Template, error) {
func (a *api) parsePostBody(post post.Post) (*txttpl.Template, error) {
tpl := txttpl.New("root")
tpl = tpl.Funcs(txttpl.FuncMap(a.tplFuncs()))
@ -43,7 +45,7 @@ func (a *api) parsePostBody(storedPost post.StoredPost) (*txttpl.Template, error
},
})
tpl, err := tpl.New(storedPost.ID + "-body.html").Parse(storedPost.Body)
tpl, err := tpl.New(post.ID + "-body.html").Parse(post.Body)
if err != nil {
return nil, err
@ -60,7 +62,7 @@ type postTplPayload struct {
func (a *api) postToPostTplPayload(storedPost post.StoredPost) (postTplPayload, error) {
bodyTpl, err := a.parsePostBody(storedPost)
bodyTpl, err := a.parsePostBody(storedPost.Post)
if err != nil {
return postTplPayload{}, fmt.Errorf("parsing post body as template: %w", err)
}
@ -125,7 +127,7 @@ func (a *api) renderPostHandler() http.Handler {
tpl := a.mustParseBasedTpl("post.html")
renderPostsIndexHandler := a.renderPostsIndexHandler()
renderEditPostHandler := a.renderEditPostHandler()
renderEditPostHandler := a.renderEditPostHandler(false)
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@ -171,7 +173,7 @@ func (a *api) renderPostHandler() http.Handler {
func (a *api) renderPostsIndexHandler() http.Handler {
renderEditPostHandler := a.renderEditPostHandler()
renderEditPostHandler := a.renderEditPostHandler(false)
tpl := a.mustParseBasedTpl("posts.html")
const pageCount = 20
@ -219,7 +221,7 @@ func (a *api) renderPostsIndexHandler() http.Handler {
})
}
func (a *api) renderEditPostHandler() http.Handler {
func (a *api) renderEditPostHandler(isDraft bool) http.Handler {
tpl := a.mustParseBasedTpl("edit-post.html")
@ -232,7 +234,12 @@ func (a *api) renderEditPostHandler() http.Handler {
if id != "/" {
var err error
if isDraft {
storedPost.Post, err = a.params.PostDraftStore.GetByID(id)
} else {
storedPost, err = a.params.PostStore.GetByID(id)
}
if errors.Is(err, post.ErrPostNotFound) {
http.Error(rw, "Post not found", 404)
@ -243,6 +250,10 @@ func (a *api) renderEditPostHandler() http.Handler {
)
return
}
} else if !isDraft {
http.Error(rw, "Post ID required in URL", 400)
return
}
tags, err := a.params.PostStore.GetTags()
@ -254,9 +265,11 @@ func (a *api) renderEditPostHandler() http.Handler {
tplPayload := struct {
Post post.StoredPost
Tags []string
IsDraft bool
}{
Post: storedPost,
Tags: tags,
IsDraft: isDraft,
}
executeTemplate(rw, r, tpl, tplPayload)
@ -289,43 +302,58 @@ func postFromPostReq(r *http.Request) (post.Post, error) {
return p, nil
}
func (a *api) storeAndPublishPost(ctx context.Context, p post.Post) error {
first, err := a.params.PostStore.Set(p, time.Now())
if err != nil {
return fmt.Errorf("storing post with id %q: %w", p.ID, err)
}
if !first {
return nil
}
a.params.Logger.Info(ctx, "publishing blog post to mailing list")
urlStr := a.postURL(p.ID, true)
if err := a.params.MailingList.Publish(p.Title, urlStr); err != nil {
return fmt.Errorf("publishing post to mailing list: %w", err)
}
if err := a.params.PostDraftStore.Delete(p.ID); err != nil {
return fmt.Errorf("deleting draft: %w", err)
}
return nil
}
func (a *api) postPostHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
p, err := postFromPostReq(r)
if err != nil {
apiutil.BadRequest(rw, r, err)
return
}
first, err := a.params.PostStore.Set(p, time.Now())
ctx = mctx.Annotate(ctx, "postID", p.ID)
if err != nil {
if err := a.storeAndPublishPost(ctx, p); err != nil {
apiutil.InternalServerError(
rw, r, fmt.Errorf("storing post with id %q: %w", p.ID, err),
rw, r, fmt.Errorf("storing/publishing post with id %q: %w", p.ID, err),
)
return
}
if first {
a.params.Logger.Info(r.Context(), "publishing blog post to mailing list")
urlStr := a.postURL(p.ID, true)
if err := a.params.MailingList.Publish(p.Title, urlStr); err != nil {
apiutil.InternalServerError(
rw, r, fmt.Errorf("publishing post with id %q: %w", p.ID, err),
)
return
}
}
a.executeRedirectTpl(rw, r, a.postURL(p.ID, false)+"?edit")
a.executeRedirectTpl(rw, r, a.postURL(p.ID, false))
})
}
func (a *api) deletePostHandler() http.Handler {
func (a *api) deletePostHandler(isDraft bool) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@ -336,7 +364,13 @@ func (a *api) deletePostHandler() http.Handler {
return
}
err := a.params.PostStore.Delete(id)
var err error
if isDraft {
err = a.params.PostDraftStore.Delete(id)
} else {
err = a.params.PostStore.Delete(id)
}
if errors.Is(err, post.ErrPostNotFound) {
http.Error(rw, "Post not found", 404)
@ -348,8 +382,11 @@ func (a *api) deletePostHandler() http.Handler {
return
}
if isDraft {
a.executeRedirectTpl(rw, r, a.draftsURL(false))
} else {
a.executeRedirectTpl(rw, r, a.postsURL(false))
}
})
}

View File

@ -57,6 +57,15 @@ func (a *api) assetsURL(abs bool) string {
return a.blogURL("assets", abs)
}
func (a *api) draftURL(id string, abs bool) string {
path := filepath.Join("drafts", id)
return a.blogURL(path, abs)
}
func (a *api) draftsURL(abs bool) string {
return a.blogURL("drafts", abs)
}
func (a *api) tplFuncs() template.FuncMap {
return template.FuncMap{
"BlogURL": func(path string) string {
@ -71,12 +80,15 @@ func (a *api) tplFuncs() template.FuncMap {
b, err := staticFS.ReadFile(path)
return template.CSS(b), err
},
"PostURL": func(id string) string {
return a.postURL(id, false)
},
"AssetURL": func(id string) string {
path := filepath.Join("assets", id)
return a.blogURL(path, false)
},
"PostURL": func(id string) string {
return a.postURL(id, false)
"DraftURL": func(id string) string {
return a.draftURL(id, false)
},
"DateTimeFormat": func(t time.Time) string {
return t.Format("2006-01-02")

View File

@ -0,0 +1,18 @@
{{ define "body" }}
<h1>Admin</h1>
This is a directory of pages which are used for managing blog content. They are
mostly left open to inspection, but you will not able to change
anything without providing credentials.
<ul>
<li><a href="{{ BlogURL "posts" }}">Posts</a></li>
<li><a href="{{ BlogURL "assets" }}">Assets</a></li>
<li><a href="{{ BlogURL "drafts" }}">Drafts (private)</a></li>
</ul>
{{ end }}
{{ template "base.html" . }}

View File

@ -1,5 +1,7 @@
{{ define "body" }}
<h1>Assets</h1>
<h2>Upload Asset</h2>
<p>
@ -21,6 +23,8 @@
</div>
</form>
{{ if gt (len .Payload.IDs) 0 }}
<h2>Existing Assets</h2>
<table>
@ -44,4 +48,6 @@
{{ end }}
{{ end }}
{{ template "base.html" . }}

View File

@ -0,0 +1,48 @@
{{ define "body" }}
<h1>Drafts</h1>
<p>
<a href="{{ BlogURL "drafts/" }}?edit">
New Draft
</a>
</p>
{{ if ge .Payload.PrevPage 0 }}
<p>
<a href="?p={{ .Payload.PrevPage}}">&lt; &lt; Previous Page</a>
</p>
{{ end }}
<table>
{{ range .Payload.Posts }}
<tr>
<td><a href="{{ DraftURL .ID }}">{{ .Title }}</a></td>
<td>
<a href="{{ DraftURL .ID }}?edit">
Edit
</a>
</td>
<td>
<form
action="{{ DraftURL .ID }}?method=delete"
method="POST"
>
<input type="submit" value="Delete" />
</form>
</td>
</tr>
{{ end }}
</table>
{{ if ge .Payload.NextPage 0 }}
<p>
<a href="?p={{ .Payload.NextPage}}">Next Page &gt; &gt;</a>
</p>
{{ end }}
{{ end }}
{{ template "base.html" . }}

View File

@ -15,6 +15,9 @@
type="text"
placeholder="e.g. how-to-fly-a-kite"
value="{{ .Payload.Post.ID }}" />
{{ else if .Payload.IsDraft }}
{{ .Payload.Post.ID }}
<input name="id" type="hidden" value="{{ .Payload.Post.ID }}" />
{{ else }}
<a href="{{ PostURL .Payload.Post.ID }}">{{ .Payload.Post.ID }}</a>
<input name="id" type="hidden" value="{{ .Payload.Post.ID }}" />
@ -87,6 +90,7 @@
</p>
<p>
<input
type="submit"
value="Preview"
@ -94,8 +98,25 @@
formtarget="_blank"
/>
{{ if eq .Payload.Post.ID "" }}
<input type="submit" value="Publish" formaction="{{ BlogURL "posts/" }}" />
{{ if .Payload.IsDraft }}
<input type="submit" value="Save" formaction="{{ BlogURL "drafts/" }}" />
<script>
function confirmPublish(event) {
if (!confirm("Are you sure you're ready to publish?"))
event.preventDefault();
}
</script>
<input
type="submit"
value="Publish"
formaction="{{ BlogURL "posts/" }}"
onclick="confirmPublish(event)"
/>
{{ else }}
<input type="submit" value="Update" formaction="{{ BlogURL "posts/" }}" />
{{ end }}
@ -105,10 +126,17 @@
</form>
<p>
{{ if .Payload.IsDraft }}
<a href="{{ BlogURL "drafts/" }}">
Back to Drafts
</a>
{{ else }}
<a href="{{ BlogURL "posts/" }}">
Back to Posts
</a>
{{ end }}
</p>
{{ end }}
{{ template "base.html" . }}

View File

@ -1,10 +1,6 @@
{{ define "body" }}
<p>
<a href="{{ BlogURL "posts/" }}?edit">
New Post
</a>
</p>
<h1>Posts</h1>
{{ if ge .Payload.PrevPage 0 }}
<p>

187
srv/src/post/draft_post.go Normal file
View File

@ -0,0 +1,187 @@
package post
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
)
type DraftStore interface {
// Set sets the draft Post's data into the storage, keyed by the draft
// Post's ID.
Set(post Post) error
// Get returns count draft Posts, sorted id descending, offset by the
// given page number. The returned boolean indicates if there are more pages
// or not.
Get(page, count int) ([]Post, bool, error)
// GetByID will return the draft Post with the given ID, or ErrPostNotFound.
GetByID(id string) (Post, error)
// Delete will delete the draft Post with the given ID.
Delete(id string) error
}
type draftStore struct {
db *SQLDB
}
// NewDraftStore initializes a new DraftStore using an existing SQLDB.
func NewDraftStore(db *SQLDB) DraftStore {
return &draftStore{
db: db,
}
}
func (s *draftStore) Set(post Post) error {
if post.ID == "" {
return errors.New("post ID can't be empty")
}
tagsJSON, err := json.Marshal(post.Tags)
if err != nil {
return fmt.Errorf("json marshaling tags %#v: %w", post.Tags, err)
}
_, err = s.db.db.Exec(
`INSERT INTO post_drafts (
id, title, description, tags, series, body
)
VALUES
(?, ?, ?, ?, ?, ?)
ON CONFLICT (id) DO UPDATE SET
title=excluded.title,
description=excluded.description,
tags=excluded.tags,
series=excluded.series,
body=excluded.body`,
post.ID,
post.Title,
post.Description,
&sql.NullString{String: string(tagsJSON), Valid: len(post.Tags) > 0},
&sql.NullString{String: post.Series, Valid: post.Series != ""},
post.Body,
)
if err != nil {
return fmt.Errorf("inserting into post_drafts: %w", err)
}
return nil
}
func (s *draftStore) get(
querier interface {
Query(string, ...interface{}) (*sql.Rows, error)
},
limit, offset int,
where string, whereArgs ...interface{},
) (
[]Post, error,
) {
query := `
SELECT
p.id, p.title, p.description, p.tags, p.series, p.body
FROM post_drafts p
` + where + `
ORDER BY p.id ASC`
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
}
if offset > 0 {
query += fmt.Sprintf(" OFFSET %d", offset)
}
rows, err := querier.Query(query, whereArgs...)
if err != nil {
return nil, fmt.Errorf("selecting: %w", err)
}
defer rows.Close()
var posts []Post
for rows.Next() {
var (
post Post
tags, series sql.NullString
)
err := rows.Scan(
&post.ID, &post.Title, &post.Description, &tags, &series,
&post.Body,
)
if err != nil {
return nil, fmt.Errorf("scanning row: %w", err)
}
post.Series = series.String
if tags.String != "" {
if err := json.Unmarshal([]byte(tags.String), &post.Tags); err != nil {
return nil, fmt.Errorf("json parsing %q: %w", tags.String, err)
}
}
posts = append(posts, post)
}
return posts, nil
}
func (s *draftStore) Get(page, count int) ([]Post, bool, error) {
posts, err := s.get(s.db.db, count+1, page*count, ``)
if err != nil {
return nil, false, fmt.Errorf("querying post_drafts: %w", err)
}
var hasMore bool
if len(posts) > count {
hasMore = true
posts = posts[:count]
}
return posts, hasMore, nil
}
func (s *draftStore) GetByID(id string) (Post, error) {
posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id)
if err != nil {
return Post{}, fmt.Errorf("querying post_drafts: %w", err)
}
if len(posts) == 0 {
return Post{}, ErrPostNotFound
}
if len(posts) > 1 {
panic(fmt.Sprintf("got back multiple draft posts querying id %q: %+v", id, posts))
}
return posts[0], nil
}
func (s *draftStore) Delete(id string) error {
if _, err := s.db.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil {
return fmt.Errorf("deleting from post_drafts: %w", err)
}
return nil
}

View File

@ -0,0 +1,130 @@
package post
import (
"sort"
"testing"
"github.com/stretchr/testify/assert"
)
type draftStoreTestHarness struct {
store DraftStore
}
func newDraftStoreTestHarness(t *testing.T) draftStoreTestHarness {
db := NewInMemSQLDB()
t.Cleanup(func() { db.Close() })
store := NewDraftStore(db)
return draftStoreTestHarness{
store: store,
}
}
func TestDraftStore(t *testing.T) {
assertPostEqual := func(t *testing.T, exp, got Post) {
t.Helper()
sort.Strings(exp.Tags)
sort.Strings(got.Tags)
assert.Equal(t, exp, got)
}
assertPostsEqual := func(t *testing.T, exp, got []Post) {
t.Helper()
if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) {
return
}
for i := range exp {
assertPostEqual(t, exp[i], got[i])
}
}
t.Run("not_found", func(t *testing.T) {
h := newDraftStoreTestHarness(t)
_, err := h.store.GetByID("foo")
assert.ErrorIs(t, err, ErrPostNotFound)
})
t.Run("set_get_delete", func(t *testing.T) {
h := newDraftStoreTestHarness(t)
post := testPost(0)
post.Tags = []string{"foo", "bar"}
err := h.store.Set(post)
assert.NoError(t, err)
gotPost, err := h.store.GetByID(post.ID)
assert.NoError(t, err)
assertPostEqual(t, post, gotPost)
// we will now try updating the post, and ensure it updates properly
post.Title = "something else"
post.Series = "whatever"
post.Body = "anything"
post.Tags = []string{"bar", "baz"}
err = h.store.Set(post)
assert.NoError(t, err)
gotPost, err = h.store.GetByID(post.ID)
assert.NoError(t, err)
assertPostEqual(t, post, gotPost)
// delete the post, it should go away
assert.NoError(t, h.store.Delete(post.ID))
_, err = h.store.GetByID(post.ID)
assert.ErrorIs(t, err, ErrPostNotFound)
})
t.Run("get", func(t *testing.T) {
h := newDraftStoreTestHarness(t)
posts := []Post{
testPost(0),
testPost(1),
testPost(2),
testPost(3),
}
for _, post := range posts {
err := h.store.Set(post)
assert.NoError(t, err)
}
gotPosts, hasMore, err := h.store.Get(0, 2)
assert.NoError(t, err)
assert.True(t, hasMore)
assertPostsEqual(t, posts[:2], gotPosts)
gotPosts, hasMore, err = h.store.Get(1, 2)
assert.NoError(t, err)
assert.False(t, hasMore)
assertPostsEqual(t, posts[2:4], gotPosts)
posts = append(posts, testPost(4))
err = h.store.Set(posts[4])
assert.NoError(t, err)
gotPosts, hasMore, err = h.store.Get(1, 2)
assert.NoError(t, err)
assert.True(t, hasMore)
assertPostsEqual(t, posts[2:4], gotPosts)
gotPosts, hasMore, err = h.store.Get(2, 2)
assert.NoError(t, err)
assert.False(t, hasMore)
assertPostsEqual(t, posts[4:], gotPosts)
})
}

View File

@ -77,44 +77,16 @@ type Store interface {
}
type store struct {
db *sql.DB
db *SQLDB
}
// NewStore initializes a new Store using an existing SQLDB.
func NewStore(db *SQLDB) Store {
return &store{
db: db.db,
db: db,
}
}
// if the callback returns an error then the transaction is aborted.
func (s *store) withTx(cb func(*sql.Tx) error) error {
tx, err := s.db.Begin()
if err != nil {
return fmt.Errorf("starting transaction: %w", err)
}
if err := cb(tx); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf(
"rolling back transaction: %w (original error: %v)",
rollbackErr, err,
)
}
return fmt.Errorf("performing transaction: %w (rolled back)", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
return nil
}
func (s *store) Set(post Post, now time.Time) (bool, error) {
if post.ID == "" {
@ -123,7 +95,7 @@ func (s *store) Set(post Post, now time.Time) (bool, error) {
var first bool
err := s.withTx(func(tx *sql.Tx) error {
err := s.db.withTx(func(tx *sql.Tx) error {
nowTS := now.Unix()
@ -270,7 +242,7 @@ func (s *store) get(
func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
posts, err := s.get(s.db, count+1, page*count, ``)
posts, err := s.get(s.db.db, count+1, page*count, ``)
if err != nil {
return nil, false, fmt.Errorf("querying posts: %w", err)
@ -288,7 +260,7 @@ func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
func (s *store) GetByID(id string) (StoredPost, error) {
posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id)
posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id)
if err != nil {
return StoredPost{}, fmt.Errorf("querying posts: %w", err)
@ -306,14 +278,14 @@ func (s *store) GetByID(id string) (StoredPost, error) {
}
func (s *store) GetBySeries(series string) ([]StoredPost, error) {
return s.get(s.db, 0, 0, `WHERE p.series=?`, series)
return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series)
}
func (s *store) GetByTag(tag string) ([]StoredPost, error) {
var posts []StoredPost
err := s.withTx(func(tx *sql.Tx) error {
err := s.db.withTx(func(tx *sql.Tx) error {
rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
@ -357,7 +329,7 @@ func (s *store) GetByTag(tag string) ([]StoredPost, error) {
func (s *store) GetTags() ([]string, error) {
rows, err := s.db.Query(`SELECT tag FROM post_tags GROUP BY tag`)
rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`)
if err != nil {
return nil, fmt.Errorf("querying all tags: %w", err)
}
@ -381,11 +353,7 @@ func (s *store) GetTags() ([]string, error) {
func (s *store) Delete(id string) error {
tx, err := s.db.Begin()
if err != nil {
return fmt.Errorf("starting transaction: %w", err)
}
return s.db.withTx(func(tx *sql.Tx) error {
if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
return fmt.Errorf("deleting from post_tags: %w", err)
@ -395,9 +363,6 @@ func (s *store) Delete(id string) error {
return fmt.Errorf("deleting from posts: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
return nil
})
}

View File

@ -38,10 +38,18 @@ var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration
body BLOB NOT NULL
)`,
},
Down: []string{
"DROP TABLE assets",
"DROP TABLE post_tags",
"DROP TABLE posts",
},
{
Id: "2",
Up: []string{
`CREATE TABLE post_drafts (
id TEXT NOT NULL PRIMARY KEY,
title TEXT NOT NULL,
description TEXT NOT NULL,
tags TEXT,
series TEXT,
body TEXT NOT NULL
)`,
},
},
}}
@ -89,3 +97,30 @@ func NewInMemSQLDB() *SQLDB {
func (db *SQLDB) Close() error {
return db.db.Close()
}
func (db *SQLDB) withTx(cb func(*sql.Tx) error) error {
tx, err := db.db.Begin()
if err != nil {
return fmt.Errorf("starting transaction: %w", err)
}
if err := cb(tx); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf(
"rolling back transaction: %w (original error: %v)",
rollbackErr, err,
)
}
return fmt.Errorf("performing transaction: %w (rolled back)", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
return nil
}