From c21b3e0c33cf2284b906ef51c688bbbc21c0e88f Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Thu, 12 Dec 2024 21:05:36 +0100 Subject: [PATCH] Test daemon config validation, fix a bug which came out of it --- go/daemon/daecommon/config.go | 30 +++++++-------- go/daemon/daecommon/config_test.go | 59 ++++++++++++++++++++++-------- 2 files changed, 57 insertions(+), 32 deletions(-) diff --git a/go/daemon/daecommon/config.go b/go/daemon/daecommon/config.go index e6e6672..d7db927 100644 --- a/go/daemon/daecommon/config.go +++ b/go/daemon/daecommon/config.go @@ -203,25 +203,23 @@ func (c Config) Validate() error { nebulaPorts := map[string]string{} for id, network := range c.Networks { - if network.VPN.PublicAddr == "" { - continue - } + if network.VPN.PublicAddr != "" { + _, port, err := net.SplitHostPort(network.VPN.PublicAddr) + if err != nil { + return fmt.Errorf( + "invalid vpn.public_addr %q: %w", network.VPN.PublicAddr, err, + ) + } else if otherID, ok := nebulaPorts[port]; ok { + return fmt.Errorf( + "two networks with the same vpn.public_addr port: %q and %q", + id, + otherID, + ) + } - _, port, err := net.SplitHostPort(network.VPN.PublicAddr) - if err != nil { - return fmt.Errorf( - "invalid vpn.public_addr %q: %w", network.VPN.PublicAddr, err, - ) - } else if otherID, ok := nebulaPorts[port]; ok { - return fmt.Errorf( - "two networks with the same vpn.public_addr: %q and %q", - id, - otherID, - ) + nebulaPorts[port] = id } - nebulaPorts[port] = id - if err := network.Validate(); err != nil { return fmt.Errorf("invalid config for network %q: %w", id, err) } diff --git a/go/daemon/daecommon/config_test.go b/go/daemon/daecommon/config_test.go index 4ccd1b5..2bd0f51 100644 --- a/go/daemon/daecommon/config_test.go +++ b/go/daemon/daecommon/config_test.go @@ -8,19 +8,24 @@ import ( "gopkg.in/yaml.v3" ) -func TestConfig_UnmarshalYAML(t *testing.T) { // TODO test validation +func TestConfig_UnmarshalYAML(t *testing.T) { tests := []struct { - label string - str string - config Config + label string + str string + wantConfig Config + wantErr string }{ - {"empty", ``, Config{}}, { - "DEPRECATED single global network", - ` + label: "empty", + str: ``, + wantConfig: Config{}, + }, + { + label: "DEPRECATED single global network", + str: ` {"dns":{"resolvers":["a"]}} `, - Config{ + wantConfig: Config{ Networks: map[string]NetworkConfig{ DeprecatedNetworkID: NewNetworkConfig(func(c *NetworkConfig) { c.DNS.Resolvers = []string{"a"} @@ -29,12 +34,12 @@ func TestConfig_UnmarshalYAML(t *testing.T) { // TODO test validation }, }, { - "single network", - ` + label: "single network", + str: ` networks: foo: {"dns":{"resolvers":["a"]}} `, - Config{ + wantConfig: Config{ Networks: map[string]NetworkConfig{ "foo": NewNetworkConfig(func(c *NetworkConfig) { c.DNS.Resolvers = []string{"a"} @@ -43,13 +48,13 @@ func TestConfig_UnmarshalYAML(t *testing.T) { // TODO test validation }, }, { - "multiple networks", - ` + label: "multiple networks", + str: ` networks: foo: {"dns":{"resolvers":["a"]}} bar: {} `, - Config{ + wantConfig: Config{ Networks: map[string]NetworkConfig{ "foo": NewNetworkConfig(func(c *NetworkConfig) { c.DNS.Resolvers = []string{"a"} @@ -58,6 +63,24 @@ func TestConfig_UnmarshalYAML(t *testing.T) { // TODO test validation }, }, }, + { + label: "err/shared vpn.public_addr port", + str: ` + networks: + foo: {"vpn":{"public_addr":"1.1.1.1:4001"}} + bar: {"vpn":{"public_addr":"1.1.1.1:4000"}} + baz: {"vpn":{"public_addr":"2.2.2.2:4000"}} + `, + wantErr: `two networks with the same vpn.public_addr port: "baz" and "bar"`, + }, + { + label: "err/invalid firewall", + str: ` + networks: + foo: {"vpn":{"firewall":{"inbound":[{"host":"f","port":"no"}]}}} + `, + wantErr: "port was not a number", + }, } for _, test := range tests { @@ -65,8 +88,12 @@ func TestConfig_UnmarshalYAML(t *testing.T) { // TODO test validation test.str = yamlutil.ReplacePrefixTabs(test.str) var config Config err := yaml.Unmarshal([]byte(test.str), &config) - assert.NoError(t, err) - assert.Equal(t, test.config, config) + if test.wantErr != "" { + assert.ErrorContains(t, err, test.wantErr) + } else { + assert.NoError(t, err) + assert.Equal(t, test.wantConfig, config) + } }) } }