diff --git a/go/cmd/entrypoint/change_stager.go b/go/cmd/entrypoint/change_stager.go index cee276f..4616ffd 100644 --- a/go/cmd/entrypoint/change_stager.go +++ b/go/cmd/entrypoint/change_stager.go @@ -35,7 +35,7 @@ func (mgr *changeStager) set(v any, name string) error { func (mgr *changeStager) del(name string) error { path := mgr.path(name) - if err := os.Remove(path); err != nil { + if err := os.Remove(path); err != nil && !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("removing file %q: %w", path, err) } return nil diff --git a/go/cmd/entrypoint/vpn_firewall.go b/go/cmd/entrypoint/vpn_firewall.go index c8be3ec..94589e9 100644 --- a/go/cmd/entrypoint/vpn_firewall.go +++ b/go/cmd/entrypoint/vpn_firewall.go @@ -132,6 +132,36 @@ var subCmdVPNFirewallAdd = subCmd{ }, } +var subCmdVPNFirewallCommit = subCmd{ + name: "commit", + descr: "Commit all changes which were staged using 'add' or 'remove'", + do: func(ctx subCmdCtx) error { + ctx, err := ctx.withParsedFlags() + if err != nil { + return fmt.Errorf("parsing flags: %w", err) + } + + var firewallConfig daecommon.ConfigFirewall + ok, err := ctx.opts.changeStager.get( + &firewallConfig, vpnFirewallConfigChangeStagerName, + ) + if err != nil { + return fmt.Errorf("checking for staged changes: %w", err) + } else if !ok { + return errors.New("no changes staged, use 'add' or 'remove' to stage changes") + } + + config, err := ctx.getDaemonRPC().GetConfig(ctx) + if err != nil { + return fmt.Errorf("getting network config: %w", err) + } + + config.VPN.Firewall = firewallConfig + + return ctx.getDaemonRPC().SetConfig(ctx, config) + }, +} + var subCmdVPNFirewallRemove = subCmd{ name: "remove", descr: "Remove one or more firewall rules from the staged configuration", @@ -145,7 +175,7 @@ var subCmdVPNFirewallRemove = subCmd{ indexes := ctx.flags.IntSlice( "indexes", nil, - "Comma-separated indexes of rules to remove, as returned by 'vpn firewall show --staged'", + "Comma-separated indexes of rules to remove, as returned by 'show --staged'", ) ctx, err := ctx.withParsedFlags() @@ -212,6 +242,14 @@ var subCmdVPNFirewallRemove = subCmd{ }, } +var subCmdVPNFirewallReset = subCmd{ + name: "reset", + descr: "Discard all changes which have been staged", + do: func(ctx subCmdCtx) error { + return ctx.opts.changeStager.del(vpnFirewallConfigChangeStagerName) + }, +} + type firewallRuleView struct { Index int `yaml:"index"` daecommon.ConfigFirewallRule `yaml:",inline"` @@ -289,7 +327,9 @@ var subCmdVPNFirewall = subCmd{ do: func(ctx subCmdCtx) error { return ctx.doSubCmd( subCmdVPNFirewallAdd, + subCmdVPNFirewallCommit, subCmdVPNFirewallRemove, + subCmdVPNFirewallReset, subCmdVPNFirewallShow, ) }, diff --git a/go/cmd/entrypoint/vpn_firewall_test.go b/go/cmd/entrypoint/vpn_firewall_test.go index 4d3480d..e72aff1 100644 --- a/go/cmd/entrypoint/vpn_firewall_test.go +++ b/go/cmd/entrypoint/vpn_firewall_test.go @@ -153,6 +153,77 @@ func TestVPNFirewallAdd(t *testing.T) { } } +func TestVPNFirewallCommit(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + staged *daecommon.ConfigFirewall + }{ + { + name: "error/nothing staged", + }, + { + name: "success", + staged: &daecommon.ConfigFirewall{ + Outbound: []daecommon.ConfigFirewallRule{ + { + Port: "any", + Proto: "any", + Host: "any", + }, + }, + Inbound: []daecommon.ConfigFirewallRule{ + { + Port: "22", + Proto: "tcp", + Host: "foo", + }, + { + Port: "80", + Proto: "any", + Host: "any", + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var ( + h = newRunHarness(t) + config daecommon.NetworkConfig + ) + + args := []string{"vpn", "firewall", "commit"} + + if test.staged == nil { + h.runAssertErrorContains(t, "no changes staged", args...) + return + } + + assert.NoError(t, h.changeStager.set( + *test.staged, vpnFirewallConfigChangeStagerName, + )) + + h.daemonRPC. + On("GetConfig", toolkit.MockArg[context.Context]()). + Return(config, nil). + Once() + + config.VPN.Firewall = *test.staged + + h.daemonRPC. + On("SetConfig", toolkit.MockArg[context.Context](), config). + Return(nil). + Once() + + assert.NoError(t, h.run(t, args...)) + }) + } +} + func TestVPNFirewallRemove(t *testing.T) { t.Parallel()