120 lines
2.7 KiB
Go
120 lines
2.7 KiB
Go
package jsonrpc2
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
)
|
|
|
|
var (
|
|
ctxT = reflect.TypeOf((*context.Context)(nil)).Elem()
|
|
errT = reflect.TypeOf((*error)(nil)).Elem()
|
|
)
|
|
|
|
type methodDispatchFunc func(context.Context, Request) (any, error)
|
|
|
|
func newMethodDispatchFunc(
|
|
method reflect.Value,
|
|
) methodDispatchFunc {
|
|
paramTs := make([]reflect.Type, method.Type().NumIn()-1)
|
|
for i := range paramTs {
|
|
paramTs[i] = method.Type().In(i + 1)
|
|
}
|
|
|
|
return func(ctx context.Context, req Request) (any, error) {
|
|
callVals := make([]reflect.Value, 0, len(paramTs)+1)
|
|
callVals = append(callVals, reflect.ValueOf(ctx))
|
|
|
|
for i, paramT := range paramTs {
|
|
paramPtrV := reflect.New(paramT)
|
|
|
|
err := json.Unmarshal(req.Params[i], paramPtrV.Interface())
|
|
if err != nil {
|
|
// The JSON has already been validated, so this is not an
|
|
// errCodeParse situation. We assume it's an invalid param then,
|
|
// unless the error says otherwise via an UnmarshalJSON method
|
|
// returning an Error of its own.
|
|
if !errors.As(err, new(Error)) {
|
|
err = NewInvalidParamsError(
|
|
"JSON unmarshaling param %d into %T: %v", i, paramT, err,
|
|
)
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
callVals = append(callVals, paramPtrV.Elem())
|
|
}
|
|
|
|
var (
|
|
callResV = method.Call(callVals)
|
|
resV = callResV[0]
|
|
errV = callResV[1]
|
|
)
|
|
|
|
if errV.IsNil() {
|
|
return resV.Interface(), nil
|
|
}
|
|
|
|
return nil, errV.Interface().(error)
|
|
}
|
|
}
|
|
|
|
type dispatcher struct {
|
|
methods map[string]methodDispatchFunc
|
|
}
|
|
|
|
// NewDispatchHandler returns a Handler which will use methods on the passed in
|
|
// value to dispatch RPC calls. The passed in value must be a pointer. All
|
|
// exported methods which look like:
|
|
//
|
|
// MethodName(context.Context, ParamType) (ResponseType, error)
|
|
//
|
|
// will be available via RPC calls.
|
|
func NewDispatchHandler(i any) Handler {
|
|
v := reflect.ValueOf(i)
|
|
if v.Kind() != reflect.Pointer {
|
|
panic(fmt.Sprintf("expected pointer but got type %T", i))
|
|
}
|
|
|
|
v = v.Elem()
|
|
|
|
var (
|
|
t = v.Type()
|
|
numMethods = t.NumMethod()
|
|
methods = make(map[string]methodDispatchFunc, numMethods)
|
|
)
|
|
|
|
for i := range numMethods {
|
|
var (
|
|
method = t.Method(i)
|
|
methodV = v.Method(i)
|
|
methodT = methodV.Type()
|
|
)
|
|
|
|
if !method.IsExported() ||
|
|
methodT.NumIn() < 1 ||
|
|
methodT.In(0) != ctxT ||
|
|
methodT.NumOut() != 2 ||
|
|
methodT.Out(1) != errT {
|
|
continue
|
|
}
|
|
|
|
methods[method.Name] = newMethodDispatchFunc(methodV)
|
|
}
|
|
|
|
return &dispatcher{methods}
|
|
}
|
|
|
|
func (d *dispatcher) ServeRPC(ctx context.Context, req Request) (any, error) {
|
|
fn, ok := d.methods[req.Method]
|
|
if !ok {
|
|
return nil, NewError(
|
|
errCodeMethodNotFound, "unknown method %q", req.Method,
|
|
)
|
|
}
|
|
|
|
return fn(ctx, req)
|
|
}
|