diff --git a/go/cmd/entrypoint/main_test.go b/go/cmd/entrypoint/main_test.go index b8561c0..84a663c 100644 --- a/go/cmd/entrypoint/main_test.go +++ b/go/cmd/entrypoint/main_test.go @@ -82,6 +82,15 @@ func (h *runHarness) runAssertStdout( assert.Equal(t, want, got.Elem().Interface()) } +func (h *runHarness) runAssertErrorContains( + t *testing.T, want string, args ...string, +) { + err := h.run(t, args...) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), want) + } +} + func (h *runHarness) assertChangeStaged( t *testing.T, want any, diff --git a/go/cmd/entrypoint/sub_cmd.go b/go/cmd/entrypoint/sub_cmd.go index a2546ed..d4a423b 100644 --- a/go/cmd/entrypoint/sub_cmd.go +++ b/go/cmd/entrypoint/sub_cmd.go @@ -17,6 +17,7 @@ import ( "gopkg.in/yaml.v3" ) +// TODO noNetwork and passthroughArgs should be arguments to withParsedFlags. type subCmd struct { name string descr string diff --git a/go/cmd/entrypoint/vpn_firewall.go b/go/cmd/entrypoint/vpn_firewall.go index 5290b62..328c8b8 100644 --- a/go/cmd/entrypoint/vpn_firewall.go +++ b/go/cmd/entrypoint/vpn_firewall.go @@ -4,7 +4,10 @@ import ( "errors" "fmt" "isle/daemon/daecommon" + "slices" "strings" + + "golang.org/x/exp/maps" ) const vpnFirewallConfigChangeStagerName = "vpn-firewall-config" @@ -129,6 +132,86 @@ var subCmdVPNFirewallAdd = subCmd{ }, } +var subCmdVPNFirewallRemove = subCmd{ + name: "remove", + descr: "Remove one or more firewall rules from the staged configuration", + do: func(ctx subCmdCtx) error { + from := ctx.flags.String( + "from", + "", + "Which set of rules to remove from, either 'inbound' or 'outbound'", + ) + + indexes := ctx.flags.IntSlice( + "indexes", + nil, + "Comma-separated indexes of rules to remove, as returned by 'vpn firewall show --staged'", + ) + + ctx, err := ctx.withParsedFlags() + if err != nil { + return fmt.Errorf("parsing flags: %w", err) + } + + if *from == "" || len(*indexes) == 0 { + return errors.New("--from and --indexes are required") + } + + ruleSetFn, err := vpnFirewallRuleSetToFn(*from) + if err != nil { + return fmt.Errorf("invalid --from value %q: %w", *from, err) + } + + indexSet := map[int]struct{}{} + for _, index := range *indexes { + if index < 0 { + return fmt.Errorf("invalid index %d", index) + } + indexSet[index] = struct{}{} + } + + config, err := vpnFirewallGetConfig(ctx) + if err != nil { + return fmt.Errorf("getting network config: %w", err) + } + + var ( + ruleSet = ruleSetFn(&config.VPN.Firewall) + filteredRuleSet = make( + []daecommon.ConfigFirewallRule, 0, len(*ruleSet), + ) + ) + + for i, rule := range *ruleSet { + if _, remove := indexSet[i]; remove { + delete(indexSet, i) + continue + } + filteredRuleSet = append(filteredRuleSet, rule) + } + + if len(indexSet) > 0 { + invalidIndexes := maps.Keys(indexSet) + slices.Sort(invalidIndexes) + return fmt.Errorf("invalid index(es): %v", invalidIndexes) + } + + *ruleSet = filteredRuleSet + + if err := config.Validate(); err != nil { + return err + } + + if err := ctx.opts.changeStager.set( + config.VPN.Firewall, vpnFirewallConfigChangeStagerName, + ); err != nil { + return fmt.Errorf("staging changes: %w", err) + } + + return nil + }, +} + type firewallRuleView struct { Index int `yaml:"index"` daecommon.ConfigFirewallRule `yaml:",inline"` @@ -201,6 +284,7 @@ var subCmdVPNFirewall = subCmd{ do: func(ctx subCmdCtx) error { return ctx.doSubCmd( subCmdVPNFirewallAdd, + subCmdVPNFirewallRemove, subCmdVPNFirewallShow, ) }, diff --git a/go/cmd/entrypoint/vpn_firewall_test.go b/go/cmd/entrypoint/vpn_firewall_test.go index 7e578bf..2f3427c 100644 --- a/go/cmd/entrypoint/vpn_firewall_test.go +++ b/go/cmd/entrypoint/vpn_firewall_test.go @@ -107,10 +107,7 @@ func TestVPNFirewallAdd(t *testing.T) { args := append([]string{"vpn", "firewall", "add"}, test.flags...) if test.wantFlagErr != "" { - err := h.run(t, args...) - if assert.Error(t, err) { - assert.Contains(t, err.Error(), test.wantFlagErr) - } + h.runAssertErrorContains(t, test.wantFlagErr, args...) return } @@ -126,10 +123,7 @@ func TestVPNFirewallAdd(t *testing.T) { } if test.wantValidateErr != "" { - err := h.run(t, args...) - if assert.Error(t, err) { - assert.Contains(t, err.Error(), test.wantValidateErr) - } + h.runAssertErrorContains(t, test.wantValidateErr, args...) return } @@ -155,7 +149,140 @@ func TestVPNFirewallAdd(t *testing.T) { h.assertChangeStaged( t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName, ) + }) + } +} +func TestVPNFirewallRemove(t *testing.T) { + t.Parallel() + + rules := func(hosts ...string) []daecommon.ConfigFirewallRule { + out := make([]daecommon.ConfigFirewallRule, len(hosts)) + for i := range hosts { + out[i] = daecommon.ConfigFirewallRule{ + Port: "any", + Proto: "any", + Host: hosts[i], + } + } + return out + } + + tests := []struct { + name string + outbound, inbound []string + stagedOutbound, stagedInbound []string + flags []string + wantFlagErr string + wantValidateErr string + wantOutbound, wantInbound []string + }{ + { + name: "flag error/from missing", + wantFlagErr: "--from and --indexes are required", + }, + { + name: "flag error/indexes missing", + flags: []string{"--from=what"}, + wantFlagErr: "--from and --indexes are required", + }, + { + name: "flag error/from invalid", + flags: []string{"--from=what", "--indexes=1,2,3"}, + wantFlagErr: "invalid --from value", + }, + { + name: "flag error/indexes invalid", + flags: []string{"--from=inbound", "--indexes=1,-2,3"}, + wantFlagErr: "invalid index -2", + }, + { + name: "validate error/indexes invalid", + inbound: []string{"foo"}, + flags: []string{"--from=inbound", "--indexes=0,3,4"}, + wantValidateErr: "invalid index(es): [3 4]", + }, + { + name: "validate error/indexes invalid staged", + inbound: []string{"foo", "bar"}, + stagedInbound: []string{"foo"}, + flags: []string{"--from=inbound", "--indexes=0,1"}, + wantValidateErr: "invalid index(es): [1]", + }, + { + name: "success/remove inbound single", + inbound: []string{"foo", "bar", "baz"}, + flags: []string{"--from=inbound", "--indexes=1"}, + wantInbound: []string{"foo", "baz"}, + }, + { + name: "success/remove outbound multiple", + outbound: []string{"foo", "bar", "baz"}, + inbound: []string{"any"}, + flags: []string{"--from=outbound", "--indexes=0,2"}, + wantOutbound: []string{"bar"}, + wantInbound: []string{"any"}, + }, + { + name: "success/remove staged outbound multiple", + inbound: []string{"foo", "bar"}, + outbound: []string{"foo", "bar", "baz"}, + stagedOutbound: []string{"foo", "bar", "baz", "biz"}, + stagedInbound: []string{"any"}, + flags: []string{"--from=outbound", "--indexes=0,2,3"}, + wantOutbound: []string{"bar"}, + wantInbound: []string{"any"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var ( + h = newRunHarness(t) + config, wantConfig daecommon.NetworkConfig + ) + + config.VPN.Firewall = daecommon.ConfigFirewall{ + Outbound: rules(test.outbound...), + Inbound: rules(test.inbound...), + } + + if len(test.stagedOutbound) > 0 || len(test.stagedInbound) > 0 { + assert.NoError(t, h.changeStager.set( + daecommon.ConfigFirewall{ + Outbound: rules(test.stagedOutbound...), + Inbound: rules(test.stagedInbound...), + }, + vpnFirewallConfigChangeStagerName, + )) + } + + wantConfig.VPN.Firewall = daecommon.ConfigFirewall{ + Outbound: rules(test.wantOutbound...), + Inbound: rules(test.wantInbound...), + } + + args := append([]string{"vpn", "firewall", "remove"}, test.flags...) + + if test.wantFlagErr != "" { + h.runAssertErrorContains(t, test.wantFlagErr, args...) + return + } + + h.daemonRPC. + On("GetConfig", toolkit.MockArg[context.Context]()). + Return(config, nil). + Once() + + if test.wantValidateErr != "" { + h.runAssertErrorContains(t, test.wantValidateErr, args...) + return + } + + assert.NoError(t, h.run(t, args...)) + h.assertChangeStaged( + t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName, + ) }) } }