diff --git a/go/cmd/entrypoint/host.go b/go/cmd/entrypoint/host.go index 7650ae4..70ab281 100644 --- a/go/cmd/entrypoint/host.go +++ b/go/cmd/entrypoint/host.go @@ -116,8 +116,7 @@ var subCmdHostRemove = subCmd{ return errors.New("--hostname is required") } - _, err := ctx.daemonRPC.RemoveHost(ctx, hostName.V) - if err != nil { + if err := ctx.daemonRPC.RemoveHost(ctx, hostName.V); err != nil { return fmt.Errorf("calling RemoveHost: %w", err) } diff --git a/go/cmd/entrypoint/network.go b/go/cmd/entrypoint/network.go index 7531305..80ad8a9 100644 --- a/go/cmd/entrypoint/network.go +++ b/go/cmd/entrypoint/network.go @@ -51,7 +51,7 @@ var subCmdNetworkCreate = subCmd{ return errors.New("--name, --domain, --ip-net, and --hostname are required") } - _, err := ctx.daemonRPC.CreateNetwork( + err := ctx.daemonRPC.CreateNetwork( ctx, *name, *domain, ipNet.V, hostName.V, ) if err != nil { @@ -88,8 +88,7 @@ var subCmdNetworkJoin = subCmd{ ) } - _, err := ctx.daemonRPC.JoinNetwork(ctx, newBootstrap) - return err + return ctx.daemonRPC.JoinNetwork(ctx, newBootstrap) }, } diff --git a/go/daemon/client.go b/go/daemon/client.go index ae06783..1fa8db0 100644 --- a/go/daemon/client.go +++ b/go/daemon/client.go @@ -44,11 +44,10 @@ func (c *rpcClient) CreateNebulaCertificate(ctx context.Context, hostName nebula return } -func (c *rpcClient) CreateNetwork(ctx context.Context, name string, domain string, ipNet nebula.IPNet, hostName nebula.HostName) (st1 struct { -}, err error) { +func (c *rpcClient) CreateNetwork(ctx context.Context, name string, domain string, ipNet nebula.IPNet, hostName nebula.HostName) (err error) { err = c.client.Call( ctx, - &st1, + nil, "CreateNetwork", name, domain, @@ -85,22 +84,20 @@ func (c *rpcClient) GetNebulaCAPublicCredentials(ctx context.Context) (c2 nebula return } -func (c *rpcClient) JoinNetwork(ctx context.Context, req JoiningBootstrap) (st1 struct { -}, err error) { +func (c *rpcClient) JoinNetwork(ctx context.Context, req JoiningBootstrap) (err error) { err = c.client.Call( ctx, - &st1, + nil, "JoinNetwork", req, ) return } -func (c *rpcClient) RemoveHost(ctx context.Context, hostName nebula.HostName) (st1 struct { -}, err error) { +func (c *rpcClient) RemoveHost(ctx context.Context, hostName nebula.HostName) (err error) { err = c.client.Call( ctx, - &st1, + nil, "RemoveHost", hostName, ) diff --git a/go/daemon/jsonrpc2/client_gen.tpl b/go/daemon/jsonrpc2/client_gen.tpl index a98f3ac..7018f60 100644 --- a/go/daemon/jsonrpc2/client_gen.tpl +++ b/go/daemon/jsonrpc2/client_gen.tpl @@ -17,11 +17,21 @@ func {{.Interface.Name}}FromClient(client jsonrpc2.Client) {{.Interface.Name}} { {{range $method := .Interface.Methods}} func (c *{{$t}}) {{$method.Declaration}} { {{- $ctx := (index $method.Params 0).Name}} - {{- $rcv := (index $method.Results 0).Name}} - {{- $err := (index $method.Results 1).Name}} + + {{- $rcv := ""}} + {{- $err := ""}} + + {{- if (eq (len $method.Results) 1)}} + {{- $rcv = "nil" }} + {{- $err = (index $method.Results 0).Name}} + {{- else}} + {{- $rcv = printf "&%s" (index $method.Results 0).Name}} + {{- $err = (index $method.Results 1).Name}} + {{- end}} + {{- $err}} = c.client.Call( {{$ctx}}, - &{{$rcv}}, + {{$rcv}}, "{{$method.Name}}", {{- range $param := (slice $method.Params 1)}} {{$param.Name}}, diff --git a/go/daemon/jsonrpc2/dispatcher.go b/go/daemon/jsonrpc2/dispatcher.go index b1b9ad6..4d25ece 100644 --- a/go/daemon/jsonrpc2/dispatcher.go +++ b/go/daemon/jsonrpc2/dispatcher.go @@ -48,16 +48,26 @@ func newMethodDispatchFunc( } var ( - callResV = method.Call(callVals) - resV = callResV[0] - errV = callResV[1] + callResV = method.Call(callVals) + resV, errV reflect.Value ) - if errV.IsNil() { + if len(callResV) == 1 { + errV = callResV[0] + } else { + resV = callResV[0] + errV = callResV[1] + } + + if !errV.IsNil() { + return nil, errV.Interface().(error) + } + + if resV.IsValid() { return resV.Interface(), nil } - return nil, errV.Interface().(error) + return nil, nil } } @@ -96,8 +106,9 @@ func NewDispatchHandler(i any) Handler { if !method.IsExported() || methodT.NumIn() < 1 || methodT.In(0) != ctxT || - methodT.NumOut() != 2 || - methodT.Out(1) != errT { + (methodT.NumOut() == 1 && methodT.Out(0) != errT) || + (methodT.NumOut() == 2 && methodT.Out(1) != errT) || + methodT.NumOut() > 2 { continue } diff --git a/go/daemon/jsonrpc2/jsonrpc2_test.go b/go/daemon/jsonrpc2/jsonrpc2_test.go index 4d9e977..fbcc245 100644 --- a/go/daemon/jsonrpc2/jsonrpc2_test.go +++ b/go/daemon/jsonrpc2/jsonrpc2_test.go @@ -36,12 +36,16 @@ func (dividerImpl) Divide2(ctx context.Context, top, bottom int) (int, error) { return top / bottom, nil } -func (i dividerImpl) Noop(ctx context.Context) (int, error) { +func (i dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) { + return i.Divide2(ctx, p.Top, p.Bottom) +} + +func (i dividerImpl) One(ctx context.Context) (int, error) { return 1, nil } -func (i dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) { - return i.Divide2(ctx, p.Top, p.Bottom) +func (i dividerImpl) Noop(ctx context.Context) error { + return nil } func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) { @@ -49,9 +53,10 @@ func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) { } type divider interface { - Noop(ctx context.Context) (int, error) Divide2(ctx context.Context, top, bottom int) (int, error) Divide(ctx context.Context, p DivideParams) (int, error) + One(ctx context.Context) (int, error) + Noop(ctx context.Context) error } var testHandler = func() Handler { @@ -104,7 +109,7 @@ func testClient(t *testing.T, client Client) { t.Run("success/no_params", func(t *testing.T) { var res int - err := client.Call(ctx, &res, "Noop") + err := client.Call(ctx, &res, "One") if err != nil { t.Fatal(err) } else if res != 1 { @@ -112,6 +117,13 @@ func testClient(t *testing.T, client Client) { } }) + t.Run("success/no_results", func(t *testing.T) { + err := client.Call(ctx, nil, "Noop") + if err != nil { + t.Fatal(err) + } + }) + t.Run("err/application", func(t *testing.T) { err := client.Call(ctx, nil, "Divide", DivideParams{}) if !errors.Is(err, ErrDivideByZero) { diff --git a/go/daemon/rpc.go b/go/daemon/rpc.go index 60ff262..965a43c 100644 --- a/go/daemon/rpc.go +++ b/go/daemon/rpc.go @@ -45,16 +45,10 @@ type RPC interface { domain string, ipNet nebula.IPNet, hostName nebula.HostName, - ) ( - struct{}, error, - ) + ) error // JoinNetwork passes through to the Daemon method of the same name. - JoinNetwork( - ctx context.Context, req JoiningBootstrap, - ) ( - struct{}, error, - ) + JoinNetwork(ctx context.Context, req JoiningBootstrap) error // GetHosts returns all hosts known to the network, sorted by their name. GetHosts(ctx context.Context) (GetHostsResult, error) @@ -72,11 +66,7 @@ type RPC interface { ) // RemoveHost passes the call through to the Daemon method of the same name. - RemoveHost( - ctx context.Context, hostName nebula.HostName, - ) ( - struct{}, error, - ) + RemoveHost(ctx context.Context, hostName nebula.HostName) error // CreateHost passes the call through to the Daemon method of the same name. CreateHost( @@ -111,20 +101,16 @@ func (r *rpcImpl) CreateNetwork( domain string, ipNet nebula.IPNet, hostName nebula.HostName, -) ( - struct{}, error, -) { - return struct{}{}, r.daemon.CreateNetwork( +) error { + return r.daemon.CreateNetwork( ctx, name, domain, ipNet, hostName, ) } func (r *rpcImpl) JoinNetwork( ctx context.Context, req JoiningBootstrap, -) ( - struct{}, error, -) { - return struct{}{}, r.daemon.JoinNetwork(ctx, req) +) error { + return r.daemon.JoinNetwork(ctx, req) } func (r *rpcImpl) GetHosts(ctx context.Context) (GetHostsResult, error) { @@ -166,10 +152,8 @@ func (r *rpcImpl) GetNebulaCAPublicCredentials( func (r *rpcImpl) RemoveHost( ctx context.Context, hostName nebula.HostName, -) ( - struct{}, error, -) { - return struct{}{}, r.daemon.RemoveHost(ctx, hostName) +) error { + return r.daemon.RemoveHost(ctx, hostName) } func (r *rpcImpl) CreateHost(