diff --git a/go/cmd/entrypoint/client.go b/go/cmd/entrypoint/client.go index ea2f716..ca16bc3 100644 --- a/go/cmd/entrypoint/client.go +++ b/go/cmd/entrypoint/client.go @@ -6,7 +6,7 @@ import ( ) func (ctx subCmdCtx) getHosts() (daemon.GetHostsResult, error) { - res, err := ctx.daemonRPC.GetHosts(ctx.ctx, struct{}{}) + res, err := ctx.daemonRPC.GetHosts(ctx.ctx) if err != nil { return daemon.GetHostsResult{}, fmt.Errorf("calling GetHosts: %w", err) } diff --git a/go/cmd/entrypoint/garage.go b/go/cmd/entrypoint/garage.go index 1eca175..ac2a956 100644 --- a/go/cmd/entrypoint/garage.go +++ b/go/cmd/entrypoint/garage.go @@ -52,7 +52,7 @@ var subCmdGarageMC = subCmd{ } clientParams, err := subCmdCtx.daemonRPC.GetGarageClientParams( - subCmdCtx.ctx, struct{}{}, + subCmdCtx.ctx, ) if err != nil { return fmt.Errorf("calling GetGarageClientParams: %w", err) @@ -118,7 +118,7 @@ var subCmdGarageCLI = subCmd{ do: func(subCmdCtx subCmdCtx) error { clientParams, err := subCmdCtx.daemonRPC.GetGarageClientParams( - subCmdCtx.ctx, struct{}{}, + subCmdCtx.ctx, ) if err != nil { return fmt.Errorf("calling GetGarageClientParams: %w", err) diff --git a/go/cmd/entrypoint/nebula.go b/go/cmd/entrypoint/nebula.go index 5869afd..333d454 100644 --- a/go/cmd/entrypoint/nebula.go +++ b/go/cmd/entrypoint/nebula.go @@ -86,7 +86,7 @@ var subCmdNebulaShow = subCmd{ } caPublicCreds, err := subCmdCtx.daemonRPC.GetNebulaCAPublicCredentials( - subCmdCtx.ctx, struct{}{}, + subCmdCtx.ctx, ) if err != nil { return fmt.Errorf("calling GetNebulaCAPublicCredentials: %w", err) diff --git a/go/cmd/entrypoint/network.go b/go/cmd/entrypoint/network.go index 5672858..e1f9998 100644 --- a/go/cmd/entrypoint/network.go +++ b/go/cmd/entrypoint/network.go @@ -52,12 +52,7 @@ var subCmdNetworkCreate = subCmd{ } _, err := subCmdCtx.daemonRPC.CreateNetwork( - subCmdCtx.ctx, daemon.CreateNetworkRequest{ - Name: *name, - Domain: *domain, - IPNet: ipNet.V, - HostName: hostName.V, - }, + subCmdCtx.ctx, *name, *domain, ipNet.V, hostName.V, ) if err != nil { return fmt.Errorf("creating network: %w", err) diff --git a/go/daemon/client.go b/go/daemon/client.go index 4bcc0a4..d457740 100644 --- a/go/daemon/client.go +++ b/go/daemon/client.go @@ -23,47 +23,84 @@ func RPCFromClient(client jsonrpc2.Client) RPC { } func (c *rpcClient) CreateHost(ctx context.Context, req CreateHostRequest) (c2 CreateHostResult, err error) { - err = c.client.Call(ctx, &c2, "CreateHost", req) + err = c.client.Call( + ctx, + &c2, + "CreateHost", + req, + ) return } func (c *rpcClient) CreateNebulaCertificate(ctx context.Context, req CreateNebulaCertificateRequest) (c2 CreateNebulaCertificateResult, err error) { - err = c.client.Call(ctx, &c2, "CreateNebulaCertificate", req) + err = c.client.Call( + ctx, + &c2, + "CreateNebulaCertificate", + req, + ) return } -func (c *rpcClient) CreateNetwork(ctx context.Context, req CreateNetworkRequest) (st1 struct { +func (c *rpcClient) CreateNetwork(ctx context.Context, name string, domain string, ipNet nebula.IPNet, hostName nebula.HostName) (st1 struct { }, err error) { - err = c.client.Call(ctx, &st1, "CreateNetwork", req) + err = c.client.Call( + ctx, + &st1, + "CreateNetwork", + name, + domain, + ipNet, + hostName, + ) return } -func (c *rpcClient) GetGarageClientParams(ctx context.Context, req struct { -}) (g1 GarageClientParams, err error) { - err = c.client.Call(ctx, &g1, "GetGarageClientParams", req) +func (c *rpcClient) GetGarageClientParams(ctx context.Context) (g1 GarageClientParams, err error) { + err = c.client.Call( + ctx, + &g1, + "GetGarageClientParams", + ) return } -func (c *rpcClient) GetHosts(ctx context.Context, req struct { -}) (g1 GetHostsResult, err error) { - err = c.client.Call(ctx, &g1, "GetHosts", req) +func (c *rpcClient) GetHosts(ctx context.Context) (g1 GetHostsResult, err error) { + err = c.client.Call( + ctx, + &g1, + "GetHosts", + ) return } -func (c *rpcClient) GetNebulaCAPublicCredentials(ctx context.Context, req struct { -}) (c2 nebula.CAPublicCredentials, err error) { - err = c.client.Call(ctx, &c2, "GetNebulaCAPublicCredentials", req) +func (c *rpcClient) GetNebulaCAPublicCredentials(ctx context.Context) (c2 nebula.CAPublicCredentials, err error) { + err = c.client.Call( + ctx, + &c2, + "GetNebulaCAPublicCredentials", + ) return } func (c *rpcClient) JoinNetwork(ctx context.Context, req JoiningBootstrap) (st1 struct { }, err error) { - err = c.client.Call(ctx, &st1, "JoinNetwork", req) + err = c.client.Call( + ctx, + &st1, + "JoinNetwork", + req, + ) return } func (c *rpcClient) RemoveHost(ctx context.Context, req RemoveHostRequest) (st1 struct { }, err error) { - err = c.client.Call(ctx, &st1, "RemoveHost", req) + err = c.client.Call( + ctx, + &st1, + "RemoveHost", + req, + ) return } diff --git a/go/daemon/jsonrpc2/client.go b/go/daemon/jsonrpc2/client.go index d2b652e..fba9b42 100644 --- a/go/daemon/jsonrpc2/client.go +++ b/go/daemon/jsonrpc2/client.go @@ -15,5 +15,5 @@ type Client interface { // // If an error result is returned from the server that will be returned as // an Error struct. - Call(ctx context.Context, rcv any, method string, params any) error + Call(ctx context.Context, rcv any, method string, params ...any) error } diff --git a/go/daemon/jsonrpc2/client_gen.tpl b/go/daemon/jsonrpc2/client_gen.tpl index ee4cbfc..a98f3ac 100644 --- a/go/daemon/jsonrpc2/client_gen.tpl +++ b/go/daemon/jsonrpc2/client_gen.tpl @@ -17,10 +17,16 @@ func {{.Interface.Name}}FromClient(client jsonrpc2.Client) {{.Interface.Name}} { {{range $method := .Interface.Methods}} func (c *{{$t}}) {{$method.Declaration}} { {{- $ctx := (index $method.Params 0).Name}} - {{- $arg := (index $method.Params 1).Name}} {{- $rcv := (index $method.Results 0).Name}} {{- $err := (index $method.Results 1).Name}} - {{- $err}} = c.client.Call({{$ctx}}, &{{$rcv}}, "{{$method.Name}}", {{$arg}}) + {{- $err}} = c.client.Call( + {{$ctx}}, + &{{$rcv}}, + "{{$method.Name}}", + {{- range $param := (slice $method.Params 1)}} + {{$param.Name}}, + {{- end}} + ) return } {{end}} diff --git a/go/daemon/jsonrpc2/client_http.go b/go/daemon/jsonrpc2/client_http.go index 187267b..6b2845c 100644 --- a/go/daemon/jsonrpc2/client_http.go +++ b/go/daemon/jsonrpc2/client_http.go @@ -49,7 +49,7 @@ func NewUnixHTTPClient(unixSocketPath, reqPath string) Client { } func (c *httpClient) Call( - ctx context.Context, rcv any, method string, params any, + ctx context.Context, rcv any, method string, params ...any, ) error { var ( body = new(bytes.Buffer) diff --git a/go/daemon/jsonrpc2/client_rw.go b/go/daemon/jsonrpc2/client_rw.go index d7841ca..71f7e8c 100644 --- a/go/daemon/jsonrpc2/client_rw.go +++ b/go/daemon/jsonrpc2/client_rw.go @@ -19,7 +19,7 @@ func NewReadWriterClient(rw io.ReadWriter) Client { } func (c rwClient) Call( - ctx context.Context, rcv any, method string, params any, + ctx context.Context, rcv any, method string, params ...any, ) error { id, err := encodeRequest(c.enc, method, params) if err != nil { diff --git a/go/daemon/jsonrpc2/dispatcher.go b/go/daemon/jsonrpc2/dispatcher.go index 8eca3f7..b1b9ad6 100644 --- a/go/daemon/jsonrpc2/dispatcher.go +++ b/go/daemon/jsonrpc2/dispatcher.go @@ -18,29 +18,37 @@ type methodDispatchFunc func(context.Context, Request) (any, error) func newMethodDispatchFunc( method reflect.Value, ) methodDispatchFunc { - paramT := method.Type().In(1) - return func(ctx context.Context, req Request) (any, error) { - var ( - ctxV = reflect.ValueOf(ctx) - paramPtrV = reflect.New(paramT) - ) + paramTs := make([]reflect.Type, method.Type().NumIn()-1) + for i := range paramTs { + paramTs[i] = method.Type().In(i + 1) + } - err := json.Unmarshal(req.Params, paramPtrV.Interface()) - if err != nil { - // The JSON has already been validated, so this is not an - // errCodeParse situation. We assume it's an invalid param then, - // unless the error says otherwise via an UnmarshalJSON method - // returning an Error of its own. - if !errors.As(err, new(Error)) { - err = NewInvalidParamsError( - "JSON unmarshaling params into %T: %v", paramT, err, - ) + return func(ctx context.Context, req Request) (any, error) { + callVals := make([]reflect.Value, 0, len(paramTs)+1) + callVals = append(callVals, reflect.ValueOf(ctx)) + + for i, paramT := range paramTs { + paramPtrV := reflect.New(paramT) + + err := json.Unmarshal(req.Params[i], paramPtrV.Interface()) + if err != nil { + // The JSON has already been validated, so this is not an + // errCodeParse situation. We assume it's an invalid param then, + // unless the error says otherwise via an UnmarshalJSON method + // returning an Error of its own. + if !errors.As(err, new(Error)) { + err = NewInvalidParamsError( + "JSON unmarshaling param %d into %T: %v", i, paramT, err, + ) + } + return nil, err } - return nil, err + + callVals = append(callVals, paramPtrV.Elem()) } var ( - callResV = method.Call([]reflect.Value{ctxV, paramPtrV.Elem()}) + callResV = method.Call(callVals) resV = callResV[0] errV = callResV[1] ) @@ -86,7 +94,7 @@ func NewDispatchHandler(i any) Handler { ) if !method.IsExported() || - methodT.NumIn() != 2 || + methodT.NumIn() < 1 || methodT.In(0) != ctxT || methodT.NumOut() != 2 || methodT.Out(1) != errT { diff --git a/go/daemon/jsonrpc2/handler.go b/go/daemon/jsonrpc2/handler.go index 32f705d..ded4b6b 100644 --- a/go/daemon/jsonrpc2/handler.go +++ b/go/daemon/jsonrpc2/handler.go @@ -3,6 +3,7 @@ package jsonrpc2 import ( "context" "errors" + "fmt" "dev.mediocregopher.com/mediocre-go-lib.git/mctx" "dev.mediocregopher.com/mediocre-go-lib.git/mlog" @@ -59,9 +60,14 @@ func NewMLogMiddleware(logger *mlog.Logger) Middleware { ) if logger.MaxLevel() >= mlog.LevelDebug.Int() { - ctx := mctx.Annotate( - ctx, "rpcRequestParams", string(req.Params), - ) + ctx := ctx + for i := range req.Params { + ctx = mctx.Annotate( + ctx, + fmt.Sprintf("rpcRequestParam%d", i), + string(req.Params[i]), + ) + } logger.Debug(ctx, "Handling RPC request") } diff --git a/go/daemon/jsonrpc2/jsonrpc2.go b/go/daemon/jsonrpc2/jsonrpc2.go index 2fac7b1..83a47f5 100644 --- a/go/daemon/jsonrpc2/jsonrpc2.go +++ b/go/daemon/jsonrpc2/jsonrpc2.go @@ -13,10 +13,10 @@ const version = "2.0" // Request encodes an RPC request according to the spec. type Request struct { - Version string `json:"jsonrpc"` // must be "2.0" - Method string `json:"method"` - Params json.RawMessage `json:"params,omitempty"` - ID string `json:"id"` + Version string `json:"jsonrpc"` // must be "2.0" + Method string `json:"method"` + Params []json.RawMessage `json:"params,omitempty"` + ID string `json:"id"` } type response[Result any] struct { @@ -37,13 +37,19 @@ func newID() string { // encodeRequest writes a request to an io.Writer, returning the ID of the // request. func encodeRequest( - enc *json.Encoder, method string, params any, + enc *json.Encoder, method string, params []any, ) ( string, error, ) { - paramsB, err := json.Marshal(params) - if err != nil { - return "", fmt.Errorf("encoding params as JSON: %w", err) + var ( + paramsBs = make([]json.RawMessage, len(params)) + err error + ) + for i := range params { + paramsBs[i], err = json.Marshal(params[i]) + if err != nil { + return "", fmt.Errorf("encoding param %d as JSON: %w", i, err) + } } var ( @@ -51,7 +57,7 @@ func encodeRequest( reqEnvelope = Request{ Version: version, Method: method, - Params: paramsB, + Params: paramsBs, ID: id, } ) diff --git a/go/daemon/jsonrpc2/jsonrpc2_test.go b/go/daemon/jsonrpc2/jsonrpc2_test.go index 9735df0..4d9e977 100644 --- a/go/daemon/jsonrpc2/jsonrpc2_test.go +++ b/go/daemon/jsonrpc2/jsonrpc2_test.go @@ -26,14 +26,22 @@ var ErrDivideByZero = Error{ type dividerImpl struct{} -func (dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) { - if p.Bottom == 0 { +func (dividerImpl) Divide2(ctx context.Context, top, bottom int) (int, error) { + if bottom == 0 { return 0, ErrDivideByZero } - if p.Top%p.Bottom != 0 { + if top%bottom != 0 { return 0, errors.New("numbers don't divide evenly, cannot compute!") } - return p.Top / p.Bottom, nil + return top / bottom, nil +} + +func (i dividerImpl) Noop(ctx context.Context) (int, error) { + return 1, nil +} + +func (i dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) { + return i.Divide2(ctx, p.Top, p.Bottom) } func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) { @@ -41,6 +49,8 @@ func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) { } type divider interface { + Noop(ctx context.Context) (int, error) + Divide2(ctx context.Context, top, bottom int) (int, error) Divide(ctx context.Context, p DivideParams) (int, error) } @@ -82,6 +92,26 @@ func testClient(t *testing.T, client Client) { } }) + t.Run("success/multiple_params", func(t *testing.T) { + var res int + err := client.Call(ctx, &res, "Divide2", 6, 3) + if err != nil { + t.Fatal(err) + } else if res != 2 { + t.Fatalf("expected 2, got %d", res) + } + }) + + t.Run("success/no_params", func(t *testing.T) { + var res int + err := client.Call(ctx, &res, "Noop") + if err != nil { + t.Fatal(err) + } else if res != 1 { + t.Fatalf("expected 1, got %d", res) + } + }) + t.Run("err/application", func(t *testing.T) { err := client.Call(ctx, nil, "Divide", DivideParams{}) if !errors.Is(err, ErrDivideByZero) { diff --git a/go/daemon/rpc.go b/go/daemon/rpc.go index d929ed3..c6f7ed6 100644 --- a/go/daemon/rpc.go +++ b/go/daemon/rpc.go @@ -11,25 +11,6 @@ import ( "golang.org/x/exp/maps" ) -// CreateNetworkRequest contains the arguments to the CreateNetwork RPC method. -// -// All fields are required. -type CreateNetworkRequest struct { - // Human-readable name of the network. - Name string - - // Primary domain name that network services are served under. - Domain string - - // An IP subnet, in CIDR form, which will be the overall range of possible - // IPs in the network. The first IP in this network range will become this - // first host's IP. - IPNet nebula.IPNet - - // The name of this first host in the network. - HostName nebula.HostName -} - // GetHostsResult wraps the results from the GetHosts RPC method. type GetHostsResult struct { Hosts []bootstrap.Host @@ -75,8 +56,20 @@ type CreateNebulaCertificateResult struct { // interface. type RPC interface { // CreateNetwork passes through to the Daemon method of the same name. + // + // name: Human-readable name of the network. + // domain: Primary domain name that network services are served under. + // ipNet: + // An IP subnet, in CIDR form, which will be the overall range of + // possible IPs in the network. The first IP in this network range will + // become this first host's IP. + // hostName: The name of this first host in the network. CreateNetwork( - ctx context.Context, req CreateNetworkRequest, + ctx context.Context, + name string, + domain string, + ipNet nebula.IPNet, + hostName nebula.HostName, ) ( struct{}, error, ) @@ -89,24 +82,16 @@ type RPC interface { ) // GetHosts returns all hosts known to the network, sorted by their name. - GetHosts( - ctx context.Context, req struct{}, - ) ( - GetHostsResult, error, - ) + GetHosts(ctx context.Context) (GetHostsResult, error) // GetGarageClientParams passes the call through to the Daemon method of the // same name. - GetGarageClientParams( - ctx context.Context, req struct{}, - ) ( - GarageClientParams, error, - ) + GetGarageClientParams(ctx context.Context) (GarageClientParams, error) // GetNebulaCAPublicCredentials returns the CAPublicCredentials for the // network. GetNebulaCAPublicCredentials( - ctx context.Context, req struct{}, + ctx context.Context, ) ( nebula.CAPublicCredentials, error, ) @@ -144,12 +129,16 @@ func NewRPC(daemon Daemon) RPC { } func (r *rpcImpl) CreateNetwork( - ctx context.Context, req CreateNetworkRequest, + ctx context.Context, + name string, + domain string, + ipNet nebula.IPNet, + hostName nebula.HostName, ) ( struct{}, error, ) { return struct{}{}, r.daemon.CreateNetwork( - ctx, req.Name, req.Domain, req.IPNet, req.HostName, + ctx, name, domain, ipNet, hostName, ) } @@ -161,11 +150,7 @@ func (r *rpcImpl) JoinNetwork( return struct{}{}, r.daemon.JoinNetwork(ctx, req) } -func (r *rpcImpl) GetHosts( - ctx context.Context, req struct{}, -) ( - GetHostsResult, error, -) { +func (r *rpcImpl) GetHosts(ctx context.Context) (GetHostsResult, error) { b, err := r.daemon.GetBootstrap(ctx) if err != nil { return GetHostsResult{}, fmt.Errorf("retrieving bootstrap: %w", err) @@ -180,7 +165,7 @@ func (r *rpcImpl) GetHosts( } func (r *rpcImpl) GetGarageClientParams( - ctx context.Context, req struct{}, + ctx context.Context, ) ( GarageClientParams, error, ) { @@ -188,7 +173,7 @@ func (r *rpcImpl) GetGarageClientParams( } func (r *rpcImpl) GetNebulaCAPublicCredentials( - ctx context.Context, req struct{}, + ctx context.Context, ) ( nebula.CAPublicCredentials, error, ) {