Give 'vpn firewall list' a --staged flag
This commit is contained in:
parent
b38d780bdf
commit
1608031103
42
go/cmd/entrypoint/change_stager.go
Normal file
42
go/cmd/entrypoint/change_stager.go
Normal file
@ -0,0 +1,42 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"isle/jsonutil"
|
||||
"isle/toolkit"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type changeStager struct {
|
||||
dir toolkit.Dir
|
||||
}
|
||||
|
||||
func (mgr *changeStager) path(name string) string {
|
||||
return filepath.Join(mgr.dir.Path, name+".json")
|
||||
}
|
||||
|
||||
func (mgr *changeStager) get(into any, name string) (bool, error) {
|
||||
path := mgr.path(name)
|
||||
if err := jsonutil.LoadFile(into, path); errors.Is(err, fs.ErrNotExist) {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (mgr *changeStager) set(v any, name string) error {
|
||||
path := mgr.path(name)
|
||||
return jsonutil.WriteFile(v, path, 0600)
|
||||
}
|
||||
|
||||
func (mgr *changeStager) del(name string) error {
|
||||
path := mgr.path(name)
|
||||
if err := os.Remove(path); err != nil {
|
||||
return fmt.Errorf("removing file %q: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
@ -3,13 +3,16 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"isle/toolkit"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"dev.mediocregopher.com/mediocre-go-lib.git/mctx"
|
||||
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
||||
"github.com/adrg/xdg"
|
||||
)
|
||||
|
||||
func getAppDirPath() string {
|
||||
@ -23,6 +26,19 @@ func getAppDirPath() string {
|
||||
var (
|
||||
envAppDirPath = getAppDirPath()
|
||||
envBinDirPath = filepath.Join(envAppDirPath, "bin")
|
||||
envCacheDir = sync.OnceValue(func() toolkit.Dir {
|
||||
cacheHome, err := toolkit.MkDir(xdg.CacheHome, true)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("creating cache directory %q: %w", xdg.CacheHome, err))
|
||||
}
|
||||
|
||||
cacheDir, err := cacheHome.MkChildDir("isle", true)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("creating isle cache directory: %w", err))
|
||||
}
|
||||
|
||||
return cacheDir
|
||||
})
|
||||
)
|
||||
|
||||
func binPath(name string) string {
|
||||
|
@ -18,9 +18,10 @@ import (
|
||||
type runHarness struct {
|
||||
ctx context.Context
|
||||
logger *mlog.Logger
|
||||
stdout *bytes.Buffer
|
||||
daemonRPC *daemon.MockRPC
|
||||
daemonRPCServer *httptest.Server
|
||||
stdout *bytes.Buffer
|
||||
changeStager *changeStager
|
||||
}
|
||||
|
||||
func newRunHarness(t *testing.T) *runHarness {
|
||||
@ -29,18 +30,22 @@ func newRunHarness(t *testing.T) *runHarness {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
logger = toolkit.NewTestLogger(t)
|
||||
stdout = new(bytes.Buffer)
|
||||
|
||||
daemonRPC = daemon.NewMockRPC(t)
|
||||
daemonRPCHandler = jsonrpc2.NewHTTPHandler(daemon.NewRPCHandler(
|
||||
logger.WithNamespace("rpc"), daemonRPC,
|
||||
))
|
||||
daemonRPCServer = httptest.NewServer(daemonRPCHandler)
|
||||
|
||||
stdout = new(bytes.Buffer)
|
||||
changeStager = &changeStager{toolkit.TempDir(t)}
|
||||
)
|
||||
|
||||
t.Cleanup(daemonRPCServer.Close)
|
||||
|
||||
return &runHarness{ctx, logger, stdout, daemonRPC, daemonRPCServer}
|
||||
return &runHarness{
|
||||
ctx, logger, daemonRPC, daemonRPCServer, stdout, changeStager,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *runHarness) run(t *testing.T, args ...string) error {
|
||||
@ -57,6 +62,7 @@ func (h *runHarness) run(t *testing.T, args ...string) error {
|
||||
args: args,
|
||||
daemonRPC: daemonRPCClient,
|
||||
stdout: h.stdout,
|
||||
changeStager: h.changeStager,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -40,6 +40,7 @@ type subCmdCtxOpts struct {
|
||||
subCmdNames []string // names of subCmds so far, including this one
|
||||
daemonRPC daemon.RPC
|
||||
stdout io.Writer
|
||||
changeStager *changeStager
|
||||
}
|
||||
|
||||
func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts {
|
||||
@ -55,6 +56,10 @@ func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts {
|
||||
o.stdout = os.Stdout
|
||||
}
|
||||
|
||||
if o.changeStager == nil {
|
||||
o.changeStager = &changeStager{envCacheDir()}
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"isle/daemon/daecommon"
|
||||
)
|
||||
|
||||
const vpnFirewallConfigChangeStagerName = "vpn-firewall-config"
|
||||
|
||||
type firewallRuleView struct {
|
||||
Index int `yaml:"index"`
|
||||
daecommon.ConfigFirewallRule `yaml:",inline"`
|
||||
@ -39,17 +42,35 @@ var subCmdVPNFirewallList = subCmd{
|
||||
name: "list",
|
||||
descr: "List all currently configured firewall rules",
|
||||
do: doWithOutput(func(ctx subCmdCtx) (any, error) {
|
||||
staged := ctx.flags.Bool(
|
||||
"staged",
|
||||
false,
|
||||
"Return the firewall configuration with staged changes included",
|
||||
)
|
||||
|
||||
ctx, err := ctx.withParsedFlags()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing flags: %w", err)
|
||||
}
|
||||
|
||||
var firewallConfig daecommon.ConfigFirewall
|
||||
if !*staged {
|
||||
config, err := ctx.getDaemonRPC().GetConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting network config: %w", err)
|
||||
}
|
||||
|
||||
return newFirewallView(config.VPN.Firewall), nil
|
||||
firewallConfig = config.VPN.Firewall
|
||||
|
||||
} else if ok, err := ctx.opts.changeStager.get(
|
||||
&firewallConfig, vpnFirewallConfigChangeStagerName,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("checking for staged changes: %w", err)
|
||||
} else if !ok {
|
||||
return nil, errors.New("no firewall configuration changes have been staged")
|
||||
}
|
||||
|
||||
return newFirewallView(firewallConfig), nil
|
||||
}),
|
||||
}
|
||||
|
||||
|
@ -5,10 +5,13 @@ import (
|
||||
"encoding/json"
|
||||
"isle/daemon/daecommon"
|
||||
"isle/toolkit"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
func TestVPNFirewallList(t *testing.T) {
|
||||
@ -17,7 +20,10 @@ func TestVPNFirewallList(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
outbound, inbound []string
|
||||
staged string
|
||||
flags []string
|
||||
want map[string][]any
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
@ -77,6 +83,64 @@ func TestVPNFirewallList(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "staged/nothing staged",
|
||||
flags: []string{"--staged"},
|
||||
wantErr: "no firewall configuration changes have been staged",
|
||||
},
|
||||
{
|
||||
name: "staged/staged but no flag",
|
||||
outbound: []string{
|
||||
`{"port":"any","proto":"icmp","host":"any"}`,
|
||||
},
|
||||
staged: `{
|
||||
"Inbound": [
|
||||
{
|
||||
"Port":"80",
|
||||
"Proto":"tcp",
|
||||
"Host":"some-host"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
want: map[string][]any{
|
||||
"outbound": {
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"port": "any",
|
||||
"proto": "icmp",
|
||||
"host": "any",
|
||||
},
|
||||
},
|
||||
"inbound": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "staged/staged with flag",
|
||||
outbound: []string{
|
||||
`{"port":"any","proto":"icmp","host":"any"}`,
|
||||
},
|
||||
staged: `{
|
||||
"Inbound": [
|
||||
{
|
||||
"Port":"80",
|
||||
"Proto":"tcp",
|
||||
"Host":"some-host"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
flags: []string{"--staged"},
|
||||
want: map[string][]any{
|
||||
"outbound": {},
|
||||
"inbound": {
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"port": "80",
|
||||
"proto": "tcp",
|
||||
"host": "some-host",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
@ -89,6 +153,15 @@ func TestVPNFirewallList(t *testing.T) {
|
||||
inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]"
|
||||
)
|
||||
|
||||
if test.staged != "" {
|
||||
require.True(t, json.Valid([]byte(test.staged)))
|
||||
require.NoError(t, os.WriteFile(
|
||||
h.changeStager.path(vpnFirewallConfigChangeStagerName),
|
||||
[]byte(test.staged),
|
||||
0600,
|
||||
))
|
||||
}
|
||||
|
||||
require.NoError(t, json.Unmarshal(
|
||||
[]byte(outboundRawJSON), &config.VPN.Firewall.Outbound,
|
||||
))
|
||||
@ -97,12 +170,23 @@ func TestVPNFirewallList(t *testing.T) {
|
||||
[]byte(inboundRawJSON), &config.VPN.Firewall.Inbound,
|
||||
))
|
||||
|
||||
if !slices.Contains(test.flags, "--staged") {
|
||||
h.daemonRPC.
|
||||
On("GetConfig", toolkit.MockArg[context.Context]()).
|
||||
Return(config, nil).
|
||||
Once()
|
||||
}
|
||||
|
||||
h.runAssertStdout(t, test.want, "vpn", "firewall", "list")
|
||||
args := append([]string{"vpn", "firewall", "list"}, test.flags...)
|
||||
|
||||
if test.wantErr == "" {
|
||||
h.runAssertStdout(t, test.want, args...)
|
||||
} else {
|
||||
err := h.run(t, args...)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), test.wantErr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Dir is a type which makes it possible to statically assert that a directory
|
||||
@ -58,6 +59,11 @@ func MkDir(path string, mayExist bool) (Dir, error) {
|
||||
return Dir{path}, nil
|
||||
}
|
||||
|
||||
// TempDir returns a Dir based on a temporary directory created via [t.TempDir].
|
||||
func TempDir(t *testing.T) Dir {
|
||||
return Dir{t.TempDir()}
|
||||
}
|
||||
|
||||
// MkChildDir is a helper for joining Dir's path to the given name and calling
|
||||
// MkDir with the result.
|
||||
func (d Dir) MkChildDir(name string, mayExist bool) (Dir, error) {
|
||||
|
Loading…
Reference in New Issue
Block a user