diff --git a/m/m.go b/m/m.go index ec15973..f456167 100644 --- a/m/m.go +++ b/m/m.go @@ -67,7 +67,11 @@ func ServiceContext() context.Context { // Start performs the work of populating configuration parameters and triggering // the start event. It will return once the Start event has completed running. -func Start(ctx context.Context) { +// +// This function returns a Context because there are cases where the Context +// will be modified during Start, such as if WithSubCommand was used. If the +// Context was not modified then the passed in Context will be returned. +func Start(ctx context.Context) context.Context { src, _ := ctx.Value(cfgSrcKey(0)).(mcfg.Source) if src == nil { mlog.Fatal("ctx not sourced from m package", ctx) @@ -76,12 +80,14 @@ func Start(ctx context.Context) { // no logging should happen before populate, primarily because log-level // hasn't been populated yet, but also because it makes help output on cli // look weird. - if err := mcfg.Populate(ctx, src); err != nil { + ctx, err := mcfg.Populate(ctx, src) + if err != nil { mlog.Fatal("error populating configuration", ctx, merr.Context(err)) } else if err := mrun.Start(ctx); err != nil { mlog.Fatal("error triggering start event", ctx, merr.Context(err)) } mlog.Info("start hooks completed", ctx) + return ctx } // StartWaitStop performs the work of populating configuration parameters, @@ -89,7 +95,7 @@ func Start(ctx context.Context) { // stop event. Run will block until the stop event is done. If any errors are // encountered a fatal is thrown. func StartWaitStop(ctx context.Context) { - Start(ctx) + ctx = Start(ctx) { ch := make(chan os.Signal, 1) diff --git a/m/m_test.go b/m/m_test.go index 0756cb4..e15d508 100644 --- a/m/m_test.go +++ b/m/m_test.go @@ -31,7 +31,7 @@ func TestServiceCtx(t *T) { ctx = mctx.WithChild(ctx, ctxA) params := mcfg.ParamValues{{Name: "log-level", Value: json.RawMessage(`"DEBUG"`)}} - if err := mcfg.Populate(ctx, params); err != nil { + if _, err := mcfg.Populate(ctx, params); err != nil { t.Fatal(err) } else if err := mrun.Start(ctx); err != nil { t.Fatal(err) diff --git a/mcfg/cli.go b/mcfg/cli.go index ae104ea..c1cb9c2 100644 --- a/mcfg/cli.go +++ b/mcfg/cli.go @@ -17,6 +17,7 @@ type cliKey int const ( cliKeyTailPtr cliKey = iota + cliKeySubCmdM ) // WithCLITail returns a Context which modifies the behavior of SourceCLI's @@ -41,6 +42,45 @@ func populateCLITail(ctx context.Context, tail []string) bool { return ok } +type subCmd struct { + name, descr string + flag *bool + callback func(context.Context) context.Context +} + +// WithCLISubCommand establishes a sub-command which can be activated on the +// command-line. When a sub-command is given on the command-line, the bool +// returned for that sub-command will be set to true. +// +// Additionally, the Context which was passed into Parse (i.e. the one passed +// into Populate) will be passed into the given callback, and the returned one +// used for subsequent parsing. This allows for setting sub-command specific +// Params, sub-command specific runtime behavior (via mrun.WithStartHook), +// support for sub-sub-commands, and more. The callback may be nil. +// +// If any sub-commands have been defined on a Context which is passed into +// Parse, it is assumed that a sub-command is required on the command-line. The +// exception is if a sub-command with a name of "" has been defined; if so, it +// will be used as the intended sub-command if none is specified. +// +// Sub-commands must be specified before any other options on the command-line. +func WithCLISubCommand(ctx context.Context, name, descr string, callback func(context.Context) context.Context) (context.Context, *bool) { + m, _ := ctx.Value(cliKeySubCmdM).(map[string]subCmd) + if m == nil { + m = map[string]subCmd{} + ctx = context.WithValue(ctx, cliKeySubCmdM, m) + } + + flag := new(bool) + m[name] = subCmd{ + name: name, + descr: descr, + flag: flag, + callback: callback, + } + return ctx, flag +} + // SourceCLI is a Source which will parse configuration from the CLI. // // Possible CLI options are generated by joining a Param's Path and Name with @@ -75,16 +115,56 @@ const ( ) // Parse implements the method for the Source interface -func (cli *SourceCLI) Parse(ctx context.Context, params []Param) (context.Context, []ParamValue, error) { +func (cli *SourceCLI) Parse(ctx context.Context) (context.Context, []ParamValue, error) { args := cli.Args if cli.Args == nil { args = os.Args[1:] } + return cli.parse(ctx, nil, args) +} - pM, err := cli.cliParams(params) +func (cli *SourceCLI) parse( + ctx context.Context, + subCmdPrefix, args []string, +) ( + context.Context, + []ParamValue, + error, +) { + pM, err := cli.cliParams(CollectParams(ctx)) if err != nil { return nil, nil, err } + subCmdM, _ := ctx.Value(cliKeySubCmdM).(map[string]subCmd) + + printHelpAndExit := func() { + cli.printHelp(os.Stderr, subCmdPrefix, subCmdM, pM) + os.Stderr.Sync() + os.Exit(1) + } + + // if sub-commands were defined on this Context then handle that first. One + // of them should have been given, in which case send the Context through + // the callback to obtain a new one (which presumably has further config + // options the previous didn't) and call parse again. + if len(subCmdM) > 0 { + subCmd, args, ok := cli.getSubCmd(subCmdM, args) + if !ok { + printHelpAndExit() + } + ctx = context.WithValue(ctx, cliKeySubCmdM, nil) + if subCmd.callback != nil { + ctx = subCmd.callback(ctx) + } + if subCmd.name != "" { + subCmdPrefix = append(subCmdPrefix, subCmd.name) + } + *subCmd.flag = true + return cli.parse(ctx, subCmdPrefix, args) + } + + // if sub-commands were not set, then proceed with normal command-line arg + // processing. pvs := make([]ParamValue, 0, len(args)) var ( key string @@ -98,9 +178,7 @@ func (cli *SourceCLI) Parse(ctx context.Context, params []Param) (context.Contex pvStrVal = arg pvStrValOk = true } else if !cli.DisableHelpPage && arg == cliHelpArg { - cli.printHelp(os.Stdout, pM) - os.Stdout.Sync() - os.Exit(1) + printHelpAndExit() } else { for key, p = range pM { if arg == key { @@ -159,6 +237,23 @@ func (cli *SourceCLI) Parse(ctx context.Context, params []Param) (context.Contex return ctx, pvs, nil } +func (cli *SourceCLI) getSubCmd(subCmdM map[string]subCmd, args []string) (subCmd, []string, bool) { + // if a proper sub-command is given then great, return that + if len(args) > 0 { + if subCmd, ok := subCmdM[args[0]]; ok { + return subCmd, args[1:], true + } + } + + // if the empty subCmd is set in the map it means an absent sub-command is + // allowed, check if that's the case + if subCmd, ok := subCmdM[""]; ok { + return subCmd, args, true + } + + return subCmd{}, args, false +} + func (cli *SourceCLI) cliParams(params []Param) (map[string]Param, error) { m := map[string]Param{} for _, p := range params { @@ -168,7 +263,12 @@ func (cli *SourceCLI) cliParams(params []Param) (map[string]Param, error) { return m, nil } -func (cli *SourceCLI) printHelp(w io.Writer, pM map[string]Param) { +func (cli *SourceCLI) printHelp( + w io.Writer, + subCmdPrefix []string, + subCmdM map[string]subCmd, + pM map[string]Param, +) { type pEntry struct { arg string Param @@ -200,24 +300,68 @@ func (cli *SourceCLI) printHelp(w io.Writer, pM map[string]Param) { return fmt.Sprint(val.Interface()) } - for _, p := range pA { - fmt.Fprintf(w, "\n%s", p.arg) - if p.IsBool { - fmt.Fprintf(w, " (Flag)") - } else if p.Required { - fmt.Fprintf(w, " (Required)") - } else if defVal := fmtDefaultVal(p.Into); defVal != "" { - fmt.Fprintf(w, " (Default: %s)", defVal) + type subCmdEntry struct { + name string + subCmd + } + + subCmdA := make([]subCmdEntry, 0, len(subCmdM)) + for name, subCmd := range subCmdM { + if name == "" { + name = "" } - fmt.Fprintf(w, "\n") - if usage := p.Usage; usage != "" { - // make all usages end with a period, because I say so - usage = strings.TrimSpace(usage) - if !strings.HasSuffix(usage, ".") { - usage += "." - } - fmt.Fprintln(w, "\t"+usage) + subCmdA = append(subCmdA, subCmdEntry{name: name, subCmd: subCmd}) + } + + sort.Slice(subCmdA, func(i, j int) bool { + return subCmdA[i].name < subCmdA[j].name + }) + + fmt.Fprintf(w, "Usage: %s", os.Args[0]) + if len(subCmdPrefix) > 0 { + fmt.Fprintf(w, " %s", strings.Join(subCmdPrefix, " ")) + } + if len(subCmdA) > 0 { + if _, ok := subCmdM[""]; ok { + fmt.Fprint(w, " [sub-command]") + } else { + fmt.Fprint(w, " ") + } + } + if len(pA) > 0 { + fmt.Fprint(w, " [options]") + } + fmt.Fprint(w, "\n\n") + + if len(subCmdA) > 0 { + fmt.Fprint(w, "Sub-commands:\n\n") + for _, subCmd := range subCmdA { + fmt.Fprintf(w, "\t%s\t%s\n", subCmd.name, subCmd.descr) + } + fmt.Fprint(w, "\n") + } + + if len(pA) > 0 { + fmt.Fprint(w, "Options:\n\n") + for _, p := range pA { + fmt.Fprintf(w, "\t%s", p.arg) + if p.IsBool { + fmt.Fprintf(w, " (Flag)") + } else if p.Required { + fmt.Fprintf(w, " (Required)") + } else if defVal := fmtDefaultVal(p.Into); defVal != "" { + fmt.Fprintf(w, " (Default: %s)", defVal) + } + fmt.Fprint(w, "\n") + if usage := p.Usage; usage != "" { + // make all usages end with a period, because I say so + usage = strings.TrimSpace(usage) + if !strings.HasSuffix(usage, ".") { + usage += "." + } + fmt.Fprintln(w, "\t\t"+usage) + } + fmt.Fprint(w, "\n") } } - fmt.Fprintf(w, "\n") } diff --git a/mcfg/cli_test.go b/mcfg/cli_test.go index e194743..0bdc79f 100644 --- a/mcfg/cli_test.go +++ b/mcfg/cli_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "regexp" "strings" . "testing" "time" @@ -16,35 +17,122 @@ import ( ) func TestSourceCLIHelp(t *T) { + assertHelp := func(ctx context.Context, subCmdPrefix []string, exp string) { + buf := new(bytes.Buffer) + src := &SourceCLI{} + pM, err := src.cliParams(CollectParams(ctx)) + require.NoError(t, err) + subCmdM, _ := ctx.Value(cliKeySubCmdM).(map[string]subCmd) + src.printHelp(buf, subCmdPrefix, subCmdM, pM) + + out := buf.String() + ok := regexp.MustCompile(exp).MatchString(out) + assert.True(t, ok, "exp:%s (%q)\ngot:%s (%q)", exp, exp, out, out) + } + ctx := context.Background() + assertHelp(ctx, nil, `^Usage: \S+ + +$`) + assertHelp(ctx, []string{"foo", "bar"}, `^Usage: \S+ foo bar + +$`) + ctx, _ = WithInt(ctx, "foo", 5, "Test int param ") // trailing space should be trimmed ctx, _ = WithBool(ctx, "bar", "Test bool param.") ctx, _ = WithString(ctx, "baz", "baz", "Test string param") ctx, _ = WithRequiredString(ctx, "baz2", "") ctx, _ = WithRequiredString(ctx, "baz3", "") - src := SourceCLI{} - buf := new(bytes.Buffer) - pM, err := src.cliParams(collectParams(ctx)) - require.NoError(t, err) - new(SourceCLI).printHelp(buf, pM) + assertHelp(ctx, nil, `^Usage: \S+ \[options\] - exp := ` ---baz2 (Required) +Options: ---baz3 (Required) + --baz2 \(Required\) ---bar (Flag) - Test bool param. + --baz3 \(Required\) ---baz (Default: "baz") - Test string param. + --bar \(Flag\) + Test bool param. ---foo (Default: 5) - Test int param. + --baz \(Default: "baz"\) + Test string param. -` - assert.Equal(t, exp, buf.String()) + --foo \(Default: 5\) + Test int param. + +$`) + + assertHelp(ctx, []string{"foo", "bar"}, `^Usage: \S+ foo bar \[options\] + +Options: + + --baz2 \(Required\) + + --baz3 \(Required\) + + --bar \(Flag\) + Test bool param. + + --baz \(Default: "baz"\) + Test string param. + + --foo \(Default: 5\) + Test int param. + +$`) + + ctx, _ = WithCLISubCommand(ctx, "first", "First sub-command", nil) + ctx, _ = WithCLISubCommand(ctx, "second", "Second sub-command", nil) + assertHelp(ctx, []string{"foo", "bar"}, `^Usage: \S+ foo bar \[options\] + +Sub-commands: + + first First sub-command + second Second sub-command + +Options: + + --baz2 \(Required\) + + --baz3 \(Required\) + + --bar \(Flag\) + Test bool param. + + --baz \(Default: "baz"\) + Test string param. + + --foo \(Default: 5\) + Test int param. + +$`) + + ctx, _ = WithCLISubCommand(ctx, "", "No sub-command", nil) + assertHelp(ctx, []string{"foo", "bar"}, `^Usage: \S+ foo bar \[sub-command\] \[options\] + +Sub-commands: + + No sub-command + first First sub-command + second Second sub-command + +Options: + + --baz2 \(Required\) + + --baz3 \(Required\) + + --bar \(Flag\) + Test bool param. + + --baz \(Default: "baz"\) + Test string param. + + --foo \(Default: 5\) + Test int param. + +$`) } func TestSourceCLI(t *T) { @@ -164,3 +252,91 @@ func ExampleWithCLITail() { fmt.Printf("err:%v foo:%v bar:%v tail:%#v\n", err, *foo, *bar, *tail) // Output: err: foo:100 bar:defaultVal tail:[]string{"BADARG", "--bar", "BAR"} } + +func TestWithCLISubCommand(t *T) { + var ( + ctx context.Context + foo *int + bar *int + baz *int + aFlag *bool + defaultFlag *bool + ) + reset := func() { + foo, bar, baz, aFlag, defaultFlag = nil, nil, nil, nil, nil + ctx = context.Background() + ctx, foo = WithInt(ctx, "foo", 0, "Description of foo.") + ctx, aFlag = WithCLISubCommand(ctx, "a", "Description of a.", + func(ctx context.Context) context.Context { + ctx, bar = WithInt(ctx, "bar", 0, "Description of bar.") + return ctx + }) + ctx, defaultFlag = WithCLISubCommand(ctx, "", "Description of default.", + func(ctx context.Context) context.Context { + ctx, baz = WithInt(ctx, "baz", 0, "Description of baz.") + return ctx + }) + } + + reset() + _, err := Populate(ctx, &SourceCLI{ + Args: []string{"a", "--foo=1", "--bar=2"}, + }) + massert.Require(t, + massert.Comment(massert.Nil(err), "%v", err), + massert.Equal(1, *foo), + massert.Equal(2, *bar), + massert.Nil(baz), + massert.Equal(true, *aFlag), + massert.Equal(false, *defaultFlag), + ) + + reset() + _, err = Populate(ctx, &SourceCLI{ + Args: []string{"--foo=1", "--baz=3"}, + }) + massert.Require(t, + massert.Comment(massert.Nil(err), "%v", err), + massert.Equal(1, *foo), + massert.Nil(bar), + massert.Equal(3, *baz), + massert.Equal(false, *aFlag), + massert.Equal(true, *defaultFlag), + ) +} + +func ExampleWithCLISubCommand() { + ctx := context.Background() + ctx, foo := WithInt(ctx, "foo", 0, "Description of foo.") + + var bar *int + ctx, aFlag := WithCLISubCommand(ctx, "a", "Description of a.", + func(ctx context.Context) context.Context { + ctx, bar = WithInt(ctx, "bar", 0, "Description of bar.") + return ctx + }) + + var baz *int + ctx, defaultFlag := WithCLISubCommand(ctx, "", "Description of default.", + func(ctx context.Context) context.Context { + ctx, baz = WithInt(ctx, "baz", 0, "Description of baz.") + return ctx + }) + + args := []string{"a", "--foo=1", "--bar=2"} + if _, err := Populate(ctx, &SourceCLI{Args: args}); err != nil { + panic(err) + } + fmt.Printf("foo:%d bar:%d aFlag:%v defaultFlag:%v\n", *foo, *bar, *aFlag, *defaultFlag) + + // reset output for another Populate + *aFlag = false + args = []string{"--foo=1", "--baz=3"} + if _, err := Populate(ctx, &SourceCLI{Args: args}); err != nil { + panic(err) + } + fmt.Printf("foo:%d baz:%d aFlag:%v defaultFlag:%v\n", *foo, *baz, *aFlag, *defaultFlag) + + // Output: foo:1 bar:2 aFlag:true defaultFlag:false + // foo:1 baz:3 aFlag:false defaultFlag:true +} diff --git a/mcfg/env.go b/mcfg/env.go index 8864b98..b94aec6 100644 --- a/mcfg/env.go +++ b/mcfg/env.go @@ -43,12 +43,13 @@ func (env *SourceEnv) expectedName(path []string, name string) string { } // Parse implements the method for the Source interface -func (env *SourceEnv) Parse(ctx context.Context, params []Param) (context.Context, []ParamValue, error) { +func (env *SourceEnv) Parse(ctx context.Context) (context.Context, []ParamValue, error) { kvs := env.Env if kvs == nil { kvs = os.Environ() } + params := CollectParams(ctx) pM := map[string]Param{} for _, p := range params { name := env.expectedName(mctx.Path(p.Context), p.Name) diff --git a/mcfg/mcfg.go b/mcfg/mcfg.go index 79e63a6..ba570e1 100644 --- a/mcfg/mcfg.go +++ b/mcfg/mcfg.go @@ -18,6 +18,12 @@ import ( // - JSON file // - YAML file +// TODO WithCLISubCommand does not play nice with the expected use-case of +// having CLI params overwrite Env ones. If Env is specified first in the +// Sources slice then it won't know about any extra Params which might get added +// due to a sub-command, but if it's specified second then Env values will +// overwrite CLI ones. + func sortParams(params []Param) { sort.Slice(params, func(i, j int) bool { a, b := params[i], params[j] @@ -39,10 +45,10 @@ func sortParams(params []Param) { }) } -// returns all Params gathered by recursively retrieving them from this Context -// and its children. Returned Params are sorted according to their Path and -// Name. -func collectParams(ctx context.Context) []Param { +// CollectParams returns all Params gathered by recursively retrieving them from +// this Context and its children. Returned Params are sorted according to their +// Path and Name. +func CollectParams(ctx context.Context) []Param { var params []Param var visit func(context.Context) @@ -85,13 +91,18 @@ func paramHash(path []string, name string) string { // Source may be nil to indicate that no configuration is provided. Only default // values will be used, and if any parameters are required this will error. func Populate(ctx context.Context, src Source) (context.Context, error) { - params := collectParams(ctx) if src == nil { src = ParamValues(nil) } - // map Params to their hash, so we can match them to their ParamValues + ctx, pvs, err := src.Parse(ctx) + if err != nil { + return nil, err + } + + // map Params to their hash, so we can match them to their ParamValues. // later. There should not be any duplicates here. + params := CollectParams(ctx) pM := map[string]Param{} for _, p := range params { path := mctx.Path(p.Context) @@ -102,11 +113,6 @@ func Populate(ctx context.Context, src Source) (context.Context, error) { pM[hash] = p } - ctx, pvs, err := src.Parse(ctx, params) - if err != nil { - return nil, err - } - // dedupe the ParamValues based on their hashes, with the last ParamValue // taking precedence. Also filter out those with no corresponding Param. pvM := map[string]ParamValue{} diff --git a/mcfg/source.go b/mcfg/source.go index 29f4266..92642cf 100644 --- a/mcfg/source.go +++ b/mcfg/source.go @@ -13,14 +13,13 @@ type ParamValue struct { Value json.RawMessage } -// Source parses ParamValues out of a particular configuration source, given a -// sorted set of possible Params to parse, and the Context from with the Params -// were extracted. +// Source parses ParamValues out of a particular configuration source, given the +// Context which the Params were added to (via WithInt, WithString, etc...). +// CollectParams can be used to retrieve these Params. // // It's possible for Parsing to affect the Context itself, for example in the // case of sub-commands. For this reason Parse can return a Context, which will -// get used for subsequent Parse commands inside, and then returned from, -// Populate. +// get used for subsequent Parse commands inside Populate. // // Source should not return ParamValues which were not explicitly set to a value // by the configuration source. @@ -30,7 +29,7 @@ type ParamValue struct { // ParamValues which do not correspond to any of the passed in Params. These // will be ignored in Populate. type Source interface { - Parse(context.Context, []Param) (context.Context, []ParamValue, error) + Parse(context.Context) (context.Context, []ParamValue, error) } // ParamValues is simply a slice of ParamValue elements, which implements Parse @@ -38,7 +37,7 @@ type Source interface { type ParamValues []ParamValue // Parse implements the method for the Source interface. -func (pvs ParamValues) Parse(ctx context.Context, _ []Param) (context.Context, []ParamValue, error) { +func (pvs ParamValues) Parse(ctx context.Context) (context.Context, []ParamValue, error) { return ctx, pvs, nil } @@ -48,12 +47,12 @@ func (pvs ParamValues) Parse(ctx context.Context, _ []Param) (context.Context, [ type Sources []Source // Parse implements the method for the Source interface. -func (ss Sources) Parse(ctx context.Context, params []Param) (context.Context, []ParamValue, error) { +func (ss Sources) Parse(ctx context.Context) (context.Context, []ParamValue, error) { var pvs []ParamValue for _, s := range ss { var innerPVs []ParamValue var err error - if ctx, innerPVs, err = s.Parse(ctx, params); err != nil { + if ctx, innerPVs, err = s.Parse(ctx); err != nil { return nil, nil, err } pvs = append(pvs, innerPVs...) diff --git a/mcfg/source_test.go b/mcfg/source_test.go index 2dafeca..e72483b 100644 --- a/mcfg/source_test.go +++ b/mcfg/source_test.go @@ -143,7 +143,7 @@ func (scs srcCommonState) applyCtxAndPV(p srcCommonParams) srcCommonState { // ParamValues func (scs srcCommonState) assert(s Source) error { root := scs.mkRoot() - _, gotPVs, err := s.Parse(root, collectParams(root)) + _, gotPVs, err := s.Parse(root) if err != nil { return err } diff --git a/mtest/mtest.go b/mtest/mtest.go index 6eea8d1..216abed 100644 --- a/mtest/mtest.go +++ b/mtest/mtest.go @@ -50,7 +50,8 @@ func Run(ctx context.Context, t *testing.T, body func()) { env = append(env, tup[0]+"="+tup[1]) } - if err := mcfg.Populate(ctx, &mcfg.SourceEnv{Env: env}); err != nil { + ctx, err := mcfg.Populate(ctx, &mcfg.SourceEnv{Env: env}) + if err != nil { t.Fatal(err) } else if err := mrun.Start(ctx); err != nil { t.Fatal(err)