mctx: refactor so that contexts no longer carry mutable data

This change required refactoring nearly every package in this project,
but it does a lot to simplify mctx and make other code using it easier
to think about.

Other code, such as mlog and mcfg, had to be slightly modified for this
change to work as well.
This commit is contained in:
Brian Picciano 2019-02-05 15:18:17 -05:00
parent 0c2c49501e
commit 4b446a0efc
34 changed files with 969 additions and 892 deletions

View File

@ -12,6 +12,7 @@ package main
*/
import (
"context"
"net/http"
"net/url"
"time"
@ -19,7 +20,6 @@ import (
"github.com/mediocregopher/mediocre-go-lib/m"
"github.com/mediocregopher/mediocre-go-lib/mcfg"
"github.com/mediocregopher/mediocre-go-lib/mcrypto"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/mhttp"
"github.com/mediocregopher/mediocre-go-lib/mlog"
"github.com/mediocregopher/mediocre-go-lib/mrand"
@ -30,27 +30,26 @@ import (
func main() {
ctx := m.NewServiceCtx()
logger := mlog.From(ctx)
cookieName := mcfg.String(ctx, "cookie-name", "_totp_proxy", "String to use as the name for cookies")
cookieTimeout := mcfg.Duration(ctx, "cookie-timeout", mtime.Duration{1 * time.Hour}, "Timeout for cookies")
ctx, cookieName := mcfg.String(ctx, "cookie-name", "_totp_proxy", "String to use as the name for cookies")
ctx, cookieTimeout := mcfg.Duration(ctx, "cookie-timeout", mtime.Duration{1 * time.Hour}, "Timeout for cookies")
var userSecrets map[string]string
mcfg.RequiredJSON(ctx, "users", &userSecrets, "JSON object which maps usernames to their TOTP secret strings")
ctx = mcfg.RequiredJSON(ctx, "users", &userSecrets, "JSON object which maps usernames to their TOTP secret strings")
var secret mcrypto.Secret
secretStr := mcfg.String(ctx, "secret", "", "String used to sign authentication tokens. If one isn't given a new one will be generated on each startup, invalidating all previous tokens.")
mrun.OnStart(ctx, func(mctx.Context) error {
ctx, secretStr := mcfg.String(ctx, "secret", "", "String used to sign authentication tokens. If one isn't given a new one will be generated on each startup, invalidating all previous tokens.")
ctx = mrun.OnStart(ctx, func(context.Context) error {
if *secretStr == "" {
*secretStr = mrand.Hex(32)
}
logger.Info("generating secret")
mlog.Info(ctx, "generating secret")
secret = mcrypto.NewSecret([]byte(*secretStr))
return nil
})
proxyHandler := new(struct{ http.Handler })
proxyURL := mcfg.RequiredString(ctx, "dst-url", "URL to proxy requests to. Only the scheme and host should be set.")
mrun.OnStart(ctx, func(mctx.Context) error {
ctx, proxyURL := mcfg.RequiredString(ctx, "dst-url", "URL to proxy requests to. Only the scheme and host should be set.")
ctx = mrun.OnStart(ctx, func(context.Context) error {
u, err := url.Parse(*proxyURL)
if err != nil {
return err
@ -61,8 +60,8 @@ func main() {
authHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// TODO mlog.FromHTTP?
authLogger := logger.Clone()
authLogger.SetKV(mlog.CtxKV(r.Context()))
// TODO annotate this ctx
ctx := r.Context()
unauthorized := func() {
w.Header().Add("WWW-Authenticate", "Basic")
@ -80,7 +79,7 @@ func main() {
}
if cookie, _ := r.Cookie(*cookieName); cookie != nil {
authLogger.Debug("authenticating with cookie", mlog.KV{"cookie": cookie.String()})
mlog.Debug(ctx, "authenticating with cookie", mlog.KV{"cookie": cookie.String()})
var sig mcrypto.Signature
if err := sig.UnmarshalText([]byte(cookie.Value)); err == nil {
err := mcrypto.VerifyString(secret, sig, "")
@ -92,7 +91,7 @@ func main() {
}
if user, pass, ok := r.BasicAuth(); ok && pass != "" {
logger.Debug("authenticating with user/pass", mlog.KV{
mlog.Debug(ctx, "authenticating with user/pass", mlog.KV{
"user": user,
"pass": pass,
})
@ -107,6 +106,6 @@ func main() {
unauthorized()
})
mhttp.MListenAndServe(ctx, authHandler)
ctx, _ = mhttp.MListenAndServe(ctx, authHandler)
m.Run(ctx)
}

29
m/m.go
View File

@ -5,11 +5,11 @@
package m
import (
"context"
"os"
"os/signal"
"github.com/mediocregopher/mediocre-go-lib/mcfg"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/merr"
"github.com/mediocregopher/mediocre-go-lib/mlog"
"github.com/mediocregopher/mediocre-go-lib/mrun"
@ -32,20 +32,19 @@ func CfgSource() mcfg.Source {
// debug information about the running process can be accessed.
//
// TODO set up the debug endpoint.
func NewServiceCtx() mctx.Context {
ctx := mctx.New()
func NewServiceCtx() context.Context {
ctx := context.Background()
// set up log level handling
logLevelStr := mcfg.String(ctx, "log-level", "info", "Maximum log level which will be printed.")
mrun.OnStart(ctx, func(mctx.Context) error {
logger := mlog.NewLogger()
ctx = mlog.Set(ctx, logger)
ctx, logLevelStr := mcfg.String(ctx, "log-level", "info", "Maximum log level which will be printed.")
ctx = mrun.OnStart(ctx, func(context.Context) error {
logLevel := mlog.LevelFromString(*logLevelStr)
if logLevel == nil {
return merr.New("invalid log level", "log-level", *logLevelStr)
}
mlog.CtxSetAll(ctx, func(_ mctx.Context, logger *mlog.Logger) *mlog.Logger {
logger.SetMaxLevel(logLevel)
return logger
})
logger.SetMaxLevel(logLevel)
return nil
})
@ -56,23 +55,23 @@ func NewServiceCtx() mctx.Context {
// start event, waiting for an interrupt, and then triggering the stop event.
// Run will block until the stop event is done. If any errors are encountered a
// fatal is thrown.
func Run(ctx mctx.Context) {
func Run(ctx context.Context) {
log := mlog.From(ctx)
if err := mcfg.Populate(ctx, CfgSource()); err != nil {
log.Fatal("error populating configuration", merr.KV(err))
log.Fatal(ctx, "error populating configuration", merr.KV(err))
} else if err := mrun.Start(ctx); err != nil {
log.Fatal("error triggering start event", merr.KV(err))
log.Fatal(ctx, "error triggering start event", merr.KV(err))
}
{
ch := make(chan os.Signal, 1)
signal.Notify(ch, os.Interrupt)
s := <-ch
log.Info("signal received, stopping", mlog.KV{"signal": s})
log.Info(ctx, "signal received, stopping", mlog.KV{"signal": s})
}
if err := mrun.Stop(ctx); err != nil {
log.Fatal("error triggering stop event", merr.KV(err))
log.Fatal(ctx, "error triggering stop event", merr.KV(err))
}
log.Info("exiting process")
log.Info(ctx, "exiting process")
}

View File

@ -26,7 +26,8 @@ func TestServiceCtx(t *T) {
// create a child Context before running to ensure it the change propagates
// correctly.
ctxA := mctx.ChildOf(ctx, "A")
ctxA := mctx.NewChild(ctx, "A")
ctx = mctx.WithChild(ctx, ctxA)
params := mcfg.ParamValues{{Name: "log-level", Value: json.RawMessage(`"DEBUG"`)}}
if err := mcfg.Populate(ctx, params); err != nil {
@ -35,14 +36,16 @@ func TestServiceCtx(t *T) {
t.Fatal(err)
}
mlog.From(ctxA).Info("foo")
mlog.From(ctxA).Debug("bar")
mlog.From(ctxA).Info(ctxA, "foo")
mlog.From(ctxA).Debug(ctxA, "bar")
massert.Fatal(t, massert.All(
massert.Len(msgs, 2),
massert.Equal(msgs[0].Level.String(), "INFO"),
massert.Equal(msgs[0].Description.String(), "(/A) foo"),
massert.Equal(msgs[0].Description, "foo"),
massert.Equal(msgs[0].Context, ctxA),
massert.Equal(msgs[1].Level.String(), "DEBUG"),
massert.Equal(msgs[1].Description.String(), "(/A) bar"),
massert.Equal(msgs[1].Description, "bar"),
massert.Equal(msgs[1].Context, ctxA),
))
})
}

View File

@ -8,6 +8,7 @@ import (
"sort"
"strings"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/merr"
)
@ -108,7 +109,7 @@ func (cli SourceCLI) Parse(params []Param) ([]ParamValue, error) {
pvs = append(pvs, ParamValue{
Name: p.Name,
Path: p.Path,
Path: mctx.Path(p.Context),
Value: p.fuzzyParse(pvStrVal),
})
@ -127,7 +128,7 @@ func (cli SourceCLI) Parse(params []Param) ([]ParamValue, error) {
func (cli SourceCLI) cliParams(params []Param) (map[string]Param, error) {
m := map[string]Param{}
for _, p := range params {
key := strings.Join(append(p.Path, p.Name), cliKeyJoin)
key := strings.Join(append(mctx.Path(p.Context), p.Name), cliKeyJoin)
m[cliKeyPrefix+key] = p
}
return m, nil

View File

@ -2,11 +2,11 @@ package mcfg
import (
"bytes"
"context"
"strings"
. "testing"
"time"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/mrand"
"github.com/mediocregopher/mediocre-go-lib/mtest/mchk"
"github.com/stretchr/testify/assert"
@ -14,12 +14,12 @@ import (
)
func TestSourceCLIHelp(t *T) {
ctx := mctx.New()
Int(ctx, "foo", 5, "Test int param ") // trailing space should be trimmed
Bool(ctx, "bar", "Test bool param.")
String(ctx, "baz", "baz", "Test string param")
RequiredString(ctx, "baz2", "")
RequiredString(ctx, "baz3", "")
ctx := context.Background()
ctx, _ = Int(ctx, "foo", 5, "Test int param ") // trailing space should be trimmed
ctx, _ = Bool(ctx, "bar", "Test bool param.")
ctx, _ = String(ctx, "baz", "baz", "Test string param")
ctx, _ = RequiredString(ctx, "baz2", "")
ctx, _ = RequiredString(ctx, "baz3", "")
src := SourceCLI{}
buf := new(bytes.Buffer)

View File

@ -4,6 +4,7 @@ import (
"os"
"strings"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/merr"
)
@ -49,7 +50,7 @@ func (env SourceEnv) Parse(params []Param) ([]ParamValue, error) {
pM := map[string]Param{}
for _, p := range params {
name := env.expectedName(p.Path, p.Name)
name := env.expectedName(mctx.Path(p.Context), p.Name)
pM[name] = p
}
@ -63,7 +64,7 @@ func (env SourceEnv) Parse(params []Param) ([]ParamValue, error) {
if p, ok := pM[k]; ok {
pvs = append(pvs, ParamValue{
Name: p.Name,
Path: p.Path,
Path: mctx.Path(p.Context),
Value: p.fuzzyParse(v),
})
}

View File

@ -3,6 +3,7 @@
package mcfg
import (
"context"
"crypto/md5"
"encoding/hex"
"encoding/json"
@ -17,28 +18,10 @@ import (
// - JSON file
// - YAML file
type ctxCfg struct {
path []string
params map[string]Param
}
type ctxKey int
func get(ctx mctx.Context) *ctxCfg {
return mctx.GetSetMutableValue(ctx, true, ctxKey(0),
func(interface{}) interface{} {
return &ctxCfg{
path: mctx.Path(ctx),
params: map[string]Param{},
}
},
).(*ctxCfg)
}
func sortParams(params []Param) {
sort.Slice(params, func(i, j int) bool {
a, b := params[i], params[j]
aPath, bPath := a.Path, b.Path
aPath, bPath := mctx.Path(a.Context), mctx.Path(b.Context)
for {
switch {
case len(aPath) == 0 && len(bPath) == 0:
@ -59,12 +42,12 @@ func sortParams(params []Param) {
// returns all Params gathered by recursively retrieving them from this Context
// and its children. Returned Params are sorted according to their Path and
// Name.
func collectParams(ctx mctx.Context) []Param {
func collectParams(ctx context.Context) []Param {
var params []Param
var visit func(mctx.Context)
visit = func(ctx mctx.Context) {
for _, param := range get(ctx).params {
var visit func(context.Context)
visit = func(ctx context.Context) {
for _, param := range getLocalParams(ctx) {
params = append(params, param)
}
@ -98,9 +81,10 @@ func populate(params []Param, src Source) error {
// later. There should not be any duplicates here.
pM := map[string]Param{}
for _, p := range params {
hash := paramHash(p.Path, p.Name)
path := mctx.Path(p.Context)
hash := paramHash(path, p.Name)
if _, ok := pM[hash]; ok {
panic("duplicate Param: " + paramFullName(p.Path, p.Name))
panic("duplicate Param: " + paramFullName(path, p.Name))
}
pM[hash] = p
}
@ -127,7 +111,7 @@ func populate(params []Param, src Source) error {
continue
} else if _, ok := pvM[hash]; !ok {
return merr.New("required parameter is not set",
"param", paramFullName(p.Path, p.Name))
"param", paramFullName(mctx.Path(p.Context), p.Name))
}
}
@ -144,12 +128,12 @@ func populate(params []Param, src Source) error {
}
// Populate uses the Source to populate the values of all Params which were
// added to the given mctx.Context, and all of its children. Populate may be
// called multiple times with the same mctx.Context, each time will only affect
// the values of the Params which were provided by the respective Source.
// added to the given Context, and all of its children. Populate may be called
// multiple times with the same Context, each time will only affect the values
// of the Params which were provided by the respective Source.
//
// Source may be nil to indicate that no configuration is provided. Only default
// values will be used, and if any paramaters are required this will error.
func Populate(ctx mctx.Context, src Source) error {
func Populate(ctx context.Context, src Source) error {
return populate(collectParams(ctx), src)
}

View File

@ -1,6 +1,7 @@
package mcfg
import (
"context"
. "testing"
"github.com/mediocregopher/mediocre-go-lib/mctx"
@ -9,11 +10,12 @@ import (
func TestPopulate(t *T) {
{
ctx := mctx.New()
a := Int(ctx, "a", 0, "")
ctxChild := mctx.ChildOf(ctx, "foo")
b := Int(ctxChild, "b", 0, "")
c := Int(ctxChild, "c", 0, "")
ctx := context.Background()
ctx, a := Int(ctx, "a", 0, "")
ctxChild := mctx.NewChild(ctx, "foo")
ctxChild, b := Int(ctxChild, "b", 0, "")
ctxChild, c := Int(ctxChild, "c", 0, "")
ctx = mctx.WithChild(ctx, ctxChild)
err := Populate(ctx, SourceCLI{
Args: []string{"--a=1", "--foo-b=2"},
@ -25,11 +27,12 @@ func TestPopulate(t *T) {
}
{ // test that required params are enforced
ctx := mctx.New()
a := Int(ctx, "a", 0, "")
ctxChild := mctx.ChildOf(ctx, "foo")
b := Int(ctxChild, "b", 0, "")
c := RequiredInt(ctxChild, "c", "")
ctx := context.Background()
ctx, a := Int(ctx, "a", 0, "")
ctxChild := mctx.NewChild(ctx, "foo")
ctxChild, b := Int(ctxChild, "b", 0, "")
ctxChild, c := RequiredInt(ctxChild, "c", "")
ctx = mctx.WithChild(ctx, ctxChild)
err := Populate(ctx, SourceCLI{
Args: []string{"--a=1", "--foo-b=2"},

View File

@ -1,6 +1,7 @@
package mcfg
import (
"context"
"encoding/json"
"fmt"
"strings"
@ -10,15 +11,16 @@ import (
)
// Param is a configuration parameter which can be populated by Populate. The
// Param will exist as part of an mctx.Context, relative to its Path. For
// example, a Param with name "addr" under an mctx.Context with Path of
// []string{"foo","bar"} will be setabble on the CLI via "--foo-bar-addr". Other
// configuration Sources may treat the path/name differently, however.
// Param will exist as part of a Context, relative to its path (see the mctx
// package for more on Context path). For example, a Param with name "addr"
// under a Context with path of []string{"foo","bar"} will be setable on the CLI
// via "--foo-bar-addr". Other configuration Sources may treat the path/name
// differently, however.
//
// Param values are always unmarshaled as JSON values into the Into field of the
// Param, regardless of the actual Source.
type Param struct {
// How the parameter will be identified within an mctx.Context.
// How the parameter will be identified within a Context.
Name string
// A helpful description of how a parameter is expected to be used.
@ -42,9 +44,9 @@ type Param struct {
// value of the parameter.
Into interface{}
// The Path field of the Cfg this Param is attached to. NOTE that this
// will be automatically filled in when the Param is added to the Cfg.
Path []string
// The Context this Param was added to. NOTE that this will be automatically
// filled in by MustAdd when the Param is added to the Context.
Context context.Context
}
func paramFullName(path []string, name string) string {
@ -65,118 +67,145 @@ func (p Param) fuzzyParse(v string) json.RawMessage {
return json.RawMessage(v)
}
// MustAdd adds the given Param to the mctx.Context. It will panic if a Param of
// the same Name already exists in the mctx.Context.
func MustAdd(ctx mctx.Context, param Param) {
type ctxKey string
func getParam(ctx context.Context, name string) (Param, bool) {
param, ok := mctx.LocalValue(ctx, ctxKey(name)).(Param)
return param, ok
}
// MustAdd returns a Context with the given Param added to it. It will panic if
// a Param with the same Name already exists in the Context.
func MustAdd(ctx context.Context, param Param) context.Context {
param.Name = strings.ToLower(param.Name)
param.Path = mctx.Path(ctx)
param.Context = ctx
cfg := get(ctx)
if _, ok := cfg.params[param.Name]; ok {
panic(fmt.Sprintf("Context Path:%#v Name:%q already exists", param.Path, param.Name))
if _, ok := getParam(ctx, param.Name); ok {
path := mctx.Path(ctx)
panic(fmt.Sprintf("Context Path:%#v Name:%q already exists", path, param.Name))
}
cfg.params[param.Name] = param
return mctx.WithLocalValue(ctx, ctxKey(param.Name), param)
}
// Int64 returns an *int64 which will be populated once Populate is run.
func Int64(ctx mctx.Context, name string, defaultVal int64, usage string) *int64 {
func getLocalParams(ctx context.Context) []Param {
localVals := mctx.LocalValues(ctx)
params := make([]Param, 0, len(localVals))
for _, val := range localVals {
if param, ok := val.(Param); ok {
params = append(params, param)
}
}
return params
}
// Int64 returns an *int64 which will be populated once Populate is run on the
// returned Context.
func Int64(ctx context.Context, name string, defaultVal int64, usage string) (context.Context, *int64) {
i := defaultVal
MustAdd(ctx, Param{Name: name, Usage: usage, Into: &i})
return &i
ctx = MustAdd(ctx, Param{Name: name, Usage: usage, Into: &i})
return ctx, &i
}
// RequiredInt64 returns an *int64 which will be populated once Populate is run,
// and which must be supplied by a configuration Source.
func RequiredInt64(ctx mctx.Context, name string, usage string) *int64 {
// RequiredInt64 returns an *int64 which will be populated once Populate is run
// on the returned Context, and which must be supplied by a configuration
// Source.
func RequiredInt64(ctx context.Context, name string, usage string) (context.Context, *int64) {
var i int64
MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &i})
return &i
ctx = MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &i})
return ctx, &i
}
// Int returns an *int which will be populated once Populate is run.
func Int(ctx mctx.Context, name string, defaultVal int, usage string) *int {
// Int returns an *int which will be populated once Populate is run on the
// returned Context.
func Int(ctx context.Context, name string, defaultVal int, usage string) (context.Context, *int) {
i := defaultVal
MustAdd(ctx, Param{Name: name, Usage: usage, Into: &i})
return &i
ctx = MustAdd(ctx, Param{Name: name, Usage: usage, Into: &i})
return ctx, &i
}
// RequiredInt returns an *int which will be populated once Populate is run, and
// which must be supplied by a configuration Source.
func RequiredInt(ctx mctx.Context, name string, usage string) *int {
// RequiredInt returns an *int which will be populated once Populate is run on
// the returned Context, and which must be supplied by a configuration Source.
func RequiredInt(ctx context.Context, name string, usage string) (context.Context, *int) {
var i int
MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &i})
return &i
ctx = MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &i})
return ctx, &i
}
// String returns a *string which will be populated once Populate is run.
func String(ctx mctx.Context, name, defaultVal, usage string) *string {
// String returns a *string which will be populated once Populate is run on the
// returned Context.
func String(ctx context.Context, name, defaultVal, usage string) (context.Context, *string) {
s := defaultVal
MustAdd(ctx, Param{Name: name, Usage: usage, IsString: true, Into: &s})
return &s
ctx = MustAdd(ctx, Param{Name: name, Usage: usage, IsString: true, Into: &s})
return ctx, &s
}
// RequiredString returns a *string which will be populated once Populate is
// run, and which must be supplied by a configuration Source.
func RequiredString(ctx mctx.Context, name, usage string) *string {
// run on the returned Context, and which must be supplied by a configuration
// Source.
func RequiredString(ctx context.Context, name, usage string) (context.Context, *string) {
var s string
MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, IsString: true, Into: &s})
return &s
ctx = MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, IsString: true, Into: &s})
return ctx, &s
}
// Bool returns a *bool which will be populated once Populate is run, and which
// defaults to false if unconfigured.
// Bool returns a *bool which will be populated once Populate is run on the
// returned Context, and which defaults to false if unconfigured.
//
// The default behavior of all Sources is that a boolean parameter will be set
// to true unless the value is "", 0, or false. In the case of the CLI Source
// the value will also be true when the parameter is used with no value at all,
// as would be expected.
func Bool(ctx mctx.Context, name, usage string) *bool {
func Bool(ctx context.Context, name, usage string) (context.Context, *bool) {
var b bool
MustAdd(ctx, Param{Name: name, Usage: usage, IsBool: true, Into: &b})
return &b
ctx = MustAdd(ctx, Param{Name: name, Usage: usage, IsBool: true, Into: &b})
return ctx, &b
}
// TS returns an *mtime.TS which will be populated once Populate is run.
func TS(ctx mctx.Context, name string, defaultVal mtime.TS, usage string) *mtime.TS {
// TS returns an *mtime.TS which will be populated once Populate is run on the
// returned Context.
func TS(ctx context.Context, name string, defaultVal mtime.TS, usage string) (context.Context, *mtime.TS) {
t := defaultVal
MustAdd(ctx, Param{Name: name, Usage: usage, Into: &t})
return &t
ctx = MustAdd(ctx, Param{Name: name, Usage: usage, Into: &t})
return ctx, &t
}
// RequiredTS returns an *mtime.TS which will be populated once Populate is run,
// and which must be supplied by a configuration Source.
func RequiredTS(ctx mctx.Context, name, usage string) *mtime.TS {
// RequiredTS returns an *mtime.TS which will be populated once Populate is run
// on the returned Context, and which must be supplied by a configuration
// Source.
func RequiredTS(ctx context.Context, name, usage string) (context.Context, *mtime.TS) {
var t mtime.TS
MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &t})
return &t
ctx = MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: &t})
return ctx, &t
}
// Duration returns an *mtime.Duration which will be populated once
// Populate is run.
func Duration(ctx mctx.Context, name string, defaultVal mtime.Duration, usage string) *mtime.Duration {
// Duration returns an *mtime.Duration which will be populated once Populate is
// run on the returned Context.
func Duration(ctx context.Context, name string, defaultVal mtime.Duration, usage string) (context.Context, *mtime.Duration) {
d := defaultVal
MustAdd(ctx, Param{Name: name, Usage: usage, IsString: true, Into: &d})
return &d
ctx = MustAdd(ctx, Param{Name: name, Usage: usage, IsString: true, Into: &d})
return ctx, &d
}
// RequiredDuration returns an *mtime.Duration which will be populated once
// Populate is run, and which must be supplied by a configuration Source.
func RequiredDuration(ctx mctx.Context, name string, defaultVal mtime.Duration, usage string) *mtime.Duration {
// Populate is run on the returned Context, and which must be supplied by a
// configuration Source.
func RequiredDuration(ctx context.Context, name string, defaultVal mtime.Duration, usage string) (context.Context, *mtime.Duration) {
var d mtime.Duration
MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, IsString: true, Into: &d})
return &d
ctx = MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, IsString: true, Into: &d})
return ctx, &d
}
// JSON reads the parameter value as a JSON value and unmarshals it into the
// given interface{} (which should be a pointer). The receiver (into) is also
// used to determine the default value.
func JSON(ctx mctx.Context, name string, into interface{}, usage string) {
MustAdd(ctx, Param{Name: name, Usage: usage, Into: into})
func JSON(ctx context.Context, name string, into interface{}, usage string) context.Context {
return MustAdd(ctx, Param{Name: name, Usage: usage, Into: into})
}
// RequiredJSON reads the parameter value as a JSON value and unmarshals it into
// the given interface{} (which should be a pointer). The value must be supplied
// by a configuration Source.
func RequiredJSON(ctx mctx.Context, name string, into interface{}, usage string) {
MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: into})
func RequiredJSON(ctx context.Context, name string, into interface{}, usage string) context.Context {
return MustAdd(ctx, Param{Name: name, Required: true, Usage: usage, Into: into})
}

View File

@ -1,6 +1,7 @@
package mcfg
import (
"context"
"encoding/json"
"fmt"
. "testing"
@ -16,24 +17,35 @@ import (
// all the code they share
type srcCommonState struct {
ctx mctx.Context
availCtxs []mctx.Context
expPVs []ParamValue
// availCtxs get updated in place as the run goes on, and mkRoot is used to
// create the latest version of the root context based on them
availCtxs []*context.Context
mkRoot func() context.Context
expPVs []ParamValue
// each specific test should wrap this to add the Source itself
}
func newSrcCommonState() srcCommonState {
var scs srcCommonState
scs.ctx = mctx.New()
{
a := mctx.ChildOf(scs.ctx, "a")
b := mctx.ChildOf(scs.ctx, "b")
c := mctx.ChildOf(scs.ctx, "c")
ab := mctx.ChildOf(a, "b")
bc := mctx.ChildOf(b, "c")
abc := mctx.ChildOf(ab, "c")
scs.availCtxs = []mctx.Context{scs.ctx, a, b, c, ab, bc, abc}
root := context.Background()
a := mctx.NewChild(root, "a")
b := mctx.NewChild(root, "b")
c := mctx.NewChild(root, "c")
ab := mctx.NewChild(a, "b")
bc := mctx.NewChild(b, "c")
abc := mctx.NewChild(ab, "c")
scs.availCtxs = []*context.Context{&root, &a, &b, &c, &ab, &bc, &abc}
scs.mkRoot = func() context.Context {
ab := mctx.WithChild(ab, abc)
a := mctx.WithChild(a, ab)
b := mctx.WithChild(b, bc)
root := mctx.WithChild(root, a)
root = mctx.WithChild(root, b)
root = mctx.WithChild(root, c)
return root
}
}
return scs
}
@ -57,10 +69,9 @@ func (scs srcCommonState) next() srcCommonParams {
}
p.availCtxI = mrand.Intn(len(scs.availCtxs))
thisCtx := scs.availCtxs[p.availCtxI]
p.path = mctx.Path(thisCtx)
p.path = mctx.Path(*scs.availCtxs[p.availCtxI])
p.isBool = mrand.Intn(2) == 0
p.isBool = mrand.Intn(8) == 0
if !p.isBool {
p.nonBoolType = mrand.Element([]string{
"int",
@ -105,11 +116,11 @@ func (scs srcCommonState) applyCtxAndPV(p srcCommonParams) srcCommonState {
// the Sources don't actually care about the other fields of Param,
// those are only used by Populate once it has all ParamValues together
}
MustAdd(thisCtx, ctxP)
ctxP = get(thisCtx).params[p.name] // get it back out to get any added fields
*thisCtx = MustAdd(*thisCtx, ctxP)
ctxP, _ = getParam(*thisCtx, ctxP.Name) // get it back out to get any added fields
if !p.unset {
pv := ParamValue{Name: ctxP.Name, Path: ctxP.Path}
pv := ParamValue{Name: ctxP.Name, Path: mctx.Path(ctxP.Context)}
if p.isBool {
pv.Value = json.RawMessage("true")
} else {
@ -131,7 +142,8 @@ func (scs srcCommonState) applyCtxAndPV(p srcCommonParams) srcCommonState {
// given a Source asserts that it's Parse method returns the expected
// ParamValues
func (scs srcCommonState) assert(s Source) error {
gotPVs, err := s.Parse(collectParams(scs.ctx))
root := scs.mkRoot()
gotPVs, err := s.Parse(collectParams(root))
if err != nil {
return err
}
@ -142,10 +154,10 @@ func (scs srcCommonState) assert(s Source) error {
}
func TestSources(t *T) {
ctx := mctx.New()
a := RequiredInt(ctx, "a", "")
b := RequiredInt(ctx, "b", "")
c := RequiredInt(ctx, "c", "")
ctx := context.Background()
ctx, a := RequiredInt(ctx, "a", "")
ctx, b := RequiredInt(ctx, "b", "")
ctx, c := RequiredInt(ctx, "c", "")
err := Populate(ctx, Sources{
SourceCLI{Args: []string{"--a=1", "--b=666"}},
@ -160,11 +172,12 @@ func TestSources(t *T) {
}
func TestSourceParamValues(t *T) {
ctx := mctx.New()
a := RequiredInt(ctx, "a", "")
foo := mctx.ChildOf(ctx, "foo")
b := RequiredString(foo, "b", "")
c := Bool(foo, "c", "")
ctx := context.Background()
ctx, a := RequiredInt(ctx, "a", "")
foo := mctx.NewChild(ctx, "foo")
foo, b := RequiredString(foo, "b", "")
foo, c := Bool(foo, "c", "")
ctx = mctx.WithChild(ctx, foo)
err := Populate(ctx, ParamValues{
{Name: "a", Value: json.RawMessage(`4`)},

View File

@ -10,151 +10,163 @@
package mctx
import (
"sync"
"time"
goctx "context"
"context"
"fmt"
)
// Context is the same as the builtin type, but is used to indicate that the
// Context originally came from this package (aka New or ChildOf).
type Context goctx.Context
// CancelFunc is a direct alias of the type from the context package, see its
// docs.
type CancelFunc = goctx.CancelFunc
// WithValue mimics the function from the context package.
func WithValue(parent Context, key, val interface{}) Context {
return Context(goctx.WithValue(goctx.Context(parent), key, val))
}
// WithCancel mimics the function from the context package.
func WithCancel(parent Context) (Context, CancelFunc) {
ctx, fn := goctx.WithCancel(goctx.Context(parent))
return Context(ctx), fn
}
// WithDeadline mimics the function from the context package.
func WithDeadline(parent Context, t time.Time) (Context, CancelFunc) {
ctx, fn := goctx.WithDeadline(goctx.Context(parent), t)
return Context(ctx), fn
}
// WithTimeout mimics the function from the context package.
func WithTimeout(parent Context, d time.Duration) (Context, CancelFunc) {
ctx, fn := goctx.WithTimeout(goctx.Context(parent), d)
return Context(ctx), fn
}
////////////////////////////////////////////////////////////////////////////////
type mutVal struct {
l sync.RWMutex
v interface{}
}
type context struct {
goctx.Context
path []string
l sync.RWMutex
parent *context
children map[string]Context
mutL sync.RWMutex
mutVals map[interface{}]*mutVal
}
// New returns a new context which can be used as the root context for all
// purposes in this framework.
func New() Context {
return &context{Context: goctx.Background()}
}
//func New() Context {
// return &context{Context: goctx.Background()}
//}
func getCtx(Ctx Context) *context {
ctx, ok := Ctx.(*context)
type ancestryKey int // 0 -> children, 1 -> parent, 2 -> path
const (
ancestryKeyChildren ancestryKey = iota
ancestryKeyChildrenMap
ancestryKeyParent
ancestryKeyPath
)
// Child returns the Context of the given name which was added to parent via
// WithChild, or nil if no Context of that name was ever added.
func Child(parent context.Context, name string) context.Context {
childrenMap, _ := parent.Value(ancestryKeyChildrenMap).(map[string]int)
if len(childrenMap) == 0 {
return nil
}
i, ok := childrenMap[name]
if !ok {
panic("non-conforming Context used")
return nil
}
return ctx
return parent.Value(ancestryKeyChildren).([]context.Context)[i]
}
// Path returns the sequence of names which were used to produce this context
// via the ChildOf function.
func Path(Ctx Context) []string {
return getCtx(Ctx).path
// Children returns all children of this Context which have been kept by
// WithChild, mapped by their name. If this Context wasn't produced by WithChild
// then this returns nil.
func Children(parent context.Context) []context.Context {
children, _ := parent.Value(ancestryKeyChildren).([]context.Context)
return children
}
// Children returns all children of this context which have been created by
// ChildOf, mapped by their name.
func Children(Ctx Context) map[string]Context {
ctx := getCtx(Ctx)
out := map[string]Context{}
ctx.l.RLock()
defer ctx.l.RUnlock()
for name, childCtx := range ctx.children {
out[name] = childCtx
func childrenCP(parent context.Context) ([]context.Context, map[string]int) {
children := Children(parent)
// plus 1 because this is most commonly used in WithChild, which will append
// to it. At any rate it doesn't hurt anything.
outChildren := make([]context.Context, len(children), len(children)+1)
copy(outChildren, children)
childrenMap, _ := parent.Value(ancestryKeyChildrenMap).(map[string]int)
outChildrenMap := make(map[string]int, len(childrenMap)+1)
for name, i := range childrenMap {
outChildrenMap[name] = i
}
return out
return outChildren, outChildrenMap
}
// Parent returns the parent Context of the given one, or nil if this is a root
// context (i.e. returned from New).
func Parent(Ctx Context) Context {
return getCtx(Ctx).parent
}
// Root returns the root Context from which this Context and all of its parents
// were derived (i.e. the Context which was originally returned from New).
// parentOf returns the Context from which this one was generated via NewChild.
// Returns nil if this Context was not generated via NewChild.
//
// If the given Context is the root then it is returned as-id.
func Root(Ctx Context) Context {
ctx := getCtx(Ctx)
for {
if ctx.parent == nil {
return ctx
// This is kept private because the behavior is a bit confusing. This will
// return the Context which was passed into NewChild, but users would probably
// expect it to return the one from WithChild if they were to call this.
func parentOf(ctx context.Context) context.Context {
parent, _ := ctx.Value(ancestryKeyParent).(context.Context)
return parent
}
// Path returns the sequence of names which were used to produce this Context
// via the NewChild function. If this Context wasn't produced by NewChild then
// this returns nil.
func Path(ctx context.Context) []string {
path, _ := ctx.Value(ancestryKeyPath).([]string)
return path
}
func pathCP(ctx context.Context) []string {
path := Path(ctx)
// plus 1 because this is most commonly used in NewChild, which will append
// to it. At any rate it doesn't hurt anything.
outPath := make([]string, len(path), len(path)+1)
copy(outPath, path)
return outPath
}
// Name returns the name this Context was generated with via NewChild, or false
// if this Context was not generated via NewChild.
func Name(ctx context.Context) (string, bool) {
path := Path(ctx)
if len(path) == 0 {
return "", false
}
return path[len(path)-1], true
}
// NewChild creates a new Context based off of the parent one, and returns a new
// instance of the passed in parent and the new child. The child will have a
// path which is the parent's path with the given name appended. The parent will
// have the new child as part of its set of children (see Children function).
//
// If the parent already has a child of the given name this function panics.
func NewChild(parent context.Context, name string) context.Context {
if Child(parent, name) != nil {
panic(fmt.Sprintf("child with name %q already exists on parent", name))
}
childPath := append(pathCP(parent), name)
child := withoutLocalValues(parent)
child = context.WithValue(child, ancestryKeyChildren, nil) // unset children
child = context.WithValue(child, ancestryKeyChildrenMap, nil) // unset children
child = context.WithValue(child, ancestryKeyParent, parent)
child = context.WithValue(child, ancestryKeyPath, childPath)
return child
}
func isChild(parent, child context.Context) bool {
parentPath, childPath := Path(parent), Path(child)
if len(parentPath) != len(childPath)-1 {
return false
}
for i := range parentPath {
if parentPath[i] != childPath[i] {
return false
}
ctx = ctx.parent
}
return true
}
// ChildOf creates a child of the given context with the given name and returns
// it. The Path of the returned context will be the path of the parent with its
// name appended to it. The Children function can be called on the parent to
// retrieve all children which have been made using this function.
//
// TODO If the given Context already has a child with the given name that child
// will be returned.
func ChildOf(Ctx Context, name string) Context {
ctx, childCtx := getCtx(Ctx), new(context)
ctx.l.Lock()
defer ctx.l.Unlock()
// set child's path field
childCtx.path = make([]string, 0, len(ctx.path)+1)
childCtx.path = append(childCtx.path, ctx.path...)
childCtx.path = append(childCtx.path, name)
// set child's parent field
childCtx.parent = ctx
// create child's ctx and store it in parent
if ctx.children == nil {
ctx.children = map[string]Context{}
// WithChild returns a modified parent which holds a reference to child in its
// Children set. If the child's name is already taken in the parent then this
// function panics.
func WithChild(parent, child context.Context) context.Context {
if !isChild(parent, child) {
panic(fmt.Sprintf("child cannot be kept by Context which is not its parent"))
}
ctx.children[name] = childCtx
return childCtx
name, _ := Name(child)
children, childrenMap := childrenCP(parent)
if _, ok := childrenMap[name]; ok {
panic(fmt.Sprintf("child with name %q already exists on parent", name))
}
children = append(children, child)
childrenMap[name] = len(children) - 1
parent = context.WithValue(parent, ancestryKeyChildren, children)
parent = context.WithValue(parent, ancestryKeyChildrenMap, childrenMap)
return parent
}
// BreadthFirstVisit visits this Context and all of its children, and their
// children, in a breadth-first order. If the callback returns false then the
// function returns without visiting any more Contexts.
//
// The exact order of visitation is non-deterministic.
func BreadthFirstVisit(Ctx Context, callback func(Context) bool) {
queue := []Context{Ctx}
func BreadthFirstVisit(ctx context.Context, callback func(context.Context) bool) {
queue := []context.Context{ctx}
for len(queue) > 0 {
if !callback(queue[0]) {
return
@ -167,78 +179,56 @@ func BreadthFirstVisit(Ctx Context, callback func(Context) bool) {
}
////////////////////////////////////////////////////////////////////////////////
// code related to mutable values
// local value stuff
// MutableValue acts like the Value method, except that it only deals with
// keys/values set by SetMutableValue.
func MutableValue(Ctx Context, key interface{}) interface{} {
ctx := getCtx(Ctx)
ctx.mutL.RLock()
defer ctx.mutL.RUnlock()
if ctx.mutVals == nil {
return nil
} else if mVal, ok := ctx.mutVals[key]; ok {
mVal.l.RLock()
defer mVal.l.RUnlock()
return mVal.v
}
return nil
type localValsKey int
type localVal struct {
prev *localVal
key, val interface{}
}
// GetSetMutableValue is used to interact with a mutable value on the context in
// a thread-safe way. The key's value is retrieved and passed to the callback.
// The value returned from the callback is stored back into the context as well
// as being returned from this function.
//
// If noCallbackIfSet is set to true, then if the key is already set the value
// will be returned without calling the callback.
//
// The callback returning nil is equivalent to unsetting the value.
//
// Children of this context will _not_ inherit any of its mutable values.
//
// Within the callback it is fine to call other functions/methods on the
// Context, except for those related to mutable values for this same key (e.g.
// MutableValue and SetMutableValue).
func GetSetMutableValue(
Ctx Context, noCallbackIfSet bool,
key interface{}, fn func(interface{}) interface{},
) interface{} {
// WithLocalValue is like WithValue, but the stored value will not be present
// on any children created via WithChild. Local values must be retrieved with
// the LocalValue function in this package. Local values share a different
// namespace than the normal WithValue/Value values (i.e. they do not overlap).
func WithLocalValue(ctx context.Context, key, val interface{}) context.Context {
prev, _ := ctx.Value(localValsKey(0)).(*localVal)
return context.WithValue(ctx, localValsKey(0), &localVal{
prev: prev,
key: key, val: val,
})
}
// if noCallbackIfSet, do a fast lookup with MutableValue first.
if noCallbackIfSet {
if v := MutableValue(Ctx, key); v != nil {
return v
func withoutLocalValues(ctx context.Context) context.Context {
return context.WithValue(ctx, localValsKey(0), nil)
}
// LocalValue returns the value for the given key which was set by a call to
// WithLocalValue, or nil if no value was set for the given key.
func LocalValue(ctx context.Context, key interface{}) interface{} {
lv, _ := ctx.Value(localValsKey(0)).(*localVal)
for {
if lv == nil {
return nil
} else if lv.key == key {
return lv.val
}
lv = lv.prev
}
}
// LocalValues returns all key/value pairs which have been set on the Context
// via WithLocalValue.
func LocalValues(ctx context.Context) map[interface{}]interface{} {
m := map[interface{}]interface{}{}
lv, _ := ctx.Value(localValsKey(0)).(*localVal)
for {
if lv == nil {
return m
} else if _, ok := m[lv.key]; !ok {
m[lv.key] = lv.val
}
lv = lv.prev
}
ctx := getCtx(Ctx)
ctx.mutL.Lock()
if ctx.mutVals == nil {
ctx.mutVals = map[interface{}]*mutVal{}
}
mVal, ok := ctx.mutVals[key]
if !ok {
mVal = new(mutVal)
ctx.mutVals[key] = mVal
}
ctx.mutL.Unlock()
mVal.l.Lock()
defer mVal.l.Unlock()
// It's possible something happened between the first check inside the
// read-lock and now, so double check this case. It's still good to have the
// read-lock check there, it'll handle 99% of the cases.
if noCallbackIfSet && mVal.v != nil {
return mVal.v
}
mVal.v = fn(mVal.v)
// TODO if the new v is nil then key could be deleted out of mutVals. But
// doing so would be weird in the case that there's another routine which
// has already pulled this same mVal out of mutVals and is waiting on its
// mutex.
return mVal.v
}

View File

@ -1,18 +1,22 @@
package mctx
import (
"sync"
"context"
. "testing"
"github.com/mediocregopher/mediocre-go-lib/mtest/massert"
)
func TestInheritance(t *T) {
ctx := New()
ctx1 := ChildOf(ctx, "1")
ctx1a := ChildOf(ctx1, "a")
ctx1b := ChildOf(ctx1, "b")
ctx2 := ChildOf(ctx, "2")
ctx := context.Background()
ctx1 := NewChild(ctx, "1")
ctx1a := NewChild(ctx1, "a")
ctx1b := NewChild(ctx1, "b")
ctx1 = WithChild(ctx1, ctx1a)
ctx1 = WithChild(ctx1, ctx1b)
ctx2 := NewChild(ctx, "2")
ctx = WithChild(ctx, ctx1)
ctx = WithChild(ctx, ctx2)
massert.Fatal(t, massert.All(
massert.Len(Path(ctx), 0),
@ -23,129 +27,130 @@ func TestInheritance(t *T) {
))
massert.Fatal(t, massert.All(
massert.Equal(
map[string]Context{"1": ctx1, "2": ctx2},
Children(ctx),
),
massert.Equal(
map[string]Context{"a": ctx1a, "b": ctx1b},
Children(ctx1),
),
massert.Equal(
map[string]Context{},
Children(ctx2),
),
))
massert.Fatal(t, massert.All(
massert.Nil(Parent(ctx)),
massert.Equal(Parent(ctx1), ctx),
massert.Equal(Parent(ctx1a), ctx1),
massert.Equal(Parent(ctx1b), ctx1),
massert.Equal(Parent(ctx2), ctx),
))
massert.Fatal(t, massert.All(
massert.Equal(Root(ctx), ctx),
massert.Equal(Root(ctx1), ctx),
massert.Equal(Root(ctx1a), ctx),
massert.Equal(Root(ctx1b), ctx),
massert.Equal(Root(ctx2), ctx),
massert.Equal([]context.Context{ctx1, ctx2}, Children(ctx)),
massert.Equal([]context.Context{ctx1a, ctx1b}, Children(ctx1)),
massert.Len(Children(ctx2), 0),
))
}
func TestBreadFirstVisit(t *T) {
ctx := New()
ctx1 := ChildOf(ctx, "1")
ctx1a := ChildOf(ctx1, "a")
ctx1b := ChildOf(ctx1, "b")
ctx2 := ChildOf(ctx, "2")
ctx := context.Background()
ctx1 := NewChild(ctx, "1")
ctx1a := NewChild(ctx1, "a")
ctx1b := NewChild(ctx1, "b")
ctx1 = WithChild(ctx1, ctx1a)
ctx1 = WithChild(ctx1, ctx1b)
ctx2 := NewChild(ctx, "2")
ctx = WithChild(ctx, ctx1)
ctx = WithChild(ctx, ctx2)
{
got := make([]Context, 0, 5)
BreadthFirstVisit(ctx, func(ctx Context) bool {
got := make([]context.Context, 0, 5)
BreadthFirstVisit(ctx, func(ctx context.Context) bool {
got = append(got, ctx)
return true
})
// since children are stored in a map the exact order is non-deterministic
massert.Fatal(t, massert.Any(
massert.Equal([]Context{ctx, ctx1, ctx2, ctx1a, ctx1b}, got),
massert.Equal([]Context{ctx, ctx1, ctx2, ctx1b, ctx1a}, got),
massert.Equal([]Context{ctx, ctx2, ctx1, ctx1a, ctx1b}, got),
massert.Equal([]Context{ctx, ctx2, ctx1, ctx1b, ctx1a}, got),
))
massert.Fatal(t,
massert.Equal([]context.Context{ctx, ctx1, ctx2, ctx1a, ctx1b}, got),
)
}
{
got := make([]Context, 0, 3)
BreadthFirstVisit(ctx, func(ctx Context) bool {
got := make([]context.Context, 0, 3)
BreadthFirstVisit(ctx, func(ctx context.Context) bool {
if len(Path(ctx)) > 1 {
return false
}
got = append(got, ctx)
return true
})
massert.Fatal(t, massert.Any(
massert.Equal([]Context{ctx, ctx1, ctx2}, got),
massert.Equal([]Context{ctx, ctx2, ctx1}, got),
))
massert.Fatal(t,
massert.Equal([]context.Context{ctx, ctx1, ctx2}, got),
)
}
}
func TestMutableValues(t *T) {
fn := func(v interface{}) interface{} {
if v == nil {
return 0
}
return v.(int) + 1
}
func TestLocalValues(t *T) {
var aa []massert.Assertion
// test with no value set
ctx := context.Background()
massert.Fatal(t, massert.All(
massert.Nil(LocalValue(ctx, "foo")),
massert.Len(LocalValues(ctx), 0),
))
ctx := New()
aa = append(aa, massert.Equal(GetSetMutableValue(ctx, false, 0, fn), 0))
aa = append(aa, massert.Equal(GetSetMutableValue(ctx, false, 0, fn), 1))
aa = append(aa, massert.Equal(GetSetMutableValue(ctx, true, 0, fn), 1))
// test basic value retrieval
ctx = WithLocalValue(ctx, "foo", "bar")
massert.Fatal(t, massert.All(
massert.Equal("bar", LocalValue(ctx, "foo")),
massert.Equal(
map[interface{}]interface{}{"foo": "bar"},
LocalValues(ctx),
),
))
aa = append(aa, massert.Equal(MutableValue(ctx, 0), 1))
// test that doesn't conflict with WithValue
ctx = context.WithValue(ctx, "foo", "WithValue bar")
massert.Fatal(t, massert.All(
massert.Equal("bar", LocalValue(ctx, "foo")),
massert.Equal("WithValue bar", ctx.Value("foo")),
massert.Equal(
map[interface{}]interface{}{"foo": "bar"},
LocalValues(ctx),
),
))
ctx1 := ChildOf(ctx, "one")
aa = append(aa, massert.Equal(GetSetMutableValue(ctx1, true, 0, fn), 0))
aa = append(aa, massert.Equal(GetSetMutableValue(ctx1, false, 0, fn), 1))
aa = append(aa, massert.Equal(GetSetMutableValue(ctx1, true, 0, fn), 1))
// test that child doesn't get values
child := NewChild(ctx, "child")
massert.Fatal(t, massert.All(
massert.Equal("bar", LocalValue(ctx, "foo")),
massert.Nil(LocalValue(child, "foo")),
massert.Len(LocalValues(child), 0),
))
massert.Fatal(t, massert.All(aa...))
}
func TestMutableValuesParallel(t *T) {
const events = 1000000
const workers = 10
incr := func(v interface{}) interface{} {
if v == nil {
return 1
}
return v.(int) + 1
}
ch := make(chan bool, events)
for i := 0; i < events; i++ {
ch <- true
}
close(ch)
ctx := New()
wg := new(sync.WaitGroup)
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range ch {
GetSetMutableValue(ctx, false, 0, incr)
}
}()
}
wg.Wait()
massert.Fatal(t, massert.Equal(events, MutableValue(ctx, 0)))
// test that values on child don't affect parent values
child = WithLocalValue(child, "foo", "child bar")
ctx = WithChild(ctx, child)
massert.Fatal(t, massert.All(
massert.Equal("bar", LocalValue(ctx, "foo")),
massert.Equal("child bar", LocalValue(child, "foo")),
massert.Equal(
map[interface{}]interface{}{"foo": "bar"},
LocalValues(ctx),
),
massert.Equal(
map[interface{}]interface{}{"foo": "child bar"},
LocalValues(child),
),
))
// test that two With calls on the same context generate distinct contexts
childA := WithLocalValue(child, "foo2", "baz")
childB := WithLocalValue(child, "foo2", "buz")
massert.Fatal(t, massert.All(
massert.Equal("bar", LocalValue(ctx, "foo")),
massert.Equal("child bar", LocalValue(child, "foo")),
massert.Nil(LocalValue(child, "foo2")),
massert.Equal("baz", LocalValue(childA, "foo2")),
massert.Equal("buz", LocalValue(childB, "foo2")),
massert.Equal(
map[interface{}]interface{}{"foo": "child bar", "foo2": "baz"},
LocalValues(childA),
),
massert.Equal(
map[interface{}]interface{}{"foo": "child bar", "foo2": "buz"},
LocalValues(childB),
),
))
// if a value overwrites a previous one the newer one should show in
// LocalValues
ctx = WithLocalValue(ctx, "foo", "barbar")
massert.Fatal(t, massert.All(
massert.Equal("barbar", LocalValue(ctx, "foo")),
massert.Equal(
map[interface{}]interface{}{"foo": "barbar"},
LocalValues(ctx),
),
))
}

View File

@ -34,7 +34,7 @@ func isErrAlreadyExists(err error) bool {
type BigQuery struct {
*bigquery.Client
gce *mdb.GCE
log *mlog.Logger
ctx context.Context
// key is dataset/tableName
tablesL sync.Mutex
@ -43,35 +43,37 @@ type BigQuery struct {
}
// MNew returns a BigQuery instance which will be initialized and configured
// when the start event is triggered on ctx (see mrun.Start). The BigQuery
// instance will have Close called on it when the stop event is triggered on ctx
// (see mrun.Stop).
// when the start event is triggered on the returned (see mrun.Start). The
// BigQuery instance will have Close called on it when the stop event is
// triggered on the returned Context (see mrun.Stop).
//
// gce is optional and can be passed in if there's an existing gce object which
// should be used, otherwise a new one will be created with mdb.MGCE.
func MNew(ctx mctx.Context, gce *mdb.GCE) *BigQuery {
func MNew(ctx context.Context, gce *mdb.GCE) (context.Context, *BigQuery) {
if gce == nil {
gce = mdb.MGCE(ctx, "")
ctx, gce = mdb.MGCE(ctx, "")
}
ctx = mctx.ChildOf(ctx, "bigquery")
bq := &BigQuery{
gce: gce,
tables: map[[2]string]*bigquery.Table{},
tableUploaders: map[[2]string]*bigquery.Uploader{},
log: mlog.From(ctx),
ctx: mctx.NewChild(ctx, "bigquery"),
}
bq.log.SetKV(bq)
mrun.OnStart(ctx, func(innerCtx mctx.Context) error {
bq.log.Info("connecting to bigquery")
// TODO the equivalent functionality as here will be added with annotations
// bq.log.SetKV(bq)
bq.ctx = mrun.OnStart(bq.ctx, func(innerCtx context.Context) error {
mlog.Info(bq.ctx, "connecting to bigquery")
var err error
bq.Client, err = bigquery.NewClient(innerCtx, bq.gce.Project, bq.gce.ClientOptions()...)
return merr.WithKV(err, bq.KV())
})
mrun.OnStop(ctx, func(mctx.Context) error {
bq.ctx = mrun.OnStop(bq.ctx, func(context.Context) error {
return bq.Client.Close()
})
return bq
return mctx.WithChild(ctx, bq.ctx), bq
}
// KV implements the mlog.KVer interface.
@ -99,7 +101,7 @@ func (bq *BigQuery) Table(
}
kv := mlog.KV{"bigQueryDataset": dataset, "bigQueryTable": tableName}
bq.log.Debug("creating/grabbing table", kv)
mlog.Debug(bq.ctx, "creating/grabbing table", kv)
schema, err := bigquery.InferSchema(schemaObj)
if err != nil {

View File

@ -28,44 +28,45 @@ type Bigtable struct {
Instance string
gce *mdb.GCE
log *mlog.Logger
ctx context.Context
}
// MNew returns a Bigtable instance which will be initialized and configured
// when the start event is triggered on ctx (see mrun.Start). The Bigtable
// instance will have Close called on it when the stop event is triggered on ctx
// (see mrun.Stop).
// when the start event is triggered on the returned Context (see mrun.Start).
// The Bigtable instance will have Close called on it when the stop event is
// triggered on the returned Context (see mrun.Stop).
//
// gce is optional and can be passed in if there's an existing gce object which
// should be used, otherwise a new one will be created with mdb.MGCE.
//
// defaultInstance can be given as the instance name to use as the default
// parameter value. If empty the parameter will be required to be set.
func MNew(ctx mctx.Context, gce *mdb.GCE, defaultInstance string) *Bigtable {
func MNew(ctx context.Context, gce *mdb.GCE, defaultInstance string) (context.Context, *Bigtable) {
if gce == nil {
gce = mdb.MGCE(ctx, "")
ctx, gce = mdb.MGCE(ctx, "")
}
ctx = mctx.ChildOf(ctx, "bigtable")
bt := &Bigtable{
gce: gce,
log: mlog.From(ctx),
ctx: mctx.NewChild(ctx, "bigtable"),
}
bt.log.SetKV(bt)
// TODO the equivalent functionality as here will be added with annotations
// bt.log.SetKV(bt)
var inst *string
{
const name, descr = "instance", "name of the bigtable instance in the project to connect to"
if defaultInstance != "" {
inst = mcfg.String(ctx, name, defaultInstance, descr)
bt.ctx, inst = mcfg.String(bt.ctx, name, defaultInstance, descr)
} else {
inst = mcfg.RequiredString(ctx, name, descr)
bt.ctx, inst = mcfg.RequiredString(bt.ctx, name, descr)
}
}
mrun.OnStart(ctx, func(innerCtx mctx.Context) error {
bt.ctx = mrun.OnStart(bt.ctx, func(innerCtx context.Context) error {
bt.Instance = *inst
bt.log.Info("connecting to bigtable")
mlog.Info(bt.ctx, "connecting to bigtable", bt)
var err error
bt.Client, err = bigtable.NewClient(
innerCtx,
@ -74,10 +75,10 @@ func MNew(ctx mctx.Context, gce *mdb.GCE, defaultInstance string) *Bigtable {
)
return merr.WithKV(err, bt.KV())
})
mrun.OnStop(ctx, func(mctx.Context) error {
bt.ctx = mrun.OnStop(bt.ctx, func(context.Context) error {
return bt.Client.Close()
})
return bt
return mctx.WithChild(ctx, bt.ctx), bt
}
// KV implements the mlog.KVer interface.
@ -93,16 +94,16 @@ func (bt *Bigtable) KV() map[string]interface{} {
// This method requires admin privileges on the bigtable instance.
func (bt *Bigtable) EnsureTable(ctx context.Context, name string, colFams ...string) error {
kv := mlog.KV{"bigtableTable": name}
bt.log.Info("ensuring table", kv)
mlog.Info(bt.ctx, "ensuring table", kv)
bt.log.Debug("creating admin client", kv)
mlog.Debug(bt.ctx, "creating admin client", kv)
adminClient, err := bigtable.NewAdminClient(ctx, bt.gce.Project, bt.Instance)
if err != nil {
return merr.WithKV(err, bt.KV(), kv.KV())
}
defer adminClient.Close()
bt.log.Debug("creating bigtable table (if needed)", kv)
mlog.Debug(bt.ctx, "creating bigtable table (if needed)", kv)
err = adminClient.CreateTable(ctx, name)
if err != nil && !isErrAlreadyExists(err) {
return merr.WithKV(err, bt.KV(), kv.KV())
@ -110,7 +111,7 @@ func (bt *Bigtable) EnsureTable(ctx context.Context, name string, colFams ...str
for _, colFam := range colFams {
kv := kv.Set("family", colFam)
bt.log.Debug("creating bigtable column family (if needed)", kv)
mlog.Debug(bt.ctx, "creating bigtable column family (if needed)", kv)
err := adminClient.CreateColumnFamily(ctx, name, colFam)
if err != nil && !isErrAlreadyExists(err) {
return merr.WithKV(err, bt.KV(), kv.KV())

View File

@ -12,8 +12,8 @@ import (
func TestBasic(t *T) {
ctx := mtest.NewCtx()
mtest.SetEnv(ctx, "GCE_PROJECT", "testProject")
bt := MNew(ctx, nil, "testInstance")
ctx = mtest.SetEnv(ctx, "GCE_PROJECT", "testProject")
ctx, bt := MNew(ctx, nil, "testInstance")
mtest.Run(ctx, t, func() {
tableName := "test-" + mrand.Hex(8)

View File

@ -3,6 +3,8 @@
package mdatastore
import (
"context"
"cloud.google.com/go/datastore"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/mdb"
@ -17,38 +19,39 @@ type Datastore struct {
*datastore.Client
gce *mdb.GCE
log *mlog.Logger
ctx context.Context
}
// MNew returns a Datastore instance which will be initialized and configured
// when the start event is triggered on ctx (see mrun.Start). The Datastore
// instance will have Close called on it when the stop event is triggered on ctx
// (see mrun.Stop).
// when the start event is triggered on the returned Context (see mrun.Start).
// The Datastore instance will have Close called on it when the stop event is
// triggered on the returned Context (see mrun.Stop).
//
// gce is optional and can be passed in if there's an existing gce object which
// should be used, otherwise a new one will be created with mdb.MGCE.
func MNew(ctx mctx.Context, gce *mdb.GCE) *Datastore {
func MNew(ctx context.Context, gce *mdb.GCE) (context.Context, *Datastore) {
if gce == nil {
gce = mdb.MGCE(ctx, "")
ctx, gce = mdb.MGCE(ctx, "")
}
ctx = mctx.ChildOf(ctx, "datastore")
ds := &Datastore{
gce: gce,
log: mlog.From(ctx),
ctx: mctx.NewChild(ctx, "datastore"),
}
ds.log.SetKV(ds)
mrun.OnStart(ctx, func(innerCtx mctx.Context) error {
ds.log.Info("connecting to datastore")
// TODO the equivalent functionality as here will be added with annotations
// ds.log.SetKV(ds)
ds.ctx = mrun.OnStart(ds.ctx, func(innerCtx context.Context) error {
mlog.Info(ds.ctx, "connecting to datastore")
var err error
ds.Client, err = datastore.NewClient(innerCtx, ds.gce.Project, ds.gce.ClientOptions()...)
return merr.WithKV(err, ds.KV())
})
mrun.OnStop(ctx, func(mctx.Context) error {
ds.ctx = mrun.OnStop(ds.ctx, func(context.Context) error {
return ds.Client.Close()
})
return ds
return mctx.WithChild(ctx, ds.ctx), ds
}
// KV implements the mlog.KVer interface.

View File

@ -12,8 +12,8 @@ import (
// Requires datastore emulator to be running
func TestBasic(t *T) {
ctx := mtest.NewCtx()
mtest.SetEnv(ctx, "GCE_PROJECT", "test")
ds := MNew(ctx, nil)
ctx = mtest.SetEnv(ctx, "GCE_PROJECT", "test")
ctx, ds := MNew(ctx, nil)
mtest.Run(ctx, t, func() {
name := mrand.Hex(8)
key := datastore.NameKey("testKind", name, nil)

View File

@ -3,6 +3,8 @@
package mdb
import (
"context"
"github.com/mediocregopher/mediocre-go-lib/mcfg"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/mlog"
@ -18,27 +20,28 @@ type GCE struct {
}
// MGCE returns a GCE instance which will be initialized and configured when the
// start event is triggered on ctx (see mrun.Start). defaultProject is used as
// the default value for the mcfg parameter this function creates.
func MGCE(ctx mctx.Context, defaultProject string) *GCE {
ctx = mctx.ChildOf(ctx, "gce")
credFile := mcfg.String(ctx, "cred-file", "", "Path to GCE credientials JSON file, if any")
// start event is triggered on the returned Context (see mrun.Start).
// defaultProject is used as the default value for the mcfg parameter this
// function creates.
func MGCE(parent context.Context, defaultProject string) (context.Context, *GCE) {
ctx := mctx.NewChild(parent, "gce")
ctx, credFile := mcfg.String(ctx, "cred-file", "", "Path to GCE credientials JSON file, if any")
var project *string
const projectUsage = "Name of GCE project to use"
if defaultProject == "" {
project = mcfg.RequiredString(ctx, "project", projectUsage)
ctx, project = mcfg.RequiredString(ctx, "project", projectUsage)
} else {
project = mcfg.String(ctx, "project", defaultProject, projectUsage)
ctx, project = mcfg.String(ctx, "project", defaultProject, projectUsage)
}
var gce GCE
mrun.OnStart(ctx, func(mctx.Context) error {
ctx = mrun.OnStart(ctx, func(context.Context) error {
gce.Project = *project
gce.CredFile = *credFile
return nil
})
return &gce
return mctx.WithChild(parent, ctx), &gce
}
// ClientOptions generates and returns the ClientOption instances which can be

View File

@ -19,7 +19,6 @@ import (
"google.golang.org/grpc/status"
)
// TODO this package still uses context.Context in the callback functions
// TODO Consume (and probably BatchConsume) don't properly handle the Client
// being closed.
@ -39,38 +38,39 @@ type PubSub struct {
*pubsub.Client
gce *mdb.GCE
log *mlog.Logger
ctx context.Context
}
// MNew returns a PubSub instance which will be initialized and configured when
// the start event is triggered on ctx (see mrun.Start). The PubSub instance
// will have Close called on it when the stop event is triggered on ctx (see
// mrun.Stop).
// the start event is triggered on the returned Context (see mrun.Start). The
// PubSub instance will have Close called on it when the stop event is triggered
// on the returned Context(see mrun.Stop).
//
// gce is optional and can be passed in if there's an existing gce object which
// should be used, otherwise a new one will be created with mdb.MGCE.
func MNew(ctx mctx.Context, gce *mdb.GCE) *PubSub {
func MNew(ctx context.Context, gce *mdb.GCE) (context.Context, *PubSub) {
if gce == nil {
gce = mdb.MGCE(ctx, "")
ctx, gce = mdb.MGCE(ctx, "")
}
ctx = mctx.ChildOf(ctx, "pubsub")
ps := &PubSub{
gce: gce,
log: mlog.From(ctx),
ctx: mctx.NewChild(ctx, "pubsub"),
}
ps.log.SetKV(ps)
mrun.OnStart(ctx, func(innerCtx mctx.Context) error {
ps.log.Info("connecting to pubsub")
// TODO the equivalent functionality as here will be added with annotations
// ps.log.SetKV(ps)
ps.ctx = mrun.OnStart(ps.ctx, func(innerCtx context.Context) error {
mlog.Info(ps.ctx, "connecting to pubsub")
var err error
ps.Client, err = pubsub.NewClient(innerCtx, ps.gce.Project, ps.gce.ClientOptions()...)
return merr.WithKV(err, ps.KV())
})
mrun.OnStop(ctx, func(mctx.Context) error {
ps.ctx = mrun.OnStop(ps.ctx, func(context.Context) error {
return ps.Client.Close()
})
return ps
return mctx.WithChild(ctx, ps.ctx), ps
}
// KV implements the mlog.KVer interface
@ -226,7 +226,7 @@ func (s *Subscription) Consume(ctx context.Context, fn ConsumerFunc, opts Consum
ok, err := fn(context.Context(innerCtx), msg)
if err != nil {
s.topic.ps.log.Warn("error consuming pubsub message", s, merr.KV(err))
mlog.Warn(s.topic.ps.ctx, "error consuming pubsub message", s, merr.KV(err))
}
if ok {
@ -238,7 +238,7 @@ func (s *Subscription) Consume(ctx context.Context, fn ConsumerFunc, opts Consum
if octx.Err() == context.Canceled || err == nil {
return
} else if err != nil {
s.topic.ps.log.Warn("error consuming from pubsub", s, merr.KV(err))
mlog.Warn(s.topic.ps.ctx, "error consuming from pubsub", s, merr.KV(err))
time.Sleep(1 * time.Second)
}
}
@ -331,7 +331,7 @@ func (s *Subscription) BatchConsume(
}
ret, err := fn(thisCtx, msgs)
if err != nil {
s.topic.ps.log.Warn("error consuming pubsub batch messages", s, merr.KV(err))
mlog.Warn(s.topic.ps.ctx, "error consuming pubsub batch messages", s, merr.KV(err))
}
for i := range thisGroup {
thisGroup[i].retCh <- ret // retCh is buffered

View File

@ -14,8 +14,8 @@ import (
// this requires the pubsub emulator to be running
func TestPubSub(t *T) {
ctx := mtest.NewCtx()
mtest.SetEnv(ctx, "GCE_PROJECT", "test")
ps := MNew(ctx, nil)
ctx = mtest.SetEnv(ctx, "GCE_PROJECT", "test")
ctx, ps := MNew(ctx, nil)
mtest.Run(ctx, t, func() {
topicName := "testTopic_" + mrand.Hex(8)
ctx := context.Background()
@ -48,8 +48,8 @@ func TestPubSub(t *T) {
func TestBatchPubSub(t *T) {
ctx := mtest.NewCtx()
mtest.SetEnv(ctx, "GCE_PROJECT", "test")
ps := MNew(ctx, nil)
ctx = mtest.SetEnv(ctx, "GCE_PROJECT", "test")
ctx, ps := MNew(ctx, nil)
mtest.Run(ctx, t, func() {
topicName := "testBatchTopic_" + mrand.Hex(8)

View File

@ -17,27 +17,39 @@ import (
"github.com/mediocregopher/mediocre-go-lib/mrun"
)
// MServer is returned by MListenAndServe and simply wraps an *http.Server.
type MServer struct {
*http.Server
ctx context.Context
}
// MListenAndServe returns an http.Server which will be initialized and have
// ListenAndServe called on it (asynchronously) when the start event is
// triggered on ctx (see mrun.Start). The Server will have Shutdown called on it
// when the stop event is triggered on ctx (see mrun.Stop).
// triggered on the returned Context (see mrun.Start). The Server will have
// Shutdown called on it when the stop event is triggered on the returned
// Context (see mrun.Stop).
//
// This function automatically handles setting up configuration parameters via
// mcfg. The default listen address is ":0".
func MListenAndServe(ctx mctx.Context, h http.Handler) *http.Server {
ctx = mctx.ChildOf(ctx, "http")
listener := mnet.MListen(ctx, "tcp", "")
func MListenAndServe(ctx context.Context, h http.Handler) (context.Context, *MServer) {
srv := &MServer{
Server: &http.Server{Handler: h},
ctx: mctx.NewChild(ctx, "http"),
}
var listener *mnet.MListener
srv.ctx, listener = mnet.MListen(srv.ctx, "tcp", "")
listener.NoCloseOnStop = true // http.Server.Shutdown will do this
logger := mlog.From(ctx)
logger.SetKV(listener)
// TODO the equivalent functionality as here will be added with annotations
//logger := mlog.From(ctx)
//logger.SetKV(listener)
srv := http.Server{Handler: h}
mrun.OnStart(ctx, func(mctx.Context) error {
srv.ctx = mrun.OnStart(srv.ctx, func(context.Context) error {
srv.Addr = listener.Addr().String()
mrun.Thread(ctx, func() error {
srv.ctx = mrun.Thread(srv.ctx, func() error {
if err := srv.Serve(listener); err != http.ErrServerClosed {
logger.Error("error serving listener", merr.KV(err))
mlog.Error(srv.ctx, "error serving listener", merr.KV(err))
return err
}
return nil
@ -45,15 +57,15 @@ func MListenAndServe(ctx mctx.Context, h http.Handler) *http.Server {
return nil
})
mrun.OnStop(ctx, func(innerCtx mctx.Context) error {
logger.Info("shutting down server")
if err := srv.Shutdown(context.Context(innerCtx)); err != nil {
srv.ctx = mrun.OnStop(srv.ctx, func(innerCtx context.Context) error {
mlog.Info(srv.ctx, "shutting down server")
if err := srv.Shutdown(innerCtx); err != nil {
return err
}
return mrun.Wait(ctx, innerCtx.Done())
return mrun.Wait(srv.ctx, innerCtx.Done())
})
return &srv
return mctx.WithChild(ctx, srv.ctx), srv
}
// AddXForwardedFor populates the X-Forwarded-For header on the Request to

View File

@ -15,7 +15,7 @@ import (
func TestMListenAndServe(t *T) {
ctx := mtest.NewCtx()
srv := MListenAndServe(ctx, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx, srv := MListenAndServe(ctx, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
io.Copy(rw, r.Body)
}))

View File

@ -1,87 +1,53 @@
package mlog
import (
"path"
"github.com/mediocregopher/mediocre-go-lib/mctx"
)
import "context"
type ctxKey int
// CtxSet caches the Logger in the Context, overwriting any previous one which
// might have been cached there. From is the corresponding function which
// retrieves the Logger back out when needed.
//
// This function can be used to premptively set a preconfigured Logger on a root
// Context so that the default (NewLogger) isn't used when From is called for
// the first time.
func CtxSet(ctx mctx.Context, l *Logger) {
mctx.GetSetMutableValue(ctx, false, ctxKey(0), func(interface{}) interface{} {
// Set returns the Context with the Logger carried by it.
func Set(ctx context.Context, l *Logger) context.Context {
return context.WithValue(ctx, ctxKey(0), l)
}
// DefaultLogger is an instance of Logger which is returned by From when a
// Logger hasn't been previously Set on the Context passed in.
var DefaultLogger = NewLogger()
// From returns the Logger carried by this Context, or DefaultLogger if none is
// being carried.
func From(ctx context.Context) *Logger {
if l, _ := ctx.Value(ctxKey(0)).(*Logger); l != nil {
return l
})
}
return DefaultLogger
}
// CtxSetAll traverses the given Context's children, breadth-first. It calls the
// callback for each Context which has a Logger set on it, replacing that Logger
// with the returned one.
//
// This is useful, for example, when changing the log level of all Loggers in an
// app.
func CtxSetAll(ctx mctx.Context, callback func(mctx.Context, *Logger) *Logger) {
mctx.BreadthFirstVisit(ctx, func(ctx mctx.Context) bool {
mctx.GetSetMutableValue(ctx, false, ctxKey(0), func(i interface{}) interface{} {
if i == nil {
return nil
}
return callback(ctx, i.(*Logger))
})
return true
})
// Debug is a shortcut for
// mlog.From(ctx).Debug(ctx, descr, kvs...)
func Debug(ctx context.Context, descr string, kvs ...KVer) {
From(ctx).Debug(ctx, descr, kvs...)
}
type ctxPathStringer struct {
str Stringer
pathStr string
// Info is a shortcut for
// mlog.From(ctx).Info(ctx, descr, kvs...)
func Info(ctx context.Context, descr string, kvs ...KVer) {
From(ctx).Info(ctx, descr, kvs...)
}
func (cp ctxPathStringer) String() string {
return "(" + cp.pathStr + ") " + cp.str.String()
// Warn is a shortcut for
// mlog.From(ctx).Warn(ctx, descr, kvs...)
func Warn(ctx context.Context, descr string, kvs ...KVer) {
From(ctx).Warn(ctx, descr, kvs...)
}
// From returns an instance of Logger which has been customized for this
// Context, primarily by adding a prefix describing the Context's path to all
// Message descriptions the Logger receives.
//
// The Context caches within it the generated Logger, so a new one isn't created
// everytime. When From is first called on a Context the Logger inherits the
// Context parent's Logger. If the parent hasn't had From called on it its
// parent will be queried instead, and so on.
func From(ctx mctx.Context) *Logger {
return mctx.GetSetMutableValue(ctx, true, ctxKey(0), func(interface{}) interface{} {
ctxPath := mctx.Path(ctx)
if len(ctxPath) == 0 {
// we're at the root node and it doesn't have a Logger set, use
// the default
return NewLogger()
}
// set up child's logger
pathStr := "/" + path.Join(ctxPath...)
parentL := From(mctx.Parent(ctx))
parentH := parentL.Handler()
thisL := parentL.Clone()
thisL.SetHandler(func(msg Message) error {
// if the Description is already a ctxPathStringer it can be
// assumed this Message was passed in from a child Logger.
if _, ok := msg.Description.(ctxPathStringer); !ok {
msg.Description = ctxPathStringer{
str: msg.Description,
pathStr: pathStr,
}
}
return parentH(msg)
})
return thisL
}).(*Logger)
// Error is a shortcut for
// mlog.From(ctx).Error(ctx, descr, kvs...)
func Error(ctx context.Context, descr string, kvs ...KVer) {
From(ctx).Error(ctx, descr, kvs...)
}
// Fatal is a shortcut for
// mlog.From(ctx).Fatal(ctx, descr, kvs...)
func Fatal(ctx context.Context, descr string, kvs ...KVer) {
From(ctx).Fatal(ctx, descr, kvs...)
}

View File

@ -1,60 +1,67 @@
package mlog
import (
"bytes"
"context"
"strings"
. "testing"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/mtest/massert"
)
func TestContextStuff(t *T) {
ctx := mctx.New()
ctx1 := mctx.ChildOf(ctx, "1")
ctx1a := mctx.ChildOf(ctx1, "a")
ctx1b := mctx.ChildOf(ctx1, "b")
func TestContextLogging(t *T) {
var descrs []string
var lines []string
l := NewLogger()
l.SetHandler(func(msg Message) error {
descrs = append(descrs, msg.Description.String())
buf := new(bytes.Buffer)
if err := DefaultFormat(buf, msg); err != nil {
t.Fatal(err)
}
lines = append(lines, strings.TrimSuffix(buf.String(), "\n"))
return nil
})
CtxSet(ctx, l)
From(ctx1a).Info("ctx1a")
From(ctx1).Info("ctx1")
From(ctx).Info("ctx")
From(ctx1b).Debug("ctx1b (shouldn't show up)")
From(ctx1b).Info("ctx1b")
ctx := Set(context.Background(), l)
ctx1 := mctx.NewChild(ctx, "1")
ctx1a := mctx.NewChild(ctx1, "a")
ctx1b := mctx.NewChild(ctx1, "b")
ctx1 = mctx.WithChild(ctx1, ctx1a)
ctx1 = mctx.WithChild(ctx1, ctx1b)
ctx = mctx.WithChild(ctx, ctx1)
ctx2 := mctx.ChildOf(ctx, "2")
From(ctx2).Info("ctx2")
From(ctx).Info(ctx1a, "ctx1a")
From(ctx).Info(ctx1, "ctx1")
From(ctx).Info(ctx, "ctx")
From(ctx).Debug(ctx1b, "ctx1b (shouldn't show up)")
From(ctx).Info(ctx1b, "ctx1b")
ctx2 := mctx.NewChild(ctx, "2")
ctx = mctx.WithChild(ctx, ctx2)
From(ctx2).Info(ctx2, "ctx2")
massert.Fatal(t, massert.All(
massert.Len(descrs, 5),
massert.Equal(descrs[0], "(/1/a) ctx1a"),
massert.Equal(descrs[1], "(/1) ctx1"),
massert.Equal(descrs[2], "ctx"),
massert.Equal(descrs[3], "(/1/b) ctx1b"),
massert.Equal(descrs[4], "(/2) ctx2"),
massert.Len(lines, 5),
massert.Equal(lines[0], "~ INFO -- (/1/a) ctx1a"),
massert.Equal(lines[1], "~ INFO -- (/1) ctx1"),
massert.Equal(lines[2], "~ INFO -- ctx"),
massert.Equal(lines[3], "~ INFO -- (/1/b) ctx1b"),
massert.Equal(lines[4], "~ INFO -- (/2) ctx2"),
))
// use CtxSetAll to change all MaxLevels in-place
ctx2L := From(ctx2)
CtxSetAll(ctx, func(_ mctx.Context, l *Logger) *Logger {
l.SetMaxLevel(DebugLevel)
return l
})
// changing MaxLevel on ctx's Logger should change it for all
From(ctx).SetMaxLevel(DebugLevel)
descrs = descrs[:0]
From(ctx).Info("ctx")
From(ctx).Debug("ctx debug")
ctx2L.Debug("ctx2L debug")
lines = lines[:0]
From(ctx).Info(ctx, "ctx")
From(ctx).Debug(ctx, "ctx debug")
From(ctx2).Debug(ctx2, "ctx2 debug")
massert.Fatal(t, massert.All(
massert.Len(descrs, 3),
massert.Equal(descrs[0], "ctx"),
massert.Equal(descrs[1], "ctx debug"),
massert.Equal(descrs[2], "(/2) ctx2L debug"),
massert.Len(lines, 3),
massert.Equal(lines[0], "~ INFO -- ctx"),
massert.Equal(lines[1], "~ DEBUG -- ctx debug"),
massert.Equal(lines[2], "~ DEBUG -- (/2) ctx2 debug"),
))
}

View File

@ -17,6 +17,7 @@ package mlog
import (
"bufio"
"context"
"fmt"
"io"
"os"
@ -25,6 +26,7 @@ import (
"strings"
"sync"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/merr"
)
@ -204,23 +206,12 @@ func Prefix(kv KVer, prefix string) KVer {
////////////////////////////////////////////////////////////////////////////////
// Stringer generates and returns a string.
type Stringer interface {
String() string
}
// String is simply a string which implements Stringer.
type String string
func (str String) String() string {
return string(str)
}
// Message describes a message to be logged, after having already resolved the
// KVer
type Message struct {
context.Context
Level
Description Stringer
Description string
KVer
}
@ -253,7 +244,11 @@ func DefaultFormat(w io.Writer, msg Message) error {
_, err = fmt.Fprintf(w, s, args...)
}
}
write("~ %s -- %s", msg.Level.String(), msg.Description.String())
write("~ %s -- ", msg.Level.String())
if path := mctx.Path(msg.Context); len(path) > 0 {
write("(%s) ", "/"+strings.Join(path, "/"))
}
write("%s", msg.Description)
if msg.KVer != nil {
if kv := msg.KV(); len(kv) > 0 {
write(" --")
@ -363,7 +358,7 @@ func (l *Logger) Log(msg Message) {
}
if err := l.h(msg); err != nil {
go l.Error("Logger.Handler returned error", merr.KV(err))
go l.Error(context.Background(), "Logger.Handler returned error", merr.KV(err))
return
}
@ -376,36 +371,37 @@ func (l *Logger) Log(msg Message) {
}
}
func mkMsg(lvl Level, descr string, kvs ...KVer) Message {
func mkMsg(ctx context.Context, lvl Level, descr string, kvs ...KVer) Message {
return Message{
Context: ctx,
Level: lvl,
Description: String(descr),
Description: descr,
KVer: Merge(kvs...),
}
}
// Debug logs a DebugLevel message, merging the KVers together first
func (l *Logger) Debug(descr string, kvs ...KVer) {
l.Log(mkMsg(DebugLevel, descr, kvs...))
func (l *Logger) Debug(ctx context.Context, descr string, kvs ...KVer) {
l.Log(mkMsg(ctx, DebugLevel, descr, kvs...))
}
// Info logs a InfoLevel message, merging the KVers together first
func (l *Logger) Info(descr string, kvs ...KVer) {
l.Log(mkMsg(InfoLevel, descr, kvs...))
func (l *Logger) Info(ctx context.Context, descr string, kvs ...KVer) {
l.Log(mkMsg(ctx, InfoLevel, descr, kvs...))
}
// Warn logs a WarnLevel message, merging the KVers together first
func (l *Logger) Warn(descr string, kvs ...KVer) {
l.Log(mkMsg(WarnLevel, descr, kvs...))
func (l *Logger) Warn(ctx context.Context, descr string, kvs ...KVer) {
l.Log(mkMsg(ctx, WarnLevel, descr, kvs...))
}
// Error logs a ErrorLevel message, merging the KVers together first
func (l *Logger) Error(descr string, kvs ...KVer) {
l.Log(mkMsg(ErrorLevel, descr, kvs...))
func (l *Logger) Error(ctx context.Context, descr string, kvs ...KVer) {
l.Log(mkMsg(ctx, ErrorLevel, descr, kvs...))
}
// Fatal logs a FatalLevel message, merging the KVers together first. A Fatal
// message automatically stops the process with an os.Exit(1)
func (l *Logger) Fatal(descr string, kvs ...KVer) {
l.Log(mkMsg(FatalLevel, descr, kvs...))
func (l *Logger) Fatal(ctx context.Context, descr string, kvs ...KVer) {
l.Log(mkMsg(ctx, FatalLevel, descr, kvs...))
}

View File

@ -2,6 +2,7 @@ package mlog
import (
"bytes"
"context"
"regexp"
"strings"
. "testing"
@ -36,6 +37,7 @@ func TestKV(t *T) {
}
func TestLogger(t *T) {
ctx := context.Background()
buf := new(bytes.Buffer)
h := func(msg Message) error {
return DefaultFormat(buf, msg)
@ -59,10 +61,10 @@ func TestLogger(t *T) {
}
// Default max level should be INFO
l.Debug("foo")
l.Info("bar")
l.Warn("baz")
l.Error("buz")
l.Debug(ctx, "foo")
l.Info(ctx, "bar")
l.Warn(ctx, "baz")
l.Error(ctx, "buz")
massert.Fatal(t, massert.All(
assertOut("~ INFO -- bar\n"),
assertOut("~ WARN -- baz\n"),
@ -70,10 +72,10 @@ func TestLogger(t *T) {
))
l.SetMaxLevel(WarnLevel)
l.Debug("foo")
l.Info("bar")
l.Warn("baz")
l.Error("buz", KV{"a": "b"})
l.Debug(ctx, "foo")
l.Info(ctx, "bar")
l.Warn(ctx, "baz")
l.Error(ctx, "buz", KV{"a": "b"})
massert.Fatal(t, massert.All(
assertOut("~ WARN -- baz\n"),
assertOut("~ ERROR -- buz -- a=\"b\"\n"),
@ -82,12 +84,12 @@ func TestLogger(t *T) {
l2 := l.Clone()
l2.SetMaxLevel(InfoLevel)
l2.SetHandler(func(msg Message) error {
msg.Description = String(strings.ToUpper(msg.Description.String()))
msg.Description = strings.ToUpper(msg.Description)
return h(msg)
})
l2.Info("bar")
l2.Warn("baz")
l.Error("buz")
l2.Info(ctx, "bar")
l2.Warn(ctx, "baz")
l.Error(ctx, "buz")
massert.Fatal(t, massert.All(
assertOut("~ INFO -- BAR\n"),
assertOut("~ WARN -- BAZ\n"),
@ -96,8 +98,8 @@ func TestLogger(t *T) {
l3 := l2.Clone()
l3.SetKV(KV{"a": 1})
l3.Info("foo", KV{"b": 2})
l3.Info("bar", KV{"a": 2, "b": 3})
l3.Info(ctx, "foo", KV{"b": 2})
l3.Info(ctx, "bar", KV{"a": 2, "b": 3})
massert.Fatal(t, massert.All(
assertOut("~ INFO -- FOO -- a=\"1\" b=\"2\"\n"),
assertOut("~ INFO -- BAR -- a=\"2\" b=\"3\"\n"),
@ -121,7 +123,11 @@ func TestDefaultFormat(t *T) {
)
}
msg := Message{Level: InfoLevel, Description: String("this is a test")}
msg := Message{
Context: context.Background(),
Level: InfoLevel,
Description: "this is a test",
}
massert.Fatal(t, assertFormat("INFO -- this is a test", msg))
msg.KVer = KV{}

View File

@ -3,11 +3,11 @@
package mnet
import (
"context"
"net"
"strings"
"github.com/mediocregopher/mediocre-go-lib/mcfg"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/merr"
"github.com/mediocregopher/mediocre-go-lib/mlog"
"github.com/mediocregopher/mediocre-go-lib/mrun"
@ -16,7 +16,7 @@ import (
// MListener is returned by MListen and simply wraps a net.Listener.
type MListener struct {
net.Listener
log *mlog.Logger
ctx context.Context
// If set to true before mrun's stop event is run, the stop event will not
// cause the MListener to be closed.
@ -24,44 +24,47 @@ type MListener struct {
}
// MListen returns an MListener which will be initialized when the start event
// is triggered on ctx (see mrun.Start), and closed when the stop event is
// triggered on ctx (see mrun.Stop).
// is triggered on the returned Context (see mrun.Start), and closed when the
// stop event is triggered on the returned Context (see mrun.Stop).
//
// network defaults to "tcp" if empty. defaultAddr defaults to ":0" if empty,
// and will be configurable via mcfg.
func MListen(ctx mctx.Context, network, defaultAddr string) *MListener {
func MListen(ctx context.Context, network, defaultAddr string) (context.Context, *MListener) {
if network == "" {
network = "tcp"
}
if defaultAddr == "" {
defaultAddr = ":0"
}
addr := mcfg.String(ctx, "listen-addr", defaultAddr, strings.ToUpper(network)+" address to listen on in format [host]:port. If port is 0 then a random one will be chosen")
l := new(MListener)
l.log = mlog.From(ctx)
l.log.SetKV(l)
mrun.OnStart(ctx, func(mctx.Context) error {
// TODO the equivalent functionality as here will be added with annotations
//l.log = mlog.From(ctx)
//l.log.SetKV(l)
ctx, addr := mcfg.String(ctx, "listen-addr", defaultAddr, strings.ToUpper(network)+" address to listen on in format [host]:port. If port is 0 then a random one will be chosen")
ctx = mrun.OnStart(ctx, func(context.Context) error {
var err error
if l.Listener, err = net.Listen(network, *addr); err != nil {
return err
}
l.log.Info("listening")
mlog.Info(l.ctx, "listening")
return nil
})
// TODO track connections and wait for them to complete before shutting
// down?
mrun.OnStop(ctx, func(mctx.Context) error {
ctx = mrun.OnStop(ctx, func(context.Context) error {
if l.NoCloseOnStop {
return nil
}
l.log.Info("stopping listener")
mlog.Info(l.ctx, "stopping listener")
return l.Close()
})
return l
l.ctx = ctx
return ctx, l
}
// Accept wraps a call to Accept on the underlying net.Listener, providing debug
@ -71,16 +74,16 @@ func (l *MListener) Accept() (net.Conn, error) {
if err != nil {
return conn, err
}
l.log.Debug("connection accepted", mlog.KV{"remoteAddr": conn.RemoteAddr()})
mlog.Debug(l.ctx, "connection accepted", mlog.KV{"remoteAddr": conn.RemoteAddr()})
return conn, nil
}
// Close wraps a call to Close on the underlying net.Listener, providing debug
// logging.
func (l *MListener) Close() error {
l.log.Debug("listener closing")
mlog.Debug(l.ctx, "listener closing")
err := l.Listener.Close()
l.log.Debug("listener closed", merr.KV(err))
mlog.Debug(l.ctx, "listener closed", merr.KV(err))
return err
}

View File

@ -37,7 +37,7 @@ func TestIsReservedIP(t *T) {
func TestMListen(t *T) {
ctx := mtest.NewCtx()
l := MListen(ctx, "", "")
ctx, l := MListen(ctx, "", "")
mtest.Run(ctx, t, func() {
go func() {
conn, err := net.Dial("tcp", l.Addr().String())

View File

@ -1,87 +1,143 @@
package mrun
import "github.com/mediocregopher/mediocre-go-lib/mctx"
import (
"context"
type ctxEventKeyWrap struct {
key interface{}
}
"github.com/mediocregopher/mediocre-go-lib/mctx"
)
// Hook describes a function which can be registered to trigger on an event via
// the RegisterHook function.
type Hook func(mctx.Context) error
type Hook func(context.Context) error
type ctxKey int
const (
ctxKeyHookEls ctxKey = iota
ctxKeyNumChildren
ctxKeyNumHooks
)
type ctxKeyWrap struct {
key ctxKey
userKey interface{}
}
// because we want Hooks to be called in the order created, taking into account
// the creation of children and their hooks as well, we create a sequence of
// elements which can either be a Hook or a child.
type hookEl struct {
hook Hook
child context.Context
}
func ctxKeys(userKey interface{}) (ctxKeyWrap, ctxKeyWrap, ctxKeyWrap) {
return ctxKeyWrap{
key: ctxKeyHookEls,
userKey: userKey,
}, ctxKeyWrap{
key: ctxKeyNumChildren,
userKey: userKey,
}, ctxKeyWrap{
key: ctxKeyNumHooks,
userKey: userKey,
}
}
// getHookEls retrieves a copy of the []hookEl in the Context and possibly
// appends more elements if more children have been added since that []hookEl
// was created.
//
// this also returns the latest numChildren and numHooks values for convenience.
func getHookEls(ctx context.Context, userKey interface{}) ([]hookEl, int, int) {
hookElsKey, numChildrenKey, numHooksKey := ctxKeys(userKey)
lastNumChildren, _ := mctx.LocalValue(ctx, numChildrenKey).(int)
lastNumHooks, _ := mctx.LocalValue(ctx, numHooksKey).(int)
lastHookEls, _ := mctx.LocalValue(ctx, hookElsKey).([]hookEl)
children := mctx.Children(ctx)
// plus 1 in case we wanna append something else outside this function
hookEls := make([]hookEl, len(lastHookEls), lastNumHooks+len(children)-lastNumChildren+1)
copy(hookEls, lastHookEls)
for _, child := range children[lastNumChildren:] {
hookEls = append(hookEls, hookEl{child: child})
}
return hookEls, len(children), lastNumHooks
}
// RegisterHook registers a Hook under a typed key. The Hook will be called when
// TriggerHooks is called with that same key. Multiple Hooks can be registered
// for the same key, and will be called sequentially when triggered.
//
// RegisterHook registers Hooks onto the root of the given Context. Therefore,
// Hooks will be triggered in the global order they were registered. For
// example: if one Hook is registered on a Context, then one is registered on a
// child of that Context, then another one is registered on the original Context
// again, the three Hooks will be triggered in the order: parent, child,
// parent.
//
// Hooks will be called with whatever Context is passed into TriggerHooks.
func RegisterHook(ctx mctx.Context, key interface{}, hook Hook) {
ctx = mctx.Root(ctx)
mctx.GetSetMutableValue(ctx, false, ctxEventKeyWrap{key}, func(v interface{}) interface{} {
hooks, _ := v.([]Hook)
return append(hooks, hook)
})
func RegisterHook(ctx context.Context, key interface{}, hook Hook) context.Context {
hookEls, numChildren, numHooks := getHookEls(ctx, key)
hookEls = append(hookEls, hookEl{hook: hook})
hookElsKey, numChildrenKey, numHooksKey := ctxKeys(key)
ctx = mctx.WithLocalValue(ctx, hookElsKey, hookEls)
ctx = mctx.WithLocalValue(ctx, numChildrenKey, numChildren)
ctx = mctx.WithLocalValue(ctx, numHooksKey, numHooks+1)
return ctx
}
func triggerHooks(ctx mctx.Context, key interface{}, next func([]Hook) (Hook, []Hook)) error {
rootCtx := mctx.Root(ctx)
var err error
mctx.GetSetMutableValue(rootCtx, false, ctxEventKeyWrap{key}, func(i interface{}) interface{} {
var hook Hook
hooks, _ := i.([]Hook)
for {
if len(hooks) == 0 {
break
}
hook, hooks = next(hooks)
// err here is the var outside GetSetMutableValue, we lift it out
if err = hook(ctx); err != nil {
break
}
func triggerHooks(ctx context.Context, userKey interface{}, next func([]hookEl) (hookEl, []hookEl)) error {
hookEls, _, _ := getHookEls(ctx, userKey)
var hookEl hookEl
for {
if len(hookEls) == 0 {
break
}
// if there was an error then we want to keep all the hooks which
// weren't called. If there wasn't we want to reset the value to nil so
// the slice doesn't grow unbounded.
if err != nil {
return hooks
hookEl, hookEls = next(hookEls)
if hookEl.child != nil {
if err := triggerHooks(hookEl.child, userKey, next); err != nil {
return err
}
} else if err := hookEl.hook(ctx); err != nil {
return err
}
return nil
})
return err
}
return nil
}
// TriggerHooks causes all Hooks registered with RegisterHook under the given
// key to be called in the global order they were registered, using the given
// Context as their input parameter. The given Context does not need to be the
// root Context (see RegisterHook).
// TriggerHooks causes all Hooks registered with RegisterHook on the Context
// (and its predecessors) under the given key to be called in the order they
// were registered.
//
// If any Hook returns an error no further Hooks will be called and that error
// will be returned.
//
// TriggerHooks causes all Hooks which were called to be de-registered. If an
// error caused execution to stop prematurely then any Hooks which were not
// called will remain registered.
func TriggerHooks(ctx mctx.Context, key interface{}) error {
return triggerHooks(ctx, key, func(hooks []Hook) (Hook, []Hook) {
return hooks[0], hooks[1:]
// If the Context has children (see the mctx package), and those children have
// Hooks registered under this key, then their Hooks will be called in the
// expected order. For example:
//
// // parent context has hookA registered
// ctx := context.Background()
// ctx = RegisterHook(ctx, 0, hookA)
//
// // child context has hookB registered
// childCtx := mctx.NewChild(ctx, "child")
// childCtx = RegisterHook(childCtx, 0, hookB)
// ctx = mctx.WithChild(ctx, childCtx) // needed to link childCtx to ctx
//
// // parent context has another Hook, hookC, registered
// ctx = RegisterHook(ctx, 0, hookC)
//
// // The Hooks will be triggered in the order: hookA, hookB, then hookC
// err := TriggerHooks(ctx, 0)
//
func TriggerHooks(ctx context.Context, key interface{}) error {
return triggerHooks(ctx, key, func(hookEls []hookEl) (hookEl, []hookEl) {
return hookEls[0], hookEls[1:]
})
}
// TriggerHooksReverse is the same as TriggerHooks except that registered Hooks
// are called in the reverse order in which they were registered.
func TriggerHooksReverse(ctx mctx.Context, key interface{}) error {
return triggerHooks(ctx, key, func(hooks []Hook) (Hook, []Hook) {
last := len(hooks) - 1
return hooks[last], hooks[:last]
func TriggerHooksReverse(ctx context.Context, key interface{}) error {
return triggerHooks(ctx, key, func(hookEls []hookEl) (hookEl, []hookEl) {
last := len(hookEls) - 1
return hookEls[last], hookEls[:last]
})
}
@ -101,24 +157,24 @@ const (
// server) will want to use the Hook only to initialize, and spawn off a
// go-routine to do their actual work. Long-lived tasks should set themselves up
// to stop on the stop event (see OnStop).
func OnStart(ctx mctx.Context, hook Hook) {
RegisterHook(ctx, start, hook)
func OnStart(ctx context.Context, hook Hook) context.Context {
return RegisterHook(ctx, start, hook)
}
// Start runs all Hooks registered using OnStart. This is a special case of
// TriggerHooks.
func Start(ctx mctx.Context) error {
func Start(ctx context.Context) error {
return TriggerHooks(ctx, start)
}
// OnStop registers the given Hook to run when Stop is called. This is a special
// case of RegisterHook.
func OnStop(ctx mctx.Context, hook Hook) {
RegisterHook(ctx, stop, hook)
func OnStop(ctx context.Context, hook Hook) context.Context {
return RegisterHook(ctx, stop, hook)
}
// Stop runs all Hooks registered using OnStop in the reverse order in which
// they were registered. This is a special case of TriggerHooks.
func Stop(ctx mctx.Context) error {
func Stop(ctx context.Context) error {
return TriggerHooksReverse(ctx, stop)
}

View File

@ -1,7 +1,7 @@
package mrun
import (
"errors"
"context"
. "testing"
"github.com/mediocregopher/mediocre-go-lib/mctx"
@ -9,46 +9,40 @@ import (
)
func TestHooks(t *T) {
ch := make(chan int, 10)
ctx := mctx.New()
ctxChild := mctx.ChildOf(ctx, "child")
var out []int
mkHook := func(i int) Hook {
return func(mctx.Context) error {
ch <- i
return func(context.Context) error {
out = append(out, i)
return nil
}
}
RegisterHook(ctx, 0, mkHook(0))
RegisterHook(ctxChild, 0, mkHook(1))
RegisterHook(ctx, 0, mkHook(2))
ctx := context.Background()
ctx = RegisterHook(ctx, 0, mkHook(1))
ctx = RegisterHook(ctx, 0, mkHook(2))
bogusErr := errors.New("bogus error")
RegisterHook(ctxChild, 0, func(mctx.Context) error { return bogusErr })
ctxA := mctx.NewChild(ctx, "a")
ctxA = RegisterHook(ctxA, 0, mkHook(3))
ctxA = RegisterHook(ctxA, 999, mkHook(999)) // different key
ctx = mctx.WithChild(ctx, ctxA)
RegisterHook(ctx, 0, mkHook(3))
RegisterHook(ctx, 0, mkHook(4))
ctx = RegisterHook(ctx, 0, mkHook(4))
massert.Fatal(t, massert.All(
massert.Equal(bogusErr, TriggerHooks(ctx, 0)),
massert.Equal(0, <-ch),
massert.Equal(1, <-ch),
massert.Equal(2, <-ch),
))
// after the error the 3 and 4 Hooks should still be registered, but not
// called yet.
select {
case <-ch:
t.Fatal("Hooks should not have been called yet")
default:
}
ctxB := mctx.NewChild(ctx, "b")
ctxB = RegisterHook(ctxB, 0, mkHook(5))
ctxB1 := mctx.NewChild(ctxB, "1")
ctxB1 = RegisterHook(ctxB1, 0, mkHook(6))
ctxB = mctx.WithChild(ctxB, ctxB1)
ctx = mctx.WithChild(ctx, ctxB)
massert.Fatal(t, massert.All(
massert.Nil(TriggerHooks(ctx, 0)),
massert.Equal(3, <-ch),
massert.Equal(4, <-ch),
massert.Equal([]int{1, 2, 3, 4, 5, 6}, out),
))
out = nil
massert.Fatal(t, massert.All(
massert.Nil(TriggerHooksReverse(ctx, 0)),
massert.Equal([]int{6, 5, 4, 3, 2, 1}, out),
))
}

View File

@ -3,6 +3,7 @@
package mrun
import (
"context"
"errors"
"github.com/mediocregopher/mediocre-go-lib/mctx"
@ -33,26 +34,24 @@ func (fe *futureErr) set(err error) {
close(fe.doneCh)
}
type ctxKey int
type threadCtxKey int
// Thread spawns a go-routine which executes the given function. When the passed
// in Context is canceled the Context within all threads spawned from it will
// be canceled as well.
//
// See Wait for accompanying functionality.
func Thread(ctx mctx.Context, fn func() error) {
// Thread spawns a go-routine which executes the given function. The returned
// Context tracks this go-routine, which can then be passed into the Wait
// function to block until the spawned go-routine returns.
func Thread(ctx context.Context, fn func() error) context.Context {
futErr := newFutureErr()
mctx.GetSetMutableValue(ctx, false, ctxKey(0), func(i interface{}) interface{} {
futErrs, ok := i.([]*futureErr)
if !ok {
futErrs = make([]*futureErr, 0, 1)
}
return append(futErrs, futErr)
})
oldFutErrs, _ := ctx.Value(threadCtxKey(0)).([]*futureErr)
futErrs := make([]*futureErr, len(oldFutErrs), len(oldFutErrs)+1)
copy(futErrs, oldFutErrs)
futErrs = append(futErrs, futErr)
ctx = context.WithValue(ctx, threadCtxKey(0), futErrs)
go func() {
futErr.set(fn())
}()
return ctx
}
// ErrDone is returned from Wait if cancelCh is closed before all threads have
@ -60,7 +59,7 @@ func Thread(ctx mctx.Context, fn func() error) {
var ErrDone = errors.New("Wait is done waiting")
// Wait blocks until all go-routines spawned using Thread on the passed in
// Context, and all of its children, have returned. Any number of the threads
// Context (and its predecessors) have returned. Any number of the go-routines
// may have returned already when Wait is called.
//
// If any of the thread functions returned an error during its runtime Wait will
@ -71,10 +70,8 @@ var ErrDone = errors.New("Wait is done waiting")
// this function stops waiting and returns ErrDone.
//
// Wait is safe to call in parallel, and will return the same result if called
// multiple times in sequence. If new Thread calls have been made since the last
// Wait call, the results of those calls will be waited upon during subsequent
// Wait calls.
func Wait(ctx mctx.Context, cancelCh <-chan struct{}) error {
// multiple times in sequence.
func Wait(ctx context.Context, cancelCh <-chan struct{}) error {
// First wait for all the children, and see if any of them return an error
children := mctx.Children(ctx)
for _, childCtx := range children {
@ -83,7 +80,7 @@ func Wait(ctx mctx.Context, cancelCh <-chan struct{}) error {
}
}
futErrs, _ := mctx.MutableValue(ctx, ctxKey(0)).([]*futureErr)
futErrs, _ := ctx.Value(threadCtxKey(0)).([]*futureErr)
for _, futErr := range futErrs {
err, ok := futErr.get(cancelCh)
if !ok {

View File

@ -1,6 +1,7 @@
package mrun
import (
"context"
"errors"
. "testing"
"time"
@ -12,11 +13,11 @@ func TestThreadWait(t *T) {
testErr := errors.New("test error")
cancelCh := func(t time.Duration) <-chan struct{} {
tCtx, _ := mctx.WithTimeout(mctx.New(), t*2)
tCtx, _ := context.WithTimeout(context.Background(), t*2)
return tCtx.Done()
}
wait := func(ctx mctx.Context, shouldTake time.Duration) error {
wait := func(ctx context.Context, shouldTake time.Duration) error {
start := time.Now()
err := Wait(ctx, cancelCh(shouldTake*2))
if took := time.Since(start); took < shouldTake || took > shouldTake*4/3 {
@ -28,16 +29,16 @@ func TestThreadWait(t *T) {
t.Run("noChildren", func(t *T) {
t.Run("noBlock", func(t *T) {
t.Run("noErr", func(t *T) {
ctx := mctx.New()
Thread(ctx, func() error { return nil })
ctx := context.Background()
ctx = Thread(ctx, func() error { return nil })
if err := Wait(ctx, nil); err != nil {
t.Fatal(err)
}
})
t.Run("err", func(t *T) {
ctx := mctx.New()
Thread(ctx, func() error { return testErr })
ctx := context.Background()
ctx = Thread(ctx, func() error { return testErr })
if err := Wait(ctx, nil); err != testErr {
t.Fatalf("should have got test error, got: %v", err)
}
@ -46,8 +47,8 @@ func TestThreadWait(t *T) {
t.Run("block", func(t *T) {
t.Run("noErr", func(t *T) {
ctx := mctx.New()
Thread(ctx, func() error {
ctx := context.Background()
ctx = Thread(ctx, func() error {
time.Sleep(1 * time.Second)
return nil
})
@ -57,8 +58,8 @@ func TestThreadWait(t *T) {
})
t.Run("err", func(t *T) {
ctx := mctx.New()
Thread(ctx, func() error {
ctx := context.Background()
ctx = Thread(ctx, func() error {
time.Sleep(1 * time.Second)
return testErr
})
@ -68,8 +69,8 @@ func TestThreadWait(t *T) {
})
t.Run("canceled", func(t *T) {
ctx := mctx.New()
Thread(ctx, func() error {
ctx := context.Background()
ctx = Thread(ctx, func() error {
time.Sleep(5 * time.Second)
return testErr
})
@ -80,16 +81,17 @@ func TestThreadWait(t *T) {
})
})
ctxWithChild := func() (mctx.Context, mctx.Context) {
ctx := mctx.New()
return ctx, mctx.ChildOf(ctx, "child")
ctxWithChild := func() (context.Context, context.Context) {
ctx := context.Background()
return ctx, mctx.NewChild(ctx, "child")
}
t.Run("children", func(t *T) {
t.Run("noBlock", func(t *T) {
t.Run("noErr", func(t *T) {
ctx, childCtx := ctxWithChild()
Thread(childCtx, func() error { return nil })
childCtx = Thread(childCtx, func() error { return nil })
ctx = mctx.WithChild(ctx, childCtx)
if err := Wait(ctx, nil); err != nil {
t.Fatal(err)
}
@ -97,7 +99,8 @@ func TestThreadWait(t *T) {
t.Run("err", func(t *T) {
ctx, childCtx := ctxWithChild()
Thread(childCtx, func() error { return testErr })
childCtx = Thread(childCtx, func() error { return testErr })
ctx = mctx.WithChild(ctx, childCtx)
if err := Wait(ctx, nil); err != testErr {
t.Fatalf("should have got test error, got: %v", err)
}
@ -107,10 +110,11 @@ func TestThreadWait(t *T) {
t.Run("block", func(t *T) {
t.Run("noErr", func(t *T) {
ctx, childCtx := ctxWithChild()
Thread(childCtx, func() error {
childCtx = Thread(childCtx, func() error {
time.Sleep(1 * time.Second)
return nil
})
ctx = mctx.WithChild(ctx, childCtx)
if err := wait(ctx, 1*time.Second); err != nil {
t.Fatal(err)
}
@ -118,10 +122,11 @@ func TestThreadWait(t *T) {
t.Run("err", func(t *T) {
ctx, childCtx := ctxWithChild()
Thread(childCtx, func() error {
childCtx = Thread(childCtx, func() error {
time.Sleep(1 * time.Second)
return testErr
})
ctx = mctx.WithChild(ctx, childCtx)
if err := wait(ctx, 1*time.Second); err != testErr {
t.Fatalf("should have got test error, got: %v", err)
}
@ -129,10 +134,11 @@ func TestThreadWait(t *T) {
t.Run("canceled", func(t *T) {
ctx, childCtx := ctxWithChild()
Thread(childCtx, func() error {
childCtx = Thread(childCtx, func() error {
time.Sleep(5 * time.Second)
return testErr
})
ctx = mctx.WithChild(ctx, childCtx)
if err := Wait(ctx, cancelCh(500*time.Millisecond)); err != ErrDone {
t.Fatalf("should have got ErrDone, got: %v", err)
}

View File

@ -2,35 +2,33 @@
package mtest
import (
"context"
"testing"
"github.com/mediocregopher/mediocre-go-lib/mcfg"
"github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/mlog"
"github.com/mediocregopher/mediocre-go-lib/mrun"
)
type ctxKey int
type envCtxKey int
// NewCtx creates and returns a root Context suitable for testing.
func NewCtx() mctx.Context {
ctx := mctx.New()
mlog.From(ctx).SetMaxLevel(mlog.DebugLevel)
return ctx
func NewCtx() context.Context {
ctx := context.Background()
logger := mlog.NewLogger()
logger.SetMaxLevel(mlog.DebugLevel)
return mlog.Set(ctx, logger)
}
// SetEnv sets the given environment variable on the given Context, such that it
// will be used as if it was a real environment variable when the Run function
// from this package is called.
func SetEnv(ctx mctx.Context, key, val string) {
mctx.GetSetMutableValue(ctx, false, ctxKey(0), func(i interface{}) interface{} {
m, _ := i.(map[string]string)
if m == nil {
m = map[string]string{}
}
m[key] = val
return m
})
func SetEnv(ctx context.Context, key, val string) context.Context {
prevEnv, _ := ctx.Value(envCtxKey(0)).([][2]string)
env := make([][2]string, len(prevEnv), len(prevEnv)+1)
copy(env, prevEnv)
env = append(env, [2]string{key, val})
return context.WithValue(ctx, envCtxKey(0), env)
}
// Run performs the following using the given Context:
@ -45,11 +43,11 @@ func SetEnv(ctx mctx.Context, key, val string) {
//
// The intention is that Run is used within a test on a Context created via
// NewCtx, after any setup functions have been called (e.g. mnet.MListen).
func Run(ctx mctx.Context, t *testing.T, body func()) {
envMap, _ := mctx.MutableValue(ctx, ctxKey(0)).(map[string]string)
env := make([]string, 0, len(envMap))
for key, val := range envMap {
env = append(env, key+"="+val)
func Run(ctx context.Context, t *testing.T, body func()) {
envTups, _ := ctx.Value(envCtxKey(0)).([][2]string)
env := make([]string, 0, len(envTups))
for _, tup := range envTups {
env = append(env, tup[0]+"="+tup[1])
}
if err := mcfg.Populate(ctx, mcfg.SourceEnv{Env: env}); err != nil {

View File

@ -8,8 +8,8 @@ import (
func TestRun(t *T) {
ctx := NewCtx()
arg := mcfg.RequiredString(ctx, "arg", "Required by this test")
SetEnv(ctx, "ARG", "foo")
ctx, arg := mcfg.RequiredString(ctx, "arg", "Required by this test")
ctx = SetEnv(ctx, "ARG", "foo")
Run(ctx, t, func() {
if *arg != "foo" {
t.Fatalf(`arg not set to "foo", is set to %q`, *arg)