diff --git a/mrun/hook.go b/mrun/hook.go index 2e38374..f1a7e70 100644 --- a/mrun/hook.go +++ b/mrun/hook.go @@ -14,7 +14,6 @@ type ctxKey int const ( ctxKeyHookEls ctxKey = iota - ctxKeyNumChildren ctxKeyNumHooks ) @@ -31,13 +30,10 @@ type hookEl struct { child context.Context } -func ctxKeys(userKey interface{}) (ctxKeyWrap, ctxKeyWrap, ctxKeyWrap) { +func ctxKeys(userKey interface{}) (ctxKeyWrap, ctxKeyWrap) { return ctxKeyWrap{ key: ctxKeyHookEls, userKey: userKey, - }, ctxKeyWrap{ - key: ctxKeyNumChildren, - userKey: userKey, }, ctxKeyWrap{ key: ctxKeyNumHooks, userKey: userKey, @@ -48,21 +44,22 @@ func ctxKeys(userKey interface{}) (ctxKeyWrap, ctxKeyWrap, ctxKeyWrap) { // appends more elements if more children have been added since that []hookEl // was created. // -// this also returns the latest numChildren and numHooks values for convenience. -func getHookEls(ctx context.Context, userKey interface{}) ([]hookEl, int, int) { - hookElsKey, numChildrenKey, numHooksKey := ctxKeys(userKey) - lastNumChildren, _ := mctx.LocalValue(ctx, numChildrenKey).(int) +// this also returns the latest numHooks value for convenience. +func getHookEls(ctx context.Context, userKey interface{}) ([]hookEl, int) { + hookElsKey, numHooksKey := ctxKeys(userKey) lastNumHooks, _ := mctx.LocalValue(ctx, numHooksKey).(int) lastHookEls, _ := mctx.LocalValue(ctx, hookElsKey).([]hookEl) children := mctx.Children(ctx) // plus 1 in case we wanna append something else outside this function - hookEls := make([]hookEl, len(lastHookEls), lastNumHooks+len(children)-lastNumChildren+1) + hookEls := make([]hookEl, len(lastHookEls), lastNumHooks+len(children)+1) copy(hookEls, lastHookEls) + + lastNumChildren := len(lastHookEls) - lastNumHooks for _, child := range children[lastNumChildren:] { hookEls = append(hookEls, hookEl{child: child}) } - return hookEls, len(children), lastNumHooks + return hookEls, lastNumHooks } // WithHook registers a Hook under a typed key. The Hook will be called when @@ -71,18 +68,17 @@ func getHookEls(ctx context.Context, userKey interface{}) ([]hookEl, int, int) { // // Hooks will be called with whatever Context is passed into TriggerHooks. func WithHook(ctx context.Context, key interface{}, hook Hook) context.Context { - hookEls, numChildren, numHooks := getHookEls(ctx, key) + hookEls, numHooks := getHookEls(ctx, key) hookEls = append(hookEls, hookEl{hook: hook}) - hookElsKey, numChildrenKey, numHooksKey := ctxKeys(key) + hookElsKey, numHooksKey := ctxKeys(key) ctx = mctx.WithLocalValue(ctx, hookElsKey, hookEls) - ctx = mctx.WithLocalValue(ctx, numChildrenKey, numChildren) ctx = mctx.WithLocalValue(ctx, numHooksKey, numHooks+1) return ctx } func triggerHooks(ctx context.Context, userKey interface{}, next func([]hookEl) (hookEl, []hookEl)) error { - hookEls, _, _ := getHookEls(ctx, userKey) + hookEls, _ := getHookEls(ctx, userKey) var hookEl hookEl for { if len(hookEls) == 0 { diff --git a/mrun/hook_test.go b/mrun/hook_test.go index ef6d9f3..b4c5b09 100644 --- a/mrun/hook_test.go +++ b/mrun/hook_test.go @@ -35,14 +35,16 @@ func TestHooks(t *T) { ctxB = mctx.WithChild(ctxB, ctxB1) ctx = mctx.WithChild(ctx, ctxB) + ctx = WithHook(ctx, 0, mkHook(7)) + massert.Fatal(t, massert.All( massert.Nil(TriggerHooks(ctx, 0)), - massert.Equal([]int{1, 2, 3, 4, 5, 6}, out), + massert.Equal([]int{1, 2, 3, 4, 5, 6, 7}, out), )) out = nil massert.Fatal(t, massert.All( massert.Nil(TriggerHooksReverse(ctx, 0)), - massert.Equal([]int{6, 5, 4, 3, 2, 1}, out), + massert.Equal([]int{7, 6, 5, 4, 3, 2, 1}, out), )) }