package jsonrpc2

import (
	"context"
	"errors"
	"io"
	"isle/toolkit"
	"net"
	"net/http/httptest"
	"path/filepath"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
)

type DivideParams struct {
	Top, Bottom int
}

var ErrDivideByZero = Error{
	Code:    1,
	Message: "Cannot divide by zero",
}

type dividerImpl struct{}

func (dividerImpl) Divide2(ctx context.Context, top, bottom int) (int, error) {
	if bottom == 0 {
		return 0, ErrDivideByZero
	}
	if top%bottom != 0 {
		return 0, errors.New("numbers don't divide evenly, cannot compute!")
	}
	return top / bottom, nil
}

func (i dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) {
	return i.Divide2(ctx, p.Top, p.Bottom)
}

func (i dividerImpl) One(ctx context.Context) (int, error) {
	return 1, nil
}

func (i dividerImpl) Noop(ctx context.Context) error {
	return nil
}

func (i dividerImpl) Divide2FromMeta(ctx context.Context) (int, error) {
	var (
		meta   = GetMeta(ctx)
		top    = int(meta["top"].(float64))
		bottom = int(meta["bottom"].(float64))
	)

	return i.Divide2(ctx, top, bottom)
}

func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) {
	return 0, errors.New("Shouldn't be possible to call this!")
}

type divider interface {
	Divide2(ctx context.Context, top, bottom int) (int, error)
	Divide(ctx context.Context, p DivideParams) (int, error)
	One(ctx context.Context) (int, error)
	Noop(ctx context.Context) error
	Divide2FromMeta(ctx context.Context) (int, error)
}

func testHandler(t *testing.T) Handler {
	var (
		logger = toolkit.NewTestLogger(t)
		d      = divider(dividerImpl{})
	)

	return Chain(
		NewMLogMiddleware(logger),
		ExposeServerSideErrorsMiddleware,
	)(
		NewDispatchHandler(&d),
	)
}

func testClient(t *testing.T, client Client) {
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	t.Run("success/no_result", func(t *testing.T) {
		err := client.Call(ctx, nil, "Divide", DivideParams{12, 4})
		if err != nil {
			t.Fatal(err)
		}
	})

	t.Run("success/with_result", func(t *testing.T) {
		var res int
		err := client.Call(ctx, &res, "Divide", DivideParams{6, 3})
		if err != nil {
			t.Fatal(err)
		} else if res != 2 {
			t.Fatalf("expected 2, got %d", res)
		}
	})

	t.Run("success/multiple_params", func(t *testing.T) {
		var res int
		err := client.Call(ctx, &res, "Divide2", 6, 3)
		if err != nil {
			t.Fatal(err)
		} else if res != 2 {
			t.Fatalf("expected 2, got %d", res)
		}
	})

	t.Run("success/no_params", func(t *testing.T) {
		var res int
		err := client.Call(ctx, &res, "One")
		if err != nil {
			t.Fatal(err)
		} else if res != 1 {
			t.Fatalf("expected 1, got %d", res)
		}
	})

	t.Run("success/no_results", func(t *testing.T) {
		err := client.Call(ctx, nil, "Noop")
		if err != nil {
			t.Fatal(err)
		}
	})

	t.Run("success/meta", func(t *testing.T) {
		ctx = WithMeta(ctx, "top", 6)
		ctx = WithMeta(ctx, "bottom", 2)

		var res int
		err := client.Call(ctx, &res, "Divide2FromMeta")
		if err != nil {
			t.Fatal(err)
		} else if res != 3 {
			t.Fatalf("expected 2, got %d", res)
		}
	})

	t.Run("err/application", func(t *testing.T) {
		err := client.Call(ctx, nil, "Divide", DivideParams{})
		if !errors.Is(err, ErrDivideByZero) {
			t.Fatalf("expected application error but got: %#v", err)
		}
	})

	t.Run("err/calling hidden method", func(t *testing.T) {
		err := client.Call(ctx, nil, "Hidden", struct{}{})
		if jErr := (Error{}); !errors.As(err, &jErr) {
			t.Fatalf("expected RPC error but got: %#v", err)
		} else if jErr.Code != errCodeMethodNotFound {
			t.Fatalf("expected method not found error but got: %#v", jErr)
		}
	})

	t.Run("err/sever-side error", func(t *testing.T) {
		err := client.Call(ctx, nil, "Divide", DivideParams{6, 4})
		t.Log(err)
		if jErr := (Error{}); !errors.As(err, &jErr) {
			t.Fatalf("expected RPC error but got: %#v", err)
		} else if jErr.Code != errCodeServerError {
			t.Fatalf("expected server error but got: %#v", jErr)
		} else if !strings.Contains(jErr.Message, "cannot compute!") {
			t.Fatalf("expected server-side error message to be propagated but got: %#v", jErr)
		}
	})
}

func TestReadWriter(t *testing.T) {
	type (
		rw struct {
			io.Reader
			io.Writer
		}
	)

	var (
		ctx = context.Background()

		clientReader, handlerWriter = io.Pipe()
		handlerReader, clientWriter = io.Pipe()

		clientRW  = rw{clientReader, clientWriter}
		handlerRW = rw{handlerReader, handlerWriter}

		server = NewReadWriterServer(testHandler(t), handlerRW)
		client = NewReadWriterClient(clientRW)

		wg = new(sync.WaitGroup)
	)

	defer wg.Wait()
	defer clientWriter.Close()
	defer handlerWriter.Close()

	wg.Add(1)
	go func() {
		defer wg.Done()
		for {
			if err := server.HandleNext(ctx); errors.Is(err, io.EOF) {
				return
			} else if err != nil {
				return
			}
		}
	}()

	testClient(t, client)
}

func TestHTTP(t *testing.T) {
	logger := toolkit.NewTestLogger(t)
	server := httptest.NewServer(NewHTTPHandler(testHandler(t)))
	t.Cleanup(server.Close)
	httpClient := toolkit.NewHTTPClient(logger)
	t.Cleanup(func() { assert.NoError(t, httpClient.Close()) })
	testClient(t, NewHTTPClient(httpClient, server.URL))
}

func TestUnixHTTP(t *testing.T) {
	var (
		logger         = toolkit.NewTestLogger(t)
		unixSocketPath = filepath.Join(t.TempDir(), "test.sock")
		server         = httptest.NewUnstartedServer(NewHTTPHandler(testHandler(t)))
	)

	var err error
	if server.Listener, err = net.Listen("unix", unixSocketPath); err != nil {
		t.Fatal(err)
	}

	server.Start()
	t.Cleanup(server.Close)

	httpClient, baseURL := toolkit.NewUnixHTTPClient(logger, unixSocketPath)
	t.Cleanup(func() { assert.NoError(t, httpClient.Close()) })

	testClient(t, NewHTTPClient(httpClient, baseURL.String()))
}