package jsonrpc2 import ( "context" "errors" "io" "isle/toolkit" "net" "net/http/httptest" "path/filepath" "strings" "sync" "testing" "time" "github.com/stretchr/testify/assert" ) type DivideParams struct { Top, Bottom int } var ErrDivideByZero = Error{ Code: 1, Message: "Cannot divide by zero", } type dividerImpl struct{} func (dividerImpl) Divide2(ctx context.Context, top, bottom int) (int, error) { if bottom == 0 { return 0, ErrDivideByZero } if top%bottom != 0 { return 0, errors.New("numbers don't divide evenly, cannot compute!") } return top / bottom, nil } func (i dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) { return i.Divide2(ctx, p.Top, p.Bottom) } func (i dividerImpl) One(ctx context.Context) (int, error) { return 1, nil } func (i dividerImpl) Noop(ctx context.Context) error { return nil } func (i dividerImpl) Divide2FromMeta(ctx context.Context) (int, error) { var ( meta = GetMeta(ctx) top = int(meta["top"].(float64)) bottom = int(meta["bottom"].(float64)) ) return i.Divide2(ctx, top, bottom) } func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) { return 0, errors.New("Shouldn't be possible to call this!") } type divider interface { Divide2(ctx context.Context, top, bottom int) (int, error) Divide(ctx context.Context, p DivideParams) (int, error) One(ctx context.Context) (int, error) Noop(ctx context.Context) error Divide2FromMeta(ctx context.Context) (int, error) } func testHandler(t *testing.T) Handler { var ( logger = toolkit.NewTestLogger(t) d = divider(dividerImpl{}) ) return Chain( NewMLogMiddleware(logger), ExposeServerSideErrorsMiddleware, )( NewDispatchHandler(&d), ) } func testClient(t *testing.T, client Client) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() t.Run("success/no_result", func(t *testing.T) { err := client.Call(ctx, nil, "Divide", DivideParams{12, 4}) if err != nil { t.Fatal(err) } }) t.Run("success/with_result", func(t *testing.T) { var res int err := client.Call(ctx, &res, "Divide", DivideParams{6, 3}) if err != nil { t.Fatal(err) } else if res != 2 { t.Fatalf("expected 2, got %d", res) } }) t.Run("success/multiple_params", func(t *testing.T) { var res int err := client.Call(ctx, &res, "Divide2", 6, 3) if err != nil { t.Fatal(err) } else if res != 2 { t.Fatalf("expected 2, got %d", res) } }) t.Run("success/no_params", func(t *testing.T) { var res int err := client.Call(ctx, &res, "One") if err != nil { t.Fatal(err) } else if res != 1 { t.Fatalf("expected 1, got %d", res) } }) t.Run("success/no_results", func(t *testing.T) { err := client.Call(ctx, nil, "Noop") if err != nil { t.Fatal(err) } }) t.Run("success/meta", func(t *testing.T) { ctx = WithMeta(ctx, "top", 6) ctx = WithMeta(ctx, "bottom", 2) var res int err := client.Call(ctx, &res, "Divide2FromMeta") if err != nil { t.Fatal(err) } else if res != 3 { t.Fatalf("expected 2, got %d", res) } }) t.Run("err/application", func(t *testing.T) { err := client.Call(ctx, nil, "Divide", DivideParams{}) if !errors.Is(err, ErrDivideByZero) { t.Fatalf("expected application error but got: %#v", err) } }) t.Run("err/calling hidden method", func(t *testing.T) { err := client.Call(ctx, nil, "Hidden", struct{}{}) if jErr := (Error{}); !errors.As(err, &jErr) { t.Fatalf("expected RPC error but got: %#v", err) } else if jErr.Code != errCodeMethodNotFound { t.Fatalf("expected method not found error but got: %#v", jErr) } }) t.Run("err/sever-side error", func(t *testing.T) { err := client.Call(ctx, nil, "Divide", DivideParams{6, 4}) t.Log(err) if jErr := (Error{}); !errors.As(err, &jErr) { t.Fatalf("expected RPC error but got: %#v", err) } else if jErr.Code != errCodeServerError { t.Fatalf("expected server error but got: %#v", jErr) } else if !strings.Contains(jErr.Message, "cannot compute!") { t.Fatalf("expected server-side error message to be propagated but got: %#v", jErr) } }) } func TestReadWriter(t *testing.T) { type ( rw struct { io.Reader io.Writer } ) var ( ctx = context.Background() clientReader, handlerWriter = io.Pipe() handlerReader, clientWriter = io.Pipe() clientRW = rw{clientReader, clientWriter} handlerRW = rw{handlerReader, handlerWriter} server = NewReadWriterServer(testHandler(t), handlerRW) client = NewReadWriterClient(clientRW) wg = new(sync.WaitGroup) ) defer wg.Wait() defer clientWriter.Close() defer handlerWriter.Close() wg.Add(1) go func() { defer wg.Done() for { if err := server.HandleNext(ctx); errors.Is(err, io.EOF) { return } else if err != nil { return } } }() testClient(t, client) } func TestHTTP(t *testing.T) { logger := toolkit.NewTestLogger(t) server := httptest.NewServer(NewHTTPHandler(testHandler(t))) t.Cleanup(server.Close) httpClient := toolkit.NewHTTPClient(logger) t.Cleanup(func() { assert.NoError(t, httpClient.Close()) }) testClient(t, NewHTTPClient(httpClient, server.URL)) } func TestUnixHTTP(t *testing.T) { var ( logger = toolkit.NewTestLogger(t) unixSocketPath = filepath.Join(t.TempDir(), "test.sock") server = httptest.NewUnstartedServer(NewHTTPHandler(testHandler(t))) ) var err error if server.Listener, err = net.Listen("unix", unixSocketPath); err != nil { t.Fatal(err) } server.Start() t.Cleanup(server.Close) httpClient, baseURL := toolkit.NewUnixHTTPClient(logger, unixSocketPath) t.Cleanup(func() { assert.NoError(t, httpClient.Close()) }) testClient(t, NewHTTPClient(httpClient, baseURL.String())) }