Implement 'vpn firewall reset' and 'vpn firewall commit'

This commit is contained in:
Brian Picciano 2024-12-10 16:14:48 +01:00
parent 9b27676521
commit 10758f11a2
3 changed files with 113 additions and 2 deletions

View File

@ -35,7 +35,7 @@ func (mgr *changeStager) set(v any, name string) error {
func (mgr *changeStager) del(name string) error { func (mgr *changeStager) del(name string) error {
path := mgr.path(name) path := mgr.path(name)
if err := os.Remove(path); err != nil { if err := os.Remove(path); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("removing file %q: %w", path, err) return fmt.Errorf("removing file %q: %w", path, err)
} }
return nil return nil

View File

@ -132,6 +132,36 @@ var subCmdVPNFirewallAdd = subCmd{
}, },
} }
var subCmdVPNFirewallCommit = subCmd{
name: "commit",
descr: "Commit all changes which were staged using 'add' or 'remove'",
do: func(ctx subCmdCtx) error {
ctx, err := ctx.withParsedFlags()
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
var firewallConfig daecommon.ConfigFirewall
ok, err := ctx.opts.changeStager.get(
&firewallConfig, vpnFirewallConfigChangeStagerName,
)
if err != nil {
return fmt.Errorf("checking for staged changes: %w", err)
} else if !ok {
return errors.New("no changes staged, use 'add' or 'remove' to stage changes")
}
config, err := ctx.getDaemonRPC().GetConfig(ctx)
if err != nil {
return fmt.Errorf("getting network config: %w", err)
}
config.VPN.Firewall = firewallConfig
return ctx.getDaemonRPC().SetConfig(ctx, config)
},
}
var subCmdVPNFirewallRemove = subCmd{ var subCmdVPNFirewallRemove = subCmd{
name: "remove", name: "remove",
descr: "Remove one or more firewall rules from the staged configuration", descr: "Remove one or more firewall rules from the staged configuration",
@ -145,7 +175,7 @@ var subCmdVPNFirewallRemove = subCmd{
indexes := ctx.flags.IntSlice( indexes := ctx.flags.IntSlice(
"indexes", "indexes",
nil, nil,
"Comma-separated indexes of rules to remove, as returned by 'vpn firewall show --staged'", "Comma-separated indexes of rules to remove, as returned by 'show --staged'",
) )
ctx, err := ctx.withParsedFlags() ctx, err := ctx.withParsedFlags()
@ -212,6 +242,14 @@ var subCmdVPNFirewallRemove = subCmd{
}, },
} }
var subCmdVPNFirewallReset = subCmd{
name: "reset",
descr: "Discard all changes which have been staged",
do: func(ctx subCmdCtx) error {
return ctx.opts.changeStager.del(vpnFirewallConfigChangeStagerName)
},
}
type firewallRuleView struct { type firewallRuleView struct {
Index int `yaml:"index"` Index int `yaml:"index"`
daecommon.ConfigFirewallRule `yaml:",inline"` daecommon.ConfigFirewallRule `yaml:",inline"`
@ -289,7 +327,9 @@ var subCmdVPNFirewall = subCmd{
do: func(ctx subCmdCtx) error { do: func(ctx subCmdCtx) error {
return ctx.doSubCmd( return ctx.doSubCmd(
subCmdVPNFirewallAdd, subCmdVPNFirewallAdd,
subCmdVPNFirewallCommit,
subCmdVPNFirewallRemove, subCmdVPNFirewallRemove,
subCmdVPNFirewallReset,
subCmdVPNFirewallShow, subCmdVPNFirewallShow,
) )
}, },

View File

@ -153,6 +153,77 @@ func TestVPNFirewallAdd(t *testing.T) {
} }
} }
func TestVPNFirewallCommit(t *testing.T) {
t.Parallel()
tests := []struct {
name string
staged *daecommon.ConfigFirewall
}{
{
name: "error/nothing staged",
},
{
name: "success",
staged: &daecommon.ConfigFirewall{
Outbound: []daecommon.ConfigFirewallRule{
{
Port: "any",
Proto: "any",
Host: "any",
},
},
Inbound: []daecommon.ConfigFirewallRule{
{
Port: "22",
Proto: "tcp",
Host: "foo",
},
{
Port: "80",
Proto: "any",
Host: "any",
},
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
h = newRunHarness(t)
config daecommon.NetworkConfig
)
args := []string{"vpn", "firewall", "commit"}
if test.staged == nil {
h.runAssertErrorContains(t, "no changes staged", args...)
return
}
assert.NoError(t, h.changeStager.set(
*test.staged, vpnFirewallConfigChangeStagerName,
))
h.daemonRPC.
On("GetConfig", toolkit.MockArg[context.Context]()).
Return(config, nil).
Once()
config.VPN.Firewall = *test.staged
h.daemonRPC.
On("SetConfig", toolkit.MockArg[context.Context](), config).
Return(nil).
Once()
assert.NoError(t, h.run(t, args...))
})
}
}
func TestVPNFirewallRemove(t *testing.T) { func TestVPNFirewallRemove(t *testing.T) {
t.Parallel() t.Parallel()