Give 'vpn firewall list' a --staged flag

This commit is contained in:
Brian Picciano 2024-12-09 18:09:45 +01:00
parent b38d780bdf
commit 1608031103
7 changed files with 199 additions and 19 deletions

View File

@ -0,0 +1,42 @@
package main
import (
"errors"
"fmt"
"io/fs"
"isle/jsonutil"
"isle/toolkit"
"os"
"path/filepath"
)
type changeStager struct {
dir toolkit.Dir
}
func (mgr *changeStager) path(name string) string {
return filepath.Join(mgr.dir.Path, name+".json")
}
func (mgr *changeStager) get(into any, name string) (bool, error) {
path := mgr.path(name)
if err := jsonutil.LoadFile(into, path); errors.Is(err, fs.ErrNotExist) {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}
func (mgr *changeStager) set(v any, name string) error {
path := mgr.path(name)
return jsonutil.WriteFile(v, path, 0600)
}
func (mgr *changeStager) del(name string) error {
path := mgr.path(name)
if err := os.Remove(path); err != nil {
return fmt.Errorf("removing file %q: %w", path, err)
}
return nil
}

View File

@ -3,13 +3,16 @@ package main
import (
"context"
"fmt"
"isle/toolkit"
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"dev.mediocregopher.com/mediocre-go-lib.git/mctx"
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
"github.com/adrg/xdg"
)
func getAppDirPath() string {
@ -23,6 +26,19 @@ func getAppDirPath() string {
var (
envAppDirPath = getAppDirPath()
envBinDirPath = filepath.Join(envAppDirPath, "bin")
envCacheDir = sync.OnceValue(func() toolkit.Dir {
cacheHome, err := toolkit.MkDir(xdg.CacheHome, true)
if err != nil {
panic(fmt.Errorf("creating cache directory %q: %w", xdg.CacheHome, err))
}
cacheDir, err := cacheHome.MkChildDir("isle", true)
if err != nil {
panic(fmt.Errorf("creating isle cache directory: %w", err))
}
return cacheDir
})
)
func binPath(name string) string {

View File

@ -18,9 +18,10 @@ import (
type runHarness struct {
ctx context.Context
logger *mlog.Logger
stdout *bytes.Buffer
daemonRPC *daemon.MockRPC
daemonRPCServer *httptest.Server
stdout *bytes.Buffer
changeStager *changeStager
}
func newRunHarness(t *testing.T) *runHarness {
@ -29,18 +30,22 @@ func newRunHarness(t *testing.T) *runHarness {
var (
ctx = context.Background()
logger = toolkit.NewTestLogger(t)
stdout = new(bytes.Buffer)
daemonRPC = daemon.NewMockRPC(t)
daemonRPCHandler = jsonrpc2.NewHTTPHandler(daemon.NewRPCHandler(
logger.WithNamespace("rpc"), daemonRPC,
))
daemonRPCServer = httptest.NewServer(daemonRPCHandler)
stdout = new(bytes.Buffer)
changeStager = &changeStager{toolkit.TempDir(t)}
)
t.Cleanup(daemonRPCServer.Close)
return &runHarness{ctx, logger, stdout, daemonRPC, daemonRPCServer}
return &runHarness{
ctx, logger, daemonRPC, daemonRPCServer, stdout, changeStager,
}
}
func (h *runHarness) run(t *testing.T, args ...string) error {
@ -57,6 +62,7 @@ func (h *runHarness) run(t *testing.T, args ...string) error {
args: args,
daemonRPC: daemonRPCClient,
stdout: h.stdout,
changeStager: h.changeStager,
})
}

View File

@ -40,6 +40,7 @@ type subCmdCtxOpts struct {
subCmdNames []string // names of subCmds so far, including this one
daemonRPC daemon.RPC
stdout io.Writer
changeStager *changeStager
}
func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts {
@ -55,6 +56,10 @@ func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts {
o.stdout = os.Stdout
}
if o.changeStager == nil {
o.changeStager = &changeStager{envCacheDir()}
}
return o
}

View File

@ -1,10 +1,13 @@
package main
import (
"errors"
"fmt"
"isle/daemon/daecommon"
)
const vpnFirewallConfigChangeStagerName = "vpn-firewall-config"
type firewallRuleView struct {
Index int `yaml:"index"`
daecommon.ConfigFirewallRule `yaml:",inline"`
@ -39,17 +42,35 @@ var subCmdVPNFirewallList = subCmd{
name: "list",
descr: "List all currently configured firewall rules",
do: doWithOutput(func(ctx subCmdCtx) (any, error) {
staged := ctx.flags.Bool(
"staged",
false,
"Return the firewall configuration with staged changes included",
)
ctx, err := ctx.withParsedFlags()
if err != nil {
return nil, fmt.Errorf("parsing flags: %w", err)
}
var firewallConfig daecommon.ConfigFirewall
if !*staged {
config, err := ctx.getDaemonRPC().GetConfig(ctx)
if err != nil {
return nil, fmt.Errorf("getting network config: %w", err)
}
return newFirewallView(config.VPN.Firewall), nil
firewallConfig = config.VPN.Firewall
} else if ok, err := ctx.opts.changeStager.get(
&firewallConfig, vpnFirewallConfigChangeStagerName,
); err != nil {
return nil, fmt.Errorf("checking for staged changes: %w", err)
} else if !ok {
return nil, errors.New("no firewall configuration changes have been staged")
}
return newFirewallView(firewallConfig), nil
}),
}

View File

@ -5,10 +5,13 @@ import (
"encoding/json"
"isle/daemon/daecommon"
"isle/toolkit"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
func TestVPNFirewallList(t *testing.T) {
@ -17,7 +20,10 @@ func TestVPNFirewallList(t *testing.T) {
tests := []struct {
name string
outbound, inbound []string
staged string
flags []string
want map[string][]any
wantErr string
}{
{
name: "empty",
@ -77,6 +83,64 @@ func TestVPNFirewallList(t *testing.T) {
},
},
},
{
name: "staged/nothing staged",
flags: []string{"--staged"},
wantErr: "no firewall configuration changes have been staged",
},
{
name: "staged/staged but no flag",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
staged: `{
"Inbound": [
{
"Port":"80",
"Proto":"tcp",
"Host":"some-host"
}
]
}`,
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {},
},
},
{
name: "staged/staged with flag",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
staged: `{
"Inbound": [
{
"Port":"80",
"Proto":"tcp",
"Host":"some-host"
}
]
}`,
flags: []string{"--staged"},
want: map[string][]any{
"outbound": {},
"inbound": {
map[string]any{
"index": 0,
"port": "80",
"proto": "tcp",
"host": "some-host",
},
},
},
},
}
for _, test := range tests {
@ -89,6 +153,15 @@ func TestVPNFirewallList(t *testing.T) {
inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]"
)
if test.staged != "" {
require.True(t, json.Valid([]byte(test.staged)))
require.NoError(t, os.WriteFile(
h.changeStager.path(vpnFirewallConfigChangeStagerName),
[]byte(test.staged),
0600,
))
}
require.NoError(t, json.Unmarshal(
[]byte(outboundRawJSON), &config.VPN.Firewall.Outbound,
))
@ -97,12 +170,23 @@ func TestVPNFirewallList(t *testing.T) {
[]byte(inboundRawJSON), &config.VPN.Firewall.Inbound,
))
if !slices.Contains(test.flags, "--staged") {
h.daemonRPC.
On("GetConfig", toolkit.MockArg[context.Context]()).
Return(config, nil).
Once()
}
h.runAssertStdout(t, test.want, "vpn", "firewall", "list")
args := append([]string{"vpn", "firewall", "list"}, test.flags...)
if test.wantErr == "" {
h.runAssertStdout(t, test.want, args...)
} else {
err := h.run(t, args...)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), test.wantErr)
}
}
})
}
}

View File

@ -6,6 +6,7 @@ import (
"io/fs"
"os"
"path/filepath"
"testing"
)
// Dir is a type which makes it possible to statically assert that a directory
@ -58,6 +59,11 @@ func MkDir(path string, mayExist bool) (Dir, error) {
return Dir{path}, nil
}
// TempDir returns a Dir based on a temporary directory created via [t.TempDir].
func TempDir(t *testing.T) Dir {
return Dir{t.TempDir()}
}
// MkChildDir is a helper for joining Dir's path to the given name and calling
// MkDir with the result.
func (d Dir) MkChildDir(name string, mayExist bool) (Dir, error) {