isle/go/daemon/jsonrpc2/dispatcher.go

132 lines
2.9 KiB
Go
Raw Normal View History

package jsonrpc2
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
)
var (
ctxT = reflect.TypeOf((*context.Context)(nil)).Elem()
errT = reflect.TypeOf((*error)(nil)).Elem()
)
type methodDispatchFunc func(context.Context, Request) (any, error)
func newMethodDispatchFunc(
method reflect.Value,
) methodDispatchFunc {
paramTs := make([]reflect.Type, method.Type().NumIn()-1)
for i := range paramTs {
paramTs[i] = method.Type().In(i + 1)
}
return func(ctx context.Context, req Request) (any, error) {
callVals := make([]reflect.Value, 0, len(paramTs)+1)
callVals = append(callVals, reflect.ValueOf(ctx))
for i, paramT := range paramTs {
paramPtrV := reflect.New(paramT)
err := json.Unmarshal(req.Params[i], paramPtrV.Interface())
if err != nil {
// The JSON has already been validated, so this is not an
// errCodeParse situation. We assume it's an invalid param then,
// unless the error says otherwise via an UnmarshalJSON method
// returning an Error of its own.
if !errors.As(err, new(Error)) {
err = NewInvalidParamsError(
"JSON unmarshaling param %d into %T: %v", i, paramT, err,
)
}
return nil, err
}
callVals = append(callVals, paramPtrV.Elem())
}
var (
callResV = method.Call(callVals)
resV, errV reflect.Value
)
if len(callResV) == 1 {
errV = callResV[0]
} else {
resV = callResV[0]
errV = callResV[1]
}
if !errV.IsNil() {
return nil, errV.Interface().(error)
}
if resV.IsValid() {
return resV.Interface(), nil
}
return nil, nil
}
}
type dispatcher struct {
methods map[string]methodDispatchFunc
}
// NewDispatchHandler returns a Handler which will use methods on the passed in
// value to dispatch RPC calls. The passed in value must be a pointer. All
// exported methods which look like:
//
// MethodName(context.Context, ...ParamType) (ResponseType, error)
// MethodName(context.Context, ...ParamType) error
//
// will be available via RPC calls.
func NewDispatchHandler(i any) Handler {
v := reflect.ValueOf(i)
if v.Kind() != reflect.Pointer {
panic(fmt.Sprintf("expected pointer but got type %T", i))
}
v = v.Elem()
var (
t = v.Type()
numMethods = t.NumMethod()
methods = make(map[string]methodDispatchFunc, numMethods)
)
for i := range numMethods {
var (
method = t.Method(i)
methodV = v.Method(i)
methodT = methodV.Type()
)
if !method.IsExported() ||
methodT.NumIn() < 1 ||
methodT.In(0) != ctxT ||
(methodT.NumOut() == 1 && methodT.Out(0) != errT) ||
(methodT.NumOut() == 2 && methodT.Out(1) != errT) ||
methodT.NumOut() > 2 {
continue
}
methods[method.Name] = newMethodDispatchFunc(methodV)
}
return &dispatcher{methods}
}
func (d *dispatcher) ServeRPC(ctx context.Context, req Request) (any, error) {
fn, ok := d.methods[req.Method]
if !ok {
return nil, NewError(
errCodeMethodNotFound, "unknown method %q", req.Method,
)
}
return fn(ctx, req)
}