2024-06-22 15:37:15 +00:00
|
|
|
package jsonrpc2
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"encoding/json"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"reflect"
|
|
|
|
)
|
|
|
|
|
|
|
|
var (
|
|
|
|
ctxT = reflect.TypeOf((*context.Context)(nil)).Elem()
|
|
|
|
errT = reflect.TypeOf((*error)(nil)).Elem()
|
|
|
|
)
|
|
|
|
|
|
|
|
type methodDispatchFunc func(context.Context, Request) (any, error)
|
|
|
|
|
|
|
|
func newMethodDispatchFunc(
|
|
|
|
method reflect.Value,
|
|
|
|
) methodDispatchFunc {
|
2024-09-04 20:25:38 +00:00
|
|
|
paramTs := make([]reflect.Type, method.Type().NumIn()-1)
|
|
|
|
for i := range paramTs {
|
|
|
|
paramTs[i] = method.Type().In(i + 1)
|
|
|
|
}
|
2024-06-22 15:37:15 +00:00
|
|
|
|
2024-09-04 20:25:38 +00:00
|
|
|
return func(ctx context.Context, req Request) (any, error) {
|
|
|
|
callVals := make([]reflect.Value, 0, len(paramTs)+1)
|
|
|
|
callVals = append(callVals, reflect.ValueOf(ctx))
|
|
|
|
|
|
|
|
for i, paramT := range paramTs {
|
|
|
|
paramPtrV := reflect.New(paramT)
|
|
|
|
|
|
|
|
err := json.Unmarshal(req.Params[i], paramPtrV.Interface())
|
|
|
|
if err != nil {
|
|
|
|
// The JSON has already been validated, so this is not an
|
|
|
|
// errCodeParse situation. We assume it's an invalid param then,
|
|
|
|
// unless the error says otherwise via an UnmarshalJSON method
|
|
|
|
// returning an Error of its own.
|
|
|
|
if !errors.As(err, new(Error)) {
|
|
|
|
err = NewInvalidParamsError(
|
|
|
|
"JSON unmarshaling param %d into %T: %v", i, paramT, err,
|
|
|
|
)
|
|
|
|
}
|
|
|
|
return nil, err
|
2024-06-22 15:37:15 +00:00
|
|
|
}
|
2024-09-04 20:25:38 +00:00
|
|
|
|
|
|
|
callVals = append(callVals, paramPtrV.Elem())
|
2024-06-22 15:37:15 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
var (
|
2024-09-05 15:28:10 +00:00
|
|
|
callResV = method.Call(callVals)
|
|
|
|
resV, errV reflect.Value
|
2024-06-22 15:37:15 +00:00
|
|
|
)
|
|
|
|
|
2024-09-05 15:28:10 +00:00
|
|
|
if len(callResV) == 1 {
|
|
|
|
errV = callResV[0]
|
|
|
|
} else {
|
|
|
|
resV = callResV[0]
|
|
|
|
errV = callResV[1]
|
|
|
|
}
|
|
|
|
|
|
|
|
if !errV.IsNil() {
|
|
|
|
return nil, errV.Interface().(error)
|
|
|
|
}
|
|
|
|
|
|
|
|
if resV.IsValid() {
|
2024-06-22 15:37:15 +00:00
|
|
|
return resV.Interface(), nil
|
|
|
|
}
|
|
|
|
|
2024-09-05 15:28:10 +00:00
|
|
|
return nil, nil
|
2024-06-22 15:37:15 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type dispatcher struct {
|
|
|
|
methods map[string]methodDispatchFunc
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewDispatchHandler returns a Handler which will use methods on the passed in
|
|
|
|
// value to dispatch RPC calls. The passed in value must be a pointer. All
|
|
|
|
// exported methods which look like:
|
|
|
|
//
|
2024-09-05 17:36:21 +00:00
|
|
|
// MethodName(context.Context, ...ParamType) (ResponseType, error)
|
|
|
|
// MethodName(context.Context, ...ParamType) error
|
2024-06-22 15:37:15 +00:00
|
|
|
//
|
|
|
|
// will be available via RPC calls.
|
|
|
|
func NewDispatchHandler(i any) Handler {
|
|
|
|
v := reflect.ValueOf(i)
|
|
|
|
if v.Kind() != reflect.Pointer {
|
|
|
|
panic(fmt.Sprintf("expected pointer but got type %T", i))
|
|
|
|
}
|
|
|
|
|
|
|
|
v = v.Elem()
|
|
|
|
|
|
|
|
var (
|
|
|
|
t = v.Type()
|
|
|
|
numMethods = t.NumMethod()
|
|
|
|
methods = make(map[string]methodDispatchFunc, numMethods)
|
|
|
|
)
|
|
|
|
|
|
|
|
for i := range numMethods {
|
|
|
|
var (
|
|
|
|
method = t.Method(i)
|
|
|
|
methodV = v.Method(i)
|
|
|
|
methodT = methodV.Type()
|
|
|
|
)
|
|
|
|
|
|
|
|
if !method.IsExported() ||
|
2024-09-04 20:25:38 +00:00
|
|
|
methodT.NumIn() < 1 ||
|
2024-06-22 15:37:15 +00:00
|
|
|
methodT.In(0) != ctxT ||
|
2024-09-05 15:28:10 +00:00
|
|
|
(methodT.NumOut() == 1 && methodT.Out(0) != errT) ||
|
|
|
|
(methodT.NumOut() == 2 && methodT.Out(1) != errT) ||
|
|
|
|
methodT.NumOut() > 2 {
|
2024-06-22 15:37:15 +00:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
methods[method.Name] = newMethodDispatchFunc(methodV)
|
|
|
|
}
|
|
|
|
|
|
|
|
return &dispatcher{methods}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (d *dispatcher) ServeRPC(ctx context.Context, req Request) (any, error) {
|
|
|
|
fn, ok := d.methods[req.Method]
|
|
|
|
if !ok {
|
|
|
|
return nil, NewError(
|
|
|
|
errCodeMethodNotFound, "unknown method %q", req.Method,
|
|
|
|
)
|
|
|
|
}
|
|
|
|
|
|
|
|
return fn(ctx, req)
|
|
|
|
}
|