Implement 'vpn firewall remove'

This commit is contained in:
Brian Picciano 2024-12-10 15:17:07 +01:00
parent a5829a6493
commit dd847cafe1
4 changed files with 229 additions and 8 deletions

View File

@ -82,6 +82,15 @@ func (h *runHarness) runAssertStdout(
assert.Equal(t, want, got.Elem().Interface()) assert.Equal(t, want, got.Elem().Interface())
} }
func (h *runHarness) runAssertErrorContains(
t *testing.T, want string, args ...string,
) {
err := h.run(t, args...)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), want)
}
}
func (h *runHarness) assertChangeStaged( func (h *runHarness) assertChangeStaged(
t *testing.T, t *testing.T,
want any, want any,

View File

@ -17,6 +17,7 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
// TODO noNetwork and passthroughArgs should be arguments to withParsedFlags.
type subCmd struct { type subCmd struct {
name string name string
descr string descr string

View File

@ -4,7 +4,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"isle/daemon/daecommon" "isle/daemon/daecommon"
"slices"
"strings" "strings"
"golang.org/x/exp/maps"
) )
const vpnFirewallConfigChangeStagerName = "vpn-firewall-config" const vpnFirewallConfigChangeStagerName = "vpn-firewall-config"
@ -129,6 +132,86 @@ var subCmdVPNFirewallAdd = subCmd{
}, },
} }
var subCmdVPNFirewallRemove = subCmd{
name: "remove",
descr: "Remove one or more firewall rules from the staged configuration",
do: func(ctx subCmdCtx) error {
from := ctx.flags.String(
"from",
"",
"Which set of rules to remove from, either 'inbound' or 'outbound'",
)
indexes := ctx.flags.IntSlice(
"indexes",
nil,
"Comma-separated indexes of rules to remove, as returned by 'vpn firewall show --staged'",
)
ctx, err := ctx.withParsedFlags()
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
if *from == "" || len(*indexes) == 0 {
return errors.New("--from and --indexes are required")
}
ruleSetFn, err := vpnFirewallRuleSetToFn(*from)
if err != nil {
return fmt.Errorf("invalid --from value %q: %w", *from, err)
}
indexSet := map[int]struct{}{}
for _, index := range *indexes {
if index < 0 {
return fmt.Errorf("invalid index %d", index)
}
indexSet[index] = struct{}{}
}
config, err := vpnFirewallGetConfig(ctx)
if err != nil {
return fmt.Errorf("getting network config: %w", err)
}
var (
ruleSet = ruleSetFn(&config.VPN.Firewall)
filteredRuleSet = make(
[]daecommon.ConfigFirewallRule, 0, len(*ruleSet),
)
)
for i, rule := range *ruleSet {
if _, remove := indexSet[i]; remove {
delete(indexSet, i)
continue
}
filteredRuleSet = append(filteredRuleSet, rule)
}
if len(indexSet) > 0 {
invalidIndexes := maps.Keys(indexSet)
slices.Sort(invalidIndexes)
return fmt.Errorf("invalid index(es): %v", invalidIndexes)
}
*ruleSet = filteredRuleSet
if err := config.Validate(); err != nil {
return err
}
if err := ctx.opts.changeStager.set(
config.VPN.Firewall, vpnFirewallConfigChangeStagerName,
); err != nil {
return fmt.Errorf("staging changes: %w", err)
}
return nil
},
}
type firewallRuleView struct { type firewallRuleView struct {
Index int `yaml:"index"` Index int `yaml:"index"`
daecommon.ConfigFirewallRule `yaml:",inline"` daecommon.ConfigFirewallRule `yaml:",inline"`
@ -201,6 +284,7 @@ var subCmdVPNFirewall = subCmd{
do: func(ctx subCmdCtx) error { do: func(ctx subCmdCtx) error {
return ctx.doSubCmd( return ctx.doSubCmd(
subCmdVPNFirewallAdd, subCmdVPNFirewallAdd,
subCmdVPNFirewallRemove,
subCmdVPNFirewallShow, subCmdVPNFirewallShow,
) )
}, },

View File

@ -107,10 +107,7 @@ func TestVPNFirewallAdd(t *testing.T) {
args := append([]string{"vpn", "firewall", "add"}, test.flags...) args := append([]string{"vpn", "firewall", "add"}, test.flags...)
if test.wantFlagErr != "" { if test.wantFlagErr != "" {
err := h.run(t, args...) h.runAssertErrorContains(t, test.wantFlagErr, args...)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), test.wantFlagErr)
}
return return
} }
@ -126,10 +123,7 @@ func TestVPNFirewallAdd(t *testing.T) {
} }
if test.wantValidateErr != "" { if test.wantValidateErr != "" {
err := h.run(t, args...) h.runAssertErrorContains(t, test.wantValidateErr, args...)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), test.wantValidateErr)
}
return return
} }
@ -155,7 +149,140 @@ func TestVPNFirewallAdd(t *testing.T) {
h.assertChangeStaged( h.assertChangeStaged(
t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName, t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName,
) )
})
}
}
func TestVPNFirewallRemove(t *testing.T) {
t.Parallel()
rules := func(hosts ...string) []daecommon.ConfigFirewallRule {
out := make([]daecommon.ConfigFirewallRule, len(hosts))
for i := range hosts {
out[i] = daecommon.ConfigFirewallRule{
Port: "any",
Proto: "any",
Host: hosts[i],
}
}
return out
}
tests := []struct {
name string
outbound, inbound []string
stagedOutbound, stagedInbound []string
flags []string
wantFlagErr string
wantValidateErr string
wantOutbound, wantInbound []string
}{
{
name: "flag error/from missing",
wantFlagErr: "--from and --indexes are required",
},
{
name: "flag error/indexes missing",
flags: []string{"--from=what"},
wantFlagErr: "--from and --indexes are required",
},
{
name: "flag error/from invalid",
flags: []string{"--from=what", "--indexes=1,2,3"},
wantFlagErr: "invalid --from value",
},
{
name: "flag error/indexes invalid",
flags: []string{"--from=inbound", "--indexes=1,-2,3"},
wantFlagErr: "invalid index -2",
},
{
name: "validate error/indexes invalid",
inbound: []string{"foo"},
flags: []string{"--from=inbound", "--indexes=0,3,4"},
wantValidateErr: "invalid index(es): [3 4]",
},
{
name: "validate error/indexes invalid staged",
inbound: []string{"foo", "bar"},
stagedInbound: []string{"foo"},
flags: []string{"--from=inbound", "--indexes=0,1"},
wantValidateErr: "invalid index(es): [1]",
},
{
name: "success/remove inbound single",
inbound: []string{"foo", "bar", "baz"},
flags: []string{"--from=inbound", "--indexes=1"},
wantInbound: []string{"foo", "baz"},
},
{
name: "success/remove outbound multiple",
outbound: []string{"foo", "bar", "baz"},
inbound: []string{"any"},
flags: []string{"--from=outbound", "--indexes=0,2"},
wantOutbound: []string{"bar"},
wantInbound: []string{"any"},
},
{
name: "success/remove staged outbound multiple",
inbound: []string{"foo", "bar"},
outbound: []string{"foo", "bar", "baz"},
stagedOutbound: []string{"foo", "bar", "baz", "biz"},
stagedInbound: []string{"any"},
flags: []string{"--from=outbound", "--indexes=0,2,3"},
wantOutbound: []string{"bar"},
wantInbound: []string{"any"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
h = newRunHarness(t)
config, wantConfig daecommon.NetworkConfig
)
config.VPN.Firewall = daecommon.ConfigFirewall{
Outbound: rules(test.outbound...),
Inbound: rules(test.inbound...),
}
if len(test.stagedOutbound) > 0 || len(test.stagedInbound) > 0 {
assert.NoError(t, h.changeStager.set(
daecommon.ConfigFirewall{
Outbound: rules(test.stagedOutbound...),
Inbound: rules(test.stagedInbound...),
},
vpnFirewallConfigChangeStagerName,
))
}
wantConfig.VPN.Firewall = daecommon.ConfigFirewall{
Outbound: rules(test.wantOutbound...),
Inbound: rules(test.wantInbound...),
}
args := append([]string{"vpn", "firewall", "remove"}, test.flags...)
if test.wantFlagErr != "" {
h.runAssertErrorContains(t, test.wantFlagErr, args...)
return
}
h.daemonRPC.
On("GetConfig", toolkit.MockArg[context.Context]()).
Return(config, nil).
Once()
if test.wantValidateErr != "" {
h.runAssertErrorContains(t, test.wantValidateErr, args...)
return
}
assert.NoError(t, h.run(t, args...))
h.assertChangeStaged(
t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName,
)
}) })
} }
} }