mrun: change WithThread to WithThreads

This commit is contained in:
Brian Picciano 2019-02-24 17:38:05 -05:00
parent dc57aadb54
commit 553a4854ea
3 changed files with 34 additions and 26 deletions

View File

@ -43,7 +43,7 @@ func WithListeningServer(ctx context.Context, h http.Handler) (context.Context,
srv.ctx = mrun.WithStartHook(srv.ctx, func(context.Context) error { srv.ctx = mrun.WithStartHook(srv.ctx, func(context.Context) error {
srv.Addr = listener.Addr().String() srv.Addr = listener.Addr().String()
srv.ctx = mrun.WithThread(srv.ctx, func() error { srv.ctx = mrun.WithThreads(srv.ctx, 1, func() error {
mlog.Info("serving requests", srv.ctx) mlog.Info("serving requests", srv.ctx)
if err := srv.Serve(listener); !merr.Equal(err, http.ErrServerClosed) { if err := srv.Serve(listener); !merr.Equal(err, http.ErrServerClosed) {
mlog.Error("error serving listener", srv.ctx, merr.Context(err)) mlog.Error("error serving listener", srv.ctx, merr.Context(err))

View File

@ -36,31 +36,39 @@ func (fe *futureErr) set(err error) {
type threadCtxKey int type threadCtxKey int
// WithThread spawns a go-routine which executes the given function. The // WithThreads spawns n go-routines, each of which executes the given function. The
// returned Context tracks this go-routine, which can then be passed into the // returned Context tracks these go-routines, and can then be passed into the
// Wait function to block until the spawned go-routine returns. // Wait function to block until the spawned go-routines all return.
func WithThread(ctx context.Context, fn func() error) context.Context { func WithThreads(ctx context.Context, n uint, fn func() error) context.Context {
futErr := newFutureErr() // I dunno why this would happen, but it wouldn't actually hurt anything
if n == 0 {
return ctx
}
oldFutErrs, _ := ctx.Value(threadCtxKey(0)).([]*futureErr) oldFutErrs, _ := ctx.Value(threadCtxKey(0)).([]*futureErr)
futErrs := make([]*futureErr, len(oldFutErrs), len(oldFutErrs)+1) futErrs := make([]*futureErr, len(oldFutErrs), len(oldFutErrs)+int(n))
copy(futErrs, oldFutErrs) copy(futErrs, oldFutErrs)
for i := uint(0); i < n; i++ {
futErr := newFutureErr()
futErrs = append(futErrs, futErr) futErrs = append(futErrs, futErr)
ctx = context.WithValue(ctx, threadCtxKey(0), futErrs)
go func() { go func() {
futErr.set(fn()) futErr.set(fn())
}() }()
}
return ctx return context.WithValue(ctx, threadCtxKey(0), futErrs)
} }
// ErrDone is returned from Wait if cancelCh is closed before all threads have // ErrDone is returned from Wait if cancelCh is closed before all threads have
// returned. // returned.
var ErrDone = errors.New("Wait is done waiting") var ErrDone = errors.New("Wait is done waiting")
// Wait blocks until all go-routines spawned using Thread on the passed in // Wait blocks until all go-routines spawned using WithThreads on the passed in
// Context (and its predecessors) have returned. Any number of the go-routines // Context (and its predecessors) have returned. Any number of the go-routines
// may have returned already when Wait is called. // may have returned already when Wait is called, and not all go-routines need
// be from the same WithThreads call.
// //
// If any of the thread functions returned an error during its runtime Wait will // If any of the thread functions returned an error during its runtime Wait will
// return that error. If multiple returned an error only one of those will be // return that error. If multiple returned an error only one of those will be
@ -70,7 +78,7 @@ var ErrDone = errors.New("Wait is done waiting")
// this function stops waiting and returns ErrDone. // this function stops waiting and returns ErrDone.
// //
// Wait is safe to call in parallel, and will return the same result if called // Wait is safe to call in parallel, and will return the same result if called
// multiple times in sequence. // multiple times.
func Wait(ctx context.Context, cancelCh <-chan struct{}) error { func Wait(ctx context.Context, cancelCh <-chan struct{}) error {
// First wait for all the children, and see if any of them return an error // First wait for all the children, and see if any of them return an error
children := mctx.Children(ctx) children := mctx.Children(ctx)

View File

@ -30,7 +30,7 @@ func TestThreadWait(t *T) {
t.Run("noBlock", func(t *T) { t.Run("noBlock", func(t *T) {
t.Run("noErr", func(t *T) { t.Run("noErr", func(t *T) {
ctx := context.Background() ctx := context.Background()
ctx = WithThread(ctx, func() error { return nil }) ctx = WithThreads(ctx, 1, func() error { return nil })
if err := Wait(ctx, nil); err != nil { if err := Wait(ctx, nil); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -38,7 +38,7 @@ func TestThreadWait(t *T) {
t.Run("err", func(t *T) { t.Run("err", func(t *T) {
ctx := context.Background() ctx := context.Background()
ctx = WithThread(ctx, func() error { return testErr }) ctx = WithThreads(ctx, 1, func() error { return testErr })
if err := Wait(ctx, nil); err != testErr { if err := Wait(ctx, nil); err != testErr {
t.Fatalf("should have got test error, got: %v", err) t.Fatalf("should have got test error, got: %v", err)
} }
@ -48,7 +48,7 @@ func TestThreadWait(t *T) {
t.Run("block", func(t *T) { t.Run("block", func(t *T) {
t.Run("noErr", func(t *T) { t.Run("noErr", func(t *T) {
ctx := context.Background() ctx := context.Background()
ctx = WithThread(ctx, func() error { ctx = WithThreads(ctx, 1, func() error {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
return nil return nil
}) })
@ -59,7 +59,7 @@ func TestThreadWait(t *T) {
t.Run("err", func(t *T) { t.Run("err", func(t *T) {
ctx := context.Background() ctx := context.Background()
ctx = WithThread(ctx, func() error { ctx = WithThreads(ctx, 1, func() error {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
return testErr return testErr
}) })
@ -70,7 +70,7 @@ func TestThreadWait(t *T) {
t.Run("canceled", func(t *T) { t.Run("canceled", func(t *T) {
ctx := context.Background() ctx := context.Background()
ctx = WithThread(ctx, func() error { ctx = WithThreads(ctx, 1, func() error {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
return testErr return testErr
}) })
@ -90,7 +90,7 @@ func TestThreadWait(t *T) {
t.Run("noBlock", func(t *T) { t.Run("noBlock", func(t *T) {
t.Run("noErr", func(t *T) { t.Run("noErr", func(t *T) {
ctx, childCtx := ctxWithChild() ctx, childCtx := ctxWithChild()
childCtx = WithThread(childCtx, func() error { return nil }) childCtx = WithThreads(childCtx, 1, func() error { return nil })
ctx = mctx.WithChild(ctx, childCtx) ctx = mctx.WithChild(ctx, childCtx)
if err := Wait(ctx, nil); err != nil { if err := Wait(ctx, nil); err != nil {
t.Fatal(err) t.Fatal(err)
@ -99,7 +99,7 @@ func TestThreadWait(t *T) {
t.Run("err", func(t *T) { t.Run("err", func(t *T) {
ctx, childCtx := ctxWithChild() ctx, childCtx := ctxWithChild()
childCtx = WithThread(childCtx, func() error { return testErr }) childCtx = WithThreads(childCtx, 1, func() error { return testErr })
ctx = mctx.WithChild(ctx, childCtx) ctx = mctx.WithChild(ctx, childCtx)
if err := Wait(ctx, nil); err != testErr { if err := Wait(ctx, nil); err != testErr {
t.Fatalf("should have got test error, got: %v", err) t.Fatalf("should have got test error, got: %v", err)
@ -110,7 +110,7 @@ func TestThreadWait(t *T) {
t.Run("block", func(t *T) { t.Run("block", func(t *T) {
t.Run("noErr", func(t *T) { t.Run("noErr", func(t *T) {
ctx, childCtx := ctxWithChild() ctx, childCtx := ctxWithChild()
childCtx = WithThread(childCtx, func() error { childCtx = WithThreads(childCtx, 1, func() error {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
return nil return nil
}) })
@ -122,7 +122,7 @@ func TestThreadWait(t *T) {
t.Run("err", func(t *T) { t.Run("err", func(t *T) {
ctx, childCtx := ctxWithChild() ctx, childCtx := ctxWithChild()
childCtx = WithThread(childCtx, func() error { childCtx = WithThreads(childCtx, 1, func() error {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
return testErr return testErr
}) })
@ -134,7 +134,7 @@ func TestThreadWait(t *T) {
t.Run("canceled", func(t *T) { t.Run("canceled", func(t *T) {
ctx, childCtx := ctxWithChild() ctx, childCtx := ctxWithChild()
childCtx = WithThread(childCtx, func() error { childCtx = WithThreads(childCtx, 1, func() error {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
return testErr return testErr
}) })