package main import ( "context" "errors" "fmt" "io" "isle/bootstrap" "isle/daemon" "isle/daemon/jsonrpc2" "isle/jsonutil" "isle/toolkit" "os" "strings" "dev.mediocregopher.com/mediocre-go-lib.git/mlog" "github.com/spf13/pflag" "gopkg.in/yaml.v3" ) type subCmd struct { name string descr string do func(subCmdCtx) error // If set then the name will be allowed to be suffixed with this string. plural string } func (c subCmd) fullName() string { name := c.name if c.plural != "" { name += "(" + c.plural + ")" } return name } type subCmdCtxOpts struct { args []string // command-line arguments, excluding the subCmd itself. subCmdNames []string // names of subCmds so far, including this one stdout io.Writer changeStager *changeStager daemonRPC daemon.RPC // defaults to bootstrap.NewCreationParams bootstrapNewCreationParams func(name, domain string) bootstrap.CreationParams } func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts { if o == nil { o = new(subCmdCtxOpts) } if o.args == nil { o.args = os.Args[1:] } if o.stdout == nil { o.stdout = os.Stdout } if o.bootstrapNewCreationParams == nil { o.bootstrapNewCreationParams = bootstrap.NewCreationParams } return o } // subCmdCtx contains all information available to a subCmd's do method. type subCmdCtx struct { context.Context logger *mlog.Logger subCmd subCmd // the subCmd itself opts *subCmdCtxOpts flags *pflag.FlagSet } func newSubCmdCtx( ctx context.Context, logger *mlog.Logger, subCmd subCmd, opts *subCmdCtxOpts, ) subCmdCtx { opts = opts.withDefaults() return subCmdCtx{ Context: ctx, logger: logger, subCmd: subCmd, opts: opts, flags: pflag.NewFlagSet(subCmd.name, pflag.ExitOnError), } } type daemonRPCCloser struct { daemon.RPC Close func() error } func (ctx subCmdCtx) newDaemonRPC() (*daemonRPCCloser, error) { if ctx.opts.daemonRPC != nil { return &daemonRPCCloser{ ctx.opts.daemonRPC, func() error { return nil }, }, nil } socketPath := httpSocketPath() if stat, err := os.Stat(socketPath); err != nil { return nil, fmt.Errorf("checking http socket file: %w", err) } else if stat.Mode().Type() != os.ModeSocket { return nil, fmt.Errorf("%q exists but is not a socket", socketPath) } httpClient, baseURL := toolkit.NewUnixHTTPClient( ctx.logger.WithNamespace("http-client"), socketPath, ) baseURL.Path = daemonHTTPRPCPath daemonRPC := daemon.RPCFromClient( jsonrpc2.NewHTTPClient(httpClient, baseURL.String()), ) return &daemonRPCCloser{ daemonRPC, func() error { return httpClient.Close() }, }, nil } func (ctx subCmdCtx) getChangeStager() *changeStager { if ctx.opts.changeStager != nil { return ctx.opts.changeStager } return &changeStager{envCacheDir()} } func usagePrefix(subCmdNames []string) string { subCmdNamesStr := strings.Join(subCmdNames, " ") if subCmdNamesStr != "" { subCmdNamesStr += " " } return fmt.Sprintf("USAGE:\n %s %s", os.Args[0], subCmdNamesStr) } type withParsedFlagsOpts struct { // noNetwork, if true, means the call doesn't require a network to be // specified on the command-line if there are more than one networks // configured. noNetwork bool // Extra arguments on the command-line will be passed through to some // underlying command. passthroughArgs bool } func (o *withParsedFlagsOpts) withDefaults() *withParsedFlagsOpts { if o == nil { o = new(withParsedFlagsOpts) } return o } func (ctx subCmdCtx) withParsedFlags( opts *withParsedFlagsOpts, ) ( subCmdCtx, error, ) { opts = opts.withDefaults() logLevel := logLevelFlag{mlog.LevelInfo} ctx.flags.VarP( &logLevel, "log-level", "l", "Maximum log level to output. Can be DEBUG, CHILD, INFO, WARN, ERROR, or FATAL.", ) var network string if !opts.noNetwork { ctx.flags.StringVar( &network, "network", "", "Which network to perform the command against, if more than one is joined. Can be an ID, name, or domain.", ) } ctx.flags.VisitAll(func(f *pflag.Flag) { if f.Shorthand == "h" { panic(fmt.Sprintf("flag %+v has reserved shorthand `-h`", f)) } if f.Name == "help" { panic(fmt.Sprintf("flag %+v has reserved name `--help`", f)) } }) ctx.flags.Usage = func() { if ctx.subCmd.descr != "" { fmt.Fprintf( os.Stderr, "\nDESCRIPTION:\n %s\n\n", ctx.subCmd.descr, ) } var passthroughStr string if opts.passthroughArgs { passthroughStr = " [--] [args...]" } fmt.Fprintf( os.Stderr, "%s[-h|--help] [%s flags...]%s\n\n", usagePrefix(ctx.opts.subCmdNames), ctx.subCmd.name, passthroughStr, ) fmt.Fprintf( os.Stderr, "%s FLAGS:\n", strings.ToUpper(ctx.subCmd.name), ) fmt.Fprintln(os.Stderr, ctx.flags.FlagUsages()) os.Stderr.Sync() os.Exit(2) } if err := ctx.flags.Parse(ctx.opts.args); err != nil { return ctx, err } ctx.Context = daemon.WithNetwork(ctx.Context, network) ctx.logger = ctx.logger.WithMaxLevel(logLevel.Int()) return ctx, nil } func (ctx subCmdCtx) doSubCmd(subCmds ...subCmd) error { printUsageExit := func(subCmdName string) { fmt.Fprintf(os.Stderr, "unknown sub-command %q\n\n", subCmdName) if ctx.subCmd.descr != "" { fmt.Fprintf( os.Stderr, "DESCRIPTION:\n %s\n\n", ctx.subCmd.descr, ) } fmt.Fprintf( os.Stderr, "%s [-h|--help] [sub-command flags...]\n\n", usagePrefix(ctx.opts.subCmdNames), ) fmt.Fprintf(os.Stderr, "SUB-COMMANDS:\n") var maxNameLen int for _, subCmd := range subCmds { l := len(subCmd.fullName()) if l > maxNameLen { maxNameLen = l } } for _, subCmd := range subCmds { var ( name = subCmd.fullName() padding = strings.Repeat(" ", maxNameLen-len(name)+3) ) fmt.Fprintf( os.Stderr, " %s%s%s\n", name, padding, subCmd.descr, ) } fmt.Fprintf(os.Stderr, "\n") os.Stderr.Sync() os.Exit(2) } args := ctx.opts.args if len(args) == 0 { printUsageExit("") } subCmdsMap := map[string]subCmd{} for _, subCmd := range subCmds { subCmdsMap[subCmd.name] = subCmd if subCmd.plural != "" { subCmdsMap[subCmd.name+subCmd.plural] = subCmd } } subCmdName, args := args[0], args[1:] subCmd, ok := subCmdsMap[subCmdName] if !ok { printUsageExit(subCmdName) } nextSubCmdCtxOpts := *ctx.opts nextSubCmdCtxOpts.args = args nextSubCmdCtxOpts.subCmdNames = append(ctx.opts.subCmdNames, subCmdName) nextSubCmdCtx := newSubCmdCtx( ctx.Context, ctx.logger, subCmd, &nextSubCmdCtxOpts, ) if err := subCmd.do(nextSubCmdCtx); err != nil { return err } return nil } type outputFormat string func (f outputFormat) MarshalText() ([]byte, error) { return []byte(f), nil } func (f *outputFormat) UnmarshalText(b []byte) error { *f = outputFormat(strings.ToLower(string(b))) switch *f { case "json", "yaml": return nil default: return errors.New("invalid output format") } } // doWithOutput wraps a subCmd's do function so that it will output some value // to stdout. The value will be formatted according to a command-line argument. func doWithOutput(fn func(subCmdCtx) (any, error)) func(subCmdCtx) error { return func(ctx subCmdCtx) error { type outputFormatFlag = textUnmarshalerFlag[outputFormat, *outputFormat] outputFormat := outputFormatFlag{"yaml"} ctx.flags.Var( &outputFormat, "format", "How to format the output value. Can be 'json' or 'yaml'.", ) res, err := fn(ctx) if err != nil { return err } switch outputFormat.V { case "json": return jsonutil.WriteIndented(ctx.opts.stdout, res) case "yaml": return yaml.NewEncoder(ctx.opts.stdout).Encode(res) default: panic(fmt.Sprintf("unexpected outputFormat %q", outputFormat)) } } }