diff --git a/srv/api/api.go b/srv/api/api.go index 39d73d9..bbb677a 100644 --- a/srv/api/api.go +++ b/srv/api/api.go @@ -142,6 +142,8 @@ func (a *api) handler() http.Handler { staticHandler = httputil.NewSingleHostReverseProxy(a.params.StaticProxy) } + staticHandler = setCSRFMiddleware(staticHandler) + // sugar requirePow := func(h http.Handler) http.Handler { return a.requirePowMiddleware(h) @@ -163,7 +165,9 @@ func (a *api) handler() http.Handler { apiMux.Handle("/mailinglist/finalize", a.mailingListFinalizeHandler()) apiMux.Handle("/mailinglist/unsubscribe", a.mailingListUnsubscribeHandler()) - apiHandler := logMiddleware(a.params.Logger, apiMux) + var apiHandler http.Handler = apiMux + apiHandler = checkCSRFMiddleware(apiHandler) + apiHandler = logMiddleware(a.params.Logger, apiHandler) apiHandler = annotateMiddleware(apiHandler) apiHandler = addResponseHeaders(map[string]string{ "Cache-Control": "no-store, max-age=0", diff --git a/srv/api/csrf.go b/srv/api/csrf.go new file mode 100644 index 0000000..d705adb --- /dev/null +++ b/srv/api/csrf.go @@ -0,0 +1,50 @@ +package api + +import ( + "errors" + "net/http" +) + +const ( + csrfTokenCookieName = "csrf_token" + csrfTokenHeaderName = "X-CSRF-Token" +) + +func setCSRFMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + csrfTok, err := getCookie(r, csrfTokenCookieName, "") + + if err != nil { + internalServerError(rw, r, err) + return + + } else if csrfTok == "" { + http.SetCookie(rw, &http.Cookie{ + Name: csrfTokenCookieName, + Value: randStr(32), + Secure: true, + }) + } + + h.ServeHTTP(rw, r) + }) +} + +func checkCSRFMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + csrfTok, err := getCookie(r, csrfTokenCookieName, "") + + if err != nil { + internalServerError(rw, r, err) + return + + } else if csrfTok == "" || r.Header.Get(csrfTokenHeaderName) != csrfTok { + badRequest(rw, r, errors.New("invalid CSRF token")) + return + } + + h.ServeHTTP(rw, r) + }) +} diff --git a/srv/api/utils.go b/srv/api/utils.go index 7662e17..2cf40b6 100644 --- a/srv/api/utils.go +++ b/srv/api/utils.go @@ -2,7 +2,11 @@ package api import ( "context" + "crypto/rand" + "encoding/hex" "encoding/json" + "errors" + "fmt" "net/http" "strconv" @@ -66,3 +70,22 @@ func strToInt(str string, defaultVal int) (int, error) { } return strconv.Atoi(str) } + +func getCookie(r *http.Request, cookieName, defaultVal string) (string, error) { + c, err := r.Cookie(cookieName) + if errors.Is(err, http.ErrNoCookie) { + return defaultVal, nil + } else if err != nil { + return "", fmt.Errorf("reading cookie %q: %w", cookieName, err) + } + + return c.Value, nil +} + +func randStr(numBytesEntropy int) string { + b := make([]byte, numBytesEntropy) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return hex.EncodeToString(b) +} diff --git a/static/src/assets/api.js b/static/src/assets/api.js index bec2740..b591764 100644 --- a/static/src/assets/api.js +++ b/static/src/assets/api.js @@ -1,3 +1,4 @@ +import * as utils from "/assets/utils.js"; const doFetch = async (req) => { let res, jsonRes; @@ -48,7 +49,15 @@ const solvePow = async () => { const call = async (method, route, opts = {}) => { const { body = {}, requiresPow = false } = opts; - const reqOpts = { method }; + if (!utils.cookies["csrf_token"]) + throw "csrf_token cookie not set, can't make api call"; + + const reqOpts = { + method, + headers: { + "X-CSRF-Token": utils.cookies["csrf_token"], + }, + }; if (requiresPow) { const {seed, solution} = await solvePow(); @@ -57,7 +66,6 @@ const call = async (method, route, opts = {}) => { } if (Object.keys(body).length > 0) { - const form = new FormData(); for (const key in body) form.append(key, body[key]); diff --git a/static/src/assets/utils.js b/static/src/assets/utils.js new file mode 100644 index 0000000..96a2950 --- /dev/null +++ b/static/src/assets/utils.js @@ -0,0 +1,12 @@ +const cookies = {}; +const cookieKVs = document.cookie + .split(';') + .map(cookie => cookie.trim().split('=', 2)); + +for (const i in cookieKVs) { + cookies[cookieKVs[i][0]] = cookieKVs[i][1]; +} + +export { + cookies, +}