Allow variadic number of parameters on RPC calls

This commit is contained in:
Brian Picciano 2024-09-04 22:25:38 +02:00
parent 53ad8a91b4
commit 6c185f6263
14 changed files with 178 additions and 105 deletions

View File

@ -6,7 +6,7 @@ import (
)
func (ctx subCmdCtx) getHosts() (daemon.GetHostsResult, error) {
res, err := ctx.daemonRPC.GetHosts(ctx.ctx, struct{}{})
res, err := ctx.daemonRPC.GetHosts(ctx.ctx)
if err != nil {
return daemon.GetHostsResult{}, fmt.Errorf("calling GetHosts: %w", err)
}

View File

@ -52,7 +52,7 @@ var subCmdGarageMC = subCmd{
}
clientParams, err := subCmdCtx.daemonRPC.GetGarageClientParams(
subCmdCtx.ctx, struct{}{},
subCmdCtx.ctx,
)
if err != nil {
return fmt.Errorf("calling GetGarageClientParams: %w", err)
@ -118,7 +118,7 @@ var subCmdGarageCLI = subCmd{
do: func(subCmdCtx subCmdCtx) error {
clientParams, err := subCmdCtx.daemonRPC.GetGarageClientParams(
subCmdCtx.ctx, struct{}{},
subCmdCtx.ctx,
)
if err != nil {
return fmt.Errorf("calling GetGarageClientParams: %w", err)

View File

@ -86,7 +86,7 @@ var subCmdNebulaShow = subCmd{
}
caPublicCreds, err := subCmdCtx.daemonRPC.GetNebulaCAPublicCredentials(
subCmdCtx.ctx, struct{}{},
subCmdCtx.ctx,
)
if err != nil {
return fmt.Errorf("calling GetNebulaCAPublicCredentials: %w", err)

View File

@ -52,12 +52,7 @@ var subCmdNetworkCreate = subCmd{
}
_, err := subCmdCtx.daemonRPC.CreateNetwork(
subCmdCtx.ctx, daemon.CreateNetworkRequest{
Name: *name,
Domain: *domain,
IPNet: ipNet.V,
HostName: hostName.V,
},
subCmdCtx.ctx, *name, *domain, ipNet.V, hostName.V,
)
if err != nil {
return fmt.Errorf("creating network: %w", err)

View File

@ -23,47 +23,84 @@ func RPCFromClient(client jsonrpc2.Client) RPC {
}
func (c *rpcClient) CreateHost(ctx context.Context, req CreateHostRequest) (c2 CreateHostResult, err error) {
err = c.client.Call(ctx, &c2, "CreateHost", req)
err = c.client.Call(
ctx,
&c2,
"CreateHost",
req,
)
return
}
func (c *rpcClient) CreateNebulaCertificate(ctx context.Context, req CreateNebulaCertificateRequest) (c2 CreateNebulaCertificateResult, err error) {
err = c.client.Call(ctx, &c2, "CreateNebulaCertificate", req)
err = c.client.Call(
ctx,
&c2,
"CreateNebulaCertificate",
req,
)
return
}
func (c *rpcClient) CreateNetwork(ctx context.Context, req CreateNetworkRequest) (st1 struct {
func (c *rpcClient) CreateNetwork(ctx context.Context, name string, domain string, ipNet nebula.IPNet, hostName nebula.HostName) (st1 struct {
}, err error) {
err = c.client.Call(ctx, &st1, "CreateNetwork", req)
err = c.client.Call(
ctx,
&st1,
"CreateNetwork",
name,
domain,
ipNet,
hostName,
)
return
}
func (c *rpcClient) GetGarageClientParams(ctx context.Context, req struct {
}) (g1 GarageClientParams, err error) {
err = c.client.Call(ctx, &g1, "GetGarageClientParams", req)
func (c *rpcClient) GetGarageClientParams(ctx context.Context) (g1 GarageClientParams, err error) {
err = c.client.Call(
ctx,
&g1,
"GetGarageClientParams",
)
return
}
func (c *rpcClient) GetHosts(ctx context.Context, req struct {
}) (g1 GetHostsResult, err error) {
err = c.client.Call(ctx, &g1, "GetHosts", req)
func (c *rpcClient) GetHosts(ctx context.Context) (g1 GetHostsResult, err error) {
err = c.client.Call(
ctx,
&g1,
"GetHosts",
)
return
}
func (c *rpcClient) GetNebulaCAPublicCredentials(ctx context.Context, req struct {
}) (c2 nebula.CAPublicCredentials, err error) {
err = c.client.Call(ctx, &c2, "GetNebulaCAPublicCredentials", req)
func (c *rpcClient) GetNebulaCAPublicCredentials(ctx context.Context) (c2 nebula.CAPublicCredentials, err error) {
err = c.client.Call(
ctx,
&c2,
"GetNebulaCAPublicCredentials",
)
return
}
func (c *rpcClient) JoinNetwork(ctx context.Context, req JoiningBootstrap) (st1 struct {
}, err error) {
err = c.client.Call(ctx, &st1, "JoinNetwork", req)
err = c.client.Call(
ctx,
&st1,
"JoinNetwork",
req,
)
return
}
func (c *rpcClient) RemoveHost(ctx context.Context, req RemoveHostRequest) (st1 struct {
}, err error) {
err = c.client.Call(ctx, &st1, "RemoveHost", req)
err = c.client.Call(
ctx,
&st1,
"RemoveHost",
req,
)
return
}

View File

@ -15,5 +15,5 @@ type Client interface {
//
// If an error result is returned from the server that will be returned as
// an Error struct.
Call(ctx context.Context, rcv any, method string, params any) error
Call(ctx context.Context, rcv any, method string, params ...any) error
}

View File

@ -17,10 +17,16 @@ func {{.Interface.Name}}FromClient(client jsonrpc2.Client) {{.Interface.Name}} {
{{range $method := .Interface.Methods}}
func (c *{{$t}}) {{$method.Declaration}} {
{{- $ctx := (index $method.Params 0).Name}}
{{- $arg := (index $method.Params 1).Name}}
{{- $rcv := (index $method.Results 0).Name}}
{{- $err := (index $method.Results 1).Name}}
{{- $err}} = c.client.Call({{$ctx}}, &{{$rcv}}, "{{$method.Name}}", {{$arg}})
{{- $err}} = c.client.Call(
{{$ctx}},
&{{$rcv}},
"{{$method.Name}}",
{{- range $param := (slice $method.Params 1)}}
{{$param.Name}},
{{- end}}
)
return
}
{{end}}

View File

@ -49,7 +49,7 @@ func NewUnixHTTPClient(unixSocketPath, reqPath string) Client {
}
func (c *httpClient) Call(
ctx context.Context, rcv any, method string, params any,
ctx context.Context, rcv any, method string, params ...any,
) error {
var (
body = new(bytes.Buffer)

View File

@ -19,7 +19,7 @@ func NewReadWriterClient(rw io.ReadWriter) Client {
}
func (c rwClient) Call(
ctx context.Context, rcv any, method string, params any,
ctx context.Context, rcv any, method string, params ...any,
) error {
id, err := encodeRequest(c.enc, method, params)
if err != nil {

View File

@ -18,29 +18,37 @@ type methodDispatchFunc func(context.Context, Request) (any, error)
func newMethodDispatchFunc(
method reflect.Value,
) methodDispatchFunc {
paramT := method.Type().In(1)
return func(ctx context.Context, req Request) (any, error) {
var (
ctxV = reflect.ValueOf(ctx)
paramPtrV = reflect.New(paramT)
)
paramTs := make([]reflect.Type, method.Type().NumIn()-1)
for i := range paramTs {
paramTs[i] = method.Type().In(i + 1)
}
err := json.Unmarshal(req.Params, paramPtrV.Interface())
if err != nil {
// The JSON has already been validated, so this is not an
// errCodeParse situation. We assume it's an invalid param then,
// unless the error says otherwise via an UnmarshalJSON method
// returning an Error of its own.
if !errors.As(err, new(Error)) {
err = NewInvalidParamsError(
"JSON unmarshaling params into %T: %v", paramT, err,
)
return func(ctx context.Context, req Request) (any, error) {
callVals := make([]reflect.Value, 0, len(paramTs)+1)
callVals = append(callVals, reflect.ValueOf(ctx))
for i, paramT := range paramTs {
paramPtrV := reflect.New(paramT)
err := json.Unmarshal(req.Params[i], paramPtrV.Interface())
if err != nil {
// The JSON has already been validated, so this is not an
// errCodeParse situation. We assume it's an invalid param then,
// unless the error says otherwise via an UnmarshalJSON method
// returning an Error of its own.
if !errors.As(err, new(Error)) {
err = NewInvalidParamsError(
"JSON unmarshaling param %d into %T: %v", i, paramT, err,
)
}
return nil, err
}
return nil, err
callVals = append(callVals, paramPtrV.Elem())
}
var (
callResV = method.Call([]reflect.Value{ctxV, paramPtrV.Elem()})
callResV = method.Call(callVals)
resV = callResV[0]
errV = callResV[1]
)
@ -86,7 +94,7 @@ func NewDispatchHandler(i any) Handler {
)
if !method.IsExported() ||
methodT.NumIn() != 2 ||
methodT.NumIn() < 1 ||
methodT.In(0) != ctxT ||
methodT.NumOut() != 2 ||
methodT.Out(1) != errT {

View File

@ -3,6 +3,7 @@ package jsonrpc2
import (
"context"
"errors"
"fmt"
"dev.mediocregopher.com/mediocre-go-lib.git/mctx"
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
@ -59,9 +60,14 @@ func NewMLogMiddleware(logger *mlog.Logger) Middleware {
)
if logger.MaxLevel() >= mlog.LevelDebug.Int() {
ctx := mctx.Annotate(
ctx, "rpcRequestParams", string(req.Params),
)
ctx := ctx
for i := range req.Params {
ctx = mctx.Annotate(
ctx,
fmt.Sprintf("rpcRequestParam%d", i),
string(req.Params[i]),
)
}
logger.Debug(ctx, "Handling RPC request")
}

View File

@ -13,10 +13,10 @@ const version = "2.0"
// Request encodes an RPC request according to the spec.
type Request struct {
Version string `json:"jsonrpc"` // must be "2.0"
Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"`
ID string `json:"id"`
Version string `json:"jsonrpc"` // must be "2.0"
Method string `json:"method"`
Params []json.RawMessage `json:"params,omitempty"`
ID string `json:"id"`
}
type response[Result any] struct {
@ -37,13 +37,19 @@ func newID() string {
// encodeRequest writes a request to an io.Writer, returning the ID of the
// request.
func encodeRequest(
enc *json.Encoder, method string, params any,
enc *json.Encoder, method string, params []any,
) (
string, error,
) {
paramsB, err := json.Marshal(params)
if err != nil {
return "", fmt.Errorf("encoding params as JSON: %w", err)
var (
paramsBs = make([]json.RawMessage, len(params))
err error
)
for i := range params {
paramsBs[i], err = json.Marshal(params[i])
if err != nil {
return "", fmt.Errorf("encoding param %d as JSON: %w", i, err)
}
}
var (
@ -51,7 +57,7 @@ func encodeRequest(
reqEnvelope = Request{
Version: version,
Method: method,
Params: paramsB,
Params: paramsBs,
ID: id,
}
)

View File

@ -26,14 +26,22 @@ var ErrDivideByZero = Error{
type dividerImpl struct{}
func (dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) {
if p.Bottom == 0 {
func (dividerImpl) Divide2(ctx context.Context, top, bottom int) (int, error) {
if bottom == 0 {
return 0, ErrDivideByZero
}
if p.Top%p.Bottom != 0 {
if top%bottom != 0 {
return 0, errors.New("numbers don't divide evenly, cannot compute!")
}
return p.Top / p.Bottom, nil
return top / bottom, nil
}
func (i dividerImpl) Noop(ctx context.Context) (int, error) {
return 1, nil
}
func (i dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) {
return i.Divide2(ctx, p.Top, p.Bottom)
}
func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) {
@ -41,6 +49,8 @@ func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) {
}
type divider interface {
Noop(ctx context.Context) (int, error)
Divide2(ctx context.Context, top, bottom int) (int, error)
Divide(ctx context.Context, p DivideParams) (int, error)
}
@ -82,6 +92,26 @@ func testClient(t *testing.T, client Client) {
}
})
t.Run("success/multiple_params", func(t *testing.T) {
var res int
err := client.Call(ctx, &res, "Divide2", 6, 3)
if err != nil {
t.Fatal(err)
} else if res != 2 {
t.Fatalf("expected 2, got %d", res)
}
})
t.Run("success/no_params", func(t *testing.T) {
var res int
err := client.Call(ctx, &res, "Noop")
if err != nil {
t.Fatal(err)
} else if res != 1 {
t.Fatalf("expected 1, got %d", res)
}
})
t.Run("err/application", func(t *testing.T) {
err := client.Call(ctx, nil, "Divide", DivideParams{})
if !errors.Is(err, ErrDivideByZero) {

View File

@ -11,25 +11,6 @@ import (
"golang.org/x/exp/maps"
)
// CreateNetworkRequest contains the arguments to the CreateNetwork RPC method.
//
// All fields are required.
type CreateNetworkRequest struct {
// Human-readable name of the network.
Name string
// Primary domain name that network services are served under.
Domain string
// An IP subnet, in CIDR form, which will be the overall range of possible
// IPs in the network. The first IP in this network range will become this
// first host's IP.
IPNet nebula.IPNet
// The name of this first host in the network.
HostName nebula.HostName
}
// GetHostsResult wraps the results from the GetHosts RPC method.
type GetHostsResult struct {
Hosts []bootstrap.Host
@ -75,8 +56,20 @@ type CreateNebulaCertificateResult struct {
// interface.
type RPC interface {
// CreateNetwork passes through to the Daemon method of the same name.
//
// name: Human-readable name of the network.
// domain: Primary domain name that network services are served under.
// ipNet:
// An IP subnet, in CIDR form, which will be the overall range of
// possible IPs in the network. The first IP in this network range will
// become this first host's IP.
// hostName: The name of this first host in the network.
CreateNetwork(
ctx context.Context, req CreateNetworkRequest,
ctx context.Context,
name string,
domain string,
ipNet nebula.IPNet,
hostName nebula.HostName,
) (
struct{}, error,
)
@ -89,24 +82,16 @@ type RPC interface {
)
// GetHosts returns all hosts known to the network, sorted by their name.
GetHosts(
ctx context.Context, req struct{},
) (
GetHostsResult, error,
)
GetHosts(ctx context.Context) (GetHostsResult, error)
// GetGarageClientParams passes the call through to the Daemon method of the
// same name.
GetGarageClientParams(
ctx context.Context, req struct{},
) (
GarageClientParams, error,
)
GetGarageClientParams(ctx context.Context) (GarageClientParams, error)
// GetNebulaCAPublicCredentials returns the CAPublicCredentials for the
// network.
GetNebulaCAPublicCredentials(
ctx context.Context, req struct{},
ctx context.Context,
) (
nebula.CAPublicCredentials, error,
)
@ -144,12 +129,16 @@ func NewRPC(daemon Daemon) RPC {
}
func (r *rpcImpl) CreateNetwork(
ctx context.Context, req CreateNetworkRequest,
ctx context.Context,
name string,
domain string,
ipNet nebula.IPNet,
hostName nebula.HostName,
) (
struct{}, error,
) {
return struct{}{}, r.daemon.CreateNetwork(
ctx, req.Name, req.Domain, req.IPNet, req.HostName,
ctx, name, domain, ipNet, hostName,
)
}
@ -161,11 +150,7 @@ func (r *rpcImpl) JoinNetwork(
return struct{}{}, r.daemon.JoinNetwork(ctx, req)
}
func (r *rpcImpl) GetHosts(
ctx context.Context, req struct{},
) (
GetHostsResult, error,
) {
func (r *rpcImpl) GetHosts(ctx context.Context) (GetHostsResult, error) {
b, err := r.daemon.GetBootstrap(ctx)
if err != nil {
return GetHostsResult{}, fmt.Errorf("retrieving bootstrap: %w", err)
@ -180,7 +165,7 @@ func (r *rpcImpl) GetHosts(
}
func (r *rpcImpl) GetGarageClientParams(
ctx context.Context, req struct{},
ctx context.Context,
) (
GarageClientParams, error,
) {
@ -188,7 +173,7 @@ func (r *rpcImpl) GetGarageClientParams(
}
func (r *rpcImpl) GetNebulaCAPublicCredentials(
ctx context.Context, req struct{},
ctx context.Context,
) (
nebula.CAPublicCredentials, error,
) {