package jsonrpc2 import ( "context" "errors" "io" "net" "net/http/httptest" "path/filepath" "strings" "sync" "testing" "time" "dev.mediocregopher.com/mediocre-go-lib.git/mlog" ) type DivideParams struct { Top, Bottom int } var ErrDivideByZero = Error{ Code: 1, Message: "Cannot divide by zero", } type dividerImpl struct{} func (dividerImpl) Divide2(ctx context.Context, top, bottom int) (int, error) { if bottom == 0 { return 0, ErrDivideByZero } if top%bottom != 0 { return 0, errors.New("numbers don't divide evenly, cannot compute!") } return top / bottom, nil } func (i dividerImpl) Noop(ctx context.Context) (int, error) { return 1, nil } func (i dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) { return i.Divide2(ctx, p.Top, p.Bottom) } func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) { return 0, errors.New("Shouldn't be possible to call this!") } type divider interface { Noop(ctx context.Context) (int, error) Divide2(ctx context.Context, top, bottom int) (int, error) Divide(ctx context.Context, p DivideParams) (int, error) } var testHandler = func() Handler { var ( logger = mlog.NewLogger(&mlog.LoggerOpts{ MaxLevel: mlog.LevelDebug.Int(), }) d = divider(dividerImpl{}) ) return Chain( NewMLogMiddleware(logger), ExposeServerSideErrorsMiddleware, )( NewDispatchHandler(&d), ) }() func testClient(t *testing.T, client Client) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() t.Run("success/no_result", func(t *testing.T) { err := client.Call(ctx, nil, "Divide", DivideParams{12, 4}) if err != nil { t.Fatal(err) } }) t.Run("success/with_result", func(t *testing.T) { var res int err := client.Call(ctx, &res, "Divide", DivideParams{6, 3}) if err != nil { t.Fatal(err) } else if res != 2 { t.Fatalf("expected 2, got %d", res) } }) t.Run("success/multiple_params", func(t *testing.T) { var res int err := client.Call(ctx, &res, "Divide2", 6, 3) if err != nil { t.Fatal(err) } else if res != 2 { t.Fatalf("expected 2, got %d", res) } }) t.Run("success/no_params", func(t *testing.T) { var res int err := client.Call(ctx, &res, "Noop") if err != nil { t.Fatal(err) } else if res != 1 { t.Fatalf("expected 1, got %d", res) } }) t.Run("err/application", func(t *testing.T) { err := client.Call(ctx, nil, "Divide", DivideParams{}) if !errors.Is(err, ErrDivideByZero) { t.Fatalf("expected application error but got: %#v", err) } }) t.Run("err/calling hidden method", func(t *testing.T) { err := client.Call(ctx, nil, "Hidden", struct{}{}) if jErr := (Error{}); !errors.As(err, &jErr) { t.Fatalf("expected RPC error but got: %#v", err) } else if jErr.Code != errCodeMethodNotFound { t.Fatalf("expected method not found error but got: %#v", jErr) } }) t.Run("err/sever-side error", func(t *testing.T) { err := client.Call(ctx, nil, "Divide", DivideParams{6, 4}) t.Log(err) if jErr := (Error{}); !errors.As(err, &jErr) { t.Fatalf("expected RPC error but got: %#v", err) } else if jErr.Code != errCodeServerError { t.Fatalf("expected server error but got: %#v", jErr) } else if !strings.Contains(jErr.Message, "cannot compute!") { t.Fatalf("expected server-side error message to be propagated but got: %#v", jErr) } }) } func TestReadWriter(t *testing.T) { type ( rw struct { io.Reader io.Writer } ) var ( ctx = context.Background() clientReader, handlerWriter = io.Pipe() handlerReader, clientWriter = io.Pipe() clientRW = rw{clientReader, clientWriter} handlerRW = rw{handlerReader, handlerWriter} server = NewReadWriterServer(testHandler, handlerRW) client = NewReadWriterClient(clientRW) wg = new(sync.WaitGroup) ) defer wg.Wait() defer clientWriter.Close() defer handlerWriter.Close() wg.Add(1) go func() { defer wg.Done() for { if err := server.HandleNext(ctx); errors.Is(err, io.EOF) { return } else if err != nil { return } } }() testClient(t, client) } func TestHTTP(t *testing.T) { server := httptest.NewServer(NewHTTPHandler(testHandler)) t.Cleanup(server.Close) testClient(t, NewHTTPClient(server.URL)) } func TestUnixHTTP(t *testing.T) { var ( unixSocketPath = filepath.Join(t.TempDir(), "test.sock") server = httptest.NewUnstartedServer(NewHTTPHandler(testHandler)) ) var err error if server.Listener, err = net.Listen("unix", unixSocketPath); err != nil { t.Fatal(err) } server.Start() t.Cleanup(server.Close) testClient(t, NewUnixHTTPClient(unixSocketPath, "/")) }