package main import ( "context" "encoding/json" "fmt" "isle/daemon/daecommon" "isle/toolkit" "os" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/slices" ) func TestVPNFirewallAdd(t *testing.T) { t.Parallel() tests := []struct { name string staged *daecommon.ConfigFirewall flags []string to string wantFlagErr string wantValidateErr string want daecommon.ConfigFirewallRule }{ { name: "flag error/to missing", wantFlagErr: "--to is required", }, { name: "flag error/to invalid", flags: []string{"--to=what"}, wantFlagErr: "invalid --to value", }, { name: "flag error/host and groups given", flags: []string{"--to=inbound", "--groups=foo,bar", "--host=baz"}, wantFlagErr: "--host and --groups are mutually exclusive", }, { name: "validate error/bad port", flags: []string{"--to=inbound", "--port=80-20"}, wantValidateErr: "start port was lower than end port", }, { name: "success/only host", flags: []string{"--to=inbound", "--host=foo"}, to: "inbound", want: daecommon.ConfigFirewallRule{ Port: "any", Proto: "any", Host: "foo", }, }, { name: "success/groups", flags: []string{"--to=outbound", "--groups=foo,bar", "--groups=baz"}, to: "outbound", want: daecommon.ConfigFirewallRule{ Port: "any", Proto: "any", Groups: []string{"foo", "bar", "baz"}, }, }, { name: "success/port and proto", flags: []string{"--to=outbound", "--port=22", "--proto=tcp"}, to: "outbound", want: daecommon.ConfigFirewallRule{ Port: "22", Proto: "tcp", Host: "any", }, }, { name: "success/with staged", staged: &daecommon.ConfigFirewall{ Inbound: []daecommon.ConfigFirewallRule{ { Port: "1", Proto: "tcp", Host: "any", }, }, }, flags: []string{"--to=inbound", "--port=2"}, to: "inbound", want: daecommon.ConfigFirewallRule{ Port: "2", Proto: "any", Host: "any", }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { var ( h = newRunHarness(t) config daecommon.NetworkConfig ) args := append([]string{"vpn", "firewall", "add"}, test.flags...) if test.wantFlagErr != "" { err := h.run(t, args...) if assert.Error(t, err) { assert.Contains(t, err.Error(), test.wantFlagErr) } return } h.daemonRPC. On("GetConfig", toolkit.MockArg[context.Context]()). Return(config, nil). Once() if test.staged != nil { assert.NoError(t, h.changeStager.set( *test.staged, vpnFirewallConfigChangeStagerName, )) } if test.wantValidateErr != "" { err := h.run(t, args...) if assert.Error(t, err) { assert.Contains(t, err.Error(), test.wantValidateErr) } return } wantConfig := config if test.staged != nil { wantConfig.VPN.Firewall = *test.staged } switch test.to { case "outbound": wantConfig.VPN.Firewall.Outbound = append( wantConfig.VPN.Firewall.Outbound, test.want, ) case "inbound": wantConfig.VPN.Firewall.Inbound = append( wantConfig.VPN.Firewall.Inbound, test.want, ) default: panic(fmt.Sprintf("invalid test.to %q", test.to)) } assert.NoError(t, h.run(t, args...)) h.assertChangeStaged( t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName, ) }) } } func TestVPNFirewallShow(t *testing.T) { t.Parallel() tests := []struct { name string outbound, inbound []string staged string flags []string want map[string][]any wantErr string }{ { name: "empty", want: map[string][]any{ "outbound": {}, "inbound": {}, }, }, { name: "single", outbound: []string{ `{"port":"any","proto":"icmp","host":"any"}`, }, want: map[string][]any{ "outbound": { map[string]any{ "index": 0, "port": "any", "proto": "icmp", "host": "any", }, }, "inbound": {}, }, }, { name: "multiple", outbound: []string{ `{"port":"any","proto":"icmp","host":"any"}`, }, inbound: []string{ `{"port":"any","proto":"icmp","host":"any"}`, `{"port":"22","proto":"tcp","host":"foo"}`, }, want: map[string][]any{ "outbound": { map[string]any{ "index": 0, "port": "any", "proto": "icmp", "host": "any", }, }, "inbound": { map[string]any{ "index": 0, "port": "any", "proto": "icmp", "host": "any", }, map[string]any{ "index": 1, "port": "22", "proto": "tcp", "host": "foo", }, }, }, }, { name: "staged/nothing staged", flags: []string{"--staged"}, wantErr: "no firewall configuration changes have been staged", }, { name: "staged/staged but no flag", outbound: []string{ `{"port":"any","proto":"icmp","host":"any"}`, }, staged: `{ "Inbound": [ { "Port":"80", "Proto":"tcp", "Host":"some-host" } ] }`, want: map[string][]any{ "outbound": { map[string]any{ "index": 0, "port": "any", "proto": "icmp", "host": "any", }, }, "inbound": {}, }, }, { name: "staged/staged with flag", outbound: []string{ `{"port":"any","proto":"icmp","host":"any"}`, }, staged: `{ "Inbound": [ { "Port":"80", "Proto":"tcp", "Host":"some-host" } ] }`, flags: []string{"--staged"}, want: map[string][]any{ "outbound": {}, "inbound": { map[string]any{ "index": 0, "port": "80", "proto": "tcp", "host": "some-host", }, }, }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { var ( h = newRunHarness(t) config daecommon.NetworkConfig outboundRawJSON = "[" + strings.Join(test.outbound, ",") + "]" inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]" ) if test.staged != "" { require.True(t, json.Valid([]byte(test.staged))) require.NoError(t, os.WriteFile( h.changeStager.path(vpnFirewallConfigChangeStagerName), []byte(test.staged), 0600, )) } require.NoError(t, json.Unmarshal( []byte(outboundRawJSON), &config.VPN.Firewall.Outbound, )) require.NoError(t, json.Unmarshal( []byte(inboundRawJSON), &config.VPN.Firewall.Inbound, )) if !slices.Contains(test.flags, "--staged") { h.daemonRPC. On("GetConfig", toolkit.MockArg[context.Context]()). Return(config, nil). Once() } args := append([]string{"vpn", "firewall", "show"}, test.flags...) if test.wantErr == "" { h.runAssertStdout(t, test.want, args...) } else { err := h.run(t, args...) if assert.Error(t, err) { assert.Contains(t, err.Error(), test.wantErr) } } }) } }