package main import ( "context" "encoding/json" "fmt" "isle/daemon/daecommon" "isle/toolkit" "os" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/slices" ) func TestVPNFirewallAdd(t *testing.T) { t.Parallel() tests := []struct { name string staged *daecommon.ConfigFirewall flags []string to string wantFlagErr string wantValidateErr string want daecommon.ConfigFirewallRule }{ { name: "flag error/to missing", wantFlagErr: "--to is required", }, { name: "flag error/to invalid", flags: []string{"--to=what"}, wantFlagErr: "invalid --to value", }, { name: "flag error/host and groups given", flags: []string{"--to=inbound", "--groups=foo,bar", "--host=baz"}, wantFlagErr: "--host and --groups are mutually exclusive", }, { name: "validate error/bad port", flags: []string{"--to=inbound", "--port=80-20"}, wantValidateErr: "start port was lower than end port", }, { name: "success/only host", flags: []string{"--to=inbound", "--host=foo"}, to: "inbound", want: daecommon.ConfigFirewallRule{ Port: "any", Proto: "any", Host: "foo", }, }, { name: "success/groups", flags: []string{"--to=outbound", "--groups=foo,bar", "--groups=baz"}, to: "outbound", want: daecommon.ConfigFirewallRule{ Port: "any", Proto: "any", Groups: []string{"foo", "bar", "baz"}, }, }, { name: "success/port and proto", flags: []string{"--to=outbound", "--port=22", "--proto=tcp"}, to: "outbound", want: daecommon.ConfigFirewallRule{ Port: "22", Proto: "tcp", Host: "any", }, }, { name: "success/with staged", staged: &daecommon.ConfigFirewall{ Inbound: []daecommon.ConfigFirewallRule{ { Port: "1", Proto: "tcp", Host: "any", }, }, }, flags: []string{"--to=inbound", "--port=2"}, to: "inbound", want: daecommon.ConfigFirewallRule{ Port: "2", Proto: "any", Host: "any", }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { var ( h = newRunHarness(t) config daecommon.NetworkConfig ) args := append([]string{"vpn", "firewall", "add"}, test.flags...) if test.wantFlagErr != "" { h.runAssertErrorContains(t, test.wantFlagErr, args...) return } h.daemonRPC. On("GetConfig", toolkit.MockArg[context.Context]()). Return(config, nil). Once() if test.staged != nil { assert.NoError(t, h.changeStager.set( *test.staged, vpnFirewallConfigChangeStagerName, )) } if test.wantValidateErr != "" { h.runAssertErrorContains(t, test.wantValidateErr, args...) return } wantConfig := config if test.staged != nil { wantConfig.VPN.Firewall = *test.staged } switch test.to { case "outbound": wantConfig.VPN.Firewall.Outbound = append( wantConfig.VPN.Firewall.Outbound, test.want, ) case "inbound": wantConfig.VPN.Firewall.Inbound = append( wantConfig.VPN.Firewall.Inbound, test.want, ) default: panic(fmt.Sprintf("invalid test.to %q", test.to)) } assert.NoError(t, h.run(t, args...)) h.assertChangeStaged( t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName, ) }) } } func TestVPNFirewallRemove(t *testing.T) { t.Parallel() rules := func(hosts ...string) []daecommon.ConfigFirewallRule { out := make([]daecommon.ConfigFirewallRule, len(hosts)) for i := range hosts { out[i] = daecommon.ConfigFirewallRule{ Port: "any", Proto: "any", Host: hosts[i], } } return out } tests := []struct { name string outbound, inbound []string stagedOutbound, stagedInbound []string flags []string wantFlagErr string wantValidateErr string wantOutbound, wantInbound []string }{ { name: "flag error/from missing", wantFlagErr: "--from and --indexes are required", }, { name: "flag error/indexes missing", flags: []string{"--from=what"}, wantFlagErr: "--from and --indexes are required", }, { name: "flag error/from invalid", flags: []string{"--from=what", "--indexes=1,2,3"}, wantFlagErr: "invalid --from value", }, { name: "flag error/indexes invalid", flags: []string{"--from=inbound", "--indexes=1,-2,3"}, wantFlagErr: "invalid index -2", }, { name: "validate error/indexes invalid", inbound: []string{"foo"}, flags: []string{"--from=inbound", "--indexes=0,3,4"}, wantValidateErr: "invalid index(es): [3 4]", }, { name: "validate error/indexes invalid staged", inbound: []string{"foo", "bar"}, stagedInbound: []string{"foo"}, flags: []string{"--from=inbound", "--indexes=0,1"}, wantValidateErr: "invalid index(es): [1]", }, { name: "success/remove inbound single", inbound: []string{"foo", "bar", "baz"}, flags: []string{"--from=inbound", "--indexes=1"}, wantInbound: []string{"foo", "baz"}, }, { name: "success/remove outbound multiple", outbound: []string{"foo", "bar", "baz"}, inbound: []string{"any"}, flags: []string{"--from=outbound", "--indexes=0,2"}, wantOutbound: []string{"bar"}, wantInbound: []string{"any"}, }, { name: "success/remove staged outbound multiple", inbound: []string{"foo", "bar"}, outbound: []string{"foo", "bar", "baz"}, stagedOutbound: []string{"foo", "bar", "baz", "biz"}, stagedInbound: []string{"any"}, flags: []string{"--from=outbound", "--indexes=0,2,3"}, wantOutbound: []string{"bar"}, wantInbound: []string{"any"}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { var ( h = newRunHarness(t) config, wantConfig daecommon.NetworkConfig ) config.VPN.Firewall = daecommon.ConfigFirewall{ Outbound: rules(test.outbound...), Inbound: rules(test.inbound...), } if len(test.stagedOutbound) > 0 || len(test.stagedInbound) > 0 { assert.NoError(t, h.changeStager.set( daecommon.ConfigFirewall{ Outbound: rules(test.stagedOutbound...), Inbound: rules(test.stagedInbound...), }, vpnFirewallConfigChangeStagerName, )) } wantConfig.VPN.Firewall = daecommon.ConfigFirewall{ Outbound: rules(test.wantOutbound...), Inbound: rules(test.wantInbound...), } args := append([]string{"vpn", "firewall", "remove"}, test.flags...) if test.wantFlagErr != "" { h.runAssertErrorContains(t, test.wantFlagErr, args...) return } h.daemonRPC. On("GetConfig", toolkit.MockArg[context.Context]()). Return(config, nil). Once() if test.wantValidateErr != "" { h.runAssertErrorContains(t, test.wantValidateErr, args...) return } assert.NoError(t, h.run(t, args...)) h.assertChangeStaged( t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName, ) }) } } func TestVPNFirewallShow(t *testing.T) { t.Parallel() tests := []struct { name string outbound, inbound []string staged string flags []string want map[string][]any wantErr string }{ { name: "empty", want: map[string][]any{ "outbound": {}, "inbound": {}, }, }, { name: "single", outbound: []string{ `{"port":"any","proto":"icmp","host":"any"}`, }, want: map[string][]any{ "outbound": { map[string]any{ "index": 0, "port": "any", "proto": "icmp", "host": "any", }, }, "inbound": {}, }, }, { name: "multiple", outbound: []string{ `{"port":"any","proto":"icmp","host":"any"}`, }, inbound: []string{ `{"port":"any","proto":"icmp","host":"any"}`, `{"port":"22","proto":"tcp","host":"foo"}`, }, want: map[string][]any{ "outbound": { map[string]any{ "index": 0, "port": "any", "proto": "icmp", "host": "any", }, }, "inbound": { map[string]any{ "index": 0, "port": "any", "proto": "icmp", "host": "any", }, map[string]any{ "index": 1, "port": "22", "proto": "tcp", "host": "foo", }, }, }, }, { name: "staged/nothing staged", flags: []string{"--staged"}, wantErr: "no firewall configuration changes have been staged", }, { name: "staged/staged but no flag", outbound: []string{ `{"port":"any","proto":"icmp","host":"any"}`, }, staged: `{ "Inbound": [ { "Port":"80", "Proto":"tcp", "Host":"some-host" } ] }`, want: map[string][]any{ "outbound": { map[string]any{ "index": 0, "port": "any", "proto": "icmp", "host": "any", }, }, "inbound": {}, }, }, { name: "staged/staged with flag", outbound: []string{ `{"port":"any","proto":"icmp","host":"any"}`, }, staged: `{ "Inbound": [ { "Port":"80", "Proto":"tcp", "Host":"some-host" } ] }`, flags: []string{"--staged"}, want: map[string][]any{ "outbound": {}, "inbound": { map[string]any{ "index": 0, "port": "80", "proto": "tcp", "host": "some-host", }, }, }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { var ( h = newRunHarness(t) config daecommon.NetworkConfig outboundRawJSON = "[" + strings.Join(test.outbound, ",") + "]" inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]" ) if test.staged != "" { require.True(t, json.Valid([]byte(test.staged))) require.NoError(t, os.WriteFile( h.changeStager.path(vpnFirewallConfigChangeStagerName), []byte(test.staged), 0600, )) } require.NoError(t, json.Unmarshal( []byte(outboundRawJSON), &config.VPN.Firewall.Outbound, )) require.NoError(t, json.Unmarshal( []byte(inboundRawJSON), &config.VPN.Firewall.Inbound, )) if !slices.Contains(test.flags, "--staged") { h.daemonRPC. On("GetConfig", toolkit.MockArg[context.Context]()). Return(config, nil). Once() } args := append([]string{"vpn", "firewall", "show"}, test.flags...) if test.wantErr == "" { h.runAssertStdout(t, test.want, args...) } else { err := h.run(t, args...) if assert.Error(t, err) { assert.Contains(t, err.Error(), test.wantErr) } } }) } }