diff --git a/mctx/ctx.go b/mctx/ctx.go index b89d0cf..4106391 100644 --- a/mctx/ctx.go +++ b/mctx/ctx.go @@ -10,93 +10,90 @@ package mctx import ( - "context" "sync" "time" + + goctx "context" ) // Context is the same as the builtin type, but is used to indicate that the // Context originally came from this package (aka New or ChildOf). -type Context context.Context +type Context goctx.Context // CancelFunc is a direct alias of the type from the context package, see its // docs. -type CancelFunc = context.CancelFunc +type CancelFunc = goctx.CancelFunc // WithValue mimics the function from the context package. func WithValue(parent Context, key, val interface{}) Context { - return Context(context.WithValue(context.Context(parent), key, val)) + return Context(goctx.WithValue(goctx.Context(parent), key, val)) } // WithCancel mimics the function from the context package. func WithCancel(parent Context) (Context, CancelFunc) { - ctx, fn := context.WithCancel(context.Context(parent)) + ctx, fn := goctx.WithCancel(goctx.Context(parent)) return Context(ctx), fn } // WithDeadline mimics the function from the context package. func WithDeadline(parent Context, t time.Time) (Context, CancelFunc) { - ctx, fn := context.WithDeadline(context.Context(parent), t) + ctx, fn := goctx.WithDeadline(goctx.Context(parent), t) return Context(ctx), fn } // WithTimeout mimics the function from the context package. func WithTimeout(parent Context, d time.Duration) (Context, CancelFunc) { - ctx, fn := context.WithTimeout(context.Context(parent), d) + ctx, fn := goctx.WithTimeout(goctx.Context(parent), d) return Context(ctx), fn } //////////////////////////////////////////////////////////////////////////////// -type ctxKey int - type mutVal struct { l sync.RWMutex v interface{} } -type ctxState struct { +type context struct { + goctx.Context + path []string l sync.RWMutex - parent Context + parent *context children map[string]Context mutL sync.RWMutex mutVals map[interface{}]*mutVal } -func getCtxState(ctx Context) *ctxState { - s, _ := ctx.Value(ctxKey(0)).(*ctxState) - if s == nil { - panic("non-conforming context used") - } - return s -} - -func withCtxState(ctx Context, s *ctxState) Context { - return WithValue(ctx, ctxKey(0), s) -} - // New returns a new context which can be used as the root context for all // purposes in this framework. func New() Context { - return withCtxState(Context(context.Background()), &ctxState{}) + return &context{Context: goctx.Background()} +} + +func getCtx(Ctx Context) *context { + ctx, ok := Ctx.(*context) + if !ok { + panic("non-conforming Context used") + } + return ctx } // Path returns the sequence of names which were used to produce this context // via the ChildOf function. -func Path(ctx Context) []string { - return getCtxState(ctx).path +func Path(Ctx Context) []string { + return getCtx(Ctx).path } // Children returns all children of this context which have been created by // ChildOf, mapped by their name. -func Children(ctx Context) map[string]Context { - s := getCtxState(ctx) +func Children(Ctx Context) map[string]Context { + ctx := getCtx(Ctx) out := map[string]Context{} - s.l.RLock() - defer s.l.RUnlock() - for name, childCtx := range s.children { + ctx.l.RLock() + defer ctx.l.RUnlock() + for name, childCtx := range ctx.children { out[name] = childCtx } return out @@ -104,21 +101,21 @@ func Children(ctx Context) map[string]Context { // Parent returns the parent Context of the given one, or nil if this is a root // context (i.e. returned from New). -func Parent(ctx Context) Context { - return getCtxState(ctx).parent +func Parent(Ctx Context) Context { + return getCtx(Ctx).parent } // Root returns the root Context from which this Context and all of its parents // were derived (i.e. the Context which was originally returned from New). // // If the given Context is the root then it is returned as-id. -func Root(ctx Context) Context { +func Root(Ctx Context) Context { + ctx := getCtx(Ctx) for { - s := getCtxState(ctx) - if s.parent == nil { + if ctx.parent == nil { return ctx } - ctx = s.parent + ctx = ctx.parent } } @@ -129,26 +126,25 @@ func Root(ctx Context) Context { // // TODO If the given Context already has a child with the given name that child // will be returned. -func ChildOf(ctx Context, name string) Context { - s, childS := getCtxState(ctx), new(ctxState) +func ChildOf(Ctx Context, name string) Context { + ctx, childCtx := getCtx(Ctx), new(context) - s.l.Lock() - defer s.l.Unlock() + ctx.l.Lock() + defer ctx.l.Unlock() // set child's path field - childS.path = make([]string, 0, len(s.path)+1) - childS.path = append(childS.path, s.path...) - childS.path = append(childS.path, name) + childCtx.path = make([]string, 0, len(ctx.path)+1) + childCtx.path = append(childCtx.path, ctx.path...) + childCtx.path = append(childCtx.path, name) // set child's parent field - childS.parent = ctx + childCtx.parent = ctx // create child's ctx and store it in parent - childCtx := withCtxState(ctx, childS) - if s.children == nil { - s.children = map[string]Context{} + if ctx.children == nil { + ctx.children = map[string]Context{} } - s.children[name] = childCtx + ctx.children[name] = childCtx return childCtx } @@ -157,8 +153,8 @@ func ChildOf(ctx Context, name string) Context { // function returns without visiting any more Contexts. // // The exact order of visitation is non-deterministic. -func BreadthFirstVisit(ctx Context, callback func(Context) bool) { - queue := []Context{ctx} +func BreadthFirstVisit(Ctx Context, callback func(Context) bool) { + queue := []Context{Ctx} for len(queue) > 0 { if !callback(queue[0]) { return @@ -175,13 +171,13 @@ func BreadthFirstVisit(ctx Context, callback func(Context) bool) { // MutableValue acts like the Value method, except that it only deals with // keys/values set by SetMutableValue. -func MutableValue(ctx Context, key interface{}) interface{} { - s := getCtxState(ctx) - s.mutL.RLock() - defer s.mutL.RUnlock() - if s.mutVals == nil { +func MutableValue(Ctx Context, key interface{}) interface{} { + ctx := getCtx(Ctx) + ctx.mutL.RLock() + defer ctx.mutL.RUnlock() + if ctx.mutVals == nil { return nil - } else if mVal, ok := s.mutVals[key]; ok { + } else if mVal, ok := ctx.mutVals[key]; ok { mVal.l.RLock() defer mVal.l.RUnlock() return mVal.v @@ -205,28 +201,28 @@ func MutableValue(ctx Context, key interface{}) interface{} { // Context, except for those related to mutable values for this same key (e.g. // MutableValue and SetMutableValue). func GetSetMutableValue( - ctx Context, noCallbackIfSet bool, + Ctx Context, noCallbackIfSet bool, key interface{}, fn func(interface{}) interface{}, ) interface{} { - s := getCtxState(ctx) // if noCallbackIfSet, do a fast lookup with MutableValue first. if noCallbackIfSet { - if v := MutableValue(ctx, key); v != nil { + if v := MutableValue(Ctx, key); v != nil { return v } } - s.mutL.Lock() - if s.mutVals == nil { - s.mutVals = map[interface{}]*mutVal{} + ctx := getCtx(Ctx) + ctx.mutL.Lock() + if ctx.mutVals == nil { + ctx.mutVals = map[interface{}]*mutVal{} } - mVal, ok := s.mutVals[key] + mVal, ok := ctx.mutVals[key] if !ok { mVal = new(mutVal) - s.mutVals[key] = mVal + ctx.mutVals[key] = mVal } - s.mutL.Unlock() + ctx.mutL.Unlock() mVal.l.Lock() defer mVal.l.Unlock()