From 723642e13bca60614c62a6bac60bc418420633bc Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Fri, 6 Dec 2024 15:42:37 +0100 Subject: [PATCH] Remove 'nebula show' subcmd and add that data to 'network list' --- go/cmd/entrypoint/nebula.go | 61 --------- go/cmd/entrypoint/network.go | 75 ++++++++++- go/cmd/entrypoint/network_test.go | 217 ++++++++++++++++++++++++++++++ go/daemon/ctx.go | 11 ++ 4 files changed, 302 insertions(+), 62 deletions(-) create mode 100644 go/cmd/entrypoint/network_test.go diff --git a/go/cmd/entrypoint/nebula.go b/go/cmd/entrypoint/nebula.go index f25a929..9354ca1 100644 --- a/go/cmd/entrypoint/nebula.go +++ b/go/cmd/entrypoint/nebula.go @@ -63,73 +63,12 @@ var subCmdNebulaCreateCert = subCmd{ }, } -var subCmdNebulaShow = subCmd{ - name: "show", - descr: "Writes nebula network information to stdout in JSON format", - do: doWithOutput(func(ctx subCmdCtx) (any, error) { - ctx, err := ctx.withParsedFlags() - if err != nil { - return nil, fmt.Errorf("parsing flags: %w", err) - } - - currBoostrap, err := ctx.getDaemonRPC().GetBootstrap(ctx) - if err != nil { - return nil, fmt.Errorf("calling GetBootstrap: %w", err) - } - - var ( - hosts = currBoostrap.HostsOrdered() - caPublicCreds = currBoostrap.CAPublicCredentials - ) - - caCert := caPublicCreds.Cert - caCertDetails := caCert.Unwrap().Details - - if len(caCertDetails.Subnets) != 1 { - return nil, fmt.Errorf( - "malformed ca.crt, contains unexpected subnets %#v", - caCertDetails.Subnets, - ) - } - - subnet := caCertDetails.Subnets[0] - - type outLighthouse struct { - PublicAddr string - IP string - } - - out := struct { - CACert nebula.Certificate - SubnetCIDR string - Lighthouses []outLighthouse - }{ - CACert: caCert, - SubnetCIDR: subnet.String(), - } - - for _, h := range hosts { - if h.Nebula.PublicAddr == "" { - continue - } - - out.Lighthouses = append(out.Lighthouses, outLighthouse{ - PublicAddr: h.Nebula.PublicAddr, - IP: h.IP().String(), - }) - } - - return out, nil - }), -} - var subCmdNebula = subCmd{ name: "nebula", descr: "Sub-commands related to the nebula VPN", do: func(ctx subCmdCtx) error { return ctx.doSubCmd( subCmdNebulaCreateCert, - subCmdNebulaShow, ) }, } diff --git a/go/cmd/entrypoint/network.go b/go/cmd/entrypoint/network.go index 603b3d8..742c705 100644 --- a/go/cmd/entrypoint/network.go +++ b/go/cmd/entrypoint/network.go @@ -1,10 +1,15 @@ package main import ( + "cmp" "errors" "fmt" + "isle/bootstrap" + "isle/daemon" "isle/daemon/network" "isle/jsonutil" + "isle/nebula" + "slices" ) var subCmdNetworkCreate = subCmd{ @@ -102,7 +107,75 @@ var subCmdNetworkList = subCmd{ return nil, fmt.Errorf("parsing flags: %w", err) } - return ctx.getDaemonRPC().GetNetworks(ctx) + networkCreationParams, err := ctx.getDaemonRPC().GetNetworks(ctx) + if err != nil { + return nil, fmt.Errorf("calling GetNetworks: %w", err) + } + + type lighthouseView struct { + PublicAddr string `yaml:"public_addr,omitempty"` + IP string `yaml:"ip"` + } + + type networkView struct { + bootstrap.CreationParams `yaml:",inline"` + CACert nebula.Certificate `yaml:"ca_cert"` + SubnetCIDR string `yaml:"subnet_cidr"` + Lighthouses []lighthouseView `yaml:"lighthouses"` + } + + var ( + daemonRPC = ctx.getDaemonRPC() + networkViews = make([]networkView, len(networkCreationParams)) + ) + + for i, creationParams := range networkCreationParams { + ctx := daemon.WithNetwork(ctx, creationParams.ID) + + networkBootstrap, err := daemonRPC.GetBootstrap(ctx) + if err != nil { + return nil, fmt.Errorf( + "calling GetBootstrap with network:%+v: %w", + networkCreationParams, + err, + ) + } + + var ( + caCert = networkBootstrap.CAPublicCredentials.Cert + caCertDetails = caCert.Unwrap().Details + subnet = caCertDetails.Subnets[0] + + lighthouseViews []lighthouseView + ) + + for _, h := range networkBootstrap.HostsOrdered() { + if h.Nebula.PublicAddr == "" { + continue + } + + lighthouseViews = append(lighthouseViews, lighthouseView{ + PublicAddr: h.Nebula.PublicAddr, + IP: h.IP().String(), + }) + } + + networkViews[i] = networkView{ + CreationParams: creationParams, + CACert: caCert, + SubnetCIDR: subnet.String(), + Lighthouses: lighthouseViews, + } + } + + slices.SortFunc(networkViews, func(a, b networkView) int { + return cmp.Or( + cmp.Compare(a.Name, b.Name), + cmp.Compare(a.ID, b.ID), + ) + }) + + return networkViews, nil }), } diff --git a/go/cmd/entrypoint/network_test.go b/go/cmd/entrypoint/network_test.go new file mode 100644 index 0000000..a6e5931 --- /dev/null +++ b/go/cmd/entrypoint/network_test.go @@ -0,0 +1,217 @@ +package main + +import ( + "context" + "fmt" + "isle/bootstrap" + "isle/daemon" + "isle/nebula" + "isle/toolkit" + "net/netip" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNetworkList(t *testing.T) { + t.Parallel() + + type networkBase struct { + bootstrap.CreationParams + ipNet nebula.IPNet + caCreds nebula.CACredentials + caCertPEM string + } + + newNetworkBase := func(id, name, domain, ipNetStr string) networkBase { + var ipNet nebula.IPNet + require.NoError(t, ipNet.UnmarshalText([]byte(ipNetStr))) + + caCreds, err := nebula.NewCACredentials(domain, ipNet) + require.NoError(t, err) + + caCertPEM, err := caCreds.Public.Cert.MarshalText() + require.NoError(t, err) + + return networkBase{ + CreationParams: bootstrap.CreationParams{ + ID: id, + Name: name, + Domain: domain, + }, + ipNet: ipNet, + caCreds: caCreds, + caCertPEM: string(caCertPEM), + } + } + + var ( + networkBaseA = newNetworkBase("idA", "nameA", "a.com", "172.16.1.0/24") + networkBaseB = newNetworkBase("idB", "nameB", "b.com", "172.16.2.0/24") + ) + + type host struct { + ip string + publicAddr string + } + + type network struct { + networkBase + hosts []host + } + + tests := []struct { + name string + networks []network + want []map[string]any + }{ + { + name: "no networks", + want: []map[string]any{}, + }, + { + name: "single", + networks: []network{ + { + networkBase: networkBaseA, + hosts: []host{ + { + ip: "172.16.1.1", + publicAddr: "1.1.1.1:80", + }, + }, + }, + }, + want: []map[string]any{ + { + "id": "idA", + "name": "nameA", + "domain": "a.com", + "ca_cert": networkBaseA.caCertPEM, + "subnet_cidr": "172.16.1.0/24", + "lighthouses": []any{ + map[string]any{ + "ip": "172.16.1.1", + "public_addr": "1.1.1.1:80", + }, + }, + }, + }, + }, + { + name: "multiple", + networks: []network{ + { + networkBase: networkBaseB, + hosts: []host{ + { + ip: "172.16.2.1", + publicAddr: "2.2.2.2:80", + }, + { + ip: "172.16.2.2", + publicAddr: "3.3.3.3:80", + }, + { + ip: "172.16.2.3", + }, + }, + }, + { + networkBase: networkBaseA, + hosts: []host{ + { + ip: "172.16.1.1", + publicAddr: "1.1.1.1:80", + }, + }, + }, + }, + want: []map[string]any{ + { + "id": "idA", + "name": "nameA", + "domain": "a.com", + "ca_cert": networkBaseA.caCertPEM, + "subnet_cidr": "172.16.1.0/24", + "lighthouses": []any{ + map[string]any{ + "ip": "172.16.1.1", + "public_addr": "1.1.1.1:80", + }, + }, + }, + { + "id": "idB", + "name": "nameB", + "domain": "b.com", + "ca_cert": networkBaseB.caCertPEM, + "subnet_cidr": "172.16.2.0/24", + "lighthouses": []any{ + map[string]any{ + "ip": "172.16.2.1", + "public_addr": "2.2.2.2:80", + }, + map[string]any{ + "ip": "172.16.2.2", + "public_addr": "3.3.3.3:80", + }, + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var ( + h = newRunHarness(t) + creationParams = make([]bootstrap.CreationParams, len(test.networks)) + ) + + for i, testNetwork := range test.networks { + creationParams[i] = testNetwork.CreationParams + + hosts := map[nebula.HostName]bootstrap.Host{} + for _, testHost := range testNetwork.hosts { + var ( + hostName nebula.HostName + ip = netip.MustParseAddr(testHost.ip) + hostNameStr = fmt.Sprintf("host%d", len(hosts)) + ) + require.NoError( + t, hostName.UnmarshalText([]byte(hostNameStr)), + ) + + host, _, err := bootstrap.NewHost( + testNetwork.caCreds, hostName, ip, + ) + require.NoError(t, err) + + host.Nebula.PublicAddr = testHost.publicAddr + + hosts[hostName] = host + } + + h.daemonRPC. + On( + "GetBootstrap", + daemon.MockContextWithNetwork(testNetwork.ID), + ). + Return(bootstrap.Bootstrap{ + NetworkCreationParams: creationParams[i], + CAPublicCredentials: testNetwork.caCreds.Public, + Hosts: hosts, + }, nil). + Once() + } + + h.daemonRPC. + On("GetNetworks", toolkit.MockArg[context.Context]()). + Return(creationParams, nil). + Once() + + h.runAssertStdout(t, test.want, "network", "list") + }) + } +} diff --git a/go/daemon/ctx.go b/go/daemon/ctx.go index 5f001e6..d259f32 100644 --- a/go/daemon/ctx.go +++ b/go/daemon/ctx.go @@ -3,6 +3,8 @@ package daemon import ( "context" "isle/daemon/jsonrpc2" + + "github.com/stretchr/testify/mock" ) const metaKeyNetworkSearchStr = "daemon.networkSearchStr" @@ -18,3 +20,12 @@ func getNetworkSearchStr(ctx context.Context) string { v, _ := jsonrpc2.GetMeta(ctx)[metaKeyNetworkSearchStr].(string) return v } + +// MockContextWithNetwork returns a value which can be used with the +// tesstify/mock package to match a context which has a search string added to +// it by WithNetwork. +func MockContextWithNetwork(searchStr string) any { + return mock.MatchedBy(func(ctx context.Context) bool { + return getNetworkSearchStr(ctx) == searchStr + }) +}