From eaccb83a7bcc215f70fc10a0a565f742925cf7b5 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Sat, 30 Dec 2023 12:41:08 +0100 Subject: [PATCH] Fix concurrent client usage, now each worker gets its own Client --- deadlinks.go | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/deadlinks.go b/deadlinks.go index 6ff4a28..ba51eb6 100644 --- a/deadlinks.go +++ b/deadlinks.go @@ -33,13 +33,13 @@ import ( // Opts are optional fields which can be provided to New. A nil Opts is // equivalent to an empty one. type Opts struct { - Client Client // Defaults to `NewClient(nil)` - Parser Parser // Defaults to `NewParser()` + NewClient func() Client // Defaults to `func () Client { return NewClient(nil) }` + Parser Parser // Defaults to `NewParser()` // Concurrency determines the maximum number of URLs which can be checked // simultaneously. // - // Default: `runtime.NumCPU()` + // Default: `runtime.NumCPU() / 2` Concurrency int // OnError, if set, will be called whenever DeadLinks encounters an error @@ -58,8 +58,8 @@ func (o *Opts) withDefaults() *Opts { o = new(Opts) } - if o.Client == nil { - o.Client = NewClient(nil) + if o.NewClient == nil { + o.NewClient = func() Client { return NewClient(nil) } } if o.Parser == nil { @@ -67,7 +67,7 @@ func (o *Opts) withDefaults() *Opts { } if o.Concurrency == 0 { - o.Concurrency = runtime.NumCPU() + o.Concurrency = runtime.NumCPU() / 2 } if o.RequestTimeout == 0 { @@ -87,6 +87,7 @@ type DeadLinks struct { opts Opts store Store patterns []*regexp.Regexp + clients []Client } // New initializes and returns a DeadLinks instance which will track the @@ -129,6 +130,11 @@ func New( patterns: patterns, } + d.clients = make([]Client, d.opts.Concurrency) + for i := range d.clients { + d.clients[i] = d.opts.NewClient() + } + if err := d.store.SetPinned(ctx, pinnedURLs); err != nil { return nil, fmt.Errorf("pinning URLs: %w", err) } @@ -155,11 +161,15 @@ func (d *DeadLinks) shouldFollowURL(url URL) bool { return false } -func (d *DeadLinks) getURL(ctx context.Context, url URL) ([]URL, error) { +func (d *DeadLinks) getURL( + ctx context.Context, client Client, url URL, +) ( + []URL, error, +) { ctx, cancel := context.WithTimeout(ctx, d.opts.RequestTimeout) defer cancel() - mimeType, body, err := d.opts.Client.Get(ctx, url) + mimeType, body, err := client.Get(ctx, url) if err != nil { return nil, err } @@ -178,14 +188,16 @@ func (d *DeadLinks) getURL(ctx context.Context, url URL) ([]URL, error) { } // checkURL only returns an error if storing the results of the check fails. -func (d *DeadLinks) checkURL(ctx context.Context, url URL) error { +func (d *DeadLinks) checkURL( + ctx context.Context, client Client, url URL, +) error { var ( now = time.Now() status = ResourceStatusOK errorStr string ) - outgoingURLs, err := d.getURL(ctx, url) + outgoingURLs, err := d.getURL(ctx, client, url) if err != nil { status = ResourceStatusError errorStr = err.Error() @@ -219,14 +231,14 @@ func (d *DeadLinks) update( wg.Add(d.opts.Concurrency) for i := 0; i < d.opts.Concurrency; i++ { - go func() { + go func(client Client) { defer wg.Done() for url := range ch { - if err := d.checkURL(ctx, url); err != nil { + if err := d.checkURL(ctx, client, url); err != nil { d.onError(ctx, fmt.Errorf("checking url %q: %w", url, err)) } } - }() + }(d.clients[i]) } var (