178 lines
3.8 KiB
Go
178 lines
3.8 KiB
Go
package jsonrpc2
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"net/http/httptest"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
|
|
)
|
|
|
|
type DivideParams struct {
|
|
Top, Bottom int
|
|
}
|
|
|
|
var ErrDivideByZero = Error{
|
|
Code: 1,
|
|
Message: "Cannot divide by zero",
|
|
}
|
|
|
|
type dividerImpl struct{}
|
|
|
|
func (dividerImpl) Divide(ctx context.Context, p DivideParams) (int, error) {
|
|
if p.Bottom == 0 {
|
|
return 0, ErrDivideByZero
|
|
}
|
|
if p.Top%p.Bottom != 0 {
|
|
return 0, errors.New("numbers don't divide evenly, cannot compute!")
|
|
}
|
|
return p.Top / p.Bottom, nil
|
|
}
|
|
|
|
func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) {
|
|
return 0, errors.New("Shouldn't be possible to call this!")
|
|
}
|
|
|
|
type divider interface {
|
|
Divide(ctx context.Context, p DivideParams) (int, error)
|
|
}
|
|
|
|
var testHandler = func() Handler {
|
|
var (
|
|
logger = mlog.NewLogger(&mlog.LoggerOpts{
|
|
MaxLevel: mlog.LevelDebug.Int(),
|
|
})
|
|
|
|
d = divider(dividerImpl{})
|
|
)
|
|
|
|
return Chain(
|
|
NewMLogMiddleware(logger),
|
|
ExposeServerSideErrorsMiddleware,
|
|
)(
|
|
NewDispatchHandler(&d),
|
|
)
|
|
}()
|
|
|
|
func testClient(t *testing.T, client Client) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
t.Run("success/no_result", func(t *testing.T) {
|
|
err := client.Call(ctx, nil, "Divide", DivideParams{12, 4})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
|
|
t.Run("success/with_result", func(t *testing.T) {
|
|
var res int
|
|
err := client.Call(ctx, &res, "Divide", DivideParams{6, 3})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
} else if res != 2 {
|
|
t.Fatalf("expected 2, got %d", res)
|
|
}
|
|
})
|
|
|
|
t.Run("err/application", func(t *testing.T) {
|
|
err := client.Call(ctx, nil, "Divide", DivideParams{})
|
|
if !errors.Is(err, ErrDivideByZero) {
|
|
t.Fatalf("expected application error but got: %#v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("err/calling hidden method", func(t *testing.T) {
|
|
err := client.Call(ctx, nil, "Hidden", struct{}{})
|
|
if jErr := (Error{}); !errors.As(err, &jErr) {
|
|
t.Fatalf("expected RPC error but got: %#v", err)
|
|
} else if jErr.Code != errCodeMethodNotFound {
|
|
t.Fatalf("expected method not found error but got: %#v", jErr)
|
|
}
|
|
})
|
|
|
|
t.Run("err/sever-side error", func(t *testing.T) {
|
|
err := client.Call(ctx, nil, "Divide", DivideParams{6, 4})
|
|
t.Log(err)
|
|
if jErr := (Error{}); !errors.As(err, &jErr) {
|
|
t.Fatalf("expected RPC error but got: %#v", err)
|
|
} else if jErr.Code != errCodeServerError {
|
|
t.Fatalf("expected server error but got: %#v", jErr)
|
|
} else if !strings.Contains(jErr.Message, "cannot compute!") {
|
|
t.Fatalf("expected server-side error message to be propagated but got: %#v", jErr)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestReadWriter(t *testing.T) {
|
|
type (
|
|
rw struct {
|
|
io.Reader
|
|
io.Writer
|
|
}
|
|
)
|
|
|
|
var (
|
|
ctx = context.Background()
|
|
|
|
clientReader, handlerWriter = io.Pipe()
|
|
handlerReader, clientWriter = io.Pipe()
|
|
|
|
clientRW = rw{clientReader, clientWriter}
|
|
handlerRW = rw{handlerReader, handlerWriter}
|
|
|
|
server = NewReadWriterServer(testHandler, handlerRW)
|
|
client = NewReadWriterClient(clientRW)
|
|
|
|
wg = new(sync.WaitGroup)
|
|
)
|
|
|
|
defer wg.Wait()
|
|
defer clientWriter.Close()
|
|
defer handlerWriter.Close()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for {
|
|
if err := server.HandleNext(ctx); errors.Is(err, io.EOF) {
|
|
return
|
|
} else if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
testClient(t, client)
|
|
}
|
|
|
|
func TestHTTP(t *testing.T) {
|
|
server := httptest.NewServer(NewHTTPHandler(testHandler))
|
|
t.Cleanup(server.Close)
|
|
testClient(t, NewHTTPClient(server.URL))
|
|
}
|
|
|
|
func TestUnixHTTP(t *testing.T) {
|
|
var (
|
|
unixSocketPath = filepath.Join(t.TempDir(), "test.sock")
|
|
server = httptest.NewUnstartedServer(NewHTTPHandler(testHandler))
|
|
)
|
|
|
|
var err error
|
|
if server.Listener, err = net.Listen("unix", unixSocketPath); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
server.Start()
|
|
t.Cleanup(server.Close)
|
|
|
|
testClient(t, NewUnixHTTPClient(unixSocketPath, "/"))
|
|
}
|