diff --git a/go/cmd/entrypoint/client.go b/go/cmd/entrypoint/client.go index 249226c..a8406dd 100644 --- a/go/cmd/entrypoint/client.go +++ b/go/cmd/entrypoint/client.go @@ -6,7 +6,7 @@ import ( ) func (ctx subCmdCtx) getHosts() ([]bootstrap.Host, error) { - res, err := newDaemonRPCClient().GetHosts(ctx) + res, err := ctx.getDaemonRPC().GetHosts(ctx) if err != nil { return nil, fmt.Errorf("calling GetHosts: %w", err) } diff --git a/go/cmd/entrypoint/daemon.go b/go/cmd/entrypoint/daemon.go index 906a25a..a1803e2 100644 --- a/go/cmd/entrypoint/daemon.go +++ b/go/cmd/entrypoint/daemon.go @@ -38,9 +38,6 @@ var subCmdDaemon = subCmd{ return daecommon.CopyDefaultConfig(os.Stdout) } - logger := ctx.logger() - defer logger.Close() - // TODO check that daemon is either running as root, or that the // required linux capabilities are set. // TODO check that the tun module is loaded (for nebula). @@ -52,7 +49,7 @@ var subCmdDaemon = subCmd{ networkLoader, err := network.NewLoader( ctx, - logger.WithNamespace("loader"), + ctx.logger.WithNamespace("loader"), envBinDirPath, nil, ) @@ -60,20 +57,22 @@ var subCmdDaemon = subCmd{ return fmt.Errorf("instantiating network loader: %w", err) } - daemonInst, err := daemon.New(ctx, logger, networkLoader, daemonConfig) + daemonInst, err := daemon.New( + ctx, ctx.logger, networkLoader, daemonConfig, + ) if err != nil { return fmt.Errorf("starting daemon: %w", err) } defer func() { - logger.Info(ctx, "Stopping child processes") + ctx.logger.Info(ctx, "Stopping child processes") if err := daemonInst.Shutdown(); err != nil { - logger.Error(ctx, "Shutting down daemon cleanly failed, there may be orphaned child processes", err) + ctx.logger.Error(ctx, "Shutting down daemon cleanly failed, there may be orphaned child processes", err) } - logger.Info(ctx, "Child processes successfully stopped") + ctx.logger.Info(ctx, "Child processes successfully stopped") }() { - logger := logger.WithNamespace("http") + logger := ctx.logger.WithNamespace("http") httpSrv, err := newHTTPServer( ctx, logger, daemonInst, ) diff --git a/go/cmd/entrypoint/daemon_util.go b/go/cmd/entrypoint/daemon_util.go index d6ddaca..3f415b5 100644 --- a/go/cmd/entrypoint/daemon_util.go +++ b/go/cmd/entrypoint/daemon_util.go @@ -6,7 +6,6 @@ import ( "fmt" "io/fs" "isle/daemon" - "isle/daemon/jsonrpc2" "net" "net/http" "os" @@ -17,14 +16,6 @@ import ( const daemonHTTPRPCPath = "/rpc/v0.json" -func newDaemonRPCClient() daemon.RPC { - return daemon.RPCFromClient( - jsonrpc2.NewUnixHTTPClient( - daemon.HTTPSocketPath(), daemonHTTPRPCPath, - ), - ) -} - func newHTTPServer( ctx context.Context, logger *mlog.Logger, daemonInst *daemon.Daemon, ) ( diff --git a/go/cmd/entrypoint/garage.go b/go/cmd/entrypoint/garage.go index b881ea6..68c317e 100644 --- a/go/cmd/entrypoint/garage.go +++ b/go/cmd/entrypoint/garage.go @@ -51,7 +51,7 @@ var subCmdGarageMC = subCmd{ return fmt.Errorf("parsing flags: %w", err) } - clientParams, err := newDaemonRPCClient().GetGarageClientParams(ctx) + clientParams, err := ctx.getDaemonRPC().GetGarageClientParams(ctx) if err != nil { return fmt.Errorf("calling GetGarageClientParams: %w", err) } @@ -113,15 +113,16 @@ var subCmdGarageMC = subCmd{ } var subCmdGarageCLI = subCmd{ - name: "cli", - descr: "Runs the garage binary, automatically configured to point to the garage sub-process of a running isle daemon", + name: "cli", + descr: "Runs the garage binary, automatically configured to point to the garage sub-process of a running isle daemon", + passthroughArgs: true, do: func(ctx subCmdCtx) error { ctx, err := ctx.withParsedFlags() if err != nil { return fmt.Errorf("parsing flags: %w", err) } - clientParams, err := newDaemonRPCClient().GetGarageClientParams(ctx) + clientParams, err := ctx.getDaemonRPC().GetGarageClientParams(ctx) if err != nil { return fmt.Errorf("calling GetGarageClientParams: %w", err) } @@ -132,7 +133,7 @@ var subCmdGarageCLI = subCmd{ var ( binPath = binPath("garage") - args = append([]string{"garage"}, ctx.args...) + args = append([]string{"garage"}, ctx.opts.args...) cliEnv = append( os.Environ(), "GARAGE_RPC_HOST="+clientParams.Node.RPCNodeAddr(), diff --git a/go/cmd/entrypoint/host.go b/go/cmd/entrypoint/host.go index de84c5b..650a4c5 100644 --- a/go/cmd/entrypoint/host.go +++ b/go/cmd/entrypoint/host.go @@ -42,7 +42,7 @@ var subCmdHostCreate = subCmd{ return errors.New("--hostname is required") } - res, err := newDaemonRPCClient().CreateHost( + res, err := ctx.getDaemonRPC().CreateHost( ctx, hostName.V, network.CreateHostOpts{ IP: ip.V, CanCreateHosts: *canCreateHosts, @@ -120,7 +120,7 @@ var subCmdHostRemove = subCmd{ return errors.New("--hostname is required") } - if err := newDaemonRPCClient().RemoveHost(ctx, hostName.V); err != nil { + if err := ctx.getDaemonRPC().RemoveHost(ctx, hostName.V); err != nil { return fmt.Errorf("calling RemoveHost: %w", err) } diff --git a/go/cmd/entrypoint/main.go b/go/cmd/entrypoint/main.go index 522ac16..dd2852a 100644 --- a/go/cmd/entrypoint/main.go +++ b/go/cmd/entrypoint/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "os" "os/signal" "path/filepath" @@ -28,6 +29,32 @@ func binPath(name string) string { return filepath.Join(envBinDirPath, name) } +var rootCmd = subCmd{ + name: "isle", + descr: "All Isle sub-commands", + noNetwork: true, + do: func(ctx subCmdCtx) error { + return ctx.doSubCmd( + subCmdDaemon, + subCmdGarage, + subCmdHost, + subCmdNebula, + subCmdNetwork, + subCmdStorage, + subCmdVersion, + ) + }, +} + +func doRootCmd( + ctx context.Context, + logger *mlog.Logger, + opts *subCmdCtxOpts, +) error { + subCmdCtx := newSubCmdCtx(ctx, logger, rootCmd, opts) + return subCmdCtx.subCmd.do(subCmdCtx) +} + func main() { logger := mlog.NewLogger(nil) defer logger.Close() @@ -50,20 +77,7 @@ func main() { logger.FatalString(ctx, "second signal received, force quitting, there may be zombie children left behind, good luck!") }() - err := subCmdCtx{ - Context: ctx, - args: os.Args[1:], - }.doSubCmd( - subCmdDaemon, - subCmdGarage, - subCmdHost, - subCmdNebula, - subCmdNetwork, - subCmdStorage, - subCmdVersion, - ) - - if err != nil { - logger.Fatal(ctx, "error running command", err) + if err := doRootCmd(ctx, logger, nil); err != nil { + fmt.Fprintln(os.Stderr, err) } } diff --git a/go/cmd/entrypoint/nebula.go b/go/cmd/entrypoint/nebula.go index f893ef4..f6cea20 100644 --- a/go/cmd/entrypoint/nebula.go +++ b/go/cmd/entrypoint/nebula.go @@ -43,7 +43,7 @@ var subCmdNebulaCreateCert = subCmd{ return fmt.Errorf("unmarshaling public key as PEM: %w", err) } - res, err := newDaemonRPCClient().CreateNebulaCertificate( + res, err := ctx.getDaemonRPC().CreateNebulaCertificate( ctx, hostName.V, hostPub, ) if err != nil { @@ -77,7 +77,7 @@ var subCmdNebulaShow = subCmd{ return nil, fmt.Errorf("getting hosts: %w", err) } - caPublicCreds, err := newDaemonRPCClient().GetNebulaCAPublicCredentials(ctx) + caPublicCreds, err := ctx.getDaemonRPC().GetNebulaCAPublicCredentials(ctx) if err != nil { return nil, fmt.Errorf("calling GetNebulaCAPublicCredentials: %w", err) } diff --git a/go/cmd/entrypoint/network.go b/go/cmd/entrypoint/network.go index 955db7b..603b3d8 100644 --- a/go/cmd/entrypoint/network.go +++ b/go/cmd/entrypoint/network.go @@ -52,7 +52,7 @@ var subCmdNetworkCreate = subCmd{ return errors.New("--name, --domain, --ip-net, and --hostname are required") } - err = newDaemonRPCClient().CreateNetwork( + err = ctx.getDaemonRPC().CreateNetwork( ctx, *name, *domain, ipNet.V, hostName.V, ) if err != nil { @@ -88,7 +88,7 @@ var subCmdNetworkJoin = subCmd{ ) } - return newDaemonRPCClient().JoinNetwork(ctx, newBootstrap) + return ctx.getDaemonRPC().JoinNetwork(ctx, newBootstrap) }, } @@ -102,7 +102,7 @@ var subCmdNetworkList = subCmd{ return nil, fmt.Errorf("parsing flags: %w", err) } - return newDaemonRPCClient().GetNetworks(ctx) + return ctx.getDaemonRPC().GetNetworks(ctx) }), } @@ -115,7 +115,7 @@ var subCmdNetworkGetConfig = subCmd{ return nil, fmt.Errorf("parsing flags: %w", err) } - return newDaemonRPCClient().GetConfig(ctx) + return ctx.getDaemonRPC().GetConfig(ctx) }), } diff --git a/go/cmd/entrypoint/storage.go b/go/cmd/entrypoint/storage.go index bca78e1..1c8ba90 100644 --- a/go/cmd/entrypoint/storage.go +++ b/go/cmd/entrypoint/storage.go @@ -17,7 +17,7 @@ var subCmdStorageAllocationList = subCmd{ return nil, fmt.Errorf("parsing flags: %w", err) } - config, err := newDaemonRPCClient().GetConfig(ctx) + config, err := ctx.getDaemonRPC().GetConfig(ctx) if err != nil { return nil, fmt.Errorf("getting network config: %w", err) } diff --git a/go/cmd/entrypoint/sub_cmd.go b/go/cmd/entrypoint/sub_cmd.go index dba483a..011b649 100644 --- a/go/cmd/entrypoint/sub_cmd.go +++ b/go/cmd/entrypoint/sub_cmd.go @@ -4,7 +4,9 @@ import ( "context" "errors" "fmt" + "io" "isle/daemon" + "isle/daemon/jsonrpc2" "isle/jsonutil" "os" "strings" @@ -14,13 +16,6 @@ import ( "gopkg.in/yaml.v3" ) -type flagSet struct { - *pflag.FlagSet - - network string - logLevel logLevelFlag -} - type subCmd struct { name string descr string @@ -39,64 +34,53 @@ type subCmd struct { passthroughArgs bool } +type subCmdCtxOpts struct { + args []string // command-line arguments, excluding the subCmd itself. + subCmdNames []string // names of subCmds so far, including this one + daemonRPC daemon.RPC + stdout io.Writer +} + +func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts { + if o == nil { + o = new(subCmdCtxOpts) + } + + if o.args == nil { + o.args = os.Args[1:] + } + + if o.stdout == nil { + o.stdout = os.Stdout + } + + return o +} + // subCmdCtx contains all information available to a subCmd's do method. type subCmdCtx struct { context.Context + logger *mlog.Logger + subCmd subCmd // the subCmd itself + opts *subCmdCtxOpts - subCmd subCmd // the subCmd itself - args []string // command-line arguments, excluding the subCmd itself. - subCmdNames []string // names of subCmds so far, including this one - - flags *flagSet + flags *pflag.FlagSet } func newSubCmdCtx( ctx context.Context, + logger *mlog.Logger, subCmd subCmd, - args []string, - subCmdNames []string, + opts *subCmdCtxOpts, ) subCmdCtx { - flags := pflag.NewFlagSet(subCmd.name, pflag.ExitOnError) - flags.Usage = func() { - var passthroughStr string - if subCmd.passthroughArgs { - passthroughStr = " [--] [args...]" - } - - fmt.Fprintf( - os.Stderr, "%s[-h|--help] [%s flags...]%s\n\n", - usagePrefix(subCmdNames), subCmd.name, passthroughStr, - ) - fmt.Fprintf(os.Stderr, "%s FLAGS:\n\n", strings.ToUpper(subCmd.name)) - fmt.Fprintln(os.Stderr, flags.FlagUsages()) - - os.Stderr.Sync() - os.Exit(2) - } - - fs := &flagSet{ - FlagSet: flags, - logLevel: logLevelFlag{mlog.LevelInfo}, - } - - if !subCmd.noNetwork { - fs.FlagSet.StringVar( - &fs.network, "network", "", "Which network to perform the command against, if more than one is joined. Can be an ID, name, or domain.", - ) - } - - fs.FlagSet.VarP( - &fs.logLevel, - "log-level", "l", - "Maximum log level to output. Can be DEBUG, CHILD, INFO, WARN, ERROR, or FATAL.", - ) + opts = opts.withDefaults() return subCmdCtx{ - Context: ctx, - subCmd: subCmd, - args: args, - subCmdNames: subCmdNames, - flags: fs, + Context: ctx, + logger: logger, + subCmd: subCmd, + opts: opts, + flags: pflag.NewFlagSet(subCmd.name, pflag.ExitOnError), } } @@ -109,13 +93,34 @@ func usagePrefix(subCmdNames []string) string { return fmt.Sprintf("\nUSAGE: %s %s", os.Args[0], subCmdNamesStr) } -func (ctx subCmdCtx) logger() *mlog.Logger { - return mlog.NewLogger(&mlog.LoggerOpts{ - MaxLevel: ctx.flags.logLevel.Int(), - }) +func (ctx subCmdCtx) getDaemonRPC() daemon.RPC { + if ctx.opts.daemonRPC == nil { + ctx.opts.daemonRPC = daemon.RPCFromClient( + jsonrpc2.NewUnixHTTPClient( + daemon.HTTPSocketPath(), daemonHTTPRPCPath, + ), + ) + } + return ctx.opts.daemonRPC } func (ctx subCmdCtx) withParsedFlags() (subCmdCtx, error) { + logLevel := logLevelFlag{mlog.LevelInfo} + ctx.flags.VarP( + &logLevel, + "log-level", "l", + "Maximum log level to output. Can be DEBUG, CHILD, INFO, WARN, ERROR, or FATAL.", + ) + + var network string + if !ctx.subCmd.noNetwork { + ctx.flags.StringVar( + &network, + "network", "", + "Which network to perform the command against, if more than one is joined. Can be an ID, name, or domain.", + ) + } + ctx.flags.VisitAll(func(f *pflag.Flag) { if f.Shorthand == "h" { panic(fmt.Sprintf("flag %+v has reserved shorthand `-h`", f)) @@ -125,11 +130,32 @@ func (ctx subCmdCtx) withParsedFlags() (subCmdCtx, error) { } }) - if err := ctx.flags.Parse(ctx.args); err != nil { + ctx.flags.Usage = func() { + var passthroughStr string + if ctx.subCmd.passthroughArgs { + passthroughStr = " [--] [args...]" + } + + fmt.Fprintf( + os.Stderr, "%s[-h|--help] [%s flags...]%s\n\n", + usagePrefix(ctx.opts.subCmdNames), ctx.subCmd.name, passthroughStr, + ) + fmt.Fprintf( + os.Stderr, "%s FLAGS:\n\n", strings.ToUpper(ctx.subCmd.name), + ) + fmt.Fprintln(os.Stderr, ctx.flags.FlagUsages()) + + os.Stderr.Sync() + os.Exit(2) + } + + if err := ctx.flags.Parse(ctx.opts.args); err != nil { return ctx, err } - ctx.Context = daemon.WithNetwork(ctx.Context, ctx.flags.network) + ctx.Context = daemon.WithNetwork(ctx.Context, network) + ctx.logger = ctx.logger.WithMaxLevel(logLevel.Int()) + return ctx, nil } @@ -142,7 +168,7 @@ func (ctx subCmdCtx) doSubCmd(subCmds ...subCmd) error { fmt.Fprintf( os.Stderr, "%s [-h|--help] [sub-command flags...]\n", - usagePrefix(ctx.subCmdNames), + usagePrefix(ctx.opts.subCmdNames), ) fmt.Fprintf(os.Stderr, "\nSUB-COMMANDS:\n\n") @@ -160,7 +186,7 @@ func (ctx subCmdCtx) doSubCmd(subCmds ...subCmd) error { os.Exit(2) } - args := ctx.args + args := ctx.opts.args if len(args) == 0 { printUsageExit("") @@ -181,11 +207,12 @@ func (ctx subCmdCtx) doSubCmd(subCmds ...subCmd) error { printUsageExit(subCmdName) } + nextSubCmdCtxOpts := *ctx.opts + nextSubCmdCtxOpts.args = args + nextSubCmdCtxOpts.subCmdNames = append(ctx.opts.subCmdNames, subCmdName) + nextSubCmdCtx := newSubCmdCtx( - ctx.Context, - subCmd, - args, - append(ctx.subCmdNames, subCmdName), + ctx.Context, ctx.logger, subCmd, &nextSubCmdCtxOpts, ) if err := subCmd.do(nextSubCmdCtx); err != nil { @@ -229,9 +256,9 @@ func doWithOutput(fn func(subCmdCtx) (any, error)) func(subCmdCtx) error { switch outputFormat.V { case "json": - return jsonutil.WriteIndented(os.Stdout, res) + return jsonutil.WriteIndented(ctx.opts.stdout, res) case "yaml": - return yaml.NewEncoder(os.Stdout).Encode(res) + return yaml.NewEncoder(ctx.opts.stdout).Encode(res) default: panic(fmt.Sprintf("unexpected outputFormat %q", outputFormat)) } diff --git a/go/daemon/rpc.go b/go/daemon/rpc.go index 698cc62..f7ea426 100644 --- a/go/daemon/rpc.go +++ b/go/daemon/rpc.go @@ -1,3 +1,5 @@ +//go:generate mockery --name RPC --inpackage --filename rpc_mock.go + package daemon import (