Perform full config validation using stored network configs during init and SetConfig
This commit is contained in:
parent
5669123c99
commit
886f76fe0b
@ -1,11 +1,13 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"isle/bootstrap"
|
||||
"isle/daemon/daecommon"
|
||||
"isle/daemon/network"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
@ -41,15 +43,16 @@ var HTTPSocketPath = sync.OnceValue(func() string {
|
||||
})
|
||||
|
||||
func pickNetworkConfig(
|
||||
daemonConfig daecommon.Config, creationParams bootstrap.CreationParams,
|
||||
networkConfigs map[string]daecommon.NetworkConfig,
|
||||
creationParams bootstrap.CreationParams,
|
||||
) *daecommon.NetworkConfig {
|
||||
if len(daemonConfig.Networks) == 1 { // DEPRECATED
|
||||
if c, ok := daemonConfig.Networks[daecommon.DeprecatedNetworkID]; ok {
|
||||
if len(networkConfigs) == 1 { // DEPRECATED
|
||||
if c, ok := networkConfigs[daecommon.DeprecatedNetworkID]; ok {
|
||||
return &c
|
||||
}
|
||||
}
|
||||
|
||||
for searchStr, networkConfig := range daemonConfig.Networks {
|
||||
for searchStr, networkConfig := range networkConfigs {
|
||||
if creationParams.Matches(searchStr) {
|
||||
return &networkConfig
|
||||
}
|
||||
@ -58,6 +61,34 @@ func pickNetworkConfig(
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConfig(
|
||||
ctx context.Context,
|
||||
networkLoader network.Loader,
|
||||
daemonConfig daecommon.Config,
|
||||
loadableNetworks []bootstrap.CreationParams,
|
||||
) error {
|
||||
givenConfigs := daemonConfig.Networks
|
||||
daemonConfig.Networks = map[string]daecommon.NetworkConfig{}
|
||||
|
||||
for _, creationParams := range loadableNetworks {
|
||||
id := creationParams.ID
|
||||
|
||||
if c := pickNetworkConfig(givenConfigs, creationParams); c != nil {
|
||||
daemonConfig.Networks[id] = *c
|
||||
continue
|
||||
}
|
||||
|
||||
c, err := networkLoader.StoredConfig(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting stored config for %q: %w", id, err)
|
||||
}
|
||||
|
||||
daemonConfig.Networks[id] = c
|
||||
}
|
||||
|
||||
return daemonConfig.Validate()
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Jigs
|
||||
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"isle/toolkit"
|
||||
"isle/yamlutil"
|
||||
"net"
|
||||
"sort"
|
||||
|
||||
_ "embed"
|
||||
|
||||
@ -210,10 +211,13 @@ func (c Config) Validate() error {
|
||||
"invalid vpn.public_addr %q: %w", network.VPN.PublicAddr, err,
|
||||
)
|
||||
} else if otherID, ok := nebulaPorts[port]; ok {
|
||||
ids := []string{id, otherID}
|
||||
sort.Strings(ids)
|
||||
|
||||
return fmt.Errorf(
|
||||
"two networks with the same vpn.public_addr port: %q and %q",
|
||||
id,
|
||||
otherID,
|
||||
ids[0],
|
||||
ids[1],
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -71,7 +71,7 @@ func TestConfig_UnmarshalYAML(t *testing.T) {
|
||||
bar: {"vpn":{"public_addr":"1.1.1.1:4000"}}
|
||||
baz: {"vpn":{"public_addr":"2.2.2.2:4000"}}
|
||||
`,
|
||||
wantErr: `two networks with the same vpn.public_addr port: "baz" and "bar"`,
|
||||
wantErr: `two networks with the same vpn.public_addr port: "bar" and "baz"`,
|
||||
},
|
||||
{
|
||||
label: "err/invalid firewall",
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
var _ RPC = (*Daemon)(nil)
|
||||
|
||||
type joinedNetwork struct {
|
||||
id string
|
||||
network.Network
|
||||
config *daecommon.NetworkConfig
|
||||
}
|
||||
@ -69,10 +70,18 @@ func New(
|
||||
return nil, fmt.Errorf("listing loadable networks: %w", err)
|
||||
}
|
||||
|
||||
if err := validateConfig(
|
||||
ctx, networkLoader, daemonConfig, loadableNetworks,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("validating daemon config: %w", err)
|
||||
}
|
||||
|
||||
for _, creationParams := range loadableNetworks {
|
||||
var (
|
||||
id = creationParams.ID
|
||||
networkConfig = pickNetworkConfig(daemonConfig, creationParams)
|
||||
networkConfig = pickNetworkConfig(
|
||||
daemonConfig.Networks, creationParams,
|
||||
)
|
||||
)
|
||||
|
||||
n, err := networkLoader.Load(
|
||||
@ -87,7 +96,7 @@ func New(
|
||||
return nil, fmt.Errorf("loading network %q: %w", id, err)
|
||||
}
|
||||
|
||||
d.networks[id] = joinedNetwork{n, networkConfig}
|
||||
d.networks[id] = joinedNetwork{id, n, networkConfig}
|
||||
}
|
||||
|
||||
return d, nil
|
||||
@ -115,8 +124,10 @@ func (d *Daemon) CreateNetwork(
|
||||
|
||||
var (
|
||||
creationParams = bootstrap.NewCreationParams(name, domain)
|
||||
networkConfig = pickNetworkConfig(d.daemonConfig, creationParams)
|
||||
networkLogger = networkLogger(d.logger, creationParams)
|
||||
networkConfig = pickNetworkConfig(
|
||||
d.daemonConfig.Networks, creationParams,
|
||||
)
|
||||
networkLogger = networkLogger(d.logger, creationParams)
|
||||
)
|
||||
|
||||
if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil {
|
||||
@ -141,7 +152,9 @@ func (d *Daemon) CreateNetwork(
|
||||
}
|
||||
|
||||
networkLogger.Info(ctx, "Network created successfully")
|
||||
d.networks[creationParams.ID] = joinedNetwork{n, networkConfig}
|
||||
d.networks[creationParams.ID] = joinedNetwork{
|
||||
creationParams.ID, n, networkConfig,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -159,8 +172,10 @@ func (d *Daemon) JoinNetwork(
|
||||
var (
|
||||
creationParams = newBootstrap.Bootstrap.NetworkCreationParams
|
||||
networkID = creationParams.ID
|
||||
networkConfig = pickNetworkConfig(d.daemonConfig, creationParams)
|
||||
networkLogger = networkLogger(
|
||||
networkConfig = pickNetworkConfig(
|
||||
d.daemonConfig.Networks, creationParams,
|
||||
)
|
||||
networkLogger = networkLogger(
|
||||
d.logger, newBootstrap.Bootstrap.NetworkCreationParams,
|
||||
)
|
||||
)
|
||||
@ -187,7 +202,7 @@ func (d *Daemon) JoinNetwork(
|
||||
}
|
||||
|
||||
networkLogger.Info(ctx, "Network joined successfully")
|
||||
d.networks[networkID] = joinedNetwork{n, networkConfig}
|
||||
d.networks[networkID] = joinedNetwork{networkID, n, networkConfig}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -356,20 +371,46 @@ func (d *Daemon) GetConfig(
|
||||
|
||||
// SetConfig implements the method for the network.RPC interface.
|
||||
func (d *Daemon) SetConfig(
|
||||
ctx context.Context, config daecommon.NetworkConfig,
|
||||
ctx context.Context, networkConfig daecommon.NetworkConfig,
|
||||
) error {
|
||||
_, err := withNetwork(
|
||||
ctx,
|
||||
d,
|
||||
func(ctx context.Context, n joinedNetwork) (struct{}, error) {
|
||||
if n.config != nil {
|
||||
return struct{}{}, ErrManagedNetworkConfig
|
||||
}
|
||||
d.l.RLock()
|
||||
defer d.l.RUnlock()
|
||||
|
||||
return struct{}{}, n.SetConfig(ctx, config)
|
||||
},
|
||||
)
|
||||
return err
|
||||
pickedNetwork, err := pickNetwork(ctx, d.networkLoader, d.networks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pickedNetwork.config != nil {
|
||||
return ErrManagedNetworkConfig
|
||||
}
|
||||
|
||||
// Reconstruct the daemon config using the actual in-use network configs,
|
||||
// along with this new one, and do a validation before calling SetConfig.
|
||||
|
||||
daemonConfig := d.daemonConfig
|
||||
daemonConfig.Networks = map[string]daecommon.NetworkConfig{
|
||||
pickedNetwork.id: networkConfig,
|
||||
}
|
||||
|
||||
for id, joinedNetwork := range d.networks {
|
||||
if pickedNetwork.id == id {
|
||||
continue
|
||||
}
|
||||
|
||||
joinedNetworkConfig, err := joinedNetwork.GetConfig(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting network config of %q: %w", id, err)
|
||||
}
|
||||
|
||||
daemonConfig.Networks[id] = joinedNetworkConfig
|
||||
}
|
||||
|
||||
if err := daemonConfig.Validate(); err != nil {
|
||||
return network.ErrInvalidConfig.WithData(err.Error())
|
||||
}
|
||||
|
||||
return pickedNetwork.SetConfig(ctx, networkConfig)
|
||||
}
|
||||
|
||||
// Shutdown blocks until all resources held or created by the daemon,
|
||||
|
@ -22,6 +22,7 @@ type expectNetworkLoad struct {
|
||||
type harnessOpts struct {
|
||||
config daecommon.Config
|
||||
expectNetworksLoaded []expectNetworkLoad
|
||||
expectStoredConfigs map[string]daecommon.NetworkConfig
|
||||
}
|
||||
|
||||
func (o *harnessOpts) withDefaults() *harnessOpts {
|
||||
@ -53,6 +54,7 @@ func newHarness(t *testing.T, opts *harnessOpts) *harness {
|
||||
)
|
||||
for i, expectNetworkLoaded := range opts.expectNetworksLoaded {
|
||||
expectLoadable[i] = expectNetworkLoaded.creationParams
|
||||
|
||||
networkLoader.
|
||||
On(
|
||||
"Load",
|
||||
@ -69,10 +71,16 @@ func newHarness(t *testing.T, opts *harnessOpts) *harness {
|
||||
expectNetworkLoaded.network.On("Shutdown").Return(nil).Once()
|
||||
}
|
||||
|
||||
for id, networkConfig := range opts.expectStoredConfigs {
|
||||
networkLoader.
|
||||
On("StoredConfig", toolkit.MockArg[context.Context](), id).
|
||||
Return(networkConfig, nil).
|
||||
Once()
|
||||
}
|
||||
|
||||
networkLoader.
|
||||
On("Loadable", toolkit.MockArg[context.Context]()).
|
||||
Return(expectLoadable, nil).
|
||||
Once()
|
||||
Return(expectLoadable, nil)
|
||||
|
||||
daemon, err := New(ctx, logger, networkLoader, opts.config)
|
||||
require.NoError(t, err)
|
||||
@ -174,18 +182,70 @@ func TestNew(t *testing.T) {
|
||||
network.NewMockNetwork(t),
|
||||
},
|
||||
},
|
||||
expectStoredConfigs: map[string]daecommon.NetworkConfig{
|
||||
creationParamsD.ID: daecommon.NewNetworkConfig(nil),
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("invalid config", func(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
logger = toolkit.NewTestLogger(t)
|
||||
networkLoader = network.NewMockLoader(t)
|
||||
|
||||
creationParamsA = bootstrap.NewCreationParams("AAA", "a.com")
|
||||
creationParamsB = bootstrap.NewCreationParams("BBB", "b.com")
|
||||
|
||||
networkConfigA = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
|
||||
c.VPN.PublicAddr = "1.1.1.1:5"
|
||||
})
|
||||
|
||||
networkConfigB = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
|
||||
c.VPN.PublicAddr = "2.2.2.2:5"
|
||||
})
|
||||
|
||||
config = daecommon.Config{
|
||||
Networks: map[string]daecommon.NetworkConfig{
|
||||
creationParamsA.ID: networkConfigA,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
networkLoader.
|
||||
On("Loadable", toolkit.MockArg[context.Context]()).
|
||||
Return(
|
||||
[]bootstrap.CreationParams{creationParamsA, creationParamsB},
|
||||
nil,
|
||||
).
|
||||
Once()
|
||||
|
||||
networkLoader.
|
||||
On(
|
||||
"StoredConfig",
|
||||
toolkit.MockArg[context.Context](),
|
||||
creationParamsB.ID,
|
||||
).
|
||||
Return(networkConfigB, nil).
|
||||
Once()
|
||||
|
||||
_, err := New(ctx, logger, networkLoader, config)
|
||||
assert.ErrorContains(t, err, "two networks with the same vpn.public_addr port")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaemon_SetConfig(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
var (
|
||||
networkA = network.NewMockNetwork(t)
|
||||
h = newHarness(t, &harnessOpts{
|
||||
networkA = network.NewMockNetwork(t)
|
||||
creationParamsA = bootstrap.NewCreationParams("AAA", "a.com")
|
||||
h = newHarness(t, &harnessOpts{
|
||||
expectNetworksLoaded: []expectNetworkLoad{{
|
||||
bootstrap.NewCreationParams("AAA", "a.com"), nil, networkA,
|
||||
creationParamsA, nil, networkA,
|
||||
}},
|
||||
expectStoredConfigs: map[string]daecommon.NetworkConfig{
|
||||
creationParamsA.ID: daecommon.NewNetworkConfig(nil),
|
||||
},
|
||||
})
|
||||
|
||||
networkConfig = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
|
||||
@ -224,4 +284,38 @@ func TestDaemon_SetConfig(t *testing.T) {
|
||||
err := h.daemon.SetConfig(h.ctx, networkConfig)
|
||||
assert.ErrorIs(t, err, ErrManagedNetworkConfig)
|
||||
})
|
||||
|
||||
t.Run("ErrInvalidConfig", func(t *testing.T) {
|
||||
var (
|
||||
creationParamsA = bootstrap.NewCreationParams("AAA", "a.com")
|
||||
networkA = network.NewMockNetwork(t)
|
||||
networkConfigA = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
|
||||
c.VPN.PublicAddr = "1.2.3.4:5"
|
||||
})
|
||||
|
||||
creationParamsB = bootstrap.NewCreationParams("BBB", "b.com")
|
||||
networkB = network.NewMockNetwork(t)
|
||||
networkConfigB = daecommon.NewNetworkConfig(nil)
|
||||
|
||||
h = newHarness(t, &harnessOpts{
|
||||
expectNetworksLoaded: []expectNetworkLoad{
|
||||
{creationParamsA, nil, networkA},
|
||||
{creationParamsB, nil, networkB},
|
||||
},
|
||||
expectStoredConfigs: map[string]daecommon.NetworkConfig{
|
||||
creationParamsA.ID: networkConfigA,
|
||||
creationParamsB.ID: networkConfigB,
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
networkA.
|
||||
On("GetConfig", toolkit.MockArg[context.Context]()).
|
||||
Return(networkConfigA, nil).
|
||||
Once()
|
||||
|
||||
networkConfigB.VPN.PublicAddr = "1.1.1.1:5"
|
||||
err := h.daemon.SetConfig(WithNetwork(h.ctx, "BBB"), networkConfigB)
|
||||
assert.ErrorIs(t, err, network.ErrInvalidConfig)
|
||||
})
|
||||
}
|
||||
|
@ -10,6 +10,27 @@ import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func configPath(networkStateDir toolkit.Dir) string {
|
||||
return filepath.Join(networkStateDir.Path, "config.json")
|
||||
}
|
||||
|
||||
// loadConfig loads the stored NetworkConfig and returns it, or returns the
|
||||
// default NetworkConfig if none is stored.
|
||||
func loadConfig(networkStateDir toolkit.Dir) (daecommon.NetworkConfig, error) {
|
||||
path := configPath(networkStateDir)
|
||||
|
||||
var config daecommon.NetworkConfig
|
||||
if err := jsonutil.LoadFile(&config, path); errors.Is(err, fs.ErrNotExist) {
|
||||
return daecommon.NewNetworkConfig(nil), nil
|
||||
} else if err != nil {
|
||||
return daecommon.NetworkConfig{}, fmt.Errorf(
|
||||
"loading %q: %w", path, err,
|
||||
)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// loadStoreConfig writes the given NetworkConfig to the networkStateDir if the
|
||||
// config is non-nil. If the config is nil then a config is read from
|
||||
// networkStateDir, returning the zero value if no config was previously
|
||||
@ -19,19 +40,11 @@ func loadStoreConfig(
|
||||
) (
|
||||
daecommon.NetworkConfig, error,
|
||||
) {
|
||||
path := filepath.Join(networkStateDir.Path, "config.json")
|
||||
|
||||
if config == nil {
|
||||
config = new(daecommon.NetworkConfig)
|
||||
err := jsonutil.LoadFile(&config, path)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return daecommon.NetworkConfig{}, fmt.Errorf(
|
||||
"loading %q: %w", path, err,
|
||||
)
|
||||
}
|
||||
return *config, nil
|
||||
return loadConfig(networkStateDir)
|
||||
}
|
||||
|
||||
path := configPath(networkStateDir)
|
||||
if err := jsonutil.WriteFile(*config, path, 0600); err != nil {
|
||||
return daecommon.NetworkConfig{}, fmt.Errorf(
|
||||
"writing to %q: %w", path, err,
|
||||
|
@ -55,6 +55,14 @@ type Loader interface {
|
||||
// Loadable returns the CreationParams for all Networks which can be Loaded.
|
||||
Loadable(context.Context) ([]bootstrap.CreationParams, error)
|
||||
|
||||
// StoredConfig returns the NetworkConfig currently stored for the network
|
||||
// with the given ID, or the default NetworkConfig if none is stored.
|
||||
StoredConfig(
|
||||
ctx context.Context, networkID string,
|
||||
) (
|
||||
daecommon.NetworkConfig, error,
|
||||
)
|
||||
|
||||
// Load initializes and returns a Network instance for a network which was
|
||||
// previously joined or created, and which has the given CreationParams.
|
||||
Load(
|
||||
@ -202,6 +210,21 @@ func (l *loader) Loadable(
|
||||
return creationParams, nil
|
||||
}
|
||||
|
||||
func (l *loader) StoredConfig(
|
||||
_ context.Context, networkID string,
|
||||
) (
|
||||
daecommon.NetworkConfig, error,
|
||||
) {
|
||||
networkStateDir, err := networkStateDir(l.networksStateDir, networkID, true)
|
||||
if err != nil {
|
||||
return daecommon.NetworkConfig{}, fmt.Errorf(
|
||||
"getting network state dir: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
return loadConfig(networkStateDir)
|
||||
}
|
||||
|
||||
func (l *loader) Load(
|
||||
ctx context.Context,
|
||||
logger *mlog.Logger,
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
context "context"
|
||||
bootstrap "isle/bootstrap"
|
||||
|
||||
daecommon "isle/daemon/daecommon"
|
||||
|
||||
mlog "dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
@ -138,6 +140,34 @@ func (_m *MockLoader) Loadable(_a0 context.Context) ([]bootstrap.CreationParams,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// StoredConfig provides a mock function with given fields: ctx, networkID
|
||||
func (_m *MockLoader) StoredConfig(ctx context.Context, networkID string) (daecommon.NetworkConfig, error) {
|
||||
ret := _m.Called(ctx, networkID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for StoredConfig")
|
||||
}
|
||||
|
||||
var r0 daecommon.NetworkConfig
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (daecommon.NetworkConfig, error)); ok {
|
||||
return rf(ctx, networkID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) daecommon.NetworkConfig); ok {
|
||||
r0 = rf(ctx, networkID)
|
||||
} else {
|
||||
r0 = ret.Get(0).(daecommon.NetworkConfig)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, networkID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NewMockLoader creates a new instance of MockLoader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockLoader(t interface {
|
||||
|
@ -108,6 +108,9 @@ type RPC interface {
|
||||
|
||||
// SetConfig overrides the current config with the given one, adjusting any
|
||||
// running child processes as needed.
|
||||
//
|
||||
// Errors:
|
||||
// - ErrInvalidConfig
|
||||
SetConfig(context.Context, daecommon.NetworkConfig) error
|
||||
}
|
||||
|
||||
|
@ -1,8 +0,0 @@
|
||||
---
|
||||
type: task
|
||||
---
|
||||
|
||||
# Validate daemon Config in SetConfig
|
||||
|
||||
When the `Daemon.SetConfig` method is called, it should call Validate on the
|
||||
full new `daemon.Config` (not just the NetworkConfig).
|
Loading…
Reference in New Issue
Block a user