MVP of chat package

pull/15/head
Brian Picciano 3 years ago
parent ac5275353c
commit eaccf41563
  1. 475
      srv/chat/chat.go
  2. 200
      srv/chat/chat_test.go
  3. 28
      srv/chat/util.go
  4. 1
      srv/go.mod
  5. 3
      srv/go.sum

@ -0,0 +1,475 @@
// Package chat implements a simple chatroom system.
package chat
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"sync"
"time"
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
"github.com/mediocregopher/radix/v4"
)
// ErrInvalidArg is returned from methods in this package when a call fails due
// to invalid input.
type ErrInvalidArg struct {
Err error
}
func (e ErrInvalidArg) Error() string {
return fmt.Sprintf("invalid argument: %v", e.Err)
}
var (
errInvalidMessageID = ErrInvalidArg{Err: errors.New("invalid Message ID")}
)
// UserID uniquely identifies an individual user who has posted a message in a
// Room.
type UserID struct {
// Name will be the user's chosen display name.
Name string `json:"name"`
// Hash will be a hex string generated from a secret only the user knows.
Hash string `json:"id"`
}
// Message describes a message which has been posted to a Room.
type Message struct {
ID string `json:"id"`
UserID UserID `json:"userID"`
Body string `json:"body"`
}
func msgFromStreamEntry(entry radix.StreamEntry) (Message, error) {
// NOTE this should probably be a shortcut in radix
var bodyStr string
for _, field := range entry.Fields {
if field[0] == "json" {
bodyStr = field[1]
break
}
}
if bodyStr == "" {
return Message{}, errors.New("no 'json' field")
}
var msg Message
if err := json.Unmarshal([]byte(bodyStr), &msg); err != nil {
return Message{}, fmt.Errorf(
"json unmarshaling body %q: %w", bodyStr, err,
)
}
msg.ID = entry.ID.String()
return msg, nil
}
// MessageIterator returns a sequence of Messages which may or may not be
// unbounded.
type MessageIterator interface {
// Next blocks until it returns the next Message in the sequence, or the
// context error if the context is cancelled, or io.EOF if the sequence has
// been exhausted.
Next(context.Context) (Message, error)
// Close should always be called once Next has returned an error or the
// MessageIterator will no longer be used.
Close() error
}
// HistoryOpts are passed into Room's History method in order to affect its
// result. All fields are optional.
type HistoryOpts struct {
Limit int // defaults to, and is capped at, 100.
Cursor string // If not given then the most recent Messages are returned.
}
func (o HistoryOpts) sanitize() (HistoryOpts, error) {
if o.Limit < 0 || o.Limit > 100 {
o.Limit = 100
}
if o.Cursor != "" {
id, err := parseStreamEntryID(o.Cursor)
if err != nil {
return HistoryOpts{}, fmt.Errorf("parsing Cursor: %w", err)
}
o.Cursor = id.String()
}
return o, nil
}
// Room implements functionality related to a single, unique chat room.
type Room interface {
// Append accepts a new Message and stores it at the end of the room's
// history. The original Message is returned with any relevant fields (e.g.
// ID) updated.
Append(context.Context, Message) (Message, error)
// Returns a cursor and the list of historical Messages in time descending
// order. The cursor can be passed into the next call to History to receive
// the next set of Messages.
History(context.Context, HistoryOpts) (string, []Message, error)
// Listen returns a MessageIterator which will return all Messages appended
// to the Room since the given ID. Once all existing messages are iterated
// through then the MessageIterator will begin blocking until a new Message
// is posted.
Listen(ctx context.Context, sinceID string) (MessageIterator, error)
// Delete deletes a Message from the Room.
Delete(ctx context.Context, id string) error
// Close is used to clean up all resources created by the Room.
Close() error
}
// RoomParams are used to instantiate a new Room. All fields are required unless
// otherwise noted.
type RoomParams struct {
Logger *mlog.Logger
Redis radix.Client
ID string
MaxMessages int
}
func (p RoomParams) streamKey() string {
return fmt.Sprintf("chat:{%s}:stream", p.ID)
}
type room struct {
params RoomParams
closeCtx context.Context
closeCancel context.CancelFunc
wg sync.WaitGroup
listeningL sync.Mutex
listening map[chan Message]struct{}
listeningLastID radix.StreamEntryID
}
// NewRoom initializes and returns a new Room instance.
func NewRoom(ctx context.Context, params RoomParams) (Room, error) {
params.Logger = params.Logger.WithNamespace("chat-room")
r := &room{
params: params,
listening: map[chan Message]struct{}{},
}
r.closeCtx, r.closeCancel = context.WithCancel(context.Background())
// figure out the most recent message, if any.
lastEntryID, err := r.mostRecentMsgID(ctx)
if err != nil {
return nil, fmt.Errorf("discovering most recent entry ID in stream: %w", err)
}
r.listeningLastID = lastEntryID
r.wg.Add(1)
go func() {
defer r.wg.Done()
r.readStreamLoop(r.closeCtx)
}()
return r, nil
}
func (r *room) Close() error {
r.closeCancel()
r.wg.Wait()
return nil
}
func (r *room) mostRecentMsgID(ctx context.Context) (radix.StreamEntryID, error) {
var entries []radix.StreamEntry
err := r.params.Redis.Do(ctx, radix.Cmd(
&entries,
"XREVRANGE", r.params.streamKey(), "+", "-", "COUNT", "1",
))
if err != nil || len(entries) == 0 {
return radix.StreamEntryID{}, err
}
return entries[0].ID, nil
}
func (r *room) Append(ctx context.Context, msg Message) (Message, error) {
msg.ID = "" // just in case
b, err := json.Marshal(msg)
if err != nil {
return Message{}, fmt.Errorf("json marshaling Message: %w", err)
}
key := r.params.streamKey()
maxLen := strconv.Itoa(r.params.MaxMessages)
body := string(b)
var id string
err = r.params.Redis.Do(ctx, radix.Cmd(
&id, "XADD", key, "MAXLEN", "=", maxLen, "*", "json", body,
))
if err != nil {
return Message{}, fmt.Errorf("posting message to redis: %w", err)
}
msg.ID = id
return msg, nil
}
const zeroCursor = "0-0"
func (r *room) History(ctx context.Context, opts HistoryOpts) (string, []Message, error) {
opts, err := opts.sanitize()
if err != nil {
return "", nil, err
}
key := r.params.streamKey()
end := opts.Cursor
if end == "" {
end = "+"
}
start := "-"
count := strconv.Itoa(opts.Limit)
msgs := make([]Message, 0, opts.Limit)
streamEntries := make([]radix.StreamEntry, 0, opts.Limit)
err = r.params.Redis.Do(ctx, radix.Cmd(
&streamEntries,
"XREVRANGE", key, end, start, "COUNT", count,
))
if err != nil {
return "", nil, fmt.Errorf("calling XREVRANGE: %w", err)
}
var oldestEntryID radix.StreamEntryID
for _, entry := range streamEntries {
oldestEntryID = entry.ID
msg, err := msgFromStreamEntry(entry)
if err != nil {
return "", nil, fmt.Errorf(
"parsing stream entry %q: %w", entry.ID, err,
)
}
msgs = append(msgs, msg)
}
if len(msgs) < opts.Limit {
return zeroCursor, msgs, nil
}
cursor := oldestEntryID.Prev()
return cursor.String(), msgs, nil
}
func (r *room) readStream(ctx context.Context) error {
r.listeningL.Lock()
lastEntryID := r.listeningLastID
r.listeningL.Unlock()
redisAddr := r.params.Redis.Addr()
redisConn, err := radix.Dial(ctx, redisAddr.Network(), redisAddr.String())
if err != nil {
return fmt.Errorf("creating redis connection: %w", err)
}
defer redisConn.Close()
streamReader := (radix.StreamReaderConfig{}).New(
redisConn,
map[string]radix.StreamConfig{
r.params.streamKey(): {After: lastEntryID},
},
)
for {
dlCtx, dlCtxCancel := context.WithTimeout(ctx, 10*time.Second)
_, streamEntry, err := streamReader.Next(dlCtx)
dlCtxCancel()
if errors.Is(err, radix.ErrNoStreamEntries) {
continue
} else if err != nil {
return fmt.Errorf("fetching next entry from stream: %w", err)
}
msg, err := msgFromStreamEntry(streamEntry)
if err != nil {
return fmt.Errorf("parsing stream entry %q: %w", streamEntry, err)
}
r.listeningL.Lock()
var dropped int
for ch := range r.listening {
select {
case ch <- msg:
default:
dropped++
}
}
if dropped > 0 {
ctx := mctx.Annotate(ctx, "msgID", msg.ID, "dropped", dropped)
r.params.Logger.WarnString(ctx, "some listening channels full, messages dropped")
}
r.listeningLastID = streamEntry.ID
r.listeningL.Unlock()
}
}
func (r *room) readStreamLoop(ctx context.Context) {
for {
err := r.readStream(ctx)
if errors.Is(err, context.Canceled) {
return
} else if err != nil {
r.params.Logger.Error(ctx, "reading from redis stream", err)
}
}
}
type listenMsgIterator struct {
ch <-chan Message
missedMsgs []Message
sinceEntryID radix.StreamEntryID
cleanup func()
}
func (i *listenMsgIterator) Next(ctx context.Context) (Message, error) {
if len(i.missedMsgs) > 0 {
msg := i.missedMsgs[0]
i.missedMsgs = i.missedMsgs[1:]
return msg, nil
}
for {
select {
case <-ctx.Done():
return Message{}, ctx.Err()
case msg := <-i.ch:
entryID, err := parseStreamEntryID(msg.ID)
if err != nil {
return Message{}, fmt.Errorf("parsing Message ID %q: %w", msg.ID, err)
} else if !i.sinceEntryID.Before(entryID) {
// this can happen if someone Appends a Message at the same time
// as another calls Listen. The Listener might have already seen
// the Message by calling History prior to the stream reader
// having processed it and updating listeningLastID.
continue
}
return msg, nil
}
}
}
func (i *listenMsgIterator) Close() error {
i.cleanup()
return nil
}
func (r *room) Listen(
ctx context.Context, sinceID string,
) (
MessageIterator, error,
) {
var sinceEntryID radix.StreamEntryID
if sinceID != "" {
var err error
if sinceEntryID, err = parseStreamEntryID(sinceID); err != nil {
return nil, fmt.Errorf("parsing sinceID: %w", err)
}
}
ch := make(chan Message, 32)
r.listeningL.Lock()
lastEntryID := r.listeningLastID
r.listening[ch] = struct{}{}
r.listeningL.Unlock()
cleanup := func() {
r.listeningL.Lock()
defer r.listeningL.Unlock()
delete(r.listening, ch)
}
key := r.params.streamKey()
start := sinceEntryID.Next().String()
end := "+"
if lastEntryID != (radix.StreamEntryID{}) {
end = lastEntryID.String()
}
var streamEntries []radix.StreamEntry
err := r.params.Redis.Do(ctx, radix.Cmd(
&streamEntries,
"XRANGE", key, start, end,
))
if err != nil {
cleanup()
return nil, fmt.Errorf("retrieving missed stream entries: %w", err)
}
missedMsgs := make([]Message, len(streamEntries))
for i := range streamEntries {
msg, err := msgFromStreamEntry(streamEntries[i])
if err != nil {
cleanup()
return nil, fmt.Errorf(
"parsing stream entry %q: %w", streamEntries[i].ID, err,
)
}
missedMsgs[i] = msg
}
return &listenMsgIterator{
ch: ch,
missedMsgs: missedMsgs,
sinceEntryID: sinceEntryID,
cleanup: cleanup,
}, nil
}
func (r *room) Delete(ctx context.Context, id string) error {
return r.params.Redis.Do(ctx, radix.Cmd(
nil, "XDEL", r.params.streamKey(), id,
))
}

@ -0,0 +1,200 @@
package chat
import (
"context"
"strconv"
"testing"
"time"
"github.com/google/uuid"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
"github.com/mediocregopher/radix/v4"
"github.com/stretchr/testify/assert"
)
const roomTestHarnessMaxMsgs = 10
type roomTestHarness struct {
ctx context.Context
room Room
allMsgs []Message
}
func (h *roomTestHarness) newMsg(t *testing.T) Message {
msg, err := h.room.Append(h.ctx, Message{
UserID: UserID{
Name: uuid.New().String(),
Hash: "0000",
},
Body: uuid.New().String(),
})
assert.NoError(t, err)
t.Logf("appended message %s", msg.ID)
h.allMsgs = append([]Message{msg}, h.allMsgs...)
if len(h.allMsgs) > roomTestHarnessMaxMsgs {
h.allMsgs = h.allMsgs[:roomTestHarnessMaxMsgs]
}
return msg
}
func newRoomTestHarness(t *testing.T) *roomTestHarness {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
redis, err := radix.Dial(ctx, "tcp", "127.0.0.1:6379")
assert.NoError(t, err)
t.Cleanup(func() { redis.Close() })
roomParams := RoomParams{
Logger: mlog.NewLogger(nil),
Redis: redis,
ID: uuid.New().String(),
MaxMessages: roomTestHarnessMaxMsgs,
}
t.Logf("creating test Room %q", roomParams.ID)
room, err := NewRoom(ctx, roomParams)
assert.NoError(t, err)
t.Cleanup(func() {
err := redis.Do(context.Background(), radix.Cmd(
nil, "DEL", roomParams.streamKey(),
))
assert.NoError(t, err)
})
return &roomTestHarness{ctx: ctx, room: room}
}
func TestRoom(t *testing.T) {
t.Run("history", func(t *testing.T) {
tests := []struct {
numMsgs int
limit int
}{
{numMsgs: 0, limit: 1},
{numMsgs: 1, limit: 1},
{numMsgs: 2, limit: 1},
{numMsgs: 2, limit: 10},
{numMsgs: 9, limit: 2},
{numMsgs: 9, limit: 3},
{numMsgs: 9, limit: 4},
{numMsgs: 15, limit: 3},
}
for i, test := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Logf("test: %+v", test)
h := newRoomTestHarness(t)
for j := 0; j < test.numMsgs; j++ {
h.newMsg(t)
}
var gotMsgs []Message
var cursor string
for {
var msgs []Message
var err error
cursor, msgs, err = h.room.History(h.ctx, HistoryOpts{
Cursor: cursor,
Limit: test.limit,
})
assert.NoError(t, err)
assert.NotEmpty(t, cursor)
if len(msgs) == 0 {
break
}
gotMsgs = append(gotMsgs, msgs...)
}
assert.Equal(t, h.allMsgs, gotMsgs)
})
}
})
assertNextMsg := func(
t *testing.T, expMsg Message,
ctx context.Context, it MessageIterator,
) {
t.Helper()
gotMsg, err := it.Next(ctx)
assert.NoError(t, err)
assert.Equal(t, expMsg, gotMsg)
}
t.Run("listen/already_populated", func(t *testing.T) {
h := newRoomTestHarness(t)
msgA, msgB, msgC := h.newMsg(t), h.newMsg(t), h.newMsg(t)
_ = msgA
_ = msgB
itFoo, err := h.room.Listen(h.ctx, msgC.ID)
assert.NoError(t, err)
defer itFoo.Close()
itBar, err := h.room.Listen(h.ctx, msgA.ID)
assert.NoError(t, err)
defer itBar.Close()
msgD := h.newMsg(t)
// itBar should get msgB and msgC before anything else.
assertNextMsg(t, msgB, h.ctx, itBar)
assertNextMsg(t, msgC, h.ctx, itBar)
// now both iterators should give msgD
assertNextMsg(t, msgD, h.ctx, itFoo)
assertNextMsg(t, msgD, h.ctx, itBar)
// timeout should be honored
{
timeoutCtx, timeoutCancel := context.WithTimeout(h.ctx, 1*time.Second)
_, errFoo := itFoo.Next(timeoutCtx)
_, errBar := itBar.Next(timeoutCtx)
timeoutCancel()
assert.ErrorIs(t, errFoo, context.DeadlineExceeded)
assert.ErrorIs(t, errBar, context.DeadlineExceeded)
}
// new message should work
{
expMsg := h.newMsg(t)
timeoutCtx, timeoutCancel := context.WithTimeout(h.ctx, 1*time.Second)
gotFooMsg, errFoo := itFoo.Next(timeoutCtx)
gotBarMsg, errBar := itBar.Next(timeoutCtx)
timeoutCancel()
assert.Equal(t, expMsg, gotFooMsg)
assert.NoError(t, errFoo)
assert.Equal(t, expMsg, gotBarMsg)
assert.NoError(t, errBar)
}
})
t.Run("listen/empty", func(t *testing.T) {
h := newRoomTestHarness(t)
it, err := h.room.Listen(h.ctx, "")
assert.NoError(t, err)
defer it.Close()
msg := h.newMsg(t)
assertNextMsg(t, msg, h.ctx, it)
})
}

@ -0,0 +1,28 @@
package chat
import (
"strconv"
"strings"
"github.com/mediocregopher/radix/v4"
)
func parseStreamEntryID(str string) (radix.StreamEntryID, error) {
split := strings.SplitN(str, "-", 2)
if len(split) != 2 {
return radix.StreamEntryID{}, errInvalidMessageID
}
time, err := strconv.ParseUint(split[0], 10, 64)
if err != nil {
return radix.StreamEntryID{}, errInvalidMessageID
}
seq, err := strconv.ParseUint(split[1], 10, 64)
if err != nil {
return radix.StreamEntryID{}, errInvalidMessageID
}
return radix.StreamEntryID{Time: time, Seq: seq}, nil
}

@ -8,6 +8,7 @@ require (
github.com/google/uuid v1.3.0
github.com/mattn/go-sqlite3 v1.14.8
github.com/mediocregopher/mediocre-go-lib/v2 v2.0.0-beta.0
github.com/mediocregopher/radix/v4 v4.0.0-beta.1.0.20210726230805-d62fa1b2e3cb // indirect
github.com/rubenv/sql-migrate v0.0.0-20210614095031-55d5740dbbcc
github.com/stretchr/testify v1.7.0
github.com/tilinna/clock v1.1.0

@ -106,6 +106,8 @@ github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mediocregopher/mediocre-go-lib/v2 v2.0.0-beta.0 h1:i9FBkcCaWXxteJ8458AD8dBL2YqSxVlpsHOMWg5N9Dc=
github.com/mediocregopher/mediocre-go-lib/v2 v2.0.0-beta.0/go.mod h1:wOZVlnKYvIbkzyCJ3dxy1k40XkirvCd1pisX2O91qoQ=
github.com/mediocregopher/radix/v4 v4.0.0-beta.1.0.20210726230805-d62fa1b2e3cb h1:7Y2vAC5q44VJzbBUdxRUEqfz88ySJ/6yXXkpQ+sxke4=
github.com/mediocregopher/radix/v4 v4.0.0-beta.1.0.20210726230805-d62fa1b2e3cb/go.mod h1:ajchozX/6ELmydxWeWM6xCFHVpZ4+67LXHOTOVR0nCE=
github.com/mitchellh/cli v1.1.2/go.mod h1:6iaV0fGdElS6dPBx0EApTxHrcWvmJphyh2n8YBLPPZ4=
github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
@ -154,6 +156,7 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tilinna/clock v1.0.2/go.mod h1:ZsP7BcY7sEEz7ktc0IVy8Us6boDrK8VradlKRUGfOao=
github.com/tilinna/clock v1.1.0 h1:6IQQQCo6KoBxVudv6gwtY8o4eDfhHo8ojA5dP0MfhSs=
github.com/tilinna/clock v1.1.0/go.mod h1:ZsP7BcY7sEEz7ktc0IVy8Us6boDrK8VradlKRUGfOao=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=

Loading…
Cancel
Save