isle/go/cmd/entrypoint/vpn_firewall_test.go

193 lines
3.7 KiB
Go
Raw Normal View History

package main
import (
"context"
"encoding/json"
"isle/daemon/daecommon"
"isle/toolkit"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
func TestVPNFirewallList(t *testing.T) {
t.Parallel()
tests := []struct {
name string
outbound, inbound []string
staged string
flags []string
want map[string][]any
wantErr string
}{
{
name: "empty",
want: map[string][]any{
"outbound": {},
"inbound": {},
},
},
{
name: "single",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {},
},
},
{
name: "multiple",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
inbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
`{"port":"22","proto":"tcp","host":"foo"}`,
},
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
map[string]any{
"index": 1,
"port": "22",
"proto": "tcp",
"host": "foo",
},
},
},
},
{
name: "staged/nothing staged",
flags: []string{"--staged"},
wantErr: "no firewall configuration changes have been staged",
},
{
name: "staged/staged but no flag",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
staged: `{
"Inbound": [
{
"Port":"80",
"Proto":"tcp",
"Host":"some-host"
}
]
}`,
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {},
},
},
{
name: "staged/staged with flag",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
staged: `{
"Inbound": [
{
"Port":"80",
"Proto":"tcp",
"Host":"some-host"
}
]
}`,
flags: []string{"--staged"},
want: map[string][]any{
"outbound": {},
"inbound": {
map[string]any{
"index": 0,
"port": "80",
"proto": "tcp",
"host": "some-host",
},
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
h = newRunHarness(t)
config daecommon.NetworkConfig
outboundRawJSON = "[" + strings.Join(test.outbound, ",") + "]"
inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]"
)
if test.staged != "" {
require.True(t, json.Valid([]byte(test.staged)))
require.NoError(t, os.WriteFile(
h.changeStager.path(vpnFirewallConfigChangeStagerName),
[]byte(test.staged),
0600,
))
}
require.NoError(t, json.Unmarshal(
[]byte(outboundRawJSON), &config.VPN.Firewall.Outbound,
))
require.NoError(t, json.Unmarshal(
[]byte(inboundRawJSON), &config.VPN.Firewall.Inbound,
))
if !slices.Contains(test.flags, "--staged") {
h.daemonRPC.
On("GetConfig", toolkit.MockArg[context.Context]()).
Return(config, nil).
Once()
}
args := append([]string{"vpn", "firewall", "list"}, test.flags...)
if test.wantErr == "" {
h.runAssertStdout(t, test.want, args...)
} else {
err := h.run(t, args...)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), test.wantErr)
}
}
})
}
}