Don't error from 'vpn firewall show --staged' if nothing is staged, return the live config instead

This commit is contained in:
Brian Picciano 2024-12-10 15:35:13 +01:00
parent dd847cafe1
commit 9b27676521
2 changed files with 33 additions and 24 deletions

View File

@ -257,21 +257,26 @@ var subCmdVPNFirewallShow = subCmd{
return nil, fmt.Errorf("parsing flags: %w", err) return nil, fmt.Errorf("parsing flags: %w", err)
} }
var firewallConfig daecommon.ConfigFirewall var (
if !*staged { firewallConfig daecommon.ConfigFirewall
foundStaged bool
)
if *staged {
var err error
if foundStaged, err = ctx.opts.changeStager.get(
&firewallConfig, vpnFirewallConfigChangeStagerName,
); err != nil {
return nil, fmt.Errorf("checking for staged changes: %w", err)
}
}
if !foundStaged {
config, err := ctx.getDaemonRPC().GetConfig(ctx) config, err := ctx.getDaemonRPC().GetConfig(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting network config: %w", err) return nil, fmt.Errorf("getting network config: %w", err)
} }
firewallConfig = config.VPN.Firewall firewallConfig = config.VPN.Firewall
} else if ok, err := ctx.opts.changeStager.get(
&firewallConfig, vpnFirewallConfigChangeStagerName,
); err != nil {
return nil, fmt.Errorf("checking for staged changes: %w", err)
} else if !ok {
return nil, errors.New("no firewall configuration changes have been staged")
} }
return newFirewallView(firewallConfig), nil return newFirewallView(firewallConfig), nil

View File

@ -7,12 +7,12 @@ import (
"isle/daemon/daecommon" "isle/daemon/daecommon"
"isle/toolkit" "isle/toolkit"
"os" "os"
"slices"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
) )
func TestVPNFirewallAdd(t *testing.T) { func TestVPNFirewallAdd(t *testing.T) {
@ -296,7 +296,6 @@ func TestVPNFirewallShow(t *testing.T) {
staged string staged string
flags []string flags []string
want map[string][]any want map[string][]any
wantErr string
}{ }{
{ {
name: "empty", name: "empty",
@ -357,9 +356,22 @@ func TestVPNFirewallShow(t *testing.T) {
}, },
}, },
{ {
name: "staged/nothing staged", name: "staged/nothing staged",
flags: []string{"--staged"}, outbound: []string{
wantErr: "no firewall configuration changes have been staged", `{"port":"any","proto":"icmp","host":"any"}`,
},
flags: []string{"--staged"},
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {},
},
}, },
{ {
name: "staged/staged but no flag", name: "staged/staged but no flag",
@ -443,7 +455,7 @@ func TestVPNFirewallShow(t *testing.T) {
[]byte(inboundRawJSON), &config.VPN.Firewall.Inbound, []byte(inboundRawJSON), &config.VPN.Firewall.Inbound,
)) ))
if !slices.Contains(test.flags, "--staged") { if !slices.Contains(test.flags, "--staged") || test.staged == "" {
h.daemonRPC. h.daemonRPC.
On("GetConfig", toolkit.MockArg[context.Context]()). On("GetConfig", toolkit.MockArg[context.Context]()).
Return(config, nil). Return(config, nil).
@ -451,15 +463,7 @@ func TestVPNFirewallShow(t *testing.T) {
} }
args := append([]string{"vpn", "firewall", "show"}, test.flags...) args := append([]string{"vpn", "firewall", "show"}, test.flags...)
h.runAssertStdout(t, test.want, args...)
if test.wantErr == "" {
h.runAssertStdout(t, test.want, args...)
} else {
err := h.run(t, args...)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), test.wantErr)
}
}
}) })
} }
} }