Perform full config validation using stored network configs during init and SetConfig

This commit is contained in:
Brian Picciano 2024-12-14 15:57:07 +01:00
parent 5669123c99
commit 886f76fe0b
10 changed files with 281 additions and 50 deletions

View File

@ -1,11 +1,13 @@
package daemon package daemon
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"isle/bootstrap" "isle/bootstrap"
"isle/daemon/daecommon" "isle/daemon/daecommon"
"isle/daemon/network"
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
@ -41,15 +43,16 @@ var HTTPSocketPath = sync.OnceValue(func() string {
}) })
func pickNetworkConfig( func pickNetworkConfig(
daemonConfig daecommon.Config, creationParams bootstrap.CreationParams, networkConfigs map[string]daecommon.NetworkConfig,
creationParams bootstrap.CreationParams,
) *daecommon.NetworkConfig { ) *daecommon.NetworkConfig {
if len(daemonConfig.Networks) == 1 { // DEPRECATED if len(networkConfigs) == 1 { // DEPRECATED
if c, ok := daemonConfig.Networks[daecommon.DeprecatedNetworkID]; ok { if c, ok := networkConfigs[daecommon.DeprecatedNetworkID]; ok {
return &c return &c
} }
} }
for searchStr, networkConfig := range daemonConfig.Networks { for searchStr, networkConfig := range networkConfigs {
if creationParams.Matches(searchStr) { if creationParams.Matches(searchStr) {
return &networkConfig return &networkConfig
} }
@ -58,6 +61,34 @@ func pickNetworkConfig(
return nil return nil
} }
func validateConfig(
ctx context.Context,
networkLoader network.Loader,
daemonConfig daecommon.Config,
loadableNetworks []bootstrap.CreationParams,
) error {
givenConfigs := daemonConfig.Networks
daemonConfig.Networks = map[string]daecommon.NetworkConfig{}
for _, creationParams := range loadableNetworks {
id := creationParams.ID
if c := pickNetworkConfig(givenConfigs, creationParams); c != nil {
daemonConfig.Networks[id] = *c
continue
}
c, err := networkLoader.StoredConfig(ctx, id)
if err != nil {
return fmt.Errorf("getting stored config for %q: %w", id, err)
}
daemonConfig.Networks[id] = c
}
return daemonConfig.Validate()
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Jigs // Jigs

View File

@ -8,6 +8,7 @@ import (
"isle/toolkit" "isle/toolkit"
"isle/yamlutil" "isle/yamlutil"
"net" "net"
"sort"
_ "embed" _ "embed"
@ -210,10 +211,13 @@ func (c Config) Validate() error {
"invalid vpn.public_addr %q: %w", network.VPN.PublicAddr, err, "invalid vpn.public_addr %q: %w", network.VPN.PublicAddr, err,
) )
} else if otherID, ok := nebulaPorts[port]; ok { } else if otherID, ok := nebulaPorts[port]; ok {
ids := []string{id, otherID}
sort.Strings(ids)
return fmt.Errorf( return fmt.Errorf(
"two networks with the same vpn.public_addr port: %q and %q", "two networks with the same vpn.public_addr port: %q and %q",
id, ids[0],
otherID, ids[1],
) )
} }

View File

@ -71,7 +71,7 @@ func TestConfig_UnmarshalYAML(t *testing.T) {
bar: {"vpn":{"public_addr":"1.1.1.1:4000"}} bar: {"vpn":{"public_addr":"1.1.1.1:4000"}}
baz: {"vpn":{"public_addr":"2.2.2.2:4000"}} baz: {"vpn":{"public_addr":"2.2.2.2:4000"}}
`, `,
wantErr: `two networks with the same vpn.public_addr port: "baz" and "bar"`, wantErr: `two networks with the same vpn.public_addr port: "bar" and "baz"`,
}, },
{ {
label: "err/invalid firewall", label: "err/invalid firewall",

View File

@ -19,6 +19,7 @@ import (
var _ RPC = (*Daemon)(nil) var _ RPC = (*Daemon)(nil)
type joinedNetwork struct { type joinedNetwork struct {
id string
network.Network network.Network
config *daecommon.NetworkConfig config *daecommon.NetworkConfig
} }
@ -69,10 +70,18 @@ func New(
return nil, fmt.Errorf("listing loadable networks: %w", err) return nil, fmt.Errorf("listing loadable networks: %w", err)
} }
if err := validateConfig(
ctx, networkLoader, daemonConfig, loadableNetworks,
); err != nil {
return nil, fmt.Errorf("validating daemon config: %w", err)
}
for _, creationParams := range loadableNetworks { for _, creationParams := range loadableNetworks {
var ( var (
id = creationParams.ID id = creationParams.ID
networkConfig = pickNetworkConfig(daemonConfig, creationParams) networkConfig = pickNetworkConfig(
daemonConfig.Networks, creationParams,
)
) )
n, err := networkLoader.Load( n, err := networkLoader.Load(
@ -87,7 +96,7 @@ func New(
return nil, fmt.Errorf("loading network %q: %w", id, err) return nil, fmt.Errorf("loading network %q: %w", id, err)
} }
d.networks[id] = joinedNetwork{n, networkConfig} d.networks[id] = joinedNetwork{id, n, networkConfig}
} }
return d, nil return d, nil
@ -115,7 +124,9 @@ func (d *Daemon) CreateNetwork(
var ( var (
creationParams = bootstrap.NewCreationParams(name, domain) creationParams = bootstrap.NewCreationParams(name, domain)
networkConfig = pickNetworkConfig(d.daemonConfig, creationParams) networkConfig = pickNetworkConfig(
d.daemonConfig.Networks, creationParams,
)
networkLogger = networkLogger(d.logger, creationParams) networkLogger = networkLogger(d.logger, creationParams)
) )
@ -141,7 +152,9 @@ func (d *Daemon) CreateNetwork(
} }
networkLogger.Info(ctx, "Network created successfully") networkLogger.Info(ctx, "Network created successfully")
d.networks[creationParams.ID] = joinedNetwork{n, networkConfig} d.networks[creationParams.ID] = joinedNetwork{
creationParams.ID, n, networkConfig,
}
return nil return nil
} }
@ -159,7 +172,9 @@ func (d *Daemon) JoinNetwork(
var ( var (
creationParams = newBootstrap.Bootstrap.NetworkCreationParams creationParams = newBootstrap.Bootstrap.NetworkCreationParams
networkID = creationParams.ID networkID = creationParams.ID
networkConfig = pickNetworkConfig(d.daemonConfig, creationParams) networkConfig = pickNetworkConfig(
d.daemonConfig.Networks, creationParams,
)
networkLogger = networkLogger( networkLogger = networkLogger(
d.logger, newBootstrap.Bootstrap.NetworkCreationParams, d.logger, newBootstrap.Bootstrap.NetworkCreationParams,
) )
@ -187,7 +202,7 @@ func (d *Daemon) JoinNetwork(
} }
networkLogger.Info(ctx, "Network joined successfully") networkLogger.Info(ctx, "Network joined successfully")
d.networks[networkID] = joinedNetwork{n, networkConfig} d.networks[networkID] = joinedNetwork{networkID, n, networkConfig}
return nil return nil
} }
@ -356,20 +371,46 @@ func (d *Daemon) GetConfig(
// SetConfig implements the method for the network.RPC interface. // SetConfig implements the method for the network.RPC interface.
func (d *Daemon) SetConfig( func (d *Daemon) SetConfig(
ctx context.Context, config daecommon.NetworkConfig, ctx context.Context, networkConfig daecommon.NetworkConfig,
) error { ) error {
_, err := withNetwork( d.l.RLock()
ctx, defer d.l.RUnlock()
d,
func(ctx context.Context, n joinedNetwork) (struct{}, error) { pickedNetwork, err := pickNetwork(ctx, d.networkLoader, d.networks)
if n.config != nil { if err != nil {
return struct{}{}, ErrManagedNetworkConfig return err
} }
return struct{}{}, n.SetConfig(ctx, config) if pickedNetwork.config != nil {
}, return ErrManagedNetworkConfig
) }
return err
// Reconstruct the daemon config using the actual in-use network configs,
// along with this new one, and do a validation before calling SetConfig.
daemonConfig := d.daemonConfig
daemonConfig.Networks = map[string]daecommon.NetworkConfig{
pickedNetwork.id: networkConfig,
}
for id, joinedNetwork := range d.networks {
if pickedNetwork.id == id {
continue
}
joinedNetworkConfig, err := joinedNetwork.GetConfig(ctx)
if err != nil {
return fmt.Errorf("getting network config of %q: %w", id, err)
}
daemonConfig.Networks[id] = joinedNetworkConfig
}
if err := daemonConfig.Validate(); err != nil {
return network.ErrInvalidConfig.WithData(err.Error())
}
return pickedNetwork.SetConfig(ctx, networkConfig)
} }
// Shutdown blocks until all resources held or created by the daemon, // Shutdown blocks until all resources held or created by the daemon,

View File

@ -22,6 +22,7 @@ type expectNetworkLoad struct {
type harnessOpts struct { type harnessOpts struct {
config daecommon.Config config daecommon.Config
expectNetworksLoaded []expectNetworkLoad expectNetworksLoaded []expectNetworkLoad
expectStoredConfigs map[string]daecommon.NetworkConfig
} }
func (o *harnessOpts) withDefaults() *harnessOpts { func (o *harnessOpts) withDefaults() *harnessOpts {
@ -53,6 +54,7 @@ func newHarness(t *testing.T, opts *harnessOpts) *harness {
) )
for i, expectNetworkLoaded := range opts.expectNetworksLoaded { for i, expectNetworkLoaded := range opts.expectNetworksLoaded {
expectLoadable[i] = expectNetworkLoaded.creationParams expectLoadable[i] = expectNetworkLoaded.creationParams
networkLoader. networkLoader.
On( On(
"Load", "Load",
@ -69,10 +71,16 @@ func newHarness(t *testing.T, opts *harnessOpts) *harness {
expectNetworkLoaded.network.On("Shutdown").Return(nil).Once() expectNetworkLoaded.network.On("Shutdown").Return(nil).Once()
} }
for id, networkConfig := range opts.expectStoredConfigs {
networkLoader.
On("StoredConfig", toolkit.MockArg[context.Context](), id).
Return(networkConfig, nil).
Once()
}
networkLoader. networkLoader.
On("Loadable", toolkit.MockArg[context.Context]()). On("Loadable", toolkit.MockArg[context.Context]()).
Return(expectLoadable, nil). Return(expectLoadable, nil)
Once()
daemon, err := New(ctx, logger, networkLoader, opts.config) daemon, err := New(ctx, logger, networkLoader, opts.config)
require.NoError(t, err) require.NoError(t, err)
@ -174,18 +182,70 @@ func TestNew(t *testing.T) {
network.NewMockNetwork(t), network.NewMockNetwork(t),
}, },
}, },
expectStoredConfigs: map[string]daecommon.NetworkConfig{
creationParamsD.ID: daecommon.NewNetworkConfig(nil),
},
}) })
}) })
t.Run("invalid config", func(t *testing.T) {
var (
ctx = context.Background()
logger = toolkit.NewTestLogger(t)
networkLoader = network.NewMockLoader(t)
creationParamsA = bootstrap.NewCreationParams("AAA", "a.com")
creationParamsB = bootstrap.NewCreationParams("BBB", "b.com")
networkConfigA = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
c.VPN.PublicAddr = "1.1.1.1:5"
})
networkConfigB = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
c.VPN.PublicAddr = "2.2.2.2:5"
})
config = daecommon.Config{
Networks: map[string]daecommon.NetworkConfig{
creationParamsA.ID: networkConfigA,
},
}
)
networkLoader.
On("Loadable", toolkit.MockArg[context.Context]()).
Return(
[]bootstrap.CreationParams{creationParamsA, creationParamsB},
nil,
).
Once()
networkLoader.
On(
"StoredConfig",
toolkit.MockArg[context.Context](),
creationParamsB.ID,
).
Return(networkConfigB, nil).
Once()
_, err := New(ctx, logger, networkLoader, config)
assert.ErrorContains(t, err, "two networks with the same vpn.public_addr port")
})
} }
func TestDaemon_SetConfig(t *testing.T) { func TestDaemon_SetConfig(t *testing.T) {
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
var ( var (
networkA = network.NewMockNetwork(t) networkA = network.NewMockNetwork(t)
creationParamsA = bootstrap.NewCreationParams("AAA", "a.com")
h = newHarness(t, &harnessOpts{ h = newHarness(t, &harnessOpts{
expectNetworksLoaded: []expectNetworkLoad{{ expectNetworksLoaded: []expectNetworkLoad{{
bootstrap.NewCreationParams("AAA", "a.com"), nil, networkA, creationParamsA, nil, networkA,
}}, }},
expectStoredConfigs: map[string]daecommon.NetworkConfig{
creationParamsA.ID: daecommon.NewNetworkConfig(nil),
},
}) })
networkConfig = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) { networkConfig = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
@ -224,4 +284,38 @@ func TestDaemon_SetConfig(t *testing.T) {
err := h.daemon.SetConfig(h.ctx, networkConfig) err := h.daemon.SetConfig(h.ctx, networkConfig)
assert.ErrorIs(t, err, ErrManagedNetworkConfig) assert.ErrorIs(t, err, ErrManagedNetworkConfig)
}) })
t.Run("ErrInvalidConfig", func(t *testing.T) {
var (
creationParamsA = bootstrap.NewCreationParams("AAA", "a.com")
networkA = network.NewMockNetwork(t)
networkConfigA = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
c.VPN.PublicAddr = "1.2.3.4:5"
})
creationParamsB = bootstrap.NewCreationParams("BBB", "b.com")
networkB = network.NewMockNetwork(t)
networkConfigB = daecommon.NewNetworkConfig(nil)
h = newHarness(t, &harnessOpts{
expectNetworksLoaded: []expectNetworkLoad{
{creationParamsA, nil, networkA},
{creationParamsB, nil, networkB},
},
expectStoredConfigs: map[string]daecommon.NetworkConfig{
creationParamsA.ID: networkConfigA,
creationParamsB.ID: networkConfigB,
},
})
)
networkA.
On("GetConfig", toolkit.MockArg[context.Context]()).
Return(networkConfigA, nil).
Once()
networkConfigB.VPN.PublicAddr = "1.1.1.1:5"
err := h.daemon.SetConfig(WithNetwork(h.ctx, "BBB"), networkConfigB)
assert.ErrorIs(t, err, network.ErrInvalidConfig)
})
} }

View File

@ -10,6 +10,27 @@ import (
"path/filepath" "path/filepath"
) )
func configPath(networkStateDir toolkit.Dir) string {
return filepath.Join(networkStateDir.Path, "config.json")
}
// loadConfig loads the stored NetworkConfig and returns it, or returns the
// default NetworkConfig if none is stored.
func loadConfig(networkStateDir toolkit.Dir) (daecommon.NetworkConfig, error) {
path := configPath(networkStateDir)
var config daecommon.NetworkConfig
if err := jsonutil.LoadFile(&config, path); errors.Is(err, fs.ErrNotExist) {
return daecommon.NewNetworkConfig(nil), nil
} else if err != nil {
return daecommon.NetworkConfig{}, fmt.Errorf(
"loading %q: %w", path, err,
)
}
return config, nil
}
// loadStoreConfig writes the given NetworkConfig to the networkStateDir if the // loadStoreConfig writes the given NetworkConfig to the networkStateDir if the
// config is non-nil. If the config is nil then a config is read from // config is non-nil. If the config is nil then a config is read from
// networkStateDir, returning the zero value if no config was previously // networkStateDir, returning the zero value if no config was previously
@ -19,19 +40,11 @@ func loadStoreConfig(
) ( ) (
daecommon.NetworkConfig, error, daecommon.NetworkConfig, error,
) { ) {
path := filepath.Join(networkStateDir.Path, "config.json")
if config == nil { if config == nil {
config = new(daecommon.NetworkConfig) return loadConfig(networkStateDir)
err := jsonutil.LoadFile(&config, path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return daecommon.NetworkConfig{}, fmt.Errorf(
"loading %q: %w", path, err,
)
}
return *config, nil
} }
path := configPath(networkStateDir)
if err := jsonutil.WriteFile(*config, path, 0600); err != nil { if err := jsonutil.WriteFile(*config, path, 0600); err != nil {
return daecommon.NetworkConfig{}, fmt.Errorf( return daecommon.NetworkConfig{}, fmt.Errorf(
"writing to %q: %w", path, err, "writing to %q: %w", path, err,

View File

@ -55,6 +55,14 @@ type Loader interface {
// Loadable returns the CreationParams for all Networks which can be Loaded. // Loadable returns the CreationParams for all Networks which can be Loaded.
Loadable(context.Context) ([]bootstrap.CreationParams, error) Loadable(context.Context) ([]bootstrap.CreationParams, error)
// StoredConfig returns the NetworkConfig currently stored for the network
// with the given ID, or the default NetworkConfig if none is stored.
StoredConfig(
ctx context.Context, networkID string,
) (
daecommon.NetworkConfig, error,
)
// Load initializes and returns a Network instance for a network which was // Load initializes and returns a Network instance for a network which was
// previously joined or created, and which has the given CreationParams. // previously joined or created, and which has the given CreationParams.
Load( Load(
@ -202,6 +210,21 @@ func (l *loader) Loadable(
return creationParams, nil return creationParams, nil
} }
func (l *loader) StoredConfig(
_ context.Context, networkID string,
) (
daecommon.NetworkConfig, error,
) {
networkStateDir, err := networkStateDir(l.networksStateDir, networkID, true)
if err != nil {
return daecommon.NetworkConfig{}, fmt.Errorf(
"getting network state dir: %w", err,
)
}
return loadConfig(networkStateDir)
}
func (l *loader) Load( func (l *loader) Load(
ctx context.Context, ctx context.Context,
logger *mlog.Logger, logger *mlog.Logger,

View File

@ -6,6 +6,8 @@ import (
context "context" context "context"
bootstrap "isle/bootstrap" bootstrap "isle/bootstrap"
daecommon "isle/daemon/daecommon"
mlog "dev.mediocregopher.com/mediocre-go-lib.git/mlog" mlog "dev.mediocregopher.com/mediocre-go-lib.git/mlog"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
@ -138,6 +140,34 @@ func (_m *MockLoader) Loadable(_a0 context.Context) ([]bootstrap.CreationParams,
return r0, r1 return r0, r1
} }
// StoredConfig provides a mock function with given fields: ctx, networkID
func (_m *MockLoader) StoredConfig(ctx context.Context, networkID string) (daecommon.NetworkConfig, error) {
ret := _m.Called(ctx, networkID)
if len(ret) == 0 {
panic("no return value specified for StoredConfig")
}
var r0 daecommon.NetworkConfig
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (daecommon.NetworkConfig, error)); ok {
return rf(ctx, networkID)
}
if rf, ok := ret.Get(0).(func(context.Context, string) daecommon.NetworkConfig); ok {
r0 = rf(ctx, networkID)
} else {
r0 = ret.Get(0).(daecommon.NetworkConfig)
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, networkID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewMockLoader creates a new instance of MockLoader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // NewMockLoader creates a new instance of MockLoader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value. // The first argument is typically a *testing.T value.
func NewMockLoader(t interface { func NewMockLoader(t interface {

View File

@ -108,6 +108,9 @@ type RPC interface {
// SetConfig overrides the current config with the given one, adjusting any // SetConfig overrides the current config with the given one, adjusting any
// running child processes as needed. // running child processes as needed.
//
// Errors:
// - ErrInvalidConfig
SetConfig(context.Context, daecommon.NetworkConfig) error SetConfig(context.Context, daecommon.NetworkConfig) error
} }

View File

@ -1,8 +0,0 @@
---
type: task
---
# Validate daemon Config in SetConfig
When the `Daemon.SetConfig` method is called, it should call Validate on the
full new `daemon.Config` (not just the NetworkConfig).