diff --git a/go/daemon/config.go b/go/daemon/config.go index ee17058..8097f39 100644 --- a/go/daemon/config.go +++ b/go/daemon/config.go @@ -1,11 +1,13 @@ package daemon import ( + "context" "errors" "fmt" "io/fs" "isle/bootstrap" "isle/daemon/daecommon" + "isle/daemon/network" "os" "path/filepath" "slices" @@ -41,15 +43,16 @@ var HTTPSocketPath = sync.OnceValue(func() string { }) func pickNetworkConfig( - daemonConfig daecommon.Config, creationParams bootstrap.CreationParams, + networkConfigs map[string]daecommon.NetworkConfig, + creationParams bootstrap.CreationParams, ) *daecommon.NetworkConfig { - if len(daemonConfig.Networks) == 1 { // DEPRECATED - if c, ok := daemonConfig.Networks[daecommon.DeprecatedNetworkID]; ok { + if len(networkConfigs) == 1 { // DEPRECATED + if c, ok := networkConfigs[daecommon.DeprecatedNetworkID]; ok { return &c } } - for searchStr, networkConfig := range daemonConfig.Networks { + for searchStr, networkConfig := range networkConfigs { if creationParams.Matches(searchStr) { return &networkConfig } @@ -58,6 +61,34 @@ func pickNetworkConfig( return nil } +func validateConfig( + ctx context.Context, + networkLoader network.Loader, + daemonConfig daecommon.Config, + loadableNetworks []bootstrap.CreationParams, +) error { + givenConfigs := daemonConfig.Networks + daemonConfig.Networks = map[string]daecommon.NetworkConfig{} + + for _, creationParams := range loadableNetworks { + id := creationParams.ID + + if c := pickNetworkConfig(givenConfigs, creationParams); c != nil { + daemonConfig.Networks[id] = *c + continue + } + + c, err := networkLoader.StoredConfig(ctx, id) + if err != nil { + return fmt.Errorf("getting stored config for %q: %w", id, err) + } + + daemonConfig.Networks[id] = c + } + + return daemonConfig.Validate() +} + //////////////////////////////////////////////////////////////////////////////// // Jigs diff --git a/go/daemon/daecommon/config.go b/go/daemon/daecommon/config.go index d7db927..5c60933 100644 --- a/go/daemon/daecommon/config.go +++ b/go/daemon/daecommon/config.go @@ -8,6 +8,7 @@ import ( "isle/toolkit" "isle/yamlutil" "net" + "sort" _ "embed" @@ -210,10 +211,13 @@ func (c Config) Validate() error { "invalid vpn.public_addr %q: %w", network.VPN.PublicAddr, err, ) } else if otherID, ok := nebulaPorts[port]; ok { + ids := []string{id, otherID} + sort.Strings(ids) + return fmt.Errorf( "two networks with the same vpn.public_addr port: %q and %q", - id, - otherID, + ids[0], + ids[1], ) } diff --git a/go/daemon/daecommon/config_test.go b/go/daemon/daecommon/config_test.go index 2bd0f51..9907052 100644 --- a/go/daemon/daecommon/config_test.go +++ b/go/daemon/daecommon/config_test.go @@ -71,7 +71,7 @@ func TestConfig_UnmarshalYAML(t *testing.T) { bar: {"vpn":{"public_addr":"1.1.1.1:4000"}} baz: {"vpn":{"public_addr":"2.2.2.2:4000"}} `, - wantErr: `two networks with the same vpn.public_addr port: "baz" and "bar"`, + wantErr: `two networks with the same vpn.public_addr port: "bar" and "baz"`, }, { label: "err/invalid firewall", diff --git a/go/daemon/daemon.go b/go/daemon/daemon.go index d3b399f..633d91b 100644 --- a/go/daemon/daemon.go +++ b/go/daemon/daemon.go @@ -19,6 +19,7 @@ import ( var _ RPC = (*Daemon)(nil) type joinedNetwork struct { + id string network.Network config *daecommon.NetworkConfig } @@ -69,10 +70,18 @@ func New( return nil, fmt.Errorf("listing loadable networks: %w", err) } + if err := validateConfig( + ctx, networkLoader, daemonConfig, loadableNetworks, + ); err != nil { + return nil, fmt.Errorf("validating daemon config: %w", err) + } + for _, creationParams := range loadableNetworks { var ( id = creationParams.ID - networkConfig = pickNetworkConfig(daemonConfig, creationParams) + networkConfig = pickNetworkConfig( + daemonConfig.Networks, creationParams, + ) ) n, err := networkLoader.Load( @@ -87,7 +96,7 @@ func New( return nil, fmt.Errorf("loading network %q: %w", id, err) } - d.networks[id] = joinedNetwork{n, networkConfig} + d.networks[id] = joinedNetwork{id, n, networkConfig} } return d, nil @@ -115,8 +124,10 @@ func (d *Daemon) CreateNetwork( var ( creationParams = bootstrap.NewCreationParams(name, domain) - networkConfig = pickNetworkConfig(d.daemonConfig, creationParams) - networkLogger = networkLogger(d.logger, creationParams) + networkConfig = pickNetworkConfig( + d.daemonConfig.Networks, creationParams, + ) + networkLogger = networkLogger(d.logger, creationParams) ) if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil { @@ -141,7 +152,9 @@ func (d *Daemon) CreateNetwork( } networkLogger.Info(ctx, "Network created successfully") - d.networks[creationParams.ID] = joinedNetwork{n, networkConfig} + d.networks[creationParams.ID] = joinedNetwork{ + creationParams.ID, n, networkConfig, + } return nil } @@ -159,8 +172,10 @@ func (d *Daemon) JoinNetwork( var ( creationParams = newBootstrap.Bootstrap.NetworkCreationParams networkID = creationParams.ID - networkConfig = pickNetworkConfig(d.daemonConfig, creationParams) - networkLogger = networkLogger( + networkConfig = pickNetworkConfig( + d.daemonConfig.Networks, creationParams, + ) + networkLogger = networkLogger( d.logger, newBootstrap.Bootstrap.NetworkCreationParams, ) ) @@ -187,7 +202,7 @@ func (d *Daemon) JoinNetwork( } networkLogger.Info(ctx, "Network joined successfully") - d.networks[networkID] = joinedNetwork{n, networkConfig} + d.networks[networkID] = joinedNetwork{networkID, n, networkConfig} return nil } @@ -356,20 +371,46 @@ func (d *Daemon) GetConfig( // SetConfig implements the method for the network.RPC interface. func (d *Daemon) SetConfig( - ctx context.Context, config daecommon.NetworkConfig, + ctx context.Context, networkConfig daecommon.NetworkConfig, ) error { - _, err := withNetwork( - ctx, - d, - func(ctx context.Context, n joinedNetwork) (struct{}, error) { - if n.config != nil { - return struct{}{}, ErrManagedNetworkConfig - } + d.l.RLock() + defer d.l.RUnlock() - return struct{}{}, n.SetConfig(ctx, config) - }, - ) - return err + pickedNetwork, err := pickNetwork(ctx, d.networkLoader, d.networks) + if err != nil { + return err + } + + if pickedNetwork.config != nil { + return ErrManagedNetworkConfig + } + + // Reconstruct the daemon config using the actual in-use network configs, + // along with this new one, and do a validation before calling SetConfig. + + daemonConfig := d.daemonConfig + daemonConfig.Networks = map[string]daecommon.NetworkConfig{ + pickedNetwork.id: networkConfig, + } + + for id, joinedNetwork := range d.networks { + if pickedNetwork.id == id { + continue + } + + joinedNetworkConfig, err := joinedNetwork.GetConfig(ctx) + if err != nil { + return fmt.Errorf("getting network config of %q: %w", id, err) + } + + daemonConfig.Networks[id] = joinedNetworkConfig + } + + if err := daemonConfig.Validate(); err != nil { + return network.ErrInvalidConfig.WithData(err.Error()) + } + + return pickedNetwork.SetConfig(ctx, networkConfig) } // Shutdown blocks until all resources held or created by the daemon, diff --git a/go/daemon/daemon_test.go b/go/daemon/daemon_test.go index c7c0645..afb966e 100644 --- a/go/daemon/daemon_test.go +++ b/go/daemon/daemon_test.go @@ -22,6 +22,7 @@ type expectNetworkLoad struct { type harnessOpts struct { config daecommon.Config expectNetworksLoaded []expectNetworkLoad + expectStoredConfigs map[string]daecommon.NetworkConfig } func (o *harnessOpts) withDefaults() *harnessOpts { @@ -53,6 +54,7 @@ func newHarness(t *testing.T, opts *harnessOpts) *harness { ) for i, expectNetworkLoaded := range opts.expectNetworksLoaded { expectLoadable[i] = expectNetworkLoaded.creationParams + networkLoader. On( "Load", @@ -69,10 +71,16 @@ func newHarness(t *testing.T, opts *harnessOpts) *harness { expectNetworkLoaded.network.On("Shutdown").Return(nil).Once() } + for id, networkConfig := range opts.expectStoredConfigs { + networkLoader. + On("StoredConfig", toolkit.MockArg[context.Context](), id). + Return(networkConfig, nil). + Once() + } + networkLoader. On("Loadable", toolkit.MockArg[context.Context]()). - Return(expectLoadable, nil). - Once() + Return(expectLoadable, nil) daemon, err := New(ctx, logger, networkLoader, opts.config) require.NoError(t, err) @@ -174,18 +182,70 @@ func TestNew(t *testing.T) { network.NewMockNetwork(t), }, }, + expectStoredConfigs: map[string]daecommon.NetworkConfig{ + creationParamsD.ID: daecommon.NewNetworkConfig(nil), + }, }) }) + + t.Run("invalid config", func(t *testing.T) { + var ( + ctx = context.Background() + logger = toolkit.NewTestLogger(t) + networkLoader = network.NewMockLoader(t) + + creationParamsA = bootstrap.NewCreationParams("AAA", "a.com") + creationParamsB = bootstrap.NewCreationParams("BBB", "b.com") + + networkConfigA = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) { + c.VPN.PublicAddr = "1.1.1.1:5" + }) + + networkConfigB = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) { + c.VPN.PublicAddr = "2.2.2.2:5" + }) + + config = daecommon.Config{ + Networks: map[string]daecommon.NetworkConfig{ + creationParamsA.ID: networkConfigA, + }, + } + ) + + networkLoader. + On("Loadable", toolkit.MockArg[context.Context]()). + Return( + []bootstrap.CreationParams{creationParamsA, creationParamsB}, + nil, + ). + Once() + + networkLoader. + On( + "StoredConfig", + toolkit.MockArg[context.Context](), + creationParamsB.ID, + ). + Return(networkConfigB, nil). + Once() + + _, err := New(ctx, logger, networkLoader, config) + assert.ErrorContains(t, err, "two networks with the same vpn.public_addr port") + }) } func TestDaemon_SetConfig(t *testing.T) { t.Run("success", func(t *testing.T) { var ( - networkA = network.NewMockNetwork(t) - h = newHarness(t, &harnessOpts{ + networkA = network.NewMockNetwork(t) + creationParamsA = bootstrap.NewCreationParams("AAA", "a.com") + h = newHarness(t, &harnessOpts{ expectNetworksLoaded: []expectNetworkLoad{{ - bootstrap.NewCreationParams("AAA", "a.com"), nil, networkA, + creationParamsA, nil, networkA, }}, + expectStoredConfigs: map[string]daecommon.NetworkConfig{ + creationParamsA.ID: daecommon.NewNetworkConfig(nil), + }, }) networkConfig = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) { @@ -224,4 +284,38 @@ func TestDaemon_SetConfig(t *testing.T) { err := h.daemon.SetConfig(h.ctx, networkConfig) assert.ErrorIs(t, err, ErrManagedNetworkConfig) }) + + t.Run("ErrInvalidConfig", func(t *testing.T) { + var ( + creationParamsA = bootstrap.NewCreationParams("AAA", "a.com") + networkA = network.NewMockNetwork(t) + networkConfigA = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) { + c.VPN.PublicAddr = "1.2.3.4:5" + }) + + creationParamsB = bootstrap.NewCreationParams("BBB", "b.com") + networkB = network.NewMockNetwork(t) + networkConfigB = daecommon.NewNetworkConfig(nil) + + h = newHarness(t, &harnessOpts{ + expectNetworksLoaded: []expectNetworkLoad{ + {creationParamsA, nil, networkA}, + {creationParamsB, nil, networkB}, + }, + expectStoredConfigs: map[string]daecommon.NetworkConfig{ + creationParamsA.ID: networkConfigA, + creationParamsB.ID: networkConfigB, + }, + }) + ) + + networkA. + On("GetConfig", toolkit.MockArg[context.Context]()). + Return(networkConfigA, nil). + Once() + + networkConfigB.VPN.PublicAddr = "1.1.1.1:5" + err := h.daemon.SetConfig(WithNetwork(h.ctx, "BBB"), networkConfigB) + assert.ErrorIs(t, err, network.ErrInvalidConfig) + }) } diff --git a/go/daemon/network/config.go b/go/daemon/network/config.go index 2eee09e..03e5f85 100644 --- a/go/daemon/network/config.go +++ b/go/daemon/network/config.go @@ -10,6 +10,27 @@ import ( "path/filepath" ) +func configPath(networkStateDir toolkit.Dir) string { + return filepath.Join(networkStateDir.Path, "config.json") +} + +// loadConfig loads the stored NetworkConfig and returns it, or returns the +// default NetworkConfig if none is stored. +func loadConfig(networkStateDir toolkit.Dir) (daecommon.NetworkConfig, error) { + path := configPath(networkStateDir) + + var config daecommon.NetworkConfig + if err := jsonutil.LoadFile(&config, path); errors.Is(err, fs.ErrNotExist) { + return daecommon.NewNetworkConfig(nil), nil + } else if err != nil { + return daecommon.NetworkConfig{}, fmt.Errorf( + "loading %q: %w", path, err, + ) + } + + return config, nil +} + // loadStoreConfig writes the given NetworkConfig to the networkStateDir if the // config is non-nil. If the config is nil then a config is read from // networkStateDir, returning the zero value if no config was previously @@ -19,19 +40,11 @@ func loadStoreConfig( ) ( daecommon.NetworkConfig, error, ) { - path := filepath.Join(networkStateDir.Path, "config.json") - if config == nil { - config = new(daecommon.NetworkConfig) - err := jsonutil.LoadFile(&config, path) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return daecommon.NetworkConfig{}, fmt.Errorf( - "loading %q: %w", path, err, - ) - } - return *config, nil + return loadConfig(networkStateDir) } + path := configPath(networkStateDir) if err := jsonutil.WriteFile(*config, path, 0600); err != nil { return daecommon.NetworkConfig{}, fmt.Errorf( "writing to %q: %w", path, err, diff --git a/go/daemon/network/loader.go b/go/daemon/network/loader.go index 2fb5e0b..4b1789c 100644 --- a/go/daemon/network/loader.go +++ b/go/daemon/network/loader.go @@ -55,6 +55,14 @@ type Loader interface { // Loadable returns the CreationParams for all Networks which can be Loaded. Loadable(context.Context) ([]bootstrap.CreationParams, error) + // StoredConfig returns the NetworkConfig currently stored for the network + // with the given ID, or the default NetworkConfig if none is stored. + StoredConfig( + ctx context.Context, networkID string, + ) ( + daecommon.NetworkConfig, error, + ) + // Load initializes and returns a Network instance for a network which was // previously joined or created, and which has the given CreationParams. Load( @@ -202,6 +210,21 @@ func (l *loader) Loadable( return creationParams, nil } +func (l *loader) StoredConfig( + _ context.Context, networkID string, +) ( + daecommon.NetworkConfig, error, +) { + networkStateDir, err := networkStateDir(l.networksStateDir, networkID, true) + if err != nil { + return daecommon.NetworkConfig{}, fmt.Errorf( + "getting network state dir: %w", err, + ) + } + + return loadConfig(networkStateDir) +} + func (l *loader) Load( ctx context.Context, logger *mlog.Logger, diff --git a/go/daemon/network/loader_mock.go b/go/daemon/network/loader_mock.go index e67321e..f88bc75 100644 --- a/go/daemon/network/loader_mock.go +++ b/go/daemon/network/loader_mock.go @@ -6,6 +6,8 @@ import ( context "context" bootstrap "isle/bootstrap" + daecommon "isle/daemon/daecommon" + mlog "dev.mediocregopher.com/mediocre-go-lib.git/mlog" mock "github.com/stretchr/testify/mock" @@ -138,6 +140,34 @@ func (_m *MockLoader) Loadable(_a0 context.Context) ([]bootstrap.CreationParams, return r0, r1 } +// StoredConfig provides a mock function with given fields: ctx, networkID +func (_m *MockLoader) StoredConfig(ctx context.Context, networkID string) (daecommon.NetworkConfig, error) { + ret := _m.Called(ctx, networkID) + + if len(ret) == 0 { + panic("no return value specified for StoredConfig") + } + + var r0 daecommon.NetworkConfig + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (daecommon.NetworkConfig, error)); ok { + return rf(ctx, networkID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) daecommon.NetworkConfig); ok { + r0 = rf(ctx, networkID) + } else { + r0 = ret.Get(0).(daecommon.NetworkConfig) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, networkID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // NewMockLoader creates a new instance of MockLoader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockLoader(t interface { diff --git a/go/daemon/network/network.go b/go/daemon/network/network.go index 88400b1..c00046c 100644 --- a/go/daemon/network/network.go +++ b/go/daemon/network/network.go @@ -108,6 +108,9 @@ type RPC interface { // SetConfig overrides the current config with the given one, adjusting any // running child processes as needed. + // + // Errors: + // - ErrInvalidConfig SetConfig(context.Context, daecommon.NetworkConfig) error } diff --git a/tasks/v0.0.3/code/setconfig-validate.md b/tasks/v0.0.3/code/setconfig-validate.md deleted file mode 100644 index 6be4590..0000000 --- a/tasks/v0.0.3/code/setconfig-validate.md +++ /dev/null @@ -1,8 +0,0 @@ ---- -type: task ---- - -# Validate daemon Config in SetConfig - -When the `Daemon.SetConfig` method is called, it should call Validate on the -full new `daemon.Config` (not just the NetworkConfig).