diff --git a/mcfg/cli.go b/mcfg/cli.go index 01302a0..61d94f3 100644 --- a/mcfg/cli.go +++ b/mcfg/cli.go @@ -13,6 +13,34 @@ import ( "github.com/mediocregopher/mediocre-go-lib/merr" ) +type cliKey int + +const ( + cliKeyTailPtr cliKey = iota +) + +// WithCLITail returns a Context which modifies the behavior of SourceCLI's +// Parse, if SourceCLI is used with that Context at all. Normally when SourceCLI +// encounters an unexpected Arg it will immediately return an error. This +// function modifies the Context to indicate to Parse that the unexpected Arg, +// and all subsequent Args (i.e. the tail) should be set to the returned +// []string value. +// +// If multiple WithCLITail calls are used then only the latest returned pointer +// will be filled. +func WithCLITail(ctx context.Context) (context.Context, *[]string) { + tailPtr := new([]string) + return context.WithValue(ctx, cliKeyTailPtr, tailPtr), tailPtr +} + +func populateCLITail(ctx context.Context, tail []string) bool { + tailPtr, ok := ctx.Value(cliKeyTailPtr).(*[]string) + if ok { + *tailPtr = tail + } + return ok +} + // SourceCLI is a Source which will parse configuration from the CLI. // // Possible CLI options are generated by joining a Param's Path and Name with @@ -37,12 +65,6 @@ type SourceCLI struct { Args []string // if nil then os.Args[1:] is used DisableHelpPage bool - - // Normally if any unexpected Arg value is encountered Parse will error out. - // If instead TailCallback is set then it will be called whenever the first - // unexpected Arg is encountered, and will not error out. TailCallback will - // be given a slice of Args starting at the first unexpected element. - TailCallback func([]string) } const ( @@ -53,7 +75,7 @@ const ( ) // Parse implements the method for the Source interface -func (cli *SourceCLI) Parse(params []Param) ([]ParamValue, error) { +func (cli *SourceCLI) Parse(ctx context.Context, params []Param) ([]ParamValue, error) { args := cli.Args if cli.Args == nil { args = os.Args[1:] @@ -96,8 +118,7 @@ func (cli *SourceCLI) Parse(params []Param) ([]ParamValue, error) { break } if !pOk { - if cli.TailCallback != nil { - cli.TailCallback(args[i:]) + if ok := populateCLITail(ctx, args[i:]); ok { return pvs, nil } ctx := mctx.Annotate(context.Background(), "param", arg) @@ -134,6 +155,7 @@ func (cli *SourceCLI) Parse(params []Param) ([]ParamValue, error) { ctx := mctx.Annotate(p.Context, "param", key) return nil, merr.New("param expected a value", ctx) } + return pvs, nil } diff --git a/mcfg/cli_test.go b/mcfg/cli_test.go index eec1e69..3b63982 100644 --- a/mcfg/cli_test.go +++ b/mcfg/cli_test.go @@ -3,6 +3,7 @@ package mcfg import ( "bytes" "context" + "fmt" "strings" . "testing" "time" @@ -107,18 +108,11 @@ func TestSourceCLI(t *T) { } } -func TestSourceCLITailCallback(t *T) { +func TestWithCLITail(t *T) { ctx := context.Background() ctx, _ = WithInt(ctx, "foo", 5, "") ctx, _ = WithBool(ctx, "bar", "") - var tail []string - src := &SourceCLI{ - TailCallback: func(gotTail []string) { - tail = gotTail - }, - } - type testCase struct { args []string expTail []string @@ -127,7 +121,7 @@ func TestSourceCLITailCallback(t *T) { cases := []testCase{ { args: []string{"--foo", "5"}, - expTail: []string{}, + expTail: nil, }, { args: []string{"--foo", "5", "a", "b", "c"}, @@ -139,7 +133,7 @@ func TestSourceCLITailCallback(t *T) { }, { args: []string{"--foo", "5", "--bar"}, - expTail: []string{}, + expTail: nil, }, { args: []string{"--foo", "5", "--bar", "a", "b", "c"}, @@ -148,12 +142,25 @@ func TestSourceCLITailCallback(t *T) { } for _, tc := range cases { - tail = []string{} - src.Args = tc.args - err := Populate(ctx, src) + ctx, tail := WithCLITail(ctx) + err := Populate(ctx, &SourceCLI{Args: tc.args}) massert.Require(t, massert.Comment(massert.All( massert.Nil(err), - massert.Equal(tc.expTail, tail), + massert.Equal(tc.expTail, *tail), ), "tc: %#v", tc)) } } + +func ExampleWithCLITail() { + ctx := context.Background() + ctx, foo := WithInt(ctx, "foo", 1, "Description of foo.") + ctx, tail := WithCLITail(ctx) + ctx, bar := WithString(ctx, "bar", "defaultVal", "Description of bar.") + + err := Populate(ctx, &SourceCLI{ + Args: []string{"--foo=100", "BADARG", "--bar", "BAR"}, + }) + + fmt.Printf("err:%v foo:%v bar:%v tail:%#v\n", err, *foo, *bar, *tail) + // Output: err: foo:100 bar:defaultVal tail:[]string{"BADARG", "--bar", "BAR"} +} diff --git a/mcfg/env.go b/mcfg/env.go index 5b12240..47f472f 100644 --- a/mcfg/env.go +++ b/mcfg/env.go @@ -43,7 +43,7 @@ func (env *SourceEnv) expectedName(path []string, name string) string { } // Parse implements the method for the Source interface -func (env *SourceEnv) Parse(params []Param) ([]ParamValue, error) { +func (env *SourceEnv) Parse(ctx context.Context, params []Param) ([]ParamValue, error) { kvs := env.Env if kvs == nil { kvs = os.Environ() diff --git a/mcfg/mcfg.go b/mcfg/mcfg.go index 1c0ef5a..b582482 100644 --- a/mcfg/mcfg.go +++ b/mcfg/mcfg.go @@ -72,7 +72,15 @@ func paramHash(path []string, name string) string { return paramFullName(path, name) + "/" + hStr } -func populate(params []Param, src Source) error { +// Populate uses the Source to populate the values of all Params which were +// added to the given Context, and all of its children. Populate may be called +// multiple times with the same Context, each time will only affect the values +// of the Params which were provided by the respective Source. +// +// Source may be nil to indicate that no configuration is provided. Only default +// values will be used, and if any parameters are required this will error. +func Populate(ctx context.Context, src Source) error { + params := collectParams(ctx) if src == nil { src = ParamValues(nil) } @@ -89,7 +97,7 @@ func populate(params []Param, src Source) error { pM[hash] = p } - pvs, err := src.Parse(params) + pvs, err := src.Parse(ctx, params) if err != nil { return err } @@ -127,14 +135,3 @@ func populate(params []Param, src Source) error { return nil } - -// Populate uses the Source to populate the values of all Params which were -// added to the given Context, and all of its children. Populate may be called -// multiple times with the same Context, each time will only affect the values -// of the Params which were provided by the respective Source. -// -// Source may be nil to indicate that no configuration is provided. Only default -// values will be used, and if any parameters are required this will error. -func Populate(ctx context.Context, src Source) error { - return populate(collectParams(ctx), src) -} diff --git a/mcfg/source.go b/mcfg/source.go index 3bb661d..c3f9cc4 100644 --- a/mcfg/source.go +++ b/mcfg/source.go @@ -1,6 +1,7 @@ package mcfg import ( + "context" "encoding/json" ) @@ -13,7 +14,8 @@ type ParamValue struct { } // Source parses ParamValues out of a particular configuration source, given a -// sorted set of possible Params to parse. +// sorted set of possible Params to parse, and the Context from with the Params +// were extracted. // // Source should not return ParamValues which were not explicitly set to a value // by the configuration source. @@ -23,7 +25,7 @@ type ParamValue struct { // ParamValues which do not correspond to any of the passed in Params. These // will be ignored in Populate. type Source interface { - Parse([]Param) ([]ParamValue, error) + Parse(context.Context, []Param) ([]ParamValue, error) } // ParamValues is simply a slice of ParamValue elements, which implements Parse @@ -31,7 +33,7 @@ type Source interface { type ParamValues []ParamValue // Parse implements the method for the Source interface. -func (pvs ParamValues) Parse([]Param) ([]ParamValue, error) { +func (pvs ParamValues) Parse(context.Context, []Param) ([]ParamValue, error) { return pvs, nil } @@ -41,10 +43,10 @@ func (pvs ParamValues) Parse([]Param) ([]ParamValue, error) { type Sources []Source // Parse implements the method for the Source interface. -func (ss Sources) Parse(params []Param) ([]ParamValue, error) { +func (ss Sources) Parse(ctx context.Context, params []Param) ([]ParamValue, error) { var pvs []ParamValue for _, s := range ss { - innerPVs, err := s.Parse(params) + innerPVs, err := s.Parse(ctx, params) if err != nil { return nil, err } diff --git a/mcfg/source_test.go b/mcfg/source_test.go index a084de1..67b1dc6 100644 --- a/mcfg/source_test.go +++ b/mcfg/source_test.go @@ -143,7 +143,7 @@ func (scs srcCommonState) applyCtxAndPV(p srcCommonParams) srcCommonState { // ParamValues func (scs srcCommonState) assert(s Source) error { root := scs.mkRoot() - gotPVs, err := s.Parse(collectParams(root)) + gotPVs, err := s.Parse(root, collectParams(root)) if err != nil { return err }