Allow variadic number of parameters on RPC calls
This commit is contained in:
parent
53ad8a91b4
commit
6c185f6263
@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
func (ctx subCmdCtx) getHosts() (daemon.GetHostsResult, error) {
|
||||
res, err := ctx.daemonRPC.GetHosts(ctx.ctx, struct{}{})
|
||||
res, err := ctx.daemonRPC.GetHosts(ctx.ctx)
|
||||
if err != nil {
|
||||
return daemon.GetHostsResult{}, fmt.Errorf("calling GetHosts: %w", err)
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ var subCmdGarageMC = subCmd{
|
||||
}
|
||||
|
||||
clientParams, err := subCmdCtx.daemonRPC.GetGarageClientParams(
|
||||
subCmdCtx.ctx, struct{}{},
|
||||
subCmdCtx.ctx,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("calling GetGarageClientParams: %w", err)
|
||||
@ -118,7 +118,7 @@ var subCmdGarageCLI = subCmd{
|
||||
do: func(subCmdCtx subCmdCtx) error {
|
||||
|
||||
clientParams, err := subCmdCtx.daemonRPC.GetGarageClientParams(
|
||||
subCmdCtx.ctx, struct{}{},
|
||||
subCmdCtx.ctx,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("calling GetGarageClientParams: %w", err)
|
||||
|
@ -86,7 +86,7 @@ var subCmdNebulaShow = subCmd{
|
||||
}
|
||||
|
||||
caPublicCreds, err := subCmdCtx.daemonRPC.GetNebulaCAPublicCredentials(
|
||||
subCmdCtx.ctx, struct{}{},
|
||||
subCmdCtx.ctx,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("calling GetNebulaCAPublicCredentials: %w", err)
|
||||
|
@ -52,12 +52,7 @@ var subCmdNetworkCreate = subCmd{
|
||||
}
|
||||
|
||||
_, err := subCmdCtx.daemonRPC.CreateNetwork(
|
||||
subCmdCtx.ctx, daemon.CreateNetworkRequest{
|
||||
Name: *name,
|
||||
Domain: *domain,
|
||||
IPNet: ipNet.V,
|
||||
HostName: hostName.V,
|
||||
},
|
||||
subCmdCtx.ctx, *name, *domain, ipNet.V, hostName.V,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating network: %w", err)
|
||||
|
@ -23,47 +23,84 @@ func RPCFromClient(client jsonrpc2.Client) RPC {
|
||||
}
|
||||
|
||||
func (c *rpcClient) CreateHost(ctx context.Context, req CreateHostRequest) (c2 CreateHostResult, err error) {
|
||||
err = c.client.Call(ctx, &c2, "CreateHost", req)
|
||||
err = c.client.Call(
|
||||
ctx,
|
||||
&c2,
|
||||
"CreateHost",
|
||||
req,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *rpcClient) CreateNebulaCertificate(ctx context.Context, req CreateNebulaCertificateRequest) (c2 CreateNebulaCertificateResult, err error) {
|
||||
err = c.client.Call(ctx, &c2, "CreateNebulaCertificate", req)
|
||||
err = c.client.Call(
|
||||
ctx,
|
||||
&c2,
|
||||
"CreateNebulaCertificate",
|
||||
req,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *rpcClient) CreateNetwork(ctx context.Context, req CreateNetworkRequest) (st1 struct {
|
||||
func (c *rpcClient) CreateNetwork(ctx context.Context, name string, domain string, ipNet nebula.IPNet, hostName nebula.HostName) (st1 struct {
|
||||
}, err error) {
|
||||
err = c.client.Call(ctx, &st1, "CreateNetwork", req)
|
||||
err = c.client.Call(
|
||||
ctx,
|
||||
&st1,
|
||||
"CreateNetwork",
|
||||
name,
|
||||
domain,
|
||||
ipNet,
|
||||
hostName,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *rpcClient) GetGarageClientParams(ctx context.Context, req struct {
|
||||
}) (g1 GarageClientParams, err error) {
|
||||
err = c.client.Call(ctx, &g1, "GetGarageClientParams", req)
|
||||
func (c *rpcClient) GetGarageClientParams(ctx context.Context) (g1 GarageClientParams, err error) {
|
||||
err = c.client.Call(
|
||||
ctx,
|
||||
&g1,
|
||||
"GetGarageClientParams",
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *rpcClient) GetHosts(ctx context.Context, req struct {
|
||||
}) (g1 GetHostsResult, err error) {
|
||||
err = c.client.Call(ctx, &g1, "GetHosts", req)
|
||||
func (c *rpcClient) GetHosts(ctx context.Context) (g1 GetHostsResult, err error) {
|
||||
err = c.client.Call(
|
||||
ctx,
|
||||
&g1,
|
||||
"GetHosts",
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *rpcClient) GetNebulaCAPublicCredentials(ctx context.Context, req struct {
|
||||
}) (c2 nebula.CAPublicCredentials, err error) {
|
||||
err = c.client.Call(ctx, &c2, "GetNebulaCAPublicCredentials", req)
|
||||
func (c *rpcClient) GetNebulaCAPublicCredentials(ctx context.Context) (c2 nebula.CAPublicCredentials, err error) {
|
||||
err = c.client.Call(
|
||||
ctx,
|
||||
&c2,
|
||||
"GetNebulaCAPublicCredentials",
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *rpcClient) JoinNetwork(ctx context.Context, req JoiningBootstrap) (st1 struct {
|
||||
}, err error) {
|
||||
err = c.client.Call(ctx, &st1, "JoinNetwork", req)
|
||||
err = c.client.Call(
|
||||
ctx,
|
||||
&st1,
|
||||
"JoinNetwork",
|
||||
req,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *rpcClient) RemoveHost(ctx context.Context, req RemoveHostRequest) (st1 struct {
|
||||
}, err error) {
|
||||
err = c.client.Call(ctx, &st1, "RemoveHost", req)
|
||||
err = c.client.Call(
|
||||
ctx,
|
||||
&st1,
|
||||
"RemoveHost",
|
||||
req,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
@ -15,5 +15,5 @@ type Client interface {
|
||||
//
|
||||
// If an error result is returned from the server that will be returned as
|
||||
// an Error struct.
|
||||
Call(ctx context.Context, rcv any, method string, params any) error
|
||||
Call(ctx context.Context, rcv any, method string, params ...any) error
|
||||
}
|
||||
|
@ -17,10 +17,16 @@ func {{.Interface.Name}}FromClient(client jsonrpc2.Client) {{.Interface.Name}} {
|
||||
{{range $method := .Interface.Methods}}
|
||||
func (c *{{$t}}) {{$method.Declaration}} {
|
||||
{{- $ctx := (index $method.Params 0).Name}}
|
||||
{{- $arg := (index $method.Params 1).Name}}
|
||||
{{- $rcv := (index $method.Results 0).Name}}
|
||||
{{- $err := (index $method.Results 1).Name}}
|
||||
{{- $err}} = c.client.Call({{$ctx}}, &{{$rcv}}, "{{$method.Name}}", {{$arg}})
|
||||
{{- $err}} = c.client.Call(
|
||||
{{$ctx}},
|
||||
&{{$rcv}},
|
||||
"{{$method.Name}}",
|
||||
{{- range $param := (slice $method.Params 1)}}
|
||||
{{$param.Name}},
|
||||
{{- end}}
|
||||
)
|
||||
return
|
||||
}
|
||||
{{end}}
|
||||
|
@ -49,7 +49,7 @@ func NewUnixHTTPClient(unixSocketPath, reqPath string) Client {
|
||||
}
|
||||
|
||||
func (c *httpClient) Call(
|
||||
ctx context.Context, rcv any, method string, params any,
|
||||
ctx context.Context, rcv any, method string, params ...any,
|
||||
) error {
|
||||
var (
|
||||
body = new(bytes.Buffer)
|
||||
|
@ -19,7 +19,7 @@ func NewReadWriterClient(rw io.ReadWriter) Client {
|
||||
}
|
||||
|
||||
func (c rwClient) Call(
|
||||
ctx context.Context, rcv any, method string, params any,
|
||||
ctx context.Context, rcv any, method string, params ...any,
|
||||
) error {
|
||||
id, err := encodeRequest(c.enc, method, params)
|
||||
if err != nil {
|
||||
|
@ -18,29 +18,37 @@ type methodDispatchFunc func(context.Context, Request) (any, error)
|
||||
func newMethodDispatchFunc(
|
||||
method reflect.Value,
|
||||
) methodDispatchFunc {
|
||||
paramT := method.Type().In(1)
|
||||
return func(ctx context.Context, req Request) (any, error) {
|
||||
var (
|
||||
ctxV = reflect.ValueOf(ctx)
|
||||
paramPtrV = reflect.New(paramT)
|
||||
)
|
||||
paramTs := make([]reflect.Type, method.Type().NumIn()-1)
|
||||
for i := range paramTs {
|
||||
paramTs[i] = method.Type().In(i + 1)
|
||||
}
|
||||
|
||||
err := json.Unmarshal(req.Params, paramPtrV.Interface())
|
||||
if err != nil {
|
||||
// The JSON has already been validated, so this is not an
|
||||
// errCodeParse situation. We assume it's an invalid param then,
|
||||
// unless the error says otherwise via an UnmarshalJSON method
|
||||
// returning an Error of its own.
|
||||
if !errors.As(err, new(Error)) {
|
||||
err = NewInvalidParamsError(
|
||||
"JSON unmarshaling params into %T: %v", paramT, err,
|
||||
)
|
||||
return func(ctx context.Context, req Request) (any, error) {
|
||||
callVals := make([]reflect.Value, 0, len(paramTs)+1)
|
||||
callVals = append(callVals, reflect.ValueOf(ctx))
|
||||
|
||||
for i, paramT := range paramTs {
|
||||
paramPtrV := reflect.New(paramT)
|
||||
|
||||
err := json.Unmarshal(req.Params[i], paramPtrV.Interface())
|
||||
if err != nil {
|
||||
// The JSON has already been validated, so this is not an
|
||||
// errCodeParse situation. We assume it's an invalid param then,
|
||||
// unless the error says otherwise via an UnmarshalJSON method
|
||||
// returning an Error of its own.
|
||||
if !errors.As(err, new(Error)) {
|
||||
err = NewInvalidParamsError(
|
||||
"JSON unmarshaling param %d into %T: %v", i, paramT, err,
|
||||
)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
|
||||
callVals = append(callVals, paramPtrV.Elem())
|
||||
}
|
||||
|
||||
var (
|
||||
callResV = method.Call([]reflect.Value{ctxV, paramPtrV.Elem()})
|
||||
callResV = method.Call(callVals)
|
||||
resV = callResV[0]
|
||||
errV = callResV[1]
|
||||
)
|
||||
@ -86,7 +94,7 @@ func NewDispatchHandler(i any) Handler {
|
||||
)
|
||||
|
||||
if !method.IsExported() ||
|
||||
methodT.NumIn() != 2 ||
|
||||
methodT.NumIn() < 1 ||
|
||||
methodT.In(0) != ctxT ||
|
||||
methodT.NumOut() != 2 ||
|
||||
methodT.Out(1) != errT {
|
||||
|
@ -3,6 +3,7 @@ package jsonrpc2
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"dev.mediocregopher.com/mediocre-go-lib.git/mctx"
|
||||
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
||||
@ -59,9 +60,14 @@ func NewMLogMiddleware(logger *mlog.Logger) Middleware {
|
||||
)
|
||||
|
||||
if logger.MaxLevel() >= mlog.LevelDebug.Int() {
|
||||
ctx := mctx.Annotate(
|
||||
ctx, "rpcRequestParams", string(req.Params),
|
||||
)
|
||||
ctx := ctx
|
||||
for i := range req.Params {
|
||||
ctx = mctx.Annotate(
|
||||
ctx,
|
||||
fmt.Sprintf("rpcRequestParam%d", i),
|
||||
string(req.Params[i]),
|
||||
)
|
||||
}
|
||||
logger.Debug(ctx, "Handling RPC request")
|
||||
}
|
||||
|
||||
|
@ -13,10 +13,10 @@ const version = "2.0"
|
||||
|
||||
// Request encodes an RPC request according to the spec.
|
||||
type Request struct {
|
||||
Version string `json:"jsonrpc"` // must be "2.0"
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Version string `json:"jsonrpc"` // must be "2.0"
|
||||
Method string `json:"method"`
|
||||
Params []json.RawMessage `json:"params,omitempty"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type response[Result any] struct {
|
||||
@ -37,13 +37,19 @@ func newID() string {
|
||||
// encodeRequest writes a request to an io.Writer, returning the ID of the
|
||||
// request.
|
||||
func encodeRequest(
|
||||
enc *json.Encoder, method string, params any,
|
||||
enc *json.Encoder, method string, params []any,
|
||||
) (
|
||||
string, error,
|
||||
) {
|
||||
paramsB, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encoding params as JSON: %w", err)
|
||||
var (
|
||||
paramsBs = make([]json.RawMessage, len(params))
|
||||
err error
|
||||
)
|
||||
for i := range params {
|
||||
paramsBs[i], err = json.Marshal(params[i])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encoding param %d as JSON: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
@ -51,7 +57,7 @@ func encodeRequest(
|
||||
reqEnvelope = Request{
|
||||
Version: version,
|
||||
Method: method,
|
||||
Params: paramsB,
|
||||
Params: paramsBs,
|
||||
ID: id,
|
||||
}
|
||||
)
|
||||
|
@ -26,14 +26,22 @@ var ErrDivideByZero = Error{
|
||||
|
||||
type dividerImpl struct{}
|
||||
|
||||
func (dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) {
|
||||
if p.Bottom == 0 {
|
||||
func (dividerImpl) Divide2(ctx context.Context, top, bottom int) (int, error) {
|
||||
if bottom == 0 {
|
||||
return 0, ErrDivideByZero
|
||||
}
|
||||
if p.Top%p.Bottom != 0 {
|
||||
if top%bottom != 0 {
|
||||
return 0, errors.New("numbers don't divide evenly, cannot compute!")
|
||||
}
|
||||
return p.Top / p.Bottom, nil
|
||||
return top / bottom, nil
|
||||
}
|
||||
|
||||
func (i dividerImpl) Noop(ctx context.Context) (int, error) {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (i dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) {
|
||||
return i.Divide2(ctx, p.Top, p.Bottom)
|
||||
}
|
||||
|
||||
func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) {
|
||||
@ -41,6 +49,8 @@ func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) {
|
||||
}
|
||||
|
||||
type divider interface {
|
||||
Noop(ctx context.Context) (int, error)
|
||||
Divide2(ctx context.Context, top, bottom int) (int, error)
|
||||
Divide(ctx context.Context, p DivideParams) (int, error)
|
||||
}
|
||||
|
||||
@ -82,6 +92,26 @@ func testClient(t *testing.T, client Client) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("success/multiple_params", func(t *testing.T) {
|
||||
var res int
|
||||
err := client.Call(ctx, &res, "Divide2", 6, 3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if res != 2 {
|
||||
t.Fatalf("expected 2, got %d", res)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("success/no_params", func(t *testing.T) {
|
||||
var res int
|
||||
err := client.Call(ctx, &res, "Noop")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if res != 1 {
|
||||
t.Fatalf("expected 1, got %d", res)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("err/application", func(t *testing.T) {
|
||||
err := client.Call(ctx, nil, "Divide", DivideParams{})
|
||||
if !errors.Is(err, ErrDivideByZero) {
|
||||
|
@ -11,25 +11,6 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// CreateNetworkRequest contains the arguments to the CreateNetwork RPC method.
|
||||
//
|
||||
// All fields are required.
|
||||
type CreateNetworkRequest struct {
|
||||
// Human-readable name of the network.
|
||||
Name string
|
||||
|
||||
// Primary domain name that network services are served under.
|
||||
Domain string
|
||||
|
||||
// An IP subnet, in CIDR form, which will be the overall range of possible
|
||||
// IPs in the network. The first IP in this network range will become this
|
||||
// first host's IP.
|
||||
IPNet nebula.IPNet
|
||||
|
||||
// The name of this first host in the network.
|
||||
HostName nebula.HostName
|
||||
}
|
||||
|
||||
// GetHostsResult wraps the results from the GetHosts RPC method.
|
||||
type GetHostsResult struct {
|
||||
Hosts []bootstrap.Host
|
||||
@ -75,8 +56,20 @@ type CreateNebulaCertificateResult struct {
|
||||
// interface.
|
||||
type RPC interface {
|
||||
// CreateNetwork passes through to the Daemon method of the same name.
|
||||
//
|
||||
// name: Human-readable name of the network.
|
||||
// domain: Primary domain name that network services are served under.
|
||||
// ipNet:
|
||||
// An IP subnet, in CIDR form, which will be the overall range of
|
||||
// possible IPs in the network. The first IP in this network range will
|
||||
// become this first host's IP.
|
||||
// hostName: The name of this first host in the network.
|
||||
CreateNetwork(
|
||||
ctx context.Context, req CreateNetworkRequest,
|
||||
ctx context.Context,
|
||||
name string,
|
||||
domain string,
|
||||
ipNet nebula.IPNet,
|
||||
hostName nebula.HostName,
|
||||
) (
|
||||
struct{}, error,
|
||||
)
|
||||
@ -89,24 +82,16 @@ type RPC interface {
|
||||
)
|
||||
|
||||
// GetHosts returns all hosts known to the network, sorted by their name.
|
||||
GetHosts(
|
||||
ctx context.Context, req struct{},
|
||||
) (
|
||||
GetHostsResult, error,
|
||||
)
|
||||
GetHosts(ctx context.Context) (GetHostsResult, error)
|
||||
|
||||
// GetGarageClientParams passes the call through to the Daemon method of the
|
||||
// same name.
|
||||
GetGarageClientParams(
|
||||
ctx context.Context, req struct{},
|
||||
) (
|
||||
GarageClientParams, error,
|
||||
)
|
||||
GetGarageClientParams(ctx context.Context) (GarageClientParams, error)
|
||||
|
||||
// GetNebulaCAPublicCredentials returns the CAPublicCredentials for the
|
||||
// network.
|
||||
GetNebulaCAPublicCredentials(
|
||||
ctx context.Context, req struct{},
|
||||
ctx context.Context,
|
||||
) (
|
||||
nebula.CAPublicCredentials, error,
|
||||
)
|
||||
@ -144,12 +129,16 @@ func NewRPC(daemon Daemon) RPC {
|
||||
}
|
||||
|
||||
func (r *rpcImpl) CreateNetwork(
|
||||
ctx context.Context, req CreateNetworkRequest,
|
||||
ctx context.Context,
|
||||
name string,
|
||||
domain string,
|
||||
ipNet nebula.IPNet,
|
||||
hostName nebula.HostName,
|
||||
) (
|
||||
struct{}, error,
|
||||
) {
|
||||
return struct{}{}, r.daemon.CreateNetwork(
|
||||
ctx, req.Name, req.Domain, req.IPNet, req.HostName,
|
||||
ctx, name, domain, ipNet, hostName,
|
||||
)
|
||||
}
|
||||
|
||||
@ -161,11 +150,7 @@ func (r *rpcImpl) JoinNetwork(
|
||||
return struct{}{}, r.daemon.JoinNetwork(ctx, req)
|
||||
}
|
||||
|
||||
func (r *rpcImpl) GetHosts(
|
||||
ctx context.Context, req struct{},
|
||||
) (
|
||||
GetHostsResult, error,
|
||||
) {
|
||||
func (r *rpcImpl) GetHosts(ctx context.Context) (GetHostsResult, error) {
|
||||
b, err := r.daemon.GetBootstrap(ctx)
|
||||
if err != nil {
|
||||
return GetHostsResult{}, fmt.Errorf("retrieving bootstrap: %w", err)
|
||||
@ -180,7 +165,7 @@ func (r *rpcImpl) GetHosts(
|
||||
}
|
||||
|
||||
func (r *rpcImpl) GetGarageClientParams(
|
||||
ctx context.Context, req struct{},
|
||||
ctx context.Context,
|
||||
) (
|
||||
GarageClientParams, error,
|
||||
) {
|
||||
@ -188,7 +173,7 @@ func (r *rpcImpl) GetGarageClientParams(
|
||||
}
|
||||
|
||||
func (r *rpcImpl) GetNebulaCAPublicCredentials(
|
||||
ctx context.Context, req struct{},
|
||||
ctx context.Context,
|
||||
) (
|
||||
nebula.CAPublicCredentials, error,
|
||||
) {
|
||||
|
Loading…
Reference in New Issue
Block a user