Allow variadic number of parameters on RPC calls

This commit is contained in:
Brian Picciano 2024-09-04 22:25:38 +02:00
parent 53ad8a91b4
commit 6c185f6263
14 changed files with 178 additions and 105 deletions

View File

@ -6,7 +6,7 @@ import (
) )
func (ctx subCmdCtx) getHosts() (daemon.GetHostsResult, error) { func (ctx subCmdCtx) getHosts() (daemon.GetHostsResult, error) {
res, err := ctx.daemonRPC.GetHosts(ctx.ctx, struct{}{}) res, err := ctx.daemonRPC.GetHosts(ctx.ctx)
if err != nil { if err != nil {
return daemon.GetHostsResult{}, fmt.Errorf("calling GetHosts: %w", err) return daemon.GetHostsResult{}, fmt.Errorf("calling GetHosts: %w", err)
} }

View File

@ -52,7 +52,7 @@ var subCmdGarageMC = subCmd{
} }
clientParams, err := subCmdCtx.daemonRPC.GetGarageClientParams( clientParams, err := subCmdCtx.daemonRPC.GetGarageClientParams(
subCmdCtx.ctx, struct{}{}, subCmdCtx.ctx,
) )
if err != nil { if err != nil {
return fmt.Errorf("calling GetGarageClientParams: %w", err) return fmt.Errorf("calling GetGarageClientParams: %w", err)
@ -118,7 +118,7 @@ var subCmdGarageCLI = subCmd{
do: func(subCmdCtx subCmdCtx) error { do: func(subCmdCtx subCmdCtx) error {
clientParams, err := subCmdCtx.daemonRPC.GetGarageClientParams( clientParams, err := subCmdCtx.daemonRPC.GetGarageClientParams(
subCmdCtx.ctx, struct{}{}, subCmdCtx.ctx,
) )
if err != nil { if err != nil {
return fmt.Errorf("calling GetGarageClientParams: %w", err) return fmt.Errorf("calling GetGarageClientParams: %w", err)

View File

@ -86,7 +86,7 @@ var subCmdNebulaShow = subCmd{
} }
caPublicCreds, err := subCmdCtx.daemonRPC.GetNebulaCAPublicCredentials( caPublicCreds, err := subCmdCtx.daemonRPC.GetNebulaCAPublicCredentials(
subCmdCtx.ctx, struct{}{}, subCmdCtx.ctx,
) )
if err != nil { if err != nil {
return fmt.Errorf("calling GetNebulaCAPublicCredentials: %w", err) return fmt.Errorf("calling GetNebulaCAPublicCredentials: %w", err)

View File

@ -52,12 +52,7 @@ var subCmdNetworkCreate = subCmd{
} }
_, err := subCmdCtx.daemonRPC.CreateNetwork( _, err := subCmdCtx.daemonRPC.CreateNetwork(
subCmdCtx.ctx, daemon.CreateNetworkRequest{ subCmdCtx.ctx, *name, *domain, ipNet.V, hostName.V,
Name: *name,
Domain: *domain,
IPNet: ipNet.V,
HostName: hostName.V,
},
) )
if err != nil { if err != nil {
return fmt.Errorf("creating network: %w", err) return fmt.Errorf("creating network: %w", err)

View File

@ -23,47 +23,84 @@ func RPCFromClient(client jsonrpc2.Client) RPC {
} }
func (c *rpcClient) CreateHost(ctx context.Context, req CreateHostRequest) (c2 CreateHostResult, err error) { func (c *rpcClient) CreateHost(ctx context.Context, req CreateHostRequest) (c2 CreateHostResult, err error) {
err = c.client.Call(ctx, &c2, "CreateHost", req) err = c.client.Call(
ctx,
&c2,
"CreateHost",
req,
)
return return
} }
func (c *rpcClient) CreateNebulaCertificate(ctx context.Context, req CreateNebulaCertificateRequest) (c2 CreateNebulaCertificateResult, err error) { func (c *rpcClient) CreateNebulaCertificate(ctx context.Context, req CreateNebulaCertificateRequest) (c2 CreateNebulaCertificateResult, err error) {
err = c.client.Call(ctx, &c2, "CreateNebulaCertificate", req) err = c.client.Call(
ctx,
&c2,
"CreateNebulaCertificate",
req,
)
return return
} }
func (c *rpcClient) CreateNetwork(ctx context.Context, req CreateNetworkRequest) (st1 struct { func (c *rpcClient) CreateNetwork(ctx context.Context, name string, domain string, ipNet nebula.IPNet, hostName nebula.HostName) (st1 struct {
}, err error) { }, err error) {
err = c.client.Call(ctx, &st1, "CreateNetwork", req) err = c.client.Call(
ctx,
&st1,
"CreateNetwork",
name,
domain,
ipNet,
hostName,
)
return return
} }
func (c *rpcClient) GetGarageClientParams(ctx context.Context, req struct { func (c *rpcClient) GetGarageClientParams(ctx context.Context) (g1 GarageClientParams, err error) {
}) (g1 GarageClientParams, err error) { err = c.client.Call(
err = c.client.Call(ctx, &g1, "GetGarageClientParams", req) ctx,
&g1,
"GetGarageClientParams",
)
return return
} }
func (c *rpcClient) GetHosts(ctx context.Context, req struct { func (c *rpcClient) GetHosts(ctx context.Context) (g1 GetHostsResult, err error) {
}) (g1 GetHostsResult, err error) { err = c.client.Call(
err = c.client.Call(ctx, &g1, "GetHosts", req) ctx,
&g1,
"GetHosts",
)
return return
} }
func (c *rpcClient) GetNebulaCAPublicCredentials(ctx context.Context, req struct { func (c *rpcClient) GetNebulaCAPublicCredentials(ctx context.Context) (c2 nebula.CAPublicCredentials, err error) {
}) (c2 nebula.CAPublicCredentials, err error) { err = c.client.Call(
err = c.client.Call(ctx, &c2, "GetNebulaCAPublicCredentials", req) ctx,
&c2,
"GetNebulaCAPublicCredentials",
)
return return
} }
func (c *rpcClient) JoinNetwork(ctx context.Context, req JoiningBootstrap) (st1 struct { func (c *rpcClient) JoinNetwork(ctx context.Context, req JoiningBootstrap) (st1 struct {
}, err error) { }, err error) {
err = c.client.Call(ctx, &st1, "JoinNetwork", req) err = c.client.Call(
ctx,
&st1,
"JoinNetwork",
req,
)
return return
} }
func (c *rpcClient) RemoveHost(ctx context.Context, req RemoveHostRequest) (st1 struct { func (c *rpcClient) RemoveHost(ctx context.Context, req RemoveHostRequest) (st1 struct {
}, err error) { }, err error) {
err = c.client.Call(ctx, &st1, "RemoveHost", req) err = c.client.Call(
ctx,
&st1,
"RemoveHost",
req,
)
return return
} }

View File

@ -15,5 +15,5 @@ type Client interface {
// //
// If an error result is returned from the server that will be returned as // If an error result is returned from the server that will be returned as
// an Error struct. // an Error struct.
Call(ctx context.Context, rcv any, method string, params any) error Call(ctx context.Context, rcv any, method string, params ...any) error
} }

View File

@ -17,10 +17,16 @@ func {{.Interface.Name}}FromClient(client jsonrpc2.Client) {{.Interface.Name}} {
{{range $method := .Interface.Methods}} {{range $method := .Interface.Methods}}
func (c *{{$t}}) {{$method.Declaration}} { func (c *{{$t}}) {{$method.Declaration}} {
{{- $ctx := (index $method.Params 0).Name}} {{- $ctx := (index $method.Params 0).Name}}
{{- $arg := (index $method.Params 1).Name}}
{{- $rcv := (index $method.Results 0).Name}} {{- $rcv := (index $method.Results 0).Name}}
{{- $err := (index $method.Results 1).Name}} {{- $err := (index $method.Results 1).Name}}
{{- $err}} = c.client.Call({{$ctx}}, &{{$rcv}}, "{{$method.Name}}", {{$arg}}) {{- $err}} = c.client.Call(
{{$ctx}},
&{{$rcv}},
"{{$method.Name}}",
{{- range $param := (slice $method.Params 1)}}
{{$param.Name}},
{{- end}}
)
return return
} }
{{end}} {{end}}

View File

@ -49,7 +49,7 @@ func NewUnixHTTPClient(unixSocketPath, reqPath string) Client {
} }
func (c *httpClient) Call( func (c *httpClient) Call(
ctx context.Context, rcv any, method string, params any, ctx context.Context, rcv any, method string, params ...any,
) error { ) error {
var ( var (
body = new(bytes.Buffer) body = new(bytes.Buffer)

View File

@ -19,7 +19,7 @@ func NewReadWriterClient(rw io.ReadWriter) Client {
} }
func (c rwClient) Call( func (c rwClient) Call(
ctx context.Context, rcv any, method string, params any, ctx context.Context, rcv any, method string, params ...any,
) error { ) error {
id, err := encodeRequest(c.enc, method, params) id, err := encodeRequest(c.enc, method, params)
if err != nil { if err != nil {

View File

@ -18,14 +18,19 @@ type methodDispatchFunc func(context.Context, Request) (any, error)
func newMethodDispatchFunc( func newMethodDispatchFunc(
method reflect.Value, method reflect.Value,
) methodDispatchFunc { ) methodDispatchFunc {
paramT := method.Type().In(1) paramTs := make([]reflect.Type, method.Type().NumIn()-1)
return func(ctx context.Context, req Request) (any, error) { for i := range paramTs {
var ( paramTs[i] = method.Type().In(i + 1)
ctxV = reflect.ValueOf(ctx) }
paramPtrV = reflect.New(paramT)
)
err := json.Unmarshal(req.Params, paramPtrV.Interface()) return func(ctx context.Context, req Request) (any, error) {
callVals := make([]reflect.Value, 0, len(paramTs)+1)
callVals = append(callVals, reflect.ValueOf(ctx))
for i, paramT := range paramTs {
paramPtrV := reflect.New(paramT)
err := json.Unmarshal(req.Params[i], paramPtrV.Interface())
if err != nil { if err != nil {
// The JSON has already been validated, so this is not an // The JSON has already been validated, so this is not an
// errCodeParse situation. We assume it's an invalid param then, // errCodeParse situation. We assume it's an invalid param then,
@ -33,14 +38,17 @@ func newMethodDispatchFunc(
// returning an Error of its own. // returning an Error of its own.
if !errors.As(err, new(Error)) { if !errors.As(err, new(Error)) {
err = NewInvalidParamsError( err = NewInvalidParamsError(
"JSON unmarshaling params into %T: %v", paramT, err, "JSON unmarshaling param %d into %T: %v", i, paramT, err,
) )
} }
return nil, err return nil, err
} }
callVals = append(callVals, paramPtrV.Elem())
}
var ( var (
callResV = method.Call([]reflect.Value{ctxV, paramPtrV.Elem()}) callResV = method.Call(callVals)
resV = callResV[0] resV = callResV[0]
errV = callResV[1] errV = callResV[1]
) )
@ -86,7 +94,7 @@ func NewDispatchHandler(i any) Handler {
) )
if !method.IsExported() || if !method.IsExported() ||
methodT.NumIn() != 2 || methodT.NumIn() < 1 ||
methodT.In(0) != ctxT || methodT.In(0) != ctxT ||
methodT.NumOut() != 2 || methodT.NumOut() != 2 ||
methodT.Out(1) != errT { methodT.Out(1) != errT {

View File

@ -3,6 +3,7 @@ package jsonrpc2
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"dev.mediocregopher.com/mediocre-go-lib.git/mctx" "dev.mediocregopher.com/mediocre-go-lib.git/mctx"
"dev.mediocregopher.com/mediocre-go-lib.git/mlog" "dev.mediocregopher.com/mediocre-go-lib.git/mlog"
@ -59,9 +60,14 @@ func NewMLogMiddleware(logger *mlog.Logger) Middleware {
) )
if logger.MaxLevel() >= mlog.LevelDebug.Int() { if logger.MaxLevel() >= mlog.LevelDebug.Int() {
ctx := mctx.Annotate( ctx := ctx
ctx, "rpcRequestParams", string(req.Params), for i := range req.Params {
ctx = mctx.Annotate(
ctx,
fmt.Sprintf("rpcRequestParam%d", i),
string(req.Params[i]),
) )
}
logger.Debug(ctx, "Handling RPC request") logger.Debug(ctx, "Handling RPC request")
} }

View File

@ -15,7 +15,7 @@ const version = "2.0"
type Request struct { type Request struct {
Version string `json:"jsonrpc"` // must be "2.0" Version string `json:"jsonrpc"` // must be "2.0"
Method string `json:"method"` Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"` Params []json.RawMessage `json:"params,omitempty"`
ID string `json:"id"` ID string `json:"id"`
} }
@ -37,13 +37,19 @@ func newID() string {
// encodeRequest writes a request to an io.Writer, returning the ID of the // encodeRequest writes a request to an io.Writer, returning the ID of the
// request. // request.
func encodeRequest( func encodeRequest(
enc *json.Encoder, method string, params any, enc *json.Encoder, method string, params []any,
) ( ) (
string, error, string, error,
) { ) {
paramsB, err := json.Marshal(params) var (
paramsBs = make([]json.RawMessage, len(params))
err error
)
for i := range params {
paramsBs[i], err = json.Marshal(params[i])
if err != nil { if err != nil {
return "", fmt.Errorf("encoding params as JSON: %w", err) return "", fmt.Errorf("encoding param %d as JSON: %w", i, err)
}
} }
var ( var (
@ -51,7 +57,7 @@ func encodeRequest(
reqEnvelope = Request{ reqEnvelope = Request{
Version: version, Version: version,
Method: method, Method: method,
Params: paramsB, Params: paramsBs,
ID: id, ID: id,
} }
) )

View File

@ -26,14 +26,22 @@ var ErrDivideByZero = Error{
type dividerImpl struct{} type dividerImpl struct{}
func (dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) { func (dividerImpl) Divide2(ctx context.Context, top, bottom int) (int, error) {
if p.Bottom == 0 { if bottom == 0 {
return 0, ErrDivideByZero return 0, ErrDivideByZero
} }
if p.Top%p.Bottom != 0 { if top%bottom != 0 {
return 0, errors.New("numbers don't divide evenly, cannot compute!") return 0, errors.New("numbers don't divide evenly, cannot compute!")
} }
return p.Top / p.Bottom, nil return top / bottom, nil
}
func (i dividerImpl) Noop(ctx context.Context) (int, error) {
return 1, nil
}
func (i dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) {
return i.Divide2(ctx, p.Top, p.Bottom)
} }
func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) { func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) {
@ -41,6 +49,8 @@ func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) {
} }
type divider interface { type divider interface {
Noop(ctx context.Context) (int, error)
Divide2(ctx context.Context, top, bottom int) (int, error)
Divide(ctx context.Context, p DivideParams) (int, error) Divide(ctx context.Context, p DivideParams) (int, error)
} }
@ -82,6 +92,26 @@ func testClient(t *testing.T, client Client) {
} }
}) })
t.Run("success/multiple_params", func(t *testing.T) {
var res int
err := client.Call(ctx, &res, "Divide2", 6, 3)
if err != nil {
t.Fatal(err)
} else if res != 2 {
t.Fatalf("expected 2, got %d", res)
}
})
t.Run("success/no_params", func(t *testing.T) {
var res int
err := client.Call(ctx, &res, "Noop")
if err != nil {
t.Fatal(err)
} else if res != 1 {
t.Fatalf("expected 1, got %d", res)
}
})
t.Run("err/application", func(t *testing.T) { t.Run("err/application", func(t *testing.T) {
err := client.Call(ctx, nil, "Divide", DivideParams{}) err := client.Call(ctx, nil, "Divide", DivideParams{})
if !errors.Is(err, ErrDivideByZero) { if !errors.Is(err, ErrDivideByZero) {

View File

@ -11,25 +11,6 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
// CreateNetworkRequest contains the arguments to the CreateNetwork RPC method.
//
// All fields are required.
type CreateNetworkRequest struct {
// Human-readable name of the network.
Name string
// Primary domain name that network services are served under.
Domain string
// An IP subnet, in CIDR form, which will be the overall range of possible
// IPs in the network. The first IP in this network range will become this
// first host's IP.
IPNet nebula.IPNet
// The name of this first host in the network.
HostName nebula.HostName
}
// GetHostsResult wraps the results from the GetHosts RPC method. // GetHostsResult wraps the results from the GetHosts RPC method.
type GetHostsResult struct { type GetHostsResult struct {
Hosts []bootstrap.Host Hosts []bootstrap.Host
@ -75,8 +56,20 @@ type CreateNebulaCertificateResult struct {
// interface. // interface.
type RPC interface { type RPC interface {
// CreateNetwork passes through to the Daemon method of the same name. // CreateNetwork passes through to the Daemon method of the same name.
//
// name: Human-readable name of the network.
// domain: Primary domain name that network services are served under.
// ipNet:
// An IP subnet, in CIDR form, which will be the overall range of
// possible IPs in the network. The first IP in this network range will
// become this first host's IP.
// hostName: The name of this first host in the network.
CreateNetwork( CreateNetwork(
ctx context.Context, req CreateNetworkRequest, ctx context.Context,
name string,
domain string,
ipNet nebula.IPNet,
hostName nebula.HostName,
) ( ) (
struct{}, error, struct{}, error,
) )
@ -89,24 +82,16 @@ type RPC interface {
) )
// GetHosts returns all hosts known to the network, sorted by their name. // GetHosts returns all hosts known to the network, sorted by their name.
GetHosts( GetHosts(ctx context.Context) (GetHostsResult, error)
ctx context.Context, req struct{},
) (
GetHostsResult, error,
)
// GetGarageClientParams passes the call through to the Daemon method of the // GetGarageClientParams passes the call through to the Daemon method of the
// same name. // same name.
GetGarageClientParams( GetGarageClientParams(ctx context.Context) (GarageClientParams, error)
ctx context.Context, req struct{},
) (
GarageClientParams, error,
)
// GetNebulaCAPublicCredentials returns the CAPublicCredentials for the // GetNebulaCAPublicCredentials returns the CAPublicCredentials for the
// network. // network.
GetNebulaCAPublicCredentials( GetNebulaCAPublicCredentials(
ctx context.Context, req struct{}, ctx context.Context,
) ( ) (
nebula.CAPublicCredentials, error, nebula.CAPublicCredentials, error,
) )
@ -144,12 +129,16 @@ func NewRPC(daemon Daemon) RPC {
} }
func (r *rpcImpl) CreateNetwork( func (r *rpcImpl) CreateNetwork(
ctx context.Context, req CreateNetworkRequest, ctx context.Context,
name string,
domain string,
ipNet nebula.IPNet,
hostName nebula.HostName,
) ( ) (
struct{}, error, struct{}, error,
) { ) {
return struct{}{}, r.daemon.CreateNetwork( return struct{}{}, r.daemon.CreateNetwork(
ctx, req.Name, req.Domain, req.IPNet, req.HostName, ctx, name, domain, ipNet, hostName,
) )
} }
@ -161,11 +150,7 @@ func (r *rpcImpl) JoinNetwork(
return struct{}{}, r.daemon.JoinNetwork(ctx, req) return struct{}{}, r.daemon.JoinNetwork(ctx, req)
} }
func (r *rpcImpl) GetHosts( func (r *rpcImpl) GetHosts(ctx context.Context) (GetHostsResult, error) {
ctx context.Context, req struct{},
) (
GetHostsResult, error,
) {
b, err := r.daemon.GetBootstrap(ctx) b, err := r.daemon.GetBootstrap(ctx)
if err != nil { if err != nil {
return GetHostsResult{}, fmt.Errorf("retrieving bootstrap: %w", err) return GetHostsResult{}, fmt.Errorf("retrieving bootstrap: %w", err)
@ -180,7 +165,7 @@ func (r *rpcImpl) GetHosts(
} }
func (r *rpcImpl) GetGarageClientParams( func (r *rpcImpl) GetGarageClientParams(
ctx context.Context, req struct{}, ctx context.Context,
) ( ) (
GarageClientParams, error, GarageClientParams, error,
) { ) {
@ -188,7 +173,7 @@ func (r *rpcImpl) GetGarageClientParams(
} }
func (r *rpcImpl) GetNebulaCAPublicCredentials( func (r *rpcImpl) GetNebulaCAPublicCredentials(
ctx context.Context, req struct{}, ctx context.Context,
) ( ) (
nebula.CAPublicCredentials, error, nebula.CAPublicCredentials, error,
) { ) {