package jsonrpc2 import ( "context" "errors" "fmt" "dev.mediocregopher.com/mediocre-go-lib.git/mctx" "dev.mediocregopher.com/mediocre-go-lib.git/mlog" ) // Handler is any type which is capable of handling arbitrary RPC calls. The // value returned from the ServeRPC method will be JSON encoded, wrapped in the // response object dictated by the spec, and returned as the response. // // If an Error is returned as the error field then that will be wrapped and // returned to the caller instead. If a non-Error error is returned then an // error response indicating that a server-side error occurred will be returned // to the client. type Handler interface { ServeRPC(context.Context, Request) (any, error) } // HandlerFunc implements Handler on a stand-alone function. type HandlerFunc func(context.Context, Request) (any, error) func (hf HandlerFunc) ServeRPC(ctx context.Context, req Request) (any, error) { return hf(ctx, req) } // Middleware is used to transparently wrap a Handler with some functionality. type Middleware func(Handler) Handler // Chain combines a sequence of Middlewares into a single one. The Middlewares // will apply to the wrapped handler such that the first given Middleware is the // outermost one. func Chain(mm ...Middleware) Middleware { return func(h Handler) Handler { for i := range mm { h = mm[len(mm)-1-i](h) } return h } } // NewMLogMiddleware returns a Middleware which will log an Info message on // every successful RPC request, a Warn message if an application-level error is // being returned (denoted by an Error being returned from the inner Handler), // and an Error message otherwise. // // If the Logger has debug logging enabled then the full request parameters will // also be logged. func NewMLogMiddleware(logger *mlog.Logger) Middleware { return func(h Handler) Handler { return HandlerFunc(func(ctx context.Context, req Request) (any, error) { ctx = mctx.Annotate( ctx, "rpcRequestID", req.ID, "rpcRequestMethod", req.Method, ) if logger.MaxLevel() >= mlog.LevelDebug.Int() { ctx := ctx for i := range req.Params { ctx = mctx.Annotate( ctx, fmt.Sprintf("rpcRequestParam%d", i), string(req.Params[i]), ) } logger.Debug(ctx, "Handling RPC request") } res, err := h.ServeRPC(ctx, req) if jErr := (Error{}); errors.As(err, &jErr) && jErr.Code != errCodeServerError { logger.Warn(ctx, "Returning error to client", err) } else if err != nil { logger.Error(ctx, "Unexpected server-side error", err) } else { logger.Info(ctx, "Handled RPC request") } return res, err }) } } // ExposeServerSideErrorsMiddleware causes non-Error error messages to be // exposed to the caller in the error message they receive. var ExposeServerSideErrorsMiddleware Middleware = func(h Handler) Handler { return HandlerFunc(func(ctx context.Context, req Request) (any, error) { res, err := h.ServeRPC(ctx, req) if err != nil && !errors.As(err, new(Error)) { err = NewError( errCodeServerError, "unexpected server-side error: %v", err, ) } return res, err }) }