mcfg: Parse now takes a Context, which allows for implementing WithCLITail

This commit is contained in:
Brian Picciano 2019-04-04 12:20:40 -04:00
parent a596306d9e
commit c0afa3d126
6 changed files with 71 additions and 43 deletions

View File

@ -13,6 +13,34 @@ import (
"github.com/mediocregopher/mediocre-go-lib/merr" "github.com/mediocregopher/mediocre-go-lib/merr"
) )
type cliKey int
const (
cliKeyTailPtr cliKey = iota
)
// WithCLITail returns a Context which modifies the behavior of SourceCLI's
// Parse, if SourceCLI is used with that Context at all. Normally when SourceCLI
// encounters an unexpected Arg it will immediately return an error. This
// function modifies the Context to indicate to Parse that the unexpected Arg,
// and all subsequent Args (i.e. the tail) should be set to the returned
// []string value.
//
// If multiple WithCLITail calls are used then only the latest returned pointer
// will be filled.
func WithCLITail(ctx context.Context) (context.Context, *[]string) {
tailPtr := new([]string)
return context.WithValue(ctx, cliKeyTailPtr, tailPtr), tailPtr
}
func populateCLITail(ctx context.Context, tail []string) bool {
tailPtr, ok := ctx.Value(cliKeyTailPtr).(*[]string)
if ok {
*tailPtr = tail
}
return ok
}
// SourceCLI is a Source which will parse configuration from the CLI. // SourceCLI is a Source which will parse configuration from the CLI.
// //
// Possible CLI options are generated by joining a Param's Path and Name with // Possible CLI options are generated by joining a Param's Path and Name with
@ -37,12 +65,6 @@ type SourceCLI struct {
Args []string // if nil then os.Args[1:] is used Args []string // if nil then os.Args[1:] is used
DisableHelpPage bool DisableHelpPage bool
// Normally if any unexpected Arg value is encountered Parse will error out.
// If instead TailCallback is set then it will be called whenever the first
// unexpected Arg is encountered, and will not error out. TailCallback will
// be given a slice of Args starting at the first unexpected element.
TailCallback func([]string)
} }
const ( const (
@ -53,7 +75,7 @@ const (
) )
// Parse implements the method for the Source interface // Parse implements the method for the Source interface
func (cli *SourceCLI) Parse(params []Param) ([]ParamValue, error) { func (cli *SourceCLI) Parse(ctx context.Context, params []Param) ([]ParamValue, error) {
args := cli.Args args := cli.Args
if cli.Args == nil { if cli.Args == nil {
args = os.Args[1:] args = os.Args[1:]
@ -96,8 +118,7 @@ func (cli *SourceCLI) Parse(params []Param) ([]ParamValue, error) {
break break
} }
if !pOk { if !pOk {
if cli.TailCallback != nil { if ok := populateCLITail(ctx, args[i:]); ok {
cli.TailCallback(args[i:])
return pvs, nil return pvs, nil
} }
ctx := mctx.Annotate(context.Background(), "param", arg) ctx := mctx.Annotate(context.Background(), "param", arg)
@ -134,6 +155,7 @@ func (cli *SourceCLI) Parse(params []Param) ([]ParamValue, error) {
ctx := mctx.Annotate(p.Context, "param", key) ctx := mctx.Annotate(p.Context, "param", key)
return nil, merr.New("param expected a value", ctx) return nil, merr.New("param expected a value", ctx)
} }
return pvs, nil return pvs, nil
} }

View File

@ -3,6 +3,7 @@ package mcfg
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
"strings" "strings"
. "testing" . "testing"
"time" "time"
@ -107,18 +108,11 @@ func TestSourceCLI(t *T) {
} }
} }
func TestSourceCLITailCallback(t *T) { func TestWithCLITail(t *T) {
ctx := context.Background() ctx := context.Background()
ctx, _ = WithInt(ctx, "foo", 5, "") ctx, _ = WithInt(ctx, "foo", 5, "")
ctx, _ = WithBool(ctx, "bar", "") ctx, _ = WithBool(ctx, "bar", "")
var tail []string
src := &SourceCLI{
TailCallback: func(gotTail []string) {
tail = gotTail
},
}
type testCase struct { type testCase struct {
args []string args []string
expTail []string expTail []string
@ -127,7 +121,7 @@ func TestSourceCLITailCallback(t *T) {
cases := []testCase{ cases := []testCase{
{ {
args: []string{"--foo", "5"}, args: []string{"--foo", "5"},
expTail: []string{}, expTail: nil,
}, },
{ {
args: []string{"--foo", "5", "a", "b", "c"}, args: []string{"--foo", "5", "a", "b", "c"},
@ -139,7 +133,7 @@ func TestSourceCLITailCallback(t *T) {
}, },
{ {
args: []string{"--foo", "5", "--bar"}, args: []string{"--foo", "5", "--bar"},
expTail: []string{}, expTail: nil,
}, },
{ {
args: []string{"--foo", "5", "--bar", "a", "b", "c"}, args: []string{"--foo", "5", "--bar", "a", "b", "c"},
@ -148,12 +142,25 @@ func TestSourceCLITailCallback(t *T) {
} }
for _, tc := range cases { for _, tc := range cases {
tail = []string{} ctx, tail := WithCLITail(ctx)
src.Args = tc.args err := Populate(ctx, &SourceCLI{Args: tc.args})
err := Populate(ctx, src)
massert.Require(t, massert.Comment(massert.All( massert.Require(t, massert.Comment(massert.All(
massert.Nil(err), massert.Nil(err),
massert.Equal(tc.expTail, tail), massert.Equal(tc.expTail, *tail),
), "tc: %#v", tc)) ), "tc: %#v", tc))
} }
} }
func ExampleWithCLITail() {
ctx := context.Background()
ctx, foo := WithInt(ctx, "foo", 1, "Description of foo.")
ctx, tail := WithCLITail(ctx)
ctx, bar := WithString(ctx, "bar", "defaultVal", "Description of bar.")
err := Populate(ctx, &SourceCLI{
Args: []string{"--foo=100", "BADARG", "--bar", "BAR"},
})
fmt.Printf("err:%v foo:%v bar:%v tail:%#v\n", err, *foo, *bar, *tail)
// Output: err:<nil> foo:100 bar:defaultVal tail:[]string{"BADARG", "--bar", "BAR"}
}

View File

@ -43,7 +43,7 @@ func (env *SourceEnv) expectedName(path []string, name string) string {
} }
// Parse implements the method for the Source interface // Parse implements the method for the Source interface
func (env *SourceEnv) Parse(params []Param) ([]ParamValue, error) { func (env *SourceEnv) Parse(ctx context.Context, params []Param) ([]ParamValue, error) {
kvs := env.Env kvs := env.Env
if kvs == nil { if kvs == nil {
kvs = os.Environ() kvs = os.Environ()

View File

@ -72,7 +72,15 @@ func paramHash(path []string, name string) string {
return paramFullName(path, name) + "/" + hStr return paramFullName(path, name) + "/" + hStr
} }
func populate(params []Param, src Source) error { // Populate uses the Source to populate the values of all Params which were
// added to the given Context, and all of its children. Populate may be called
// multiple times with the same Context, each time will only affect the values
// of the Params which were provided by the respective Source.
//
// Source may be nil to indicate that no configuration is provided. Only default
// values will be used, and if any parameters are required this will error.
func Populate(ctx context.Context, src Source) error {
params := collectParams(ctx)
if src == nil { if src == nil {
src = ParamValues(nil) src = ParamValues(nil)
} }
@ -89,7 +97,7 @@ func populate(params []Param, src Source) error {
pM[hash] = p pM[hash] = p
} }
pvs, err := src.Parse(params) pvs, err := src.Parse(ctx, params)
if err != nil { if err != nil {
return err return err
} }
@ -127,14 +135,3 @@ func populate(params []Param, src Source) error {
return nil return nil
} }
// Populate uses the Source to populate the values of all Params which were
// added to the given Context, and all of its children. Populate may be called
// multiple times with the same Context, each time will only affect the values
// of the Params which were provided by the respective Source.
//
// Source may be nil to indicate that no configuration is provided. Only default
// values will be used, and if any parameters are required this will error.
func Populate(ctx context.Context, src Source) error {
return populate(collectParams(ctx), src)
}

View File

@ -1,6 +1,7 @@
package mcfg package mcfg
import ( import (
"context"
"encoding/json" "encoding/json"
) )
@ -13,7 +14,8 @@ type ParamValue struct {
} }
// Source parses ParamValues out of a particular configuration source, given a // Source parses ParamValues out of a particular configuration source, given a
// sorted set of possible Params to parse. // sorted set of possible Params to parse, and the Context from with the Params
// were extracted.
// //
// Source should not return ParamValues which were not explicitly set to a value // Source should not return ParamValues which were not explicitly set to a value
// by the configuration source. // by the configuration source.
@ -23,7 +25,7 @@ type ParamValue struct {
// ParamValues which do not correspond to any of the passed in Params. These // ParamValues which do not correspond to any of the passed in Params. These
// will be ignored in Populate. // will be ignored in Populate.
type Source interface { type Source interface {
Parse([]Param) ([]ParamValue, error) Parse(context.Context, []Param) ([]ParamValue, error)
} }
// ParamValues is simply a slice of ParamValue elements, which implements Parse // ParamValues is simply a slice of ParamValue elements, which implements Parse
@ -31,7 +33,7 @@ type Source interface {
type ParamValues []ParamValue type ParamValues []ParamValue
// Parse implements the method for the Source interface. // Parse implements the method for the Source interface.
func (pvs ParamValues) Parse([]Param) ([]ParamValue, error) { func (pvs ParamValues) Parse(context.Context, []Param) ([]ParamValue, error) {
return pvs, nil return pvs, nil
} }
@ -41,10 +43,10 @@ func (pvs ParamValues) Parse([]Param) ([]ParamValue, error) {
type Sources []Source type Sources []Source
// Parse implements the method for the Source interface. // Parse implements the method for the Source interface.
func (ss Sources) Parse(params []Param) ([]ParamValue, error) { func (ss Sources) Parse(ctx context.Context, params []Param) ([]ParamValue, error) {
var pvs []ParamValue var pvs []ParamValue
for _, s := range ss { for _, s := range ss {
innerPVs, err := s.Parse(params) innerPVs, err := s.Parse(ctx, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -143,7 +143,7 @@ func (scs srcCommonState) applyCtxAndPV(p srcCommonParams) srcCommonState {
// ParamValues // ParamValues
func (scs srcCommonState) assert(s Source) error { func (scs srcCommonState) assert(s Source) error {
root := scs.mkRoot() root := scs.mkRoot()
gotPVs, err := s.Parse(collectParams(root)) gotPVs, err := s.Parse(root, collectParams(root))
if err != nil { if err != nil {
return err return err
} }