From 6539cc2650b29440ecc25031d71ce889bbd1efb1 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Fri, 29 Dec 2023 20:11:43 +0100 Subject: [PATCH] Fix how iteration works in Store, since sqlite doesn't like concurrent access --- store.go | 161 +++++++++++++++++++++++++++++++++++--------------- store_test.go | 56 +++++++++++++++++- 2 files changed, 170 insertions(+), 47 deletions(-) diff --git a/store.go b/store.go index 0dc7567..1c2f516 100644 --- a/store.go +++ b/store.go @@ -117,6 +117,10 @@ func NewSQLiteStore(o *SQLiteStoreOpts) *SQLiteStore { panic(fmt.Errorf("opening sqlite in memory: %w", err)) } + // go-sqlite doesn't support multiple go-routines, this is equivalent to + // wrapping each call to the db in a mutex. + db.SetMaxOpenConns(1) + if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil { panic(fmt.Errorf("running migrations: %w", err)) } @@ -130,6 +134,66 @@ func (s *SQLiteStore) Close() error { return s.db.Close() } +// iterate is a helper which can be used to read the results of a query in +// chunks, producing a single unified Iterator. +// +// iterate assumes that the rows are being scanned ordered by their row ID, +// which must be a number, and that each query call has some kind of limit +// applied. +func iterate[T any]( + db interface { + QueryContext(context.Context, string, ...any) (*sql.Rows, error) + }, + query string, + mkArgs func(minRowID int) []any, + scan func(*sql.Rows) (T, int, error), // returns scanned value and its rowID +) miter.Iterator[T] { + var ( + zero T + minRowID = -1 + res, resBase []T + + pop = func() T { + r := res[0] + res = res[1:] + return r + } + ) + + return miter.FromFunc(func(ctx context.Context) (T, error) { + if len(res) > 0 { + return pop(), nil + } + + res = resBase + + rows, err := db.QueryContext(ctx, query, mkArgs(minRowID)...) + if err != nil { + return zero, fmt.Errorf("executing query: %w", err) + } + defer rows.Close() + + for rows.Next() { + var r T + if r, minRowID, err = scan(rows); err != nil { + return zero, fmt.Errorf("scanning row: %w", err) + } + + res = append(res, r) + } + + resBase = res[:0] + + if len(res) == 0 { + return zero, miter.ErrEnd + } + + return pop(), nil + }) +} + +const getByStatusLimit = 16 + // GetByStatus implements the method for the Store interface. func (s *SQLiteStore) GetByStatus(status ResourceStatus) miter.Iterator[Resource] { const query = ` @@ -151,6 +215,7 @@ func (s *SQLiteStore) GetByStatus(status ResourceStatus) miter.Iterator[Resource GROUP BY from_url_id ) SELECT + resources.url_id, url, status, pinned, @@ -163,26 +228,28 @@ func (s *SQLiteStore) GetByStatus(status ResourceStatus) miter.Iterator[Resource LEFT JOIN incoming ON (incoming.url_id = resources.url_id) LEFT JOIN outgoing ON (outgoing.url_id = resources.url_id) WHERE status = ? - AND (pinned OR incoming.urls IS NOT NULL)` - - return miter.Lazily(func(ctx context.Context) (miter.Iterator[Resource], error) { - rows, err := s.db.QueryContext(ctx, query, status) - if err != nil { - return nil, fmt.Errorf("executing query: %w", err) - } - - return miter.FromFunc(func(ctx context.Context) (Resource, error) { + AND (pinned OR incoming.urls IS NOT NULL) + AND resources.url_id > ? + ORDER BY resources.url_id ASC + LIMIT ?` + + return iterate( + s.db, + query, + func(minRowID int) []any { + return []any{status, minRowID, getByStatusLimit} + }, + func(rows *sql.Rows) (Resource, int, error) { var ( r Resource + rowID int + err error lastChecked int64 incoming, outgoing sql.NullString ) - if !rows.Next() { - return Resource{}, errors.Join(rows.Close(), miter.ErrEnd) - } - if err := rows.Scan( + &rowID, &r.URL, &r.Status, &r.Pinned, @@ -191,9 +258,7 @@ func (s *SQLiteStore) GetByStatus(status ResourceStatus) miter.Iterator[Resource &incoming, &outgoing, ); err != nil { - return Resource{}, errors.Join( - rows.Close(), fmt.Errorf("scanning row: %w", err), - ) + return Resource{}, 0, fmt.Errorf("calling Scan: %w", err) } if lastChecked != 0 { @@ -204,8 +269,8 @@ func (s *SQLiteStore) GetByStatus(status ResourceStatus) miter.Iterator[Resource if r.IncomingLinkURLs, err = parseURLs( strings.Split(incoming.String, "\x00"), ); err != nil { - return Resource{}, errors.Join( - rows.Close(), fmt.Errorf("parsing incoming links: %w", err), + return Resource{}, 0, fmt.Errorf( + "parsing incoming links: %w", err, ) } } @@ -214,17 +279,19 @@ func (s *SQLiteStore) GetByStatus(status ResourceStatus) miter.Iterator[Resource if r.OutgoingLinkURLs, err = parseURLs( strings.Split(outgoing.String, "\x00"), ); err != nil { - return Resource{}, errors.Join( - rows.Close(), fmt.Errorf("parsing outgoing links: %w", err), + return Resource{}, 0, fmt.Errorf( + "parsing outgoing links: %w", err, ) } } - return r, nil - }), nil - }) + return r, rowID, nil + }, + ) } +const getURLsByLastCheckedLimit = 64 + // GetURLsByLastChecked implements the method for the Store interface. func (s *SQLiteStore) GetURLsByLastChecked( olderThan time.Time, @@ -236,42 +303,44 @@ func (s *SQLiteStore) GetURLsByLastChecked( FROM links GROUP BY to_url_id ) - SELECT url + SELECT resources.url_id, url FROM resources JOIN urls ON (urls.id = resources.url_id) LEFT JOIN incoming ON (incoming.url_id = resources.url_id) WHERE last_checked < ? - AND (pinned OR incoming.urls IS NOT NULL)` - - return miter.Lazily(func(ctx context.Context) (miter.Iterator[URL], error) { - rows, err := s.db.QueryContext(ctx, query, olderThan.Unix()) - if err != nil { - return nil, fmt.Errorf("executing query: %w", err) - } - - return miter.FromFunc(func(ctx context.Context) (URL, error) { - if !rows.Next() { - return "", errors.Join(rows.Close(), miter.ErrEnd) + AND (pinned OR incoming.urls IS NOT NULL) + AND resources.url_id > ? + ORDER BY resources.url_id ASC + LIMIT ?` + + return iterate( + s.db, + query, + func(minRowID int) []any { + return []any{ + olderThan.Unix(), minRowID, getURLsByLastCheckedLimit, } + }, + func(rows *sql.Rows) (URL, int, error) { + var ( + urlID int + urlStr string + ) - var urlStr string - if err := rows.Scan(&urlStr); err != nil { - return "", errors.Join( - rows.Close(), fmt.Errorf("scanning url: %w", err), - ) + if err := rows.Scan(&urlID, &urlStr); err != nil { + return "", 0, fmt.Errorf("scanning url: %w", err) } url, err := ParseURL(urlStr) if err != nil { - return "", errors.Join( - rows.Close(), - fmt.Errorf("parsing url %q from db: %w", urlStr, err), + return "", 0, fmt.Errorf( + "parsing url %q from db: %w", urlStr, err, ) } - return url, nil - }), nil - }) + return url, urlID, nil + }, + ) } func (s *SQLiteStore) touch(ctx context.Context, urls []URL, pinned bool) ( diff --git a/store_test.go b/store_test.go index bace044..662c03f 100644 --- a/store_test.go +++ b/store_test.go @@ -2,6 +2,7 @@ package deadlinks import ( "context" + "fmt" "sort" "testing" "time" @@ -84,7 +85,7 @@ func TestSQLiteStore(t *testing.T) { h.assertGetByStatus(t, []Resource{b}, ResourceStatusUnknown) }) - t.Run("Update", func(t *testing.T) { + t.Run("Update/general", func(t *testing.T) { t.Parallel() var ( @@ -142,6 +143,59 @@ func TestSQLiteStore(t *testing.T) { h.assertGetByStatus(t, nil, ResourceStatusError) }) + t.Run("Update/while_GetByStatus", func(t *testing.T) { + t.Parallel() + + var ( + h = newSQLiteStoreHarness() + + urlA = URL("https://a.com") + urlB = URL("https://b.com") + urlC = URL("https://c.com") + + a = Resource{ + URL: urlA, + Status: ResourceStatusOK, + Pinned: true, + LastChecked: h.now, + OutgoingLinkURLs: []URL{urlC}, + } + + b = Resource{ + URL: urlB, + Status: ResourceStatusOK, + Pinned: true, + LastChecked: h.now, + OutgoingLinkURLs: []URL{urlC}, + } + + c = Resource{ + URL: urlC, + Status: ResourceStatusOK, + LastChecked: h.now, + IncomingLinkURLs: []URL{urlA, urlB, urlC}, + OutgoingLinkURLs: []URL{urlC}, + } + ) + + assert.NoError(t, h.store.SetPinned(h.ctx, []URL{urlA, urlB})) + + iter := h.store.GetByStatus(ResourceStatusUnknown) + err := miter.ForEach(h.ctx, iter, func(r Resource) error { + err := h.store.Update( + h.ctx, h.now, r.URL, ResourceStatusOK, "", []URL{urlC}, + ) + + if err != nil { + return fmt.Errorf("updating %+v: %w", r, err) + } + return nil + }) + assert.NoError(t, err) + + h.assertGetByStatus(t, []Resource{a, b, c}, ResourceStatusOK) + }) + t.Run("GetURLsByLastChecked", func(t *testing.T) { t.Parallel()