Give 'vpn firewall list' a --staged flag
This commit is contained in:
parent
b38d780bdf
commit
1608031103
42
go/cmd/entrypoint/change_stager.go
Normal file
42
go/cmd/entrypoint/change_stager.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"isle/jsonutil"
|
||||||
|
"isle/toolkit"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
type changeStager struct {
|
||||||
|
dir toolkit.Dir
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *changeStager) path(name string) string {
|
||||||
|
return filepath.Join(mgr.dir.Path, name+".json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *changeStager) get(into any, name string) (bool, error) {
|
||||||
|
path := mgr.path(name)
|
||||||
|
if err := jsonutil.LoadFile(into, path); errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return false, nil
|
||||||
|
} else if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *changeStager) set(v any, name string) error {
|
||||||
|
path := mgr.path(name)
|
||||||
|
return jsonutil.WriteFile(v, path, 0600)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *changeStager) del(name string) error {
|
||||||
|
path := mgr.path(name)
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
return fmt.Errorf("removing file %q: %w", path, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -3,13 +3,16 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"isle/toolkit"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"dev.mediocregopher.com/mediocre-go-lib.git/mctx"
|
"dev.mediocregopher.com/mediocre-go-lib.git/mctx"
|
||||||
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
||||||
|
"github.com/adrg/xdg"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getAppDirPath() string {
|
func getAppDirPath() string {
|
||||||
@ -23,6 +26,19 @@ func getAppDirPath() string {
|
|||||||
var (
|
var (
|
||||||
envAppDirPath = getAppDirPath()
|
envAppDirPath = getAppDirPath()
|
||||||
envBinDirPath = filepath.Join(envAppDirPath, "bin")
|
envBinDirPath = filepath.Join(envAppDirPath, "bin")
|
||||||
|
envCacheDir = sync.OnceValue(func() toolkit.Dir {
|
||||||
|
cacheHome, err := toolkit.MkDir(xdg.CacheHome, true)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("creating cache directory %q: %w", xdg.CacheHome, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheDir, err := cacheHome.MkChildDir("isle", true)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("creating isle cache directory: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return cacheDir
|
||||||
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
func binPath(name string) string {
|
func binPath(name string) string {
|
||||||
|
@ -18,9 +18,10 @@ import (
|
|||||||
type runHarness struct {
|
type runHarness struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
logger *mlog.Logger
|
logger *mlog.Logger
|
||||||
stdout *bytes.Buffer
|
|
||||||
daemonRPC *daemon.MockRPC
|
daemonRPC *daemon.MockRPC
|
||||||
daemonRPCServer *httptest.Server
|
daemonRPCServer *httptest.Server
|
||||||
|
stdout *bytes.Buffer
|
||||||
|
changeStager *changeStager
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRunHarness(t *testing.T) *runHarness {
|
func newRunHarness(t *testing.T) *runHarness {
|
||||||
@ -29,18 +30,22 @@ func newRunHarness(t *testing.T) *runHarness {
|
|||||||
var (
|
var (
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
logger = toolkit.NewTestLogger(t)
|
logger = toolkit.NewTestLogger(t)
|
||||||
stdout = new(bytes.Buffer)
|
|
||||||
|
|
||||||
daemonRPC = daemon.NewMockRPC(t)
|
daemonRPC = daemon.NewMockRPC(t)
|
||||||
daemonRPCHandler = jsonrpc2.NewHTTPHandler(daemon.NewRPCHandler(
|
daemonRPCHandler = jsonrpc2.NewHTTPHandler(daemon.NewRPCHandler(
|
||||||
logger.WithNamespace("rpc"), daemonRPC,
|
logger.WithNamespace("rpc"), daemonRPC,
|
||||||
))
|
))
|
||||||
daemonRPCServer = httptest.NewServer(daemonRPCHandler)
|
daemonRPCServer = httptest.NewServer(daemonRPCHandler)
|
||||||
|
|
||||||
|
stdout = new(bytes.Buffer)
|
||||||
|
changeStager = &changeStager{toolkit.TempDir(t)}
|
||||||
)
|
)
|
||||||
|
|
||||||
t.Cleanup(daemonRPCServer.Close)
|
t.Cleanup(daemonRPCServer.Close)
|
||||||
|
|
||||||
return &runHarness{ctx, logger, stdout, daemonRPC, daemonRPCServer}
|
return &runHarness{
|
||||||
|
ctx, logger, daemonRPC, daemonRPCServer, stdout, changeStager,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *runHarness) run(t *testing.T, args ...string) error {
|
func (h *runHarness) run(t *testing.T, args ...string) error {
|
||||||
@ -57,6 +62,7 @@ func (h *runHarness) run(t *testing.T, args ...string) error {
|
|||||||
args: args,
|
args: args,
|
||||||
daemonRPC: daemonRPCClient,
|
daemonRPC: daemonRPCClient,
|
||||||
stdout: h.stdout,
|
stdout: h.stdout,
|
||||||
|
changeStager: h.changeStager,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ type subCmdCtxOpts struct {
|
|||||||
subCmdNames []string // names of subCmds so far, including this one
|
subCmdNames []string // names of subCmds so far, including this one
|
||||||
daemonRPC daemon.RPC
|
daemonRPC daemon.RPC
|
||||||
stdout io.Writer
|
stdout io.Writer
|
||||||
|
changeStager *changeStager
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts {
|
func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts {
|
||||||
@ -55,6 +56,10 @@ func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts {
|
|||||||
o.stdout = os.Stdout
|
o.stdout = os.Stdout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if o.changeStager == nil {
|
||||||
|
o.changeStager = &changeStager{envCacheDir()}
|
||||||
|
}
|
||||||
|
|
||||||
return o
|
return o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"isle/daemon/daecommon"
|
"isle/daemon/daecommon"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const vpnFirewallConfigChangeStagerName = "vpn-firewall-config"
|
||||||
|
|
||||||
type firewallRuleView struct {
|
type firewallRuleView struct {
|
||||||
Index int `yaml:"index"`
|
Index int `yaml:"index"`
|
||||||
daecommon.ConfigFirewallRule `yaml:",inline"`
|
daecommon.ConfigFirewallRule `yaml:",inline"`
|
||||||
@ -39,17 +42,35 @@ var subCmdVPNFirewallList = subCmd{
|
|||||||
name: "list",
|
name: "list",
|
||||||
descr: "List all currently configured firewall rules",
|
descr: "List all currently configured firewall rules",
|
||||||
do: doWithOutput(func(ctx subCmdCtx) (any, error) {
|
do: doWithOutput(func(ctx subCmdCtx) (any, error) {
|
||||||
|
staged := ctx.flags.Bool(
|
||||||
|
"staged",
|
||||||
|
false,
|
||||||
|
"Return the firewall configuration with staged changes included",
|
||||||
|
)
|
||||||
|
|
||||||
ctx, err := ctx.withParsedFlags()
|
ctx, err := ctx.withParsedFlags()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parsing flags: %w", err)
|
return nil, fmt.Errorf("parsing flags: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var firewallConfig daecommon.ConfigFirewall
|
||||||
|
if !*staged {
|
||||||
config, err := ctx.getDaemonRPC().GetConfig(ctx)
|
config, err := ctx.getDaemonRPC().GetConfig(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting network config: %w", err)
|
return nil, fmt.Errorf("getting network config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newFirewallView(config.VPN.Firewall), nil
|
firewallConfig = config.VPN.Firewall
|
||||||
|
|
||||||
|
} else if ok, err := ctx.opts.changeStager.get(
|
||||||
|
&firewallConfig, vpnFirewallConfigChangeStagerName,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("checking for staged changes: %w", err)
|
||||||
|
} else if !ok {
|
||||||
|
return nil, errors.New("no firewall configuration changes have been staged")
|
||||||
|
}
|
||||||
|
|
||||||
|
return newFirewallView(firewallConfig), nil
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,10 +5,13 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"isle/daemon/daecommon"
|
"isle/daemon/daecommon"
|
||||||
"isle/toolkit"
|
"isle/toolkit"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestVPNFirewallList(t *testing.T) {
|
func TestVPNFirewallList(t *testing.T) {
|
||||||
@ -17,7 +20,10 @@ func TestVPNFirewallList(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
outbound, inbound []string
|
outbound, inbound []string
|
||||||
|
staged string
|
||||||
|
flags []string
|
||||||
want map[string][]any
|
want map[string][]any
|
||||||
|
wantErr string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "empty",
|
name: "empty",
|
||||||
@ -77,6 +83,64 @@ func TestVPNFirewallList(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "staged/nothing staged",
|
||||||
|
flags: []string{"--staged"},
|
||||||
|
wantErr: "no firewall configuration changes have been staged",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "staged/staged but no flag",
|
||||||
|
outbound: []string{
|
||||||
|
`{"port":"any","proto":"icmp","host":"any"}`,
|
||||||
|
},
|
||||||
|
staged: `{
|
||||||
|
"Inbound": [
|
||||||
|
{
|
||||||
|
"Port":"80",
|
||||||
|
"Proto":"tcp",
|
||||||
|
"Host":"some-host"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
want: map[string][]any{
|
||||||
|
"outbound": {
|
||||||
|
map[string]any{
|
||||||
|
"index": 0,
|
||||||
|
"port": "any",
|
||||||
|
"proto": "icmp",
|
||||||
|
"host": "any",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"inbound": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "staged/staged with flag",
|
||||||
|
outbound: []string{
|
||||||
|
`{"port":"any","proto":"icmp","host":"any"}`,
|
||||||
|
},
|
||||||
|
staged: `{
|
||||||
|
"Inbound": [
|
||||||
|
{
|
||||||
|
"Port":"80",
|
||||||
|
"Proto":"tcp",
|
||||||
|
"Host":"some-host"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
flags: []string{"--staged"},
|
||||||
|
want: map[string][]any{
|
||||||
|
"outbound": {},
|
||||||
|
"inbound": {
|
||||||
|
map[string]any{
|
||||||
|
"index": 0,
|
||||||
|
"port": "80",
|
||||||
|
"proto": "tcp",
|
||||||
|
"host": "some-host",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@ -89,6 +153,15 @@ func TestVPNFirewallList(t *testing.T) {
|
|||||||
inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]"
|
inboundRawJSON = "[" + strings.Join(test.inbound, ",") + "]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if test.staged != "" {
|
||||||
|
require.True(t, json.Valid([]byte(test.staged)))
|
||||||
|
require.NoError(t, os.WriteFile(
|
||||||
|
h.changeStager.path(vpnFirewallConfigChangeStagerName),
|
||||||
|
[]byte(test.staged),
|
||||||
|
0600,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
require.NoError(t, json.Unmarshal(
|
require.NoError(t, json.Unmarshal(
|
||||||
[]byte(outboundRawJSON), &config.VPN.Firewall.Outbound,
|
[]byte(outboundRawJSON), &config.VPN.Firewall.Outbound,
|
||||||
))
|
))
|
||||||
@ -97,12 +170,23 @@ func TestVPNFirewallList(t *testing.T) {
|
|||||||
[]byte(inboundRawJSON), &config.VPN.Firewall.Inbound,
|
[]byte(inboundRawJSON), &config.VPN.Firewall.Inbound,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
if !slices.Contains(test.flags, "--staged") {
|
||||||
h.daemonRPC.
|
h.daemonRPC.
|
||||||
On("GetConfig", toolkit.MockArg[context.Context]()).
|
On("GetConfig", toolkit.MockArg[context.Context]()).
|
||||||
Return(config, nil).
|
Return(config, nil).
|
||||||
Once()
|
Once()
|
||||||
|
}
|
||||||
|
|
||||||
h.runAssertStdout(t, test.want, "vpn", "firewall", "list")
|
args := append([]string{"vpn", "firewall", "list"}, test.flags...)
|
||||||
|
|
||||||
|
if test.wantErr == "" {
|
||||||
|
h.runAssertStdout(t, test.want, args...)
|
||||||
|
} else {
|
||||||
|
err := h.run(t, args...)
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Contains(t, err.Error(), test.wantErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"io/fs"
|
"io/fs"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Dir is a type which makes it possible to statically assert that a directory
|
// Dir is a type which makes it possible to statically assert that a directory
|
||||||
@ -58,6 +59,11 @@ func MkDir(path string, mayExist bool) (Dir, error) {
|
|||||||
return Dir{path}, nil
|
return Dir{path}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TempDir returns a Dir based on a temporary directory created via [t.TempDir].
|
||||||
|
func TempDir(t *testing.T) Dir {
|
||||||
|
return Dir{t.TempDir()}
|
||||||
|
}
|
||||||
|
|
||||||
// MkChildDir is a helper for joining Dir's path to the given name and calling
|
// MkChildDir is a helper for joining Dir's path to the given name and calling
|
||||||
// MkDir with the result.
|
// MkDir with the result.
|
||||||
func (d Dir) MkChildDir(name string, mayExist bool) (Dir, error) {
|
func (d Dir) MkChildDir(name string, mayExist bool) (Dir, error) {
|
||||||
|
Loading…
Reference in New Issue
Block a user