package jsonrpc2 import ( "context" "encoding/json" "errors" "fmt" "reflect" ) var ( ctxT = reflect.TypeOf((*context.Context)(nil)).Elem() errT = reflect.TypeOf((*error)(nil)).Elem() ) type methodDispatchFunc func(context.Context, Request) (any, error) func newMethodDispatchFunc( method reflect.Value, ) methodDispatchFunc { paramTs := make([]reflect.Type, method.Type().NumIn()-1) for i := range paramTs { paramTs[i] = method.Type().In(i + 1) } return func(ctx context.Context, req Request) (any, error) { callVals := make([]reflect.Value, 0, len(paramTs)+1) callVals = append(callVals, reflect.ValueOf(ctx)) for i, paramT := range paramTs { paramPtrV := reflect.New(paramT) err := json.Unmarshal(req.Params[i], paramPtrV.Interface()) if err != nil { // The JSON has already been validated, so this is not an // errCodeParse situation. We assume it's an invalid param then, // unless the error says otherwise via an UnmarshalJSON method // returning an Error of its own. if !errors.As(err, new(Error)) { err = NewInvalidParamsError( "JSON unmarshaling param %d into %T: %v", i, paramT, err, ) } return nil, err } callVals = append(callVals, paramPtrV.Elem()) } var ( callResV = method.Call(callVals) resV, errV reflect.Value ) if len(callResV) == 1 { errV = callResV[0] } else { resV = callResV[0] errV = callResV[1] } if !errV.IsNil() { return nil, errV.Interface().(error) } if resV.IsValid() { return resV.Interface(), nil } return nil, nil } } type dispatcher struct { methods map[string]methodDispatchFunc } // NewDispatchHandler returns a Handler which will use methods on the passed in // value to dispatch RPC calls. The passed in value must be a pointer. All // exported methods which look like: // // MethodName(context.Context, ...ParamType) (ResponseType, error) // MethodName(context.Context, ...ParamType) error // // will be available via RPC calls. func NewDispatchHandler(i any) Handler { v := reflect.ValueOf(i) if v.Kind() != reflect.Pointer { panic(fmt.Sprintf("expected pointer but got type %T", i)) } v = v.Elem() var ( t = v.Type() numMethods = t.NumMethod() methods = make(map[string]methodDispatchFunc, numMethods) ) for i := range numMethods { var ( method = t.Method(i) methodV = v.Method(i) methodT = methodV.Type() ) if !method.IsExported() || methodT.NumIn() < 1 || methodT.In(0) != ctxT || (methodT.NumOut() == 1 && methodT.Out(0) != errT) || (methodT.NumOut() == 2 && methodT.Out(1) != errT) || methodT.NumOut() > 2 { continue } methods[method.Name] = newMethodDispatchFunc(methodV) } return &dispatcher{methods} } func (d *dispatcher) ServeRPC(ctx context.Context, req Request) (any, error) { fn, ok := d.methods[req.Method] if !ok { return nil, NewError( errCodeMethodNotFound, "unknown method %q", req.Method, ) } return fn(ctx, req) }