From df5ece950ad909edf4fbc9203cffcef673438074 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Sat, 7 Dec 2024 20:39:13 +0100 Subject: [PATCH] Implement GetBootstrap to replace other redundant methods --- go/bootstrap/bootstrap.go | 61 +++++----- go/bootstrap/hosts.go | 25 ++++ go/cmd/entrypoint/client.go | 9 -- go/cmd/entrypoint/host.go | 49 +++++--- go/cmd/entrypoint/host_test.go | 134 ++++++++++++++++++++++ go/cmd/entrypoint/main_test.go | 41 +++++-- go/cmd/entrypoint/nebula.go | 12 +- go/daemon/client.go | 27 ++--- go/daemon/daemon.go | 36 ++---- go/daemon/jsonrpc2/server_http.go | 9 +- go/daemon/network/network.go | 46 ++------ go/daemon/network/network_it_test.go | 13 +++ go/daemon/network/network_it_util_test.go | 10 +- go/daemon/network/network_mock.go | 86 +++++--------- go/daemon/rpc.go | 16 ++- go/daemon/rpc_mock.go | 86 +++++--------- 16 files changed, 371 insertions(+), 289 deletions(-) delete mode 100644 go/cmd/entrypoint/client.go create mode 100644 go/cmd/entrypoint/host_test.go diff --git a/go/bootstrap/bootstrap.go b/go/bootstrap/bootstrap.go index edbd38f..6c96fd2 100644 --- a/go/bootstrap/bootstrap.go +++ b/go/bootstrap/bootstrap.go @@ -3,7 +3,7 @@ package bootstrap import ( - "crypto/sha512" + "cmp" "encoding/json" "fmt" "isle/nebula" @@ -11,7 +11,7 @@ import ( "maps" "net/netip" "path/filepath" - "sort" + "slices" "strings" "dev.mediocregopher.com/mediocre-go-lib.git/mctx" @@ -110,33 +110,26 @@ func New( ) ( Bootstrap, error, ) { - hostPubCreds, hostPrivCreds, err := nebula.NewHostCredentials( - caCreds, name, ip, - ) + host, hostPrivCreds, err := NewHost(caCreds, name, ip) if err != nil { - return Bootstrap{}, fmt.Errorf("generating host credentials: %w", err) + return Bootstrap{}, fmt.Errorf("creating host: %w", err) } - assigned := HostAssigned{ - Name: name, - PublicCredentials: hostPubCreds, - } - - signedAssigned, err := nebula.Sign(assigned, caCreds.SigningPrivateKey) + signedAssigned, err := nebula.Sign( + host.HostAssigned, caCreds.SigningPrivateKey, + ) if err != nil { return Bootstrap{}, fmt.Errorf("signing assigned fields: %w", err) } existingHosts = maps.Clone(existingHosts) - existingHosts[name] = Host{ - HostAssigned: assigned, - } + existingHosts[name] = host return Bootstrap{ NetworkCreationParams: adminCreationParams, CAPublicCredentials: caCreds.Public, PrivateCredentials: hostPrivCreds, - HostAssigned: assigned, + HostAssigned: host.HostAssigned, SignedHostAssigned: signedAssigned, Hosts: existingHosts, }, nil @@ -150,14 +143,18 @@ func (b *Bootstrap) UnmarshalJSON(data []byte) error { err := json.Unmarshal(data, (*inner)(b)) if err != nil { - return err + return fmt.Errorf("json unmarshaling: %w", err) } - b.HostAssigned, err = b.SignedHostAssigned.Unwrap( - b.CAPublicCredentials.SigningKey, - ) - if err != nil { - return fmt.Errorf("unwrapping HostAssigned: %w", err) + // Generally this will be filled, but during unit tests we sometimes leave + // it empty for convenience. + if b.SignedHostAssigned != nil { + b.HostAssigned, err = b.SignedHostAssigned.Unwrap( + b.CAPublicCredentials.SigningKey, + ) + if err != nil { + return fmt.Errorf("unwrapping HostAssigned: %w", err) + } } return nil @@ -174,22 +171,16 @@ func (b Bootstrap) ThisHost() Host { return host } -// Hash returns a deterministic hash of the given hosts map. -func HostsHash(hostsMap map[nebula.HostName]Host) ([]byte, error) { - - hosts := make([]Host, 0, len(hostsMap)) - for _, host := range hostsMap { +// HostsOrdered returns the Hosts as a slice in a deterministic order. +func (b Bootstrap) HostsOrdered() []Host { + hosts := make([]Host, 0, len(b.Hosts)) + for _, host := range b.Hosts { hosts = append(hosts, host) } - sort.Slice(hosts, func(i, j int) bool { - return hosts[i].Name < hosts[j].Name + slices.SortFunc(hosts, func(a, b Host) int { + return cmp.Compare(a.Name, b.Name) }) - h := sha512.New() - if err := json.NewEncoder(h).Encode(hosts); err != nil { - return nil, err - } - - return h.Sum(nil), nil + return hosts } diff --git a/go/bootstrap/hosts.go b/go/bootstrap/hosts.go index b4880da..f0ded15 100644 --- a/go/bootstrap/hosts.go +++ b/go/bootstrap/hosts.go @@ -73,6 +73,31 @@ type Host struct { HostConfigured } +// NewHost creates a Host instance using the given assigned fields, along with +// the HostPrivateCredentials which its PublicCredentials field. +func NewHost( + caCreds nebula.CACredentials, name nebula.HostName, ip netip.Addr, +) ( + host Host, hostPrivCreds nebula.HostPrivateCredentials, err error, +) { + hostPubCreds, hostPrivCreds, err := nebula.NewHostCredentials( + caCreds, name, ip, + ) + if err != nil { + err = fmt.Errorf("generating host credentials: %w", err) + return + } + + host = Host{ + HostAssigned: HostAssigned{ + Name: name, + PublicCredentials: hostPubCreds, + }, + } + + return +} + // IP returns the IP address encoded in the Host's nebula certificate, or panics // if there is an error. // diff --git a/go/cmd/entrypoint/client.go b/go/cmd/entrypoint/client.go deleted file mode 100644 index 8d65b79..0000000 --- a/go/cmd/entrypoint/client.go +++ /dev/null @@ -1,9 +0,0 @@ -package main - -import ( - "isle/bootstrap" -) - -func (ctx subCmdCtx) getHosts() ([]bootstrap.Host, error) { - return ctx.getDaemonRPC().GetHosts(ctx) -} diff --git a/go/cmd/entrypoint/host.go b/go/cmd/entrypoint/host.go index 650a4c5..3291498 100644 --- a/go/cmd/entrypoint/host.go +++ b/go/cmd/entrypoint/host.go @@ -4,10 +4,8 @@ import ( "encoding/json" "errors" "fmt" - "isle/bootstrap" "isle/daemon/network" "os" - "sort" ) var subCmdHostCreate = subCmd{ @@ -65,35 +63,48 @@ var subCmdHostList = subCmd{ return nil, fmt.Errorf("parsing flags: %w", err) } - hostsRes, err := ctx.getHosts() + currBoostrap, err := ctx.getDaemonRPC().GetBootstrap(ctx) if err != nil { - return nil, fmt.Errorf("calling GetHosts: %w", err) + return nil, fmt.Errorf("calling GetBootstrap: %w", err) } - type host struct { - Name string + hosts := currBoostrap.HostsOrdered() + + type storageInstanceView struct { + ID string `yaml:"id"` + RPCPort int `yaml:"rpc_port"` + S3APIPort int `yaml:"s3_api_port"` + } + + type hostView struct { + Name string `yaml:"name"` VPN struct { - IP string + IP string `yaml:"ip"` + PublicAddr string `yaml:"public_addr,omitempty"` } - Storage bootstrap.GarageHost `json:",omitempty"` + Storage struct { + Instances []storageInstanceView `yaml:"instances"` + } `yaml:",omitempty"` } - hosts := make([]host, 0, len(hostsRes)) - for _, h := range hostsRes { - - host := host{ - Name: string(h.Name), - Storage: h.Garage, + hostViews := make([]hostView, len(hosts)) + for i, host := range hosts { + storageInstanceViews := make([]storageInstanceView, len(host.Garage.Instances)) + for i := range host.Garage.Instances { + storageInstanceViews[i] = storageInstanceView(host.Garage.Instances[i]) } - host.VPN.IP = h.IP().String() + hostView := hostView{ + Name: string(host.Name), + } - hosts = append(hosts, host) + hostView.VPN.IP = host.IP().String() + hostView.VPN.PublicAddr = host.Nebula.PublicAddr + hostView.Storage.Instances = storageInstanceViews + hostViews[i] = hostView } - sort.Slice(hosts, func(i, j int) bool { return hosts[i].Name < hosts[j].Name }) - - return hosts, nil + return hostViews, nil }), } diff --git a/go/cmd/entrypoint/host_test.go b/go/cmd/entrypoint/host_test.go new file mode 100644 index 0000000..8274a51 --- /dev/null +++ b/go/cmd/entrypoint/host_test.go @@ -0,0 +1,134 @@ +package main + +import ( + "context" + "isle/bootstrap" + "isle/nebula" + "isle/toolkit" + "net/netip" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHostList(t *testing.T) { + t.Parallel() + + var ipNet nebula.IPNet + require.NoError(t, ipNet.UnmarshalText([]byte("172.16.0.0/16"))) + + caCreds, err := nebula.NewCACredentials("test.com", ipNet) + require.NoError(t, err) + + type host struct { + name string + ip string + publicAddr string + storageInstances []bootstrap.GarageHostInstance + } + + tests := []struct { + name string + hosts []host + want any + }{ + { + name: "no hosts", + want: []any{}, + }, + { + name: "single", + hosts: []host{ + { + name: "a", + ip: "172.16.0.1", + }, + }, + want: []map[string]any{ + { + "name": "a", + "vpn": map[string]any{ + "ip": "172.16.0.1", + }, + }, + }, + }, + { + name: "multiple", + hosts: []host{ + { + name: "a", + ip: "172.16.0.1", + }, + { + name: "b", + ip: "172.16.0.2", + publicAddr: "1.1.1.1:80", + storageInstances: []bootstrap.GarageHostInstance{{ + ID: "storageInstanceID", + RPCPort: 9000, + S3APIPort: 9001, + }}, + }, + }, + want: []map[string]any{ + { + "name": "a", + "vpn": map[string]any{ + "ip": "172.16.0.1", + }, + }, + { + "name": "b", + "vpn": map[string]any{ + "ip": "172.16.0.2", + "public_addr": "1.1.1.1:80", + }, + "storage": map[string]any{ + "instances": []any{ + map[string]any{ + "id": "storageInstanceID", + "rpc_port": 9000, + "s3_api_port": 9001, + }, + }, + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var ( + h = newRunHarness(t) + hosts = map[nebula.HostName]bootstrap.Host{} + ) + + for _, testHost := range test.hosts { + var ( + hostName nebula.HostName + ip = netip.MustParseAddr(testHost.ip) + ) + require.NoError( + t, hostName.UnmarshalText([]byte(testHost.name)), + ) + + host, _, err := bootstrap.NewHost(caCreds, hostName, ip) + require.NoError(t, err) + + host.Nebula.PublicAddr = testHost.publicAddr + host.Garage.Instances = testHost.storageInstances + + hosts[hostName] = host + } + + h.daemonRPC. + On("GetBootstrap", toolkit.MockArg[context.Context]()). + Return(bootstrap.Bootstrap{Hosts: hosts}, nil). + Once() + + h.runAssertStdout(t, test.want, "host", "list") + }) + } +} diff --git a/go/cmd/entrypoint/main_test.go b/go/cmd/entrypoint/main_test.go index 2954537..34540ac 100644 --- a/go/cmd/entrypoint/main_test.go +++ b/go/cmd/entrypoint/main_test.go @@ -4,7 +4,9 @@ import ( "bytes" "context" "isle/daemon" + "isle/daemon/jsonrpc2" "isle/toolkit" + "net/http/httptest" "reflect" "testing" @@ -14,29 +16,46 @@ import ( ) type runHarness struct { - ctx context.Context - logger *mlog.Logger - daemonRPC *daemon.MockRPC - stdout *bytes.Buffer + ctx context.Context + logger *mlog.Logger + stdout *bytes.Buffer + daemonRPC *daemon.MockRPC + daemonRPCServer *httptest.Server } func newRunHarness(t *testing.T) *runHarness { t.Parallel() var ( - ctx = context.Background() - logger = toolkit.NewTestLogger(t) - daemonRPC = daemon.NewMockRPC(t) - stdout = new(bytes.Buffer) + ctx = context.Background() + logger = toolkit.NewTestLogger(t) + stdout = new(bytes.Buffer) + + daemonRPC = daemon.NewMockRPC(t) + daemonRPCHandler = jsonrpc2.NewHTTPHandler(daemon.NewRPCHandler( + logger.WithNamespace("rpc"), daemonRPC, + )) + daemonRPCServer = httptest.NewServer(daemonRPCHandler) ) - return &runHarness{ctx, logger, daemonRPC, stdout} + t.Cleanup(daemonRPCServer.Close) + + return &runHarness{ctx, logger, stdout, daemonRPC, daemonRPCServer} } -func (h *runHarness) run(_ *testing.T, args ...string) error { +func (h *runHarness) run(t *testing.T, args ...string) error { + httpClient := toolkit.NewHTTPClient(h.logger.WithNamespace("http")) + t.Cleanup(func() { + assert.NoError(t, httpClient.Close()) + }) + + daemonRPCClient := daemon.RPCFromClient( + jsonrpc2.NewHTTPClient(httpClient, h.daemonRPCServer.URL), + ) + return doRootCmd(h.ctx, h.logger, &subCmdCtxOpts{ args: args, - daemonRPC: h.daemonRPC, + daemonRPC: daemonRPCClient, stdout: h.stdout, }) } diff --git a/go/cmd/entrypoint/nebula.go b/go/cmd/entrypoint/nebula.go index f6cea20..f25a929 100644 --- a/go/cmd/entrypoint/nebula.go +++ b/go/cmd/entrypoint/nebula.go @@ -72,15 +72,15 @@ var subCmdNebulaShow = subCmd{ return nil, fmt.Errorf("parsing flags: %w", err) } - hosts, err := ctx.getHosts() + currBoostrap, err := ctx.getDaemonRPC().GetBootstrap(ctx) if err != nil { - return nil, fmt.Errorf("getting hosts: %w", err) + return nil, fmt.Errorf("calling GetBootstrap: %w", err) } - caPublicCreds, err := ctx.getDaemonRPC().GetNebulaCAPublicCredentials(ctx) - if err != nil { - return nil, fmt.Errorf("calling GetNebulaCAPublicCredentials: %w", err) - } + var ( + hosts = currBoostrap.HostsOrdered() + caPublicCreds = currBoostrap.CAPublicCredentials + ) caCert := caPublicCreds.Cert caCertDetails := caCert.Unwrap().Details diff --git a/go/daemon/client.go b/go/daemon/client.go index a2e2ea3..c780a11 100644 --- a/go/daemon/client.go +++ b/go/daemon/client.go @@ -60,6 +60,15 @@ func (c *rpcClient) CreateNetwork(ctx context.Context, name string, domain strin return } +func (c *rpcClient) GetBootstrap(ctx context.Context) (b1 bootstrap.Bootstrap, err error) { + err = c.client.Call( + ctx, + &b1, + "GetBootstrap", + ) + return +} + func (c *rpcClient) GetConfig(ctx context.Context) (n1 daecommon.NetworkConfig, err error) { err = c.client.Call( ctx, @@ -78,24 +87,6 @@ func (c *rpcClient) GetGarageClientParams(ctx context.Context) (g1 network.Garag return } -func (c *rpcClient) GetHosts(ctx context.Context) (ha1 []bootstrap.Host, err error) { - err = c.client.Call( - ctx, - &ha1, - "GetHosts", - ) - return -} - -func (c *rpcClient) GetNebulaCAPublicCredentials(ctx context.Context) (c2 nebula.CAPublicCredentials, err error) { - err = c.client.Call( - ctx, - &c2, - "GetNebulaCAPublicCredentials", - ) - return -} - func (c *rpcClient) GetNetworks(ctx context.Context) (ca1 []bootstrap.CreationParams, err error) { err = c.client.Call( ctx, diff --git a/go/daemon/daemon.go b/go/daemon/daemon.go index 970faa9..0bb8a04 100644 --- a/go/daemon/daemon.go +++ b/go/daemon/daemon.go @@ -242,13 +242,21 @@ func (d *Daemon) GetNetworks( return res, nil } -// GetHost implements the method for the network.RPC interface. -func (d *Daemon) GetHosts(ctx context.Context) ([]bootstrap.Host, error) { +// GetBootstrap implements the method for the network.RPC interface. +func (d *Daemon) GetBootstrap( + ctx context.Context, +) ( + bootstrap.Bootstrap, error, +) { return withNetwork( ctx, d, - func(ctx context.Context, n joinedNetwork) ([]bootstrap.Host, error) { - return n.GetHosts(ctx) + func( + ctx context.Context, n joinedNetwork, + ) ( + bootstrap.Bootstrap, error, + ) { + return n.GetBootstrap(ctx) }, ) } @@ -272,26 +280,6 @@ func (d *Daemon) GetGarageClientParams( ) } -// GetNebulaCAPublicCredentials implements the method for the network.RPC -// interface. -func (d *Daemon) GetNebulaCAPublicCredentials( - ctx context.Context, -) ( - nebula.CAPublicCredentials, error, -) { - return withNetwork( - ctx, - d, - func( - ctx context.Context, n joinedNetwork, - ) ( - nebula.CAPublicCredentials, error, - ) { - return n.GetNebulaCAPublicCredentials(ctx) - }, - ) -} - // RemoveHost implements the method for the network.RPC interface. func (d *Daemon) RemoveHost(ctx context.Context, hostName nebula.HostName) error { _, err := withNetwork( diff --git a/go/daemon/jsonrpc2/server_http.go b/go/daemon/jsonrpc2/server_http.go index d179bd0..baf2d42 100644 --- a/go/daemon/jsonrpc2/server_http.go +++ b/go/daemon/jsonrpc2/server_http.go @@ -3,6 +3,7 @@ package jsonrpc2 import ( "encoding/json" "errors" + "fmt" "net/http" ) @@ -50,7 +51,9 @@ func (h httpHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { switch rpcErr.Code { case 0: // no error - _ = encodeResponse(enc, req.ID, res) + if err := encodeResponse(enc, req.ID, res); err != nil { + panic(fmt.Errorf("encoding response %#v: %w", res, err)) + } return case errCodeMethodNotFound: @@ -63,5 +66,7 @@ func (h httpHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusBadRequest) } - _ = encodeErrorResponse(enc, req.ID, rpcErr) + if err := encodeErrorResponse(enc, req.ID, rpcErr); err != nil { + panic(fmt.Errorf("encoding error %+v: %w", rpcErr, err)) + } } diff --git a/go/daemon/network/network.go b/go/daemon/network/network.go index fc96f62..75c7ec0 100644 --- a/go/daemon/network/network.go +++ b/go/daemon/network/network.go @@ -6,7 +6,6 @@ package network import ( "bytes" - "cmp" "context" "crypto/rand" "encoding/json" @@ -21,13 +20,11 @@ import ( "isle/secrets" "isle/toolkit" "net/netip" - "slices" "sync" "time" "dev.mediocregopher.com/mediocre-go-lib.git/mctx" "dev.mediocregopher.com/mediocre-go-lib.git/mlog" - "golang.org/x/exp/maps" ) // GarageClientParams contains all the data needed to instantiate garage @@ -75,21 +72,14 @@ type JoiningBootstrap struct { // RPC defines the methods related to a single network which are available over // the daemon's RPC interface. type RPC interface { - // GetHosts returns all hosts known to the network, sorted by their name. - GetHosts(context.Context) ([]bootstrap.Host, error) + // GetBootstrap returns the currently active Bootstrap for the Network. The + // PrivateCredentials field will be zero'd out before being returned. + GetBootstrap(context.Context) (bootstrap.Bootstrap, error) // GetGarageClientParams returns a GarageClientParams for the current // network state. GetGarageClientParams(context.Context) (GarageClientParams, error) - // GetNebulaCAPublicCredentials returns the CAPublicCredentials for the - // network. - GetNebulaCAPublicCredentials( - context.Context, - ) ( - nebula.CAPublicCredentials, error, - ) - // RemoveHost removes the host of the given name from the network. RemoveHost(ctx context.Context, hostName nebula.HostName) error @@ -750,17 +740,18 @@ func (n *network) getBootstrap() ( }) } -func (n *network) GetHosts(ctx context.Context) ([]bootstrap.Host, error) { +func (n *network) GetBootstrap( + ctx context.Context, +) ( + bootstrap.Bootstrap, error, +) { return withCurrBootstrap(n, func( currBootstrap bootstrap.Bootstrap, ) ( - []bootstrap.Host, error, + bootstrap.Bootstrap, error, ) { - hosts := maps.Values(currBootstrap.Hosts) - slices.SortFunc(hosts, func(a, b bootstrap.Host) int { - return cmp.Compare(a.Name, b.Name) - }) - return hosts, nil + currBootstrap.PrivateCredentials = nebula.HostPrivateCredentials{} + return currBootstrap, nil }) } @@ -778,21 +769,6 @@ func (n *network) GetGarageClientParams( }) } -func (n *network) GetNebulaCAPublicCredentials( - ctx context.Context, -) ( - nebula.CAPublicCredentials, error, -) { - b, err := n.getBootstrap() - if err != nil { - return nebula.CAPublicCredentials{}, fmt.Errorf( - "retrieving bootstrap: %w", err, - ) - } - - return b.CAPublicCredentials, nil -} - func (n *network) RemoveHost(ctx context.Context, hostName nebula.HostName) error { // TODO RemoveHost should publish a certificate revocation for the host // being removed. diff --git a/go/daemon/network/network_it_test.go b/go/daemon/network/network_it_test.go index 8c453fb..5b0c595 100644 --- a/go/daemon/network/network_it_test.go +++ b/go/daemon/network/network_it_test.go @@ -85,6 +85,19 @@ func TestJoin(t *testing.T) { }) } +func TestNetwork_GetBootstrap(t *testing.T) { + var ( + h = newIntegrationHarness(t) + network = h.createNetwork(t, "primus", nil) + ) + + currBootstrap, err := network.GetBootstrap(h.ctx) + assert.NoError(t, err) + assert.Equal( + t, nebula.HostPrivateCredentials{}, currBootstrap.PrivateCredentials, + ) +} + func TestNetwork_SetConfig(t *testing.T) { t.Parallel() diff --git a/go/daemon/network/network_it_util_test.go b/go/daemon/network/network_it_util_test.go index a292110..6fd5796 100644 --- a/go/daemon/network/network_it_util_test.go +++ b/go/daemon/network/network_it_util_test.go @@ -364,13 +364,7 @@ func (nh *integrationHarnessNetwork) garageAdminClient( func (nh *integrationHarnessNetwork) getHostsByName( t *testing.T, ) map[nebula.HostName]bootstrap.Host { - hosts, err := nh.Network.GetHosts(nh.ctx) + currBootstrap, err := nh.Network.GetBootstrap(nh.ctx) require.NoError(t, err) - - hostsByName := map[nebula.HostName]bootstrap.Host{} - for _, h := range hosts { - hostsByName[h.Name] = h - } - - return hostsByName + return currBootstrap.Hosts } diff --git a/go/daemon/network/network_mock.go b/go/daemon/network/network_mock.go index 341dd4b..fc45164 100644 --- a/go/daemon/network/network_mock.go +++ b/go/daemon/network/network_mock.go @@ -74,6 +74,34 @@ func (_m *MockNetwork) CreateNebulaCertificate(_a0 context.Context, _a1 nebula.H return r0, r1 } +// GetBootstrap provides a mock function with given fields: _a0 +func (_m *MockNetwork) GetBootstrap(_a0 context.Context) (bootstrap.Bootstrap, error) { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for GetBootstrap") + } + + var r0 bootstrap.Bootstrap + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (bootstrap.Bootstrap, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(context.Context) bootstrap.Bootstrap); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(bootstrap.Bootstrap) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetConfig provides a mock function with given fields: _a0 func (_m *MockNetwork) GetConfig(_a0 context.Context) (daecommon.NetworkConfig, error) { ret := _m.Called(_a0) @@ -130,64 +158,6 @@ func (_m *MockNetwork) GetGarageClientParams(_a0 context.Context) (GarageClientP return r0, r1 } -// GetHosts provides a mock function with given fields: _a0 -func (_m *MockNetwork) GetHosts(_a0 context.Context) ([]bootstrap.Host, error) { - ret := _m.Called(_a0) - - if len(ret) == 0 { - panic("no return value specified for GetHosts") - } - - var r0 []bootstrap.Host - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) ([]bootstrap.Host, error)); ok { - return rf(_a0) - } - if rf, ok := ret.Get(0).(func(context.Context) []bootstrap.Host); ok { - r0 = rf(_a0) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]bootstrap.Host) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(_a0) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetNebulaCAPublicCredentials provides a mock function with given fields: _a0 -func (_m *MockNetwork) GetNebulaCAPublicCredentials(_a0 context.Context) (nebula.CAPublicCredentials, error) { - ret := _m.Called(_a0) - - if len(ret) == 0 { - panic("no return value specified for GetNebulaCAPublicCredentials") - } - - var r0 nebula.CAPublicCredentials - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (nebula.CAPublicCredentials, error)); ok { - return rf(_a0) - } - if rf, ok := ret.Get(0).(func(context.Context) nebula.CAPublicCredentials); ok { - r0 = rf(_a0) - } else { - r0 = ret.Get(0).(nebula.CAPublicCredentials) - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(_a0) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetNetworkCreationParams provides a mock function with given fields: _a0 func (_m *MockNetwork) GetNetworkCreationParams(_a0 context.Context) (bootstrap.CreationParams, error) { ret := _m.Called(_a0) diff --git a/go/daemon/rpc.go b/go/daemon/rpc.go index 63676f7..0cc436b 100644 --- a/go/daemon/rpc.go +++ b/go/daemon/rpc.go @@ -10,6 +10,8 @@ import ( "isle/daemon/network" "isle/nebula" "net/http" + + "dev.mediocregopher.com/mediocre-go-lib.git/mlog" ) // RPC defines the methods which the Daemon exposes over RPC (via the RPCHandler @@ -53,12 +55,13 @@ type RPC interface { network.RPC } -// RPCHandler returns a jsonrpc2.Handler which will use the Daemon to serve all -// methods defined on the RPC interface. -func (d *Daemon) RPCHandler() jsonrpc2.Handler { - rpc := RPC(d) +// NewRPCHandler returns a jsonrpc2.Handler which will use the given RPC to +// serve all methods defined on the interface. +func NewRPCHandler( + logger *mlog.Logger, rpc RPC, +) jsonrpc2.Handler { return jsonrpc2.Chain( - jsonrpc2.NewMLogMiddleware(d.logger.WithNamespace("rpc")), + jsonrpc2.NewMLogMiddleware(logger), jsonrpc2.ExposeServerSideErrorsMiddleware, )( jsonrpc2.NewDispatchHandler(&rpc), @@ -68,5 +71,6 @@ func (d *Daemon) RPCHandler() jsonrpc2.Handler { // HTTPRPCHandler returns an http.Handler which will use the Daemon to serve all // methods defined on the RPC interface via the JSONRPC2 protocol. func (d *Daemon) HTTPRPCHandler() http.Handler { - return jsonrpc2.NewHTTPHandler(d.RPCHandler()) + handler := NewRPCHandler(d.logger.WithNamespace("rpc"), d) + return jsonrpc2.NewHTTPHandler(handler) } diff --git a/go/daemon/rpc_mock.go b/go/daemon/rpc_mock.go index 496482c..e2e0209 100644 --- a/go/daemon/rpc_mock.go +++ b/go/daemon/rpc_mock.go @@ -94,6 +94,34 @@ func (_m *MockRPC) CreateNetwork(ctx context.Context, name string, domain string return r0 } +// GetBootstrap provides a mock function with given fields: _a0 +func (_m *MockRPC) GetBootstrap(_a0 context.Context) (bootstrap.Bootstrap, error) { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for GetBootstrap") + } + + var r0 bootstrap.Bootstrap + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (bootstrap.Bootstrap, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(context.Context) bootstrap.Bootstrap); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(bootstrap.Bootstrap) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetConfig provides a mock function with given fields: _a0 func (_m *MockRPC) GetConfig(_a0 context.Context) (daecommon.NetworkConfig, error) { ret := _m.Called(_a0) @@ -150,64 +178,6 @@ func (_m *MockRPC) GetGarageClientParams(_a0 context.Context) (network.GarageCli return r0, r1 } -// GetHosts provides a mock function with given fields: _a0 -func (_m *MockRPC) GetHosts(_a0 context.Context) ([]bootstrap.Host, error) { - ret := _m.Called(_a0) - - if len(ret) == 0 { - panic("no return value specified for GetHosts") - } - - var r0 []bootstrap.Host - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) ([]bootstrap.Host, error)); ok { - return rf(_a0) - } - if rf, ok := ret.Get(0).(func(context.Context) []bootstrap.Host); ok { - r0 = rf(_a0) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]bootstrap.Host) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(_a0) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetNebulaCAPublicCredentials provides a mock function with given fields: _a0 -func (_m *MockRPC) GetNebulaCAPublicCredentials(_a0 context.Context) (nebula.CAPublicCredentials, error) { - ret := _m.Called(_a0) - - if len(ret) == 0 { - panic("no return value specified for GetNebulaCAPublicCredentials") - } - - var r0 nebula.CAPublicCredentials - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (nebula.CAPublicCredentials, error)); ok { - return rf(_a0) - } - if rf, ok := ret.Get(0).(func(context.Context) nebula.CAPublicCredentials); ok { - r0 = rf(_a0) - } else { - r0 = ret.Get(0).(nebula.CAPublicCredentials) - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(_a0) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetNetworks provides a mock function with given fields: _a0 func (_m *MockRPC) GetNetworks(_a0 context.Context) ([]bootstrap.CreationParams, error) { ret := _m.Called(_a0)