isle/go/cmd/entrypoint/vpn_firewall_test.go

541 lines
12 KiB
Go
Raw Normal View History

package main
import (
"context"
"encoding/json"
2024-12-10 12:52:57 +00:00
"fmt"
"isle/daemon/daecommon"
"isle/toolkit"
"os"
"slices"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
2024-12-10 12:52:57 +00:00
func TestVPNFirewallAdd(t *testing.T) {
t.Parallel()
tests := []struct {
name string
staged *daecommon.ConfigFirewall
flags []string
to string
wantFlagErr string
wantValidateErr string
want daecommon.ConfigFirewallRule
}{
{
name: "flag error/to missing",
wantFlagErr: "--to is required",
},
{
name: "flag error/to invalid",
flags: []string{"--to=what"},
wantFlagErr: "invalid --to value",
},
{
name: "flag error/host and groups given",
flags: []string{"--to=inbound", "--groups=foo,bar", "--host=baz"},
wantFlagErr: "--host and --groups are mutually exclusive",
},
{
name: "validate error/bad port",
flags: []string{"--to=inbound", "--port=80-20"},
wantValidateErr: "start port was lower than end port",
},
{
name: "success/only host",
flags: []string{"--to=inbound", "--host=foo"},
to: "inbound",
want: daecommon.ConfigFirewallRule{
Port: "any",
Proto: "any",
Host: "foo",
},
},
{
name: "success/groups",
flags: []string{"--to=outbound", "--groups=foo,bar", "--groups=baz"},
to: "outbound",
want: daecommon.ConfigFirewallRule{
Port: "any",
Proto: "any",
Groups: []string{"foo", "bar", "baz"},
},
},
{
name: "success/port and proto",
flags: []string{"--to=outbound", "--port=22", "--proto=tcp"},
to: "outbound",
want: daecommon.ConfigFirewallRule{
Port: "22",
Proto: "tcp",
Host: "any",
},
},
{
name: "success/with staged",
staged: &daecommon.ConfigFirewall{
Inbound: []daecommon.ConfigFirewallRule{
{
Port: "1",
Proto: "tcp",
Host: "any",
},
},
},
flags: []string{"--to=inbound", "--port=2"},
to: "inbound",
want: daecommon.ConfigFirewallRule{
Port: "2",
Proto: "any",
Host: "any",
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
h = newRunHarness(t)
config daecommon.NetworkConfig
)
args := append([]string{"vpn", "firewall", "add"}, test.flags...)
if test.wantFlagErr != "" {
2024-12-10 14:17:07 +00:00
h.runAssertErrorContains(t, test.wantFlagErr, args...)
2024-12-10 12:52:57 +00:00
return
}
h.daemonRPC.
On("GetConfig", toolkit.MockArg[context.Context]()).
Return(config, nil).
Once()
if test.staged != nil {
assert.NoError(t, h.changeStager.set(
*test.staged, vpnFirewallConfigChangeStagerName,
))
}
if test.wantValidateErr != "" {
2024-12-10 14:17:07 +00:00
h.runAssertErrorContains(t, test.wantValidateErr, args...)
2024-12-10 12:52:57 +00:00
return
}
wantConfig := config
if test.staged != nil {
wantConfig.VPN.Firewall = *test.staged
}
switch test.to {
case "outbound":
wantConfig.VPN.Firewall.Outbound = append(
wantConfig.VPN.Firewall.Outbound, test.want,
)
case "inbound":
wantConfig.VPN.Firewall.Inbound = append(
wantConfig.VPN.Firewall.Inbound, test.want,
)
default:
panic(fmt.Sprintf("invalid test.to %q", test.to))
}
assert.NoError(t, h.run(t, args...))
h.assertChangeStaged(
t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName,
)
2024-12-10 14:17:07 +00:00
})
}
}
func TestVPNFirewallCommit(t *testing.T) {
t.Parallel()
tests := []struct {
name string
staged *daecommon.ConfigFirewall
}{
{
name: "error/nothing staged",
},
{
name: "success",
staged: &daecommon.ConfigFirewall{
Outbound: []daecommon.ConfigFirewallRule{
{
Port: "any",
Proto: "any",
Host: "any",
},
},
Inbound: []daecommon.ConfigFirewallRule{
{
Port: "22",
Proto: "tcp",
Host: "foo",
},
{
Port: "80",
Proto: "any",
Host: "any",
},
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
h = newRunHarness(t)
config daecommon.NetworkConfig
)
args := []string{"vpn", "firewall", "commit"}
if test.staged == nil {
h.runAssertErrorContains(t, "no changes staged", args...)
return
}
assert.NoError(t, h.changeStager.set(
*test.staged, vpnFirewallConfigChangeStagerName,
))
h.daemonRPC.
On("GetConfig", toolkit.MockArg[context.Context]()).
Return(config, nil).
Once()
config.VPN.Firewall = *test.staged
h.daemonRPC.
On("SetConfig", toolkit.MockArg[context.Context](), config).
Return(nil).
Once()
assert.NoError(t, h.run(t, args...))
})
}
}
2024-12-10 14:17:07 +00:00
func TestVPNFirewallRemove(t *testing.T) {
t.Parallel()
rules := func(hosts ...string) []daecommon.ConfigFirewallRule {
out := make([]daecommon.ConfigFirewallRule, len(hosts))
for i := range hosts {
out[i] = daecommon.ConfigFirewallRule{
Port: "any",
Proto: "any",
Host: hosts[i],
}
}
return out
}
tests := []struct {
name string
outbound, inbound []string
stagedOutbound, stagedInbound []string
flags []string
wantFlagErr string
wantValidateErr string
wantOutbound, wantInbound []string
}{
{
name: "flag error/from missing",
wantFlagErr: "--from and --indexes are required",
},
{
name: "flag error/indexes missing",
flags: []string{"--from=what"},
wantFlagErr: "--from and --indexes are required",
},
{
name: "flag error/from invalid",
flags: []string{"--from=what", "--indexes=1,2,3"},
wantFlagErr: "invalid --from value",
},
{
name: "flag error/indexes invalid",
flags: []string{"--from=inbound", "--indexes=1,-2,3"},
wantFlagErr: "invalid index -2",
},
{
name: "validate error/indexes invalid",
inbound: []string{"foo"},
flags: []string{"--from=inbound", "--indexes=0,3,4"},
wantValidateErr: "invalid index(es): [3 4]",
},
{
name: "validate error/indexes invalid staged",
inbound: []string{"foo", "bar"},
stagedInbound: []string{"foo"},
flags: []string{"--from=inbound", "--indexes=0,1"},
wantValidateErr: "invalid index(es): [1]",
},
{
name: "success/remove inbound single",
inbound: []string{"foo", "bar", "baz"},
flags: []string{"--from=inbound", "--indexes=1"},
wantInbound: []string{"foo", "baz"},
},
{
name: "success/remove outbound multiple",
outbound: []string{"foo", "bar", "baz"},
inbound: []string{"any"},
flags: []string{"--from=outbound", "--indexes=0,2"},
wantOutbound: []string{"bar"},
wantInbound: []string{"any"},
},
{
name: "success/remove staged outbound multiple",
inbound: []string{"foo", "bar"},
outbound: []string{"foo", "bar", "baz"},
stagedOutbound: []string{"foo", "bar", "baz", "biz"},
stagedInbound: []string{"any"},
flags: []string{"--from=outbound", "--indexes=0,2,3"},
wantOutbound: []string{"bar"},
wantInbound: []string{"any"},
},
}
2024-12-10 12:52:57 +00:00
2024-12-10 14:17:07 +00:00
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
h = newRunHarness(t)
config, wantConfig daecommon.NetworkConfig
)
config.VPN.Firewall = daecommon.ConfigFirewall{
Outbound: rules(test.outbound...),
Inbound: rules(test.inbound...),
}
if len(test.stagedOutbound) > 0 || len(test.stagedInbound) > 0 {
assert.NoError(t, h.changeStager.set(
daecommon.ConfigFirewall{
Outbound: rules(test.stagedOutbound...),
Inbound: rules(test.stagedInbound...),
},
vpnFirewallConfigChangeStagerName,
))
}
wantConfig.VPN.Firewall = daecommon.ConfigFirewall{
Outbound: rules(test.wantOutbound...),
Inbound: rules(test.wantInbound...),
}
args := append([]string{"vpn", "firewall", "remove"}, test.flags...)
if test.wantFlagErr != "" {
h.runAssertErrorContains(t, test.wantFlagErr, args...)
return
}
h.daemonRPC.
On("GetConfig", toolkit.MockArg[context.Context]()).
Return(config, nil).
Once()
if test.wantValidateErr != "" {
h.runAssertErrorContains(t, test.wantValidateErr, args...)
return
}
assert.NoError(t, h.run(t, args...))
h.assertChangeStaged(
t, wantConfig.VPN.Firewall, vpnFirewallConfigChangeStagerName,
)
2024-12-10 12:52:57 +00:00
})
}
}
func TestVPNFirewallShow(t *testing.T) {
t.Parallel()
tests := []struct {
name string
outbound, inbound []string
staged string
flags []string
want map[string][]any
}{
{
name: "empty",
want: map[string][]any{
"outbound": {},
"inbound": {},
},
},
{
name: "single",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {},
},
},
{
name: "multiple",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
inbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
`{"port":"22","proto":"tcp","host":"foo"}`,
},
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
map[string]any{
"index": 1,
"port": "22",
"proto": "tcp",
"host": "foo",
},
},
},
},
{
name: "staged/nothing staged",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
flags: []string{"--staged"},
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {},
},
},
{
name: "staged/staged but no flag",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
staged: `{
"Inbound": [
{
"Port":"80",
"Proto":"tcp",
"Host":"some-host"
}
]
}`,
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {},
},
},
{
name: "staged/staged with flag",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
staged: `{
"Inbound": [
{
"Port":"80",
"Proto":"tcp",
"Host":"some-host"
}
]
}`,
flags: []string{"--staged"},
want: map[string][]any{
"outbound": {},
"inbound": {
map[string]any{
"index": 0,
"port": "80",
"proto": "tcp",
"host": "some-host",
},
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
h = newRunHarness(t)
config daecommon.NetworkConfig
outboundRawJSON = "[" + strings.Join(test.outbound, ",") + "]"
inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]"
)
if test.staged != "" {
require.True(t, json.Valid([]byte(test.staged)))
require.NoError(t, os.WriteFile(
h.changeStager.path(vpnFirewallConfigChangeStagerName),
[]byte(test.staged),
0600,
))
}
require.NoError(t, json.Unmarshal(
[]byte(outboundRawJSON), &config.VPN.Firewall.Outbound,
))
require.NoError(t, json.Unmarshal(
[]byte(inboundRawJSON), &config.VPN.Firewall.Inbound,
))
if !slices.Contains(test.flags, "--staged") || test.staged == "" {
h.daemonRPC.
On("GetConfig", toolkit.MockArg[context.Context]()).
Return(config, nil).
Once()
}
args := append([]string{"vpn", "firewall", "show"}, test.flags...)
h.runAssertStdout(t, test.want, args...)
})
}
}