mctx: refactor implementation of Context to not use a state stored as a value, but to just wrap a context.Context

This commit is contained in:
Brian Picciano 2019-02-03 19:25:46 -05:00
parent 5dc8d92987
commit 0c2c49501e

View File

@ -10,93 +10,90 @@
package mctx package mctx
import ( import (
"context"
"sync" "sync"
"time" "time"
goctx "context"
) )
// Context is the same as the builtin type, but is used to indicate that the // Context is the same as the builtin type, but is used to indicate that the
// Context originally came from this package (aka New or ChildOf). // Context originally came from this package (aka New or ChildOf).
type Context context.Context type Context goctx.Context
// CancelFunc is a direct alias of the type from the context package, see its // CancelFunc is a direct alias of the type from the context package, see its
// docs. // docs.
type CancelFunc = context.CancelFunc type CancelFunc = goctx.CancelFunc
// WithValue mimics the function from the context package. // WithValue mimics the function from the context package.
func WithValue(parent Context, key, val interface{}) Context { func WithValue(parent Context, key, val interface{}) Context {
return Context(context.WithValue(context.Context(parent), key, val)) return Context(goctx.WithValue(goctx.Context(parent), key, val))
} }
// WithCancel mimics the function from the context package. // WithCancel mimics the function from the context package.
func WithCancel(parent Context) (Context, CancelFunc) { func WithCancel(parent Context) (Context, CancelFunc) {
ctx, fn := context.WithCancel(context.Context(parent)) ctx, fn := goctx.WithCancel(goctx.Context(parent))
return Context(ctx), fn return Context(ctx), fn
} }
// WithDeadline mimics the function from the context package. // WithDeadline mimics the function from the context package.
func WithDeadline(parent Context, t time.Time) (Context, CancelFunc) { func WithDeadline(parent Context, t time.Time) (Context, CancelFunc) {
ctx, fn := context.WithDeadline(context.Context(parent), t) ctx, fn := goctx.WithDeadline(goctx.Context(parent), t)
return Context(ctx), fn return Context(ctx), fn
} }
// WithTimeout mimics the function from the context package. // WithTimeout mimics the function from the context package.
func WithTimeout(parent Context, d time.Duration) (Context, CancelFunc) { func WithTimeout(parent Context, d time.Duration) (Context, CancelFunc) {
ctx, fn := context.WithTimeout(context.Context(parent), d) ctx, fn := goctx.WithTimeout(goctx.Context(parent), d)
return Context(ctx), fn return Context(ctx), fn
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
type ctxKey int
type mutVal struct { type mutVal struct {
l sync.RWMutex l sync.RWMutex
v interface{} v interface{}
} }
type ctxState struct { type context struct {
goctx.Context
path []string path []string
l sync.RWMutex l sync.RWMutex
parent Context parent *context
children map[string]Context children map[string]Context
mutL sync.RWMutex mutL sync.RWMutex
mutVals map[interface{}]*mutVal mutVals map[interface{}]*mutVal
} }
func getCtxState(ctx Context) *ctxState {
s, _ := ctx.Value(ctxKey(0)).(*ctxState)
if s == nil {
panic("non-conforming context used")
}
return s
}
func withCtxState(ctx Context, s *ctxState) Context {
return WithValue(ctx, ctxKey(0), s)
}
// New returns a new context which can be used as the root context for all // New returns a new context which can be used as the root context for all
// purposes in this framework. // purposes in this framework.
func New() Context { func New() Context {
return withCtxState(Context(context.Background()), &ctxState{}) return &context{Context: goctx.Background()}
}
func getCtx(Ctx Context) *context {
ctx, ok := Ctx.(*context)
if !ok {
panic("non-conforming Context used")
}
return ctx
} }
// Path returns the sequence of names which were used to produce this context // Path returns the sequence of names which were used to produce this context
// via the ChildOf function. // via the ChildOf function.
func Path(ctx Context) []string { func Path(Ctx Context) []string {
return getCtxState(ctx).path return getCtx(Ctx).path
} }
// Children returns all children of this context which have been created by // Children returns all children of this context which have been created by
// ChildOf, mapped by their name. // ChildOf, mapped by their name.
func Children(ctx Context) map[string]Context { func Children(Ctx Context) map[string]Context {
s := getCtxState(ctx) ctx := getCtx(Ctx)
out := map[string]Context{} out := map[string]Context{}
s.l.RLock() ctx.l.RLock()
defer s.l.RUnlock() defer ctx.l.RUnlock()
for name, childCtx := range s.children { for name, childCtx := range ctx.children {
out[name] = childCtx out[name] = childCtx
} }
return out return out
@ -104,21 +101,21 @@ func Children(ctx Context) map[string]Context {
// Parent returns the parent Context of the given one, or nil if this is a root // Parent returns the parent Context of the given one, or nil if this is a root
// context (i.e. returned from New). // context (i.e. returned from New).
func Parent(ctx Context) Context { func Parent(Ctx Context) Context {
return getCtxState(ctx).parent return getCtx(Ctx).parent
} }
// Root returns the root Context from which this Context and all of its parents // Root returns the root Context from which this Context and all of its parents
// were derived (i.e. the Context which was originally returned from New). // were derived (i.e. the Context which was originally returned from New).
// //
// If the given Context is the root then it is returned as-id. // If the given Context is the root then it is returned as-id.
func Root(ctx Context) Context { func Root(Ctx Context) Context {
ctx := getCtx(Ctx)
for { for {
s := getCtxState(ctx) if ctx.parent == nil {
if s.parent == nil {
return ctx return ctx
} }
ctx = s.parent ctx = ctx.parent
} }
} }
@ -129,26 +126,25 @@ func Root(ctx Context) Context {
// //
// TODO If the given Context already has a child with the given name that child // TODO If the given Context already has a child with the given name that child
// will be returned. // will be returned.
func ChildOf(ctx Context, name string) Context { func ChildOf(Ctx Context, name string) Context {
s, childS := getCtxState(ctx), new(ctxState) ctx, childCtx := getCtx(Ctx), new(context)
s.l.Lock() ctx.l.Lock()
defer s.l.Unlock() defer ctx.l.Unlock()
// set child's path field // set child's path field
childS.path = make([]string, 0, len(s.path)+1) childCtx.path = make([]string, 0, len(ctx.path)+1)
childS.path = append(childS.path, s.path...) childCtx.path = append(childCtx.path, ctx.path...)
childS.path = append(childS.path, name) childCtx.path = append(childCtx.path, name)
// set child's parent field // set child's parent field
childS.parent = ctx childCtx.parent = ctx
// create child's ctx and store it in parent // create child's ctx and store it in parent
childCtx := withCtxState(ctx, childS) if ctx.children == nil {
if s.children == nil { ctx.children = map[string]Context{}
s.children = map[string]Context{}
} }
s.children[name] = childCtx ctx.children[name] = childCtx
return childCtx return childCtx
} }
@ -157,8 +153,8 @@ func ChildOf(ctx Context, name string) Context {
// function returns without visiting any more Contexts. // function returns without visiting any more Contexts.
// //
// The exact order of visitation is non-deterministic. // The exact order of visitation is non-deterministic.
func BreadthFirstVisit(ctx Context, callback func(Context) bool) { func BreadthFirstVisit(Ctx Context, callback func(Context) bool) {
queue := []Context{ctx} queue := []Context{Ctx}
for len(queue) > 0 { for len(queue) > 0 {
if !callback(queue[0]) { if !callback(queue[0]) {
return return
@ -175,13 +171,13 @@ func BreadthFirstVisit(ctx Context, callback func(Context) bool) {
// MutableValue acts like the Value method, except that it only deals with // MutableValue acts like the Value method, except that it only deals with
// keys/values set by SetMutableValue. // keys/values set by SetMutableValue.
func MutableValue(ctx Context, key interface{}) interface{} { func MutableValue(Ctx Context, key interface{}) interface{} {
s := getCtxState(ctx) ctx := getCtx(Ctx)
s.mutL.RLock() ctx.mutL.RLock()
defer s.mutL.RUnlock() defer ctx.mutL.RUnlock()
if s.mutVals == nil { if ctx.mutVals == nil {
return nil return nil
} else if mVal, ok := s.mutVals[key]; ok { } else if mVal, ok := ctx.mutVals[key]; ok {
mVal.l.RLock() mVal.l.RLock()
defer mVal.l.RUnlock() defer mVal.l.RUnlock()
return mVal.v return mVal.v
@ -205,28 +201,28 @@ func MutableValue(ctx Context, key interface{}) interface{} {
// Context, except for those related to mutable values for this same key (e.g. // Context, except for those related to mutable values for this same key (e.g.
// MutableValue and SetMutableValue). // MutableValue and SetMutableValue).
func GetSetMutableValue( func GetSetMutableValue(
ctx Context, noCallbackIfSet bool, Ctx Context, noCallbackIfSet bool,
key interface{}, fn func(interface{}) interface{}, key interface{}, fn func(interface{}) interface{},
) interface{} { ) interface{} {
s := getCtxState(ctx)
// if noCallbackIfSet, do a fast lookup with MutableValue first. // if noCallbackIfSet, do a fast lookup with MutableValue first.
if noCallbackIfSet { if noCallbackIfSet {
if v := MutableValue(ctx, key); v != nil { if v := MutableValue(Ctx, key); v != nil {
return v return v
} }
} }
s.mutL.Lock() ctx := getCtx(Ctx)
if s.mutVals == nil { ctx.mutL.Lock()
s.mutVals = map[interface{}]*mutVal{} if ctx.mutVals == nil {
ctx.mutVals = map[interface{}]*mutVal{}
} }
mVal, ok := s.mutVals[key] mVal, ok := ctx.mutVals[key]
if !ok { if !ok {
mVal = new(mutVal) mVal = new(mutVal)
s.mutVals[key] = mVal ctx.mutVals[key] = mVal
} }
s.mutL.Unlock() ctx.mutL.Unlock()
mVal.l.Lock() mVal.l.Lock()
defer mVal.l.Unlock() defer mVal.l.Unlock()