Pass NetworkConfig into Network loaders as an optional argument

This commit is contained in:
Brian Picciano 2024-11-11 15:32:15 +01:00
parent 72bca72b29
commit 6ec56f2a88
16 changed files with 691 additions and 264 deletions

View File

@ -44,22 +44,20 @@ var HTTPSocketPath = sync.OnceValue(func() string {
func pickNetworkConfig( func pickNetworkConfig(
daemonConfig daecommon.Config, creationParams bootstrap.CreationParams, daemonConfig daecommon.Config, creationParams bootstrap.CreationParams,
) ( ) *daecommon.NetworkConfig {
daecommon.NetworkConfig, bool,
) {
if len(daemonConfig.Networks) == 1 { // DEPRECATED if len(daemonConfig.Networks) == 1 { // DEPRECATED
if c, ok := daemonConfig.Networks[daecommon.DeprecatedNetworkID]; ok { if c, ok := daemonConfig.Networks[daecommon.DeprecatedNetworkID]; ok {
return c, true return &c
} }
} }
for searchStr, networkConfig := range daemonConfig.Networks { for searchStr, networkConfig := range daemonConfig.Networks {
if creationParams.Matches(searchStr) { if creationParams.Matches(searchStr) {
return networkConfig, true return &networkConfig
} }
} }
return daecommon.NetworkConfig{}, false return nil
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -86,6 +86,18 @@ type NetworkConfig struct {
} `yaml:"storage"` } `yaml:"storage"`
} }
// NewNetworkConfig returns a new NetworkConfig populated with its default
// values. If a callback is given the NetworkConfig will be passed to it for
// modification prior to having defaults populated.
func NewNetworkConfig(fn func(*NetworkConfig)) NetworkConfig {
var c NetworkConfig
if fn != nil {
fn(&c)
}
c.fillDefaults()
return c
}
func (c *NetworkConfig) fillDefaults() { func (c *NetworkConfig) fillDefaults() {
if c.DNS.Resolvers == nil { if c.DNS.Resolvers == nil {
c.DNS.Resolvers = []string{ c.DNS.Resolvers = []string{
@ -213,9 +225,15 @@ func CopyDefaultConfig(into io.Writer) error {
// fill in default values where it can. // fill in default values where it can.
func (c *Config) UnmarshalYAML(n *yaml.Node) error { func (c *Config) UnmarshalYAML(n *yaml.Node) error {
{ // DEPRECATED { // DEPRECATED
// We decode into a wrapped NetworkConfig, so that its UnmarshalYAML
// doesn't get invoked. This prevents fillDefaults from getting
// automatically called, so the IsZero check will work as intended. This
// means we need to call fillDefaults manually though.
type wrap NetworkConfig
var networkConfig NetworkConfig var networkConfig NetworkConfig
_ = n.Decode(&networkConfig) _ = n.Decode((*wrap)(&networkConfig))
if !toolkit.IsZero(networkConfig) { if !toolkit.IsZero(networkConfig) {
networkConfig.fillDefaults()
*c = Config{ *c = Config{
Networks: map[string]NetworkConfig{ Networks: map[string]NetworkConfig{
DeprecatedNetworkID: networkConfig, DeprecatedNetworkID: networkConfig,

View File

@ -0,0 +1,72 @@
package daecommon
import (
"isle/yamlutil"
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
func TestConfig_UnmarshalYAML(t *testing.T) { // TODO test validation
tests := []struct {
label string
str string
config Config
}{
{"empty", ``, Config{}},
{
"DEPRECATED single global network",
`
{"dns":{"resolvers":["a"]}}
`,
Config{
Networks: map[string]NetworkConfig{
DeprecatedNetworkID: NewNetworkConfig(func(c *NetworkConfig) {
c.DNS.Resolvers = []string{"a"}
}),
},
},
},
{
"single network",
`
networks:
foo: {"dns":{"resolvers":["a"]}}
`,
Config{
Networks: map[string]NetworkConfig{
"foo": NewNetworkConfig(func(c *NetworkConfig) {
c.DNS.Resolvers = []string{"a"}
}),
},
},
},
{
"multiple networks",
`
networks:
foo: {"dns":{"resolvers":["a"]}}
bar: {}
`,
Config{
Networks: map[string]NetworkConfig{
"foo": NewNetworkConfig(func(c *NetworkConfig) {
c.DNS.Resolvers = []string{"a"}
}),
"bar": NewNetworkConfig(nil),
},
},
},
}
for _, test := range tests {
t.Run(test.label, func(t *testing.T) {
test.str = yamlutil.ReplacePrefixTabs(test.str)
var config Config
err := yaml.Unmarshal([]byte(test.str), &config)
assert.NoError(t, err)
assert.Equal(t, test.config, config)
})
}
}

View File

@ -19,6 +19,11 @@ import (
var _ RPC = (*Daemon)(nil) var _ RPC = (*Daemon)(nil)
type joinedNetwork struct {
network.Network
config *daecommon.NetworkConfig
}
// Daemon implements all methods of the Daemon interface, plus others used // Daemon implements all methods of the Daemon interface, plus others used
// to manage this particular implementation. // to manage this particular implementation.
// //
@ -41,7 +46,7 @@ type Daemon struct {
daemonConfig daecommon.Config daemonConfig daecommon.Config
l sync.RWMutex l sync.RWMutex
networks map[string]network.Network networks map[string]joinedNetwork
} }
// New initializes and returns a Daemon. // New initializes and returns a Daemon.
@ -57,7 +62,7 @@ func New(
logger: logger, logger: logger,
networkLoader: networkLoader, networkLoader: networkLoader,
daemonConfig: daemonConfig, daemonConfig: daemonConfig,
networks: map[string]network.Network{}, networks: map[string]joinedNetwork{},
} }
loadableNetworks, err := networkLoader.Loadable(ctx) loadableNetworks, err := networkLoader.Loadable(ctx)
@ -69,20 +74,23 @@ func New(
ctx = mctx.WithAnnotator(ctx, creationParams) ctx = mctx.WithAnnotator(ctx, creationParams)
var ( var (
id = creationParams.ID id = creationParams.ID
networkConfig, _ = pickNetworkConfig(daemonConfig, creationParams) networkConfig = pickNetworkConfig(daemonConfig, creationParams)
) )
d.networks[id], err = networkLoader.Load( n, err := networkLoader.Load(
ctx, ctx,
logger.WithNamespace("network"), logger.WithNamespace("network"),
networkConfig,
creationParams, creationParams,
nil, &network.Opts{
Config: networkConfig,
},
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("loading network %q: %w", id, err) return nil, fmt.Errorf("loading network %q: %w", id, err)
} }
d.networks[id] = joinedNetwork{n, networkConfig}
} }
return d, nil return d, nil
@ -105,17 +113,16 @@ func (d *Daemon) CreateNetwork(
ctx context.Context, ctx context.Context,
name, domain string, ipNet nebula.IPNet, hostName nebula.HostName, name, domain string, ipNet nebula.IPNet, hostName nebula.HostName,
) error { ) error {
creationParams := bootstrap.NewCreationParams(name, domain)
ctx = mctx.WithAnnotator(ctx, creationParams)
networkConfig, ok := pickNetworkConfig(d.daemonConfig, creationParams)
if !ok {
return errors.New("couldn't find network config for network being created")
}
d.l.Lock() d.l.Lock()
defer d.l.Unlock() defer d.l.Unlock()
var (
creationParams = bootstrap.NewCreationParams(name, domain)
networkConfig = pickNetworkConfig(d.daemonConfig, creationParams)
)
ctx = mctx.WithAnnotator(ctx, creationParams)
if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil { if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil {
return fmt.Errorf("checking if already joined to network: %w", err) return fmt.Errorf("checking if already joined to network: %w", err)
} else if joined { } else if joined {
@ -126,18 +133,19 @@ func (d *Daemon) CreateNetwork(
n, err := d.networkLoader.Create( n, err := d.networkLoader.Create(
ctx, ctx,
d.logger.WithNamespace("network"), d.logger.WithNamespace("network"),
networkConfig,
creationParams, creationParams,
ipNet, ipNet,
hostName, hostName,
nil, &network.Opts{
Config: networkConfig,
},
) )
if err != nil { if err != nil {
return fmt.Errorf("creating network: %w", err) return fmt.Errorf("creating network: %w", err)
} }
d.logger.Info(ctx, "Network created successfully") d.logger.Info(ctx, "Network created successfully")
d.networks[creationParams.ID] = n d.networks[creationParams.ID] = joinedNetwork{n, networkConfig}
return nil return nil
} }
@ -149,17 +157,17 @@ func (d *Daemon) CreateNetwork(
func (d *Daemon) JoinNetwork( func (d *Daemon) JoinNetwork(
ctx context.Context, newBootstrap network.JoiningBootstrap, ctx context.Context, newBootstrap network.JoiningBootstrap,
) error { ) error {
d.l.Lock()
defer d.l.Unlock()
var ( var (
creationParams = newBootstrap.Bootstrap.NetworkCreationParams creationParams = newBootstrap.Bootstrap.NetworkCreationParams
networkConfig, _ = pickNetworkConfig(d.daemonConfig, creationParams) networkID = creationParams.ID
networkID = creationParams.ID networkConfig = pickNetworkConfig(d.daemonConfig, creationParams)
) )
ctx = mctx.WithAnnotator(ctx, newBootstrap.Bootstrap.NetworkCreationParams) ctx = mctx.WithAnnotator(ctx, newBootstrap.Bootstrap.NetworkCreationParams)
d.l.Lock()
defer d.l.Unlock()
if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil { if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil {
return fmt.Errorf("checking if already joined to network: %w", err) return fmt.Errorf("checking if already joined to network: %w", err)
} else if joined { } else if joined {
@ -170,9 +178,10 @@ func (d *Daemon) JoinNetwork(
n, err := d.networkLoader.Join( n, err := d.networkLoader.Join(
ctx, ctx,
d.logger.WithNamespace("network"), d.logger.WithNamespace("network"),
networkConfig,
newBootstrap, newBootstrap,
nil, &network.Opts{
Config: networkConfig,
},
) )
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf(
@ -181,14 +190,14 @@ func (d *Daemon) JoinNetwork(
} }
d.logger.Info(ctx, "Network joined successfully") d.logger.Info(ctx, "Network joined successfully")
d.networks[networkID] = n d.networks[networkID] = joinedNetwork{n, networkConfig}
return nil return nil
} }
func withNetwork[Res any]( func withNetwork[Res any](
ctx context.Context, ctx context.Context,
d *Daemon, d *Daemon,
fn func(context.Context, network.Network) (Res, error), fn func(context.Context, joinedNetwork) (Res, error),
) ( ) (
Res, error, Res, error,
) { ) {
@ -238,7 +247,7 @@ func (d *Daemon) GetHosts(ctx context.Context) ([]bootstrap.Host, error) {
return withNetwork( return withNetwork(
ctx, ctx,
d, d,
func(ctx context.Context, n network.Network) ([]bootstrap.Host, error) { func(ctx context.Context, n joinedNetwork) ([]bootstrap.Host, error) {
return n.GetHosts(ctx) return n.GetHosts(ctx)
}, },
) )
@ -254,7 +263,7 @@ func (d *Daemon) GetGarageClientParams(
ctx, ctx,
d, d,
func( func(
ctx context.Context, n network.Network, ctx context.Context, n joinedNetwork,
) ( ) (
network.GarageClientParams, error, network.GarageClientParams, error,
) { ) {
@ -274,7 +283,7 @@ func (d *Daemon) GetNebulaCAPublicCredentials(
ctx, ctx,
d, d,
func( func(
ctx context.Context, n network.Network, ctx context.Context, n joinedNetwork,
) ( ) (
nebula.CAPublicCredentials, error, nebula.CAPublicCredentials, error,
) { ) {
@ -289,7 +298,7 @@ func (d *Daemon) RemoveHost(ctx context.Context, hostName nebula.HostName) error
ctx, ctx,
d, d,
func( func(
ctx context.Context, n network.Network, ctx context.Context, n joinedNetwork,
) ( ) (
struct{}, error, struct{}, error,
) { ) {
@ -311,7 +320,7 @@ func (d *Daemon) CreateHost(
ctx, ctx,
d, d,
func( func(
ctx context.Context, n network.Network, ctx context.Context, n joinedNetwork,
) ( ) (
network.JoiningBootstrap, error, network.JoiningBootstrap, error,
) { ) {
@ -332,7 +341,7 @@ func (d *Daemon) CreateNebulaCertificate(
ctx, ctx,
d, d,
func( func(
ctx context.Context, n network.Network, ctx context.Context, n joinedNetwork,
) ( ) (
nebula.Certificate, error, nebula.Certificate, error,
) { ) {
@ -341,6 +350,7 @@ func (d *Daemon) CreateNebulaCertificate(
) )
} }
// GetConfig implements the method for the network.RPC interface.
func (d *Daemon) GetConfig( func (d *Daemon) GetConfig(
ctx context.Context, ctx context.Context,
) ( ) (
@ -350,7 +360,7 @@ func (d *Daemon) GetConfig(
ctx, ctx,
d, d,
func( func(
ctx context.Context, n network.Network, ctx context.Context, n joinedNetwork,
) ( ) (
daecommon.NetworkConfig, error, daecommon.NetworkConfig, error,
) { ) {
@ -359,13 +369,18 @@ func (d *Daemon) GetConfig(
) )
} }
// SetConfig implements the method for the network.RPC interface.
func (d *Daemon) SetConfig( func (d *Daemon) SetConfig(
ctx context.Context, config daecommon.NetworkConfig, ctx context.Context, config daecommon.NetworkConfig,
) error { ) error {
_, err := withNetwork( _, err := withNetwork(
ctx, ctx,
d, d,
func(ctx context.Context, n network.Network) (struct{}, error) { func(ctx context.Context, n joinedNetwork) (struct{}, error) {
if n.config != nil {
return struct{}{}, ErrUserManagedNetworkConfig
}
// TODO needs to check that public addresses aren't being shared // TODO needs to check that public addresses aren't being shared
// across networks, and whatever else happens in Config.Validate. // across networks, and whatever else happens in Config.Validate.
return struct{}{}, n.SetConfig(ctx, config) return struct{}{}, n.SetConfig(ctx, config)

227
go/daemon/daemon_test.go Normal file
View File

@ -0,0 +1,227 @@
package daemon
import (
"context"
"isle/bootstrap"
"isle/daemon/daecommon"
"isle/daemon/network"
"isle/toolkit"
"testing"
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type expectNetworkLoad struct {
creationParams bootstrap.CreationParams
networkConfig *daecommon.NetworkConfig
network *network.MockNetwork
}
type harnessOpts struct {
config daecommon.Config
expectNetworksLoaded []expectNetworkLoad
}
func (o *harnessOpts) withDefaults() *harnessOpts {
if o == nil {
o = new(harnessOpts)
}
return o
}
type harness struct {
ctx context.Context
networkLoader *network.MockLoader
daemon *Daemon
}
func newHarness(t *testing.T, opts *harnessOpts) *harness {
t.Parallel()
opts = opts.withDefaults()
var (
ctx = context.Background()
logger = toolkit.NewTestLogger(t)
networkLoader = network.NewMockLoader(t)
)
expectLoadable := make(
[]bootstrap.CreationParams, len(opts.expectNetworksLoaded),
)
for i, expectNetworkLoaded := range opts.expectNetworksLoaded {
expectLoadable[i] = expectNetworkLoaded.creationParams
networkLoader.
On(
"Load",
toolkit.MockArg[context.Context](),
toolkit.MockArg[*mlog.Logger](),
expectNetworkLoaded.creationParams,
&network.Opts{
Config: expectNetworkLoaded.networkConfig,
},
).
Return(expectNetworkLoaded.network, nil).
Once()
expectNetworkLoaded.network.On("Shutdown").Return(nil).Once()
}
networkLoader.
On("Loadable", toolkit.MockArg[context.Context]()).
Return(expectLoadable, nil).
Once()
daemon, err := New(ctx, logger, networkLoader, opts.config)
require.NoError(t, err)
t.Cleanup(func() {
t.Log("Shutting down Daemon")
assert.NoError(t, daemon.Shutdown())
})
return &harness{ctx, networkLoader, daemon}
}
func TestNew(t *testing.T) {
t.Run("no networks loaded", func(t *testing.T) {
_ = newHarness(t, nil)
})
t.Run("DEPRECATED network config matching", func(t *testing.T) {
var (
creationParams = bootstrap.NewCreationParams("AAA", "a.com")
networkConfig = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
c.DNS.Resolvers = []string{"foo"}
})
config = daecommon.Config{
Networks: map[string]daecommon.NetworkConfig{
daecommon.DeprecatedNetworkID: networkConfig,
},
}
)
_ = newHarness(t, &harnessOpts{
config: config,
expectNetworksLoaded: []expectNetworkLoad{
{
creationParams,
&networkConfig,
network.NewMockNetwork(t),
},
},
})
})
t.Run("network config matching", func(t *testing.T) {
var (
creationParamsA = bootstrap.NewCreationParams("AAA", "a.com")
creationParamsB = bootstrap.NewCreationParams("BBB", "b.com")
creationParamsC = bootstrap.NewCreationParams("CCC", "c.com")
creationParamsD = bootstrap.NewCreationParams("DDD", "d.com")
networkConfigA = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
c.DNS.Resolvers = []string{"foo"}
})
networkConfigB = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
c.VPN.Tun.Device = "bar"
})
networkConfigC = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
c.Storage.Allocations = []daecommon.ConfigStorageAllocation{
{
DataPath: "/path/data",
MetaPath: "/path/meta",
Capacity: 1,
},
}
})
config = daecommon.Config{
Networks: map[string]daecommon.NetworkConfig{
creationParamsA.ID: networkConfigA,
creationParamsB.Name: networkConfigB,
creationParamsC.Domain: networkConfigC,
"unknown": {},
},
}
)
_ = newHarness(t, &harnessOpts{
config: config,
expectNetworksLoaded: []expectNetworkLoad{
{
creationParamsA,
&networkConfigA,
network.NewMockNetwork(t),
},
{
creationParamsB,
&networkConfigB,
network.NewMockNetwork(t),
},
{
creationParamsC,
&networkConfigC,
network.NewMockNetwork(t),
},
{
creationParamsD,
nil,
network.NewMockNetwork(t),
},
},
})
})
}
func TestDaemon_SetConfig(t *testing.T) {
t.Run("success", func(t *testing.T) {
var (
networkA = network.NewMockNetwork(t)
h = newHarness(t, &harnessOpts{
expectNetworksLoaded: []expectNetworkLoad{{
bootstrap.NewCreationParams("AAA", "a.com"), nil, networkA,
}},
})
networkConfig = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
c.VPN.Tun.Device = "foo"
})
)
networkA.
On("SetConfig", toolkit.MockArg[context.Context](), networkConfig).
Return(nil).
Once()
err := h.daemon.SetConfig(h.ctx, networkConfig)
assert.NoError(t, err)
})
t.Run("ErrUserManagedNetworkConfig", func(t *testing.T) {
var (
creationParams = bootstrap.NewCreationParams("AAA", "a.com")
networkA = network.NewMockNetwork(t)
networkConfig = daecommon.NewNetworkConfig(nil)
h = newHarness(t, &harnessOpts{
config: daecommon.Config{
Networks: map[string]daecommon.NetworkConfig{
creationParams.Name: networkConfig,
},
},
expectNetworksLoaded: []expectNetworkLoad{
{creationParams, &networkConfig, networkA},
},
})
)
networkConfig.VPN.Tun.Device = "foo"
err := h.daemon.SetConfig(h.ctx, networkConfig)
assert.ErrorIs(t, err, ErrUserManagedNetworkConfig)
})
}

View File

@ -10,6 +10,7 @@ const (
errCodeAlreadyJoined errCodeAlreadyJoined
errCodeNoMatchingNetworks errCodeNoMatchingNetworks
errCodeMultipleMatchingNetworks errCodeMultipleMatchingNetworks
errCodeUserManagedNetworkConfig
) )
var ( var (
@ -33,4 +34,11 @@ var (
errCodeMultipleMatchingNetworks, errCodeMultipleMatchingNetworks,
"Multiple networks matched the search string", "Multiple networks matched the search string",
) )
// ErrUserManagedNetworkConfig is returned when attempting to modify a
// network config which is managed by the user.
ErrUserManagedNetworkConfig = jsonrpc2.NewError(
errCodeUserManagedNetworkConfig,
"Network configuration is managed by the user",
)
) )

View File

@ -10,24 +10,30 @@ import (
func pickNetwork( func pickNetwork(
ctx context.Context, ctx context.Context,
networkLoader network.Loader, networkLoader network.Loader,
networks map[string]network.Network, networks map[string]joinedNetwork,
) ( ) (
network.Network, error, joinedNetwork, error,
) { ) {
if len(networks) == 0 { if len(networks) == 0 {
return nil, ErrNoNetwork return joinedNetwork{}, ErrNoNetwork
}
networkSearchStr := getNetworkSearchStr(ctx)
if networkSearchStr == "" {
if len(networks) > 1 {
return joinedNetwork{}, ErrNoMatchingNetworks
}
for _, network := range networks {
return network, nil
}
} }
creationParams, err := networkLoader.Loadable(ctx) creationParams, err := networkLoader.Loadable(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting loadable networks: %w", err) return joinedNetwork{}, fmt.Errorf("getting loadable networks: %w", err)
} }
var ( matchingNetworkIDs := make([]string, 0, len(networks))
networkSearchStr = getNetworkSearchStr(ctx)
matchingNetworkIDs = make([]string, 0, len(networks))
)
for _, creationParam := range creationParams { for _, creationParam := range creationParams {
if networkSearchStr == "" || creationParam.Matches(networkSearchStr) { if networkSearchStr == "" || creationParam.Matches(networkSearchStr) {
matchingNetworkIDs = append(matchingNetworkIDs, creationParam.ID) matchingNetworkIDs = append(matchingNetworkIDs, creationParam.ID)
@ -35,9 +41,9 @@ func pickNetwork(
} }
if len(matchingNetworkIDs) == 0 { if len(matchingNetworkIDs) == 0 {
return nil, ErrNoMatchingNetworks return joinedNetwork{}, ErrNoMatchingNetworks
} else if len(matchingNetworkIDs) > 1 { } else if len(matchingNetworkIDs) > 1 {
return nil, ErrMultipleMatchingNetworks return joinedNetwork{}, ErrMultipleMatchingNetworks
} }
return networks[matchingNetworkIDs[0]], nil return networks[matchingNetworkIDs[0]], nil
@ -45,7 +51,7 @@ func pickNetwork(
func alreadyJoined( func alreadyJoined(
ctx context.Context, ctx context.Context,
networks map[string]network.Network, networks map[string]joinedNetwork,
creationParams bootstrap.CreationParams, creationParams bootstrap.CreationParams,
) ( ) (
bool, error, bool, error,

View File

@ -0,0 +1,42 @@
package network
import (
"errors"
"fmt"
"io/fs"
"isle/daemon/daecommon"
"isle/jsonutil"
"isle/toolkit"
"path/filepath"
)
// loadStoreConfig writes the given NetworkConfig to the networkStateDir if the
// config is non-nil. If the config is nil then a config is read from
// networkStateDir, returning the zero value if no config was previously
// stored.
func loadStoreConfig(
networkStateDir toolkit.Dir, config *daecommon.NetworkConfig,
) (
daecommon.NetworkConfig, error,
) {
path := filepath.Join(networkStateDir.Path, "config.json")
if config == nil {
config = new(daecommon.NetworkConfig)
err := jsonutil.LoadFile(&config, path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return daecommon.NetworkConfig{}, fmt.Errorf(
"loading %q: %w", path, err,
)
}
return *config, nil
}
if err := jsonutil.WriteFile(*config, path, 0600); err != nil {
return daecommon.NetworkConfig{}, fmt.Errorf(
"writing to %q: %w", path, err,
)
}
return *config, nil
}

View File

@ -59,7 +59,6 @@ type Loader interface {
Load( Load(
context.Context, context.Context,
*mlog.Logger, *mlog.Logger,
daecommon.NetworkConfig,
bootstrap.CreationParams, bootstrap.CreationParams,
*Opts, *Opts,
) ( ) (
@ -73,7 +72,6 @@ type Loader interface {
Join( Join(
context.Context, context.Context,
*mlog.Logger, *mlog.Logger,
daecommon.NetworkConfig,
JoiningBootstrap, JoiningBootstrap,
*Opts, *Opts,
) ( ) (
@ -91,12 +89,11 @@ type Loader interface {
// - hostName: The name of this first host in the network. // - hostName: The name of this first host in the network.
// //
// Errors: // Errors:
// - ErrInvalidConfig - if daemonConfig doesn't have 3 storage allocations // - ErrInvalidConfig - If the Opts.Config field is not valid. It must be
// configured. // non-nil and have at least 3 storage allocations.
Create( Create(
context.Context, context.Context,
*mlog.Logger, *mlog.Logger,
daecommon.NetworkConfig,
bootstrap.CreationParams, bootstrap.CreationParams,
nebula.IPNet, nebula.IPNet,
nebula.HostName, nebula.HostName,
@ -188,7 +185,7 @@ func (l *loader) Loadable(
creationParams := make([]bootstrap.CreationParams, 0, len(networkStateDirs)) creationParams := make([]bootstrap.CreationParams, 0, len(networkStateDirs))
for _, networkStateDir := range networkStateDirs { for _, networkStateDir := range networkStateDirs {
thisCreationParams, err := LoadCreationParams(networkStateDir) thisCreationParams, err := loadCreationParams(networkStateDir)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"loading creation params from %q: %w", "loading creation params from %q: %w",
@ -205,7 +202,6 @@ func (l *loader) Loadable(
func (l *loader) Load( func (l *loader) Load(
ctx context.Context, ctx context.Context,
logger *mlog.Logger, logger *mlog.Logger,
networkConfig daecommon.NetworkConfig,
creationParams bootstrap.CreationParams, creationParams bootstrap.CreationParams,
opts *Opts, opts *Opts,
) ( ) (
@ -226,7 +222,6 @@ func (l *loader) Load(
ctx, ctx,
logger.WithNamespace("network"), logger.WithNamespace("network"),
l.envBinDirPath, l.envBinDirPath,
networkConfig,
networkStateDir, networkStateDir,
networkRuntimeDir, networkRuntimeDir,
opts, opts,
@ -236,7 +231,6 @@ func (l *loader) Load(
func (l *loader) Join( func (l *loader) Join(
ctx context.Context, ctx context.Context,
logger *mlog.Logger, logger *mlog.Logger,
networkConfig daecommon.NetworkConfig,
joiningBootstrap JoiningBootstrap, joiningBootstrap JoiningBootstrap,
opts *Opts, opts *Opts,
) ( ) (
@ -260,7 +254,6 @@ func (l *loader) Join(
ctx, ctx,
logger.WithNamespace("network"), logger.WithNamespace("network"),
l.envBinDirPath, l.envBinDirPath,
networkConfig,
joiningBootstrap, joiningBootstrap,
networkStateDir, networkStateDir,
networkRuntimeDir, networkRuntimeDir,
@ -271,7 +264,6 @@ func (l *loader) Join(
func (l *loader) Create( func (l *loader) Create(
ctx context.Context, ctx context.Context,
logger *mlog.Logger, logger *mlog.Logger,
networkConfig daecommon.NetworkConfig,
creationParams bootstrap.CreationParams, creationParams bootstrap.CreationParams,
ipNet nebula.IPNet, ipNet nebula.IPNet,
hostName nebula.HostName, hostName nebula.HostName,
@ -294,7 +286,6 @@ func (l *loader) Create(
ctx, ctx,
logger.WithNamespace("network"), logger.WithNamespace("network"),
l.envBinDirPath, l.envBinDirPath,
networkConfig,
networkStateDir, networkStateDir,
networkRuntimeDir, networkRuntimeDir,
creationParams, creationParams,

View File

@ -6,8 +6,6 @@ import (
context "context" context "context"
bootstrap "isle/bootstrap" bootstrap "isle/bootstrap"
daecommon "isle/daemon/daecommon"
mlog "dev.mediocregopher.com/mediocre-go-lib.git/mlog" mlog "dev.mediocregopher.com/mediocre-go-lib.git/mlog"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
@ -20,9 +18,9 @@ type MockLoader struct {
mock.Mock mock.Mock
} }
// Create provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5, _a6 // Create provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5
func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.NetworkConfig, _a3 bootstrap.CreationParams, _a4 nebula.IPNet, _a5 nebula.HostName, _a6 *Opts) (Network, error) { func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 bootstrap.CreationParams, _a3 nebula.IPNet, _a4 nebula.HostName, _a5 *Opts) (Network, error) {
ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5, _a6) ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for Create") panic("no return value specified for Create")
@ -30,19 +28,19 @@ func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommo
var r0 Network var r0 Network
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) (Network, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) (Network, error)); ok {
return rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6) return rf(_a0, _a1, _a2, _a3, _a4, _a5)
} }
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) Network); ok { if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) Network); ok {
r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6) r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(Network) r0 = ret.Get(0).(Network)
} }
} }
if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) error); ok {
r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6) r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -50,9 +48,9 @@ func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommo
return r0, r1 return r0, r1
} }
// Join provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4 // Join provides a mock function with given fields: _a0, _a1, _a2, _a3
func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.NetworkConfig, _a3 JoiningBootstrap, _a4 *Opts) (Network, error) { func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 JoiningBootstrap, _a3 *Opts) (Network, error) {
ret := _m.Called(_a0, _a1, _a2, _a3, _a4) ret := _m.Called(_a0, _a1, _a2, _a3)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for Join") panic("no return value specified for Join")
@ -60,19 +58,19 @@ func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.
var r0 Network var r0 Network
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, JoiningBootstrap, *Opts) (Network, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, JoiningBootstrap, *Opts) (Network, error)); ok {
return rf(_a0, _a1, _a2, _a3, _a4) return rf(_a0, _a1, _a2, _a3)
} }
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, JoiningBootstrap, *Opts) Network); ok { if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, JoiningBootstrap, *Opts) Network); ok {
r0 = rf(_a0, _a1, _a2, _a3, _a4) r0 = rf(_a0, _a1, _a2, _a3)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(Network) r0 = ret.Get(0).(Network)
} }
} }
if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, JoiningBootstrap, *Opts) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, JoiningBootstrap, *Opts) error); ok {
r1 = rf(_a0, _a1, _a2, _a3, _a4) r1 = rf(_a0, _a1, _a2, _a3)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -80,9 +78,9 @@ func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.
return r0, r1 return r0, r1
} }
// Load provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4 // Load provides a mock function with given fields: _a0, _a1, _a2, _a3
func (_m *MockLoader) Load(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.NetworkConfig, _a3 bootstrap.CreationParams, _a4 *Opts) (Network, error) { func (_m *MockLoader) Load(_a0 context.Context, _a1 *mlog.Logger, _a2 bootstrap.CreationParams, _a3 *Opts) (Network, error) {
ret := _m.Called(_a0, _a1, _a2, _a3, _a4) ret := _m.Called(_a0, _a1, _a2, _a3)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for Load") panic("no return value specified for Load")
@ -90,19 +88,19 @@ func (_m *MockLoader) Load(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.
var r0 Network var r0 Network
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, *Opts) (Network, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, *Opts) (Network, error)); ok {
return rf(_a0, _a1, _a2, _a3, _a4) return rf(_a0, _a1, _a2, _a3)
} }
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, *Opts) Network); ok { if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, *Opts) Network); ok {
r0 = rf(_a0, _a1, _a2, _a3, _a4) r0 = rf(_a0, _a1, _a2, _a3)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(Network) r0 = ret.Get(0).(Network)
} }
} }
if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, *Opts) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, *Opts) error); ok {
r1 = rf(_a0, _a1, _a2, _a3, _a4) r1 = rf(_a0, _a1, _a2, _a3)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }

View File

@ -154,6 +154,14 @@ type Network interface {
// Network instance. A nil Opts is equivalent to a zero value. // Network instance. A nil Opts is equivalent to a zero value.
type Opts struct { type Opts struct {
GarageAdminToken string // Will be randomly generated if left unset. GarageAdminToken string // Will be randomly generated if left unset.
// Config will be used as the configuration of the Network from its
// initialization onwards.
//
// If not given then the most recent NetworkConfig for the network will be
// used, either that which it was most recently initialized with or which
// was passed to [SetConfig].
Config *daecommon.NetworkConfig
} }
func (o *Opts) withDefaults() *Opts { func (o *Opts) withDefaults() *Opts {
@ -189,36 +197,53 @@ type network struct {
wg sync.WaitGroup wg sync.WaitGroup
} }
// instatiateNetwork returns an instantiated *network instance which has not yet // newNetwork returns an instantiated *network instance. All initialization
// been initialized. // steps which are common to all *network creation methods (load, join, create)
func instatiateNetwork( // are included here as well.
func newNetwork(
ctx context.Context, ctx context.Context,
logger *mlog.Logger, logger *mlog.Logger,
networkConfig daecommon.NetworkConfig,
envBinDirPath string, envBinDirPath string,
stateDir toolkit.Dir, stateDir toolkit.Dir,
runtimeDir toolkit.Dir, runtimeDir toolkit.Dir,
dirsMayExist bool,
opts *Opts, opts *Opts,
) *network { ) (
ctx = context.WithoutCancel(ctx) *network, error,
ctx, cancel := context.WithCancel(ctx) ) {
return &network{ ctx, cancel := context.WithCancel(context.WithoutCancel(ctx))
logger: logger,
networkConfig: networkConfig, var (
envBinDirPath: envBinDirPath, n = &network{
stateDir: stateDir, logger: logger,
runtimeDir: runtimeDir, envBinDirPath: envBinDirPath,
opts: opts.withDefaults(), stateDir: stateDir,
workerCtx: ctx, runtimeDir: runtimeDir,
workerCancel: cancel, opts: opts.withDefaults(),
workerCtx: ctx,
workerCancel: cancel,
}
err error
)
n.networkConfig, err = loadStoreConfig(n.stateDir, n.opts.Config)
if err != nil {
return nil, fmt.Errorf("resolving network config: %w", err)
} }
secretsDir, err := n.stateDir.MkChildDir("secrets", dirsMayExist)
if err != nil {
return nil, fmt.Errorf("creating secrets dir: %w", err)
}
n.secretsStore = secrets.NewFSStore(secretsDir.Path)
return n, nil
} }
// LoadCreationParams returns the CreationParams of a Network which was // loadCreationParams returns the CreationParams of a Network which was
// Created/Joined with the given state directory. // Created/Joined with the given state directory.
// func loadCreationParams(
// TODO probably can be private
func LoadCreationParams(
stateDir toolkit.Dir, stateDir toolkit.Dir,
) ( ) (
bootstrap.CreationParams, error, bootstrap.CreationParams, error,
@ -244,25 +269,23 @@ func load(
ctx context.Context, ctx context.Context,
logger *mlog.Logger, logger *mlog.Logger,
envBinDirPath string, envBinDirPath string,
networkConfig daecommon.NetworkConfig,
stateDir toolkit.Dir, stateDir toolkit.Dir,
runtimeDir toolkit.Dir, runtimeDir toolkit.Dir,
opts *Opts, opts *Opts,
) ( ) (
Network, error, Network, error,
) { ) {
n := instatiateNetwork( n, err := newNetwork(
ctx, ctx,
logger, logger,
networkConfig,
envBinDirPath, envBinDirPath,
stateDir, stateDir,
runtimeDir, runtimeDir,
true,
opts, opts,
) )
if err != nil {
if err := n.initializeDirs(true); err != nil { return nil, fmt.Errorf("instantiating Network: %w", err)
return nil, fmt.Errorf("initializing directories: %w", err)
} }
var ( var (
@ -285,7 +308,6 @@ func join(
ctx context.Context, ctx context.Context,
logger *mlog.Logger, logger *mlog.Logger,
envBinDirPath string, envBinDirPath string,
networkConfig daecommon.NetworkConfig,
joiningBootstrap JoiningBootstrap, joiningBootstrap JoiningBootstrap,
stateDir toolkit.Dir, stateDir toolkit.Dir,
runtimeDir toolkit.Dir, runtimeDir toolkit.Dir,
@ -293,18 +315,17 @@ func join(
) ( ) (
Network, error, Network, error,
) { ) {
n := instatiateNetwork( n, err := newNetwork(
ctx, ctx,
logger, logger,
networkConfig,
envBinDirPath, envBinDirPath,
stateDir, stateDir,
runtimeDir, runtimeDir,
false,
opts, opts,
) )
if err != nil {
if err := n.initializeDirs(false); err != nil { return nil, fmt.Errorf("instantiating Network: %w", err)
return nil, fmt.Errorf("initializing directories: %w", err)
} }
if err := secrets.Import( if err := secrets.Import(
@ -324,7 +345,6 @@ func create(
ctx context.Context, ctx context.Context,
logger *mlog.Logger, logger *mlog.Logger,
envBinDirPath string, envBinDirPath string,
networkConfig daecommon.NetworkConfig,
stateDir toolkit.Dir, stateDir toolkit.Dir,
runtimeDir toolkit.Dir, runtimeDir toolkit.Dir,
creationParams bootstrap.CreationParams, creationParams bootstrap.CreationParams,
@ -334,12 +354,6 @@ func create(
) ( ) (
Network, error, Network, error,
) { ) {
if len(networkConfig.Storage.Allocations) < 3 {
return nil, ErrInvalidConfig.WithData(
"At least three storage allocations are required.",
)
}
nebulaCACreds, err := nebula.NewCACredentials(creationParams.Domain, ipNet) nebulaCACreds, err := nebula.NewCACredentials(creationParams.Domain, ipNet)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating nebula CA cert: %w", err) return nil, fmt.Errorf("creating nebula CA cert: %w", err)
@ -347,18 +361,23 @@ func create(
garageRPCSecret := toolkit.RandStr(32) garageRPCSecret := toolkit.RandStr(32)
n := instatiateNetwork( n, err := newNetwork(
ctx, ctx,
logger, logger,
networkConfig,
envBinDirPath, envBinDirPath,
stateDir, stateDir,
runtimeDir, runtimeDir,
false,
opts, opts,
) )
if err != nil {
return nil, fmt.Errorf("instantiating Network: %w", err)
}
if err := n.initializeDirs(false); err != nil { if len(n.networkConfig.Storage.Allocations) < 3 {
return nil, fmt.Errorf("initializing directories: %w", err) return nil, ErrInvalidConfig.WithData(
"At least three storage allocations are required.",
)
} }
err = daecommon.SetGarageRPCSecret(ctx, n.secretsStore, garageRPCSecret) err = daecommon.SetGarageRPCSecret(ctx, n.secretsStore, garageRPCSecret)
@ -391,16 +410,6 @@ func create(
return n, nil return n, nil
} }
func (n *network) initializeDirs(mayExist bool) error {
secretsDir, err := n.stateDir.MkChildDir("secrets", mayExist)
if err != nil {
return fmt.Errorf("creating secrets dir: %w", err)
}
n.secretsStore = secrets.NewFSStore(secretsDir.Path)
return nil
}
func (n *network) periodically( func (n *network) periodically(
label string, label string,
fn func(context.Context) error, fn func(context.Context) error,
@ -1016,6 +1025,10 @@ func (n *network) GetConfig(context.Context) (daecommon.NetworkConfig, error) {
func (n *network) SetConfig( func (n *network) SetConfig(
ctx context.Context, config daecommon.NetworkConfig, ctx context.Context, config daecommon.NetworkConfig,
) error { ) error {
if _, err := loadStoreConfig(n.stateDir, &config); err != nil {
return fmt.Errorf("storing new config: %w", err)
}
prevBootstrap, err := n.reload(ctx, &config, nil) prevBootstrap, err := n.reload(ctx, &config, nil)
if err != nil { if err != nil {
return fmt.Errorf("reloading config: %w", err) return fmt.Errorf("reloading config: %w", err)

View File

@ -17,7 +17,7 @@ func TestCreate(t *testing.T) {
network = h.createNetwork(t, "primus", nil) network = h.createNetwork(t, "primus", nil)
) )
gotCreationParams, err := LoadCreationParams(network.stateDir) gotCreationParams, err := loadCreationParams(network.stateDir)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal( assert.Equal(
t, gotCreationParams, network.getBootstrap(t).NetworkCreationParams, t, gotCreationParams, network.getBootstrap(t).NetworkCreationParams,
@ -25,31 +25,30 @@ func TestCreate(t *testing.T) {
} }
func TestLoad(t *testing.T) { func TestLoad(t *testing.T) {
var ( t.Run("given config", func(t *testing.T) {
h = newIntegrationHarness(t) var (
network = h.createNetwork(t, "primus", &createNetworkOpts{ h = newIntegrationHarness(t)
manualShutdown: true, network = h.createNetwork(t, "primus", nil)
}) networkConfig = network.getConfig(t)
) )
t.Log("Shutting down network") network.opts.Config = &networkConfig
assert.NoError(t, network.Shutdown()) network.restart(t)
t.Log("Calling Load") assert.Equal(t, networkConfig, network.getConfig(t))
loadedNetwork, err := load( })
h.ctx,
h.logger.WithNamespace("loadedNetwork"),
getEnvBinDirPath(),
network.getConfig(t),
network.stateDir,
h.mkDir(t, "runtime"),
network.opts,
)
assert.NoError(t, err)
t.Cleanup(func() { t.Run("load previous config", func(t *testing.T) {
t.Log("Shutting down loadedNetwork") var (
assert.NoError(t, loadedNetwork.Shutdown()) h = newIntegrationHarness(t)
network = h.createNetwork(t, "primus", nil)
networkConfig = network.getConfig(t)
)
network.opts.Config = nil
network.restart(t)
assert.Equal(t, networkConfig, network.getConfig(t))
}) })
} }
@ -61,13 +60,7 @@ func TestJoin(t *testing.T) {
secondus = h.joinNetwork(t, primus, "secondus", nil) secondus = h.joinNetwork(t, primus, "secondus", nil)
) )
primusHosts, err := primus.GetHosts(h.ctx) assert.Equal(t, primus.getHostsByName(t), secondus.getHostsByName(t))
assert.NoError(t, err)
secondusHosts, err := secondus.GetHosts(h.ctx)
assert.NoError(t, err)
assert.Equal(t, primusHosts, secondusHosts)
}) })
t.Run("with alloc", func(t *testing.T) { t.Run("with alloc", func(t *testing.T) {
@ -84,28 +77,10 @@ func TestJoin(t *testing.T) {
t.Log("reloading primus' hosts") t.Log("reloading primus' hosts")
assert.NoError(t, primus.Network.(*network).reloadHosts(h.ctx)) assert.NoError(t, primus.Network.(*network).reloadHosts(h.ctx))
primusHosts, err := primus.GetHosts(h.ctx) assert.Equal(t, primus.getHostsByName(t), secondus.getHostsByName(t))
assert.NoError(t, err)
secondusHosts, err := secondus.GetHosts(h.ctx)
assert.NoError(t, err)
assert.Equal(t, primusHosts, secondusHosts)
}) })
} }
func TestNetwork_GetConfig(t *testing.T) {
var (
h = newIntegrationHarness(t)
network = h.createNetwork(t, "primus", nil)
)
config, err := network.GetConfig(h.ctx)
assert.NoError(t, err)
assert.Equal(t, config, network.getConfig(t))
}
func TestNetwork_SetConfig(t *testing.T) { func TestNetwork_SetConfig(t *testing.T) {
allocsToRoles := func( allocsToRoles := func(
hostName nebula.HostName, allocs []bootstrap.GarageHostInstance, hostName nebula.HostName, allocs []bootstrap.GarageHostInstance,
@ -259,4 +234,22 @@ func TestNetwork_SetConfig(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotContains(t, layout.Roles, removedRole) assert.NotContains(t, layout.Roles, removedRole)
}) })
t.Run("changes reflected after restart", func(t *testing.T) {
var (
h = newIntegrationHarness(t)
network = h.createNetwork(t, "primus", &createNetworkOpts{
numStorageAllocs: 4,
})
networkConfig = network.getConfig(t)
)
networkConfig.Storage.Allocations = networkConfig.Storage.Allocations[:3]
assert.NoError(t, network.SetConfig(h.ctx, networkConfig))
network.opts.Config = nil
network.restart(t)
assert.Equal(t, networkConfig, network.getConfig(t))
})
} }

View File

@ -2,7 +2,6 @@ package network
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"isle/bootstrap" "isle/bootstrap"
"isle/daemon/daecommon" "isle/daemon/daecommon"
@ -18,7 +17,6 @@ import (
"dev.mediocregopher.com/mediocre-go-lib.git/mlog" "dev.mediocregopher.com/mediocre-go-lib.git/mlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
) )
// Utilities related to running network integration tests // Utilities related to running network integration tests
@ -62,16 +60,6 @@ func newTunDevice() string {
return fmt.Sprintf("isle-test-%d", atomic.AddUint64(&tunDeviceCounter, 1)) return fmt.Sprintf("isle-test-%d", atomic.AddUint64(&tunDeviceCounter, 1))
} }
func mustParseNetworkConfigf(str string, args ...any) daecommon.NetworkConfig {
str = fmt.Sprintf(str, args...)
var networkConfig daecommon.NetworkConfig
if err := yaml.Unmarshal([]byte(str), &networkConfig); err != nil {
panic(fmt.Sprintf("parsing network config: %v", err))
}
return networkConfig
}
type integrationHarness struct { type integrationHarness struct {
ctx context.Context ctx context.Context
logger *mlog.Logger logger *mlog.Logger
@ -136,35 +124,24 @@ func (h *integrationHarness) mkNetworkConfig(
opts = new(networkConfigOpts) opts = new(networkConfigOpts)
} }
publicAddr := "" return daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
if opts.hasPublicAddr { if opts.hasPublicAddr {
publicAddr = newPublicAddr() c.VPN.PublicAddr = newPublicAddr()
}
allocs := make([]map[string]any, opts.numStorageAllocs)
for i := range allocs {
allocs[i] = map[string]any{
"data_path": h.mkDir(t, "data").Path,
"meta_path": h.mkDir(t, "meta").Path,
"capacity": 1,
} }
}
allocsJSON, err := json.Marshal(allocs) c.VPN.Tun.Device = newTunDevice() // TODO is this necessary??
require.NoError(t, err)
return mustParseNetworkConfigf(` c.Storage.Allocations = make(
vpn: []daecommon.ConfigStorageAllocation, opts.numStorageAllocs,
public_addr: %q )
tun: for i := range c.Storage.Allocations {
device: %q c.Storage.Allocations[i] = daecommon.ConfigStorageAllocation{
storage: DataPath: h.mkDir(t, "data").Path,
allocations: %s MetaPath: h.mkDir(t, "meta").Path,
`, Capacity: 1,
publicAddr, }
newTunDevice(), }
allocsJSON, })
)
} }
type createNetworkOpts struct { type createNetworkOpts struct {
@ -203,7 +180,7 @@ func (h *integrationHarness) createNetwork(
t *testing.T, t *testing.T,
hostNameStr string, hostNameStr string,
opts *createNetworkOpts, opts *createNetworkOpts,
) integrationHarnessNetwork { ) *integrationHarnessNetwork {
t.Logf("Creating as %q", hostNameStr) t.Logf("Creating as %q", hostNameStr)
opts = opts.withDefaults() opts = opts.withDefaults()
@ -222,6 +199,7 @@ func (h *integrationHarness) createNetwork(
networkOpts = &Opts{ networkOpts = &Opts{
GarageAdminToken: "admin_token", GarageAdminToken: "admin_token",
Config: &networkConfig,
} }
) )
@ -229,7 +207,6 @@ func (h *integrationHarness) createNetwork(
h.ctx, h.ctx,
logger, logger,
getEnvBinDirPath(), getEnvBinDirPath(),
networkConfig,
stateDir, stateDir,
runtimeDir, runtimeDir,
opts.creationParams, opts.creationParams,
@ -241,16 +218,7 @@ func (h *integrationHarness) createNetwork(
t.Fatalf("creating Network: %v", err) t.Fatalf("creating Network: %v", err)
} }
if !opts.manualShutdown { nh := &integrationHarnessNetwork{
t.Cleanup(func() {
t.Logf("Shutting down Network %q", hostNameStr)
if err := network.Shutdown(); err != nil {
t.Logf("Shutting down Network %q failed: %v", hostNameStr, err)
}
})
}
return integrationHarnessNetwork{
network, network,
h.ctx, h.ctx,
logger, logger,
@ -259,6 +227,17 @@ func (h *integrationHarness) createNetwork(
runtimeDir, runtimeDir,
networkOpts, networkOpts,
} }
if !opts.manualShutdown {
t.Cleanup(func() {
t.Logf("Shutting down Network %q", hostNameStr)
if err := nh.Shutdown(); err != nil {
t.Logf("Shutting down Network %q failed: %v", hostNameStr, err)
}
})
}
return nh
} }
type joinNetworkOpts struct { type joinNetworkOpts struct {
@ -279,10 +258,10 @@ func (o *joinNetworkOpts) withDefaults() *joinNetworkOpts {
func (h *integrationHarness) joinNetwork( func (h *integrationHarness) joinNetwork(
t *testing.T, t *testing.T,
network integrationHarnessNetwork, network *integrationHarnessNetwork,
hostNameStr string, hostNameStr string,
opts *joinNetworkOpts, opts *joinNetworkOpts,
) integrationHarnessNetwork { ) *integrationHarnessNetwork {
opts = opts.withDefaults() opts = opts.withDefaults()
hostName := nebula.HostName(hostNameStr) hostName := nebula.HostName(hostNameStr)
@ -301,6 +280,7 @@ func (h *integrationHarness) joinNetwork(
runtimeDir = h.mkDir(t, "runtime") runtimeDir = h.mkDir(t, "runtime")
networkOpts = &Opts{ networkOpts = &Opts{
GarageAdminToken: "admin_token", GarageAdminToken: "admin_token",
Config: &networkConfig,
} }
) )
@ -309,7 +289,6 @@ func (h *integrationHarness) joinNetwork(
h.ctx, h.ctx,
logger, logger,
getEnvBinDirPath(), getEnvBinDirPath(),
networkConfig,
joiningBootstrap, joiningBootstrap,
stateDir, stateDir,
runtimeDir, runtimeDir,
@ -319,16 +298,7 @@ func (h *integrationHarness) joinNetwork(
t.Fatalf("joining network: %v", err) t.Fatalf("joining network: %v", err)
} }
if !opts.manualShutdown { nh := &integrationHarnessNetwork{
t.Cleanup(func() {
t.Logf("Shutting down Network %q", hostNameStr)
if err := joinedNetwork.Shutdown(); err != nil {
t.Logf("Shutting down Network %q failed: %v", hostNameStr, err)
}
})
}
return integrationHarnessNetwork{
joinedNetwork, joinedNetwork,
h.ctx, h.ctx,
logger, logger,
@ -337,6 +307,34 @@ func (h *integrationHarness) joinNetwork(
runtimeDir, runtimeDir,
networkOpts, networkOpts,
} }
if !opts.manualShutdown {
t.Cleanup(func() {
t.Logf("Shutting down Network %q", hostNameStr)
if err := nh.Shutdown(); err != nil {
t.Logf("Shutting down Network %q failed: %v", hostNameStr, err)
}
})
}
return nh
}
func (nh *integrationHarnessNetwork) restart(t *testing.T) {
t.Log("Shutting down network (restart)")
require.NoError(t, nh.Network.Shutdown())
t.Log("Loading network (restart)")
var err error
nh.Network, err = load(
nh.ctx,
nh.logger,
getEnvBinDirPath(),
nh.stateDir,
nh.runtimeDir,
nh.opts,
)
require.NoError(t, err)
} }
func (nh *integrationHarnessNetwork) getConfig(t *testing.T) daecommon.NetworkConfig { func (nh *integrationHarnessNetwork) getConfig(t *testing.T) daecommon.NetworkConfig {

View File

@ -3,6 +3,7 @@ package daemon
import ( import (
"context" "context"
"isle/bootstrap" "isle/bootstrap"
"isle/daemon/daecommon"
"isle/daemon/jsonrpc2" "isle/daemon/jsonrpc2"
"isle/daemon/network" "isle/daemon/network"
"isle/nebula" "isle/nebula"
@ -25,13 +26,23 @@ type RPC interface {
GetNetworks(context.Context) ([]bootstrap.CreationParams, error) GetNetworks(context.Context) ([]bootstrap.CreationParams, error)
// SetConfig extends the [network.RPC] method of the same name such that
// [ErrUserManagedNetworkConfig] is returned if the picked network is
// configured as part of the [daecommon.Config] which the Daemon was
// initialized with.
//
// See the `network.RPC` documentation in this interface for more usage
// details.
SetConfig(context.Context, daecommon.NetworkConfig) error
// All network.RPC methods are automatically implemented by Daemon using the // All network.RPC methods are automatically implemented by Daemon using the
// currently joined network. If no network is joined then any call to these // currently joined network. If no network is joined then any call to these
// methods will return ErrNoNetwork. // methods will return ErrNoNetwork.
// //
// All calls to these methods must be accompanied with a context produced by // If more than one Network is joined then all calls to these methods must
// WithNetwork, in order to choose the network. These methods may return // be accompanied with a context produced by WithNetwork, in order to choose
// these errors, in addition to those documented on the individual methods: // the network. These methods may return these errors, in addition to those
// documented on the individual methods:
// //
// Errors: // Errors:
// - ErrNoNetwork // - ErrNoNetwork

View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"dev.mediocregopher.com/mediocre-go-lib.git/mlog" "dev.mediocregopher.com/mediocre-go-lib.git/mlog"
"github.com/stretchr/testify/mock"
) )
// MarkIntegrationTest marks a test as being an integration test. It will be // MarkIntegrationTest marks a test as being an integration test. It will be
@ -31,3 +32,10 @@ func NewTestLogger(t *testing.T) *mlog.Logger {
MaxLevel: level.Int(), MaxLevel: level.Int(),
}) })
} }
// MockArg returns a value which can be used as a [mock.Call] argument, and
// which will match any value of type T. If T is an interface then also values
// implementing that interface will be matched.
func MockArg[T any]() any {
return mock.MatchedBy(func(T) bool { return true })
}

29
go/yamlutil/testutils.go Normal file
View File

@ -0,0 +1,29 @@
package yamlutil
import (
"strings"
)
// ReplacePrefixTabs will replace the leading tabs of each line with two spaces.
// Tabs are not a valid indent in yaml (wtf), but they are convenient to use in
// go when using multi-line strings.
//
// ReplacePrefixTabs should only be used within tests.
func ReplacePrefixTabs(str string) string {
lines := strings.Split(str, "\n")
for i := range lines {
var n int
for _, r := range lines[i] {
if r == '\t' {
n++
} else {
break
}
}
spaces := strings.Repeat(" ", n)
lines[i] = spaces + lines[i][n:]
}
return strings.Join(lines, "\n")
}