package jsonrpc2

import (
	"context"
	"errors"
	"fmt"

	"dev.mediocregopher.com/mediocre-go-lib.git/mctx"
	"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
)

// Handler is any type which is capable of handling arbitrary RPC calls. The
// value returned from the ServeRPC method will be JSON encoded, wrapped in the
// response object dictated by the spec, and returned as the response.
//
// If an Error is returned as the error field then that will be wrapped and
// returned to the caller instead. If a non-Error error is returned then an
// error response indicating that a server-side error occurred will be returned
// to the client.
type Handler interface {
	ServeRPC(context.Context, Request) (any, error)
}

// HandlerFunc implements Handler on a stand-alone function.
type HandlerFunc func(context.Context, Request) (any, error)

func (hf HandlerFunc) ServeRPC(ctx context.Context, req Request) (any, error) {
	return hf(ctx, req)
}

// Middleware is used to transparently wrap a Handler with some functionality.
type Middleware func(Handler) Handler

// Chain combines a sequence of Middlewares into a single one. The Middlewares
// will apply to the wrapped handler such that the first given Middleware is the
// outermost one.
func Chain(mm ...Middleware) Middleware {
	return func(h Handler) Handler {
		for i := range mm {
			h = mm[len(mm)-1-i](h)
		}
		return h
	}
}

// NewMLogMiddleware returns a Middleware which will log an Info message on
// every successful RPC request, a Warn message if an application-level error is
// being returned (denoted by an Error being returned from the inner Handler),
// and an Error message otherwise.
//
// If the Logger has debug logging enabled then the full request parameters will
// also be logged.
func NewMLogMiddleware(logger *mlog.Logger) Middleware {
	return func(h Handler) Handler {
		return HandlerFunc(func(ctx context.Context, req Request) (any, error) {
			ctx = mctx.Annotate(
				ctx,
				"rpcRequestID", req.ID,
				"rpcRequestMethod", req.Method,
				"rpcRequestMeta", req.Params.Meta,
			)

			if logger.MaxLevel() >= mlog.LevelDebug.Int() {
				ctx := ctx
				for i := range req.Params.Args {
					ctx = mctx.Annotate(
						ctx,
						fmt.Sprintf("rpcRequestArgs%d", i),
						string(req.Params.Args[i]),
					)
				}
				logger.Debug(ctx, "Handling RPC request")
			}

			res, err := h.ServeRPC(ctx, req)
			if jErr := (Error{}); errors.As(err, &jErr) && jErr.Code != errCodeServerError {
				logger.Warn(ctx, "Returning error to client", err)
			} else if err != nil {
				logger.Error(ctx, "Unexpected server-side error", err)
			} else {
				logger.Info(ctx, "Handled RPC request")
			}

			return res, err
		})
	}
}

// ExposeServerSideErrorsMiddleware causes non-Error error messages to be
// exposed to the caller in the error message they receive.
var ExposeServerSideErrorsMiddleware Middleware = func(h Handler) Handler {
	return HandlerFunc(func(ctx context.Context, req Request) (any, error) {
		res, err := h.ServeRPC(ctx, req)
		if err != nil && !errors.As(err, new(Error)) {
			err = NewError(
				errCodeServerError, "unexpected server-side error: %v", err,
			)
		}
		return res, err
	})
}