diff --git a/mrun/mrun.go b/mrun/mrun.go index 3811d02..adb23f8 100644 --- a/mrun/mrun.go +++ b/mrun/mrun.go @@ -94,3 +94,103 @@ func Wait(ctx mctx.Context, cancelCh <-chan struct{}) error { return nil } + +type ctxEventKeyWrap struct { + key interface{} +} + +// Hook describes a function which can be registered to trigger on an event via +// the OnEvent function. +type Hook func(mctx.Context) error + +// OnEvent registers a Hook under a typed key. The Hook will be called when +// TriggerEvent is called with that same key. Multiple Hooks can be registered +// for the same key, and will be called sequentially when triggered. +// +// OnEvent registers Hooks onto the root of the given Context. Therefore, Hooks +// will be triggered in the global order they were registered (i.e. if a Hook is +// registered on a Context, then one registered on a child of that Context, then +// another on the original Context again, the three Hooks will be triggered in +// the order: parent, child, parent). +// +// Hooks will be called with whatever Context is passed into TriggerEvent. +func OnEvent(ctx mctx.Context, key interface{}, hook Hook) { + ctx = mctx.Root(ctx) + mctx.GetSetMutableValue(ctx, false, ctxEventKeyWrap{key}, func(v interface{}) interface{} { + hooks, _ := v.([]Hook) + return append(hooks, hook) + }) +} + +// TriggerEvent causes all Hooks registered with OnEvent under the given key to +// be called sequentially, using the given Context as their input. The given +// Context does not need to be the root Context (see OnEvent). +// +// If any Hook returns an error no further Hooks will be called and that error +// will be returned. +// +// TriggerEvent causes all Hooks which were called to be de-registered. If an +// error caused execution to stop prematurely then any Hooks which were not +// called will remain registered. +func TriggerEvent(ctx mctx.Context, key interface{}) error { + rootCtx := mctx.Root(ctx) + var err error + mctx.GetSetMutableValue(rootCtx, false, ctxEventKeyWrap{key}, func(i interface{}) interface{} { + hooks, _ := i.([]Hook) + for _, hook := range hooks { + hooks = hooks[1:] + + // err here is the var outside GetSetMutableValue, we lift it out + if err = hook(ctx); err != nil { + break + } + } + + // if there was an error then we want to keep all the hooks which + // weren't called. If there wasn't we want to reset the value to nil so + // the slice doesn't grow unbounded. + if err != nil { + return hooks + } + return nil + }) + return err +} + +type builtinEvent int + +const ( + start builtinEvent = iota + stop +) + +// OnStart registers the given Hook to run when Start is called. This is a +// special case of OnEvent. +// +// As a convention Hooks running on the start event should block only as long as +// it takes to ensure that whatever is running can do so successfully. For +// short-lived tasks this isn't a problem, but long-lived tasks (e.g. a web +// server) will want to use the Hook only to initialize, and spawn off a +// go-routine to do their actual work. Long-lived tasks should set themselves up +// to stop on the stop event (see OnStop). +func OnStart(ctx mctx.Context, hook Hook) { + OnEvent(ctx, start, hook) +} + +// Start runs all Hooks registered using OnStart. This is a special case of +// TriggerEvent. +func Start(ctx mctx.Context) error { + return TriggerEvent(ctx, start) +} + +// OnStop registers the given Hook to run when Stop is called. This is a special +// case of OnEvent. +func OnStop(ctx mctx.Context, hook Hook) { + OnEvent(ctx, stop, hook) +} + +// Stop runs all Hooks registered using OnStop. This is a special case of +// TriggerEvent. +func Stop(ctx mctx.Context) error { + return TriggerEvent(ctx, stop) +} diff --git a/mrun/mrun_test.go b/mrun/mrun_test.go index 70981e3..a2bbc3c 100644 --- a/mrun/mrun_test.go +++ b/mrun/mrun_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/mediocregopher/mediocre-go-lib/mctx" + "github.com/mediocregopher/mediocre-go-lib/mtest/massert" ) func TestThreadWait(t *T) { @@ -140,3 +141,48 @@ func TestThreadWait(t *T) { }) }) } + +func TestEvent(t *T) { + ch := make(chan int, 10) + ctx := mctx.New() + ctxChild := mctx.ChildOf(ctx, "child") + + mkHook := func(i int) Hook { + return func(mctx.Context) error { + ch <- i + return nil + } + } + + OnEvent(ctx, 0, mkHook(0)) + OnEvent(ctxChild, 0, mkHook(1)) + OnEvent(ctx, 0, mkHook(2)) + + bogusErr := errors.New("bogus error") + OnEvent(ctxChild, 0, func(mctx.Context) error { return bogusErr }) + + OnEvent(ctx, 0, mkHook(3)) + OnEvent(ctx, 0, mkHook(4)) + + massert.Fatal(t, massert.All( + massert.Equal(bogusErr, TriggerEvent(ctx, 0)), + massert.Equal(0, <-ch), + massert.Equal(1, <-ch), + massert.Equal(2, <-ch), + )) + + // after the error the 3 and 4 Hooks should still be registered, but not + // called yet. + + select { + case <-ch: + t.Fatal("Hooks should not have been called yet") + default: + } + + massert.Fatal(t, massert.All( + massert.Nil(TriggerEvent(ctx, 0)), + massert.Equal(3, <-ch), + massert.Equal(4, <-ch), + )) +}