Fix concurrent client usage, now each worker gets its own Client

This commit is contained in:
Brian Picciano 2023-12-30 12:41:08 +01:00
parent 307e311b61
commit eaccb83a7b

View File

@ -33,13 +33,13 @@ import (
// Opts are optional fields which can be provided to New. A nil Opts is // Opts are optional fields which can be provided to New. A nil Opts is
// equivalent to an empty one. // equivalent to an empty one.
type Opts struct { type Opts struct {
Client Client // Defaults to `NewClient(nil)` NewClient func() Client // Defaults to `func () Client { return NewClient(nil) }`
Parser Parser // Defaults to `NewParser()` Parser Parser // Defaults to `NewParser()`
// Concurrency determines the maximum number of URLs which can be checked // Concurrency determines the maximum number of URLs which can be checked
// simultaneously. // simultaneously.
// //
// Default: `runtime.NumCPU()` // Default: `runtime.NumCPU() / 2`
Concurrency int Concurrency int
// OnError, if set, will be called whenever DeadLinks encounters an error // OnError, if set, will be called whenever DeadLinks encounters an error
@ -58,8 +58,8 @@ func (o *Opts) withDefaults() *Opts {
o = new(Opts) o = new(Opts)
} }
if o.Client == nil { if o.NewClient == nil {
o.Client = NewClient(nil) o.NewClient = func() Client { return NewClient(nil) }
} }
if o.Parser == nil { if o.Parser == nil {
@ -67,7 +67,7 @@ func (o *Opts) withDefaults() *Opts {
} }
if o.Concurrency == 0 { if o.Concurrency == 0 {
o.Concurrency = runtime.NumCPU() o.Concurrency = runtime.NumCPU() / 2
} }
if o.RequestTimeout == 0 { if o.RequestTimeout == 0 {
@ -87,6 +87,7 @@ type DeadLinks struct {
opts Opts opts Opts
store Store store Store
patterns []*regexp.Regexp patterns []*regexp.Regexp
clients []Client
} }
// New initializes and returns a DeadLinks instance which will track the // New initializes and returns a DeadLinks instance which will track the
@ -129,6 +130,11 @@ func New(
patterns: patterns, patterns: patterns,
} }
d.clients = make([]Client, d.opts.Concurrency)
for i := range d.clients {
d.clients[i] = d.opts.NewClient()
}
if err := d.store.SetPinned(ctx, pinnedURLs); err != nil { if err := d.store.SetPinned(ctx, pinnedURLs); err != nil {
return nil, fmt.Errorf("pinning URLs: %w", err) return nil, fmt.Errorf("pinning URLs: %w", err)
} }
@ -155,11 +161,15 @@ func (d *DeadLinks) shouldFollowURL(url URL) bool {
return false return false
} }
func (d *DeadLinks) getURL(ctx context.Context, url URL) ([]URL, error) { func (d *DeadLinks) getURL(
ctx context.Context, client Client, url URL,
) (
[]URL, error,
) {
ctx, cancel := context.WithTimeout(ctx, d.opts.RequestTimeout) ctx, cancel := context.WithTimeout(ctx, d.opts.RequestTimeout)
defer cancel() defer cancel()
mimeType, body, err := d.opts.Client.Get(ctx, url) mimeType, body, err := client.Get(ctx, url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -178,14 +188,16 @@ func (d *DeadLinks) getURL(ctx context.Context, url URL) ([]URL, error) {
} }
// checkURL only returns an error if storing the results of the check fails. // checkURL only returns an error if storing the results of the check fails.
func (d *DeadLinks) checkURL(ctx context.Context, url URL) error { func (d *DeadLinks) checkURL(
ctx context.Context, client Client, url URL,
) error {
var ( var (
now = time.Now() now = time.Now()
status = ResourceStatusOK status = ResourceStatusOK
errorStr string errorStr string
) )
outgoingURLs, err := d.getURL(ctx, url) outgoingURLs, err := d.getURL(ctx, client, url)
if err != nil { if err != nil {
status = ResourceStatusError status = ResourceStatusError
errorStr = err.Error() errorStr = err.Error()
@ -219,14 +231,14 @@ func (d *DeadLinks) update(
wg.Add(d.opts.Concurrency) wg.Add(d.opts.Concurrency)
for i := 0; i < d.opts.Concurrency; i++ { for i := 0; i < d.opts.Concurrency; i++ {
go func() { go func(client Client) {
defer wg.Done() defer wg.Done()
for url := range ch { for url := range ch {
if err := d.checkURL(ctx, url); err != nil { if err := d.checkURL(ctx, client, url); err != nil {
d.onError(ctx, fmt.Errorf("checking url %q: %w", url, err)) d.onError(ctx, fmt.Errorf("checking url %q: %w", url, err))
} }
} }
}() }(d.clients[i])
} }
var ( var (