isle/go/cmd/entrypoint/sub_cmd.go

356 lines
7.5 KiB
Go
Raw Permalink Normal View History

package main
import (
2022-10-26 22:37:03 +00:00
"context"
"errors"
"fmt"
2024-11-14 20:49:35 +00:00
"io"
"isle/bootstrap"
"isle/daemon"
"isle/daemon/jsonrpc2"
"isle/jsonutil"
"isle/toolkit"
"os"
"strings"
2024-06-22 15:49:56 +00:00
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
"github.com/spf13/pflag"
"gopkg.in/yaml.v3"
)
type subCmd struct {
2024-07-12 14:13:44 +00:00
name string
descr string
do func(subCmdCtx) error
2024-07-22 13:52:51 +00:00
// If set then the name will be allowed to be suffixed with this string.
plural string
}
2024-12-10 15:35:14 +00:00
func (c subCmd) fullName() string {
name := c.name
if c.plural != "" {
name += "(" + c.plural + ")"
}
return name
}
2024-11-14 20:49:35 +00:00
type subCmdCtxOpts struct {
args []string // command-line arguments, excluding the subCmd itself.
subCmdNames []string // names of subCmds so far, including this one
stdout io.Writer
changeStager *changeStager
daemonRPC daemon.RPC
// defaults to bootstrap.NewCreationParams
bootstrapNewCreationParams func(name, domain string) bootstrap.CreationParams
}
2024-11-14 20:49:35 +00:00
func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts {
if o == nil {
o = new(subCmdCtxOpts)
}
2024-11-14 20:49:35 +00:00
if o.args == nil {
o.args = os.Args[1:]
}
2024-11-14 20:49:35 +00:00
if o.stdout == nil {
o.stdout = os.Stdout
}
if o.bootstrapNewCreationParams == nil {
o.bootstrapNewCreationParams = bootstrap.NewCreationParams
}
2024-11-14 20:49:35 +00:00
return o
}
// subCmdCtx contains all information available to a subCmd's do method.
type subCmdCtx struct {
context.Context
logger *mlog.Logger
subCmd subCmd // the subCmd itself
opts *subCmdCtxOpts
2024-11-14 20:49:35 +00:00
flags *pflag.FlagSet
}
func newSubCmdCtx(
ctx context.Context,
logger *mlog.Logger,
subCmd subCmd,
opts *subCmdCtxOpts,
) subCmdCtx {
opts = opts.withDefaults()
return subCmdCtx{
Context: ctx,
logger: logger,
subCmd: subCmd,
opts: opts,
flags: pflag.NewFlagSet(subCmd.name, pflag.ExitOnError),
}
}
type daemonRPCCloser struct {
daemon.RPC
Close func() error
}
func (ctx subCmdCtx) newDaemonRPC() (*daemonRPCCloser, error) {
if ctx.opts.daemonRPC != nil {
return &daemonRPCCloser{
ctx.opts.daemonRPC, func() error { return nil },
}, nil
}
socketPath := httpSocketPath()
if stat, err := os.Stat(socketPath); err != nil {
return nil, fmt.Errorf("checking http socket file: %w", err)
} else if stat.Mode().Type() != os.ModeSocket {
return nil, fmt.Errorf("%q exists but is not a socket", socketPath)
}
httpClient, baseURL := toolkit.NewUnixHTTPClient(
ctx.logger.WithNamespace("http-client"),
socketPath,
)
baseURL.Path = daemonHTTPRPCPath
daemonRPC := daemon.RPCFromClient(
jsonrpc2.NewHTTPClient(httpClient, baseURL.String()),
)
return &daemonRPCCloser{
daemonRPC, func() error { return httpClient.Close() },
}, nil
}
2024-12-13 09:56:43 +00:00
func (ctx subCmdCtx) getChangeStager() *changeStager {
if ctx.opts.changeStager != nil {
return ctx.opts.changeStager
}
return &changeStager{envCacheDir()}
}
func usagePrefix(subCmdNames []string) string {
subCmdNamesStr := strings.Join(subCmdNames, " ")
if subCmdNamesStr != "" {
subCmdNamesStr += " "
}
2024-12-10 15:35:14 +00:00
return fmt.Sprintf("USAGE:\n %s %s", os.Args[0], subCmdNamesStr)
}
type withParsedFlagsOpts struct {
// noNetwork, if true, means the call doesn't require a network to be
// specified on the command-line if there are more than one networks
// configured.
noNetwork bool
// Extra arguments on the command-line will be passed through to some
// underlying command.
passthroughArgs bool
}
func (o *withParsedFlagsOpts) withDefaults() *withParsedFlagsOpts {
if o == nil {
o = new(withParsedFlagsOpts)
}
return o
}
func (ctx subCmdCtx) withParsedFlags(
opts *withParsedFlagsOpts,
) (
subCmdCtx, error,
) {
opts = opts.withDefaults()
2024-11-14 20:49:35 +00:00
logLevel := logLevelFlag{mlog.LevelInfo}
ctx.flags.VarP(
&logLevel,
"log-level", "l",
"Maximum log level to output. Can be DEBUG, CHILD, INFO, WARN, ERROR, or FATAL.",
)
var network string
if !opts.noNetwork {
2024-11-14 20:49:35 +00:00
ctx.flags.StringVar(
&network,
"network", "",
"Which network to perform the command against, if more than one is joined. Can be an ID, name, or domain.",
)
}
ctx.flags.VisitAll(func(f *pflag.Flag) {
if f.Shorthand == "h" {
panic(fmt.Sprintf("flag %+v has reserved shorthand `-h`", f))
}
if f.Name == "help" {
panic(fmt.Sprintf("flag %+v has reserved name `--help`", f))
}
})
2024-11-14 20:49:35 +00:00
ctx.flags.Usage = func() {
2024-12-10 15:35:14 +00:00
if ctx.subCmd.descr != "" {
fmt.Fprintf(
os.Stderr, "\nDESCRIPTION:\n %s\n\n", ctx.subCmd.descr,
)
}
2024-11-14 20:49:35 +00:00
var passthroughStr string
if opts.passthroughArgs {
2024-11-14 20:49:35 +00:00
passthroughStr = " [--] [args...]"
}
fmt.Fprintf(
os.Stderr, "%s[-h|--help] [%s flags...]%s\n\n",
usagePrefix(ctx.opts.subCmdNames), ctx.subCmd.name, passthroughStr,
)
2024-12-10 15:35:14 +00:00
2024-11-14 20:49:35 +00:00
fmt.Fprintf(
2024-12-10 15:35:14 +00:00
os.Stderr, "%s FLAGS:\n", strings.ToUpper(ctx.subCmd.name),
2024-11-14 20:49:35 +00:00
)
fmt.Fprintln(os.Stderr, ctx.flags.FlagUsages())
os.Stderr.Sync()
os.Exit(2)
}
if err := ctx.flags.Parse(ctx.opts.args); err != nil {
return ctx, err
}
2024-11-14 20:49:35 +00:00
ctx.Context = daemon.WithNetwork(ctx.Context, network)
ctx.logger = ctx.logger.WithMaxLevel(logLevel.Int())
return ctx, nil
}
func (ctx subCmdCtx) doSubCmd(subCmds ...subCmd) error {
printUsageExit := func(subCmdName string) {
2024-12-10 15:35:14 +00:00
fmt.Fprintf(os.Stderr, "unknown sub-command %q\n\n", subCmdName)
if ctx.subCmd.descr != "" {
fmt.Fprintf(
os.Stderr, "DESCRIPTION:\n %s\n\n", ctx.subCmd.descr,
)
}
fmt.Fprintf(
os.Stderr,
2024-12-10 15:35:14 +00:00
"%s<subCmd> [-h|--help] [sub-command flags...]\n\n",
2024-11-14 20:49:35 +00:00
usagePrefix(ctx.opts.subCmdNames),
)
2024-12-10 15:35:14 +00:00
fmt.Fprintf(os.Stderr, "SUB-COMMANDS:\n")
2024-12-10 15:35:14 +00:00
var maxNameLen int
for _, subCmd := range subCmds {
2024-12-10 15:35:14 +00:00
l := len(subCmd.fullName())
if l > maxNameLen {
maxNameLen = l
2024-07-22 13:52:51 +00:00
}
2024-12-10 15:35:14 +00:00
}
for _, subCmd := range subCmds {
var (
name = subCmd.fullName()
padding = strings.Repeat(" ", maxNameLen-len(name)+3)
)
fmt.Fprintf(
os.Stderr, " %s%s%s\n", name, padding, subCmd.descr,
)
}
fmt.Fprintf(os.Stderr, "\n")
os.Stderr.Sync()
os.Exit(2)
}
2024-11-14 20:49:35 +00:00
args := ctx.opts.args
if len(args) == 0 {
printUsageExit("")
}
subCmdsMap := map[string]subCmd{}
for _, subCmd := range subCmds {
subCmdsMap[subCmd.name] = subCmd
2024-07-22 13:52:51 +00:00
if subCmd.plural != "" {
subCmdsMap[subCmd.name+subCmd.plural] = subCmd
}
}
subCmdName, args := args[0], args[1:]
subCmd, ok := subCmdsMap[subCmdName]
if !ok {
printUsageExit(subCmdName)
}
2024-11-14 20:49:35 +00:00
nextSubCmdCtxOpts := *ctx.opts
nextSubCmdCtxOpts.args = args
nextSubCmdCtxOpts.subCmdNames = append(ctx.opts.subCmdNames, subCmdName)
nextSubCmdCtx := newSubCmdCtx(
ctx.Context, ctx.logger, subCmd, &nextSubCmdCtxOpts,
)
if err := subCmd.do(nextSubCmdCtx); err != nil {
return err
}
return nil
}
type outputFormat string
func (f outputFormat) MarshalText() ([]byte, error) { return []byte(f), nil }
func (f *outputFormat) UnmarshalText(b []byte) error {
*f = outputFormat(strings.ToLower(string(b)))
switch *f {
case "json", "yaml":
return nil
default:
return errors.New("invalid output format")
}
}
// doWithOutput wraps a subCmd's do function so that it will output some value
// to stdout. The value will be formatted according to a command-line argument.
func doWithOutput(fn func(subCmdCtx) (any, error)) func(subCmdCtx) error {
return func(ctx subCmdCtx) error {
type outputFormatFlag = textUnmarshalerFlag[outputFormat, *outputFormat]
outputFormat := outputFormatFlag{"yaml"}
ctx.flags.Var(
&outputFormat,
"format",
"How to format the output value. Can be 'json' or 'yaml'.",
)
res, err := fn(ctx)
if err != nil {
return err
}
switch outputFormat.V {
case "json":
2024-11-14 20:49:35 +00:00
return jsonutil.WriteIndented(ctx.opts.stdout, res)
case "yaml":
2024-11-14 20:49:35 +00:00
return yaml.NewEncoder(ctx.opts.stdout).Encode(res)
default:
panic(fmt.Sprintf("unexpected outputFormat %q", outputFormat))
}
}
}