diff --git a/cmd/totp-proxy/main.go b/cmd/totp-proxy/main.go index 7607067..6326740 100644 --- a/cmd/totp-proxy/main.go +++ b/cmd/totp-proxy/main.go @@ -12,6 +12,7 @@ package main */ import ( + "context" "net/http" "net/url" "time" @@ -19,7 +20,6 @@ import ( "github.com/mediocregopher/mediocre-go-lib/m" "github.com/mediocregopher/mediocre-go-lib/mcfg" "github.com/mediocregopher/mediocre-go-lib/mcrypto" - "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/mhttp" "github.com/mediocregopher/mediocre-go-lib/mlog" "github.com/mediocregopher/mediocre-go-lib/mrand" @@ -30,27 +30,26 @@ import ( func main() { ctx := m.NewServiceCtx() - logger := mlog.From(ctx) - cookieName := mcfg.String(ctx, "cookie-name", "_totp_proxy", "String to use as the name for cookies") - cookieTimeout := mcfg.Duration(ctx, "cookie-timeout", mtime.Duration{1 * time.Hour}, "Timeout for cookies") + ctx, cookieName := mcfg.String(ctx, "cookie-name", "_totp_proxy", "String to use as the name for cookies") + ctx, cookieTimeout := mcfg.Duration(ctx, "cookie-timeout", mtime.Duration{1 * time.Hour}, "Timeout for cookies") var userSecrets map[string]string - mcfg.RequiredJSON(ctx, "users", &userSecrets, "JSON object which maps usernames to their TOTP secret strings") + ctx = mcfg.RequiredJSON(ctx, "users", &userSecrets, "JSON object which maps usernames to their TOTP secret strings") var secret mcrypto.Secret - secretStr := mcfg.String(ctx, "secret", "", "String used to sign authentication tokens. If one isn't given a new one will be generated on each startup, invalidating all previous tokens.") - mrun.OnStart(ctx, func(mctx.Context) error { + ctx, secretStr := mcfg.String(ctx, "secret", "", "String used to sign authentication tokens. If one isn't given a new one will be generated on each startup, invalidating all previous tokens.") + ctx = mrun.OnStart(ctx, func(context.Context) error { if *secretStr == "" { *secretStr = mrand.Hex(32) } - logger.Info("generating secret") + mlog.Info(ctx, "generating secret") secret = mcrypto.NewSecret([]byte(*secretStr)) return nil }) proxyHandler := new(struct{ http.Handler }) - proxyURL := mcfg.RequiredString(ctx, "dst-url", "URL to proxy requests to. Only the scheme and host should be set.") - mrun.OnStart(ctx, func(mctx.Context) error { + ctx, proxyURL := mcfg.RequiredString(ctx, "dst-url", "URL to proxy requests to. Only the scheme and host should be set.") + ctx = mrun.OnStart(ctx, func(context.Context) error { u, err := url.Parse(*proxyURL) if err != nil { return err @@ -61,8 +60,8 @@ func main() { authHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // TODO mlog.FromHTTP? - authLogger := logger.Clone() - authLogger.SetKV(mlog.CtxKV(r.Context())) + // TODO annotate this ctx + ctx := r.Context() unauthorized := func() { w.Header().Add("WWW-Authenticate", "Basic") @@ -80,7 +79,7 @@ func main() { } if cookie, _ := r.Cookie(*cookieName); cookie != nil { - authLogger.Debug("authenticating with cookie", mlog.KV{"cookie": cookie.String()}) + mlog.Debug(ctx, "authenticating with cookie", mlog.KV{"cookie": cookie.String()}) var sig mcrypto.Signature if err := sig.UnmarshalText([]byte(cookie.Value)); err == nil { err := mcrypto.VerifyString(secret, sig, "") @@ -92,7 +91,7 @@ func main() { } if user, pass, ok := r.BasicAuth(); ok && pass != "" { - logger.Debug("authenticating with user/pass", mlog.KV{ + mlog.Debug(ctx, "authenticating with user/pass", mlog.KV{ "user": user, "pass": pass, }) @@ -107,6 +106,6 @@ func main() { unauthorized() }) - mhttp.MListenAndServe(ctx, authHandler) + ctx, _ = mhttp.MListenAndServe(ctx, authHandler) m.Run(ctx) } diff --git a/m/m.go b/m/m.go index 23ef500..c50675f 100644 --- a/m/m.go +++ b/m/m.go @@ -5,11 +5,11 @@ package m import ( + "context" "os" "os/signal" "github.com/mediocregopher/mediocre-go-lib/mcfg" - "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/merr" "github.com/mediocregopher/mediocre-go-lib/mlog" "github.com/mediocregopher/mediocre-go-lib/mrun" @@ -32,20 +32,19 @@ func CfgSource() mcfg.Source { // debug information about the running process can be accessed. // // TODO set up the debug endpoint. -func NewServiceCtx() mctx.Context { - ctx := mctx.New() +func NewServiceCtx() context.Context { + ctx := context.Background() // set up log level handling - logLevelStr := mcfg.String(ctx, "log-level", "info", "Maximum log level which will be printed.") - mrun.OnStart(ctx, func(mctx.Context) error { + logger := mlog.NewLogger() + ctx = mlog.Set(ctx, logger) + ctx, logLevelStr := mcfg.String(ctx, "log-level", "info", "Maximum log level which will be printed.") + ctx = mrun.OnStart(ctx, func(context.Context) error { logLevel := mlog.LevelFromString(*logLevelStr) if logLevel == nil { return merr.New("invalid log level", "log-level", *logLevelStr) } - mlog.CtxSetAll(ctx, func(_ mctx.Context, logger *mlog.Logger) *mlog.Logger { - logger.SetMaxLevel(logLevel) - return logger - }) + logger.SetMaxLevel(logLevel) return nil }) @@ -56,23 +55,23 @@ func NewServiceCtx() mctx.Context { // start event, waiting for an interrupt, and then triggering the stop event. // Run will block until the stop event is done. If any errors are encountered a // fatal is thrown. -func Run(ctx mctx.Context) { +func Run(ctx context.Context) { log := mlog.From(ctx) if err := mcfg.Populate(ctx, CfgSource()); err != nil { - log.Fatal("error populating configuration", merr.KV(err)) + log.Fatal(ctx, "error populating configuration", merr.KV(err)) } else if err := mrun.Start(ctx); err != nil { - log.Fatal("error triggering start event", merr.KV(err)) + log.Fatal(ctx, "error triggering start event", merr.KV(err)) } { ch := make(chan os.Signal, 1) signal.Notify(ch, os.Interrupt) s := <-ch - log.Info("signal received, stopping", mlog.KV{"signal": s}) + log.Info(ctx, "signal received, stopping", mlog.KV{"signal": s}) } if err := mrun.Stop(ctx); err != nil { - log.Fatal("error triggering stop event", merr.KV(err)) + log.Fatal(ctx, "error triggering stop event", merr.KV(err)) } - log.Info("exiting process") + log.Info(ctx, "exiting process") } diff --git a/m/m_test.go b/m/m_test.go index f20288a..ad6d84a 100644 --- a/m/m_test.go +++ b/m/m_test.go @@ -26,7 +26,8 @@ func TestServiceCtx(t *T) { // create a child Context before running to ensure it the change propagates // correctly. - ctxA := mctx.ChildOf(ctx, "A") + ctxA := mctx.NewChild(ctx, "A") + ctx = mctx.WithChild(ctx, ctxA) params := mcfg.ParamValues{{Name: "log-level", Value: json.RawMessage(`"DEBUG"`)}} if err := mcfg.Populate(ctx, params); err != nil { @@ -35,14 +36,16 @@ func TestServiceCtx(t *T) { t.Fatal(err) } - mlog.From(ctxA).Info("foo") - mlog.From(ctxA).Debug("bar") + mlog.From(ctxA).Info(ctxA, "foo") + mlog.From(ctxA).Debug(ctxA, "bar") massert.Fatal(t, massert.All( massert.Len(msgs, 2), massert.Equal(msgs[0].Level.String(), "INFO"), - massert.Equal(msgs[0].Description.String(), "(/A) foo"), + massert.Equal(msgs[0].Description, "foo"), + massert.Equal(msgs[0].Context, ctxA), massert.Equal(msgs[1].Level.String(), "DEBUG"), - massert.Equal(msgs[1].Description.String(), "(/A) bar"), + massert.Equal(msgs[1].Description, "bar"), + massert.Equal(msgs[1].Context, ctxA), )) }) } diff --git a/mcfg/cli.go b/mcfg/cli.go index 28b1ed5..bd0d9ad 100644 --- a/mcfg/cli.go +++ b/mcfg/cli.go @@ -8,6 +8,7 @@ import ( "sort" "strings" + "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/merr" ) @@ -108,7 +109,7 @@ func (cli SourceCLI) Parse(params []Param) ([]ParamValue, error) { pvs = append(pvs, ParamValue{ Name: p.Name, - Path: p.Path, + Path: mctx.Path(p.Context), Value: p.fuzzyParse(pvStrVal), }) @@ -127,7 +128,7 @@ func (cli SourceCLI) Parse(params []Param) ([]ParamValue, error) { func (cli SourceCLI) cliParams(params []Param) (map[string]Param, error) { m := map[string]Param{} for _, p := range params { - key := strings.Join(append(p.Path, p.Name), cliKeyJoin) + key := strings.Join(append(mctx.Path(p.Context), p.Name), cliKeyJoin) m[cliKeyPrefix+key] = p } return m, nil diff --git a/mcfg/cli_test.go b/mcfg/cli_test.go index 09d3960..892f7a2 100644 --- a/mcfg/cli_test.go +++ b/mcfg/cli_test.go @@ -2,11 +2,11 @@ package mcfg import ( "bytes" + "context" "strings" . "testing" "time" - "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/mrand" "github.com/mediocregopher/mediocre-go-lib/mtest/mchk" "github.com/stretchr/testify/assert" @@ -14,12 +14,12 @@ import ( ) func TestSourceCLIHelp(t *T) { - ctx := mctx.New() - Int(ctx, "foo", 5, "Test int param ") // trailing space should be trimmed - Bool(ctx, "bar", "Test bool param.") - String(ctx, "baz", "baz", "Test string param") - RequiredString(ctx, "baz2", "") - RequiredString(ctx, "baz3", "") + ctx := context.Background() + ctx, _ = Int(ctx, "foo", 5, "Test int param ") // trailing space should be trimmed + ctx, _ = Bool(ctx, "bar", "Test bool param.") + ctx, _ = String(ctx, "baz", "baz", "Test string param") + ctx, _ = RequiredString(ctx, "baz2", "") + ctx, _ = RequiredString(ctx, "baz3", "") src := SourceCLI{} buf := new(bytes.Buffer) diff --git a/mcfg/env.go b/mcfg/env.go index 3eb977b..f1ab14b 100644 --- a/mcfg/env.go +++ b/mcfg/env.go @@ -4,6 +4,7 @@ import ( "os" "strings" + "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/merr" ) @@ -49,7 +50,7 @@ func (env SourceEnv) Parse(params []Param) ([]ParamValue, error) { pM := map[string]Param{} for _, p := range params { - name := env.expectedName(p.Path, p.Name) + name := env.expectedName(mctx.Path(p.Context), p.Name) pM[name] = p } @@ -63,7 +64,7 @@ func (env SourceEnv) Parse(params []Param) ([]ParamValue, error) { if p, ok := pM[k]; ok { pvs = append(pvs, ParamValue{ Name: p.Name, - Path: p.Path, + Path: mctx.Path(p.Context), Value: p.fuzzyParse(v), }) } diff --git a/mcfg/mcfg.go b/mcfg/mcfg.go index b519fe2..1a23479 100644 --- a/mcfg/mcfg.go +++ b/mcfg/mcfg.go @@ -3,6 +3,7 @@ package mcfg import ( + "context" "crypto/md5" "encoding/hex" "encoding/json" @@ -17,28 +18,10 @@ import ( // - JSON file // - YAML file -type ctxCfg struct { - path []string - params map[string]Param -} - -type ctxKey int - -func get(ctx mctx.Context) *ctxCfg { - return mctx.GetSetMutableValue(ctx, true, ctxKey(0), - func(interface{}) interface{} { - return &ctxCfg{ - path: mctx.Path(ctx), - params: map[string]Param{}, - } - }, - ).(*ctxCfg) -} - func sortParams(params []Param) { sort.Slice(params, func(i, j int) bool { a, b := params[i], params[j] - aPath, bPath := a.Path, b.Path + aPath, bPath := mctx.Path(a.Context), mctx.Path(b.Context) for { switch { case len(aPath) == 0 && len(bPath) == 0: @@ -59,12 +42,12 @@ func sortParams(params []Param) { // returns all Params gathered by recursively retrieving them from this Context // and its children. Returned Params are sorted according to their Path and // Name. -func collectParams(ctx mctx.Context) []Param { +func collectParams(ctx context.Context) []Param { var params []Param - var visit func(mctx.Context) - visit = func(ctx mctx.Context) { - for _, param := range get(ctx).params { + var visit func(context.Context) + visit = func(ctx context.Context) { + for _, param := range getLocalParams(ctx) { params = append(params, param) } @@ -98,9 +81,10 @@ func populate(params []Param, src Source) error { // later. There should not be any duplicates here. pM := map[string]Param{} for _, p := range params { - hash := paramHash(p.Path, p.Name) + path := mctx.Path(p.Context) + hash := paramHash(path, p.Name) if _, ok := pM[hash]; ok { - panic("duplicate Param: " + paramFullName(p.Path, p.Name)) + panic("duplicate Param: " + paramFullName(path, p.Name)) } pM[hash] = p } @@ -127,7 +111,7 @@ func populate(params []Param, src Source) error { continue } else if _, ok := pvM[hash]; !ok { return merr.New("required parameter is not set", - "param", paramFullName(p.Path, p.Name)) + "param", paramFullName(mctx.Path(p.Context), p.Name)) } } @@ -144,12 +128,12 @@ func populate(params []Param, src Source) error { } // Populate uses the Source to populate the values of all Params which were -// added to the given mctx.Context, and all of its children. Populate may be -// called multiple times with the same mctx.Context, each time will only affect -// the values of the Params which were provided by the respective Source. +// added to the given Context, and all of its children. Populate may be called +// multiple times with the same Context, each time will only affect the values +// of the Params which were provided by the respective Source. // // Source may be nil to indicate that no configuration is provided. Only default // values will be used, and if any paramaters are required this will error. -func Populate(ctx mctx.Context, src Source) error { +func Populate(ctx context.Context, src Source) error { return populate(collectParams(ctx), src) } diff --git a/mcfg/mcfg_test.go b/mcfg/mcfg_test.go index 15e5043..9ad6df1 100644 --- a/mcfg/mcfg_test.go +++ b/mcfg/mcfg_test.go @@ -1,6 +1,7 @@ package mcfg import ( + "context" . "testing" "github.com/mediocregopher/mediocre-go-lib/mctx" @@ -9,11 +10,12 @@ import ( func TestPopulate(t *T) { { - ctx := mctx.New() - a := Int(ctx, "a", 0, "") - ctxChild := mctx.ChildOf(ctx, "foo") - b := Int(ctxChild, "b", 0, "") - c := Int(ctxChild, "c", 0, "") + ctx := context.Background() + ctx, a := Int(ctx, "a", 0, "") + ctxChild := mctx.NewChild(ctx, "foo") + ctxChild, b := Int(ctxChild, "b", 0, "") + ctxChild, c := Int(ctxChild, "c", 0, "") + ctx = mctx.WithChild(ctx, ctxChild) err := Populate(ctx, SourceCLI{ Args: []string{"--a=1", "--foo-b=2"}, @@ -25,11 +27,12 @@ func TestPopulate(t *T) { } { // test that required params are enforced - ctx := mctx.New() - a := Int(ctx, "a", 0, "") - ctxChild := mctx.ChildOf(ctx, "foo") - b := Int(ctxChild, "b", 0, "") - c := RequiredInt(ctxChild, "c", "") + ctx := context.Background() + ctx, a := Int(ctx, "a", 0, "") + ctxChild := mctx.NewChild(ctx, "foo") + ctxChild, b := Int(ctxChild, "b", 0, "") + ctxChild, c := RequiredInt(ctxChild, "c", "") + ctx = mctx.WithChild(ctx, ctxChild) err := Populate(ctx, SourceCLI{ Args: []string{"--a=1", "--foo-b=2"}, diff --git a/mcfg/param.go b/mcfg/param.go index abc6101..b0223ce 100644 --- a/mcfg/param.go +++ b/mcfg/param.go @@ -1,6 +1,7 @@ package mcfg import ( + "context" "encoding/json" "fmt" "strings" @@ -10,15 +11,16 @@ import ( ) // Param is a configuration parameter which can be populated by Populate. The -// Param will exist as part of an mctx.Context, relative to its Path. For -// example, a Param with name "addr" under an mctx.Context with Path of -// []string{"foo","bar"} will be setabble on the CLI via "--foo-bar-addr". Other -// configuration Sources may treat the path/name differently, however. +// Param will exist as part of a Context, relative to its path (see the mctx +// package for more on Context path). For example, a Param with name "addr" +// under a Context with path of []string{"foo","bar"} will be setable on the CLI +// via "--foo-bar-addr". Other configuration Sources may treat the path/name +// differently, however. // // Param values are always unmarshaled as JSON values into the Into field of the // Param, regardless of the actual Source. type Param struct { - // How the parameter will be identified within an mctx.Context. + // How the parameter will be identified within a Context. Name string // A helpful description of how a parameter is expected to be used. @@ -42,9 +44,9 @@ type Param struct { // value of the parameter. Into interface{} - // The Path field of the Cfg this Param is attached to. NOTE that this - // will be automatically filled in when the Param is added to the Cfg. - Path []string + // The Context this Param was added to. NOTE that this will be automatically + // filled in by MustAdd when the Param is added to the Context. + Context context.Context } func paramFullName(path []string, name string) string { @@ -65,118 +67,145 @@ func (p Param) fuzzyParse(v string) json.RawMessage { return json.RawMessage(v) } -// MustAdd adds the given Param to the mctx.Context. It will panic if a Param of -// the same Name already exists in the mctx.Context. -func MustAdd(ctx mctx.Context, param Param) { +type ctxKey string + +func getParam(ctx context.Context, name string) (Param, bool) { + param, ok := mctx.LocalValue(ctx, ctxKey(name)).(Param) + return param, ok +} + +// MustAdd returns a Context with the given Param added to it. It will panic if +// a Param with the same Name already exists in the Context. +func MustAdd(ctx context.Context, param Param) context.Context { param.Name = strings.ToLower(param.Name) - param.Path = mctx.Path(ctx) + param.Context = ctx - cfg := get(ctx) - if _, ok := cfg.params[param.Name]; ok { - panic(fmt.Sprintf("Context Path:%#v Name:%q already exists", param.Path, param.Name)) + if _, ok := getParam(ctx, param.Name); ok { + path := mctx.Path(ctx) + panic(fmt.Sprintf("Context Path:%#v Name:%q already exists", path, param.Name)) } - cfg.params[param.Name] = param + + return mctx.WithLocalValue(ctx, ctxKey(param.Name), param) } -// Int64 returns an *int64 which will be populated once Populate is run. -func Int64(ctx mctx.Context, name string, defaultVal int64, usage string) *int64 { +func getLocalParams(ctx context.Context) []Param { + localVals := mctx.LocalValues(ctx) + params := make([]Param, 0, len(localVals)) + for _, val := range localVals { + if param, ok := val.(Param); ok { + params = append(params, param) + } + } + return params +} + +// Int64 returns an *int64 which will be populated once Populate is run on the +// returned Context. +func Int64(ctx context.Context, name string, defaultVal int64, usage string) (context.Context, *int64) { i := defaultVal - MustAdd(ctx, Param{Name: name, Usage: usage, Into: &i}) - return &i + ctx = MustAdd(ctx, Param{Name: name, Usage: usage, Into: &i}) + return ctx, &i } -// RequiredInt64 returns an *int64 which will be populated once Populate is run, -// and which must be supplied by a configuration Source. -func RequiredInt64(ctx mctx.Context, name string, usage string) *int64 { +// RequiredInt64 returns an *int64 which will be populated once Populate is run +// on the returned Context, and which must be supplied by a configuration +// Source. +func RequiredInt64(ctx context.Context, name string, usage string) (context.Context, *int64) { var i int64 - MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &i}) - return &i + ctx = MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &i}) + return ctx, &i } -// Int returns an *int which will be populated once Populate is run. -func Int(ctx mctx.Context, name string, defaultVal int, usage string) *int { +// Int returns an *int which will be populated once Populate is run on the +// returned Context. +func Int(ctx context.Context, name string, defaultVal int, usage string) (context.Context, *int) { i := defaultVal - MustAdd(ctx, Param{Name: name, Usage: usage, Into: &i}) - return &i + ctx = MustAdd(ctx, Param{Name: name, Usage: usage, Into: &i}) + return ctx, &i } -// RequiredInt returns an *int which will be populated once Populate is run, and -// which must be supplied by a configuration Source. -func RequiredInt(ctx mctx.Context, name string, usage string) *int { +// RequiredInt returns an *int which will be populated once Populate is run on +// the returned Context, and which must be supplied by a configuration Source. +func RequiredInt(ctx context.Context, name string, usage string) (context.Context, *int) { var i int - MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &i}) - return &i + ctx = MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &i}) + return ctx, &i } -// String returns a *string which will be populated once Populate is run. -func String(ctx mctx.Context, name, defaultVal, usage string) *string { +// String returns a *string which will be populated once Populate is run on the +// returned Context. +func String(ctx context.Context, name, defaultVal, usage string) (context.Context, *string) { s := defaultVal - MustAdd(ctx, Param{Name: name, Usage: usage, IsString: true, Into: &s}) - return &s + ctx = MustAdd(ctx, Param{Name: name, Usage: usage, IsString: true, Into: &s}) + return ctx, &s } // RequiredString returns a *string which will be populated once Populate is -// run, and which must be supplied by a configuration Source. -func RequiredString(ctx mctx.Context, name, usage string) *string { +// run on the returned Context, and which must be supplied by a configuration +// Source. +func RequiredString(ctx context.Context, name, usage string) (context.Context, *string) { var s string - MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, IsString: true, Into: &s}) - return &s + ctx = MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, IsString: true, Into: &s}) + return ctx, &s } -// Bool returns a *bool which will be populated once Populate is run, and which -// defaults to false if unconfigured. +// Bool returns a *bool which will be populated once Populate is run on the +// returned Context, and which defaults to false if unconfigured. // // The default behavior of all Sources is that a boolean parameter will be set // to true unless the value is "", 0, or false. In the case of the CLI Source // the value will also be true when the parameter is used with no value at all, // as would be expected. -func Bool(ctx mctx.Context, name, usage string) *bool { +func Bool(ctx context.Context, name, usage string) (context.Context, *bool) { var b bool - MustAdd(ctx, Param{Name: name, Usage: usage, IsBool: true, Into: &b}) - return &b + ctx = MustAdd(ctx, Param{Name: name, Usage: usage, IsBool: true, Into: &b}) + return ctx, &b } -// TS returns an *mtime.TS which will be populated once Populate is run. -func TS(ctx mctx.Context, name string, defaultVal mtime.TS, usage string) *mtime.TS { +// TS returns an *mtime.TS which will be populated once Populate is run on the +// returned Context. +func TS(ctx context.Context, name string, defaultVal mtime.TS, usage string) (context.Context, *mtime.TS) { t := defaultVal - MustAdd(ctx, Param{Name: name, Usage: usage, Into: &t}) - return &t + ctx = MustAdd(ctx, Param{Name: name, Usage: usage, Into: &t}) + return ctx, &t } -// RequiredTS returns an *mtime.TS which will be populated once Populate is run, -// and which must be supplied by a configuration Source. -func RequiredTS(ctx mctx.Context, name, usage string) *mtime.TS { +// RequiredTS returns an *mtime.TS which will be populated once Populate is run +// on the returned Context, and which must be supplied by a configuration +// Source. +func RequiredTS(ctx context.Context, name, usage string) (context.Context, *mtime.TS) { var t mtime.TS - MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &t}) - return &t + ctx = MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &t}) + return ctx, &t } -// Duration returns an *mtime.Duration which will be populated once -// Populate is run. -func Duration(ctx mctx.Context, name string, defaultVal mtime.Duration, usage string) *mtime.Duration { +// Duration returns an *mtime.Duration which will be populated once Populate is +// run on the returned Context. +func Duration(ctx context.Context, name string, defaultVal mtime.Duration, usage string) (context.Context, *mtime.Duration) { d := defaultVal - MustAdd(ctx, Param{Name: name, Usage: usage, IsString: true, Into: &d}) - return &d + ctx = MustAdd(ctx, Param{Name: name, Usage: usage, IsString: true, Into: &d}) + return ctx, &d } // RequiredDuration returns an *mtime.Duration which will be populated once -// Populate is run, and which must be supplied by a configuration Source. -func RequiredDuration(ctx mctx.Context, name string, defaultVal mtime.Duration, usage string) *mtime.Duration { +// Populate is run on the returned Context, and which must be supplied by a +// configuration Source. +func RequiredDuration(ctx context.Context, name string, defaultVal mtime.Duration, usage string) (context.Context, *mtime.Duration) { var d mtime.Duration - MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, IsString: true, Into: &d}) - return &d + ctx = MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, IsString: true, Into: &d}) + return ctx, &d } // JSON reads the parameter value as a JSON value and unmarshals it into the // given interface{} (which should be a pointer). The receiver (into) is also // used to determine the default value. -func JSON(ctx mctx.Context, name string, into interface{}, usage string) { - MustAdd(ctx, Param{Name: name, Usage: usage, Into: into}) +func JSON(ctx context.Context, name string, into interface{}, usage string) context.Context { + return MustAdd(ctx, Param{Name: name, Usage: usage, Into: into}) } // RequiredJSON reads the parameter value as a JSON value and unmarshals it into // the given interface{} (which should be a pointer). The value must be supplied // by a configuration Source. -func RequiredJSON(ctx mctx.Context, name string, into interface{}, usage string) { - MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: into}) +func RequiredJSON(ctx context.Context, name string, into interface{}, usage string) context.Context { + return MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: into}) } diff --git a/mcfg/source_test.go b/mcfg/source_test.go index 3e0a5fe..b3ad605 100644 --- a/mcfg/source_test.go +++ b/mcfg/source_test.go @@ -1,6 +1,7 @@ package mcfg import ( + "context" "encoding/json" "fmt" . "testing" @@ -16,24 +17,35 @@ import ( // all the code they share type srcCommonState struct { - ctx mctx.Context - availCtxs []mctx.Context - expPVs []ParamValue + // availCtxs get updated in place as the run goes on, and mkRoot is used to + // create the latest version of the root context based on them + availCtxs []*context.Context + mkRoot func() context.Context + expPVs []ParamValue // each specific test should wrap this to add the Source itself } func newSrcCommonState() srcCommonState { var scs srcCommonState - scs.ctx = mctx.New() { - a := mctx.ChildOf(scs.ctx, "a") - b := mctx.ChildOf(scs.ctx, "b") - c := mctx.ChildOf(scs.ctx, "c") - ab := mctx.ChildOf(a, "b") - bc := mctx.ChildOf(b, "c") - abc := mctx.ChildOf(ab, "c") - scs.availCtxs = []mctx.Context{scs.ctx, a, b, c, ab, bc, abc} + root := context.Background() + a := mctx.NewChild(root, "a") + b := mctx.NewChild(root, "b") + c := mctx.NewChild(root, "c") + ab := mctx.NewChild(a, "b") + bc := mctx.NewChild(b, "c") + abc := mctx.NewChild(ab, "c") + scs.availCtxs = []*context.Context{&root, &a, &b, &c, &ab, &bc, &abc} + scs.mkRoot = func() context.Context { + ab := mctx.WithChild(ab, abc) + a := mctx.WithChild(a, ab) + b := mctx.WithChild(b, bc) + root := mctx.WithChild(root, a) + root = mctx.WithChild(root, b) + root = mctx.WithChild(root, c) + return root + } } return scs } @@ -57,10 +69,9 @@ func (scs srcCommonState) next() srcCommonParams { } p.availCtxI = mrand.Intn(len(scs.availCtxs)) - thisCtx := scs.availCtxs[p.availCtxI] - p.path = mctx.Path(thisCtx) + p.path = mctx.Path(*scs.availCtxs[p.availCtxI]) - p.isBool = mrand.Intn(2) == 0 + p.isBool = mrand.Intn(8) == 0 if !p.isBool { p.nonBoolType = mrand.Element([]string{ "int", @@ -105,11 +116,11 @@ func (scs srcCommonState) applyCtxAndPV(p srcCommonParams) srcCommonState { // the Sources don't actually care about the other fields of Param, // those are only used by Populate once it has all ParamValues together } - MustAdd(thisCtx, ctxP) - ctxP = get(thisCtx).params[p.name] // get it back out to get any added fields + *thisCtx = MustAdd(*thisCtx, ctxP) + ctxP, _ = getParam(*thisCtx, ctxP.Name) // get it back out to get any added fields if !p.unset { - pv := ParamValue{Name: ctxP.Name, Path: ctxP.Path} + pv := ParamValue{Name: ctxP.Name, Path: mctx.Path(ctxP.Context)} if p.isBool { pv.Value = json.RawMessage("true") } else { @@ -131,7 +142,8 @@ func (scs srcCommonState) applyCtxAndPV(p srcCommonParams) srcCommonState { // given a Source asserts that it's Parse method returns the expected // ParamValues func (scs srcCommonState) assert(s Source) error { - gotPVs, err := s.Parse(collectParams(scs.ctx)) + root := scs.mkRoot() + gotPVs, err := s.Parse(collectParams(root)) if err != nil { return err } @@ -142,10 +154,10 @@ func (scs srcCommonState) assert(s Source) error { } func TestSources(t *T) { - ctx := mctx.New() - a := RequiredInt(ctx, "a", "") - b := RequiredInt(ctx, "b", "") - c := RequiredInt(ctx, "c", "") + ctx := context.Background() + ctx, a := RequiredInt(ctx, "a", "") + ctx, b := RequiredInt(ctx, "b", "") + ctx, c := RequiredInt(ctx, "c", "") err := Populate(ctx, Sources{ SourceCLI{Args: []string{"--a=1", "--b=666"}}, @@ -160,11 +172,12 @@ func TestSources(t *T) { } func TestSourceParamValues(t *T) { - ctx := mctx.New() - a := RequiredInt(ctx, "a", "") - foo := mctx.ChildOf(ctx, "foo") - b := RequiredString(foo, "b", "") - c := Bool(foo, "c", "") + ctx := context.Background() + ctx, a := RequiredInt(ctx, "a", "") + foo := mctx.NewChild(ctx, "foo") + foo, b := RequiredString(foo, "b", "") + foo, c := Bool(foo, "c", "") + ctx = mctx.WithChild(ctx, foo) err := Populate(ctx, ParamValues{ {Name: "a", Value: json.RawMessage(`4`)}, diff --git a/mctx/ctx.go b/mctx/ctx.go index 4106391..de4e3a8 100644 --- a/mctx/ctx.go +++ b/mctx/ctx.go @@ -10,151 +10,163 @@ package mctx import ( - "sync" - "time" - - goctx "context" + "context" + "fmt" ) -// Context is the same as the builtin type, but is used to indicate that the -// Context originally came from this package (aka New or ChildOf). -type Context goctx.Context - -// CancelFunc is a direct alias of the type from the context package, see its -// docs. -type CancelFunc = goctx.CancelFunc - -// WithValue mimics the function from the context package. -func WithValue(parent Context, key, val interface{}) Context { - return Context(goctx.WithValue(goctx.Context(parent), key, val)) -} - -// WithCancel mimics the function from the context package. -func WithCancel(parent Context) (Context, CancelFunc) { - ctx, fn := goctx.WithCancel(goctx.Context(parent)) - return Context(ctx), fn -} - -// WithDeadline mimics the function from the context package. -func WithDeadline(parent Context, t time.Time) (Context, CancelFunc) { - ctx, fn := goctx.WithDeadline(goctx.Context(parent), t) - return Context(ctx), fn -} - -// WithTimeout mimics the function from the context package. -func WithTimeout(parent Context, d time.Duration) (Context, CancelFunc) { - ctx, fn := goctx.WithTimeout(goctx.Context(parent), d) - return Context(ctx), fn -} - //////////////////////////////////////////////////////////////////////////////// -type mutVal struct { - l sync.RWMutex - v interface{} -} - -type context struct { - goctx.Context - - path []string - l sync.RWMutex - parent *context - children map[string]Context - - mutL sync.RWMutex - mutVals map[interface{}]*mutVal -} - // New returns a new context which can be used as the root context for all // purposes in this framework. -func New() Context { - return &context{Context: goctx.Background()} -} +//func New() Context { +// return &context{Context: goctx.Background()} +//} -func getCtx(Ctx Context) *context { - ctx, ok := Ctx.(*context) +type ancestryKey int // 0 -> children, 1 -> parent, 2 -> path + +const ( + ancestryKeyChildren ancestryKey = iota + ancestryKeyChildrenMap + ancestryKeyParent + ancestryKeyPath +) + +// Child returns the Context of the given name which was added to parent via +// WithChild, or nil if no Context of that name was ever added. +func Child(parent context.Context, name string) context.Context { + childrenMap, _ := parent.Value(ancestryKeyChildrenMap).(map[string]int) + if len(childrenMap) == 0 { + return nil + } + i, ok := childrenMap[name] if !ok { - panic("non-conforming Context used") + return nil } - return ctx + return parent.Value(ancestryKeyChildren).([]context.Context)[i] } -// Path returns the sequence of names which were used to produce this context -// via the ChildOf function. -func Path(Ctx Context) []string { - return getCtx(Ctx).path +// Children returns all children of this Context which have been kept by +// WithChild, mapped by their name. If this Context wasn't produced by WithChild +// then this returns nil. +func Children(parent context.Context) []context.Context { + children, _ := parent.Value(ancestryKeyChildren).([]context.Context) + return children } -// Children returns all children of this context which have been created by -// ChildOf, mapped by their name. -func Children(Ctx Context) map[string]Context { - ctx := getCtx(Ctx) - out := map[string]Context{} - ctx.l.RLock() - defer ctx.l.RUnlock() - for name, childCtx := range ctx.children { - out[name] = childCtx +func childrenCP(parent context.Context) ([]context.Context, map[string]int) { + children := Children(parent) + // plus 1 because this is most commonly used in WithChild, which will append + // to it. At any rate it doesn't hurt anything. + outChildren := make([]context.Context, len(children), len(children)+1) + copy(outChildren, children) + + childrenMap, _ := parent.Value(ancestryKeyChildrenMap).(map[string]int) + outChildrenMap := make(map[string]int, len(childrenMap)+1) + for name, i := range childrenMap { + outChildrenMap[name] = i } - return out + + return outChildren, outChildrenMap } -// Parent returns the parent Context of the given one, or nil if this is a root -// context (i.e. returned from New). -func Parent(Ctx Context) Context { - return getCtx(Ctx).parent -} - -// Root returns the root Context from which this Context and all of its parents -// were derived (i.e. the Context which was originally returned from New). +// parentOf returns the Context from which this one was generated via NewChild. +// Returns nil if this Context was not generated via NewChild. // -// If the given Context is the root then it is returned as-id. -func Root(Ctx Context) Context { - ctx := getCtx(Ctx) - for { - if ctx.parent == nil { - return ctx +// This is kept private because the behavior is a bit confusing. This will +// return the Context which was passed into NewChild, but users would probably +// expect it to return the one from WithChild if they were to call this. +func parentOf(ctx context.Context) context.Context { + parent, _ := ctx.Value(ancestryKeyParent).(context.Context) + return parent +} + +// Path returns the sequence of names which were used to produce this Context +// via the NewChild function. If this Context wasn't produced by NewChild then +// this returns nil. +func Path(ctx context.Context) []string { + path, _ := ctx.Value(ancestryKeyPath).([]string) + return path +} + +func pathCP(ctx context.Context) []string { + path := Path(ctx) + // plus 1 because this is most commonly used in NewChild, which will append + // to it. At any rate it doesn't hurt anything. + outPath := make([]string, len(path), len(path)+1) + copy(outPath, path) + return outPath +} + +// Name returns the name this Context was generated with via NewChild, or false +// if this Context was not generated via NewChild. +func Name(ctx context.Context) (string, bool) { + path := Path(ctx) + if len(path) == 0 { + return "", false + } + return path[len(path)-1], true +} + +// NewChild creates a new Context based off of the parent one, and returns a new +// instance of the passed in parent and the new child. The child will have a +// path which is the parent's path with the given name appended. The parent will +// have the new child as part of its set of children (see Children function). +// +// If the parent already has a child of the given name this function panics. +func NewChild(parent context.Context, name string) context.Context { + if Child(parent, name) != nil { + panic(fmt.Sprintf("child with name %q already exists on parent", name)) + } + + childPath := append(pathCP(parent), name) + child := withoutLocalValues(parent) + child = context.WithValue(child, ancestryKeyChildren, nil) // unset children + child = context.WithValue(child, ancestryKeyChildrenMap, nil) // unset children + child = context.WithValue(child, ancestryKeyParent, parent) + child = context.WithValue(child, ancestryKeyPath, childPath) + return child +} + +func isChild(parent, child context.Context) bool { + parentPath, childPath := Path(parent), Path(child) + if len(parentPath) != len(childPath)-1 { + return false + } + + for i := range parentPath { + if parentPath[i] != childPath[i] { + return false } - ctx = ctx.parent } + return true } -// ChildOf creates a child of the given context with the given name and returns -// it. The Path of the returned context will be the path of the parent with its -// name appended to it. The Children function can be called on the parent to -// retrieve all children which have been made using this function. -// -// TODO If the given Context already has a child with the given name that child -// will be returned. -func ChildOf(Ctx Context, name string) Context { - ctx, childCtx := getCtx(Ctx), new(context) - - ctx.l.Lock() - defer ctx.l.Unlock() - - // set child's path field - childCtx.path = make([]string, 0, len(ctx.path)+1) - childCtx.path = append(childCtx.path, ctx.path...) - childCtx.path = append(childCtx.path, name) - - // set child's parent field - childCtx.parent = ctx - - // create child's ctx and store it in parent - if ctx.children == nil { - ctx.children = map[string]Context{} +// WithChild returns a modified parent which holds a reference to child in its +// Children set. If the child's name is already taken in the parent then this +// function panics. +func WithChild(parent, child context.Context) context.Context { + if !isChild(parent, child) { + panic(fmt.Sprintf("child cannot be kept by Context which is not its parent")) } - ctx.children[name] = childCtx - return childCtx + + name, _ := Name(child) + children, childrenMap := childrenCP(parent) + if _, ok := childrenMap[name]; ok { + panic(fmt.Sprintf("child with name %q already exists on parent", name)) + } + children = append(children, child) + childrenMap[name] = len(children) - 1 + + parent = context.WithValue(parent, ancestryKeyChildren, children) + parent = context.WithValue(parent, ancestryKeyChildrenMap, childrenMap) + return parent } // BreadthFirstVisit visits this Context and all of its children, and their // children, in a breadth-first order. If the callback returns false then the // function returns without visiting any more Contexts. -// -// The exact order of visitation is non-deterministic. -func BreadthFirstVisit(Ctx Context, callback func(Context) bool) { - queue := []Context{Ctx} +func BreadthFirstVisit(ctx context.Context, callback func(context.Context) bool) { + queue := []context.Context{ctx} for len(queue) > 0 { if !callback(queue[0]) { return @@ -167,78 +179,56 @@ func BreadthFirstVisit(Ctx Context, callback func(Context) bool) { } //////////////////////////////////////////////////////////////////////////////// -// code related to mutable values +// local value stuff -// MutableValue acts like the Value method, except that it only deals with -// keys/values set by SetMutableValue. -func MutableValue(Ctx Context, key interface{}) interface{} { - ctx := getCtx(Ctx) - ctx.mutL.RLock() - defer ctx.mutL.RUnlock() - if ctx.mutVals == nil { - return nil - } else if mVal, ok := ctx.mutVals[key]; ok { - mVal.l.RLock() - defer mVal.l.RUnlock() - return mVal.v - } - return nil +type localValsKey int + +type localVal struct { + prev *localVal + key, val interface{} } -// GetSetMutableValue is used to interact with a mutable value on the context in -// a thread-safe way. The key's value is retrieved and passed to the callback. -// The value returned from the callback is stored back into the context as well -// as being returned from this function. -// -// If noCallbackIfSet is set to true, then if the key is already set the value -// will be returned without calling the callback. -// -// The callback returning nil is equivalent to unsetting the value. -// -// Children of this context will _not_ inherit any of its mutable values. -// -// Within the callback it is fine to call other functions/methods on the -// Context, except for those related to mutable values for this same key (e.g. -// MutableValue and SetMutableValue). -func GetSetMutableValue( - Ctx Context, noCallbackIfSet bool, - key interface{}, fn func(interface{}) interface{}, -) interface{} { +// WithLocalValue is like WithValue, but the stored value will not be present +// on any children created via WithChild. Local values must be retrieved with +// the LocalValue function in this package. Local values share a different +// namespace than the normal WithValue/Value values (i.e. they do not overlap). +func WithLocalValue(ctx context.Context, key, val interface{}) context.Context { + prev, _ := ctx.Value(localValsKey(0)).(*localVal) + return context.WithValue(ctx, localValsKey(0), &localVal{ + prev: prev, + key: key, val: val, + }) +} - // if noCallbackIfSet, do a fast lookup with MutableValue first. - if noCallbackIfSet { - if v := MutableValue(Ctx, key); v != nil { - return v +func withoutLocalValues(ctx context.Context) context.Context { + return context.WithValue(ctx, localValsKey(0), nil) +} + +// LocalValue returns the value for the given key which was set by a call to +// WithLocalValue, or nil if no value was set for the given key. +func LocalValue(ctx context.Context, key interface{}) interface{} { + lv, _ := ctx.Value(localValsKey(0)).(*localVal) + for { + if lv == nil { + return nil + } else if lv.key == key { + return lv.val } + lv = lv.prev + } +} + +// LocalValues returns all key/value pairs which have been set on the Context +// via WithLocalValue. +func LocalValues(ctx context.Context) map[interface{}]interface{} { + m := map[interface{}]interface{}{} + lv, _ := ctx.Value(localValsKey(0)).(*localVal) + for { + if lv == nil { + return m + } else if _, ok := m[lv.key]; !ok { + m[lv.key] = lv.val + } + lv = lv.prev } - - ctx := getCtx(Ctx) - ctx.mutL.Lock() - if ctx.mutVals == nil { - ctx.mutVals = map[interface{}]*mutVal{} - } - mVal, ok := ctx.mutVals[key] - if !ok { - mVal = new(mutVal) - ctx.mutVals[key] = mVal - } - ctx.mutL.Unlock() - - mVal.l.Lock() - defer mVal.l.Unlock() - - // It's possible something happened between the first check inside the - // read-lock and now, so double check this case. It's still good to have the - // read-lock check there, it'll handle 99% of the cases. - if noCallbackIfSet && mVal.v != nil { - return mVal.v - } - - mVal.v = fn(mVal.v) - - // TODO if the new v is nil then key could be deleted out of mutVals. But - // doing so would be weird in the case that there's another routine which - // has already pulled this same mVal out of mutVals and is waiting on its - // mutex. - return mVal.v } diff --git a/mctx/ctx_test.go b/mctx/ctx_test.go index 9a71f95..5772299 100644 --- a/mctx/ctx_test.go +++ b/mctx/ctx_test.go @@ -1,18 +1,22 @@ package mctx import ( - "sync" + "context" . "testing" "github.com/mediocregopher/mediocre-go-lib/mtest/massert" ) func TestInheritance(t *T) { - ctx := New() - ctx1 := ChildOf(ctx, "1") - ctx1a := ChildOf(ctx1, "a") - ctx1b := ChildOf(ctx1, "b") - ctx2 := ChildOf(ctx, "2") + ctx := context.Background() + ctx1 := NewChild(ctx, "1") + ctx1a := NewChild(ctx1, "a") + ctx1b := NewChild(ctx1, "b") + ctx1 = WithChild(ctx1, ctx1a) + ctx1 = WithChild(ctx1, ctx1b) + ctx2 := NewChild(ctx, "2") + ctx = WithChild(ctx, ctx1) + ctx = WithChild(ctx, ctx2) massert.Fatal(t, massert.All( massert.Len(Path(ctx), 0), @@ -23,129 +27,130 @@ func TestInheritance(t *T) { )) massert.Fatal(t, massert.All( - massert.Equal( - map[string]Context{"1": ctx1, "2": ctx2}, - Children(ctx), - ), - massert.Equal( - map[string]Context{"a": ctx1a, "b": ctx1b}, - Children(ctx1), - ), - massert.Equal( - map[string]Context{}, - Children(ctx2), - ), - )) - - massert.Fatal(t, massert.All( - massert.Nil(Parent(ctx)), - massert.Equal(Parent(ctx1), ctx), - massert.Equal(Parent(ctx1a), ctx1), - massert.Equal(Parent(ctx1b), ctx1), - massert.Equal(Parent(ctx2), ctx), - )) - - massert.Fatal(t, massert.All( - massert.Equal(Root(ctx), ctx), - massert.Equal(Root(ctx1), ctx), - massert.Equal(Root(ctx1a), ctx), - massert.Equal(Root(ctx1b), ctx), - massert.Equal(Root(ctx2), ctx), + massert.Equal([]context.Context{ctx1, ctx2}, Children(ctx)), + massert.Equal([]context.Context{ctx1a, ctx1b}, Children(ctx1)), + massert.Len(Children(ctx2), 0), )) } func TestBreadFirstVisit(t *T) { - ctx := New() - ctx1 := ChildOf(ctx, "1") - ctx1a := ChildOf(ctx1, "a") - ctx1b := ChildOf(ctx1, "b") - ctx2 := ChildOf(ctx, "2") + ctx := context.Background() + ctx1 := NewChild(ctx, "1") + ctx1a := NewChild(ctx1, "a") + ctx1b := NewChild(ctx1, "b") + ctx1 = WithChild(ctx1, ctx1a) + ctx1 = WithChild(ctx1, ctx1b) + ctx2 := NewChild(ctx, "2") + ctx = WithChild(ctx, ctx1) + ctx = WithChild(ctx, ctx2) { - got := make([]Context, 0, 5) - BreadthFirstVisit(ctx, func(ctx Context) bool { + got := make([]context.Context, 0, 5) + BreadthFirstVisit(ctx, func(ctx context.Context) bool { got = append(got, ctx) return true }) - // since children are stored in a map the exact order is non-deterministic - massert.Fatal(t, massert.Any( - massert.Equal([]Context{ctx, ctx1, ctx2, ctx1a, ctx1b}, got), - massert.Equal([]Context{ctx, ctx1, ctx2, ctx1b, ctx1a}, got), - massert.Equal([]Context{ctx, ctx2, ctx1, ctx1a, ctx1b}, got), - massert.Equal([]Context{ctx, ctx2, ctx1, ctx1b, ctx1a}, got), - )) + massert.Fatal(t, + massert.Equal([]context.Context{ctx, ctx1, ctx2, ctx1a, ctx1b}, got), + ) } { - got := make([]Context, 0, 3) - BreadthFirstVisit(ctx, func(ctx Context) bool { + got := make([]context.Context, 0, 3) + BreadthFirstVisit(ctx, func(ctx context.Context) bool { if len(Path(ctx)) > 1 { return false } got = append(got, ctx) return true }) - massert.Fatal(t, massert.Any( - massert.Equal([]Context{ctx, ctx1, ctx2}, got), - massert.Equal([]Context{ctx, ctx2, ctx1}, got), - )) + massert.Fatal(t, + massert.Equal([]context.Context{ctx, ctx1, ctx2}, got), + ) } } -func TestMutableValues(t *T) { - fn := func(v interface{}) interface{} { - if v == nil { - return 0 - } - return v.(int) + 1 - } +func TestLocalValues(t *T) { - var aa []massert.Assertion + // test with no value set + ctx := context.Background() + massert.Fatal(t, massert.All( + massert.Nil(LocalValue(ctx, "foo")), + massert.Len(LocalValues(ctx), 0), + )) - ctx := New() - aa = append(aa, massert.Equal(GetSetMutableValue(ctx, false, 0, fn), 0)) - aa = append(aa, massert.Equal(GetSetMutableValue(ctx, false, 0, fn), 1)) - aa = append(aa, massert.Equal(GetSetMutableValue(ctx, true, 0, fn), 1)) + // test basic value retrieval + ctx = WithLocalValue(ctx, "foo", "bar") + massert.Fatal(t, massert.All( + massert.Equal("bar", LocalValue(ctx, "foo")), + massert.Equal( + map[interface{}]interface{}{"foo": "bar"}, + LocalValues(ctx), + ), + )) - aa = append(aa, massert.Equal(MutableValue(ctx, 0), 1)) + // test that doesn't conflict with WithValue + ctx = context.WithValue(ctx, "foo", "WithValue bar") + massert.Fatal(t, massert.All( + massert.Equal("bar", LocalValue(ctx, "foo")), + massert.Equal("WithValue bar", ctx.Value("foo")), + massert.Equal( + map[interface{}]interface{}{"foo": "bar"}, + LocalValues(ctx), + ), + )) - ctx1 := ChildOf(ctx, "one") - aa = append(aa, massert.Equal(GetSetMutableValue(ctx1, true, 0, fn), 0)) - aa = append(aa, massert.Equal(GetSetMutableValue(ctx1, false, 0, fn), 1)) - aa = append(aa, massert.Equal(GetSetMutableValue(ctx1, true, 0, fn), 1)) + // test that child doesn't get values + child := NewChild(ctx, "child") + massert.Fatal(t, massert.All( + massert.Equal("bar", LocalValue(ctx, "foo")), + massert.Nil(LocalValue(child, "foo")), + massert.Len(LocalValues(child), 0), + )) - massert.Fatal(t, massert.All(aa...)) -} - -func TestMutableValuesParallel(t *T) { - const events = 1000000 - const workers = 10 - - incr := func(v interface{}) interface{} { - if v == nil { - return 1 - } - return v.(int) + 1 - } - - ch := make(chan bool, events) - for i := 0; i < events; i++ { - ch <- true - } - close(ch) - - ctx := New() - wg := new(sync.WaitGroup) - for i := 0; i < workers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for range ch { - GetSetMutableValue(ctx, false, 0, incr) - } - }() - } - - wg.Wait() - massert.Fatal(t, massert.Equal(events, MutableValue(ctx, 0))) + // test that values on child don't affect parent values + child = WithLocalValue(child, "foo", "child bar") + ctx = WithChild(ctx, child) + massert.Fatal(t, massert.All( + massert.Equal("bar", LocalValue(ctx, "foo")), + massert.Equal("child bar", LocalValue(child, "foo")), + massert.Equal( + map[interface{}]interface{}{"foo": "bar"}, + LocalValues(ctx), + ), + massert.Equal( + map[interface{}]interface{}{"foo": "child bar"}, + LocalValues(child), + ), + )) + + // test that two With calls on the same context generate distinct contexts + childA := WithLocalValue(child, "foo2", "baz") + childB := WithLocalValue(child, "foo2", "buz") + massert.Fatal(t, massert.All( + massert.Equal("bar", LocalValue(ctx, "foo")), + massert.Equal("child bar", LocalValue(child, "foo")), + massert.Nil(LocalValue(child, "foo2")), + massert.Equal("baz", LocalValue(childA, "foo2")), + massert.Equal("buz", LocalValue(childB, "foo2")), + massert.Equal( + map[interface{}]interface{}{"foo": "child bar", "foo2": "baz"}, + LocalValues(childA), + ), + massert.Equal( + map[interface{}]interface{}{"foo": "child bar", "foo2": "buz"}, + LocalValues(childB), + ), + )) + + // if a value overwrites a previous one the newer one should show in + // LocalValues + ctx = WithLocalValue(ctx, "foo", "barbar") + massert.Fatal(t, massert.All( + massert.Equal("barbar", LocalValue(ctx, "foo")), + massert.Equal( + map[interface{}]interface{}{"foo": "barbar"}, + LocalValues(ctx), + ), + )) } diff --git a/mdb/mbigquery/bigquery.go b/mdb/mbigquery/bigquery.go index 24eeaa6..153fb57 100644 --- a/mdb/mbigquery/bigquery.go +++ b/mdb/mbigquery/bigquery.go @@ -34,7 +34,7 @@ func isErrAlreadyExists(err error) bool { type BigQuery struct { *bigquery.Client gce *mdb.GCE - log *mlog.Logger + ctx context.Context // key is dataset/tableName tablesL sync.Mutex @@ -43,35 +43,37 @@ type BigQuery struct { } // MNew returns a BigQuery instance which will be initialized and configured -// when the start event is triggered on ctx (see mrun.Start). The BigQuery -// instance will have Close called on it when the stop event is triggered on ctx -// (see mrun.Stop). +// when the start event is triggered on the returned (see mrun.Start). The +// BigQuery instance will have Close called on it when the stop event is +// triggered on the returned Context (see mrun.Stop). // // gce is optional and can be passed in if there's an existing gce object which // should be used, otherwise a new one will be created with mdb.MGCE. -func MNew(ctx mctx.Context, gce *mdb.GCE) *BigQuery { +func MNew(ctx context.Context, gce *mdb.GCE) (context.Context, *BigQuery) { if gce == nil { - gce = mdb.MGCE(ctx, "") + ctx, gce = mdb.MGCE(ctx, "") } - ctx = mctx.ChildOf(ctx, "bigquery") bq := &BigQuery{ gce: gce, tables: map[[2]string]*bigquery.Table{}, tableUploaders: map[[2]string]*bigquery.Uploader{}, - log: mlog.From(ctx), + ctx: mctx.NewChild(ctx, "bigquery"), } - bq.log.SetKV(bq) - mrun.OnStart(ctx, func(innerCtx mctx.Context) error { - bq.log.Info("connecting to bigquery") + + // TODO the equivalent functionality as here will be added with annotations + // bq.log.SetKV(bq) + + bq.ctx = mrun.OnStart(bq.ctx, func(innerCtx context.Context) error { + mlog.Info(bq.ctx, "connecting to bigquery") var err error bq.Client, err = bigquery.NewClient(innerCtx, bq.gce.Project, bq.gce.ClientOptions()...) return merr.WithKV(err, bq.KV()) }) - mrun.OnStop(ctx, func(mctx.Context) error { + bq.ctx = mrun.OnStop(bq.ctx, func(context.Context) error { return bq.Client.Close() }) - return bq + return mctx.WithChild(ctx, bq.ctx), bq } // KV implements the mlog.KVer interface. @@ -99,7 +101,7 @@ func (bq *BigQuery) Table( } kv := mlog.KV{"bigQueryDataset": dataset, "bigQueryTable": tableName} - bq.log.Debug("creating/grabbing table", kv) + mlog.Debug(bq.ctx, "creating/grabbing table", kv) schema, err := bigquery.InferSchema(schemaObj) if err != nil { diff --git a/mdb/mbigtable/bigtable.go b/mdb/mbigtable/bigtable.go index 2042169..420b62e 100644 --- a/mdb/mbigtable/bigtable.go +++ b/mdb/mbigtable/bigtable.go @@ -28,44 +28,45 @@ type Bigtable struct { Instance string gce *mdb.GCE - log *mlog.Logger + ctx context.Context } // MNew returns a Bigtable instance which will be initialized and configured -// when the start event is triggered on ctx (see mrun.Start). The Bigtable -// instance will have Close called on it when the stop event is triggered on ctx -// (see mrun.Stop). +// when the start event is triggered on the returned Context (see mrun.Start). +// The Bigtable instance will have Close called on it when the stop event is +// triggered on the returned Context (see mrun.Stop). // // gce is optional and can be passed in if there's an existing gce object which // should be used, otherwise a new one will be created with mdb.MGCE. // // defaultInstance can be given as the instance name to use as the default // parameter value. If empty the parameter will be required to be set. -func MNew(ctx mctx.Context, gce *mdb.GCE, defaultInstance string) *Bigtable { +func MNew(ctx context.Context, gce *mdb.GCE, defaultInstance string) (context.Context, *Bigtable) { if gce == nil { - gce = mdb.MGCE(ctx, "") + ctx, gce = mdb.MGCE(ctx, "") } - ctx = mctx.ChildOf(ctx, "bigtable") bt := &Bigtable{ gce: gce, - log: mlog.From(ctx), + ctx: mctx.NewChild(ctx, "bigtable"), } - bt.log.SetKV(bt) + + // TODO the equivalent functionality as here will be added with annotations + // bt.log.SetKV(bt) var inst *string { const name, descr = "instance", "name of the bigtable instance in the project to connect to" if defaultInstance != "" { - inst = mcfg.String(ctx, name, defaultInstance, descr) + bt.ctx, inst = mcfg.String(bt.ctx, name, defaultInstance, descr) } else { - inst = mcfg.RequiredString(ctx, name, descr) + bt.ctx, inst = mcfg.RequiredString(bt.ctx, name, descr) } } - mrun.OnStart(ctx, func(innerCtx mctx.Context) error { + bt.ctx = mrun.OnStart(bt.ctx, func(innerCtx context.Context) error { bt.Instance = *inst - bt.log.Info("connecting to bigtable") + mlog.Info(bt.ctx, "connecting to bigtable", bt) var err error bt.Client, err = bigtable.NewClient( innerCtx, @@ -74,10 +75,10 @@ func MNew(ctx mctx.Context, gce *mdb.GCE, defaultInstance string) *Bigtable { ) return merr.WithKV(err, bt.KV()) }) - mrun.OnStop(ctx, func(mctx.Context) error { + bt.ctx = mrun.OnStop(bt.ctx, func(context.Context) error { return bt.Client.Close() }) - return bt + return mctx.WithChild(ctx, bt.ctx), bt } // KV implements the mlog.KVer interface. @@ -93,16 +94,16 @@ func (bt *Bigtable) KV() map[string]interface{} { // This method requires admin privileges on the bigtable instance. func (bt *Bigtable) EnsureTable(ctx context.Context, name string, colFams ...string) error { kv := mlog.KV{"bigtableTable": name} - bt.log.Info("ensuring table", kv) + mlog.Info(bt.ctx, "ensuring table", kv) - bt.log.Debug("creating admin client", kv) + mlog.Debug(bt.ctx, "creating admin client", kv) adminClient, err := bigtable.NewAdminClient(ctx, bt.gce.Project, bt.Instance) if err != nil { return merr.WithKV(err, bt.KV(), kv.KV()) } defer adminClient.Close() - bt.log.Debug("creating bigtable table (if needed)", kv) + mlog.Debug(bt.ctx, "creating bigtable table (if needed)", kv) err = adminClient.CreateTable(ctx, name) if err != nil && !isErrAlreadyExists(err) { return merr.WithKV(err, bt.KV(), kv.KV()) @@ -110,7 +111,7 @@ func (bt *Bigtable) EnsureTable(ctx context.Context, name string, colFams ...str for _, colFam := range colFams { kv := kv.Set("family", colFam) - bt.log.Debug("creating bigtable column family (if needed)", kv) + mlog.Debug(bt.ctx, "creating bigtable column family (if needed)", kv) err := adminClient.CreateColumnFamily(ctx, name, colFam) if err != nil && !isErrAlreadyExists(err) { return merr.WithKV(err, bt.KV(), kv.KV()) diff --git a/mdb/mbigtable/bigtable_test.go b/mdb/mbigtable/bigtable_test.go index e304df6..ad1427c 100644 --- a/mdb/mbigtable/bigtable_test.go +++ b/mdb/mbigtable/bigtable_test.go @@ -12,8 +12,8 @@ import ( func TestBasic(t *T) { ctx := mtest.NewCtx() - mtest.SetEnv(ctx, "GCE_PROJECT", "testProject") - bt := MNew(ctx, nil, "testInstance") + ctx = mtest.SetEnv(ctx, "GCE_PROJECT", "testProject") + ctx, bt := MNew(ctx, nil, "testInstance") mtest.Run(ctx, t, func() { tableName := "test-" + mrand.Hex(8) diff --git a/mdb/mdatastore/datastore.go b/mdb/mdatastore/datastore.go index 04730ea..c150796 100644 --- a/mdb/mdatastore/datastore.go +++ b/mdb/mdatastore/datastore.go @@ -3,6 +3,8 @@ package mdatastore import ( + "context" + "cloud.google.com/go/datastore" "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/mdb" @@ -17,38 +19,39 @@ type Datastore struct { *datastore.Client gce *mdb.GCE - log *mlog.Logger + ctx context.Context } // MNew returns a Datastore instance which will be initialized and configured -// when the start event is triggered on ctx (see mrun.Start). The Datastore -// instance will have Close called on it when the stop event is triggered on ctx -// (see mrun.Stop). +// when the start event is triggered on the returned Context (see mrun.Start). +// The Datastore instance will have Close called on it when the stop event is +// triggered on the returned Context (see mrun.Stop). // // gce is optional and can be passed in if there's an existing gce object which // should be used, otherwise a new one will be created with mdb.MGCE. -func MNew(ctx mctx.Context, gce *mdb.GCE) *Datastore { +func MNew(ctx context.Context, gce *mdb.GCE) (context.Context, *Datastore) { if gce == nil { - gce = mdb.MGCE(ctx, "") + ctx, gce = mdb.MGCE(ctx, "") } - ctx = mctx.ChildOf(ctx, "datastore") ds := &Datastore{ gce: gce, - log: mlog.From(ctx), + ctx: mctx.NewChild(ctx, "datastore"), } - ds.log.SetKV(ds) - mrun.OnStart(ctx, func(innerCtx mctx.Context) error { - ds.log.Info("connecting to datastore") + // TODO the equivalent functionality as here will be added with annotations + // ds.log.SetKV(ds) + + ds.ctx = mrun.OnStart(ds.ctx, func(innerCtx context.Context) error { + mlog.Info(ds.ctx, "connecting to datastore") var err error ds.Client, err = datastore.NewClient(innerCtx, ds.gce.Project, ds.gce.ClientOptions()...) return merr.WithKV(err, ds.KV()) }) - mrun.OnStop(ctx, func(mctx.Context) error { + ds.ctx = mrun.OnStop(ds.ctx, func(context.Context) error { return ds.Client.Close() }) - return ds + return mctx.WithChild(ctx, ds.ctx), ds } // KV implements the mlog.KVer interface. diff --git a/mdb/mdatastore/datastore_test.go b/mdb/mdatastore/datastore_test.go index d6432c2..c6b8270 100644 --- a/mdb/mdatastore/datastore_test.go +++ b/mdb/mdatastore/datastore_test.go @@ -12,8 +12,8 @@ import ( // Requires datastore emulator to be running func TestBasic(t *T) { ctx := mtest.NewCtx() - mtest.SetEnv(ctx, "GCE_PROJECT", "test") - ds := MNew(ctx, nil) + ctx = mtest.SetEnv(ctx, "GCE_PROJECT", "test") + ctx, ds := MNew(ctx, nil) mtest.Run(ctx, t, func() { name := mrand.Hex(8) key := datastore.NameKey("testKind", name, nil) diff --git a/mdb/mdb.go b/mdb/mdb.go index 89a1695..a148984 100644 --- a/mdb/mdb.go +++ b/mdb/mdb.go @@ -3,6 +3,8 @@ package mdb import ( + "context" + "github.com/mediocregopher/mediocre-go-lib/mcfg" "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/mlog" @@ -18,27 +20,28 @@ type GCE struct { } // MGCE returns a GCE instance which will be initialized and configured when the -// start event is triggered on ctx (see mrun.Start). defaultProject is used as -// the default value for the mcfg parameter this function creates. -func MGCE(ctx mctx.Context, defaultProject string) *GCE { - ctx = mctx.ChildOf(ctx, "gce") - credFile := mcfg.String(ctx, "cred-file", "", "Path to GCE credientials JSON file, if any") +// start event is triggered on the returned Context (see mrun.Start). +// defaultProject is used as the default value for the mcfg parameter this +// function creates. +func MGCE(parent context.Context, defaultProject string) (context.Context, *GCE) { + ctx := mctx.NewChild(parent, "gce") + ctx, credFile := mcfg.String(ctx, "cred-file", "", "Path to GCE credientials JSON file, if any") var project *string const projectUsage = "Name of GCE project to use" if defaultProject == "" { - project = mcfg.RequiredString(ctx, "project", projectUsage) + ctx, project = mcfg.RequiredString(ctx, "project", projectUsage) } else { - project = mcfg.String(ctx, "project", defaultProject, projectUsage) + ctx, project = mcfg.String(ctx, "project", defaultProject, projectUsage) } var gce GCE - mrun.OnStart(ctx, func(mctx.Context) error { + ctx = mrun.OnStart(ctx, func(context.Context) error { gce.Project = *project gce.CredFile = *credFile return nil }) - return &gce + return mctx.WithChild(parent, ctx), &gce } // ClientOptions generates and returns the ClientOption instances which can be diff --git a/mdb/mpubsub/pubsub.go b/mdb/mpubsub/pubsub.go index 31e23be..64c06c1 100644 --- a/mdb/mpubsub/pubsub.go +++ b/mdb/mpubsub/pubsub.go @@ -19,7 +19,6 @@ import ( "google.golang.org/grpc/status" ) -// TODO this package still uses context.Context in the callback functions // TODO Consume (and probably BatchConsume) don't properly handle the Client // being closed. @@ -39,38 +38,39 @@ type PubSub struct { *pubsub.Client gce *mdb.GCE - log *mlog.Logger + ctx context.Context } // MNew returns a PubSub instance which will be initialized and configured when -// the start event is triggered on ctx (see mrun.Start). The PubSub instance -// will have Close called on it when the stop event is triggered on ctx (see -// mrun.Stop). +// the start event is triggered on the returned Context (see mrun.Start). The +// PubSub instance will have Close called on it when the stop event is triggered +// on the returned Context(see mrun.Stop). // // gce is optional and can be passed in if there's an existing gce object which // should be used, otherwise a new one will be created with mdb.MGCE. -func MNew(ctx mctx.Context, gce *mdb.GCE) *PubSub { +func MNew(ctx context.Context, gce *mdb.GCE) (context.Context, *PubSub) { if gce == nil { - gce = mdb.MGCE(ctx, "") + ctx, gce = mdb.MGCE(ctx, "") } - ctx = mctx.ChildOf(ctx, "pubsub") ps := &PubSub{ gce: gce, - log: mlog.From(ctx), + ctx: mctx.NewChild(ctx, "pubsub"), } - ps.log.SetKV(ps) - mrun.OnStart(ctx, func(innerCtx mctx.Context) error { - ps.log.Info("connecting to pubsub") + // TODO the equivalent functionality as here will be added with annotations + // ps.log.SetKV(ps) + + ps.ctx = mrun.OnStart(ps.ctx, func(innerCtx context.Context) error { + mlog.Info(ps.ctx, "connecting to pubsub") var err error ps.Client, err = pubsub.NewClient(innerCtx, ps.gce.Project, ps.gce.ClientOptions()...) return merr.WithKV(err, ps.KV()) }) - mrun.OnStop(ctx, func(mctx.Context) error { + ps.ctx = mrun.OnStop(ps.ctx, func(context.Context) error { return ps.Client.Close() }) - return ps + return mctx.WithChild(ctx, ps.ctx), ps } // KV implements the mlog.KVer interface @@ -226,7 +226,7 @@ func (s *Subscription) Consume(ctx context.Context, fn ConsumerFunc, opts Consum ok, err := fn(context.Context(innerCtx), msg) if err != nil { - s.topic.ps.log.Warn("error consuming pubsub message", s, merr.KV(err)) + mlog.Warn(s.topic.ps.ctx, "error consuming pubsub message", s, merr.KV(err)) } if ok { @@ -238,7 +238,7 @@ func (s *Subscription) Consume(ctx context.Context, fn ConsumerFunc, opts Consum if octx.Err() == context.Canceled || err == nil { return } else if err != nil { - s.topic.ps.log.Warn("error consuming from pubsub", s, merr.KV(err)) + mlog.Warn(s.topic.ps.ctx, "error consuming from pubsub", s, merr.KV(err)) time.Sleep(1 * time.Second) } } @@ -331,7 +331,7 @@ func (s *Subscription) BatchConsume( } ret, err := fn(thisCtx, msgs) if err != nil { - s.topic.ps.log.Warn("error consuming pubsub batch messages", s, merr.KV(err)) + mlog.Warn(s.topic.ps.ctx, "error consuming pubsub batch messages", s, merr.KV(err)) } for i := range thisGroup { thisGroup[i].retCh <- ret // retCh is buffered diff --git a/mdb/mpubsub/pubsub_test.go b/mdb/mpubsub/pubsub_test.go index e516d26..af0feb1 100644 --- a/mdb/mpubsub/pubsub_test.go +++ b/mdb/mpubsub/pubsub_test.go @@ -14,8 +14,8 @@ import ( // this requires the pubsub emulator to be running func TestPubSub(t *T) { ctx := mtest.NewCtx() - mtest.SetEnv(ctx, "GCE_PROJECT", "test") - ps := MNew(ctx, nil) + ctx = mtest.SetEnv(ctx, "GCE_PROJECT", "test") + ctx, ps := MNew(ctx, nil) mtest.Run(ctx, t, func() { topicName := "testTopic_" + mrand.Hex(8) ctx := context.Background() @@ -48,8 +48,8 @@ func TestPubSub(t *T) { func TestBatchPubSub(t *T) { ctx := mtest.NewCtx() - mtest.SetEnv(ctx, "GCE_PROJECT", "test") - ps := MNew(ctx, nil) + ctx = mtest.SetEnv(ctx, "GCE_PROJECT", "test") + ctx, ps := MNew(ctx, nil) mtest.Run(ctx, t, func() { topicName := "testBatchTopic_" + mrand.Hex(8) diff --git a/mhttp/mhttp.go b/mhttp/mhttp.go index 96ebf5f..64c50ce 100644 --- a/mhttp/mhttp.go +++ b/mhttp/mhttp.go @@ -17,27 +17,39 @@ import ( "github.com/mediocregopher/mediocre-go-lib/mrun" ) +// MServer is returned by MListenAndServe and simply wraps an *http.Server. +type MServer struct { + *http.Server + ctx context.Context +} + // MListenAndServe returns an http.Server which will be initialized and have // ListenAndServe called on it (asynchronously) when the start event is -// triggered on ctx (see mrun.Start). The Server will have Shutdown called on it -// when the stop event is triggered on ctx (see mrun.Stop). +// triggered on the returned Context (see mrun.Start). The Server will have +// Shutdown called on it when the stop event is triggered on the returned +// Context (see mrun.Stop). // // This function automatically handles setting up configuration parameters via // mcfg. The default listen address is ":0". -func MListenAndServe(ctx mctx.Context, h http.Handler) *http.Server { - ctx = mctx.ChildOf(ctx, "http") - listener := mnet.MListen(ctx, "tcp", "") +func MListenAndServe(ctx context.Context, h http.Handler) (context.Context, *MServer) { + srv := &MServer{ + Server: &http.Server{Handler: h}, + ctx: mctx.NewChild(ctx, "http"), + } + + var listener *mnet.MListener + srv.ctx, listener = mnet.MListen(srv.ctx, "tcp", "") listener.NoCloseOnStop = true // http.Server.Shutdown will do this - logger := mlog.From(ctx) - logger.SetKV(listener) + // TODO the equivalent functionality as here will be added with annotations + //logger := mlog.From(ctx) + //logger.SetKV(listener) - srv := http.Server{Handler: h} - mrun.OnStart(ctx, func(mctx.Context) error { + srv.ctx = mrun.OnStart(srv.ctx, func(context.Context) error { srv.Addr = listener.Addr().String() - mrun.Thread(ctx, func() error { + srv.ctx = mrun.Thread(srv.ctx, func() error { if err := srv.Serve(listener); err != http.ErrServerClosed { - logger.Error("error serving listener", merr.KV(err)) + mlog.Error(srv.ctx, "error serving listener", merr.KV(err)) return err } return nil @@ -45,15 +57,15 @@ func MListenAndServe(ctx mctx.Context, h http.Handler) *http.Server { return nil }) - mrun.OnStop(ctx, func(innerCtx mctx.Context) error { - logger.Info("shutting down server") - if err := srv.Shutdown(context.Context(innerCtx)); err != nil { + srv.ctx = mrun.OnStop(srv.ctx, func(innerCtx context.Context) error { + mlog.Info(srv.ctx, "shutting down server") + if err := srv.Shutdown(innerCtx); err != nil { return err } - return mrun.Wait(ctx, innerCtx.Done()) + return mrun.Wait(srv.ctx, innerCtx.Done()) }) - return &srv + return mctx.WithChild(ctx, srv.ctx), srv } // AddXForwardedFor populates the X-Forwarded-For header on the Request to diff --git a/mhttp/mhttp_test.go b/mhttp/mhttp_test.go index 621ca10..1574d9b 100644 --- a/mhttp/mhttp_test.go +++ b/mhttp/mhttp_test.go @@ -15,7 +15,7 @@ import ( func TestMListenAndServe(t *T) { ctx := mtest.NewCtx() - srv := MListenAndServe(ctx, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ctx, srv := MListenAndServe(ctx, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { io.Copy(rw, r.Body) })) diff --git a/mlog/ctx.go b/mlog/ctx.go index e601f1d..20a9bb2 100644 --- a/mlog/ctx.go +++ b/mlog/ctx.go @@ -1,87 +1,53 @@ package mlog -import ( - "path" - - "github.com/mediocregopher/mediocre-go-lib/mctx" -) +import "context" type ctxKey int -// CtxSet caches the Logger in the Context, overwriting any previous one which -// might have been cached there. From is the corresponding function which -// retrieves the Logger back out when needed. -// -// This function can be used to premptively set a preconfigured Logger on a root -// Context so that the default (NewLogger) isn't used when From is called for -// the first time. -func CtxSet(ctx mctx.Context, l *Logger) { - mctx.GetSetMutableValue(ctx, false, ctxKey(0), func(interface{}) interface{} { +// Set returns the Context with the Logger carried by it. +func Set(ctx context.Context, l *Logger) context.Context { + return context.WithValue(ctx, ctxKey(0), l) +} + +// DefaultLogger is an instance of Logger which is returned by From when a +// Logger hasn't been previously Set on the Context passed in. +var DefaultLogger = NewLogger() + +// From returns the Logger carried by this Context, or DefaultLogger if none is +// being carried. +func From(ctx context.Context) *Logger { + if l, _ := ctx.Value(ctxKey(0)).(*Logger); l != nil { return l - }) + } + return DefaultLogger } -// CtxSetAll traverses the given Context's children, breadth-first. It calls the -// callback for each Context which has a Logger set on it, replacing that Logger -// with the returned one. -// -// This is useful, for example, when changing the log level of all Loggers in an -// app. -func CtxSetAll(ctx mctx.Context, callback func(mctx.Context, *Logger) *Logger) { - mctx.BreadthFirstVisit(ctx, func(ctx mctx.Context) bool { - mctx.GetSetMutableValue(ctx, false, ctxKey(0), func(i interface{}) interface{} { - if i == nil { - return nil - } - return callback(ctx, i.(*Logger)) - }) - return true - }) +// Debug is a shortcut for +// mlog.From(ctx).Debug(ctx, descr, kvs...) +func Debug(ctx context.Context, descr string, kvs ...KVer) { + From(ctx).Debug(ctx, descr, kvs...) } -type ctxPathStringer struct { - str Stringer - pathStr string +// Info is a shortcut for +// mlog.From(ctx).Info(ctx, descr, kvs...) +func Info(ctx context.Context, descr string, kvs ...KVer) { + From(ctx).Info(ctx, descr, kvs...) } -func (cp ctxPathStringer) String() string { - return "(" + cp.pathStr + ") " + cp.str.String() +// Warn is a shortcut for +// mlog.From(ctx).Warn(ctx, descr, kvs...) +func Warn(ctx context.Context, descr string, kvs ...KVer) { + From(ctx).Warn(ctx, descr, kvs...) } -// From returns an instance of Logger which has been customized for this -// Context, primarily by adding a prefix describing the Context's path to all -// Message descriptions the Logger receives. -// -// The Context caches within it the generated Logger, so a new one isn't created -// everytime. When From is first called on a Context the Logger inherits the -// Context parent's Logger. If the parent hasn't had From called on it its -// parent will be queried instead, and so on. -func From(ctx mctx.Context) *Logger { - return mctx.GetSetMutableValue(ctx, true, ctxKey(0), func(interface{}) interface{} { - ctxPath := mctx.Path(ctx) - if len(ctxPath) == 0 { - // we're at the root node and it doesn't have a Logger set, use - // the default - return NewLogger() - } - - // set up child's logger - pathStr := "/" + path.Join(ctxPath...) - - parentL := From(mctx.Parent(ctx)) - parentH := parentL.Handler() - thisL := parentL.Clone() - thisL.SetHandler(func(msg Message) error { - // if the Description is already a ctxPathStringer it can be - // assumed this Message was passed in from a child Logger. - if _, ok := msg.Description.(ctxPathStringer); !ok { - msg.Description = ctxPathStringer{ - str: msg.Description, - pathStr: pathStr, - } - } - return parentH(msg) - }) - return thisL - }).(*Logger) +// Error is a shortcut for +// mlog.From(ctx).Error(ctx, descr, kvs...) +func Error(ctx context.Context, descr string, kvs ...KVer) { + From(ctx).Error(ctx, descr, kvs...) +} + +// Fatal is a shortcut for +// mlog.From(ctx).Fatal(ctx, descr, kvs...) +func Fatal(ctx context.Context, descr string, kvs ...KVer) { + From(ctx).Fatal(ctx, descr, kvs...) } diff --git a/mlog/ctx_test.go b/mlog/ctx_test.go index 7dba374..24fd446 100644 --- a/mlog/ctx_test.go +++ b/mlog/ctx_test.go @@ -1,60 +1,67 @@ package mlog import ( + "bytes" + "context" + "strings" . "testing" "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/mtest/massert" ) -func TestContextStuff(t *T) { - ctx := mctx.New() - ctx1 := mctx.ChildOf(ctx, "1") - ctx1a := mctx.ChildOf(ctx1, "a") - ctx1b := mctx.ChildOf(ctx1, "b") +func TestContextLogging(t *T) { - var descrs []string + var lines []string l := NewLogger() l.SetHandler(func(msg Message) error { - descrs = append(descrs, msg.Description.String()) + buf := new(bytes.Buffer) + if err := DefaultFormat(buf, msg); err != nil { + t.Fatal(err) + } + lines = append(lines, strings.TrimSuffix(buf.String(), "\n")) return nil }) - CtxSet(ctx, l) - From(ctx1a).Info("ctx1a") - From(ctx1).Info("ctx1") - From(ctx).Info("ctx") - From(ctx1b).Debug("ctx1b (shouldn't show up)") - From(ctx1b).Info("ctx1b") + ctx := Set(context.Background(), l) + ctx1 := mctx.NewChild(ctx, "1") + ctx1a := mctx.NewChild(ctx1, "a") + ctx1b := mctx.NewChild(ctx1, "b") + ctx1 = mctx.WithChild(ctx1, ctx1a) + ctx1 = mctx.WithChild(ctx1, ctx1b) + ctx = mctx.WithChild(ctx, ctx1) - ctx2 := mctx.ChildOf(ctx, "2") - From(ctx2).Info("ctx2") + From(ctx).Info(ctx1a, "ctx1a") + From(ctx).Info(ctx1, "ctx1") + From(ctx).Info(ctx, "ctx") + From(ctx).Debug(ctx1b, "ctx1b (shouldn't show up)") + From(ctx).Info(ctx1b, "ctx1b") + + ctx2 := mctx.NewChild(ctx, "2") + ctx = mctx.WithChild(ctx, ctx2) + From(ctx2).Info(ctx2, "ctx2") massert.Fatal(t, massert.All( - massert.Len(descrs, 5), - massert.Equal(descrs[0], "(/1/a) ctx1a"), - massert.Equal(descrs[1], "(/1) ctx1"), - massert.Equal(descrs[2], "ctx"), - massert.Equal(descrs[3], "(/1/b) ctx1b"), - massert.Equal(descrs[4], "(/2) ctx2"), + massert.Len(lines, 5), + massert.Equal(lines[0], "~ INFO -- (/1/a) ctx1a"), + massert.Equal(lines[1], "~ INFO -- (/1) ctx1"), + massert.Equal(lines[2], "~ INFO -- ctx"), + massert.Equal(lines[3], "~ INFO -- (/1/b) ctx1b"), + massert.Equal(lines[4], "~ INFO -- (/2) ctx2"), )) - // use CtxSetAll to change all MaxLevels in-place - ctx2L := From(ctx2) - CtxSetAll(ctx, func(_ mctx.Context, l *Logger) *Logger { - l.SetMaxLevel(DebugLevel) - return l - }) + // changing MaxLevel on ctx's Logger should change it for all + From(ctx).SetMaxLevel(DebugLevel) - descrs = descrs[:0] - From(ctx).Info("ctx") - From(ctx).Debug("ctx debug") - ctx2L.Debug("ctx2L debug") + lines = lines[:0] + From(ctx).Info(ctx, "ctx") + From(ctx).Debug(ctx, "ctx debug") + From(ctx2).Debug(ctx2, "ctx2 debug") massert.Fatal(t, massert.All( - massert.Len(descrs, 3), - massert.Equal(descrs[0], "ctx"), - massert.Equal(descrs[1], "ctx debug"), - massert.Equal(descrs[2], "(/2) ctx2L debug"), + massert.Len(lines, 3), + massert.Equal(lines[0], "~ INFO -- ctx"), + massert.Equal(lines[1], "~ DEBUG -- ctx debug"), + massert.Equal(lines[2], "~ DEBUG -- (/2) ctx2 debug"), )) } diff --git a/mlog/mlog.go b/mlog/mlog.go index 54caf22..cac1f01 100644 --- a/mlog/mlog.go +++ b/mlog/mlog.go @@ -17,6 +17,7 @@ package mlog import ( "bufio" + "context" "fmt" "io" "os" @@ -25,6 +26,7 @@ import ( "strings" "sync" + "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/merr" ) @@ -204,23 +206,12 @@ func Prefix(kv KVer, prefix string) KVer { //////////////////////////////////////////////////////////////////////////////// -// Stringer generates and returns a string. -type Stringer interface { - String() string -} - -// String is simply a string which implements Stringer. -type String string - -func (str String) String() string { - return string(str) -} - // Message describes a message to be logged, after having already resolved the // KVer type Message struct { + context.Context Level - Description Stringer + Description string KVer } @@ -253,7 +244,11 @@ func DefaultFormat(w io.Writer, msg Message) error { _, err = fmt.Fprintf(w, s, args...) } } - write("~ %s -- %s", msg.Level.String(), msg.Description.String()) + write("~ %s -- ", msg.Level.String()) + if path := mctx.Path(msg.Context); len(path) > 0 { + write("(%s) ", "/"+strings.Join(path, "/")) + } + write("%s", msg.Description) if msg.KVer != nil { if kv := msg.KV(); len(kv) > 0 { write(" --") @@ -363,7 +358,7 @@ func (l *Logger) Log(msg Message) { } if err := l.h(msg); err != nil { - go l.Error("Logger.Handler returned error", merr.KV(err)) + go l.Error(context.Background(), "Logger.Handler returned error", merr.KV(err)) return } @@ -376,36 +371,37 @@ func (l *Logger) Log(msg Message) { } } -func mkMsg(lvl Level, descr string, kvs ...KVer) Message { +func mkMsg(ctx context.Context, lvl Level, descr string, kvs ...KVer) Message { return Message{ + Context: ctx, Level: lvl, - Description: String(descr), + Description: descr, KVer: Merge(kvs...), } } // Debug logs a DebugLevel message, merging the KVers together first -func (l *Logger) Debug(descr string, kvs ...KVer) { - l.Log(mkMsg(DebugLevel, descr, kvs...)) +func (l *Logger) Debug(ctx context.Context, descr string, kvs ...KVer) { + l.Log(mkMsg(ctx, DebugLevel, descr, kvs...)) } // Info logs a InfoLevel message, merging the KVers together first -func (l *Logger) Info(descr string, kvs ...KVer) { - l.Log(mkMsg(InfoLevel, descr, kvs...)) +func (l *Logger) Info(ctx context.Context, descr string, kvs ...KVer) { + l.Log(mkMsg(ctx, InfoLevel, descr, kvs...)) } // Warn logs a WarnLevel message, merging the KVers together first -func (l *Logger) Warn(descr string, kvs ...KVer) { - l.Log(mkMsg(WarnLevel, descr, kvs...)) +func (l *Logger) Warn(ctx context.Context, descr string, kvs ...KVer) { + l.Log(mkMsg(ctx, WarnLevel, descr, kvs...)) } // Error logs a ErrorLevel message, merging the KVers together first -func (l *Logger) Error(descr string, kvs ...KVer) { - l.Log(mkMsg(ErrorLevel, descr, kvs...)) +func (l *Logger) Error(ctx context.Context, descr string, kvs ...KVer) { + l.Log(mkMsg(ctx, ErrorLevel, descr, kvs...)) } // Fatal logs a FatalLevel message, merging the KVers together first. A Fatal // message automatically stops the process with an os.Exit(1) -func (l *Logger) Fatal(descr string, kvs ...KVer) { - l.Log(mkMsg(FatalLevel, descr, kvs...)) +func (l *Logger) Fatal(ctx context.Context, descr string, kvs ...KVer) { + l.Log(mkMsg(ctx, FatalLevel, descr, kvs...)) } diff --git a/mlog/mlog_test.go b/mlog/mlog_test.go index 0c5c0ec..4065de4 100644 --- a/mlog/mlog_test.go +++ b/mlog/mlog_test.go @@ -2,6 +2,7 @@ package mlog import ( "bytes" + "context" "regexp" "strings" . "testing" @@ -36,6 +37,7 @@ func TestKV(t *T) { } func TestLogger(t *T) { + ctx := context.Background() buf := new(bytes.Buffer) h := func(msg Message) error { return DefaultFormat(buf, msg) @@ -59,10 +61,10 @@ func TestLogger(t *T) { } // Default max level should be INFO - l.Debug("foo") - l.Info("bar") - l.Warn("baz") - l.Error("buz") + l.Debug(ctx, "foo") + l.Info(ctx, "bar") + l.Warn(ctx, "baz") + l.Error(ctx, "buz") massert.Fatal(t, massert.All( assertOut("~ INFO -- bar\n"), assertOut("~ WARN -- baz\n"), @@ -70,10 +72,10 @@ func TestLogger(t *T) { )) l.SetMaxLevel(WarnLevel) - l.Debug("foo") - l.Info("bar") - l.Warn("baz") - l.Error("buz", KV{"a": "b"}) + l.Debug(ctx, "foo") + l.Info(ctx, "bar") + l.Warn(ctx, "baz") + l.Error(ctx, "buz", KV{"a": "b"}) massert.Fatal(t, massert.All( assertOut("~ WARN -- baz\n"), assertOut("~ ERROR -- buz -- a=\"b\"\n"), @@ -82,12 +84,12 @@ func TestLogger(t *T) { l2 := l.Clone() l2.SetMaxLevel(InfoLevel) l2.SetHandler(func(msg Message) error { - msg.Description = String(strings.ToUpper(msg.Description.String())) + msg.Description = strings.ToUpper(msg.Description) return h(msg) }) - l2.Info("bar") - l2.Warn("baz") - l.Error("buz") + l2.Info(ctx, "bar") + l2.Warn(ctx, "baz") + l.Error(ctx, "buz") massert.Fatal(t, massert.All( assertOut("~ INFO -- BAR\n"), assertOut("~ WARN -- BAZ\n"), @@ -96,8 +98,8 @@ func TestLogger(t *T) { l3 := l2.Clone() l3.SetKV(KV{"a": 1}) - l3.Info("foo", KV{"b": 2}) - l3.Info("bar", KV{"a": 2, "b": 3}) + l3.Info(ctx, "foo", KV{"b": 2}) + l3.Info(ctx, "bar", KV{"a": 2, "b": 3}) massert.Fatal(t, massert.All( assertOut("~ INFO -- FOO -- a=\"1\" b=\"2\"\n"), assertOut("~ INFO -- BAR -- a=\"2\" b=\"3\"\n"), @@ -121,7 +123,11 @@ func TestDefaultFormat(t *T) { ) } - msg := Message{Level: InfoLevel, Description: String("this is a test")} + msg := Message{ + Context: context.Background(), + Level: InfoLevel, + Description: "this is a test", + } massert.Fatal(t, assertFormat("INFO -- this is a test", msg)) msg.KVer = KV{} diff --git a/mnet/mnet.go b/mnet/mnet.go index 9cb773d..bb0aa05 100644 --- a/mnet/mnet.go +++ b/mnet/mnet.go @@ -3,11 +3,11 @@ package mnet import ( + "context" "net" "strings" "github.com/mediocregopher/mediocre-go-lib/mcfg" - "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/merr" "github.com/mediocregopher/mediocre-go-lib/mlog" "github.com/mediocregopher/mediocre-go-lib/mrun" @@ -16,7 +16,7 @@ import ( // MListener is returned by MListen and simply wraps a net.Listener. type MListener struct { net.Listener - log *mlog.Logger + ctx context.Context // If set to true before mrun's stop event is run, the stop event will not // cause the MListener to be closed. @@ -24,44 +24,47 @@ type MListener struct { } // MListen returns an MListener which will be initialized when the start event -// is triggered on ctx (see mrun.Start), and closed when the stop event is -// triggered on ctx (see mrun.Stop). +// is triggered on the returned Context (see mrun.Start), and closed when the +// stop event is triggered on the returned Context (see mrun.Stop). // // network defaults to "tcp" if empty. defaultAddr defaults to ":0" if empty, // and will be configurable via mcfg. -func MListen(ctx mctx.Context, network, defaultAddr string) *MListener { +func MListen(ctx context.Context, network, defaultAddr string) (context.Context, *MListener) { if network == "" { network = "tcp" } if defaultAddr == "" { defaultAddr = ":0" } - addr := mcfg.String(ctx, "listen-addr", defaultAddr, strings.ToUpper(network)+" address to listen on in format [host]:port. If port is 0 then a random one will be chosen") l := new(MListener) - l.log = mlog.From(ctx) - l.log.SetKV(l) - mrun.OnStart(ctx, func(mctx.Context) error { + // TODO the equivalent functionality as here will be added with annotations + //l.log = mlog.From(ctx) + //l.log.SetKV(l) + + ctx, addr := mcfg.String(ctx, "listen-addr", defaultAddr, strings.ToUpper(network)+" address to listen on in format [host]:port. If port is 0 then a random one will be chosen") + ctx = mrun.OnStart(ctx, func(context.Context) error { var err error if l.Listener, err = net.Listen(network, *addr); err != nil { return err } - l.log.Info("listening") + mlog.Info(l.ctx, "listening") return nil }) // TODO track connections and wait for them to complete before shutting // down? - mrun.OnStop(ctx, func(mctx.Context) error { + ctx = mrun.OnStop(ctx, func(context.Context) error { if l.NoCloseOnStop { return nil } - l.log.Info("stopping listener") + mlog.Info(l.ctx, "stopping listener") return l.Close() }) - return l + l.ctx = ctx + return ctx, l } // Accept wraps a call to Accept on the underlying net.Listener, providing debug @@ -71,16 +74,16 @@ func (l *MListener) Accept() (net.Conn, error) { if err != nil { return conn, err } - l.log.Debug("connection accepted", mlog.KV{"remoteAddr": conn.RemoteAddr()}) + mlog.Debug(l.ctx, "connection accepted", mlog.KV{"remoteAddr": conn.RemoteAddr()}) return conn, nil } // Close wraps a call to Close on the underlying net.Listener, providing debug // logging. func (l *MListener) Close() error { - l.log.Debug("listener closing") + mlog.Debug(l.ctx, "listener closing") err := l.Listener.Close() - l.log.Debug("listener closed", merr.KV(err)) + mlog.Debug(l.ctx, "listener closed", merr.KV(err)) return err } diff --git a/mnet/mnet_test.go b/mnet/mnet_test.go index 5c7ba58..b9b9b7e 100644 --- a/mnet/mnet_test.go +++ b/mnet/mnet_test.go @@ -37,7 +37,7 @@ func TestIsReservedIP(t *T) { func TestMListen(t *T) { ctx := mtest.NewCtx() - l := MListen(ctx, "", "") + ctx, l := MListen(ctx, "", "") mtest.Run(ctx, t, func() { go func() { conn, err := net.Dial("tcp", l.Addr().String()) diff --git a/mrun/hook.go b/mrun/hook.go index 5991c7c..2087c4c 100644 --- a/mrun/hook.go +++ b/mrun/hook.go @@ -1,87 +1,143 @@ package mrun -import "github.com/mediocregopher/mediocre-go-lib/mctx" +import ( + "context" -type ctxEventKeyWrap struct { - key interface{} -} + "github.com/mediocregopher/mediocre-go-lib/mctx" +) // Hook describes a function which can be registered to trigger on an event via // the RegisterHook function. -type Hook func(mctx.Context) error +type Hook func(context.Context) error + +type ctxKey int + +const ( + ctxKeyHookEls ctxKey = iota + ctxKeyNumChildren + ctxKeyNumHooks +) + +type ctxKeyWrap struct { + key ctxKey + userKey interface{} +} + +// because we want Hooks to be called in the order created, taking into account +// the creation of children and their hooks as well, we create a sequence of +// elements which can either be a Hook or a child. +type hookEl struct { + hook Hook + child context.Context +} + +func ctxKeys(userKey interface{}) (ctxKeyWrap, ctxKeyWrap, ctxKeyWrap) { + return ctxKeyWrap{ + key: ctxKeyHookEls, + userKey: userKey, + }, ctxKeyWrap{ + key: ctxKeyNumChildren, + userKey: userKey, + }, ctxKeyWrap{ + key: ctxKeyNumHooks, + userKey: userKey, + } +} + +// getHookEls retrieves a copy of the []hookEl in the Context and possibly +// appends more elements if more children have been added since that []hookEl +// was created. +// +// this also returns the latest numChildren and numHooks values for convenience. +func getHookEls(ctx context.Context, userKey interface{}) ([]hookEl, int, int) { + hookElsKey, numChildrenKey, numHooksKey := ctxKeys(userKey) + lastNumChildren, _ := mctx.LocalValue(ctx, numChildrenKey).(int) + lastNumHooks, _ := mctx.LocalValue(ctx, numHooksKey).(int) + lastHookEls, _ := mctx.LocalValue(ctx, hookElsKey).([]hookEl) + children := mctx.Children(ctx) + + // plus 1 in case we wanna append something else outside this function + hookEls := make([]hookEl, len(lastHookEls), lastNumHooks+len(children)-lastNumChildren+1) + copy(hookEls, lastHookEls) + for _, child := range children[lastNumChildren:] { + hookEls = append(hookEls, hookEl{child: child}) + } + return hookEls, len(children), lastNumHooks +} // RegisterHook registers a Hook under a typed key. The Hook will be called when // TriggerHooks is called with that same key. Multiple Hooks can be registered // for the same key, and will be called sequentially when triggered. // -// RegisterHook registers Hooks onto the root of the given Context. Therefore, -// Hooks will be triggered in the global order they were registered. For -// example: if one Hook is registered on a Context, then one is registered on a -// child of that Context, then another one is registered on the original Context -// again, the three Hooks will be triggered in the order: parent, child, -// parent. -// // Hooks will be called with whatever Context is passed into TriggerHooks. -func RegisterHook(ctx mctx.Context, key interface{}, hook Hook) { - ctx = mctx.Root(ctx) - mctx.GetSetMutableValue(ctx, false, ctxEventKeyWrap{key}, func(v interface{}) interface{} { - hooks, _ := v.([]Hook) - return append(hooks, hook) - }) +func RegisterHook(ctx context.Context, key interface{}, hook Hook) context.Context { + hookEls, numChildren, numHooks := getHookEls(ctx, key) + hookEls = append(hookEls, hookEl{hook: hook}) + + hookElsKey, numChildrenKey, numHooksKey := ctxKeys(key) + ctx = mctx.WithLocalValue(ctx, hookElsKey, hookEls) + ctx = mctx.WithLocalValue(ctx, numChildrenKey, numChildren) + ctx = mctx.WithLocalValue(ctx, numHooksKey, numHooks+1) + return ctx } -func triggerHooks(ctx mctx.Context, key interface{}, next func([]Hook) (Hook, []Hook)) error { - rootCtx := mctx.Root(ctx) - var err error - mctx.GetSetMutableValue(rootCtx, false, ctxEventKeyWrap{key}, func(i interface{}) interface{} { - var hook Hook - hooks, _ := i.([]Hook) - for { - if len(hooks) == 0 { - break - } - hook, hooks = next(hooks) - - // err here is the var outside GetSetMutableValue, we lift it out - if err = hook(ctx); err != nil { - break - } +func triggerHooks(ctx context.Context, userKey interface{}, next func([]hookEl) (hookEl, []hookEl)) error { + hookEls, _, _ := getHookEls(ctx, userKey) + var hookEl hookEl + for { + if len(hookEls) == 0 { + break } - - // if there was an error then we want to keep all the hooks which - // weren't called. If there wasn't we want to reset the value to nil so - // the slice doesn't grow unbounded. - if err != nil { - return hooks + hookEl, hookEls = next(hookEls) + if hookEl.child != nil { + if err := triggerHooks(hookEl.child, userKey, next); err != nil { + return err + } + } else if err := hookEl.hook(ctx); err != nil { + return err } - return nil - }) - return err + } + return nil } -// TriggerHooks causes all Hooks registered with RegisterHook under the given -// key to be called in the global order they were registered, using the given -// Context as their input parameter. The given Context does not need to be the -// root Context (see RegisterHook). +// TriggerHooks causes all Hooks registered with RegisterHook on the Context +// (and its predecessors) under the given key to be called in the order they +// were registered. // // If any Hook returns an error no further Hooks will be called and that error // will be returned. // -// TriggerHooks causes all Hooks which were called to be de-registered. If an -// error caused execution to stop prematurely then any Hooks which were not -// called will remain registered. -func TriggerHooks(ctx mctx.Context, key interface{}) error { - return triggerHooks(ctx, key, func(hooks []Hook) (Hook, []Hook) { - return hooks[0], hooks[1:] +// If the Context has children (see the mctx package), and those children have +// Hooks registered under this key, then their Hooks will be called in the +// expected order. For example: +// +// // parent context has hookA registered +// ctx := context.Background() +// ctx = RegisterHook(ctx, 0, hookA) +// +// // child context has hookB registered +// childCtx := mctx.NewChild(ctx, "child") +// childCtx = RegisterHook(childCtx, 0, hookB) +// ctx = mctx.WithChild(ctx, childCtx) // needed to link childCtx to ctx +// +// // parent context has another Hook, hookC, registered +// ctx = RegisterHook(ctx, 0, hookC) +// +// // The Hooks will be triggered in the order: hookA, hookB, then hookC +// err := TriggerHooks(ctx, 0) +// +func TriggerHooks(ctx context.Context, key interface{}) error { + return triggerHooks(ctx, key, func(hookEls []hookEl) (hookEl, []hookEl) { + return hookEls[0], hookEls[1:] }) } // TriggerHooksReverse is the same as TriggerHooks except that registered Hooks // are called in the reverse order in which they were registered. -func TriggerHooksReverse(ctx mctx.Context, key interface{}) error { - return triggerHooks(ctx, key, func(hooks []Hook) (Hook, []Hook) { - last := len(hooks) - 1 - return hooks[last], hooks[:last] +func TriggerHooksReverse(ctx context.Context, key interface{}) error { + return triggerHooks(ctx, key, func(hookEls []hookEl) (hookEl, []hookEl) { + last := len(hookEls) - 1 + return hookEls[last], hookEls[:last] }) } @@ -101,24 +157,24 @@ const ( // server) will want to use the Hook only to initialize, and spawn off a // go-routine to do their actual work. Long-lived tasks should set themselves up // to stop on the stop event (see OnStop). -func OnStart(ctx mctx.Context, hook Hook) { - RegisterHook(ctx, start, hook) +func OnStart(ctx context.Context, hook Hook) context.Context { + return RegisterHook(ctx, start, hook) } // Start runs all Hooks registered using OnStart. This is a special case of // TriggerHooks. -func Start(ctx mctx.Context) error { +func Start(ctx context.Context) error { return TriggerHooks(ctx, start) } // OnStop registers the given Hook to run when Stop is called. This is a special // case of RegisterHook. -func OnStop(ctx mctx.Context, hook Hook) { - RegisterHook(ctx, stop, hook) +func OnStop(ctx context.Context, hook Hook) context.Context { + return RegisterHook(ctx, stop, hook) } // Stop runs all Hooks registered using OnStop in the reverse order in which // they were registered. This is a special case of TriggerHooks. -func Stop(ctx mctx.Context) error { +func Stop(ctx context.Context) error { return TriggerHooksReverse(ctx, stop) } diff --git a/mrun/hook_test.go b/mrun/hook_test.go index ca01a09..c679b7f 100644 --- a/mrun/hook_test.go +++ b/mrun/hook_test.go @@ -1,7 +1,7 @@ package mrun import ( - "errors" + "context" . "testing" "github.com/mediocregopher/mediocre-go-lib/mctx" @@ -9,46 +9,40 @@ import ( ) func TestHooks(t *T) { - ch := make(chan int, 10) - ctx := mctx.New() - ctxChild := mctx.ChildOf(ctx, "child") - + var out []int mkHook := func(i int) Hook { - return func(mctx.Context) error { - ch <- i + return func(context.Context) error { + out = append(out, i) return nil } } - RegisterHook(ctx, 0, mkHook(0)) - RegisterHook(ctxChild, 0, mkHook(1)) - RegisterHook(ctx, 0, mkHook(2)) + ctx := context.Background() + ctx = RegisterHook(ctx, 0, mkHook(1)) + ctx = RegisterHook(ctx, 0, mkHook(2)) - bogusErr := errors.New("bogus error") - RegisterHook(ctxChild, 0, func(mctx.Context) error { return bogusErr }) + ctxA := mctx.NewChild(ctx, "a") + ctxA = RegisterHook(ctxA, 0, mkHook(3)) + ctxA = RegisterHook(ctxA, 999, mkHook(999)) // different key + ctx = mctx.WithChild(ctx, ctxA) - RegisterHook(ctx, 0, mkHook(3)) - RegisterHook(ctx, 0, mkHook(4)) + ctx = RegisterHook(ctx, 0, mkHook(4)) - massert.Fatal(t, massert.All( - massert.Equal(bogusErr, TriggerHooks(ctx, 0)), - massert.Equal(0, <-ch), - massert.Equal(1, <-ch), - massert.Equal(2, <-ch), - )) - - // after the error the 3 and 4 Hooks should still be registered, but not - // called yet. - - select { - case <-ch: - t.Fatal("Hooks should not have been called yet") - default: - } + ctxB := mctx.NewChild(ctx, "b") + ctxB = RegisterHook(ctxB, 0, mkHook(5)) + ctxB1 := mctx.NewChild(ctxB, "1") + ctxB1 = RegisterHook(ctxB1, 0, mkHook(6)) + ctxB = mctx.WithChild(ctxB, ctxB1) + ctx = mctx.WithChild(ctx, ctxB) massert.Fatal(t, massert.All( massert.Nil(TriggerHooks(ctx, 0)), - massert.Equal(3, <-ch), - massert.Equal(4, <-ch), + massert.Equal([]int{1, 2, 3, 4, 5, 6}, out), + )) + + out = nil + massert.Fatal(t, massert.All( + massert.Nil(TriggerHooksReverse(ctx, 0)), + massert.Equal([]int{6, 5, 4, 3, 2, 1}, out), )) } diff --git a/mrun/mrun.go b/mrun/mrun.go index 6f60e3b..feb0ec1 100644 --- a/mrun/mrun.go +++ b/mrun/mrun.go @@ -3,6 +3,7 @@ package mrun import ( + "context" "errors" "github.com/mediocregopher/mediocre-go-lib/mctx" @@ -33,26 +34,24 @@ func (fe *futureErr) set(err error) { close(fe.doneCh) } -type ctxKey int +type threadCtxKey int -// Thread spawns a go-routine which executes the given function. When the passed -// in Context is canceled the Context within all threads spawned from it will -// be canceled as well. -// -// See Wait for accompanying functionality. -func Thread(ctx mctx.Context, fn func() error) { +// Thread spawns a go-routine which executes the given function. The returned +// Context tracks this go-routine, which can then be passed into the Wait +// function to block until the spawned go-routine returns. +func Thread(ctx context.Context, fn func() error) context.Context { futErr := newFutureErr() - mctx.GetSetMutableValue(ctx, false, ctxKey(0), func(i interface{}) interface{} { - futErrs, ok := i.([]*futureErr) - if !ok { - futErrs = make([]*futureErr, 0, 1) - } - return append(futErrs, futErr) - }) + oldFutErrs, _ := ctx.Value(threadCtxKey(0)).([]*futureErr) + futErrs := make([]*futureErr, len(oldFutErrs), len(oldFutErrs)+1) + copy(futErrs, oldFutErrs) + futErrs = append(futErrs, futErr) + ctx = context.WithValue(ctx, threadCtxKey(0), futErrs) go func() { futErr.set(fn()) }() + + return ctx } // ErrDone is returned from Wait if cancelCh is closed before all threads have @@ -60,7 +59,7 @@ func Thread(ctx mctx.Context, fn func() error) { var ErrDone = errors.New("Wait is done waiting") // Wait blocks until all go-routines spawned using Thread on the passed in -// Context, and all of its children, have returned. Any number of the threads +// Context (and its predecessors) have returned. Any number of the go-routines // may have returned already when Wait is called. // // If any of the thread functions returned an error during its runtime Wait will @@ -71,10 +70,8 @@ var ErrDone = errors.New("Wait is done waiting") // this function stops waiting and returns ErrDone. // // Wait is safe to call in parallel, and will return the same result if called -// multiple times in sequence. If new Thread calls have been made since the last -// Wait call, the results of those calls will be waited upon during subsequent -// Wait calls. -func Wait(ctx mctx.Context, cancelCh <-chan struct{}) error { +// multiple times in sequence. +func Wait(ctx context.Context, cancelCh <-chan struct{}) error { // First wait for all the children, and see if any of them return an error children := mctx.Children(ctx) for _, childCtx := range children { @@ -83,7 +80,7 @@ func Wait(ctx mctx.Context, cancelCh <-chan struct{}) error { } } - futErrs, _ := mctx.MutableValue(ctx, ctxKey(0)).([]*futureErr) + futErrs, _ := ctx.Value(threadCtxKey(0)).([]*futureErr) for _, futErr := range futErrs { err, ok := futErr.get(cancelCh) if !ok { diff --git a/mrun/mrun_test.go b/mrun/mrun_test.go index 372089b..2d4102d 100644 --- a/mrun/mrun_test.go +++ b/mrun/mrun_test.go @@ -1,6 +1,7 @@ package mrun import ( + "context" "errors" . "testing" "time" @@ -12,11 +13,11 @@ func TestThreadWait(t *T) { testErr := errors.New("test error") cancelCh := func(t time.Duration) <-chan struct{} { - tCtx, _ := mctx.WithTimeout(mctx.New(), t*2) + tCtx, _ := context.WithTimeout(context.Background(), t*2) return tCtx.Done() } - wait := func(ctx mctx.Context, shouldTake time.Duration) error { + wait := func(ctx context.Context, shouldTake time.Duration) error { start := time.Now() err := Wait(ctx, cancelCh(shouldTake*2)) if took := time.Since(start); took < shouldTake || took > shouldTake*4/3 { @@ -28,16 +29,16 @@ func TestThreadWait(t *T) { t.Run("noChildren", func(t *T) { t.Run("noBlock", func(t *T) { t.Run("noErr", func(t *T) { - ctx := mctx.New() - Thread(ctx, func() error { return nil }) + ctx := context.Background() + ctx = Thread(ctx, func() error { return nil }) if err := Wait(ctx, nil); err != nil { t.Fatal(err) } }) t.Run("err", func(t *T) { - ctx := mctx.New() - Thread(ctx, func() error { return testErr }) + ctx := context.Background() + ctx = Thread(ctx, func() error { return testErr }) if err := Wait(ctx, nil); err != testErr { t.Fatalf("should have got test error, got: %v", err) } @@ -46,8 +47,8 @@ func TestThreadWait(t *T) { t.Run("block", func(t *T) { t.Run("noErr", func(t *T) { - ctx := mctx.New() - Thread(ctx, func() error { + ctx := context.Background() + ctx = Thread(ctx, func() error { time.Sleep(1 * time.Second) return nil }) @@ -57,8 +58,8 @@ func TestThreadWait(t *T) { }) t.Run("err", func(t *T) { - ctx := mctx.New() - Thread(ctx, func() error { + ctx := context.Background() + ctx = Thread(ctx, func() error { time.Sleep(1 * time.Second) return testErr }) @@ -68,8 +69,8 @@ func TestThreadWait(t *T) { }) t.Run("canceled", func(t *T) { - ctx := mctx.New() - Thread(ctx, func() error { + ctx := context.Background() + ctx = Thread(ctx, func() error { time.Sleep(5 * time.Second) return testErr }) @@ -80,16 +81,17 @@ func TestThreadWait(t *T) { }) }) - ctxWithChild := func() (mctx.Context, mctx.Context) { - ctx := mctx.New() - return ctx, mctx.ChildOf(ctx, "child") + ctxWithChild := func() (context.Context, context.Context) { + ctx := context.Background() + return ctx, mctx.NewChild(ctx, "child") } t.Run("children", func(t *T) { t.Run("noBlock", func(t *T) { t.Run("noErr", func(t *T) { ctx, childCtx := ctxWithChild() - Thread(childCtx, func() error { return nil }) + childCtx = Thread(childCtx, func() error { return nil }) + ctx = mctx.WithChild(ctx, childCtx) if err := Wait(ctx, nil); err != nil { t.Fatal(err) } @@ -97,7 +99,8 @@ func TestThreadWait(t *T) { t.Run("err", func(t *T) { ctx, childCtx := ctxWithChild() - Thread(childCtx, func() error { return testErr }) + childCtx = Thread(childCtx, func() error { return testErr }) + ctx = mctx.WithChild(ctx, childCtx) if err := Wait(ctx, nil); err != testErr { t.Fatalf("should have got test error, got: %v", err) } @@ -107,10 +110,11 @@ func TestThreadWait(t *T) { t.Run("block", func(t *T) { t.Run("noErr", func(t *T) { ctx, childCtx := ctxWithChild() - Thread(childCtx, func() error { + childCtx = Thread(childCtx, func() error { time.Sleep(1 * time.Second) return nil }) + ctx = mctx.WithChild(ctx, childCtx) if err := wait(ctx, 1*time.Second); err != nil { t.Fatal(err) } @@ -118,10 +122,11 @@ func TestThreadWait(t *T) { t.Run("err", func(t *T) { ctx, childCtx := ctxWithChild() - Thread(childCtx, func() error { + childCtx = Thread(childCtx, func() error { time.Sleep(1 * time.Second) return testErr }) + ctx = mctx.WithChild(ctx, childCtx) if err := wait(ctx, 1*time.Second); err != testErr { t.Fatalf("should have got test error, got: %v", err) } @@ -129,10 +134,11 @@ func TestThreadWait(t *T) { t.Run("canceled", func(t *T) { ctx, childCtx := ctxWithChild() - Thread(childCtx, func() error { + childCtx = Thread(childCtx, func() error { time.Sleep(5 * time.Second) return testErr }) + ctx = mctx.WithChild(ctx, childCtx) if err := Wait(ctx, cancelCh(500*time.Millisecond)); err != ErrDone { t.Fatalf("should have got ErrDone, got: %v", err) } diff --git a/mtest/mtest.go b/mtest/mtest.go index 8f20c76..4543b28 100644 --- a/mtest/mtest.go +++ b/mtest/mtest.go @@ -2,35 +2,33 @@ package mtest import ( + "context" "testing" "github.com/mediocregopher/mediocre-go-lib/mcfg" - "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/mlog" "github.com/mediocregopher/mediocre-go-lib/mrun" ) -type ctxKey int +type envCtxKey int // NewCtx creates and returns a root Context suitable for testing. -func NewCtx() mctx.Context { - ctx := mctx.New() - mlog.From(ctx).SetMaxLevel(mlog.DebugLevel) - return ctx +func NewCtx() context.Context { + ctx := context.Background() + logger := mlog.NewLogger() + logger.SetMaxLevel(mlog.DebugLevel) + return mlog.Set(ctx, logger) } // SetEnv sets the given environment variable on the given Context, such that it // will be used as if it was a real environment variable when the Run function // from this package is called. -func SetEnv(ctx mctx.Context, key, val string) { - mctx.GetSetMutableValue(ctx, false, ctxKey(0), func(i interface{}) interface{} { - m, _ := i.(map[string]string) - if m == nil { - m = map[string]string{} - } - m[key] = val - return m - }) +func SetEnv(ctx context.Context, key, val string) context.Context { + prevEnv, _ := ctx.Value(envCtxKey(0)).([][2]string) + env := make([][2]string, len(prevEnv), len(prevEnv)+1) + copy(env, prevEnv) + env = append(env, [2]string{key, val}) + return context.WithValue(ctx, envCtxKey(0), env) } // Run performs the following using the given Context: @@ -45,11 +43,11 @@ func SetEnv(ctx mctx.Context, key, val string) { // // The intention is that Run is used within a test on a Context created via // NewCtx, after any setup functions have been called (e.g. mnet.MListen). -func Run(ctx mctx.Context, t *testing.T, body func()) { - envMap, _ := mctx.MutableValue(ctx, ctxKey(0)).(map[string]string) - env := make([]string, 0, len(envMap)) - for key, val := range envMap { - env = append(env, key+"="+val) +func Run(ctx context.Context, t *testing.T, body func()) { + envTups, _ := ctx.Value(envCtxKey(0)).([][2]string) + env := make([]string, 0, len(envTups)) + for _, tup := range envTups { + env = append(env, tup[0]+"="+tup[1]) } if err := mcfg.Populate(ctx, mcfg.SourceEnv{Env: env}); err != nil { diff --git a/mtest/mtest_test.go b/mtest/mtest_test.go index 26bc111..473ecdc 100644 --- a/mtest/mtest_test.go +++ b/mtest/mtest_test.go @@ -8,8 +8,8 @@ import ( func TestRun(t *T) { ctx := NewCtx() - arg := mcfg.RequiredString(ctx, "arg", "Required by this test") - SetEnv(ctx, "ARG", "foo") + ctx, arg := mcfg.RequiredString(ctx, "arg", "Required by this test") + ctx = SetEnv(ctx, "ARG", "foo") Run(ctx, t, func() { if *arg != "foo" { t.Fatalf(`arg not set to "foo", is set to %q`, *arg)