diff --git a/go/cmd/entrypoint/sub_cmd.go b/go/cmd/entrypoint/sub_cmd.go index a02f805..a479b5b 100644 --- a/go/cmd/entrypoint/sub_cmd.go +++ b/go/cmd/entrypoint/sub_cmd.go @@ -52,10 +52,6 @@ func (o *subCmdCtxOpts) withDefaults() *subCmdCtxOpts { o.stdout = os.Stdout } - if o.changeStager == nil { - o.changeStager = &changeStager{envCacheDir()} - } - return o } @@ -89,6 +85,14 @@ func newSubCmdCtx( } } +func (ctx subCmdCtx) getChangeStager() *changeStager { + if ctx.opts.changeStager != nil { + return ctx.opts.changeStager + } + + return &changeStager{envCacheDir()} +} + func usagePrefix(subCmdNames []string) string { subCmdNamesStr := strings.Join(subCmdNames, " ") if subCmdNamesStr != "" { diff --git a/go/cmd/entrypoint/vpn_firewall.go b/go/cmd/entrypoint/vpn_firewall.go index f4963b4..5250f1f 100644 --- a/go/cmd/entrypoint/vpn_firewall.go +++ b/go/cmd/entrypoint/vpn_firewall.go @@ -21,7 +21,7 @@ func vpnFirewallGetConfig(ctx subCmdCtx) (daecommon.NetworkConfig, error) { } var firewallConfig daecommon.ConfigFirewall - if ok, err := ctx.opts.changeStager.get( + if ok, err := ctx.getChangeStager().get( &firewallConfig, vpnFirewallConfigChangeStagerName, ); err != nil { return daecommon.NetworkConfig{}, fmt.Errorf( @@ -122,7 +122,7 @@ var subCmdVPNFirewallAdd = subCmd{ return err } - if err := ctx.opts.changeStager.set( + if err := ctx.getChangeStager().set( config.VPN.Firewall, vpnFirewallConfigChangeStagerName, ); err != nil { return fmt.Errorf("staging changes: %w", err) @@ -142,7 +142,7 @@ var subCmdVPNFirewallCommit = subCmd{ } var firewallConfig daecommon.ConfigFirewall - ok, err := ctx.opts.changeStager.get( + ok, err := ctx.getChangeStager().get( &firewallConfig, vpnFirewallConfigChangeStagerName, ) if err != nil { @@ -232,7 +232,7 @@ var subCmdVPNFirewallRemove = subCmd{ return err } - if err := ctx.opts.changeStager.set( + if err := ctx.getChangeStager().set( config.VPN.Firewall, vpnFirewallConfigChangeStagerName, ); err != nil { return fmt.Errorf("staging changes: %w", err) @@ -246,7 +246,7 @@ var subCmdVPNFirewallReset = subCmd{ name: "reset", descr: "Discard all changes which have been staged", do: func(ctx subCmdCtx) error { - return ctx.opts.changeStager.del(vpnFirewallConfigChangeStagerName) + return ctx.getChangeStager().del(vpnFirewallConfigChangeStagerName) }, } @@ -301,7 +301,7 @@ var subCmdVPNFirewallShow = subCmd{ ) if *staged { var err error - if foundStaged, err = ctx.opts.changeStager.get( + if foundStaged, err = ctx.getChangeStager().get( &firewallConfig, vpnFirewallConfigChangeStagerName, ); err != nil { return nil, fmt.Errorf("checking for staged changes: %w", err)