diff --git a/mhttp/mhttp.go b/mhttp/mhttp.go index 2766060..dc97019 100644 --- a/mhttp/mhttp.go +++ b/mhttp/mhttp.go @@ -43,7 +43,7 @@ func WithListeningServer(ctx context.Context, h http.Handler) (context.Context, srv.ctx = mrun.WithStartHook(srv.ctx, func(context.Context) error { srv.Addr = listener.Addr().String() - srv.ctx = mrun.WithThread(srv.ctx, func() error { + srv.ctx = mrun.WithThreads(srv.ctx, 1, func() error { mlog.Info("serving requests", srv.ctx) if err := srv.Serve(listener); !merr.Equal(err, http.ErrServerClosed) { mlog.Error("error serving listener", srv.ctx, merr.Context(err)) diff --git a/mrun/mrun.go b/mrun/mrun.go index d5c44c4..d5cd751 100644 --- a/mrun/mrun.go +++ b/mrun/mrun.go @@ -36,31 +36,39 @@ func (fe *futureErr) set(err error) { type threadCtxKey int -// WithThread spawns a go-routine which executes the given function. The -// returned Context tracks this go-routine, which can then be passed into the -// Wait function to block until the spawned go-routine returns. -func WithThread(ctx context.Context, fn func() error) context.Context { - futErr := newFutureErr() +// WithThreads spawns n go-routines, each of which executes the given function. The +// returned Context tracks these go-routines, and can then be passed into the +// Wait function to block until the spawned go-routines all return. +func WithThreads(ctx context.Context, n uint, fn func() error) context.Context { + // I dunno why this would happen, but it wouldn't actually hurt anything + if n == 0 { + return ctx + } + oldFutErrs, _ := ctx.Value(threadCtxKey(0)).([]*futureErr) - futErrs := make([]*futureErr, len(oldFutErrs), len(oldFutErrs)+1) + futErrs := make([]*futureErr, len(oldFutErrs), len(oldFutErrs)+int(n)) copy(futErrs, oldFutErrs) - futErrs = append(futErrs, futErr) - ctx = context.WithValue(ctx, threadCtxKey(0), futErrs) - go func() { - futErr.set(fn()) - }() + for i := uint(0); i < n; i++ { + futErr := newFutureErr() + futErrs = append(futErrs, futErr) - return ctx + go func() { + futErr.set(fn()) + }() + } + + return context.WithValue(ctx, threadCtxKey(0), futErrs) } // ErrDone is returned from Wait if cancelCh is closed before all threads have // returned. var ErrDone = errors.New("Wait is done waiting") -// Wait blocks until all go-routines spawned using Thread on the passed in +// Wait blocks until all go-routines spawned using WithThreads on the passed in // Context (and its predecessors) have returned. Any number of the go-routines -// may have returned already when Wait is called. +// may have returned already when Wait is called, and not all go-routines need +// be from the same WithThreads call. // // If any of the thread functions returned an error during its runtime Wait will // return that error. If multiple returned an error only one of those will be @@ -70,7 +78,7 @@ var ErrDone = errors.New("Wait is done waiting") // this function stops waiting and returns ErrDone. // // Wait is safe to call in parallel, and will return the same result if called -// multiple times in sequence. +// multiple times. func Wait(ctx context.Context, cancelCh <-chan struct{}) error { // First wait for all the children, and see if any of them return an error children := mctx.Children(ctx) diff --git a/mrun/mrun_test.go b/mrun/mrun_test.go index 164dc62..1fc0c11 100644 --- a/mrun/mrun_test.go +++ b/mrun/mrun_test.go @@ -30,7 +30,7 @@ func TestThreadWait(t *T) { t.Run("noBlock", func(t *T) { t.Run("noErr", func(t *T) { ctx := context.Background() - ctx = WithThread(ctx, func() error { return nil }) + ctx = WithThreads(ctx, 1, func() error { return nil }) if err := Wait(ctx, nil); err != nil { t.Fatal(err) } @@ -38,7 +38,7 @@ func TestThreadWait(t *T) { t.Run("err", func(t *T) { ctx := context.Background() - ctx = WithThread(ctx, func() error { return testErr }) + ctx = WithThreads(ctx, 1, func() error { return testErr }) if err := Wait(ctx, nil); err != testErr { t.Fatalf("should have got test error, got: %v", err) } @@ -48,7 +48,7 @@ func TestThreadWait(t *T) { t.Run("block", func(t *T) { t.Run("noErr", func(t *T) { ctx := context.Background() - ctx = WithThread(ctx, func() error { + ctx = WithThreads(ctx, 1, func() error { time.Sleep(1 * time.Second) return nil }) @@ -59,7 +59,7 @@ func TestThreadWait(t *T) { t.Run("err", func(t *T) { ctx := context.Background() - ctx = WithThread(ctx, func() error { + ctx = WithThreads(ctx, 1, func() error { time.Sleep(1 * time.Second) return testErr }) @@ -70,7 +70,7 @@ func TestThreadWait(t *T) { t.Run("canceled", func(t *T) { ctx := context.Background() - ctx = WithThread(ctx, func() error { + ctx = WithThreads(ctx, 1, func() error { time.Sleep(5 * time.Second) return testErr }) @@ -90,7 +90,7 @@ func TestThreadWait(t *T) { t.Run("noBlock", func(t *T) { t.Run("noErr", func(t *T) { ctx, childCtx := ctxWithChild() - childCtx = WithThread(childCtx, func() error { return nil }) + childCtx = WithThreads(childCtx, 1, func() error { return nil }) ctx = mctx.WithChild(ctx, childCtx) if err := Wait(ctx, nil); err != nil { t.Fatal(err) @@ -99,7 +99,7 @@ func TestThreadWait(t *T) { t.Run("err", func(t *T) { ctx, childCtx := ctxWithChild() - childCtx = WithThread(childCtx, func() error { return testErr }) + childCtx = WithThreads(childCtx, 1, func() error { return testErr }) ctx = mctx.WithChild(ctx, childCtx) if err := Wait(ctx, nil); err != testErr { t.Fatalf("should have got test error, got: %v", err) @@ -110,7 +110,7 @@ func TestThreadWait(t *T) { t.Run("block", func(t *T) { t.Run("noErr", func(t *T) { ctx, childCtx := ctxWithChild() - childCtx = WithThread(childCtx, func() error { + childCtx = WithThreads(childCtx, 1, func() error { time.Sleep(1 * time.Second) return nil }) @@ -122,7 +122,7 @@ func TestThreadWait(t *T) { t.Run("err", func(t *T) { ctx, childCtx := ctxWithChild() - childCtx = WithThread(childCtx, func() error { + childCtx = WithThreads(childCtx, 1, func() error { time.Sleep(1 * time.Second) return testErr }) @@ -134,7 +134,7 @@ func TestThreadWait(t *T) { t.Run("canceled", func(t *T) { ctx, childCtx := ctxWithChild() - childCtx = WithThread(childCtx, func() error { + childCtx = WithThreads(childCtx, 1, func() error { time.Sleep(5 * time.Second) return testErr })