From 6ec56f2a8842c0b2b17ffc5577500a8a3991e16d Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Mon, 11 Nov 2024 15:32:15 +0100 Subject: [PATCH] Pass NetworkConfig into Network loaders as an optional argument --- go/daemon/config.go | 10 +- go/daemon/daecommon/config.go | 20 +- go/daemon/daecommon/config_test.go | 72 +++++++ go/daemon/daemon.go | 87 +++++---- go/daemon/daemon_test.go | 227 ++++++++++++++++++++++ go/daemon/errors.go | 8 + go/daemon/network.go | 30 +-- go/daemon/network/config.go | 42 ++++ go/daemon/network/loader.go | 15 +- go/daemon/network/loader_mock.go | 56 +++--- go/daemon/network/network.go | 119 +++++++----- go/daemon/network/network_it_test.go | 91 ++++----- go/daemon/network/network_it_util_test.go | 124 ++++++------ go/daemon/rpc.go | 17 +- go/toolkit/testutils.go | 8 + go/yamlutil/testutils.go | 29 +++ 16 files changed, 691 insertions(+), 264 deletions(-) create mode 100644 go/daemon/daecommon/config_test.go create mode 100644 go/daemon/daemon_test.go create mode 100644 go/daemon/network/config.go create mode 100644 go/yamlutil/testutils.go diff --git a/go/daemon/config.go b/go/daemon/config.go index 62c3641..e26044e 100644 --- a/go/daemon/config.go +++ b/go/daemon/config.go @@ -44,22 +44,20 @@ var HTTPSocketPath = sync.OnceValue(func() string { func pickNetworkConfig( daemonConfig daecommon.Config, creationParams bootstrap.CreationParams, -) ( - daecommon.NetworkConfig, bool, -) { +) *daecommon.NetworkConfig { if len(daemonConfig.Networks) == 1 { // DEPRECATED if c, ok := daemonConfig.Networks[daecommon.DeprecatedNetworkID]; ok { - return c, true + return &c } } for searchStr, networkConfig := range daemonConfig.Networks { if creationParams.Matches(searchStr) { - return networkConfig, true + return &networkConfig } } - return daecommon.NetworkConfig{}, false + return nil } //////////////////////////////////////////////////////////////////////////////// diff --git a/go/daemon/daecommon/config.go b/go/daemon/daecommon/config.go index 6bc874b..0667f7a 100644 --- a/go/daemon/daecommon/config.go +++ b/go/daemon/daecommon/config.go @@ -86,6 +86,18 @@ type NetworkConfig struct { } `yaml:"storage"` } +// NewNetworkConfig returns a new NetworkConfig populated with its default +// values. If a callback is given the NetworkConfig will be passed to it for +// modification prior to having defaults populated. +func NewNetworkConfig(fn func(*NetworkConfig)) NetworkConfig { + var c NetworkConfig + if fn != nil { + fn(&c) + } + c.fillDefaults() + return c +} + func (c *NetworkConfig) fillDefaults() { if c.DNS.Resolvers == nil { c.DNS.Resolvers = []string{ @@ -213,9 +225,15 @@ func CopyDefaultConfig(into io.Writer) error { // fill in default values where it can. func (c *Config) UnmarshalYAML(n *yaml.Node) error { { // DEPRECATED + // We decode into a wrapped NetworkConfig, so that its UnmarshalYAML + // doesn't get invoked. This prevents fillDefaults from getting + // automatically called, so the IsZero check will work as intended. This + // means we need to call fillDefaults manually though. + type wrap NetworkConfig var networkConfig NetworkConfig - _ = n.Decode(&networkConfig) + _ = n.Decode((*wrap)(&networkConfig)) if !toolkit.IsZero(networkConfig) { + networkConfig.fillDefaults() *c = Config{ Networks: map[string]NetworkConfig{ DeprecatedNetworkID: networkConfig, diff --git a/go/daemon/daecommon/config_test.go b/go/daemon/daecommon/config_test.go new file mode 100644 index 0000000..4ccd1b5 --- /dev/null +++ b/go/daemon/daecommon/config_test.go @@ -0,0 +1,72 @@ +package daecommon + +import ( + "isle/yamlutil" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +func TestConfig_UnmarshalYAML(t *testing.T) { // TODO test validation + tests := []struct { + label string + str string + config Config + }{ + {"empty", ``, Config{}}, + { + "DEPRECATED single global network", + ` + {"dns":{"resolvers":["a"]}} + `, + Config{ + Networks: map[string]NetworkConfig{ + DeprecatedNetworkID: NewNetworkConfig(func(c *NetworkConfig) { + c.DNS.Resolvers = []string{"a"} + }), + }, + }, + }, + { + "single network", + ` + networks: + foo: {"dns":{"resolvers":["a"]}} + `, + Config{ + Networks: map[string]NetworkConfig{ + "foo": NewNetworkConfig(func(c *NetworkConfig) { + c.DNS.Resolvers = []string{"a"} + }), + }, + }, + }, + { + "multiple networks", + ` + networks: + foo: {"dns":{"resolvers":["a"]}} + bar: {} + `, + Config{ + Networks: map[string]NetworkConfig{ + "foo": NewNetworkConfig(func(c *NetworkConfig) { + c.DNS.Resolvers = []string{"a"} + }), + "bar": NewNetworkConfig(nil), + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.label, func(t *testing.T) { + test.str = yamlutil.ReplacePrefixTabs(test.str) + var config Config + err := yaml.Unmarshal([]byte(test.str), &config) + assert.NoError(t, err) + assert.Equal(t, test.config, config) + }) + } +} diff --git a/go/daemon/daemon.go b/go/daemon/daemon.go index c8ee78e..481ca9a 100644 --- a/go/daemon/daemon.go +++ b/go/daemon/daemon.go @@ -19,6 +19,11 @@ import ( var _ RPC = (*Daemon)(nil) +type joinedNetwork struct { + network.Network + config *daecommon.NetworkConfig +} + // Daemon implements all methods of the Daemon interface, plus others used // to manage this particular implementation. // @@ -41,7 +46,7 @@ type Daemon struct { daemonConfig daecommon.Config l sync.RWMutex - networks map[string]network.Network + networks map[string]joinedNetwork } // New initializes and returns a Daemon. @@ -57,7 +62,7 @@ func New( logger: logger, networkLoader: networkLoader, daemonConfig: daemonConfig, - networks: map[string]network.Network{}, + networks: map[string]joinedNetwork{}, } loadableNetworks, err := networkLoader.Loadable(ctx) @@ -69,20 +74,23 @@ func New( ctx = mctx.WithAnnotator(ctx, creationParams) var ( - id = creationParams.ID - networkConfig, _ = pickNetworkConfig(daemonConfig, creationParams) + id = creationParams.ID + networkConfig = pickNetworkConfig(daemonConfig, creationParams) ) - d.networks[id], err = networkLoader.Load( + n, err := networkLoader.Load( ctx, logger.WithNamespace("network"), - networkConfig, creationParams, - nil, + &network.Opts{ + Config: networkConfig, + }, ) if err != nil { return nil, fmt.Errorf("loading network %q: %w", id, err) } + + d.networks[id] = joinedNetwork{n, networkConfig} } return d, nil @@ -105,17 +113,16 @@ func (d *Daemon) CreateNetwork( ctx context.Context, name, domain string, ipNet nebula.IPNet, hostName nebula.HostName, ) error { - creationParams := bootstrap.NewCreationParams(name, domain) - ctx = mctx.WithAnnotator(ctx, creationParams) - - networkConfig, ok := pickNetworkConfig(d.daemonConfig, creationParams) - if !ok { - return errors.New("couldn't find network config for network being created") - } - d.l.Lock() defer d.l.Unlock() + var ( + creationParams = bootstrap.NewCreationParams(name, domain) + networkConfig = pickNetworkConfig(d.daemonConfig, creationParams) + ) + + ctx = mctx.WithAnnotator(ctx, creationParams) + if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil { return fmt.Errorf("checking if already joined to network: %w", err) } else if joined { @@ -126,18 +133,19 @@ func (d *Daemon) CreateNetwork( n, err := d.networkLoader.Create( ctx, d.logger.WithNamespace("network"), - networkConfig, creationParams, ipNet, hostName, - nil, + &network.Opts{ + Config: networkConfig, + }, ) if err != nil { return fmt.Errorf("creating network: %w", err) } d.logger.Info(ctx, "Network created successfully") - d.networks[creationParams.ID] = n + d.networks[creationParams.ID] = joinedNetwork{n, networkConfig} return nil } @@ -149,17 +157,17 @@ func (d *Daemon) CreateNetwork( func (d *Daemon) JoinNetwork( ctx context.Context, newBootstrap network.JoiningBootstrap, ) error { + d.l.Lock() + defer d.l.Unlock() + var ( - creationParams = newBootstrap.Bootstrap.NetworkCreationParams - networkConfig, _ = pickNetworkConfig(d.daemonConfig, creationParams) - networkID = creationParams.ID + creationParams = newBootstrap.Bootstrap.NetworkCreationParams + networkID = creationParams.ID + networkConfig = pickNetworkConfig(d.daemonConfig, creationParams) ) ctx = mctx.WithAnnotator(ctx, newBootstrap.Bootstrap.NetworkCreationParams) - d.l.Lock() - defer d.l.Unlock() - if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil { return fmt.Errorf("checking if already joined to network: %w", err) } else if joined { @@ -170,9 +178,10 @@ func (d *Daemon) JoinNetwork( n, err := d.networkLoader.Join( ctx, d.logger.WithNamespace("network"), - networkConfig, newBootstrap, - nil, + &network.Opts{ + Config: networkConfig, + }, ) if err != nil { return fmt.Errorf( @@ -181,14 +190,14 @@ func (d *Daemon) JoinNetwork( } d.logger.Info(ctx, "Network joined successfully") - d.networks[networkID] = n + d.networks[networkID] = joinedNetwork{n, networkConfig} return nil } func withNetwork[Res any]( ctx context.Context, d *Daemon, - fn func(context.Context, network.Network) (Res, error), + fn func(context.Context, joinedNetwork) (Res, error), ) ( Res, error, ) { @@ -238,7 +247,7 @@ func (d *Daemon) GetHosts(ctx context.Context) ([]bootstrap.Host, error) { return withNetwork( ctx, d, - func(ctx context.Context, n network.Network) ([]bootstrap.Host, error) { + func(ctx context.Context, n joinedNetwork) ([]bootstrap.Host, error) { return n.GetHosts(ctx) }, ) @@ -254,7 +263,7 @@ func (d *Daemon) GetGarageClientParams( ctx, d, func( - ctx context.Context, n network.Network, + ctx context.Context, n joinedNetwork, ) ( network.GarageClientParams, error, ) { @@ -274,7 +283,7 @@ func (d *Daemon) GetNebulaCAPublicCredentials( ctx, d, func( - ctx context.Context, n network.Network, + ctx context.Context, n joinedNetwork, ) ( nebula.CAPublicCredentials, error, ) { @@ -289,7 +298,7 @@ func (d *Daemon) RemoveHost(ctx context.Context, hostName nebula.HostName) error ctx, d, func( - ctx context.Context, n network.Network, + ctx context.Context, n joinedNetwork, ) ( struct{}, error, ) { @@ -311,7 +320,7 @@ func (d *Daemon) CreateHost( ctx, d, func( - ctx context.Context, n network.Network, + ctx context.Context, n joinedNetwork, ) ( network.JoiningBootstrap, error, ) { @@ -332,7 +341,7 @@ func (d *Daemon) CreateNebulaCertificate( ctx, d, func( - ctx context.Context, n network.Network, + ctx context.Context, n joinedNetwork, ) ( nebula.Certificate, error, ) { @@ -341,6 +350,7 @@ func (d *Daemon) CreateNebulaCertificate( ) } +// GetConfig implements the method for the network.RPC interface. func (d *Daemon) GetConfig( ctx context.Context, ) ( @@ -350,7 +360,7 @@ func (d *Daemon) GetConfig( ctx, d, func( - ctx context.Context, n network.Network, + ctx context.Context, n joinedNetwork, ) ( daecommon.NetworkConfig, error, ) { @@ -359,13 +369,18 @@ func (d *Daemon) GetConfig( ) } +// SetConfig implements the method for the network.RPC interface. func (d *Daemon) SetConfig( ctx context.Context, config daecommon.NetworkConfig, ) error { _, err := withNetwork( ctx, d, - func(ctx context.Context, n network.Network) (struct{}, error) { + func(ctx context.Context, n joinedNetwork) (struct{}, error) { + if n.config != nil { + return struct{}{}, ErrUserManagedNetworkConfig + } + // TODO needs to check that public addresses aren't being shared // across networks, and whatever else happens in Config.Validate. return struct{}{}, n.SetConfig(ctx, config) diff --git a/go/daemon/daemon_test.go b/go/daemon/daemon_test.go new file mode 100644 index 0000000..58e2367 --- /dev/null +++ b/go/daemon/daemon_test.go @@ -0,0 +1,227 @@ +package daemon + +import ( + "context" + "isle/bootstrap" + "isle/daemon/daecommon" + "isle/daemon/network" + "isle/toolkit" + "testing" + + "dev.mediocregopher.com/mediocre-go-lib.git/mlog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type expectNetworkLoad struct { + creationParams bootstrap.CreationParams + networkConfig *daecommon.NetworkConfig + network *network.MockNetwork +} + +type harnessOpts struct { + config daecommon.Config + expectNetworksLoaded []expectNetworkLoad +} + +func (o *harnessOpts) withDefaults() *harnessOpts { + if o == nil { + o = new(harnessOpts) + } + return o +} + +type harness struct { + ctx context.Context + networkLoader *network.MockLoader + daemon *Daemon +} + +func newHarness(t *testing.T, opts *harnessOpts) *harness { + t.Parallel() + + opts = opts.withDefaults() + + var ( + ctx = context.Background() + logger = toolkit.NewTestLogger(t) + networkLoader = network.NewMockLoader(t) + ) + + expectLoadable := make( + []bootstrap.CreationParams, len(opts.expectNetworksLoaded), + ) + for i, expectNetworkLoaded := range opts.expectNetworksLoaded { + expectLoadable[i] = expectNetworkLoaded.creationParams + networkLoader. + On( + "Load", + toolkit.MockArg[context.Context](), + toolkit.MockArg[*mlog.Logger](), + expectNetworkLoaded.creationParams, + &network.Opts{ + Config: expectNetworkLoaded.networkConfig, + }, + ). + Return(expectNetworkLoaded.network, nil). + Once() + + expectNetworkLoaded.network.On("Shutdown").Return(nil).Once() + } + + networkLoader. + On("Loadable", toolkit.MockArg[context.Context]()). + Return(expectLoadable, nil). + Once() + + daemon, err := New(ctx, logger, networkLoader, opts.config) + require.NoError(t, err) + + t.Cleanup(func() { + t.Log("Shutting down Daemon") + assert.NoError(t, daemon.Shutdown()) + }) + + return &harness{ctx, networkLoader, daemon} +} + +func TestNew(t *testing.T) { + t.Run("no networks loaded", func(t *testing.T) { + _ = newHarness(t, nil) + }) + + t.Run("DEPRECATED network config matching", func(t *testing.T) { + var ( + creationParams = bootstrap.NewCreationParams("AAA", "a.com") + networkConfig = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) { + c.DNS.Resolvers = []string{"foo"} + }) + config = daecommon.Config{ + Networks: map[string]daecommon.NetworkConfig{ + daecommon.DeprecatedNetworkID: networkConfig, + }, + } + ) + + _ = newHarness(t, &harnessOpts{ + config: config, + expectNetworksLoaded: []expectNetworkLoad{ + { + creationParams, + &networkConfig, + network.NewMockNetwork(t), + }, + }, + }) + }) + + t.Run("network config matching", func(t *testing.T) { + var ( + creationParamsA = bootstrap.NewCreationParams("AAA", "a.com") + creationParamsB = bootstrap.NewCreationParams("BBB", "b.com") + creationParamsC = bootstrap.NewCreationParams("CCC", "c.com") + creationParamsD = bootstrap.NewCreationParams("DDD", "d.com") + + networkConfigA = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) { + c.DNS.Resolvers = []string{"foo"} + }) + + networkConfigB = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) { + c.VPN.Tun.Device = "bar" + }) + + networkConfigC = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) { + c.Storage.Allocations = []daecommon.ConfigStorageAllocation{ + { + DataPath: "/path/data", + MetaPath: "/path/meta", + Capacity: 1, + }, + } + }) + + config = daecommon.Config{ + Networks: map[string]daecommon.NetworkConfig{ + creationParamsA.ID: networkConfigA, + creationParamsB.Name: networkConfigB, + creationParamsC.Domain: networkConfigC, + "unknown": {}, + }, + } + ) + + _ = newHarness(t, &harnessOpts{ + config: config, + expectNetworksLoaded: []expectNetworkLoad{ + { + creationParamsA, + &networkConfigA, + network.NewMockNetwork(t), + }, + { + creationParamsB, + &networkConfigB, + network.NewMockNetwork(t), + }, + { + creationParamsC, + &networkConfigC, + network.NewMockNetwork(t), + }, + { + creationParamsD, + nil, + network.NewMockNetwork(t), + }, + }, + }) + }) +} + +func TestDaemon_SetConfig(t *testing.T) { + t.Run("success", func(t *testing.T) { + var ( + networkA = network.NewMockNetwork(t) + h = newHarness(t, &harnessOpts{ + expectNetworksLoaded: []expectNetworkLoad{{ + bootstrap.NewCreationParams("AAA", "a.com"), nil, networkA, + }}, + }) + + networkConfig = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) { + c.VPN.Tun.Device = "foo" + }) + ) + + networkA. + On("SetConfig", toolkit.MockArg[context.Context](), networkConfig). + Return(nil). + Once() + + err := h.daemon.SetConfig(h.ctx, networkConfig) + assert.NoError(t, err) + }) + + t.Run("ErrUserManagedNetworkConfig", func(t *testing.T) { + var ( + creationParams = bootstrap.NewCreationParams("AAA", "a.com") + networkA = network.NewMockNetwork(t) + networkConfig = daecommon.NewNetworkConfig(nil) + + h = newHarness(t, &harnessOpts{ + config: daecommon.Config{ + Networks: map[string]daecommon.NetworkConfig{ + creationParams.Name: networkConfig, + }, + }, + expectNetworksLoaded: []expectNetworkLoad{ + {creationParams, &networkConfig, networkA}, + }, + }) + ) + + networkConfig.VPN.Tun.Device = "foo" + err := h.daemon.SetConfig(h.ctx, networkConfig) + assert.ErrorIs(t, err, ErrUserManagedNetworkConfig) + }) +} diff --git a/go/daemon/errors.go b/go/daemon/errors.go index 1ddd242..1e4b2ce 100644 --- a/go/daemon/errors.go +++ b/go/daemon/errors.go @@ -10,6 +10,7 @@ const ( errCodeAlreadyJoined errCodeNoMatchingNetworks errCodeMultipleMatchingNetworks + errCodeUserManagedNetworkConfig ) var ( @@ -33,4 +34,11 @@ var ( errCodeMultipleMatchingNetworks, "Multiple networks matched the search string", ) + + // ErrUserManagedNetworkConfig is returned when attempting to modify a + // network config which is managed by the user. + ErrUserManagedNetworkConfig = jsonrpc2.NewError( + errCodeUserManagedNetworkConfig, + "Network configuration is managed by the user", + ) ) diff --git a/go/daemon/network.go b/go/daemon/network.go index 3eb7b79..d0d7628 100644 --- a/go/daemon/network.go +++ b/go/daemon/network.go @@ -10,24 +10,30 @@ import ( func pickNetwork( ctx context.Context, networkLoader network.Loader, - networks map[string]network.Network, + networks map[string]joinedNetwork, ) ( - network.Network, error, + joinedNetwork, error, ) { if len(networks) == 0 { - return nil, ErrNoNetwork + return joinedNetwork{}, ErrNoNetwork + } + + networkSearchStr := getNetworkSearchStr(ctx) + if networkSearchStr == "" { + if len(networks) > 1 { + return joinedNetwork{}, ErrNoMatchingNetworks + } + for _, network := range networks { + return network, nil + } } creationParams, err := networkLoader.Loadable(ctx) if err != nil { - return nil, fmt.Errorf("getting loadable networks: %w", err) + return joinedNetwork{}, fmt.Errorf("getting loadable networks: %w", err) } - var ( - networkSearchStr = getNetworkSearchStr(ctx) - matchingNetworkIDs = make([]string, 0, len(networks)) - ) - + matchingNetworkIDs := make([]string, 0, len(networks)) for _, creationParam := range creationParams { if networkSearchStr == "" || creationParam.Matches(networkSearchStr) { matchingNetworkIDs = append(matchingNetworkIDs, creationParam.ID) @@ -35,9 +41,9 @@ func pickNetwork( } if len(matchingNetworkIDs) == 0 { - return nil, ErrNoMatchingNetworks + return joinedNetwork{}, ErrNoMatchingNetworks } else if len(matchingNetworkIDs) > 1 { - return nil, ErrMultipleMatchingNetworks + return joinedNetwork{}, ErrMultipleMatchingNetworks } return networks[matchingNetworkIDs[0]], nil @@ -45,7 +51,7 @@ func pickNetwork( func alreadyJoined( ctx context.Context, - networks map[string]network.Network, + networks map[string]joinedNetwork, creationParams bootstrap.CreationParams, ) ( bool, error, diff --git a/go/daemon/network/config.go b/go/daemon/network/config.go new file mode 100644 index 0000000..2eee09e --- /dev/null +++ b/go/daemon/network/config.go @@ -0,0 +1,42 @@ +package network + +import ( + "errors" + "fmt" + "io/fs" + "isle/daemon/daecommon" + "isle/jsonutil" + "isle/toolkit" + "path/filepath" +) + +// loadStoreConfig writes the given NetworkConfig to the networkStateDir if the +// config is non-nil. If the config is nil then a config is read from +// networkStateDir, returning the zero value if no config was previously +// stored. +func loadStoreConfig( + networkStateDir toolkit.Dir, config *daecommon.NetworkConfig, +) ( + daecommon.NetworkConfig, error, +) { + path := filepath.Join(networkStateDir.Path, "config.json") + + if config == nil { + config = new(daecommon.NetworkConfig) + err := jsonutil.LoadFile(&config, path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return daecommon.NetworkConfig{}, fmt.Errorf( + "loading %q: %w", path, err, + ) + } + return *config, nil + } + + if err := jsonutil.WriteFile(*config, path, 0600); err != nil { + return daecommon.NetworkConfig{}, fmt.Errorf( + "writing to %q: %w", path, err, + ) + } + + return *config, nil +} diff --git a/go/daemon/network/loader.go b/go/daemon/network/loader.go index d6a280e..9bf7d0e 100644 --- a/go/daemon/network/loader.go +++ b/go/daemon/network/loader.go @@ -59,7 +59,6 @@ type Loader interface { Load( context.Context, *mlog.Logger, - daecommon.NetworkConfig, bootstrap.CreationParams, *Opts, ) ( @@ -73,7 +72,6 @@ type Loader interface { Join( context.Context, *mlog.Logger, - daecommon.NetworkConfig, JoiningBootstrap, *Opts, ) ( @@ -91,12 +89,11 @@ type Loader interface { // - hostName: The name of this first host in the network. // // Errors: - // - ErrInvalidConfig - if daemonConfig doesn't have 3 storage allocations - // configured. + // - ErrInvalidConfig - If the Opts.Config field is not valid. It must be + // non-nil and have at least 3 storage allocations. Create( context.Context, *mlog.Logger, - daecommon.NetworkConfig, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, @@ -188,7 +185,7 @@ func (l *loader) Loadable( creationParams := make([]bootstrap.CreationParams, 0, len(networkStateDirs)) for _, networkStateDir := range networkStateDirs { - thisCreationParams, err := LoadCreationParams(networkStateDir) + thisCreationParams, err := loadCreationParams(networkStateDir) if err != nil { return nil, fmt.Errorf( "loading creation params from %q: %w", @@ -205,7 +202,6 @@ func (l *loader) Loadable( func (l *loader) Load( ctx context.Context, logger *mlog.Logger, - networkConfig daecommon.NetworkConfig, creationParams bootstrap.CreationParams, opts *Opts, ) ( @@ -226,7 +222,6 @@ func (l *loader) Load( ctx, logger.WithNamespace("network"), l.envBinDirPath, - networkConfig, networkStateDir, networkRuntimeDir, opts, @@ -236,7 +231,6 @@ func (l *loader) Load( func (l *loader) Join( ctx context.Context, logger *mlog.Logger, - networkConfig daecommon.NetworkConfig, joiningBootstrap JoiningBootstrap, opts *Opts, ) ( @@ -260,7 +254,6 @@ func (l *loader) Join( ctx, logger.WithNamespace("network"), l.envBinDirPath, - networkConfig, joiningBootstrap, networkStateDir, networkRuntimeDir, @@ -271,7 +264,6 @@ func (l *loader) Join( func (l *loader) Create( ctx context.Context, logger *mlog.Logger, - networkConfig daecommon.NetworkConfig, creationParams bootstrap.CreationParams, ipNet nebula.IPNet, hostName nebula.HostName, @@ -294,7 +286,6 @@ func (l *loader) Create( ctx, logger.WithNamespace("network"), l.envBinDirPath, - networkConfig, networkStateDir, networkRuntimeDir, creationParams, diff --git a/go/daemon/network/loader_mock.go b/go/daemon/network/loader_mock.go index 54004f3..e67321e 100644 --- a/go/daemon/network/loader_mock.go +++ b/go/daemon/network/loader_mock.go @@ -6,8 +6,6 @@ import ( context "context" bootstrap "isle/bootstrap" - daecommon "isle/daemon/daecommon" - mlog "dev.mediocregopher.com/mediocre-go-lib.git/mlog" mock "github.com/stretchr/testify/mock" @@ -20,9 +18,9 @@ type MockLoader struct { mock.Mock } -// Create provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5, _a6 -func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.NetworkConfig, _a3 bootstrap.CreationParams, _a4 nebula.IPNet, _a5 nebula.HostName, _a6 *Opts) (Network, error) { - ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5, _a6) +// Create provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5 +func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 bootstrap.CreationParams, _a3 nebula.IPNet, _a4 nebula.HostName, _a5 *Opts) (Network, error) { + ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5) if len(ret) == 0 { panic("no return value specified for Create") @@ -30,19 +28,19 @@ func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommo var r0 Network var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) (Network, error)); ok { - return rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6) + if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) (Network, error)); ok { + return rf(_a0, _a1, _a2, _a3, _a4, _a5) } - if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) Network); ok { - r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6) + if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) Network); ok { + r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(Network) } } - if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) error); ok { - r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6) + if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) error); ok { + r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5) } else { r1 = ret.Error(1) } @@ -50,9 +48,9 @@ func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommo return r0, r1 } -// Join provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4 -func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.NetworkConfig, _a3 JoiningBootstrap, _a4 *Opts) (Network, error) { - ret := _m.Called(_a0, _a1, _a2, _a3, _a4) +// Join provides a mock function with given fields: _a0, _a1, _a2, _a3 +func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 JoiningBootstrap, _a3 *Opts) (Network, error) { + ret := _m.Called(_a0, _a1, _a2, _a3) if len(ret) == 0 { panic("no return value specified for Join") @@ -60,19 +58,19 @@ func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon. var r0 Network var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, JoiningBootstrap, *Opts) (Network, error)); ok { - return rf(_a0, _a1, _a2, _a3, _a4) + if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, JoiningBootstrap, *Opts) (Network, error)); ok { + return rf(_a0, _a1, _a2, _a3) } - if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, JoiningBootstrap, *Opts) Network); ok { - r0 = rf(_a0, _a1, _a2, _a3, _a4) + if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, JoiningBootstrap, *Opts) Network); ok { + r0 = rf(_a0, _a1, _a2, _a3) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(Network) } } - if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, JoiningBootstrap, *Opts) error); ok { - r1 = rf(_a0, _a1, _a2, _a3, _a4) + if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, JoiningBootstrap, *Opts) error); ok { + r1 = rf(_a0, _a1, _a2, _a3) } else { r1 = ret.Error(1) } @@ -80,9 +78,9 @@ func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon. return r0, r1 } -// Load provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4 -func (_m *MockLoader) Load(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.NetworkConfig, _a3 bootstrap.CreationParams, _a4 *Opts) (Network, error) { - ret := _m.Called(_a0, _a1, _a2, _a3, _a4) +// Load provides a mock function with given fields: _a0, _a1, _a2, _a3 +func (_m *MockLoader) Load(_a0 context.Context, _a1 *mlog.Logger, _a2 bootstrap.CreationParams, _a3 *Opts) (Network, error) { + ret := _m.Called(_a0, _a1, _a2, _a3) if len(ret) == 0 { panic("no return value specified for Load") @@ -90,19 +88,19 @@ func (_m *MockLoader) Load(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon. var r0 Network var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, *Opts) (Network, error)); ok { - return rf(_a0, _a1, _a2, _a3, _a4) + if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, *Opts) (Network, error)); ok { + return rf(_a0, _a1, _a2, _a3) } - if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, *Opts) Network); ok { - r0 = rf(_a0, _a1, _a2, _a3, _a4) + if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, *Opts) Network); ok { + r0 = rf(_a0, _a1, _a2, _a3) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(Network) } } - if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, *Opts) error); ok { - r1 = rf(_a0, _a1, _a2, _a3, _a4) + if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, *Opts) error); ok { + r1 = rf(_a0, _a1, _a2, _a3) } else { r1 = ret.Error(1) } diff --git a/go/daemon/network/network.go b/go/daemon/network/network.go index 1122e1f..9b2e3f0 100644 --- a/go/daemon/network/network.go +++ b/go/daemon/network/network.go @@ -154,6 +154,14 @@ type Network interface { // Network instance. A nil Opts is equivalent to a zero value. type Opts struct { GarageAdminToken string // Will be randomly generated if left unset. + + // Config will be used as the configuration of the Network from its + // initialization onwards. + // + // If not given then the most recent NetworkConfig for the network will be + // used, either that which it was most recently initialized with or which + // was passed to [SetConfig]. + Config *daecommon.NetworkConfig } func (o *Opts) withDefaults() *Opts { @@ -189,36 +197,53 @@ type network struct { wg sync.WaitGroup } -// instatiateNetwork returns an instantiated *network instance which has not yet -// been initialized. -func instatiateNetwork( +// newNetwork returns an instantiated *network instance. All initialization +// steps which are common to all *network creation methods (load, join, create) +// are included here as well. +func newNetwork( ctx context.Context, logger *mlog.Logger, - networkConfig daecommon.NetworkConfig, envBinDirPath string, stateDir toolkit.Dir, runtimeDir toolkit.Dir, + dirsMayExist bool, opts *Opts, -) *network { - ctx = context.WithoutCancel(ctx) - ctx, cancel := context.WithCancel(ctx) - return &network{ - logger: logger, - networkConfig: networkConfig, - envBinDirPath: envBinDirPath, - stateDir: stateDir, - runtimeDir: runtimeDir, - opts: opts.withDefaults(), - workerCtx: ctx, - workerCancel: cancel, +) ( + *network, error, +) { + ctx, cancel := context.WithCancel(context.WithoutCancel(ctx)) + + var ( + n = &network{ + logger: logger, + envBinDirPath: envBinDirPath, + stateDir: stateDir, + runtimeDir: runtimeDir, + opts: opts.withDefaults(), + workerCtx: ctx, + workerCancel: cancel, + } + err error + ) + + n.networkConfig, err = loadStoreConfig(n.stateDir, n.opts.Config) + if err != nil { + return nil, fmt.Errorf("resolving network config: %w", err) } + + secretsDir, err := n.stateDir.MkChildDir("secrets", dirsMayExist) + if err != nil { + return nil, fmt.Errorf("creating secrets dir: %w", err) + } + + n.secretsStore = secrets.NewFSStore(secretsDir.Path) + + return n, nil } -// LoadCreationParams returns the CreationParams of a Network which was +// loadCreationParams returns the CreationParams of a Network which was // Created/Joined with the given state directory. -// -// TODO probably can be private -func LoadCreationParams( +func loadCreationParams( stateDir toolkit.Dir, ) ( bootstrap.CreationParams, error, @@ -244,25 +269,23 @@ func load( ctx context.Context, logger *mlog.Logger, envBinDirPath string, - networkConfig daecommon.NetworkConfig, stateDir toolkit.Dir, runtimeDir toolkit.Dir, opts *Opts, ) ( Network, error, ) { - n := instatiateNetwork( + n, err := newNetwork( ctx, logger, - networkConfig, envBinDirPath, stateDir, runtimeDir, + true, opts, ) - - if err := n.initializeDirs(true); err != nil { - return nil, fmt.Errorf("initializing directories: %w", err) + if err != nil { + return nil, fmt.Errorf("instantiating Network: %w", err) } var ( @@ -285,7 +308,6 @@ func join( ctx context.Context, logger *mlog.Logger, envBinDirPath string, - networkConfig daecommon.NetworkConfig, joiningBootstrap JoiningBootstrap, stateDir toolkit.Dir, runtimeDir toolkit.Dir, @@ -293,18 +315,17 @@ func join( ) ( Network, error, ) { - n := instatiateNetwork( + n, err := newNetwork( ctx, logger, - networkConfig, envBinDirPath, stateDir, runtimeDir, + false, opts, ) - - if err := n.initializeDirs(false); err != nil { - return nil, fmt.Errorf("initializing directories: %w", err) + if err != nil { + return nil, fmt.Errorf("instantiating Network: %w", err) } if err := secrets.Import( @@ -324,7 +345,6 @@ func create( ctx context.Context, logger *mlog.Logger, envBinDirPath string, - networkConfig daecommon.NetworkConfig, stateDir toolkit.Dir, runtimeDir toolkit.Dir, creationParams bootstrap.CreationParams, @@ -334,12 +354,6 @@ func create( ) ( Network, error, ) { - if len(networkConfig.Storage.Allocations) < 3 { - return nil, ErrInvalidConfig.WithData( - "At least three storage allocations are required.", - ) - } - nebulaCACreds, err := nebula.NewCACredentials(creationParams.Domain, ipNet) if err != nil { return nil, fmt.Errorf("creating nebula CA cert: %w", err) @@ -347,18 +361,23 @@ func create( garageRPCSecret := toolkit.RandStr(32) - n := instatiateNetwork( + n, err := newNetwork( ctx, logger, - networkConfig, envBinDirPath, stateDir, runtimeDir, + false, opts, ) + if err != nil { + return nil, fmt.Errorf("instantiating Network: %w", err) + } - if err := n.initializeDirs(false); err != nil { - return nil, fmt.Errorf("initializing directories: %w", err) + if len(n.networkConfig.Storage.Allocations) < 3 { + return nil, ErrInvalidConfig.WithData( + "At least three storage allocations are required.", + ) } err = daecommon.SetGarageRPCSecret(ctx, n.secretsStore, garageRPCSecret) @@ -391,16 +410,6 @@ func create( return n, nil } -func (n *network) initializeDirs(mayExist bool) error { - secretsDir, err := n.stateDir.MkChildDir("secrets", mayExist) - if err != nil { - return fmt.Errorf("creating secrets dir: %w", err) - } - - n.secretsStore = secrets.NewFSStore(secretsDir.Path) - return nil -} - func (n *network) periodically( label string, fn func(context.Context) error, @@ -1016,6 +1025,10 @@ func (n *network) GetConfig(context.Context) (daecommon.NetworkConfig, error) { func (n *network) SetConfig( ctx context.Context, config daecommon.NetworkConfig, ) error { + if _, err := loadStoreConfig(n.stateDir, &config); err != nil { + return fmt.Errorf("storing new config: %w", err) + } + prevBootstrap, err := n.reload(ctx, &config, nil) if err != nil { return fmt.Errorf("reloading config: %w", err) diff --git a/go/daemon/network/network_it_test.go b/go/daemon/network/network_it_test.go index 248c304..dd25605 100644 --- a/go/daemon/network/network_it_test.go +++ b/go/daemon/network/network_it_test.go @@ -17,7 +17,7 @@ func TestCreate(t *testing.T) { network = h.createNetwork(t, "primus", nil) ) - gotCreationParams, err := LoadCreationParams(network.stateDir) + gotCreationParams, err := loadCreationParams(network.stateDir) assert.NoError(t, err) assert.Equal( t, gotCreationParams, network.getBootstrap(t).NetworkCreationParams, @@ -25,31 +25,30 @@ func TestCreate(t *testing.T) { } func TestLoad(t *testing.T) { - var ( - h = newIntegrationHarness(t) - network = h.createNetwork(t, "primus", &createNetworkOpts{ - manualShutdown: true, - }) - ) + t.Run("given config", func(t *testing.T) { + var ( + h = newIntegrationHarness(t) + network = h.createNetwork(t, "primus", nil) + networkConfig = network.getConfig(t) + ) - t.Log("Shutting down network") - assert.NoError(t, network.Shutdown()) + network.opts.Config = &networkConfig + network.restart(t) - t.Log("Calling Load") - loadedNetwork, err := load( - h.ctx, - h.logger.WithNamespace("loadedNetwork"), - getEnvBinDirPath(), - network.getConfig(t), - network.stateDir, - h.mkDir(t, "runtime"), - network.opts, - ) - assert.NoError(t, err) + assert.Equal(t, networkConfig, network.getConfig(t)) + }) - t.Cleanup(func() { - t.Log("Shutting down loadedNetwork") - assert.NoError(t, loadedNetwork.Shutdown()) + t.Run("load previous config", func(t *testing.T) { + var ( + h = newIntegrationHarness(t) + network = h.createNetwork(t, "primus", nil) + networkConfig = network.getConfig(t) + ) + + network.opts.Config = nil + network.restart(t) + + assert.Equal(t, networkConfig, network.getConfig(t)) }) } @@ -61,13 +60,7 @@ func TestJoin(t *testing.T) { secondus = h.joinNetwork(t, primus, "secondus", nil) ) - primusHosts, err := primus.GetHosts(h.ctx) - assert.NoError(t, err) - - secondusHosts, err := secondus.GetHosts(h.ctx) - assert.NoError(t, err) - - assert.Equal(t, primusHosts, secondusHosts) + assert.Equal(t, primus.getHostsByName(t), secondus.getHostsByName(t)) }) t.Run("with alloc", func(t *testing.T) { @@ -84,28 +77,10 @@ func TestJoin(t *testing.T) { t.Log("reloading primus' hosts") assert.NoError(t, primus.Network.(*network).reloadHosts(h.ctx)) - primusHosts, err := primus.GetHosts(h.ctx) - assert.NoError(t, err) - - secondusHosts, err := secondus.GetHosts(h.ctx) - assert.NoError(t, err) - - assert.Equal(t, primusHosts, secondusHosts) + assert.Equal(t, primus.getHostsByName(t), secondus.getHostsByName(t)) }) } -func TestNetwork_GetConfig(t *testing.T) { - var ( - h = newIntegrationHarness(t) - network = h.createNetwork(t, "primus", nil) - ) - - config, err := network.GetConfig(h.ctx) - assert.NoError(t, err) - - assert.Equal(t, config, network.getConfig(t)) -} - func TestNetwork_SetConfig(t *testing.T) { allocsToRoles := func( hostName nebula.HostName, allocs []bootstrap.GarageHostInstance, @@ -259,4 +234,22 @@ func TestNetwork_SetConfig(t *testing.T) { assert.NoError(t, err) assert.NotContains(t, layout.Roles, removedRole) }) + + t.Run("changes reflected after restart", func(t *testing.T) { + var ( + h = newIntegrationHarness(t) + network = h.createNetwork(t, "primus", &createNetworkOpts{ + numStorageAllocs: 4, + }) + networkConfig = network.getConfig(t) + ) + + networkConfig.Storage.Allocations = networkConfig.Storage.Allocations[:3] + assert.NoError(t, network.SetConfig(h.ctx, networkConfig)) + + network.opts.Config = nil + network.restart(t) + + assert.Equal(t, networkConfig, network.getConfig(t)) + }) } diff --git a/go/daemon/network/network_it_util_test.go b/go/daemon/network/network_it_util_test.go index 7e30624..ad3c5c0 100644 --- a/go/daemon/network/network_it_util_test.go +++ b/go/daemon/network/network_it_util_test.go @@ -2,7 +2,6 @@ package network import ( "context" - "encoding/json" "fmt" "isle/bootstrap" "isle/daemon/daecommon" @@ -18,7 +17,6 @@ import ( "dev.mediocregopher.com/mediocre-go-lib.git/mlog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" ) // Utilities related to running network integration tests @@ -62,16 +60,6 @@ func newTunDevice() string { return fmt.Sprintf("isle-test-%d", atomic.AddUint64(&tunDeviceCounter, 1)) } -func mustParseNetworkConfigf(str string, args ...any) daecommon.NetworkConfig { - str = fmt.Sprintf(str, args...) - - var networkConfig daecommon.NetworkConfig - if err := yaml.Unmarshal([]byte(str), &networkConfig); err != nil { - panic(fmt.Sprintf("parsing network config: %v", err)) - } - return networkConfig -} - type integrationHarness struct { ctx context.Context logger *mlog.Logger @@ -136,35 +124,24 @@ func (h *integrationHarness) mkNetworkConfig( opts = new(networkConfigOpts) } - publicAddr := "" - if opts.hasPublicAddr { - publicAddr = newPublicAddr() - } - - allocs := make([]map[string]any, opts.numStorageAllocs) - for i := range allocs { - allocs[i] = map[string]any{ - "data_path": h.mkDir(t, "data").Path, - "meta_path": h.mkDir(t, "meta").Path, - "capacity": 1, + return daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) { + if opts.hasPublicAddr { + c.VPN.PublicAddr = newPublicAddr() } - } - allocsJSON, err := json.Marshal(allocs) - require.NoError(t, err) + c.VPN.Tun.Device = newTunDevice() // TODO is this necessary?? - return mustParseNetworkConfigf(` - vpn: - public_addr: %q - tun: - device: %q - storage: - allocations: %s - `, - publicAddr, - newTunDevice(), - allocsJSON, - ) + c.Storage.Allocations = make( + []daecommon.ConfigStorageAllocation, opts.numStorageAllocs, + ) + for i := range c.Storage.Allocations { + c.Storage.Allocations[i] = daecommon.ConfigStorageAllocation{ + DataPath: h.mkDir(t, "data").Path, + MetaPath: h.mkDir(t, "meta").Path, + Capacity: 1, + } + } + }) } type createNetworkOpts struct { @@ -203,7 +180,7 @@ func (h *integrationHarness) createNetwork( t *testing.T, hostNameStr string, opts *createNetworkOpts, -) integrationHarnessNetwork { +) *integrationHarnessNetwork { t.Logf("Creating as %q", hostNameStr) opts = opts.withDefaults() @@ -222,6 +199,7 @@ func (h *integrationHarness) createNetwork( networkOpts = &Opts{ GarageAdminToken: "admin_token", + Config: &networkConfig, } ) @@ -229,7 +207,6 @@ func (h *integrationHarness) createNetwork( h.ctx, logger, getEnvBinDirPath(), - networkConfig, stateDir, runtimeDir, opts.creationParams, @@ -241,16 +218,7 @@ func (h *integrationHarness) createNetwork( t.Fatalf("creating Network: %v", err) } - if !opts.manualShutdown { - t.Cleanup(func() { - t.Logf("Shutting down Network %q", hostNameStr) - if err := network.Shutdown(); err != nil { - t.Logf("Shutting down Network %q failed: %v", hostNameStr, err) - } - }) - } - - return integrationHarnessNetwork{ + nh := &integrationHarnessNetwork{ network, h.ctx, logger, @@ -259,6 +227,17 @@ func (h *integrationHarness) createNetwork( runtimeDir, networkOpts, } + + if !opts.manualShutdown { + t.Cleanup(func() { + t.Logf("Shutting down Network %q", hostNameStr) + if err := nh.Shutdown(); err != nil { + t.Logf("Shutting down Network %q failed: %v", hostNameStr, err) + } + }) + } + + return nh } type joinNetworkOpts struct { @@ -279,10 +258,10 @@ func (o *joinNetworkOpts) withDefaults() *joinNetworkOpts { func (h *integrationHarness) joinNetwork( t *testing.T, - network integrationHarnessNetwork, + network *integrationHarnessNetwork, hostNameStr string, opts *joinNetworkOpts, -) integrationHarnessNetwork { +) *integrationHarnessNetwork { opts = opts.withDefaults() hostName := nebula.HostName(hostNameStr) @@ -301,6 +280,7 @@ func (h *integrationHarness) joinNetwork( runtimeDir = h.mkDir(t, "runtime") networkOpts = &Opts{ GarageAdminToken: "admin_token", + Config: &networkConfig, } ) @@ -309,7 +289,6 @@ func (h *integrationHarness) joinNetwork( h.ctx, logger, getEnvBinDirPath(), - networkConfig, joiningBootstrap, stateDir, runtimeDir, @@ -319,16 +298,7 @@ func (h *integrationHarness) joinNetwork( t.Fatalf("joining network: %v", err) } - if !opts.manualShutdown { - t.Cleanup(func() { - t.Logf("Shutting down Network %q", hostNameStr) - if err := joinedNetwork.Shutdown(); err != nil { - t.Logf("Shutting down Network %q failed: %v", hostNameStr, err) - } - }) - } - - return integrationHarnessNetwork{ + nh := &integrationHarnessNetwork{ joinedNetwork, h.ctx, logger, @@ -337,6 +307,34 @@ func (h *integrationHarness) joinNetwork( runtimeDir, networkOpts, } + + if !opts.manualShutdown { + t.Cleanup(func() { + t.Logf("Shutting down Network %q", hostNameStr) + if err := nh.Shutdown(); err != nil { + t.Logf("Shutting down Network %q failed: %v", hostNameStr, err) + } + }) + } + + return nh +} + +func (nh *integrationHarnessNetwork) restart(t *testing.T) { + t.Log("Shutting down network (restart)") + require.NoError(t, nh.Network.Shutdown()) + + t.Log("Loading network (restart)") + var err error + nh.Network, err = load( + nh.ctx, + nh.logger, + getEnvBinDirPath(), + nh.stateDir, + nh.runtimeDir, + nh.opts, + ) + require.NoError(t, err) } func (nh *integrationHarnessNetwork) getConfig(t *testing.T) daecommon.NetworkConfig { diff --git a/go/daemon/rpc.go b/go/daemon/rpc.go index 28f2863..698cc62 100644 --- a/go/daemon/rpc.go +++ b/go/daemon/rpc.go @@ -3,6 +3,7 @@ package daemon import ( "context" "isle/bootstrap" + "isle/daemon/daecommon" "isle/daemon/jsonrpc2" "isle/daemon/network" "isle/nebula" @@ -25,13 +26,23 @@ type RPC interface { GetNetworks(context.Context) ([]bootstrap.CreationParams, error) + // SetConfig extends the [network.RPC] method of the same name such that + // [ErrUserManagedNetworkConfig] is returned if the picked network is + // configured as part of the [daecommon.Config] which the Daemon was + // initialized with. + // + // See the `network.RPC` documentation in this interface for more usage + // details. + SetConfig(context.Context, daecommon.NetworkConfig) error + // All network.RPC methods are automatically implemented by Daemon using the // currently joined network. If no network is joined then any call to these // methods will return ErrNoNetwork. // - // All calls to these methods must be accompanied with a context produced by - // WithNetwork, in order to choose the network. These methods may return - // these errors, in addition to those documented on the individual methods: + // If more than one Network is joined then all calls to these methods must + // be accompanied with a context produced by WithNetwork, in order to choose + // the network. These methods may return these errors, in addition to those + // documented on the individual methods: // // Errors: // - ErrNoNetwork diff --git a/go/toolkit/testutils.go b/go/toolkit/testutils.go index 3cc1434..8efac13 100644 --- a/go/toolkit/testutils.go +++ b/go/toolkit/testutils.go @@ -6,6 +6,7 @@ import ( "testing" "dev.mediocregopher.com/mediocre-go-lib.git/mlog" + "github.com/stretchr/testify/mock" ) // MarkIntegrationTest marks a test as being an integration test. It will be @@ -31,3 +32,10 @@ func NewTestLogger(t *testing.T) *mlog.Logger { MaxLevel: level.Int(), }) } + +// MockArg returns a value which can be used as a [mock.Call] argument, and +// which will match any value of type T. If T is an interface then also values +// implementing that interface will be matched. +func MockArg[T any]() any { + return mock.MatchedBy(func(T) bool { return true }) +} diff --git a/go/yamlutil/testutils.go b/go/yamlutil/testutils.go new file mode 100644 index 0000000..57dfbf6 --- /dev/null +++ b/go/yamlutil/testutils.go @@ -0,0 +1,29 @@ +package yamlutil + +import ( + "strings" +) + +// ReplacePrefixTabs will replace the leading tabs of each line with two spaces. +// Tabs are not a valid indent in yaml (wtf), but they are convenient to use in +// go when using multi-line strings. +// +// ReplacePrefixTabs should only be used within tests. +func ReplacePrefixTabs(str string) string { + lines := strings.Split(str, "\n") + for i := range lines { + var n int + for _, r := range lines[i] { + if r == '\t' { + n++ + } else { + break + } + } + + spaces := strings.Repeat(" ", n) + lines[i] = spaces + lines[i][n:] + } + + return strings.Join(lines, "\n") +}