From 1608031103496b2344ce101a482c695bbdd6a68f Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Mon, 9 Dec 2024 18:09:45 +0100 Subject: [PATCH] Give 'vpn firewall list' a --staged flag --- go/cmd/entrypoint/change_stager.go | 42 ++++++++++++ go/cmd/entrypoint/main.go | 16 +++++ go/cmd/entrypoint/main_test.go | 18 +++-- go/cmd/entrypoint/sub_cmd.go | 13 ++-- go/cmd/entrypoint/vpn_firewall.go | 29 ++++++-- go/cmd/entrypoint/vpn_firewall_test.go | 94 ++++++++++++++++++++++++-- go/toolkit/dir.go | 6 ++ 7 files changed, 199 insertions(+), 19 deletions(-) create mode 100644 go/cmd/entrypoint/change_stager.go diff --git a/go/cmd/entrypoint/change_stager.go b/go/cmd/entrypoint/change_stager.go new file mode 100644 index 0000000..cee276f --- /dev/null +++ b/go/cmd/entrypoint/change_stager.go @@ -0,0 +1,42 @@ +package main + +import ( + "errors" + "fmt" + "io/fs" + "isle/jsonutil" + "isle/toolkit" + "os" + "path/filepath" +) + +type changeStager struct { + dir toolkit.Dir +} + +func (mgr *changeStager) path(name string) string { + return filepath.Join(mgr.dir.Path, name+".json") +} + +func (mgr *changeStager) get(into any, name string) (bool, error) { + path := mgr.path(name) + if err := jsonutil.LoadFile(into, path); errors.Is(err, fs.ErrNotExist) { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +} + +func (mgr *changeStager) set(v any, name string) error { + path := mgr.path(name) + return jsonutil.WriteFile(v, path, 0600) +} + +func (mgr *changeStager) del(name string) error { + path := mgr.path(name) + if err := os.Remove(path); err != nil { + return fmt.Errorf("removing file %q: %w", path, err) + } + return nil +} diff --git a/go/cmd/entrypoint/main.go b/go/cmd/entrypoint/main.go index 1383c33..dfd1412 100644 --- a/go/cmd/entrypoint/main.go +++ b/go/cmd/entrypoint/main.go @@ -3,13 +3,16 @@ package main import ( "context" "fmt" + "isle/toolkit" "os" "os/signal" "path/filepath" + "sync" "syscall" "dev.mediocregopher.com/mediocre-go-lib.git/mctx" "dev.mediocregopher.com/mediocre-go-lib.git/mlog" + "github.com/adrg/xdg" ) func getAppDirPath() string { @@ -23,6 +26,19 @@ func getAppDirPath() string { var ( envAppDirPath = getAppDirPath() envBinDirPath = filepath.Join(envAppDirPath, "bin") + envCacheDir = sync.OnceValue(func() toolkit.Dir { + cacheHome, err := toolkit.MkDir(xdg.CacheHome, true) + if err != nil { + panic(fmt.Errorf("creating cache directory %q: %w", xdg.CacheHome, err)) + } + + cacheDir, err := cacheHome.MkChildDir("isle", true) + if err != nil { + panic(fmt.Errorf("creating isle cache directory: %w", err)) + } + + return cacheDir + }) ) func binPath(name string) string { diff --git a/go/cmd/entrypoint/main_test.go b/go/cmd/entrypoint/main_test.go index 34540ac..5583e96 100644 --- a/go/cmd/entrypoint/main_test.go +++ b/go/cmd/entrypoint/main_test.go @@ -18,9 +18,10 @@ import ( type runHarness struct { ctx context.Context logger *mlog.Logger - stdout *bytes.Buffer daemonRPC *daemon.MockRPC daemonRPCServer *httptest.Server + stdout *bytes.Buffer + changeStager *changeStager } func newRunHarness(t *testing.T) *runHarness { @@ -29,18 +30,22 @@ func newRunHarness(t *testing.T) *runHarness { var ( ctx = context.Background() logger = toolkit.NewTestLogger(t) - stdout = new(bytes.Buffer) daemonRPC = daemon.NewMockRPC(t) daemonRPCHandler = jsonrpc2.NewHTTPHandler(daemon.NewRPCHandler( logger.WithNamespace("rpc"), daemonRPC, )) daemonRPCServer = httptest.NewServer(daemonRPCHandler) + + stdout = new(bytes.Buffer) + changeStager = &changeStager{toolkit.TempDir(t)} ) t.Cleanup(daemonRPCServer.Close) - return &runHarness{ctx, logger, stdout, daemonRPC, daemonRPCServer} + return &runHarness{ + ctx, logger, daemonRPC, daemonRPCServer, stdout, changeStager, + } } func (h *runHarness) run(t *testing.T, args ...string) error { @@ -54,9 +59,10 @@ func (h *runHarness) run(t *testing.T, args ...string) error { ) return doRootCmd(h.ctx, h.logger, &subCmdCtxOpts{ - args: args, - daemonRPC: daemonRPCClient, - stdout: h.stdout, + args: args, + daemonRPC: daemonRPCClient, + stdout: h.stdout, + changeStager: h.changeStager, }) } diff --git a/go/cmd/entrypoint/sub_cmd.go b/go/cmd/entrypoint/sub_cmd.go index c5f8070..a2546ed 100644 --- a/go/cmd/entrypoint/sub_cmd.go +++ b/go/cmd/entrypoint/sub_cmd.go @@ -36,10 +36,11 @@ type subCmd struct { } type subCmdCtxOpts struct { - args []string // command-line arguments, excluding the subCmd itself. - subCmdNames []string // names of subCmds so far, including this one - daemonRPC daemon.RPC - stdout io.Writer + args []string // command-line arguments, excluding the subCmd itself. + subCmdNames []string // names of subCmds so far, including this one + daemonRPC daemon.RPC + stdout io.Writer + changeStager *changeStager } func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts { @@ -55,6 +56,10 @@ func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts { o.stdout = os.Stdout } + if o.changeStager == nil { + o.changeStager = &changeStager{envCacheDir()} + } + return o } diff --git a/go/cmd/entrypoint/vpn_firewall.go b/go/cmd/entrypoint/vpn_firewall.go index 2ad23ba..65ee191 100644 --- a/go/cmd/entrypoint/vpn_firewall.go +++ b/go/cmd/entrypoint/vpn_firewall.go @@ -1,10 +1,13 @@ package main import ( + "errors" "fmt" "isle/daemon/daecommon" ) +const vpnFirewallConfigChangeStagerName = "vpn-firewall-config" + type firewallRuleView struct { Index int `yaml:"index"` daecommon.ConfigFirewallRule `yaml:",inline"` @@ -39,17 +42,35 @@ var subCmdVPNFirewallList = subCmd{ name: "list", descr: "List all currently configured firewall rules", do: doWithOutput(func(ctx subCmdCtx) (any, error) { + staged := ctx.flags.Bool( + "staged", + false, + "Return the firewall configuration with staged changes included", + ) + ctx, err := ctx.withParsedFlags() if err != nil { return nil, fmt.Errorf("parsing flags: %w", err) } - config, err := ctx.getDaemonRPC().GetConfig(ctx) - if err != nil { - return nil, fmt.Errorf("getting network config: %w", err) + var firewallConfig daecommon.ConfigFirewall + if !*staged { + config, err := ctx.getDaemonRPC().GetConfig(ctx) + if err != nil { + return nil, fmt.Errorf("getting network config: %w", err) + } + + firewallConfig = config.VPN.Firewall + + } else if ok, err := ctx.opts.changeStager.get( + &firewallConfig, vpnFirewallConfigChangeStagerName, + ); err != nil { + return nil, fmt.Errorf("checking for staged changes: %w", err) + } else if !ok { + return nil, errors.New("no firewall configuration changes have been staged") } - return newFirewallView(config.VPN.Firewall), nil + return newFirewallView(firewallConfig), nil }), } diff --git a/go/cmd/entrypoint/vpn_firewall_test.go b/go/cmd/entrypoint/vpn_firewall_test.go index 2757656..9a2cd9a 100644 --- a/go/cmd/entrypoint/vpn_firewall_test.go +++ b/go/cmd/entrypoint/vpn_firewall_test.go @@ -5,10 +5,13 @@ import ( "encoding/json" "isle/daemon/daecommon" "isle/toolkit" + "os" "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" ) func TestVPNFirewallList(t *testing.T) { @@ -17,7 +20,10 @@ func TestVPNFirewallList(t *testing.T) { tests := []struct { name string outbound, inbound []string + staged string + flags []string want map[string][]any + wantErr string }{ { name: "empty", @@ -77,6 +83,64 @@ func TestVPNFirewallList(t *testing.T) { }, }, }, + { + name: "staged/nothing staged", + flags: []string{"--staged"}, + wantErr: "no firewall configuration changes have been staged", + }, + { + name: "staged/staged but no flag", + outbound: []string{ + `{"port":"any","proto":"icmp","host":"any"}`, + }, + staged: `{ + "Inbound": [ + { + "Port":"80", + "Proto":"tcp", + "Host":"some-host" + } + ] + }`, + want: map[string][]any{ + "outbound": { + map[string]any{ + "index": 0, + "port": "any", + "proto": "icmp", + "host": "any", + }, + }, + "inbound": {}, + }, + }, + { + name: "staged/staged with flag", + outbound: []string{ + `{"port":"any","proto":"icmp","host":"any"}`, + }, + staged: `{ + "Inbound": [ + { + "Port":"80", + "Proto":"tcp", + "Host":"some-host" + } + ] + }`, + flags: []string{"--staged"}, + want: map[string][]any{ + "outbound": {}, + "inbound": { + map[string]any{ + "index": 0, + "port": "80", + "proto": "tcp", + "host": "some-host", + }, + }, + }, + }, } for _, test := range tests { @@ -89,6 +153,15 @@ func TestVPNFirewallList(t *testing.T) { inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]" ) + if test.staged != "" { + require.True(t, json.Valid([]byte(test.staged))) + require.NoError(t, os.WriteFile( + h.changeStager.path(vpnFirewallConfigChangeStagerName), + []byte(test.staged), + 0600, + )) + } + require.NoError(t, json.Unmarshal( []byte(outboundRawJSON), &config.VPN.Firewall.Outbound, )) @@ -97,12 +170,23 @@ func TestVPNFirewallList(t *testing.T) { []byte(inboundRawJSON), &config.VPN.Firewall.Inbound, )) - h.daemonRPC. - On("GetConfig", toolkit.MockArg[context.Context]()). - Return(config, nil). - Once() + if !slices.Contains(test.flags, "--staged") { + h.daemonRPC. + On("GetConfig", toolkit.MockArg[context.Context]()). + Return(config, nil). + Once() + } - h.runAssertStdout(t, test.want, "vpn", "firewall", "list") + args := append([]string{"vpn", "firewall", "list"}, test.flags...) + + if test.wantErr == "" { + h.runAssertStdout(t, test.want, args...) + } else { + err := h.run(t, args...) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), test.wantErr) + } + } }) } } diff --git a/go/toolkit/dir.go b/go/toolkit/dir.go index 389c8cc..4a473e6 100644 --- a/go/toolkit/dir.go +++ b/go/toolkit/dir.go @@ -6,6 +6,7 @@ import ( "io/fs" "os" "path/filepath" + "testing" ) // Dir is a type which makes it possible to statically assert that a directory @@ -58,6 +59,11 @@ func MkDir(path string, mayExist bool) (Dir, error) { return Dir{path}, nil } +// TempDir returns a Dir based on a temporary directory created via [t.TempDir]. +func TempDir(t *testing.T) Dir { + return Dir{t.TempDir()} +} + // MkChildDir is a helper for joining Dir's path to the given name and calling // MkDir with the result. func (d Dir) MkChildDir(name string, mayExist bool) (Dir, error) {