Implement 'vpn firewall remove'
This commit is contained in:
parent
a5829a6493
commit
dd847cafe1
@ -82,6 +82,15 @@ func (h *runHarness) runAssertStdout(
|
|||||||
assert.Equal(t, want, got.Elem().Interface())
|
assert.Equal(t, want, got.Elem().Interface())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *runHarness) runAssertErrorContains(
|
||||||
|
t *testing.T, want string, args ...string,
|
||||||
|
) {
|
||||||
|
err := h.run(t, args...)
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Contains(t, err.Error(), want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *runHarness) assertChangeStaged(
|
func (h *runHarness) assertChangeStaged(
|
||||||
t *testing.T,
|
t *testing.T,
|
||||||
want any,
|
want any,
|
||||||
|
@ -17,6 +17,7 @@ import (
|
|||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO noNetwork and passthroughArgs should be arguments to withParsedFlags.
|
||||||
type subCmd struct {
|
type subCmd struct {
|
||||||
name string
|
name string
|
||||||
descr string
|
descr string
|
||||||
|
@ -4,7 +4,10 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"isle/daemon/daecommon"
|
"isle/daemon/daecommon"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
)
|
)
|
||||||
|
|
||||||
const vpnFirewallConfigChangeStagerName = "vpn-firewall-config"
|
const vpnFirewallConfigChangeStagerName = "vpn-firewall-config"
|
||||||
@ -129,6 +132,86 @@ var subCmdVPNFirewallAdd = subCmd{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var subCmdVPNFirewallRemove = subCmd{
|
||||||
|
name: "remove",
|
||||||
|
descr: "Remove one or more firewall rules from the staged configuration",
|
||||||
|
do: func(ctx subCmdCtx) error {
|
||||||
|
from := ctx.flags.String(
|
||||||
|
"from",
|
||||||
|
"",
|
||||||
|
"Which set of rules to remove from, either 'inbound' or 'outbound'",
|
||||||
|
)
|
||||||
|
|
||||||
|
indexes := ctx.flags.IntSlice(
|
||||||
|
"indexes",
|
||||||
|
nil,
|
||||||
|
"Comma-separated indexes of rules to remove, as returned by 'vpn firewall show --staged'",
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, err := ctx.withParsedFlags()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing flags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if *from == "" || len(*indexes) == 0 {
|
||||||
|
return errors.New("--from and --indexes are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleSetFn, err := vpnFirewallRuleSetToFn(*from)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid --from value %q: %w", *from, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
indexSet := map[int]struct{}{}
|
||||||
|
for _, index := range *indexes {
|
||||||
|
if index < 0 {
|
||||||
|
return fmt.Errorf("invalid index %d", index)
|
||||||
|
}
|
||||||
|
indexSet[index] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := vpnFirewallGetConfig(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting network config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ruleSet = ruleSetFn(&config.VPN.Firewall)
|
||||||
|
filteredRuleSet = make(
|
||||||
|
[]daecommon.ConfigFirewallRule, 0, len(*ruleSet),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, rule := range *ruleSet {
|
||||||
|
if _, remove := indexSet[i]; remove {
|
||||||
|
delete(indexSet, i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filteredRuleSet = append(filteredRuleSet, rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(indexSet) > 0 {
|
||||||
|
invalidIndexes := maps.Keys(indexSet)
|
||||||
|
slices.Sort(invalidIndexes)
|
||||||
|
return fmt.Errorf("invalid index(es): %v", invalidIndexes)
|
||||||
|
}
|
||||||
|
|
||||||
|
*ruleSet = filteredRuleSet
|
||||||
|
|
||||||
|
if err := config.Validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ctx.opts.changeStager.set(
|
||||||
|
config.VPN.Firewall, vpnFirewallConfigChangeStagerName,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("staging changes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
type firewallRuleView struct {
|
type firewallRuleView struct {
|
||||||
Index int `yaml:"index"`
|
Index int `yaml:"index"`
|
||||||
daecommon.ConfigFirewallRule `yaml:",inline"`
|
daecommon.ConfigFirewallRule `yaml:",inline"`
|
||||||
@ -201,6 +284,7 @@ var subCmdVPNFirewall = subCmd{
|
|||||||
do: func(ctx subCmdCtx) error {
|
do: func(ctx subCmdCtx) error {
|
||||||
return ctx.doSubCmd(
|
return ctx.doSubCmd(
|
||||||
subCmdVPNFirewallAdd,
|
subCmdVPNFirewallAdd,
|
||||||
|
subCmdVPNFirewallRemove,
|
||||||
subCmdVPNFirewallShow,
|
subCmdVPNFirewallShow,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
@ -107,10 +107,7 @@ func TestVPNFirewallAdd(t *testing.T) {
|
|||||||
args := append([]string{"vpn", "firewall", "add"}, test.flags...)
|
args := append([]string{"vpn", "firewall", "add"}, test.flags...)
|
||||||
|
|
||||||
if test.wantFlagErr != "" {
|
if test.wantFlagErr != "" {
|
||||||
err := h.run(t, args...)
|
h.runAssertErrorContains(t, test.wantFlagErr, args...)
|
||||||
if assert.Error(t, err) {
|
|
||||||
assert.Contains(t, err.Error(), test.wantFlagErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -126,10 +123,7 @@ func TestVPNFirewallAdd(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if test.wantValidateErr != "" {
|
if test.wantValidateErr != "" {
|
||||||
err := h.run(t, args...)
|
h.runAssertErrorContains(t, test.wantValidateErr, args...)
|
||||||
if assert.Error(t, err) {
|
|
||||||
assert.Contains(t, err.Error(), test.wantValidateErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,7 +149,140 @@ func TestVPNFirewallAdd(t *testing.T) {
|
|||||||
h.assertChangeStaged(
|
h.assertChangeStaged(
|
||||||
t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName,
|
t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName,
|
||||||
)
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVPNFirewallRemove(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rules := func(hosts ...string) []daecommon.ConfigFirewallRule {
|
||||||
|
out := make([]daecommon.ConfigFirewallRule, len(hosts))
|
||||||
|
for i := range hosts {
|
||||||
|
out[i] = daecommon.ConfigFirewallRule{
|
||||||
|
Port: "any",
|
||||||
|
Proto: "any",
|
||||||
|
Host: hosts[i],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
outbound, inbound []string
|
||||||
|
stagedOutbound, stagedInbound []string
|
||||||
|
flags []string
|
||||||
|
wantFlagErr string
|
||||||
|
wantValidateErr string
|
||||||
|
wantOutbound, wantInbound []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "flag error/from missing",
|
||||||
|
wantFlagErr: "--from and --indexes are required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "flag error/indexes missing",
|
||||||
|
flags: []string{"--from=what"},
|
||||||
|
wantFlagErr: "--from and --indexes are required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "flag error/from invalid",
|
||||||
|
flags: []string{"--from=what", "--indexes=1,2,3"},
|
||||||
|
wantFlagErr: "invalid --from value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "flag error/indexes invalid",
|
||||||
|
flags: []string{"--from=inbound", "--indexes=1,-2,3"},
|
||||||
|
wantFlagErr: "invalid index -2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "validate error/indexes invalid",
|
||||||
|
inbound: []string{"foo"},
|
||||||
|
flags: []string{"--from=inbound", "--indexes=0,3,4"},
|
||||||
|
wantValidateErr: "invalid index(es): [3 4]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "validate error/indexes invalid staged",
|
||||||
|
inbound: []string{"foo", "bar"},
|
||||||
|
stagedInbound: []string{"foo"},
|
||||||
|
flags: []string{"--from=inbound", "--indexes=0,1"},
|
||||||
|
wantValidateErr: "invalid index(es): [1]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success/remove inbound single",
|
||||||
|
inbound: []string{"foo", "bar", "baz"},
|
||||||
|
flags: []string{"--from=inbound", "--indexes=1"},
|
||||||
|
wantInbound: []string{"foo", "baz"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success/remove outbound multiple",
|
||||||
|
outbound: []string{"foo", "bar", "baz"},
|
||||||
|
inbound: []string{"any"},
|
||||||
|
flags: []string{"--from=outbound", "--indexes=0,2"},
|
||||||
|
wantOutbound: []string{"bar"},
|
||||||
|
wantInbound: []string{"any"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success/remove staged outbound multiple",
|
||||||
|
inbound: []string{"foo", "bar"},
|
||||||
|
outbound: []string{"foo", "bar", "baz"},
|
||||||
|
stagedOutbound: []string{"foo", "bar", "baz", "biz"},
|
||||||
|
stagedInbound: []string{"any"},
|
||||||
|
flags: []string{"--from=outbound", "--indexes=0,2,3"},
|
||||||
|
wantOutbound: []string{"bar"},
|
||||||
|
wantInbound: []string{"any"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
h = newRunHarness(t)
|
||||||
|
config, wantConfig daecommon.NetworkConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
config.VPN.Firewall = daecommon.ConfigFirewall{
|
||||||
|
Outbound: rules(test.outbound...),
|
||||||
|
Inbound: rules(test.inbound...),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(test.stagedOutbound) > 0 || len(test.stagedInbound) > 0 {
|
||||||
|
assert.NoError(t, h.changeStager.set(
|
||||||
|
daecommon.ConfigFirewall{
|
||||||
|
Outbound: rules(test.stagedOutbound...),
|
||||||
|
Inbound: rules(test.stagedInbound...),
|
||||||
|
},
|
||||||
|
vpnFirewallConfigChangeStagerName,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
wantConfig.VPN.Firewall = daecommon.ConfigFirewall{
|
||||||
|
Outbound: rules(test.wantOutbound...),
|
||||||
|
Inbound: rules(test.wantInbound...),
|
||||||
|
}
|
||||||
|
|
||||||
|
args := append([]string{"vpn", "firewall", "remove"}, test.flags...)
|
||||||
|
|
||||||
|
if test.wantFlagErr != "" {
|
||||||
|
h.runAssertErrorContains(t, test.wantFlagErr, args...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.daemonRPC.
|
||||||
|
On("GetConfig", toolkit.MockArg[context.Context]()).
|
||||||
|
Return(config, nil).
|
||||||
|
Once()
|
||||||
|
|
||||||
|
if test.wantValidateErr != "" {
|
||||||
|
h.runAssertErrorContains(t, test.wantValidateErr, args...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, h.run(t, args...))
|
||||||
|
h.assertChangeStaged(
|
||||||
|
t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user