isle/go/cmd/entrypoint/vpn_firewall_test.go

339 lines
7.0 KiB
Go
Raw Normal View History

package main
import (
"context"
"encoding/json"
2024-12-10 12:52:57 +00:00
"fmt"
"isle/daemon/daecommon"
"isle/toolkit"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
2024-12-10 12:52:57 +00:00
func TestVPNFirewallAdd(t *testing.T) {
t.Parallel()
tests := []struct {
name string
staged *daecommon.ConfigFirewall
flags []string
to string
wantFlagErr string
wantValidateErr string
want daecommon.ConfigFirewallRule
}{
{
name: "flag error/to missing",
wantFlagErr: "--to is required",
},
{
name: "flag error/to invalid",
flags: []string{"--to=what"},
wantFlagErr: "invalid --to value",
},
{
name: "flag error/host and groups given",
flags: []string{"--to=inbound", "--groups=foo,bar", "--host=baz"},
wantFlagErr: "--host and --groups are mutually exclusive",
},
{
name: "validate error/bad port",
flags: []string{"--to=inbound", "--port=80-20"},
wantValidateErr: "start port was lower than end port",
},
{
name: "success/only host",
flags: []string{"--to=inbound", "--host=foo"},
to: "inbound",
want: daecommon.ConfigFirewallRule{
Port: "any",
Proto: "any",
Host: "foo",
},
},
{
name: "success/groups",
flags: []string{"--to=outbound", "--groups=foo,bar", "--groups=baz"},
to: "outbound",
want: daecommon.ConfigFirewallRule{
Port: "any",
Proto: "any",
Groups: []string{"foo", "bar", "baz"},
},
},
{
name: "success/port and proto",
flags: []string{"--to=outbound", "--port=22", "--proto=tcp"},
to: "outbound",
want: daecommon.ConfigFirewallRule{
Port: "22",
Proto: "tcp",
Host: "any",
},
},
{
name: "success/with staged",
staged: &daecommon.ConfigFirewall{
Inbound: []daecommon.ConfigFirewallRule{
{
Port: "1",
Proto: "tcp",
Host: "any",
},
},
},
flags: []string{"--to=inbound", "--port=2"},
to: "inbound",
want: daecommon.ConfigFirewallRule{
Port: "2",
Proto: "any",
Host: "any",
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
h = newRunHarness(t)
config daecommon.NetworkConfig
)
args := append([]string{"vpn", "firewall", "add"}, test.flags...)
if test.wantFlagErr != "" {
err := h.run(t, args...)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), test.wantFlagErr)
}
return
}
h.daemonRPC.
On("GetConfig", toolkit.MockArg[context.Context]()).
Return(config, nil).
Once()
if test.staged != nil {
assert.NoError(t, h.changeStager.set(
*test.staged, vpnFirewallConfigChangeStagerName,
))
}
if test.wantValidateErr != "" {
err := h.run(t, args...)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), test.wantValidateErr)
}
return
}
wantConfig := config
if test.staged != nil {
wantConfig.VPN.Firewall = *test.staged
}
switch test.to {
case "outbound":
wantConfig.VPN.Firewall.Outbound = append(
wantConfig.VPN.Firewall.Outbound, test.want,
)
case "inbound":
wantConfig.VPN.Firewall.Inbound = append(
wantConfig.VPN.Firewall.Inbound, test.want,
)
default:
panic(fmt.Sprintf("invalid test.to %q", test.to))
}
assert.NoError(t, h.run(t, args...))
h.assertChangeStaged(
t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName,
)
})
}
}
func TestVPNFirewallList(t *testing.T) {
t.Parallel()
tests := []struct {
name string
outbound, inbound []string
staged string
flags []string
want map[string][]any
wantErr string
}{
{
name: "empty",
want: map[string][]any{
"outbound": {},
"inbound": {},
},
},
{
name: "single",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {},
},
},
{
name: "multiple",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
inbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
`{"port":"22","proto":"tcp","host":"foo"}`,
},
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
map[string]any{
"index": 1,
"port": "22",
"proto": "tcp",
"host": "foo",
},
},
},
},
{
name: "staged/nothing staged",
flags: []string{"--staged"},
wantErr: "no firewall configuration changes have been staged",
},
{
name: "staged/staged but no flag",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
staged: `{
"Inbound": [
{
"Port":"80",
"Proto":"tcp",
"Host":"some-host"
}
]
}`,
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {},
},
},
{
name: "staged/staged with flag",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
staged: `{
"Inbound": [
{
"Port":"80",
"Proto":"tcp",
"Host":"some-host"
}
]
}`,
flags: []string{"--staged"},
want: map[string][]any{
"outbound": {},
"inbound": {
map[string]any{
"index": 0,
"port": "80",
"proto": "tcp",
"host": "some-host",
},
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
h = newRunHarness(t)
config daecommon.NetworkConfig
outboundRawJSON = "[" + strings.Join(test.outbound, ",") + "]"
inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]"
)
if test.staged != "" {
require.True(t, json.Valid([]byte(test.staged)))
require.NoError(t, os.WriteFile(
h.changeStager.path(vpnFirewallConfigChangeStagerName),
[]byte(test.staged),
0600,
))
}
require.NoError(t, json.Unmarshal(
[]byte(outboundRawJSON), &config.VPN.Firewall.Outbound,
))
require.NoError(t, json.Unmarshal(
[]byte(inboundRawJSON), &config.VPN.Firewall.Inbound,
))
if !slices.Contains(test.flags, "--staged") {
h.daemonRPC.
On("GetConfig", toolkit.MockArg[context.Context]()).
Return(config, nil).
Once()
}
args := append([]string{"vpn", "firewall", "list"}, test.flags...)
if test.wantErr == "" {
h.runAssertStdout(t, test.want, args...)
} else {
err := h.run(t, args...)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), test.wantErr)
}
}
})
}
}