Give 'vpn firewall list' a --staged flag

This commit is contained in:
Brian Picciano 2024-12-09 18:09:45 +01:00
parent b38d780bdf
commit 1608031103
7 changed files with 199 additions and 19 deletions

View File

@ -0,0 +1,42 @@
package main
import (
"errors"
"fmt"
"io/fs"
"isle/jsonutil"
"isle/toolkit"
"os"
"path/filepath"
)
type changeStager struct {
dir toolkit.Dir
}
func (mgr *changeStager) path(name string) string {
return filepath.Join(mgr.dir.Path, name+".json")
}
func (mgr *changeStager) get(into any, name string) (bool, error) {
path := mgr.path(name)
if err := jsonutil.LoadFile(into, path); errors.Is(err, fs.ErrNotExist) {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}
func (mgr *changeStager) set(v any, name string) error {
path := mgr.path(name)
return jsonutil.WriteFile(v, path, 0600)
}
func (mgr *changeStager) del(name string) error {
path := mgr.path(name)
if err := os.Remove(path); err != nil {
return fmt.Errorf("removing file %q: %w", path, err)
}
return nil
}

View File

@ -3,13 +3,16 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"isle/toolkit"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"sync"
"syscall" "syscall"
"dev.mediocregopher.com/mediocre-go-lib.git/mctx" "dev.mediocregopher.com/mediocre-go-lib.git/mctx"
"dev.mediocregopher.com/mediocre-go-lib.git/mlog" "dev.mediocregopher.com/mediocre-go-lib.git/mlog"
"github.com/adrg/xdg"
) )
func getAppDirPath() string { func getAppDirPath() string {
@ -23,6 +26,19 @@ func getAppDirPath() string {
var ( var (
envAppDirPath = getAppDirPath() envAppDirPath = getAppDirPath()
envBinDirPath = filepath.Join(envAppDirPath, "bin") envBinDirPath = filepath.Join(envAppDirPath, "bin")
envCacheDir = sync.OnceValue(func() toolkit.Dir {
cacheHome, err := toolkit.MkDir(xdg.CacheHome, true)
if err != nil {
panic(fmt.Errorf("creating cache directory %q: %w", xdg.CacheHome, err))
}
cacheDir, err := cacheHome.MkChildDir("isle", true)
if err != nil {
panic(fmt.Errorf("creating isle cache directory: %w", err))
}
return cacheDir
})
) )
func binPath(name string) string { func binPath(name string) string {

View File

@ -18,9 +18,10 @@ import (
type runHarness struct { type runHarness struct {
ctx context.Context ctx context.Context
logger *mlog.Logger logger *mlog.Logger
stdout *bytes.Buffer
daemonRPC *daemon.MockRPC daemonRPC *daemon.MockRPC
daemonRPCServer *httptest.Server daemonRPCServer *httptest.Server
stdout *bytes.Buffer
changeStager *changeStager
} }
func newRunHarness(t *testing.T) *runHarness { func newRunHarness(t *testing.T) *runHarness {
@ -29,18 +30,22 @@ func newRunHarness(t *testing.T) *runHarness {
var ( var (
ctx = context.Background() ctx = context.Background()
logger = toolkit.NewTestLogger(t) logger = toolkit.NewTestLogger(t)
stdout = new(bytes.Buffer)
daemonRPC = daemon.NewMockRPC(t) daemonRPC = daemon.NewMockRPC(t)
daemonRPCHandler = jsonrpc2.NewHTTPHandler(daemon.NewRPCHandler( daemonRPCHandler = jsonrpc2.NewHTTPHandler(daemon.NewRPCHandler(
logger.WithNamespace("rpc"), daemonRPC, logger.WithNamespace("rpc"), daemonRPC,
)) ))
daemonRPCServer = httptest.NewServer(daemonRPCHandler) daemonRPCServer = httptest.NewServer(daemonRPCHandler)
stdout = new(bytes.Buffer)
changeStager = &changeStager{toolkit.TempDir(t)}
) )
t.Cleanup(daemonRPCServer.Close) t.Cleanup(daemonRPCServer.Close)
return &runHarness{ctx, logger, stdout, daemonRPC, daemonRPCServer} return &runHarness{
ctx, logger, daemonRPC, daemonRPCServer, stdout, changeStager,
}
} }
func (h *runHarness) run(t *testing.T, args ...string) error { func (h *runHarness) run(t *testing.T, args ...string) error {
@ -54,9 +59,10 @@ func (h *runHarness) run(t *testing.T, args ...string) error {
) )
return doRootCmd(h.ctx, h.logger, &subCmdCtxOpts{ return doRootCmd(h.ctx, h.logger, &subCmdCtxOpts{
args: args, args: args,
daemonRPC: daemonRPCClient, daemonRPC: daemonRPCClient,
stdout: h.stdout, stdout: h.stdout,
changeStager: h.changeStager,
}) })
} }

View File

@ -36,10 +36,11 @@ type subCmd struct {
} }
type subCmdCtxOpts struct { type subCmdCtxOpts struct {
args []string // command-line arguments, excluding the subCmd itself. args []string // command-line arguments, excluding the subCmd itself.
subCmdNames []string // names of subCmds so far, including this one subCmdNames []string // names of subCmds so far, including this one
daemonRPC daemon.RPC daemonRPC daemon.RPC
stdout io.Writer stdout io.Writer
changeStager *changeStager
} }
func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts { func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts {
@ -55,6 +56,10 @@ func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts {
o.stdout = os.Stdout o.stdout = os.Stdout
} }
if o.changeStager == nil {
o.changeStager = &changeStager{envCacheDir()}
}
return o return o
} }

View File

@ -1,10 +1,13 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"isle/daemon/daecommon" "isle/daemon/daecommon"
) )
const vpnFirewallConfigChangeStagerName = "vpn-firewall-config"
type firewallRuleView struct { type firewallRuleView struct {
Index int `yaml:"index"` Index int `yaml:"index"`
daecommon.ConfigFirewallRule `yaml:",inline"` daecommon.ConfigFirewallRule `yaml:",inline"`
@ -39,17 +42,35 @@ var subCmdVPNFirewallList = subCmd{
name: "list", name: "list",
descr: "List all currently configured firewall rules", descr: "List all currently configured firewall rules",
do: doWithOutput(func(ctx subCmdCtx) (any, error) { do: doWithOutput(func(ctx subCmdCtx) (any, error) {
staged := ctx.flags.Bool(
"staged",
false,
"Return the firewall configuration with staged changes included",
)
ctx, err := ctx.withParsedFlags() ctx, err := ctx.withParsedFlags()
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing flags: %w", err) return nil, fmt.Errorf("parsing flags: %w", err)
} }
config, err := ctx.getDaemonRPC().GetConfig(ctx) var firewallConfig daecommon.ConfigFirewall
if err != nil { if !*staged {
return nil, fmt.Errorf("getting network config: %w", err) config, err := ctx.getDaemonRPC().GetConfig(ctx)
if err != nil {
return nil, fmt.Errorf("getting network config: %w", err)
}
firewallConfig = config.VPN.Firewall
} else if ok, err := ctx.opts.changeStager.get(
&firewallConfig, vpnFirewallConfigChangeStagerName,
); err != nil {
return nil, fmt.Errorf("checking for staged changes: %w", err)
} else if !ok {
return nil, errors.New("no firewall configuration changes have been staged")
} }
return newFirewallView(config.VPN.Firewall), nil return newFirewallView(firewallConfig), nil
}), }),
} }

View File

@ -5,10 +5,13 @@ import (
"encoding/json" "encoding/json"
"isle/daemon/daecommon" "isle/daemon/daecommon"
"isle/toolkit" "isle/toolkit"
"os"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
) )
func TestVPNFirewallList(t *testing.T) { func TestVPNFirewallList(t *testing.T) {
@ -17,7 +20,10 @@ func TestVPNFirewallList(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
outbound, inbound []string outbound, inbound []string
staged string
flags []string
want map[string][]any want map[string][]any
wantErr string
}{ }{
{ {
name: "empty", name: "empty",
@ -77,6 +83,64 @@ func TestVPNFirewallList(t *testing.T) {
}, },
}, },
}, },
{
name: "staged/nothing staged",
flags: []string{"--staged"},
wantErr: "no firewall configuration changes have been staged",
},
{
name: "staged/staged but no flag",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
staged: `{
"Inbound": [
{
"Port":"80",
"Proto":"tcp",
"Host":"some-host"
}
]
}`,
want: map[string][]any{
"outbound": {
map[string]any{
"index": 0,
"port": "any",
"proto": "icmp",
"host": "any",
},
},
"inbound": {},
},
},
{
name: "staged/staged with flag",
outbound: []string{
`{"port":"any","proto":"icmp","host":"any"}`,
},
staged: `{
"Inbound": [
{
"Port":"80",
"Proto":"tcp",
"Host":"some-host"
}
]
}`,
flags: []string{"--staged"},
want: map[string][]any{
"outbound": {},
"inbound": {
map[string]any{
"index": 0,
"port": "80",
"proto": "tcp",
"host": "some-host",
},
},
},
},
} }
for _, test := range tests { for _, test := range tests {
@ -89,6 +153,15 @@ func TestVPNFirewallList(t *testing.T) {
inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]" inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]"
) )
if test.staged != "" {
require.True(t, json.Valid([]byte(test.staged)))
require.NoError(t, os.WriteFile(
h.changeStager.path(vpnFirewallConfigChangeStagerName),
[]byte(test.staged),
0600,
))
}
require.NoError(t, json.Unmarshal( require.NoError(t, json.Unmarshal(
[]byte(outboundRawJSON), &config.VPN.Firewall.Outbound, []byte(outboundRawJSON), &config.VPN.Firewall.Outbound,
)) ))
@ -97,12 +170,23 @@ func TestVPNFirewallList(t *testing.T) {
[]byte(inboundRawJSON), &config.VPN.Firewall.Inbound, []byte(inboundRawJSON), &config.VPN.Firewall.Inbound,
)) ))
h.daemonRPC. if !slices.Contains(test.flags, "--staged") {
On("GetConfig", toolkit.MockArg[context.Context]()). h.daemonRPC.
Return(config, nil). On("GetConfig", toolkit.MockArg[context.Context]()).
Once() Return(config, nil).
Once()
}
h.runAssertStdout(t, test.want, "vpn", "firewall", "list") args := append([]string{"vpn", "firewall", "list"}, test.flags...)
if test.wantErr == "" {
h.runAssertStdout(t, test.want, args...)
} else {
err := h.run(t, args...)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), test.wantErr)
}
}
}) })
} }
} }

View File

@ -6,6 +6,7 @@ import (
"io/fs" "io/fs"
"os" "os"
"path/filepath" "path/filepath"
"testing"
) )
// Dir is a type which makes it possible to statically assert that a directory // Dir is a type which makes it possible to statically assert that a directory
@ -58,6 +59,11 @@ func MkDir(path string, mayExist bool) (Dir, error) {
return Dir{path}, nil return Dir{path}, nil
} }
// TempDir returns a Dir based on a temporary directory created via [t.TempDir].
func TempDir(t *testing.T) Dir {
return Dir{t.TempDir()}
}
// MkChildDir is a helper for joining Dir's path to the given name and calling // MkChildDir is a helper for joining Dir's path to the given name and calling
// MkDir with the result. // MkDir with the result.
func (d Dir) MkChildDir(name string, mayExist bool) (Dir, error) { func (d Dir) MkChildDir(name string, mayExist bool) (Dir, error) {