mcfg: Parse now takes a Context, which allows for implementing WithCLITail

This commit is contained in:
Brian Picciano 2019-04-04 12:20:40 -04:00
parent a596306d9e
commit c0afa3d126
6 changed files with 71 additions and 43 deletions

View File

@ -13,6 +13,34 @@ import (
"github.com/mediocregopher/mediocre-go-lib/merr"
)
type cliKey int
const (
cliKeyTailPtr cliKey = iota
)
// WithCLITail returns a Context which modifies the behavior of SourceCLI's
// Parse, if SourceCLI is used with that Context at all. Normally when SourceCLI
// encounters an unexpected Arg it will immediately return an error. This
// function modifies the Context to indicate to Parse that the unexpected Arg,
// and all subsequent Args (i.e. the tail) should be set to the returned
// []string value.
//
// If multiple WithCLITail calls are used then only the latest returned pointer
// will be filled.
func WithCLITail(ctx context.Context) (context.Context, *[]string) {
tailPtr := new([]string)
return context.WithValue(ctx, cliKeyTailPtr, tailPtr), tailPtr
}
func populateCLITail(ctx context.Context, tail []string) bool {
tailPtr, ok := ctx.Value(cliKeyTailPtr).(*[]string)
if ok {
*tailPtr = tail
}
return ok
}
// SourceCLI is a Source which will parse configuration from the CLI.
//
// Possible CLI options are generated by joining a Param's Path and Name with
@ -37,12 +65,6 @@ type SourceCLI struct {
Args []string // if nil then os.Args[1:] is used
DisableHelpPage bool
// Normally if any unexpected Arg value is encountered Parse will error out.
// If instead TailCallback is set then it will be called whenever the first
// unexpected Arg is encountered, and will not error out. TailCallback will
// be given a slice of Args starting at the first unexpected element.
TailCallback func([]string)
}
const (
@ -53,7 +75,7 @@ const (
)
// Parse implements the method for the Source interface
func (cli *SourceCLI) Parse(params []Param) ([]ParamValue, error) {
func (cli *SourceCLI) Parse(ctx context.Context, params []Param) ([]ParamValue, error) {
args := cli.Args
if cli.Args == nil {
args = os.Args[1:]
@ -96,8 +118,7 @@ func (cli *SourceCLI) Parse(params []Param) ([]ParamValue, error) {
break
}
if !pOk {
if cli.TailCallback != nil {
cli.TailCallback(args[i:])
if ok := populateCLITail(ctx, args[i:]); ok {
return pvs, nil
}
ctx := mctx.Annotate(context.Background(), "param", arg)
@ -134,6 +155,7 @@ func (cli *SourceCLI) Parse(params []Param) ([]ParamValue, error) {
ctx := mctx.Annotate(p.Context, "param", key)
return nil, merr.New("param expected a value", ctx)
}
return pvs, nil
}

View File

@ -3,6 +3,7 @@ package mcfg
import (
"bytes"
"context"
"fmt"
"strings"
. "testing"
"time"
@ -107,18 +108,11 @@ func TestSourceCLI(t *T) {
}
}
func TestSourceCLITailCallback(t *T) {
func TestWithCLITail(t *T) {
ctx := context.Background()
ctx, _ = WithInt(ctx, "foo", 5, "")
ctx, _ = WithBool(ctx, "bar", "")
var tail []string
src := &SourceCLI{
TailCallback: func(gotTail []string) {
tail = gotTail
},
}
type testCase struct {
args []string
expTail []string
@ -127,7 +121,7 @@ func TestSourceCLITailCallback(t *T) {
cases := []testCase{
{
args: []string{"--foo", "5"},
expTail: []string{},
expTail: nil,
},
{
args: []string{"--foo", "5", "a", "b", "c"},
@ -139,7 +133,7 @@ func TestSourceCLITailCallback(t *T) {
},
{
args: []string{"--foo", "5", "--bar"},
expTail: []string{},
expTail: nil,
},
{
args: []string{"--foo", "5", "--bar", "a", "b", "c"},
@ -148,12 +142,25 @@ func TestSourceCLITailCallback(t *T) {
}
for _, tc := range cases {
tail = []string{}
src.Args = tc.args
err := Populate(ctx, src)
ctx, tail := WithCLITail(ctx)
err := Populate(ctx, &SourceCLI{Args: tc.args})
massert.Require(t, massert.Comment(massert.All(
massert.Nil(err),
massert.Equal(tc.expTail, tail),
massert.Equal(tc.expTail, *tail),
), "tc: %#v", tc))
}
}
func ExampleWithCLITail() {
ctx := context.Background()
ctx, foo := WithInt(ctx, "foo", 1, "Description of foo.")
ctx, tail := WithCLITail(ctx)
ctx, bar := WithString(ctx, "bar", "defaultVal", "Description of bar.")
err := Populate(ctx, &SourceCLI{
Args: []string{"--foo=100", "BADARG", "--bar", "BAR"},
})
fmt.Printf("err:%v foo:%v bar:%v tail:%#v\n", err, *foo, *bar, *tail)
// Output: err:<nil> foo:100 bar:defaultVal tail:[]string{"BADARG", "--bar", "BAR"}
}

View File

@ -43,7 +43,7 @@ func (env *SourceEnv) expectedName(path []string, name string) string {
}
// Parse implements the method for the Source interface
func (env *SourceEnv) Parse(params []Param) ([]ParamValue, error) {
func (env *SourceEnv) Parse(ctx context.Context, params []Param) ([]ParamValue, error) {
kvs := env.Env
if kvs == nil {
kvs = os.Environ()

View File

@ -72,7 +72,15 @@ func paramHash(path []string, name string) string {
return paramFullName(path, name) + "/" + hStr
}
func populate(params []Param, src Source) error {
// Populate uses the Source to populate the values of all Params which were
// added to the given Context, and all of its children. Populate may be called
// multiple times with the same Context, each time will only affect the values
// of the Params which were provided by the respective Source.
//
// Source may be nil to indicate that no configuration is provided. Only default
// values will be used, and if any parameters are required this will error.
func Populate(ctx context.Context, src Source) error {
params := collectParams(ctx)
if src == nil {
src = ParamValues(nil)
}
@ -89,7 +97,7 @@ func populate(params []Param, src Source) error {
pM[hash] = p
}
pvs, err := src.Parse(params)
pvs, err := src.Parse(ctx, params)
if err != nil {
return err
}
@ -127,14 +135,3 @@ func populate(params []Param, src Source) error {
return nil
}
// Populate uses the Source to populate the values of all Params which were
// added to the given Context, and all of its children. Populate may be called
// multiple times with the same Context, each time will only affect the values
// of the Params which were provided by the respective Source.
//
// Source may be nil to indicate that no configuration is provided. Only default
// values will be used, and if any parameters are required this will error.
func Populate(ctx context.Context, src Source) error {
return populate(collectParams(ctx), src)
}

View File

@ -1,6 +1,7 @@
package mcfg
import (
"context"
"encoding/json"
)
@ -13,7 +14,8 @@ type ParamValue struct {
}
// Source parses ParamValues out of a particular configuration source, given a
// sorted set of possible Params to parse.
// sorted set of possible Params to parse, and the Context from with the Params
// were extracted.
//
// Source should not return ParamValues which were not explicitly set to a value
// by the configuration source.
@ -23,7 +25,7 @@ type ParamValue struct {
// ParamValues which do not correspond to any of the passed in Params. These
// will be ignored in Populate.
type Source interface {
Parse([]Param) ([]ParamValue, error)
Parse(context.Context, []Param) ([]ParamValue, error)
}
// ParamValues is simply a slice of ParamValue elements, which implements Parse
@ -31,7 +33,7 @@ type Source interface {
type ParamValues []ParamValue
// Parse implements the method for the Source interface.
func (pvs ParamValues) Parse([]Param) ([]ParamValue, error) {
func (pvs ParamValues) Parse(context.Context, []Param) ([]ParamValue, error) {
return pvs, nil
}
@ -41,10 +43,10 @@ func (pvs ParamValues) Parse([]Param) ([]ParamValue, error) {
type Sources []Source
// Parse implements the method for the Source interface.
func (ss Sources) Parse(params []Param) ([]ParamValue, error) {
func (ss Sources) Parse(ctx context.Context, params []Param) ([]ParamValue, error) {
var pvs []ParamValue
for _, s := range ss {
innerPVs, err := s.Parse(params)
innerPVs, err := s.Parse(ctx, params)
if err != nil {
return nil, err
}

View File

@ -143,7 +143,7 @@ func (scs srcCommonState) applyCtxAndPV(p srcCommonParams) srcCommonState {
// ParamValues
func (scs srcCommonState) assert(s Source) error {
root := scs.mkRoot()
gotPVs, err := s.Parse(collectParams(root))
gotPVs, err := s.Parse(root, collectParams(root))
if err != nil {
return err
}