mrun: fix bug in retrieving set of hooks/children from a context, and simplify the logic a bit

This commit is contained in:
Brian Picciano 2019-03-03 15:34:44 -05:00
parent 015edcd69a
commit 3d939a1e80
2 changed files with 15 additions and 17 deletions

View File

@ -14,7 +14,6 @@ type ctxKey int
const ( const (
ctxKeyHookEls ctxKey = iota ctxKeyHookEls ctxKey = iota
ctxKeyNumChildren
ctxKeyNumHooks ctxKeyNumHooks
) )
@ -31,13 +30,10 @@ type hookEl struct {
child context.Context child context.Context
} }
func ctxKeys(userKey interface{}) (ctxKeyWrap, ctxKeyWrap, ctxKeyWrap) { func ctxKeys(userKey interface{}) (ctxKeyWrap, ctxKeyWrap) {
return ctxKeyWrap{ return ctxKeyWrap{
key: ctxKeyHookEls, key: ctxKeyHookEls,
userKey: userKey, userKey: userKey,
}, ctxKeyWrap{
key: ctxKeyNumChildren,
userKey: userKey,
}, ctxKeyWrap{ }, ctxKeyWrap{
key: ctxKeyNumHooks, key: ctxKeyNumHooks,
userKey: userKey, userKey: userKey,
@ -48,21 +44,22 @@ func ctxKeys(userKey interface{}) (ctxKeyWrap, ctxKeyWrap, ctxKeyWrap) {
// appends more elements if more children have been added since that []hookEl // appends more elements if more children have been added since that []hookEl
// was created. // was created.
// //
// this also returns the latest numChildren and numHooks values for convenience. // this also returns the latest numHooks value for convenience.
func getHookEls(ctx context.Context, userKey interface{}) ([]hookEl, int, int) { func getHookEls(ctx context.Context, userKey interface{}) ([]hookEl, int) {
hookElsKey, numChildrenKey, numHooksKey := ctxKeys(userKey) hookElsKey, numHooksKey := ctxKeys(userKey)
lastNumChildren, _ := mctx.LocalValue(ctx, numChildrenKey).(int)
lastNumHooks, _ := mctx.LocalValue(ctx, numHooksKey).(int) lastNumHooks, _ := mctx.LocalValue(ctx, numHooksKey).(int)
lastHookEls, _ := mctx.LocalValue(ctx, hookElsKey).([]hookEl) lastHookEls, _ := mctx.LocalValue(ctx, hookElsKey).([]hookEl)
children := mctx.Children(ctx) children := mctx.Children(ctx)
// plus 1 in case we wanna append something else outside this function // plus 1 in case we wanna append something else outside this function
hookEls := make([]hookEl, len(lastHookEls), lastNumHooks+len(children)-lastNumChildren+1) hookEls := make([]hookEl, len(lastHookEls), lastNumHooks+len(children)+1)
copy(hookEls, lastHookEls) copy(hookEls, lastHookEls)
lastNumChildren := len(lastHookEls) - lastNumHooks
for _, child := range children[lastNumChildren:] { for _, child := range children[lastNumChildren:] {
hookEls = append(hookEls, hookEl{child: child}) hookEls = append(hookEls, hookEl{child: child})
} }
return hookEls, len(children), lastNumHooks return hookEls, lastNumHooks
} }
// WithHook registers a Hook under a typed key. The Hook will be called when // WithHook registers a Hook under a typed key. The Hook will be called when
@ -71,18 +68,17 @@ func getHookEls(ctx context.Context, userKey interface{}) ([]hookEl, int, int) {
// //
// Hooks will be called with whatever Context is passed into TriggerHooks. // Hooks will be called with whatever Context is passed into TriggerHooks.
func WithHook(ctx context.Context, key interface{}, hook Hook) context.Context { func WithHook(ctx context.Context, key interface{}, hook Hook) context.Context {
hookEls, numChildren, numHooks := getHookEls(ctx, key) hookEls, numHooks := getHookEls(ctx, key)
hookEls = append(hookEls, hookEl{hook: hook}) hookEls = append(hookEls, hookEl{hook: hook})
hookElsKey, numChildrenKey, numHooksKey := ctxKeys(key) hookElsKey, numHooksKey := ctxKeys(key)
ctx = mctx.WithLocalValue(ctx, hookElsKey, hookEls) ctx = mctx.WithLocalValue(ctx, hookElsKey, hookEls)
ctx = mctx.WithLocalValue(ctx, numChildrenKey, numChildren)
ctx = mctx.WithLocalValue(ctx, numHooksKey, numHooks+1) ctx = mctx.WithLocalValue(ctx, numHooksKey, numHooks+1)
return ctx return ctx
} }
func triggerHooks(ctx context.Context, userKey interface{}, next func([]hookEl) (hookEl, []hookEl)) error { func triggerHooks(ctx context.Context, userKey interface{}, next func([]hookEl) (hookEl, []hookEl)) error {
hookEls, _, _ := getHookEls(ctx, userKey) hookEls, _ := getHookEls(ctx, userKey)
var hookEl hookEl var hookEl hookEl
for { for {
if len(hookEls) == 0 { if len(hookEls) == 0 {

View File

@ -35,14 +35,16 @@ func TestHooks(t *T) {
ctxB = mctx.WithChild(ctxB, ctxB1) ctxB = mctx.WithChild(ctxB, ctxB1)
ctx = mctx.WithChild(ctx, ctxB) ctx = mctx.WithChild(ctx, ctxB)
ctx = WithHook(ctx, 0, mkHook(7))
massert.Fatal(t, massert.All( massert.Fatal(t, massert.All(
massert.Nil(TriggerHooks(ctx, 0)), massert.Nil(TriggerHooks(ctx, 0)),
massert.Equal([]int{1, 2, 3, 4, 5, 6}, out), massert.Equal([]int{1, 2, 3, 4, 5, 6, 7}, out),
)) ))
out = nil out = nil
massert.Fatal(t, massert.All( massert.Fatal(t, massert.All(
massert.Nil(TriggerHooksReverse(ctx, 0)), massert.Nil(TriggerHooksReverse(ctx, 0)),
massert.Equal([]int{6, 5, 4, 3, 2, 1}, out), massert.Equal([]int{7, 6, 5, 4, 3, 2, 1}, out),
)) ))
} }