Pass NetworkConfig into Network loaders as an optional argument
This commit is contained in:
parent
72bca72b29
commit
6ec56f2a88
@ -44,22 +44,20 @@ var HTTPSocketPath = sync.OnceValue(func() string {
|
|||||||
|
|
||||||
func pickNetworkConfig(
|
func pickNetworkConfig(
|
||||||
daemonConfig daecommon.Config, creationParams bootstrap.CreationParams,
|
daemonConfig daecommon.Config, creationParams bootstrap.CreationParams,
|
||||||
) (
|
) *daecommon.NetworkConfig {
|
||||||
daecommon.NetworkConfig, bool,
|
|
||||||
) {
|
|
||||||
if len(daemonConfig.Networks) == 1 { // DEPRECATED
|
if len(daemonConfig.Networks) == 1 { // DEPRECATED
|
||||||
if c, ok := daemonConfig.Networks[daecommon.DeprecatedNetworkID]; ok {
|
if c, ok := daemonConfig.Networks[daecommon.DeprecatedNetworkID]; ok {
|
||||||
return c, true
|
return &c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for searchStr, networkConfig := range daemonConfig.Networks {
|
for searchStr, networkConfig := range daemonConfig.Networks {
|
||||||
if creationParams.Matches(searchStr) {
|
if creationParams.Matches(searchStr) {
|
||||||
return networkConfig, true
|
return &networkConfig
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return daecommon.NetworkConfig{}, false
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -86,6 +86,18 @@ type NetworkConfig struct {
|
|||||||
} `yaml:"storage"`
|
} `yaml:"storage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewNetworkConfig returns a new NetworkConfig populated with its default
|
||||||
|
// values. If a callback is given the NetworkConfig will be passed to it for
|
||||||
|
// modification prior to having defaults populated.
|
||||||
|
func NewNetworkConfig(fn func(*NetworkConfig)) NetworkConfig {
|
||||||
|
var c NetworkConfig
|
||||||
|
if fn != nil {
|
||||||
|
fn(&c)
|
||||||
|
}
|
||||||
|
c.fillDefaults()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
func (c *NetworkConfig) fillDefaults() {
|
func (c *NetworkConfig) fillDefaults() {
|
||||||
if c.DNS.Resolvers == nil {
|
if c.DNS.Resolvers == nil {
|
||||||
c.DNS.Resolvers = []string{
|
c.DNS.Resolvers = []string{
|
||||||
@ -213,9 +225,15 @@ func CopyDefaultConfig(into io.Writer) error {
|
|||||||
// fill in default values where it can.
|
// fill in default values where it can.
|
||||||
func (c *Config) UnmarshalYAML(n *yaml.Node) error {
|
func (c *Config) UnmarshalYAML(n *yaml.Node) error {
|
||||||
{ // DEPRECATED
|
{ // DEPRECATED
|
||||||
|
// We decode into a wrapped NetworkConfig, so that its UnmarshalYAML
|
||||||
|
// doesn't get invoked. This prevents fillDefaults from getting
|
||||||
|
// automatically called, so the IsZero check will work as intended. This
|
||||||
|
// means we need to call fillDefaults manually though.
|
||||||
|
type wrap NetworkConfig
|
||||||
var networkConfig NetworkConfig
|
var networkConfig NetworkConfig
|
||||||
_ = n.Decode(&networkConfig)
|
_ = n.Decode((*wrap)(&networkConfig))
|
||||||
if !toolkit.IsZero(networkConfig) {
|
if !toolkit.IsZero(networkConfig) {
|
||||||
|
networkConfig.fillDefaults()
|
||||||
*c = Config{
|
*c = Config{
|
||||||
Networks: map[string]NetworkConfig{
|
Networks: map[string]NetworkConfig{
|
||||||
DeprecatedNetworkID: networkConfig,
|
DeprecatedNetworkID: networkConfig,
|
||||||
|
72
go/daemon/daecommon/config_test.go
Normal file
72
go/daemon/daecommon/config_test.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package daecommon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"isle/yamlutil"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_UnmarshalYAML(t *testing.T) { // TODO test validation
|
||||||
|
tests := []struct {
|
||||||
|
label string
|
||||||
|
str string
|
||||||
|
config Config
|
||||||
|
}{
|
||||||
|
{"empty", ``, Config{}},
|
||||||
|
{
|
||||||
|
"DEPRECATED single global network",
|
||||||
|
`
|
||||||
|
{"dns":{"resolvers":["a"]}}
|
||||||
|
`,
|
||||||
|
Config{
|
||||||
|
Networks: map[string]NetworkConfig{
|
||||||
|
DeprecatedNetworkID: NewNetworkConfig(func(c *NetworkConfig) {
|
||||||
|
c.DNS.Resolvers = []string{"a"}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"single network",
|
||||||
|
`
|
||||||
|
networks:
|
||||||
|
foo: {"dns":{"resolvers":["a"]}}
|
||||||
|
`,
|
||||||
|
Config{
|
||||||
|
Networks: map[string]NetworkConfig{
|
||||||
|
"foo": NewNetworkConfig(func(c *NetworkConfig) {
|
||||||
|
c.DNS.Resolvers = []string{"a"}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"multiple networks",
|
||||||
|
`
|
||||||
|
networks:
|
||||||
|
foo: {"dns":{"resolvers":["a"]}}
|
||||||
|
bar: {}
|
||||||
|
`,
|
||||||
|
Config{
|
||||||
|
Networks: map[string]NetworkConfig{
|
||||||
|
"foo": NewNetworkConfig(func(c *NetworkConfig) {
|
||||||
|
c.DNS.Resolvers = []string{"a"}
|
||||||
|
}),
|
||||||
|
"bar": NewNetworkConfig(nil),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.label, func(t *testing.T) {
|
||||||
|
test.str = yamlutil.ReplacePrefixTabs(test.str)
|
||||||
|
var config Config
|
||||||
|
err := yaml.Unmarshal([]byte(test.str), &config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, test.config, config)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -19,6 +19,11 @@ import (
|
|||||||
|
|
||||||
var _ RPC = (*Daemon)(nil)
|
var _ RPC = (*Daemon)(nil)
|
||||||
|
|
||||||
|
type joinedNetwork struct {
|
||||||
|
network.Network
|
||||||
|
config *daecommon.NetworkConfig
|
||||||
|
}
|
||||||
|
|
||||||
// Daemon implements all methods of the Daemon interface, plus others used
|
// Daemon implements all methods of the Daemon interface, plus others used
|
||||||
// to manage this particular implementation.
|
// to manage this particular implementation.
|
||||||
//
|
//
|
||||||
@ -41,7 +46,7 @@ type Daemon struct {
|
|||||||
daemonConfig daecommon.Config
|
daemonConfig daecommon.Config
|
||||||
|
|
||||||
l sync.RWMutex
|
l sync.RWMutex
|
||||||
networks map[string]network.Network
|
networks map[string]joinedNetwork
|
||||||
}
|
}
|
||||||
|
|
||||||
// New initializes and returns a Daemon.
|
// New initializes and returns a Daemon.
|
||||||
@ -57,7 +62,7 @@ func New(
|
|||||||
logger: logger,
|
logger: logger,
|
||||||
networkLoader: networkLoader,
|
networkLoader: networkLoader,
|
||||||
daemonConfig: daemonConfig,
|
daemonConfig: daemonConfig,
|
||||||
networks: map[string]network.Network{},
|
networks: map[string]joinedNetwork{},
|
||||||
}
|
}
|
||||||
|
|
||||||
loadableNetworks, err := networkLoader.Loadable(ctx)
|
loadableNetworks, err := networkLoader.Loadable(ctx)
|
||||||
@ -69,20 +74,23 @@ func New(
|
|||||||
ctx = mctx.WithAnnotator(ctx, creationParams)
|
ctx = mctx.WithAnnotator(ctx, creationParams)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
id = creationParams.ID
|
id = creationParams.ID
|
||||||
networkConfig, _ = pickNetworkConfig(daemonConfig, creationParams)
|
networkConfig = pickNetworkConfig(daemonConfig, creationParams)
|
||||||
)
|
)
|
||||||
|
|
||||||
d.networks[id], err = networkLoader.Load(
|
n, err := networkLoader.Load(
|
||||||
ctx,
|
ctx,
|
||||||
logger.WithNamespace("network"),
|
logger.WithNamespace("network"),
|
||||||
networkConfig,
|
|
||||||
creationParams,
|
creationParams,
|
||||||
nil,
|
&network.Opts{
|
||||||
|
Config: networkConfig,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("loading network %q: %w", id, err)
|
return nil, fmt.Errorf("loading network %q: %w", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.networks[id] = joinedNetwork{n, networkConfig}
|
||||||
}
|
}
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
@ -105,17 +113,16 @@ func (d *Daemon) CreateNetwork(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
name, domain string, ipNet nebula.IPNet, hostName nebula.HostName,
|
name, domain string, ipNet nebula.IPNet, hostName nebula.HostName,
|
||||||
) error {
|
) error {
|
||||||
creationParams := bootstrap.NewCreationParams(name, domain)
|
|
||||||
ctx = mctx.WithAnnotator(ctx, creationParams)
|
|
||||||
|
|
||||||
networkConfig, ok := pickNetworkConfig(d.daemonConfig, creationParams)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("couldn't find network config for network being created")
|
|
||||||
}
|
|
||||||
|
|
||||||
d.l.Lock()
|
d.l.Lock()
|
||||||
defer d.l.Unlock()
|
defer d.l.Unlock()
|
||||||
|
|
||||||
|
var (
|
||||||
|
creationParams = bootstrap.NewCreationParams(name, domain)
|
||||||
|
networkConfig = pickNetworkConfig(d.daemonConfig, creationParams)
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = mctx.WithAnnotator(ctx, creationParams)
|
||||||
|
|
||||||
if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil {
|
if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil {
|
||||||
return fmt.Errorf("checking if already joined to network: %w", err)
|
return fmt.Errorf("checking if already joined to network: %w", err)
|
||||||
} else if joined {
|
} else if joined {
|
||||||
@ -126,18 +133,19 @@ func (d *Daemon) CreateNetwork(
|
|||||||
n, err := d.networkLoader.Create(
|
n, err := d.networkLoader.Create(
|
||||||
ctx,
|
ctx,
|
||||||
d.logger.WithNamespace("network"),
|
d.logger.WithNamespace("network"),
|
||||||
networkConfig,
|
|
||||||
creationParams,
|
creationParams,
|
||||||
ipNet,
|
ipNet,
|
||||||
hostName,
|
hostName,
|
||||||
nil,
|
&network.Opts{
|
||||||
|
Config: networkConfig,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("creating network: %w", err)
|
return fmt.Errorf("creating network: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
d.logger.Info(ctx, "Network created successfully")
|
d.logger.Info(ctx, "Network created successfully")
|
||||||
d.networks[creationParams.ID] = n
|
d.networks[creationParams.ID] = joinedNetwork{n, networkConfig}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,17 +157,17 @@ func (d *Daemon) CreateNetwork(
|
|||||||
func (d *Daemon) JoinNetwork(
|
func (d *Daemon) JoinNetwork(
|
||||||
ctx context.Context, newBootstrap network.JoiningBootstrap,
|
ctx context.Context, newBootstrap network.JoiningBootstrap,
|
||||||
) error {
|
) error {
|
||||||
|
d.l.Lock()
|
||||||
|
defer d.l.Unlock()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
creationParams = newBootstrap.Bootstrap.NetworkCreationParams
|
creationParams = newBootstrap.Bootstrap.NetworkCreationParams
|
||||||
networkConfig, _ = pickNetworkConfig(d.daemonConfig, creationParams)
|
networkID = creationParams.ID
|
||||||
networkID = creationParams.ID
|
networkConfig = pickNetworkConfig(d.daemonConfig, creationParams)
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx = mctx.WithAnnotator(ctx, newBootstrap.Bootstrap.NetworkCreationParams)
|
ctx = mctx.WithAnnotator(ctx, newBootstrap.Bootstrap.NetworkCreationParams)
|
||||||
|
|
||||||
d.l.Lock()
|
|
||||||
defer d.l.Unlock()
|
|
||||||
|
|
||||||
if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil {
|
if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil {
|
||||||
return fmt.Errorf("checking if already joined to network: %w", err)
|
return fmt.Errorf("checking if already joined to network: %w", err)
|
||||||
} else if joined {
|
} else if joined {
|
||||||
@ -170,9 +178,10 @@ func (d *Daemon) JoinNetwork(
|
|||||||
n, err := d.networkLoader.Join(
|
n, err := d.networkLoader.Join(
|
||||||
ctx,
|
ctx,
|
||||||
d.logger.WithNamespace("network"),
|
d.logger.WithNamespace("network"),
|
||||||
networkConfig,
|
|
||||||
newBootstrap,
|
newBootstrap,
|
||||||
nil,
|
&network.Opts{
|
||||||
|
Config: networkConfig,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
@ -181,14 +190,14 @@ func (d *Daemon) JoinNetwork(
|
|||||||
}
|
}
|
||||||
|
|
||||||
d.logger.Info(ctx, "Network joined successfully")
|
d.logger.Info(ctx, "Network joined successfully")
|
||||||
d.networks[networkID] = n
|
d.networks[networkID] = joinedNetwork{n, networkConfig}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func withNetwork[Res any](
|
func withNetwork[Res any](
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
d *Daemon,
|
d *Daemon,
|
||||||
fn func(context.Context, network.Network) (Res, error),
|
fn func(context.Context, joinedNetwork) (Res, error),
|
||||||
) (
|
) (
|
||||||
Res, error,
|
Res, error,
|
||||||
) {
|
) {
|
||||||
@ -238,7 +247,7 @@ func (d *Daemon) GetHosts(ctx context.Context) ([]bootstrap.Host, error) {
|
|||||||
return withNetwork(
|
return withNetwork(
|
||||||
ctx,
|
ctx,
|
||||||
d,
|
d,
|
||||||
func(ctx context.Context, n network.Network) ([]bootstrap.Host, error) {
|
func(ctx context.Context, n joinedNetwork) ([]bootstrap.Host, error) {
|
||||||
return n.GetHosts(ctx)
|
return n.GetHosts(ctx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -254,7 +263,7 @@ func (d *Daemon) GetGarageClientParams(
|
|||||||
ctx,
|
ctx,
|
||||||
d,
|
d,
|
||||||
func(
|
func(
|
||||||
ctx context.Context, n network.Network,
|
ctx context.Context, n joinedNetwork,
|
||||||
) (
|
) (
|
||||||
network.GarageClientParams, error,
|
network.GarageClientParams, error,
|
||||||
) {
|
) {
|
||||||
@ -274,7 +283,7 @@ func (d *Daemon) GetNebulaCAPublicCredentials(
|
|||||||
ctx,
|
ctx,
|
||||||
d,
|
d,
|
||||||
func(
|
func(
|
||||||
ctx context.Context, n network.Network,
|
ctx context.Context, n joinedNetwork,
|
||||||
) (
|
) (
|
||||||
nebula.CAPublicCredentials, error,
|
nebula.CAPublicCredentials, error,
|
||||||
) {
|
) {
|
||||||
@ -289,7 +298,7 @@ func (d *Daemon) RemoveHost(ctx context.Context, hostName nebula.HostName) error
|
|||||||
ctx,
|
ctx,
|
||||||
d,
|
d,
|
||||||
func(
|
func(
|
||||||
ctx context.Context, n network.Network,
|
ctx context.Context, n joinedNetwork,
|
||||||
) (
|
) (
|
||||||
struct{}, error,
|
struct{}, error,
|
||||||
) {
|
) {
|
||||||
@ -311,7 +320,7 @@ func (d *Daemon) CreateHost(
|
|||||||
ctx,
|
ctx,
|
||||||
d,
|
d,
|
||||||
func(
|
func(
|
||||||
ctx context.Context, n network.Network,
|
ctx context.Context, n joinedNetwork,
|
||||||
) (
|
) (
|
||||||
network.JoiningBootstrap, error,
|
network.JoiningBootstrap, error,
|
||||||
) {
|
) {
|
||||||
@ -332,7 +341,7 @@ func (d *Daemon) CreateNebulaCertificate(
|
|||||||
ctx,
|
ctx,
|
||||||
d,
|
d,
|
||||||
func(
|
func(
|
||||||
ctx context.Context, n network.Network,
|
ctx context.Context, n joinedNetwork,
|
||||||
) (
|
) (
|
||||||
nebula.Certificate, error,
|
nebula.Certificate, error,
|
||||||
) {
|
) {
|
||||||
@ -341,6 +350,7 @@ func (d *Daemon) CreateNebulaCertificate(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfig implements the method for the network.RPC interface.
|
||||||
func (d *Daemon) GetConfig(
|
func (d *Daemon) GetConfig(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
) (
|
) (
|
||||||
@ -350,7 +360,7 @@ func (d *Daemon) GetConfig(
|
|||||||
ctx,
|
ctx,
|
||||||
d,
|
d,
|
||||||
func(
|
func(
|
||||||
ctx context.Context, n network.Network,
|
ctx context.Context, n joinedNetwork,
|
||||||
) (
|
) (
|
||||||
daecommon.NetworkConfig, error,
|
daecommon.NetworkConfig, error,
|
||||||
) {
|
) {
|
||||||
@ -359,13 +369,18 @@ func (d *Daemon) GetConfig(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetConfig implements the method for the network.RPC interface.
|
||||||
func (d *Daemon) SetConfig(
|
func (d *Daemon) SetConfig(
|
||||||
ctx context.Context, config daecommon.NetworkConfig,
|
ctx context.Context, config daecommon.NetworkConfig,
|
||||||
) error {
|
) error {
|
||||||
_, err := withNetwork(
|
_, err := withNetwork(
|
||||||
ctx,
|
ctx,
|
||||||
d,
|
d,
|
||||||
func(ctx context.Context, n network.Network) (struct{}, error) {
|
func(ctx context.Context, n joinedNetwork) (struct{}, error) {
|
||||||
|
if n.config != nil {
|
||||||
|
return struct{}{}, ErrUserManagedNetworkConfig
|
||||||
|
}
|
||||||
|
|
||||||
// TODO needs to check that public addresses aren't being shared
|
// TODO needs to check that public addresses aren't being shared
|
||||||
// across networks, and whatever else happens in Config.Validate.
|
// across networks, and whatever else happens in Config.Validate.
|
||||||
return struct{}{}, n.SetConfig(ctx, config)
|
return struct{}{}, n.SetConfig(ctx, config)
|
||||||
|
227
go/daemon/daemon_test.go
Normal file
227
go/daemon/daemon_test.go
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
package daemon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"isle/bootstrap"
|
||||||
|
"isle/daemon/daecommon"
|
||||||
|
"isle/daemon/network"
|
||||||
|
"isle/toolkit"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type expectNetworkLoad struct {
|
||||||
|
creationParams bootstrap.CreationParams
|
||||||
|
networkConfig *daecommon.NetworkConfig
|
||||||
|
network *network.MockNetwork
|
||||||
|
}
|
||||||
|
|
||||||
|
type harnessOpts struct {
|
||||||
|
config daecommon.Config
|
||||||
|
expectNetworksLoaded []expectNetworkLoad
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *harnessOpts) withDefaults() *harnessOpts {
|
||||||
|
if o == nil {
|
||||||
|
o = new(harnessOpts)
|
||||||
|
}
|
||||||
|
return o
|
||||||
|
}
|
||||||
|
|
||||||
|
type harness struct {
|
||||||
|
ctx context.Context
|
||||||
|
networkLoader *network.MockLoader
|
||||||
|
daemon *Daemon
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHarness(t *testing.T, opts *harnessOpts) *harness {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
opts = opts.withDefaults()
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
logger = toolkit.NewTestLogger(t)
|
||||||
|
networkLoader = network.NewMockLoader(t)
|
||||||
|
)
|
||||||
|
|
||||||
|
expectLoadable := make(
|
||||||
|
[]bootstrap.CreationParams, len(opts.expectNetworksLoaded),
|
||||||
|
)
|
||||||
|
for i, expectNetworkLoaded := range opts.expectNetworksLoaded {
|
||||||
|
expectLoadable[i] = expectNetworkLoaded.creationParams
|
||||||
|
networkLoader.
|
||||||
|
On(
|
||||||
|
"Load",
|
||||||
|
toolkit.MockArg[context.Context](),
|
||||||
|
toolkit.MockArg[*mlog.Logger](),
|
||||||
|
expectNetworkLoaded.creationParams,
|
||||||
|
&network.Opts{
|
||||||
|
Config: expectNetworkLoaded.networkConfig,
|
||||||
|
},
|
||||||
|
).
|
||||||
|
Return(expectNetworkLoaded.network, nil).
|
||||||
|
Once()
|
||||||
|
|
||||||
|
expectNetworkLoaded.network.On("Shutdown").Return(nil).Once()
|
||||||
|
}
|
||||||
|
|
||||||
|
networkLoader.
|
||||||
|
On("Loadable", toolkit.MockArg[context.Context]()).
|
||||||
|
Return(expectLoadable, nil).
|
||||||
|
Once()
|
||||||
|
|
||||||
|
daemon, err := New(ctx, logger, networkLoader, opts.config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
t.Log("Shutting down Daemon")
|
||||||
|
assert.NoError(t, daemon.Shutdown())
|
||||||
|
})
|
||||||
|
|
||||||
|
return &harness{ctx, networkLoader, daemon}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
t.Run("no networks loaded", func(t *testing.T) {
|
||||||
|
_ = newHarness(t, nil)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DEPRECATED network config matching", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
creationParams = bootstrap.NewCreationParams("AAA", "a.com")
|
||||||
|
networkConfig = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
|
||||||
|
c.DNS.Resolvers = []string{"foo"}
|
||||||
|
})
|
||||||
|
config = daecommon.Config{
|
||||||
|
Networks: map[string]daecommon.NetworkConfig{
|
||||||
|
daecommon.DeprecatedNetworkID: networkConfig,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = newHarness(t, &harnessOpts{
|
||||||
|
config: config,
|
||||||
|
expectNetworksLoaded: []expectNetworkLoad{
|
||||||
|
{
|
||||||
|
creationParams,
|
||||||
|
&networkConfig,
|
||||||
|
network.NewMockNetwork(t),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("network config matching", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
creationParamsA = bootstrap.NewCreationParams("AAA", "a.com")
|
||||||
|
creationParamsB = bootstrap.NewCreationParams("BBB", "b.com")
|
||||||
|
creationParamsC = bootstrap.NewCreationParams("CCC", "c.com")
|
||||||
|
creationParamsD = bootstrap.NewCreationParams("DDD", "d.com")
|
||||||
|
|
||||||
|
networkConfigA = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
|
||||||
|
c.DNS.Resolvers = []string{"foo"}
|
||||||
|
})
|
||||||
|
|
||||||
|
networkConfigB = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
|
||||||
|
c.VPN.Tun.Device = "bar"
|
||||||
|
})
|
||||||
|
|
||||||
|
networkConfigC = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
|
||||||
|
c.Storage.Allocations = []daecommon.ConfigStorageAllocation{
|
||||||
|
{
|
||||||
|
DataPath: "/path/data",
|
||||||
|
MetaPath: "/path/meta",
|
||||||
|
Capacity: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
config = daecommon.Config{
|
||||||
|
Networks: map[string]daecommon.NetworkConfig{
|
||||||
|
creationParamsA.ID: networkConfigA,
|
||||||
|
creationParamsB.Name: networkConfigB,
|
||||||
|
creationParamsC.Domain: networkConfigC,
|
||||||
|
"unknown": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = newHarness(t, &harnessOpts{
|
||||||
|
config: config,
|
||||||
|
expectNetworksLoaded: []expectNetworkLoad{
|
||||||
|
{
|
||||||
|
creationParamsA,
|
||||||
|
&networkConfigA,
|
||||||
|
network.NewMockNetwork(t),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
creationParamsB,
|
||||||
|
&networkConfigB,
|
||||||
|
network.NewMockNetwork(t),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
creationParamsC,
|
||||||
|
&networkConfigC,
|
||||||
|
network.NewMockNetwork(t),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
creationParamsD,
|
||||||
|
nil,
|
||||||
|
network.NewMockNetwork(t),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDaemon_SetConfig(t *testing.T) {
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
networkA = network.NewMockNetwork(t)
|
||||||
|
h = newHarness(t, &harnessOpts{
|
||||||
|
expectNetworksLoaded: []expectNetworkLoad{{
|
||||||
|
bootstrap.NewCreationParams("AAA", "a.com"), nil, networkA,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
networkConfig = daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
|
||||||
|
c.VPN.Tun.Device = "foo"
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
networkA.
|
||||||
|
On("SetConfig", toolkit.MockArg[context.Context](), networkConfig).
|
||||||
|
Return(nil).
|
||||||
|
Once()
|
||||||
|
|
||||||
|
err := h.daemon.SetConfig(h.ctx, networkConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ErrUserManagedNetworkConfig", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
creationParams = bootstrap.NewCreationParams("AAA", "a.com")
|
||||||
|
networkA = network.NewMockNetwork(t)
|
||||||
|
networkConfig = daecommon.NewNetworkConfig(nil)
|
||||||
|
|
||||||
|
h = newHarness(t, &harnessOpts{
|
||||||
|
config: daecommon.Config{
|
||||||
|
Networks: map[string]daecommon.NetworkConfig{
|
||||||
|
creationParams.Name: networkConfig,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectNetworksLoaded: []expectNetworkLoad{
|
||||||
|
{creationParams, &networkConfig, networkA},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
networkConfig.VPN.Tun.Device = "foo"
|
||||||
|
err := h.daemon.SetConfig(h.ctx, networkConfig)
|
||||||
|
assert.ErrorIs(t, err, ErrUserManagedNetworkConfig)
|
||||||
|
})
|
||||||
|
}
|
@ -10,6 +10,7 @@ const (
|
|||||||
errCodeAlreadyJoined
|
errCodeAlreadyJoined
|
||||||
errCodeNoMatchingNetworks
|
errCodeNoMatchingNetworks
|
||||||
errCodeMultipleMatchingNetworks
|
errCodeMultipleMatchingNetworks
|
||||||
|
errCodeUserManagedNetworkConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -33,4 +34,11 @@ var (
|
|||||||
errCodeMultipleMatchingNetworks,
|
errCodeMultipleMatchingNetworks,
|
||||||
"Multiple networks matched the search string",
|
"Multiple networks matched the search string",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrUserManagedNetworkConfig is returned when attempting to modify a
|
||||||
|
// network config which is managed by the user.
|
||||||
|
ErrUserManagedNetworkConfig = jsonrpc2.NewError(
|
||||||
|
errCodeUserManagedNetworkConfig,
|
||||||
|
"Network configuration is managed by the user",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
@ -10,24 +10,30 @@ import (
|
|||||||
func pickNetwork(
|
func pickNetwork(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
networkLoader network.Loader,
|
networkLoader network.Loader,
|
||||||
networks map[string]network.Network,
|
networks map[string]joinedNetwork,
|
||||||
) (
|
) (
|
||||||
network.Network, error,
|
joinedNetwork, error,
|
||||||
) {
|
) {
|
||||||
if len(networks) == 0 {
|
if len(networks) == 0 {
|
||||||
return nil, ErrNoNetwork
|
return joinedNetwork{}, ErrNoNetwork
|
||||||
|
}
|
||||||
|
|
||||||
|
networkSearchStr := getNetworkSearchStr(ctx)
|
||||||
|
if networkSearchStr == "" {
|
||||||
|
if len(networks) > 1 {
|
||||||
|
return joinedNetwork{}, ErrNoMatchingNetworks
|
||||||
|
}
|
||||||
|
for _, network := range networks {
|
||||||
|
return network, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
creationParams, err := networkLoader.Loadable(ctx)
|
creationParams, err := networkLoader.Loadable(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting loadable networks: %w", err)
|
return joinedNetwork{}, fmt.Errorf("getting loadable networks: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
matchingNetworkIDs := make([]string, 0, len(networks))
|
||||||
networkSearchStr = getNetworkSearchStr(ctx)
|
|
||||||
matchingNetworkIDs = make([]string, 0, len(networks))
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, creationParam := range creationParams {
|
for _, creationParam := range creationParams {
|
||||||
if networkSearchStr == "" || creationParam.Matches(networkSearchStr) {
|
if networkSearchStr == "" || creationParam.Matches(networkSearchStr) {
|
||||||
matchingNetworkIDs = append(matchingNetworkIDs, creationParam.ID)
|
matchingNetworkIDs = append(matchingNetworkIDs, creationParam.ID)
|
||||||
@ -35,9 +41,9 @@ func pickNetwork(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(matchingNetworkIDs) == 0 {
|
if len(matchingNetworkIDs) == 0 {
|
||||||
return nil, ErrNoMatchingNetworks
|
return joinedNetwork{}, ErrNoMatchingNetworks
|
||||||
} else if len(matchingNetworkIDs) > 1 {
|
} else if len(matchingNetworkIDs) > 1 {
|
||||||
return nil, ErrMultipleMatchingNetworks
|
return joinedNetwork{}, ErrMultipleMatchingNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
return networks[matchingNetworkIDs[0]], nil
|
return networks[matchingNetworkIDs[0]], nil
|
||||||
@ -45,7 +51,7 @@ func pickNetwork(
|
|||||||
|
|
||||||
func alreadyJoined(
|
func alreadyJoined(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
networks map[string]network.Network,
|
networks map[string]joinedNetwork,
|
||||||
creationParams bootstrap.CreationParams,
|
creationParams bootstrap.CreationParams,
|
||||||
) (
|
) (
|
||||||
bool, error,
|
bool, error,
|
||||||
|
42
go/daemon/network/config.go
Normal file
42
go/daemon/network/config.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"isle/daemon/daecommon"
|
||||||
|
"isle/jsonutil"
|
||||||
|
"isle/toolkit"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
// loadStoreConfig writes the given NetworkConfig to the networkStateDir if the
|
||||||
|
// config is non-nil. If the config is nil then a config is read from
|
||||||
|
// networkStateDir, returning the zero value if no config was previously
|
||||||
|
// stored.
|
||||||
|
func loadStoreConfig(
|
||||||
|
networkStateDir toolkit.Dir, config *daecommon.NetworkConfig,
|
||||||
|
) (
|
||||||
|
daecommon.NetworkConfig, error,
|
||||||
|
) {
|
||||||
|
path := filepath.Join(networkStateDir.Path, "config.json")
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
config = new(daecommon.NetworkConfig)
|
||||||
|
err := jsonutil.LoadFile(&config, path)
|
||||||
|
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return daecommon.NetworkConfig{}, fmt.Errorf(
|
||||||
|
"loading %q: %w", path, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return *config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := jsonutil.WriteFile(*config, path, 0600); err != nil {
|
||||||
|
return daecommon.NetworkConfig{}, fmt.Errorf(
|
||||||
|
"writing to %q: %w", path, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return *config, nil
|
||||||
|
}
|
@ -59,7 +59,6 @@ type Loader interface {
|
|||||||
Load(
|
Load(
|
||||||
context.Context,
|
context.Context,
|
||||||
*mlog.Logger,
|
*mlog.Logger,
|
||||||
daecommon.NetworkConfig,
|
|
||||||
bootstrap.CreationParams,
|
bootstrap.CreationParams,
|
||||||
*Opts,
|
*Opts,
|
||||||
) (
|
) (
|
||||||
@ -73,7 +72,6 @@ type Loader interface {
|
|||||||
Join(
|
Join(
|
||||||
context.Context,
|
context.Context,
|
||||||
*mlog.Logger,
|
*mlog.Logger,
|
||||||
daecommon.NetworkConfig,
|
|
||||||
JoiningBootstrap,
|
JoiningBootstrap,
|
||||||
*Opts,
|
*Opts,
|
||||||
) (
|
) (
|
||||||
@ -91,12 +89,11 @@ type Loader interface {
|
|||||||
// - hostName: The name of this first host in the network.
|
// - hostName: The name of this first host in the network.
|
||||||
//
|
//
|
||||||
// Errors:
|
// Errors:
|
||||||
// - ErrInvalidConfig - if daemonConfig doesn't have 3 storage allocations
|
// - ErrInvalidConfig - If the Opts.Config field is not valid. It must be
|
||||||
// configured.
|
// non-nil and have at least 3 storage allocations.
|
||||||
Create(
|
Create(
|
||||||
context.Context,
|
context.Context,
|
||||||
*mlog.Logger,
|
*mlog.Logger,
|
||||||
daecommon.NetworkConfig,
|
|
||||||
bootstrap.CreationParams,
|
bootstrap.CreationParams,
|
||||||
nebula.IPNet,
|
nebula.IPNet,
|
||||||
nebula.HostName,
|
nebula.HostName,
|
||||||
@ -188,7 +185,7 @@ func (l *loader) Loadable(
|
|||||||
creationParams := make([]bootstrap.CreationParams, 0, len(networkStateDirs))
|
creationParams := make([]bootstrap.CreationParams, 0, len(networkStateDirs))
|
||||||
|
|
||||||
for _, networkStateDir := range networkStateDirs {
|
for _, networkStateDir := range networkStateDirs {
|
||||||
thisCreationParams, err := LoadCreationParams(networkStateDir)
|
thisCreationParams, err := loadCreationParams(networkStateDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"loading creation params from %q: %w",
|
"loading creation params from %q: %w",
|
||||||
@ -205,7 +202,6 @@ func (l *loader) Loadable(
|
|||||||
func (l *loader) Load(
|
func (l *loader) Load(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
logger *mlog.Logger,
|
logger *mlog.Logger,
|
||||||
networkConfig daecommon.NetworkConfig,
|
|
||||||
creationParams bootstrap.CreationParams,
|
creationParams bootstrap.CreationParams,
|
||||||
opts *Opts,
|
opts *Opts,
|
||||||
) (
|
) (
|
||||||
@ -226,7 +222,6 @@ func (l *loader) Load(
|
|||||||
ctx,
|
ctx,
|
||||||
logger.WithNamespace("network"),
|
logger.WithNamespace("network"),
|
||||||
l.envBinDirPath,
|
l.envBinDirPath,
|
||||||
networkConfig,
|
|
||||||
networkStateDir,
|
networkStateDir,
|
||||||
networkRuntimeDir,
|
networkRuntimeDir,
|
||||||
opts,
|
opts,
|
||||||
@ -236,7 +231,6 @@ func (l *loader) Load(
|
|||||||
func (l *loader) Join(
|
func (l *loader) Join(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
logger *mlog.Logger,
|
logger *mlog.Logger,
|
||||||
networkConfig daecommon.NetworkConfig,
|
|
||||||
joiningBootstrap JoiningBootstrap,
|
joiningBootstrap JoiningBootstrap,
|
||||||
opts *Opts,
|
opts *Opts,
|
||||||
) (
|
) (
|
||||||
@ -260,7 +254,6 @@ func (l *loader) Join(
|
|||||||
ctx,
|
ctx,
|
||||||
logger.WithNamespace("network"),
|
logger.WithNamespace("network"),
|
||||||
l.envBinDirPath,
|
l.envBinDirPath,
|
||||||
networkConfig,
|
|
||||||
joiningBootstrap,
|
joiningBootstrap,
|
||||||
networkStateDir,
|
networkStateDir,
|
||||||
networkRuntimeDir,
|
networkRuntimeDir,
|
||||||
@ -271,7 +264,6 @@ func (l *loader) Join(
|
|||||||
func (l *loader) Create(
|
func (l *loader) Create(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
logger *mlog.Logger,
|
logger *mlog.Logger,
|
||||||
networkConfig daecommon.NetworkConfig,
|
|
||||||
creationParams bootstrap.CreationParams,
|
creationParams bootstrap.CreationParams,
|
||||||
ipNet nebula.IPNet,
|
ipNet nebula.IPNet,
|
||||||
hostName nebula.HostName,
|
hostName nebula.HostName,
|
||||||
@ -294,7 +286,6 @@ func (l *loader) Create(
|
|||||||
ctx,
|
ctx,
|
||||||
logger.WithNamespace("network"),
|
logger.WithNamespace("network"),
|
||||||
l.envBinDirPath,
|
l.envBinDirPath,
|
||||||
networkConfig,
|
|
||||||
networkStateDir,
|
networkStateDir,
|
||||||
networkRuntimeDir,
|
networkRuntimeDir,
|
||||||
creationParams,
|
creationParams,
|
||||||
|
@ -6,8 +6,6 @@ import (
|
|||||||
context "context"
|
context "context"
|
||||||
bootstrap "isle/bootstrap"
|
bootstrap "isle/bootstrap"
|
||||||
|
|
||||||
daecommon "isle/daemon/daecommon"
|
|
||||||
|
|
||||||
mlog "dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
mlog "dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
||||||
|
|
||||||
mock "github.com/stretchr/testify/mock"
|
mock "github.com/stretchr/testify/mock"
|
||||||
@ -20,9 +18,9 @@ type MockLoader struct {
|
|||||||
mock.Mock
|
mock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5, _a6
|
// Create provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5
|
||||||
func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.NetworkConfig, _a3 bootstrap.CreationParams, _a4 nebula.IPNet, _a5 nebula.HostName, _a6 *Opts) (Network, error) {
|
func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 bootstrap.CreationParams, _a3 nebula.IPNet, _a4 nebula.HostName, _a5 *Opts) (Network, error) {
|
||||||
ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5, _a6)
|
ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5)
|
||||||
|
|
||||||
if len(ret) == 0 {
|
if len(ret) == 0 {
|
||||||
panic("no return value specified for Create")
|
panic("no return value specified for Create")
|
||||||
@ -30,19 +28,19 @@ func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommo
|
|||||||
|
|
||||||
var r0 Network
|
var r0 Network
|
||||||
var r1 error
|
var r1 error
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) (Network, error)); ok {
|
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) (Network, error)); ok {
|
||||||
return rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6)
|
return rf(_a0, _a1, _a2, _a3, _a4, _a5)
|
||||||
}
|
}
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) Network); ok {
|
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) Network); ok {
|
||||||
r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6)
|
r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5)
|
||||||
} else {
|
} else {
|
||||||
if ret.Get(0) != nil {
|
if ret.Get(0) != nil {
|
||||||
r0 = ret.Get(0).(Network)
|
r0 = ret.Get(0).(Network)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) error); ok {
|
if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, nebula.IPNet, nebula.HostName, *Opts) error); ok {
|
||||||
r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6)
|
r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5)
|
||||||
} else {
|
} else {
|
||||||
r1 = ret.Error(1)
|
r1 = ret.Error(1)
|
||||||
}
|
}
|
||||||
@ -50,9 +48,9 @@ func (_m *MockLoader) Create(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommo
|
|||||||
return r0, r1
|
return r0, r1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Join provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4
|
// Join provides a mock function with given fields: _a0, _a1, _a2, _a3
|
||||||
func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.NetworkConfig, _a3 JoiningBootstrap, _a4 *Opts) (Network, error) {
|
func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 JoiningBootstrap, _a3 *Opts) (Network, error) {
|
||||||
ret := _m.Called(_a0, _a1, _a2, _a3, _a4)
|
ret := _m.Called(_a0, _a1, _a2, _a3)
|
||||||
|
|
||||||
if len(ret) == 0 {
|
if len(ret) == 0 {
|
||||||
panic("no return value specified for Join")
|
panic("no return value specified for Join")
|
||||||
@ -60,19 +58,19 @@ func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.
|
|||||||
|
|
||||||
var r0 Network
|
var r0 Network
|
||||||
var r1 error
|
var r1 error
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, JoiningBootstrap, *Opts) (Network, error)); ok {
|
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, JoiningBootstrap, *Opts) (Network, error)); ok {
|
||||||
return rf(_a0, _a1, _a2, _a3, _a4)
|
return rf(_a0, _a1, _a2, _a3)
|
||||||
}
|
}
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, JoiningBootstrap, *Opts) Network); ok {
|
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, JoiningBootstrap, *Opts) Network); ok {
|
||||||
r0 = rf(_a0, _a1, _a2, _a3, _a4)
|
r0 = rf(_a0, _a1, _a2, _a3)
|
||||||
} else {
|
} else {
|
||||||
if ret.Get(0) != nil {
|
if ret.Get(0) != nil {
|
||||||
r0 = ret.Get(0).(Network)
|
r0 = ret.Get(0).(Network)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, JoiningBootstrap, *Opts) error); ok {
|
if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, JoiningBootstrap, *Opts) error); ok {
|
||||||
r1 = rf(_a0, _a1, _a2, _a3, _a4)
|
r1 = rf(_a0, _a1, _a2, _a3)
|
||||||
} else {
|
} else {
|
||||||
r1 = ret.Error(1)
|
r1 = ret.Error(1)
|
||||||
}
|
}
|
||||||
@ -80,9 +78,9 @@ func (_m *MockLoader) Join(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.
|
|||||||
return r0, r1
|
return r0, r1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4
|
// Load provides a mock function with given fields: _a0, _a1, _a2, _a3
|
||||||
func (_m *MockLoader) Load(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.NetworkConfig, _a3 bootstrap.CreationParams, _a4 *Opts) (Network, error) {
|
func (_m *MockLoader) Load(_a0 context.Context, _a1 *mlog.Logger, _a2 bootstrap.CreationParams, _a3 *Opts) (Network, error) {
|
||||||
ret := _m.Called(_a0, _a1, _a2, _a3, _a4)
|
ret := _m.Called(_a0, _a1, _a2, _a3)
|
||||||
|
|
||||||
if len(ret) == 0 {
|
if len(ret) == 0 {
|
||||||
panic("no return value specified for Load")
|
panic("no return value specified for Load")
|
||||||
@ -90,19 +88,19 @@ func (_m *MockLoader) Load(_a0 context.Context, _a1 *mlog.Logger, _a2 daecommon.
|
|||||||
|
|
||||||
var r0 Network
|
var r0 Network
|
||||||
var r1 error
|
var r1 error
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, *Opts) (Network, error)); ok {
|
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, *Opts) (Network, error)); ok {
|
||||||
return rf(_a0, _a1, _a2, _a3, _a4)
|
return rf(_a0, _a1, _a2, _a3)
|
||||||
}
|
}
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, *Opts) Network); ok {
|
if rf, ok := ret.Get(0).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, *Opts) Network); ok {
|
||||||
r0 = rf(_a0, _a1, _a2, _a3, _a4)
|
r0 = rf(_a0, _a1, _a2, _a3)
|
||||||
} else {
|
} else {
|
||||||
if ret.Get(0) != nil {
|
if ret.Get(0) != nil {
|
||||||
r0 = ret.Get(0).(Network)
|
r0 = ret.Get(0).(Network)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, daecommon.NetworkConfig, bootstrap.CreationParams, *Opts) error); ok {
|
if rf, ok := ret.Get(1).(func(context.Context, *mlog.Logger, bootstrap.CreationParams, *Opts) error); ok {
|
||||||
r1 = rf(_a0, _a1, _a2, _a3, _a4)
|
r1 = rf(_a0, _a1, _a2, _a3)
|
||||||
} else {
|
} else {
|
||||||
r1 = ret.Error(1)
|
r1 = ret.Error(1)
|
||||||
}
|
}
|
||||||
|
@ -154,6 +154,14 @@ type Network interface {
|
|||||||
// Network instance. A nil Opts is equivalent to a zero value.
|
// Network instance. A nil Opts is equivalent to a zero value.
|
||||||
type Opts struct {
|
type Opts struct {
|
||||||
GarageAdminToken string // Will be randomly generated if left unset.
|
GarageAdminToken string // Will be randomly generated if left unset.
|
||||||
|
|
||||||
|
// Config will be used as the configuration of the Network from its
|
||||||
|
// initialization onwards.
|
||||||
|
//
|
||||||
|
// If not given then the most recent NetworkConfig for the network will be
|
||||||
|
// used, either that which it was most recently initialized with or which
|
||||||
|
// was passed to [SetConfig].
|
||||||
|
Config *daecommon.NetworkConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Opts) withDefaults() *Opts {
|
func (o *Opts) withDefaults() *Opts {
|
||||||
@ -189,36 +197,53 @@ type network struct {
|
|||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// instatiateNetwork returns an instantiated *network instance which has not yet
|
// newNetwork returns an instantiated *network instance. All initialization
|
||||||
// been initialized.
|
// steps which are common to all *network creation methods (load, join, create)
|
||||||
func instatiateNetwork(
|
// are included here as well.
|
||||||
|
func newNetwork(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
logger *mlog.Logger,
|
logger *mlog.Logger,
|
||||||
networkConfig daecommon.NetworkConfig,
|
|
||||||
envBinDirPath string,
|
envBinDirPath string,
|
||||||
stateDir toolkit.Dir,
|
stateDir toolkit.Dir,
|
||||||
runtimeDir toolkit.Dir,
|
runtimeDir toolkit.Dir,
|
||||||
|
dirsMayExist bool,
|
||||||
opts *Opts,
|
opts *Opts,
|
||||||
) *network {
|
) (
|
||||||
ctx = context.WithoutCancel(ctx)
|
*network, error,
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
) {
|
||||||
return &network{
|
ctx, cancel := context.WithCancel(context.WithoutCancel(ctx))
|
||||||
logger: logger,
|
|
||||||
networkConfig: networkConfig,
|
var (
|
||||||
envBinDirPath: envBinDirPath,
|
n = &network{
|
||||||
stateDir: stateDir,
|
logger: logger,
|
||||||
runtimeDir: runtimeDir,
|
envBinDirPath: envBinDirPath,
|
||||||
opts: opts.withDefaults(),
|
stateDir: stateDir,
|
||||||
workerCtx: ctx,
|
runtimeDir: runtimeDir,
|
||||||
workerCancel: cancel,
|
opts: opts.withDefaults(),
|
||||||
|
workerCtx: ctx,
|
||||||
|
workerCancel: cancel,
|
||||||
|
}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
n.networkConfig, err = loadStoreConfig(n.stateDir, n.opts.Config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolving network config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
secretsDir, err := n.stateDir.MkChildDir("secrets", dirsMayExist)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating secrets dir: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n.secretsStore = secrets.NewFSStore(secretsDir.Path)
|
||||||
|
|
||||||
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadCreationParams returns the CreationParams of a Network which was
|
// loadCreationParams returns the CreationParams of a Network which was
|
||||||
// Created/Joined with the given state directory.
|
// Created/Joined with the given state directory.
|
||||||
//
|
func loadCreationParams(
|
||||||
// TODO probably can be private
|
|
||||||
func LoadCreationParams(
|
|
||||||
stateDir toolkit.Dir,
|
stateDir toolkit.Dir,
|
||||||
) (
|
) (
|
||||||
bootstrap.CreationParams, error,
|
bootstrap.CreationParams, error,
|
||||||
@ -244,25 +269,23 @@ func load(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
logger *mlog.Logger,
|
logger *mlog.Logger,
|
||||||
envBinDirPath string,
|
envBinDirPath string,
|
||||||
networkConfig daecommon.NetworkConfig,
|
|
||||||
stateDir toolkit.Dir,
|
stateDir toolkit.Dir,
|
||||||
runtimeDir toolkit.Dir,
|
runtimeDir toolkit.Dir,
|
||||||
opts *Opts,
|
opts *Opts,
|
||||||
) (
|
) (
|
||||||
Network, error,
|
Network, error,
|
||||||
) {
|
) {
|
||||||
n := instatiateNetwork(
|
n, err := newNetwork(
|
||||||
ctx,
|
ctx,
|
||||||
logger,
|
logger,
|
||||||
networkConfig,
|
|
||||||
envBinDirPath,
|
envBinDirPath,
|
||||||
stateDir,
|
stateDir,
|
||||||
runtimeDir,
|
runtimeDir,
|
||||||
|
true,
|
||||||
opts,
|
opts,
|
||||||
)
|
)
|
||||||
|
if err != nil {
|
||||||
if err := n.initializeDirs(true); err != nil {
|
return nil, fmt.Errorf("instantiating Network: %w", err)
|
||||||
return nil, fmt.Errorf("initializing directories: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -285,7 +308,6 @@ func join(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
logger *mlog.Logger,
|
logger *mlog.Logger,
|
||||||
envBinDirPath string,
|
envBinDirPath string,
|
||||||
networkConfig daecommon.NetworkConfig,
|
|
||||||
joiningBootstrap JoiningBootstrap,
|
joiningBootstrap JoiningBootstrap,
|
||||||
stateDir toolkit.Dir,
|
stateDir toolkit.Dir,
|
||||||
runtimeDir toolkit.Dir,
|
runtimeDir toolkit.Dir,
|
||||||
@ -293,18 +315,17 @@ func join(
|
|||||||
) (
|
) (
|
||||||
Network, error,
|
Network, error,
|
||||||
) {
|
) {
|
||||||
n := instatiateNetwork(
|
n, err := newNetwork(
|
||||||
ctx,
|
ctx,
|
||||||
logger,
|
logger,
|
||||||
networkConfig,
|
|
||||||
envBinDirPath,
|
envBinDirPath,
|
||||||
stateDir,
|
stateDir,
|
||||||
runtimeDir,
|
runtimeDir,
|
||||||
|
false,
|
||||||
opts,
|
opts,
|
||||||
)
|
)
|
||||||
|
if err != nil {
|
||||||
if err := n.initializeDirs(false); err != nil {
|
return nil, fmt.Errorf("instantiating Network: %w", err)
|
||||||
return nil, fmt.Errorf("initializing directories: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := secrets.Import(
|
if err := secrets.Import(
|
||||||
@ -324,7 +345,6 @@ func create(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
logger *mlog.Logger,
|
logger *mlog.Logger,
|
||||||
envBinDirPath string,
|
envBinDirPath string,
|
||||||
networkConfig daecommon.NetworkConfig,
|
|
||||||
stateDir toolkit.Dir,
|
stateDir toolkit.Dir,
|
||||||
runtimeDir toolkit.Dir,
|
runtimeDir toolkit.Dir,
|
||||||
creationParams bootstrap.CreationParams,
|
creationParams bootstrap.CreationParams,
|
||||||
@ -334,12 +354,6 @@ func create(
|
|||||||
) (
|
) (
|
||||||
Network, error,
|
Network, error,
|
||||||
) {
|
) {
|
||||||
if len(networkConfig.Storage.Allocations) < 3 {
|
|
||||||
return nil, ErrInvalidConfig.WithData(
|
|
||||||
"At least three storage allocations are required.",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
nebulaCACreds, err := nebula.NewCACredentials(creationParams.Domain, ipNet)
|
nebulaCACreds, err := nebula.NewCACredentials(creationParams.Domain, ipNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating nebula CA cert: %w", err)
|
return nil, fmt.Errorf("creating nebula CA cert: %w", err)
|
||||||
@ -347,18 +361,23 @@ func create(
|
|||||||
|
|
||||||
garageRPCSecret := toolkit.RandStr(32)
|
garageRPCSecret := toolkit.RandStr(32)
|
||||||
|
|
||||||
n := instatiateNetwork(
|
n, err := newNetwork(
|
||||||
ctx,
|
ctx,
|
||||||
logger,
|
logger,
|
||||||
networkConfig,
|
|
||||||
envBinDirPath,
|
envBinDirPath,
|
||||||
stateDir,
|
stateDir,
|
||||||
runtimeDir,
|
runtimeDir,
|
||||||
|
false,
|
||||||
opts,
|
opts,
|
||||||
)
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("instantiating Network: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := n.initializeDirs(false); err != nil {
|
if len(n.networkConfig.Storage.Allocations) < 3 {
|
||||||
return nil, fmt.Errorf("initializing directories: %w", err)
|
return nil, ErrInvalidConfig.WithData(
|
||||||
|
"At least three storage allocations are required.",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = daecommon.SetGarageRPCSecret(ctx, n.secretsStore, garageRPCSecret)
|
err = daecommon.SetGarageRPCSecret(ctx, n.secretsStore, garageRPCSecret)
|
||||||
@ -391,16 +410,6 @@ func create(
|
|||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *network) initializeDirs(mayExist bool) error {
|
|
||||||
secretsDir, err := n.stateDir.MkChildDir("secrets", mayExist)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("creating secrets dir: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
n.secretsStore = secrets.NewFSStore(secretsDir.Path)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *network) periodically(
|
func (n *network) periodically(
|
||||||
label string,
|
label string,
|
||||||
fn func(context.Context) error,
|
fn func(context.Context) error,
|
||||||
@ -1016,6 +1025,10 @@ func (n *network) GetConfig(context.Context) (daecommon.NetworkConfig, error) {
|
|||||||
func (n *network) SetConfig(
|
func (n *network) SetConfig(
|
||||||
ctx context.Context, config daecommon.NetworkConfig,
|
ctx context.Context, config daecommon.NetworkConfig,
|
||||||
) error {
|
) error {
|
||||||
|
if _, err := loadStoreConfig(n.stateDir, &config); err != nil {
|
||||||
|
return fmt.Errorf("storing new config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
prevBootstrap, err := n.reload(ctx, &config, nil)
|
prevBootstrap, err := n.reload(ctx, &config, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reloading config: %w", err)
|
return fmt.Errorf("reloading config: %w", err)
|
||||||
|
@ -17,7 +17,7 @@ func TestCreate(t *testing.T) {
|
|||||||
network = h.createNetwork(t, "primus", nil)
|
network = h.createNetwork(t, "primus", nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
gotCreationParams, err := LoadCreationParams(network.stateDir)
|
gotCreationParams, err := loadCreationParams(network.stateDir)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(
|
assert.Equal(
|
||||||
t, gotCreationParams, network.getBootstrap(t).NetworkCreationParams,
|
t, gotCreationParams, network.getBootstrap(t).NetworkCreationParams,
|
||||||
@ -25,31 +25,30 @@ func TestCreate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLoad(t *testing.T) {
|
func TestLoad(t *testing.T) {
|
||||||
var (
|
t.Run("given config", func(t *testing.T) {
|
||||||
h = newIntegrationHarness(t)
|
var (
|
||||||
network = h.createNetwork(t, "primus", &createNetworkOpts{
|
h = newIntegrationHarness(t)
|
||||||
manualShutdown: true,
|
network = h.createNetwork(t, "primus", nil)
|
||||||
})
|
networkConfig = network.getConfig(t)
|
||||||
)
|
)
|
||||||
|
|
||||||
t.Log("Shutting down network")
|
network.opts.Config = &networkConfig
|
||||||
assert.NoError(t, network.Shutdown())
|
network.restart(t)
|
||||||
|
|
||||||
t.Log("Calling Load")
|
assert.Equal(t, networkConfig, network.getConfig(t))
|
||||||
loadedNetwork, err := load(
|
})
|
||||||
h.ctx,
|
|
||||||
h.logger.WithNamespace("loadedNetwork"),
|
|
||||||
getEnvBinDirPath(),
|
|
||||||
network.getConfig(t),
|
|
||||||
network.stateDir,
|
|
||||||
h.mkDir(t, "runtime"),
|
|
||||||
network.opts,
|
|
||||||
)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Run("load previous config", func(t *testing.T) {
|
||||||
t.Log("Shutting down loadedNetwork")
|
var (
|
||||||
assert.NoError(t, loadedNetwork.Shutdown())
|
h = newIntegrationHarness(t)
|
||||||
|
network = h.createNetwork(t, "primus", nil)
|
||||||
|
networkConfig = network.getConfig(t)
|
||||||
|
)
|
||||||
|
|
||||||
|
network.opts.Config = nil
|
||||||
|
network.restart(t)
|
||||||
|
|
||||||
|
assert.Equal(t, networkConfig, network.getConfig(t))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,13 +60,7 @@ func TestJoin(t *testing.T) {
|
|||||||
secondus = h.joinNetwork(t, primus, "secondus", nil)
|
secondus = h.joinNetwork(t, primus, "secondus", nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
primusHosts, err := primus.GetHosts(h.ctx)
|
assert.Equal(t, primus.getHostsByName(t), secondus.getHostsByName(t))
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
secondusHosts, err := secondus.GetHosts(h.ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, primusHosts, secondusHosts)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with alloc", func(t *testing.T) {
|
t.Run("with alloc", func(t *testing.T) {
|
||||||
@ -84,28 +77,10 @@ func TestJoin(t *testing.T) {
|
|||||||
t.Log("reloading primus' hosts")
|
t.Log("reloading primus' hosts")
|
||||||
assert.NoError(t, primus.Network.(*network).reloadHosts(h.ctx))
|
assert.NoError(t, primus.Network.(*network).reloadHosts(h.ctx))
|
||||||
|
|
||||||
primusHosts, err := primus.GetHosts(h.ctx)
|
assert.Equal(t, primus.getHostsByName(t), secondus.getHostsByName(t))
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
secondusHosts, err := secondus.GetHosts(h.ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, primusHosts, secondusHosts)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNetwork_GetConfig(t *testing.T) {
|
|
||||||
var (
|
|
||||||
h = newIntegrationHarness(t)
|
|
||||||
network = h.createNetwork(t, "primus", nil)
|
|
||||||
)
|
|
||||||
|
|
||||||
config, err := network.GetConfig(h.ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, config, network.getConfig(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNetwork_SetConfig(t *testing.T) {
|
func TestNetwork_SetConfig(t *testing.T) {
|
||||||
allocsToRoles := func(
|
allocsToRoles := func(
|
||||||
hostName nebula.HostName, allocs []bootstrap.GarageHostInstance,
|
hostName nebula.HostName, allocs []bootstrap.GarageHostInstance,
|
||||||
@ -259,4 +234,22 @@ func TestNetwork_SetConfig(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotContains(t, layout.Roles, removedRole)
|
assert.NotContains(t, layout.Roles, removedRole)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("changes reflected after restart", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
h = newIntegrationHarness(t)
|
||||||
|
network = h.createNetwork(t, "primus", &createNetworkOpts{
|
||||||
|
numStorageAllocs: 4,
|
||||||
|
})
|
||||||
|
networkConfig = network.getConfig(t)
|
||||||
|
)
|
||||||
|
|
||||||
|
networkConfig.Storage.Allocations = networkConfig.Storage.Allocations[:3]
|
||||||
|
assert.NoError(t, network.SetConfig(h.ctx, networkConfig))
|
||||||
|
|
||||||
|
network.opts.Config = nil
|
||||||
|
network.restart(t)
|
||||||
|
|
||||||
|
assert.Equal(t, networkConfig, network.getConfig(t))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package network
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"isle/bootstrap"
|
"isle/bootstrap"
|
||||||
"isle/daemon/daecommon"
|
"isle/daemon/daecommon"
|
||||||
@ -18,7 +17,6 @@ import (
|
|||||||
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Utilities related to running network integration tests
|
// Utilities related to running network integration tests
|
||||||
@ -62,16 +60,6 @@ func newTunDevice() string {
|
|||||||
return fmt.Sprintf("isle-test-%d", atomic.AddUint64(&tunDeviceCounter, 1))
|
return fmt.Sprintf("isle-test-%d", atomic.AddUint64(&tunDeviceCounter, 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustParseNetworkConfigf(str string, args ...any) daecommon.NetworkConfig {
|
|
||||||
str = fmt.Sprintf(str, args...)
|
|
||||||
|
|
||||||
var networkConfig daecommon.NetworkConfig
|
|
||||||
if err := yaml.Unmarshal([]byte(str), &networkConfig); err != nil {
|
|
||||||
panic(fmt.Sprintf("parsing network config: %v", err))
|
|
||||||
}
|
|
||||||
return networkConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
type integrationHarness struct {
|
type integrationHarness struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
logger *mlog.Logger
|
logger *mlog.Logger
|
||||||
@ -136,35 +124,24 @@ func (h *integrationHarness) mkNetworkConfig(
|
|||||||
opts = new(networkConfigOpts)
|
opts = new(networkConfigOpts)
|
||||||
}
|
}
|
||||||
|
|
||||||
publicAddr := ""
|
return daecommon.NewNetworkConfig(func(c *daecommon.NetworkConfig) {
|
||||||
if opts.hasPublicAddr {
|
if opts.hasPublicAddr {
|
||||||
publicAddr = newPublicAddr()
|
c.VPN.PublicAddr = newPublicAddr()
|
||||||
}
|
|
||||||
|
|
||||||
allocs := make([]map[string]any, opts.numStorageAllocs)
|
|
||||||
for i := range allocs {
|
|
||||||
allocs[i] = map[string]any{
|
|
||||||
"data_path": h.mkDir(t, "data").Path,
|
|
||||||
"meta_path": h.mkDir(t, "meta").Path,
|
|
||||||
"capacity": 1,
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
allocsJSON, err := json.Marshal(allocs)
|
c.VPN.Tun.Device = newTunDevice() // TODO is this necessary??
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return mustParseNetworkConfigf(`
|
c.Storage.Allocations = make(
|
||||||
vpn:
|
[]daecommon.ConfigStorageAllocation, opts.numStorageAllocs,
|
||||||
public_addr: %q
|
)
|
||||||
tun:
|
for i := range c.Storage.Allocations {
|
||||||
device: %q
|
c.Storage.Allocations[i] = daecommon.ConfigStorageAllocation{
|
||||||
storage:
|
DataPath: h.mkDir(t, "data").Path,
|
||||||
allocations: %s
|
MetaPath: h.mkDir(t, "meta").Path,
|
||||||
`,
|
Capacity: 1,
|
||||||
publicAddr,
|
}
|
||||||
newTunDevice(),
|
}
|
||||||
allocsJSON,
|
})
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type createNetworkOpts struct {
|
type createNetworkOpts struct {
|
||||||
@ -203,7 +180,7 @@ func (h *integrationHarness) createNetwork(
|
|||||||
t *testing.T,
|
t *testing.T,
|
||||||
hostNameStr string,
|
hostNameStr string,
|
||||||
opts *createNetworkOpts,
|
opts *createNetworkOpts,
|
||||||
) integrationHarnessNetwork {
|
) *integrationHarnessNetwork {
|
||||||
t.Logf("Creating as %q", hostNameStr)
|
t.Logf("Creating as %q", hostNameStr)
|
||||||
opts = opts.withDefaults()
|
opts = opts.withDefaults()
|
||||||
|
|
||||||
@ -222,6 +199,7 @@ func (h *integrationHarness) createNetwork(
|
|||||||
|
|
||||||
networkOpts = &Opts{
|
networkOpts = &Opts{
|
||||||
GarageAdminToken: "admin_token",
|
GarageAdminToken: "admin_token",
|
||||||
|
Config: &networkConfig,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -229,7 +207,6 @@ func (h *integrationHarness) createNetwork(
|
|||||||
h.ctx,
|
h.ctx,
|
||||||
logger,
|
logger,
|
||||||
getEnvBinDirPath(),
|
getEnvBinDirPath(),
|
||||||
networkConfig,
|
|
||||||
stateDir,
|
stateDir,
|
||||||
runtimeDir,
|
runtimeDir,
|
||||||
opts.creationParams,
|
opts.creationParams,
|
||||||
@ -241,16 +218,7 @@ func (h *integrationHarness) createNetwork(
|
|||||||
t.Fatalf("creating Network: %v", err)
|
t.Fatalf("creating Network: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !opts.manualShutdown {
|
nh := &integrationHarnessNetwork{
|
||||||
t.Cleanup(func() {
|
|
||||||
t.Logf("Shutting down Network %q", hostNameStr)
|
|
||||||
if err := network.Shutdown(); err != nil {
|
|
||||||
t.Logf("Shutting down Network %q failed: %v", hostNameStr, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return integrationHarnessNetwork{
|
|
||||||
network,
|
network,
|
||||||
h.ctx,
|
h.ctx,
|
||||||
logger,
|
logger,
|
||||||
@ -259,6 +227,17 @@ func (h *integrationHarness) createNetwork(
|
|||||||
runtimeDir,
|
runtimeDir,
|
||||||
networkOpts,
|
networkOpts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !opts.manualShutdown {
|
||||||
|
t.Cleanup(func() {
|
||||||
|
t.Logf("Shutting down Network %q", hostNameStr)
|
||||||
|
if err := nh.Shutdown(); err != nil {
|
||||||
|
t.Logf("Shutting down Network %q failed: %v", hostNameStr, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nh
|
||||||
}
|
}
|
||||||
|
|
||||||
type joinNetworkOpts struct {
|
type joinNetworkOpts struct {
|
||||||
@ -279,10 +258,10 @@ func (o *joinNetworkOpts) withDefaults() *joinNetworkOpts {
|
|||||||
|
|
||||||
func (h *integrationHarness) joinNetwork(
|
func (h *integrationHarness) joinNetwork(
|
||||||
t *testing.T,
|
t *testing.T,
|
||||||
network integrationHarnessNetwork,
|
network *integrationHarnessNetwork,
|
||||||
hostNameStr string,
|
hostNameStr string,
|
||||||
opts *joinNetworkOpts,
|
opts *joinNetworkOpts,
|
||||||
) integrationHarnessNetwork {
|
) *integrationHarnessNetwork {
|
||||||
opts = opts.withDefaults()
|
opts = opts.withDefaults()
|
||||||
hostName := nebula.HostName(hostNameStr)
|
hostName := nebula.HostName(hostNameStr)
|
||||||
|
|
||||||
@ -301,6 +280,7 @@ func (h *integrationHarness) joinNetwork(
|
|||||||
runtimeDir = h.mkDir(t, "runtime")
|
runtimeDir = h.mkDir(t, "runtime")
|
||||||
networkOpts = &Opts{
|
networkOpts = &Opts{
|
||||||
GarageAdminToken: "admin_token",
|
GarageAdminToken: "admin_token",
|
||||||
|
Config: &networkConfig,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -309,7 +289,6 @@ func (h *integrationHarness) joinNetwork(
|
|||||||
h.ctx,
|
h.ctx,
|
||||||
logger,
|
logger,
|
||||||
getEnvBinDirPath(),
|
getEnvBinDirPath(),
|
||||||
networkConfig,
|
|
||||||
joiningBootstrap,
|
joiningBootstrap,
|
||||||
stateDir,
|
stateDir,
|
||||||
runtimeDir,
|
runtimeDir,
|
||||||
@ -319,16 +298,7 @@ func (h *integrationHarness) joinNetwork(
|
|||||||
t.Fatalf("joining network: %v", err)
|
t.Fatalf("joining network: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !opts.manualShutdown {
|
nh := &integrationHarnessNetwork{
|
||||||
t.Cleanup(func() {
|
|
||||||
t.Logf("Shutting down Network %q", hostNameStr)
|
|
||||||
if err := joinedNetwork.Shutdown(); err != nil {
|
|
||||||
t.Logf("Shutting down Network %q failed: %v", hostNameStr, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return integrationHarnessNetwork{
|
|
||||||
joinedNetwork,
|
joinedNetwork,
|
||||||
h.ctx,
|
h.ctx,
|
||||||
logger,
|
logger,
|
||||||
@ -337,6 +307,34 @@ func (h *integrationHarness) joinNetwork(
|
|||||||
runtimeDir,
|
runtimeDir,
|
||||||
networkOpts,
|
networkOpts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !opts.manualShutdown {
|
||||||
|
t.Cleanup(func() {
|
||||||
|
t.Logf("Shutting down Network %q", hostNameStr)
|
||||||
|
if err := nh.Shutdown(); err != nil {
|
||||||
|
t.Logf("Shutting down Network %q failed: %v", hostNameStr, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nh
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nh *integrationHarnessNetwork) restart(t *testing.T) {
|
||||||
|
t.Log("Shutting down network (restart)")
|
||||||
|
require.NoError(t, nh.Network.Shutdown())
|
||||||
|
|
||||||
|
t.Log("Loading network (restart)")
|
||||||
|
var err error
|
||||||
|
nh.Network, err = load(
|
||||||
|
nh.ctx,
|
||||||
|
nh.logger,
|
||||||
|
getEnvBinDirPath(),
|
||||||
|
nh.stateDir,
|
||||||
|
nh.runtimeDir,
|
||||||
|
nh.opts,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nh *integrationHarnessNetwork) getConfig(t *testing.T) daecommon.NetworkConfig {
|
func (nh *integrationHarnessNetwork) getConfig(t *testing.T) daecommon.NetworkConfig {
|
||||||
|
@ -3,6 +3,7 @@ package daemon
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"isle/bootstrap"
|
"isle/bootstrap"
|
||||||
|
"isle/daemon/daecommon"
|
||||||
"isle/daemon/jsonrpc2"
|
"isle/daemon/jsonrpc2"
|
||||||
"isle/daemon/network"
|
"isle/daemon/network"
|
||||||
"isle/nebula"
|
"isle/nebula"
|
||||||
@ -25,13 +26,23 @@ type RPC interface {
|
|||||||
|
|
||||||
GetNetworks(context.Context) ([]bootstrap.CreationParams, error)
|
GetNetworks(context.Context) ([]bootstrap.CreationParams, error)
|
||||||
|
|
||||||
|
// SetConfig extends the [network.RPC] method of the same name such that
|
||||||
|
// [ErrUserManagedNetworkConfig] is returned if the picked network is
|
||||||
|
// configured as part of the [daecommon.Config] which the Daemon was
|
||||||
|
// initialized with.
|
||||||
|
//
|
||||||
|
// See the `network.RPC` documentation in this interface for more usage
|
||||||
|
// details.
|
||||||
|
SetConfig(context.Context, daecommon.NetworkConfig) error
|
||||||
|
|
||||||
// All network.RPC methods are automatically implemented by Daemon using the
|
// All network.RPC methods are automatically implemented by Daemon using the
|
||||||
// currently joined network. If no network is joined then any call to these
|
// currently joined network. If no network is joined then any call to these
|
||||||
// methods will return ErrNoNetwork.
|
// methods will return ErrNoNetwork.
|
||||||
//
|
//
|
||||||
// All calls to these methods must be accompanied with a context produced by
|
// If more than one Network is joined then all calls to these methods must
|
||||||
// WithNetwork, in order to choose the network. These methods may return
|
// be accompanied with a context produced by WithNetwork, in order to choose
|
||||||
// these errors, in addition to those documented on the individual methods:
|
// the network. These methods may return these errors, in addition to those
|
||||||
|
// documented on the individual methods:
|
||||||
//
|
//
|
||||||
// Errors:
|
// Errors:
|
||||||
// - ErrNoNetwork
|
// - ErrNoNetwork
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MarkIntegrationTest marks a test as being an integration test. It will be
|
// MarkIntegrationTest marks a test as being an integration test. It will be
|
||||||
@ -31,3 +32,10 @@ func NewTestLogger(t *testing.T) *mlog.Logger {
|
|||||||
MaxLevel: level.Int(),
|
MaxLevel: level.Int(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MockArg returns a value which can be used as a [mock.Call] argument, and
|
||||||
|
// which will match any value of type T. If T is an interface then also values
|
||||||
|
// implementing that interface will be matched.
|
||||||
|
func MockArg[T any]() any {
|
||||||
|
return mock.MatchedBy(func(T) bool { return true })
|
||||||
|
}
|
||||||
|
29
go/yamlutil/testutils.go
Normal file
29
go/yamlutil/testutils.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package yamlutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReplacePrefixTabs will replace the leading tabs of each line with two spaces.
|
||||||
|
// Tabs are not a valid indent in yaml (wtf), but they are convenient to use in
|
||||||
|
// go when using multi-line strings.
|
||||||
|
//
|
||||||
|
// ReplacePrefixTabs should only be used within tests.
|
||||||
|
func ReplacePrefixTabs(str string) string {
|
||||||
|
lines := strings.Split(str, "\n")
|
||||||
|
for i := range lines {
|
||||||
|
var n int
|
||||||
|
for _, r := range lines[i] {
|
||||||
|
if r == '\t' {
|
||||||
|
n++
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
spaces := strings.Repeat(" ", n)
|
||||||
|
lines[i] = spaces + lines[i][n:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user