add chat handlers and only allow POST methods

This commit is contained in:
Brian Picciano 2021-08-30 20:08:51 -06:00
parent 3e9a17abb9
commit 9343d2ea69
16 changed files with 291 additions and 138 deletions

View File

@ -26,7 +26,7 @@ type Params struct {
PowManager pow.Manager
MailingList mailinglist.MailingList
GlobalRoom chat.Room
UserIDCalculator chat.UserIDCalculator
UserIDCalculator *chat.UserIDCalculator
// ListenProto and ListenAddr are passed into net.Listen to create the
// API's listener. Both "tcp" and "unix" protocols are explicitly
@ -165,7 +165,14 @@ func (a *api) handler() http.Handler {
apiMux.Handle("/mailinglist/finalize", a.mailingListFinalizeHandler())
apiMux.Handle("/mailinglist/unsubscribe", a.mailingListUnsubscribeHandler())
apiMux.Handle("/chat/global/", http.StripPrefix("/chat/global", newChatHandler(
a.params.GlobalRoom,
a.params.UserIDCalculator,
a.requirePowMiddleware,
)))
var apiHandler http.Handler = apiMux
apiHandler = allowedMethod("POST", apiHandler)
apiHandler = checkCSRFMiddleware(apiHandler)
apiHandler = logMiddleware(a.params.Logger, apiHandler)
apiHandler = annotateMiddleware(apiHandler)

View File

@ -0,0 +1,112 @@
// Package apiutils contains utilities which are useful for implementing api
// endpoints.
package apiutils
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
)
type loggerCtxKey int
// SetRequestLogger sets the given Logger onto the given Request's Context,
// returning a copy.
func SetRequestLogger(r *http.Request, logger *mlog.Logger) *http.Request {
ctx := r.Context()
ctx = context.WithValue(ctx, loggerCtxKey(0), logger)
return r.WithContext(ctx)
}
// GetRequestLogger returns the Logger which was set by SetRequestLogger onto
// this Request, or nil.
func GetRequestLogger(r *http.Request) *mlog.Logger {
ctx := r.Context()
logger, _ := ctx.Value(loggerCtxKey(0)).(*mlog.Logger)
if logger == nil {
logger = mlog.Null
}
return logger
}
// JSONResult writes the JSON encoding of the given value as the response body.
func JSONResult(rw http.ResponseWriter, r *http.Request, v interface{}) {
b, err := json.Marshal(v)
if err != nil {
InternalServerError(rw, r, err)
return
}
b = append(b, '\n')
rw.Header().Set("Content-Type", "application/json")
rw.Write(b)
}
// BadRequest writes a 400 status and a JSON encoded error struct containing the
// given error as the response body.
func BadRequest(rw http.ResponseWriter, r *http.Request, err error) {
GetRequestLogger(r).Warn(r.Context(), "bad request", err)
rw.WriteHeader(400)
JSONResult(rw, r, struct {
Error string `json:"error"`
}{
Error: err.Error(),
})
}
// InternalServerError writes a 500 status and a JSON encoded error struct
// containing a generic error as the response body (though it will log the given
// one).
func InternalServerError(rw http.ResponseWriter, r *http.Request, err error) {
GetRequestLogger(r).Error(r.Context(), "internal server error", err)
rw.WriteHeader(500)
JSONResult(rw, r, struct {
Error string `json:"error"`
}{
Error: "internal server error",
})
}
// StrToInt parses the given string as an integer, or returns the given default
// integer if the string is empty.
func StrToInt(str string, defaultVal int) (int, error) {
if str == "" {
return defaultVal, nil
}
return strconv.Atoi(str)
}
// GetCookie returns the namd cookie's value, or the given default value if the
// cookie is not set.
//
// This will only return an error if there was an unexpected error parsing the
// Request's cookies.
func GetCookie(r *http.Request, cookieName, defaultVal string) (string, error) {
c, err := r.Cookie(cookieName)
if errors.Is(err, http.ErrNoCookie) {
return defaultVal, nil
} else if err != nil {
return "", fmt.Errorf("reading cookie %q: %w", cookieName, err)
}
return c.Value, nil
}
// RandStr returns a human-readable random string with the given number of bytes
// of randomness.
func RandStr(numBytes int) string {
b := make([]byte, numBytes)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return hex.EncodeToString(b)
}

View File

@ -7,32 +7,57 @@ import (
"strings"
"unicode"
"github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
"github.com/mediocregopher/blog.mediocregopher.com/srv/chat"
)
func (a *api) chatHistoryHandler() http.Handler {
type chatHandler struct {
*http.ServeMux
room chat.Room
userIDCalc *chat.UserIDCalculator
}
func newChatHandler(
room chat.Room, userIDCalc *chat.UserIDCalculator,
requirePowMiddleware func(http.Handler) http.Handler,
) http.Handler {
c := &chatHandler{
ServeMux: http.NewServeMux(),
room: room,
userIDCalc: userIDCalc,
}
c.Handle("/history", c.historyHandler())
c.Handle("/user-id", requirePowMiddleware(c.userIDHandler()))
c.Handle("/append", requirePowMiddleware(c.appendHandler()))
return c
}
func (c *chatHandler) historyHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
limit, err := strToInt(r.FormValue("limit"), 0)
limit, err := apiutils.StrToInt(r.PostFormValue("limit"), 0)
if err != nil {
badRequest(rw, r, fmt.Errorf("invalid limit parameter: %w", err))
apiutils.BadRequest(rw, r, fmt.Errorf("invalid limit parameter: %w", err))
return
}
cursor := r.FormValue("cursor")
cursor := r.PostFormValue("cursor")
cursor, msgs, err := a.params.GlobalRoom.History(r.Context(), chat.HistoryOpts{
cursor, msgs, err := c.room.History(r.Context(), chat.HistoryOpts{
Limit: limit,
Cursor: cursor,
})
if argErr := (chat.ErrInvalidArg{}); errors.As(err, &argErr) {
badRequest(rw, r, argErr.Err)
apiutils.BadRequest(rw, r, argErr.Err)
return
} else if err != nil {
internalServerError(rw, r, err)
apiutils.InternalServerError(rw, r, err)
}
jsonResult(rw, r, struct {
apiutils.JSONResult(rw, r, struct {
Cursor string `json:"cursor"`
Messages []chat.Message `json:"messages"`
}{
@ -42,7 +67,7 @@ func (a *api) chatHistoryHandler() http.Handler {
})
}
func (a *api) getUserID(r *http.Request) (chat.UserID, error) {
func (c *chatHandler) userID(r *http.Request) (chat.UserID, error) {
name := r.PostFormValue("name")
if l := len(name); l == 0 {
return chat.UserID{}, errors.New("name is required")
@ -68,21 +93,58 @@ func (a *api) getUserID(r *http.Request) (chat.UserID, error) {
return chat.UserID{}, errors.New("password too long")
}
return a.params.UserIDCalculator.Calculate(name, password), nil
return c.userIDCalc.Calculate(name, password), nil
}
func (a *api) getUserIDHandler() http.Handler {
func (c *chatHandler) userIDHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
userID, err := a.getUserID(r)
userID, err := c.userID(r)
if err != nil {
badRequest(rw, r, err)
apiutils.BadRequest(rw, r, err)
return
}
jsonResult(rw, r, struct {
apiutils.JSONResult(rw, r, struct {
UserID chat.UserID `json:"userID"`
}{
UserID: userID,
})
})
}
func (c *chatHandler) appendHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
userID, err := c.userID(r)
if err != nil {
apiutils.BadRequest(rw, r, err)
return
}
body := r.PostFormValue("body")
if l := len(body); l == 0 {
apiutils.BadRequest(rw, r, errors.New("body is required"))
return
} else if l > 300 {
apiutils.BadRequest(rw, r, errors.New("body too long"))
return
}
msg, err := c.room.Append(r.Context(), chat.Message{
UserID: userID,
Body: body,
})
if err != nil {
apiutils.InternalServerError(rw, r, err)
return
}
apiutils.JSONResult(rw, r, struct {
MessageID string `json:"messageID"`
}{
MessageID: msg.ID,
})
})
}

View File

@ -3,6 +3,8 @@ package api
import (
"errors"
"net/http"
"github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
)
const (
@ -13,16 +15,16 @@ const (
func setCSRFMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
csrfTok, err := getCookie(r, csrfTokenCookieName, "")
csrfTok, err := apiutils.GetCookie(r, csrfTokenCookieName, "")
if err != nil {
internalServerError(rw, r, err)
apiutils.InternalServerError(rw, r, err)
return
} else if csrfTok == "" {
http.SetCookie(rw, &http.Cookie{
Name: csrfTokenCookieName,
Value: randStr(32),
Value: apiutils.RandStr(32),
Secure: true,
})
}
@ -34,14 +36,14 @@ func setCSRFMiddleware(h http.Handler) http.Handler {
func checkCSRFMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
csrfTok, err := getCookie(r, csrfTokenCookieName, "")
csrfTok, err := apiutils.GetCookie(r, csrfTokenCookieName, "")
if err != nil {
internalServerError(rw, r, err)
apiutils.InternalServerError(rw, r, err)
return
} else if csrfTok == "" || r.Header.Get(csrfTokenHeaderName) != csrfTok {
badRequest(rw, r, errors.New("invalid CSRF token"))
apiutils.BadRequest(rw, r, errors.New("invalid CSRF token"))
return
}

View File

@ -5,6 +5,7 @@ import (
"net/http"
"strings"
"github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
"github.com/mediocregopher/blog.mediocregopher.com/srv/mailinglist"
)
@ -15,7 +16,7 @@ func (a *api) mailingListSubscribeHandler() http.Handler {
parts[0] == "" ||
parts[1] == "" ||
len(email) >= 512 {
badRequest(rw, r, errors.New("invalid email"))
apiutils.BadRequest(rw, r, errors.New("invalid email"))
return
}
@ -25,11 +26,11 @@ func (a *api) mailingListSubscribeHandler() http.Handler {
// just eat the error, make it look to the user like the
// verification email was sent.
} else if err != nil {
internalServerError(rw, r, err)
apiutils.InternalServerError(rw, r, err)
return
}
jsonResult(rw, r, struct{}{})
apiutils.JSONResult(rw, r, struct{}{})
})
}
@ -39,25 +40,25 @@ func (a *api) mailingListFinalizeHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
subToken := r.PostFormValue("subToken")
if l := len(subToken); l == 0 || l > 128 {
badRequest(rw, r, errInvalidSubToken)
apiutils.BadRequest(rw, r, errInvalidSubToken)
return
}
err := a.params.MailingList.FinalizeSubscription(subToken)
if errors.Is(err, mailinglist.ErrNotFound) {
badRequest(rw, r, errInvalidSubToken)
apiutils.BadRequest(rw, r, errInvalidSubToken)
return
} else if errors.Is(err, mailinglist.ErrAlreadyVerified) {
// no problem
} else if err != nil {
internalServerError(rw, r, err)
apiutils.InternalServerError(rw, r, err)
return
}
jsonResult(rw, r, struct{}{})
apiutils.JSONResult(rw, r, struct{}{})
})
}
@ -67,21 +68,21 @@ func (a *api) mailingListUnsubscribeHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
unsubToken := r.PostFormValue("unsubToken")
if l := len(unsubToken); l == 0 || l > 128 {
badRequest(rw, r, errInvalidUnsubToken)
apiutils.BadRequest(rw, r, errInvalidUnsubToken)
return
}
err := a.params.MailingList.Unsubscribe(unsubToken)
if errors.Is(err, mailinglist.ErrNotFound) {
badRequest(rw, r, errInvalidUnsubToken)
apiutils.BadRequest(rw, r, errInvalidUnsubToken)
return
} else if err != nil {
internalServerError(rw, r, err)
apiutils.InternalServerError(rw, r, err)
return
}
jsonResult(rw, r, struct{}{})
apiutils.JSONResult(rw, r, struct{}{})
})
}

View File

@ -5,6 +5,7 @@ import (
"net/http"
"time"
"github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
)
@ -57,7 +58,7 @@ func (lrw *logResponseWriter) WriteHeader(statusCode int) {
func logMiddleware(logger *mlog.Logger, h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
r = setRequestLogger(r, logger)
r = apiutils.SetRequestLogger(r, logger)
lrw := newLogResponseWriter(rw)
@ -76,3 +77,15 @@ func logMiddleware(logger *mlog.Logger, h http.Handler) http.Handler {
logger.Info(ctx, "handled HTTP request")
})
}
func allowedMethod(method string, h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if r.Method == method {
h.ServeHTTP(rw, r)
return
}
apiutils.GetRequestLogger(r).WarnString(r.Context(), "method not allowed")
rw.WriteHeader(405)
})
}

View File

@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"net/http"
"github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
)
func (a *api) newPowChallengeHandler() http.Handler {
@ -12,7 +14,7 @@ func (a *api) newPowChallengeHandler() http.Handler {
challenge := a.params.PowManager.NewChallenge()
jsonResult(rw, r, struct {
apiutils.JSONResult(rw, r, struct {
Seed string `json:"seed"`
Target uint32 `json:"target"`
}{
@ -28,21 +30,21 @@ func (a *api) requirePowMiddleware(h http.Handler) http.Handler {
seedHex := r.PostFormValue("powSeed")
seed, err := hex.DecodeString(seedHex)
if err != nil || len(seed) == 0 {
badRequest(rw, r, errors.New("invalid powSeed"))
apiutils.BadRequest(rw, r, errors.New("invalid powSeed"))
return
}
solutionHex := r.PostFormValue("powSolution")
solution, err := hex.DecodeString(solutionHex)
if err != nil || len(seed) == 0 {
badRequest(rw, r, errors.New("invalid powSolution"))
apiutils.BadRequest(rw, r, errors.New("invalid powSolution"))
return
}
err = a.params.PowManager.CheckSolution(seed, solution)
if err != nil {
badRequest(rw, r, fmt.Errorf("checking proof-of-work solution: %w", err))
apiutils.BadRequest(rw, r, fmt.Errorf("checking proof-of-work solution: %w", err))
return
}

View File

@ -1,91 +0,0 @@
package api
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
)
type loggerCtxKey int
func setRequestLogger(r *http.Request, logger *mlog.Logger) *http.Request {
ctx := r.Context()
ctx = context.WithValue(ctx, loggerCtxKey(0), logger)
return r.WithContext(ctx)
}
func getRequestLogger(r *http.Request) *mlog.Logger {
ctx := r.Context()
logger, _ := ctx.Value(loggerCtxKey(0)).(*mlog.Logger)
if logger == nil {
logger = mlog.Null
}
return logger
}
func jsonResult(rw http.ResponseWriter, r *http.Request, v interface{}) {
b, err := json.Marshal(v)
if err != nil {
internalServerError(rw, r, err)
return
}
b = append(b, '\n')
rw.Header().Set("Content-Type", "application/json")
rw.Write(b)
}
func badRequest(rw http.ResponseWriter, r *http.Request, err error) {
getRequestLogger(r).Warn(r.Context(), "bad request", err)
rw.WriteHeader(400)
jsonResult(rw, r, struct {
Error string `json:"error"`
}{
Error: err.Error(),
})
}
func internalServerError(rw http.ResponseWriter, r *http.Request, err error) {
getRequestLogger(r).Error(r.Context(), "internal server error", err)
rw.WriteHeader(500)
jsonResult(rw, r, struct {
Error string `json:"error"`
}{
Error: "internal server error",
})
}
func strToInt(str string, defaultVal int) (int, error) {
if str == "" {
return defaultVal, nil
}
return strconv.Atoi(str)
}
func getCookie(r *http.Request, cookieName, defaultVal string) (string, error) {
c, err := r.Cookie(cookieName)
if errors.Is(err, http.ErrNoCookie) {
return defaultVal, nil
} else if err != nil {
return "", fmt.Errorf("reading cookie %q: %w", cookieName, err)
}
return c.Value, nil
}
func randStr(numBytesEntropy int) string {
b := make([]byte, numBytesEntropy)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return hex.EncodeToString(b)
}

View File

@ -84,7 +84,7 @@ type HistoryOpts struct {
}
func (o HistoryOpts) sanitize() (HistoryOpts, error) {
if o.Limit < 0 || o.Limit > 100 {
if o.Limit <= 0 || o.Limit > 100 {
o.Limit = 100
}

View File

@ -38,8 +38,8 @@ type UserIDCalculator struct {
}
// NewUserIDCalculator returns a UserIDCalculator with sane defaults.
func NewUserIDCalculator(secret []byte) UserIDCalculator {
return UserIDCalculator{
func NewUserIDCalculator(secret []byte) *UserIDCalculator {
return &UserIDCalculator{
Secret: secret,
TimeCost: 15,
MemoryCost: 128 * 1024,
@ -50,7 +50,7 @@ func NewUserIDCalculator(secret []byte) UserIDCalculator {
}
// Calculate accepts a name and password and returns the calculated UserID.
func (c UserIDCalculator) Calculate(name, password string) UserID {
func (c *UserIDCalculator) Calculate(name, password string) UserID {
input := fmt.Sprintf("%q:%q", name, password)

View File

@ -11,10 +11,12 @@ import (
"github.com/mediocregopher/blog.mediocregopher.com/srv/api"
"github.com/mediocregopher/blog.mediocregopher.com/srv/cfg"
"github.com/mediocregopher/blog.mediocregopher.com/srv/chat"
"github.com/mediocregopher/blog.mediocregopher.com/srv/mailinglist"
"github.com/mediocregopher/blog.mediocregopher.com/srv/pow"
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
"github.com/mediocregopher/radix/v4"
"github.com/tilinna/clock"
)
@ -45,6 +47,13 @@ func main() {
apiParams.SetupCfg(cfg)
ctx = mctx.WithAnnotator(ctx, &apiParams)
redisProto := cfg.String("redis-proto", "tcp", "Network protocol to connect to redis over, can be tcp or unix")
redisAddr := cfg.String("redis-addr", "127.0.0.1:6379", "Address redis is expected to listen on")
redisPoolSize := cfg.Int("redis-pool-size", 5, "Number of connections in the redis pool to keep")
chatGlobalRoomMaxMsgs := cfg.Int("chat-global-room-max-messages", 1000, "Maximum number of messages the global chat room can retain")
chatUserIDCalcSecret := cfg.String("chat-user-id-calc-secret", "", "Secret to use when calculating user ids")
// initialization
err := cfg.Init(ctx)
@ -60,6 +69,10 @@ func main() {
ctx = mctx.Annotate(ctx,
"dataDir", *dataDir,
"redisProto", *redisProto,
"redisAddr", *redisAddr,
"redisPoolSize", *redisPoolSize,
"chatGlobalRoomMaxMsgs", *chatGlobalRoomMaxMsgs,
)
clock := clock.Realtime()
@ -92,9 +105,35 @@ func main() {
ml := mailinglist.New(mlParams)
redis, err := (radix.PoolConfig{
Size: *redisPoolSize,
}).New(
ctx, *redisProto, *redisAddr,
)
if err != nil {
loggerFatalErr(ctx, logger, "initializing redis pool", err)
}
defer redis.Close()
chatGlobalRoom, err := chat.NewRoom(ctx, chat.RoomParams{
Logger: logger.WithNamespace("global-chat-room"),
Redis: redis,
ID: "global",
MaxMessages: *chatGlobalRoomMaxMsgs,
})
if err != nil {
loggerFatalErr(ctx, logger, "initializing global chat room", err)
}
defer chatGlobalRoom.Close()
chatUserIDCalc := chat.NewUserIDCalculator([]byte(*chatUserIDCalcSecret))
apiParams.Logger = logger.WithNamespace("api")
apiParams.PowManager = powMgr
apiParams.MailingList = ml
apiParams.GlobalRoom = chatGlobalRoom
apiParams.UserIDCalculator = chatUserIDCalc
logger.Info(ctx, "listening")
a, err := api.New(apiParams)

View File

@ -11,6 +11,8 @@
"-pow-secret=${config.powSecret}"
"-listen-proto=${config.listenProto}"
"-listen-addr=${config.listenAddr}"
"-redis-proto=unix"
"-redis-addr=${config.runDir}/redis"
] ++ (
if config.staticProxyURL == ""
then [ "-static-dir=${staticBuild}" ]

View File

@ -31,7 +31,7 @@ const doFetch = async (req) => {
// may throw
const solvePow = async () => {
const res = await call('GET', '/api/pow/challenge');
const res = await call('/api/pow/challenge');
const worker = new Worker('/assets/solvePow.js');
@ -46,8 +46,12 @@ const solvePow = async () => {
return {seed: res.seed, solution: powSol};
}
const call = async (method, route, opts = {}) => {
const { body = {}, requiresPow = false } = opts;
const call = async (route, opts = {}) => {
const {
method = 'POST',
body = {},
requiresPow = false,
} = opts;
if (!utils.cookies["csrf_token"])
throw "csrf_token cookie not set, can't make api call";

View File

@ -63,7 +63,7 @@ emailSubscribe.onclick = async () => {
throw "The browser environment is not secure.";
}
await api.call('POST', '/api/mailinglist/subscribe', {
await api.call('/api/mailinglist/subscribe', {
body: { email: emailAddress.value },
requiresPow: true,
});

View File

@ -28,7 +28,7 @@ nofollow: true
const api = await import("/assets/api.js");
await api.call('POST', '/api/mailinglist/finalize', {
await api.call('/api/mailinglist/finalize', {
body: { subToken },
});

View File

@ -27,7 +27,7 @@ nofollow: true
const api = await import("/assets/api.js");
await api.call('POST', '/api/mailinglist/unsubscribe', {
await api.call('/api/mailinglist/unsubscribe', {
body: { unsubToken },
});