package post import ( "database/sql" "encoding/json" "errors" "fmt" ) type DraftStore interface { // Set sets the draft Post's data into the storage, keyed by the draft // Post's ID. Set(post Post) error // Get returns count draft Posts, sorted id descending, offset by the // given page number. The returned boolean indicates if there are more pages // or not. Get(page, count int) ([]Post, bool, error) // GetByID will return the draft Post with the given ID, or ErrPostNotFound. GetByID(id string) (Post, error) // Delete will delete the draft Post with the given ID. Delete(id string) error } type draftStore struct { db *SQLDB } // NewDraftStore initializes a new DraftStore using an existing SQLDB. func NewDraftStore(db *SQLDB) DraftStore { return &draftStore{ db: db, } } func (s *draftStore) Set(post Post) error { if post.ID == "" { return errors.New("post ID can't be empty") } tagsJSON, err := json.Marshal(post.Tags) if err != nil { return fmt.Errorf("json marshaling tags %#v: %w", post.Tags, err) } _, err = s.db.db.Exec( `INSERT INTO post_drafts ( id, title, description, tags, series, body ) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT (id) DO UPDATE SET title=excluded.title, description=excluded.description, tags=excluded.tags, series=excluded.series, body=excluded.body`, post.ID, post.Title, &sql.NullString{String: post.Description, Valid: len(post.Description) > 0}, &sql.NullString{String: string(tagsJSON), Valid: len(post.Tags) > 0}, &sql.NullString{String: post.Series, Valid: post.Series != ""}, post.Body, ) if err != nil { return fmt.Errorf("inserting into post_drafts: %w", err) } return nil } func (s *draftStore) get( querier interface { Query(string, ...interface{}) (*sql.Rows, error) }, limit, offset int, where string, whereArgs ...interface{}, ) ( []Post, error, ) { query := ` SELECT p.id, p.title, p.description, p.tags, p.series, p.body FROM post_drafts p ` + where + ` ORDER BY p.id ASC` if limit > 0 { query += fmt.Sprintf(" LIMIT %d", limit) } if offset > 0 { query += fmt.Sprintf(" OFFSET %d", offset) } rows, err := querier.Query(query, whereArgs...) if err != nil { return nil, fmt.Errorf("selecting: %w", err) } defer rows.Close() var posts []Post for rows.Next() { var ( post Post description, tags, series sql.NullString ) err := rows.Scan( &post.ID, &post.Title, &description, &tags, &series, &post.Body, ) if err != nil { return nil, fmt.Errorf("scanning row: %w", err) } post.Description = description.String post.Series = series.String if tags.String != "" { if err := json.Unmarshal([]byte(tags.String), &post.Tags); err != nil { return nil, fmt.Errorf("json parsing %q: %w", tags.String, err) } } posts = append(posts, post) } return posts, nil } func (s *draftStore) Get(page, count int) ([]Post, bool, error) { posts, err := s.get(s.db.db, count+1, page*count, ``) if err != nil { return nil, false, fmt.Errorf("querying post_drafts: %w", err) } var hasMore bool if len(posts) > count { hasMore = true posts = posts[:count] } return posts, hasMore, nil } func (s *draftStore) GetByID(id string) (Post, error) { posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id) if err != nil { return Post{}, fmt.Errorf("querying post_drafts: %w", err) } if len(posts) == 0 { return Post{}, ErrPostNotFound } if len(posts) > 1 { panic(fmt.Sprintf("got back multiple draft posts querying id %q: %+v", id, posts)) } return posts[0], nil } func (s *draftStore) Delete(id string) error { if _, err := s.db.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil { return fmt.Errorf("deleting from post_drafts: %w", err) } return nil }