diff --git a/go/cmd/entrypoint/host.go b/go/cmd/entrypoint/host.go index e2fb3df..de84c5b 100644 --- a/go/cmd/entrypoint/host.go +++ b/go/cmd/entrypoint/host.go @@ -6,7 +6,6 @@ import ( "fmt" "isle/bootstrap" "isle/daemon/network" - "isle/jsonutil" "os" "sort" ) @@ -60,15 +59,15 @@ var subCmdHostCreate = subCmd{ var subCmdHostList = subCmd{ name: "list", descr: "Lists all hosts in the network, and their IPs", - do: func(ctx subCmdCtx) error { + do: doWithOutput(func(ctx subCmdCtx) (any, error) { ctx, err := ctx.withParsedFlags() if err != nil { - return fmt.Errorf("parsing flags: %w", err) + return nil, fmt.Errorf("parsing flags: %w", err) } hostsRes, err := ctx.getHosts() if err != nil { - return fmt.Errorf("calling GetHosts: %w", err) + return nil, fmt.Errorf("calling GetHosts: %w", err) } type host struct { @@ -94,8 +93,8 @@ var subCmdHostList = subCmd{ sort.Slice(hosts, func(i, j int) bool { return hosts[i].Name < hosts[j].Name }) - return jsonutil.WriteIndented(os.Stdout, hosts) - }, + return hosts, nil + }), } var subCmdHostRemove = subCmd{ diff --git a/go/cmd/entrypoint/nebula.go b/go/cmd/entrypoint/nebula.go index 5138181..f893ef4 100644 --- a/go/cmd/entrypoint/nebula.go +++ b/go/cmd/entrypoint/nebula.go @@ -3,7 +3,6 @@ package main import ( "errors" "fmt" - "isle/jsonutil" "isle/nebula" "os" ) @@ -67,27 +66,27 @@ var subCmdNebulaCreateCert = subCmd{ var subCmdNebulaShow = subCmd{ name: "show", descr: "Writes nebula network information to stdout in JSON format", - do: func(ctx subCmdCtx) error { + do: doWithOutput(func(ctx subCmdCtx) (any, error) { ctx, err := ctx.withParsedFlags() if err != nil { - return fmt.Errorf("parsing flags: %w", err) + return nil, fmt.Errorf("parsing flags: %w", err) } hosts, err := ctx.getHosts() if err != nil { - return fmt.Errorf("getting hosts: %w", err) + return nil, fmt.Errorf("getting hosts: %w", err) } caPublicCreds, err := newDaemonRPCClient().GetNebulaCAPublicCredentials(ctx) if err != nil { - return fmt.Errorf("calling GetNebulaCAPublicCredentials: %w", err) + return nil, fmt.Errorf("calling GetNebulaCAPublicCredentials: %w", err) } caCert := caPublicCreds.Cert caCertDetails := caCert.Unwrap().Details if len(caCertDetails.Subnets) != 1 { - return fmt.Errorf( + return nil, fmt.Errorf( "malformed ca.crt, contains unexpected subnets %#v", caCertDetails.Subnets, ) @@ -120,12 +119,8 @@ var subCmdNebulaShow = subCmd{ }) } - if err := jsonutil.WriteIndented(os.Stdout, out); err != nil { - return fmt.Errorf("encoding to stdout: %w", err) - } - - return nil - }, + return out, nil + }), } var subCmdNebula = subCmd{ diff --git a/go/cmd/entrypoint/network.go b/go/cmd/entrypoint/network.go index 4269eca..9c4542a 100644 --- a/go/cmd/entrypoint/network.go +++ b/go/cmd/entrypoint/network.go @@ -5,7 +5,6 @@ import ( "fmt" "isle/daemon/network" "isle/jsonutil" - "os" ) var subCmdNetworkCreate = subCmd{ @@ -97,19 +96,14 @@ var subCmdNetworkList = subCmd{ name: "list", descr: "Lists all networks which have been joined", noNetwork: true, - do: func(ctx subCmdCtx) error { + do: doWithOutput(func(ctx subCmdCtx) (any, error) { ctx, err := ctx.withParsedFlags() if err != nil { - return fmt.Errorf("parsing flags: %w", err) + return nil, fmt.Errorf("parsing flags: %w", err) } - creationParams, err := newDaemonRPCClient().GetNetworks(ctx) - if err != nil { - return fmt.Errorf("getting joined networks: %w", err) - } - - return jsonutil.WriteIndented(os.Stdout, creationParams) - }, + return newDaemonRPCClient().GetNetworks(ctx) + }), } var subCmdNetwork = subCmd{ diff --git a/go/cmd/entrypoint/sub_cmd.go b/go/cmd/entrypoint/sub_cmd.go index 1e23dfd..dba483a 100644 --- a/go/cmd/entrypoint/sub_cmd.go +++ b/go/cmd/entrypoint/sub_cmd.go @@ -2,13 +2,16 @@ package main import ( "context" + "errors" "fmt" "isle/daemon" + "isle/jsonutil" "os" "strings" "dev.mediocregopher.com/mediocre-go-lib.git/mlog" "github.com/spf13/pflag" + "gopkg.in/yaml.v3" ) type flagSet struct { @@ -191,3 +194,46 @@ func (ctx subCmdCtx) doSubCmd(subCmds ...subCmd) error { return nil } + +type outputFormat string + +func (f outputFormat) MarshalText() ([]byte, error) { return []byte(f), nil } + +func (f *outputFormat) UnmarshalText(b []byte) error { + *f = outputFormat(strings.ToLower(string(b))) + switch *f { + case "json", "yaml": + return nil + default: + return errors.New("invalid output format") + } +} + +// doWithOutput wraps a subCmd's do function so that it will output some value +// to stdout. The value will be formatted according to a command-line argument. +func doWithOutput(fn func(subCmdCtx) (any, error)) func(subCmdCtx) error { + return func(ctx subCmdCtx) error { + type outputFormatFlag = textUnmarshalerFlag[outputFormat, *outputFormat] + + outputFormat := outputFormatFlag{"yaml"} + ctx.flags.Var( + &outputFormat, + "format", + "How to format the output value. Can be 'json' or 'yaml'.", + ) + + res, err := fn(ctx) + if err != nil { + return err + } + + switch outputFormat.V { + case "json": + return jsonutil.WriteIndented(os.Stdout, res) + case "yaml": + return yaml.NewEncoder(os.Stdout).Encode(res) + default: + panic(fmt.Sprintf("unexpected outputFormat %q", outputFormat)) + } + } +}