465 lines
10 KiB
Go
465 lines
10 KiB
Go
|
// Package chat implements a simple chatroom system.
|
||
|
package chat
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"strconv"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
|
||
|
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
|
||
|
"github.com/mediocregopher/radix/v4"
|
||
|
)
|
||
|
|
||
|
// ErrInvalidArg is returned from methods in this package when a call fails due
|
||
|
// to invalid input.
|
||
|
type ErrInvalidArg struct {
|
||
|
Err error
|
||
|
}
|
||
|
|
||
|
func (e ErrInvalidArg) Error() string {
|
||
|
return fmt.Sprintf("invalid argument: %v", e.Err)
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
errInvalidMessageID = ErrInvalidArg{Err: errors.New("invalid Message ID")}
|
||
|
)
|
||
|
|
||
|
// Message describes a message which has been posted to a Room.
|
||
|
type Message struct {
|
||
|
ID string `json:"id"`
|
||
|
UserID UserID `json:"userID"`
|
||
|
Body string `json:"body"`
|
||
|
}
|
||
|
|
||
|
func msgFromStreamEntry(entry radix.StreamEntry) (Message, error) {
|
||
|
|
||
|
// NOTE this should probably be a shortcut in radix
|
||
|
var bodyStr string
|
||
|
for _, field := range entry.Fields {
|
||
|
if field[0] == "json" {
|
||
|
bodyStr = field[1]
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if bodyStr == "" {
|
||
|
return Message{}, errors.New("no 'json' field")
|
||
|
}
|
||
|
|
||
|
var msg Message
|
||
|
if err := json.Unmarshal([]byte(bodyStr), &msg); err != nil {
|
||
|
return Message{}, fmt.Errorf(
|
||
|
"json unmarshaling body %q: %w", bodyStr, err,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
msg.ID = entry.ID.String()
|
||
|
return msg, nil
|
||
|
}
|
||
|
|
||
|
// MessageIterator returns a sequence of Messages which may or may not be
|
||
|
// unbounded.
|
||
|
type MessageIterator interface {
|
||
|
|
||
|
// Next blocks until it returns the next Message in the sequence, or the
|
||
|
// context error if the context is cancelled, or io.EOF if the sequence has
|
||
|
// been exhausted.
|
||
|
Next(context.Context) (Message, error)
|
||
|
|
||
|
// Close should always be called once Next has returned an error or the
|
||
|
// MessageIterator will no longer be used.
|
||
|
Close() error
|
||
|
}
|
||
|
|
||
|
// HistoryOpts are passed into Room's History method in order to affect its
|
||
|
// result. All fields are optional.
|
||
|
type HistoryOpts struct {
|
||
|
Limit int // defaults to, and is capped at, 100.
|
||
|
Cursor string // If not given then the most recent Messages are returned.
|
||
|
}
|
||
|
|
||
|
func (o HistoryOpts) sanitize() (HistoryOpts, error) {
|
||
|
if o.Limit < 0 || o.Limit > 100 {
|
||
|
o.Limit = 100
|
||
|
}
|
||
|
|
||
|
if o.Cursor != "" {
|
||
|
id, err := parseStreamEntryID(o.Cursor)
|
||
|
if err != nil {
|
||
|
return HistoryOpts{}, fmt.Errorf("parsing Cursor: %w", err)
|
||
|
}
|
||
|
o.Cursor = id.String()
|
||
|
}
|
||
|
|
||
|
return o, nil
|
||
|
}
|
||
|
|
||
|
// Room implements functionality related to a single, unique chat room.
|
||
|
type Room interface {
|
||
|
|
||
|
// Append accepts a new Message and stores it at the end of the room's
|
||
|
// history. The original Message is returned with any relevant fields (e.g.
|
||
|
// ID) updated.
|
||
|
Append(context.Context, Message) (Message, error)
|
||
|
|
||
|
// Returns a cursor and the list of historical Messages in time descending
|
||
|
// order. The cursor can be passed into the next call to History to receive
|
||
|
// the next set of Messages.
|
||
|
History(context.Context, HistoryOpts) (string, []Message, error)
|
||
|
|
||
|
// Listen returns a MessageIterator which will return all Messages appended
|
||
|
// to the Room since the given ID. Once all existing messages are iterated
|
||
|
// through then the MessageIterator will begin blocking until a new Message
|
||
|
// is posted.
|
||
|
Listen(ctx context.Context, sinceID string) (MessageIterator, error)
|
||
|
|
||
|
// Delete deletes a Message from the Room.
|
||
|
Delete(ctx context.Context, id string) error
|
||
|
|
||
|
// Close is used to clean up all resources created by the Room.
|
||
|
Close() error
|
||
|
}
|
||
|
|
||
|
// RoomParams are used to instantiate a new Room. All fields are required unless
|
||
|
// otherwise noted.
|
||
|
type RoomParams struct {
|
||
|
Logger *mlog.Logger
|
||
|
Redis radix.Client
|
||
|
ID string
|
||
|
MaxMessages int
|
||
|
}
|
||
|
|
||
|
func (p RoomParams) streamKey() string {
|
||
|
return fmt.Sprintf("chat:{%s}:stream", p.ID)
|
||
|
}
|
||
|
|
||
|
type room struct {
|
||
|
params RoomParams
|
||
|
|
||
|
closeCtx context.Context
|
||
|
closeCancel context.CancelFunc
|
||
|
wg sync.WaitGroup
|
||
|
|
||
|
listeningL sync.Mutex
|
||
|
listening map[chan Message]struct{}
|
||
|
listeningLastID radix.StreamEntryID
|
||
|
}
|
||
|
|
||
|
// NewRoom initializes and returns a new Room instance.
|
||
|
func NewRoom(ctx context.Context, params RoomParams) (Room, error) {
|
||
|
|
||
|
params.Logger = params.Logger.WithNamespace("chat-room")
|
||
|
|
||
|
r := &room{
|
||
|
params: params,
|
||
|
listening: map[chan Message]struct{}{},
|
||
|
}
|
||
|
|
||
|
r.closeCtx, r.closeCancel = context.WithCancel(context.Background())
|
||
|
|
||
|
// figure out the most recent message, if any.
|
||
|
lastEntryID, err := r.mostRecentMsgID(ctx)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("discovering most recent entry ID in stream: %w", err)
|
||
|
}
|
||
|
r.listeningLastID = lastEntryID
|
||
|
|
||
|
r.wg.Add(1)
|
||
|
go func() {
|
||
|
defer r.wg.Done()
|
||
|
r.readStreamLoop(r.closeCtx)
|
||
|
}()
|
||
|
|
||
|
return r, nil
|
||
|
}
|
||
|
|
||
|
func (r *room) Close() error {
|
||
|
r.closeCancel()
|
||
|
r.wg.Wait()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (r *room) mostRecentMsgID(ctx context.Context) (radix.StreamEntryID, error) {
|
||
|
|
||
|
var entries []radix.StreamEntry
|
||
|
err := r.params.Redis.Do(ctx, radix.Cmd(
|
||
|
&entries,
|
||
|
"XREVRANGE", r.params.streamKey(), "+", "-", "COUNT", "1",
|
||
|
))
|
||
|
|
||
|
if err != nil || len(entries) == 0 {
|
||
|
return radix.StreamEntryID{}, err
|
||
|
}
|
||
|
|
||
|
return entries[0].ID, nil
|
||
|
}
|
||
|
|
||
|
func (r *room) Append(ctx context.Context, msg Message) (Message, error) {
|
||
|
msg.ID = "" // just in case
|
||
|
|
||
|
b, err := json.Marshal(msg)
|
||
|
if err != nil {
|
||
|
return Message{}, fmt.Errorf("json marshaling Message: %w", err)
|
||
|
}
|
||
|
|
||
|
key := r.params.streamKey()
|
||
|
maxLen := strconv.Itoa(r.params.MaxMessages)
|
||
|
body := string(b)
|
||
|
|
||
|
var id string
|
||
|
|
||
|
err = r.params.Redis.Do(ctx, radix.Cmd(
|
||
|
&id, "XADD", key, "MAXLEN", "=", maxLen, "*", "json", body,
|
||
|
))
|
||
|
|
||
|
if err != nil {
|
||
|
return Message{}, fmt.Errorf("posting message to redis: %w", err)
|
||
|
}
|
||
|
|
||
|
msg.ID = id
|
||
|
return msg, nil
|
||
|
}
|
||
|
|
||
|
const zeroCursor = "0-0"
|
||
|
|
||
|
func (r *room) History(ctx context.Context, opts HistoryOpts) (string, []Message, error) {
|
||
|
opts, err := opts.sanitize()
|
||
|
if err != nil {
|
||
|
return "", nil, err
|
||
|
}
|
||
|
|
||
|
key := r.params.streamKey()
|
||
|
end := opts.Cursor
|
||
|
if end == "" {
|
||
|
end = "+"
|
||
|
}
|
||
|
start := "-"
|
||
|
count := strconv.Itoa(opts.Limit)
|
||
|
|
||
|
msgs := make([]Message, 0, opts.Limit)
|
||
|
streamEntries := make([]radix.StreamEntry, 0, opts.Limit)
|
||
|
|
||
|
err = r.params.Redis.Do(ctx, radix.Cmd(
|
||
|
&streamEntries,
|
||
|
"XREVRANGE", key, end, start, "COUNT", count,
|
||
|
))
|
||
|
|
||
|
if err != nil {
|
||
|
return "", nil, fmt.Errorf("calling XREVRANGE: %w", err)
|
||
|
}
|
||
|
|
||
|
var oldestEntryID radix.StreamEntryID
|
||
|
|
||
|
for _, entry := range streamEntries {
|
||
|
oldestEntryID = entry.ID
|
||
|
|
||
|
msg, err := msgFromStreamEntry(entry)
|
||
|
if err != nil {
|
||
|
return "", nil, fmt.Errorf(
|
||
|
"parsing stream entry %q: %w", entry.ID, err,
|
||
|
)
|
||
|
}
|
||
|
msgs = append(msgs, msg)
|
||
|
}
|
||
|
|
||
|
if len(msgs) < opts.Limit {
|
||
|
return zeroCursor, msgs, nil
|
||
|
}
|
||
|
|
||
|
cursor := oldestEntryID.Prev()
|
||
|
return cursor.String(), msgs, nil
|
||
|
}
|
||
|
|
||
|
func (r *room) readStream(ctx context.Context) error {
|
||
|
|
||
|
r.listeningL.Lock()
|
||
|
lastEntryID := r.listeningLastID
|
||
|
r.listeningL.Unlock()
|
||
|
|
||
|
redisAddr := r.params.Redis.Addr()
|
||
|
redisConn, err := radix.Dial(ctx, redisAddr.Network(), redisAddr.String())
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("creating redis connection: %w", err)
|
||
|
}
|
||
|
defer redisConn.Close()
|
||
|
|
||
|
streamReader := (radix.StreamReaderConfig{}).New(
|
||
|
redisConn,
|
||
|
map[string]radix.StreamConfig{
|
||
|
r.params.streamKey(): {After: lastEntryID},
|
||
|
},
|
||
|
)
|
||
|
|
||
|
for {
|
||
|
dlCtx, dlCtxCancel := context.WithTimeout(ctx, 10*time.Second)
|
||
|
_, streamEntry, err := streamReader.Next(dlCtx)
|
||
|
dlCtxCancel()
|
||
|
|
||
|
if errors.Is(err, radix.ErrNoStreamEntries) {
|
||
|
continue
|
||
|
} else if err != nil {
|
||
|
return fmt.Errorf("fetching next entry from stream: %w", err)
|
||
|
}
|
||
|
|
||
|
msg, err := msgFromStreamEntry(streamEntry)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("parsing stream entry %q: %w", streamEntry, err)
|
||
|
}
|
||
|
|
||
|
r.listeningL.Lock()
|
||
|
|
||
|
var dropped int
|
||
|
for ch := range r.listening {
|
||
|
select {
|
||
|
case ch <- msg:
|
||
|
default:
|
||
|
dropped++
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if dropped > 0 {
|
||
|
ctx := mctx.Annotate(ctx, "msgID", msg.ID, "dropped", dropped)
|
||
|
r.params.Logger.WarnString(ctx, "some listening channels full, messages dropped")
|
||
|
}
|
||
|
|
||
|
r.listeningLastID = streamEntry.ID
|
||
|
|
||
|
r.listeningL.Unlock()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (r *room) readStreamLoop(ctx context.Context) {
|
||
|
for {
|
||
|
err := r.readStream(ctx)
|
||
|
if errors.Is(err, context.Canceled) {
|
||
|
return
|
||
|
} else if err != nil {
|
||
|
r.params.Logger.Error(ctx, "reading from redis stream", err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type listenMsgIterator struct {
|
||
|
ch <-chan Message
|
||
|
missedMsgs []Message
|
||
|
sinceEntryID radix.StreamEntryID
|
||
|
cleanup func()
|
||
|
}
|
||
|
|
||
|
func (i *listenMsgIterator) Next(ctx context.Context) (Message, error) {
|
||
|
|
||
|
if len(i.missedMsgs) > 0 {
|
||
|
msg := i.missedMsgs[0]
|
||
|
i.missedMsgs = i.missedMsgs[1:]
|
||
|
return msg, nil
|
||
|
}
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return Message{}, ctx.Err()
|
||
|
case msg := <-i.ch:
|
||
|
|
||
|
entryID, err := parseStreamEntryID(msg.ID)
|
||
|
if err != nil {
|
||
|
return Message{}, fmt.Errorf("parsing Message ID %q: %w", msg.ID, err)
|
||
|
|
||
|
} else if !i.sinceEntryID.Before(entryID) {
|
||
|
// this can happen if someone Appends a Message at the same time
|
||
|
// as another calls Listen. The Listener might have already seen
|
||
|
// the Message by calling History prior to the stream reader
|
||
|
// having processed it and updating listeningLastID.
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
return msg, nil
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (i *listenMsgIterator) Close() error {
|
||
|
i.cleanup()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (r *room) Listen(
|
||
|
ctx context.Context, sinceID string,
|
||
|
) (
|
||
|
MessageIterator, error,
|
||
|
) {
|
||
|
|
||
|
var sinceEntryID radix.StreamEntryID
|
||
|
|
||
|
if sinceID != "" {
|
||
|
var err error
|
||
|
if sinceEntryID, err = parseStreamEntryID(sinceID); err != nil {
|
||
|
return nil, fmt.Errorf("parsing sinceID: %w", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
ch := make(chan Message, 32)
|
||
|
|
||
|
r.listeningL.Lock()
|
||
|
lastEntryID := r.listeningLastID
|
||
|
r.listening[ch] = struct{}{}
|
||
|
r.listeningL.Unlock()
|
||
|
|
||
|
cleanup := func() {
|
||
|
r.listeningL.Lock()
|
||
|
defer r.listeningL.Unlock()
|
||
|
delete(r.listening, ch)
|
||
|
}
|
||
|
|
||
|
key := r.params.streamKey()
|
||
|
start := sinceEntryID.Next().String()
|
||
|
end := "+"
|
||
|
if lastEntryID != (radix.StreamEntryID{}) {
|
||
|
end = lastEntryID.String()
|
||
|
}
|
||
|
|
||
|
var streamEntries []radix.StreamEntry
|
||
|
|
||
|
err := r.params.Redis.Do(ctx, radix.Cmd(
|
||
|
&streamEntries,
|
||
|
"XRANGE", key, start, end,
|
||
|
))
|
||
|
|
||
|
if err != nil {
|
||
|
cleanup()
|
||
|
return nil, fmt.Errorf("retrieving missed stream entries: %w", err)
|
||
|
}
|
||
|
|
||
|
missedMsgs := make([]Message, len(streamEntries))
|
||
|
|
||
|
for i := range streamEntries {
|
||
|
|
||
|
msg, err := msgFromStreamEntry(streamEntries[i])
|
||
|
if err != nil {
|
||
|
cleanup()
|
||
|
return nil, fmt.Errorf(
|
||
|
"parsing stream entry %q: %w", streamEntries[i].ID, err,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
missedMsgs[i] = msg
|
||
|
}
|
||
|
|
||
|
return &listenMsgIterator{
|
||
|
ch: ch,
|
||
|
missedMsgs: missedMsgs,
|
||
|
sinceEntryID: sinceEntryID,
|
||
|
cleanup: cleanup,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (r *room) Delete(ctx context.Context, id string) error {
|
||
|
return r.params.Redis.Do(ctx, radix.Cmd(
|
||
|
nil, "XDEL", r.params.streamKey(), id,
|
||
|
))
|
||
|
}
|