2021-08-30 04:15:58 +00:00
|
|
|
package api
|
|
|
|
|
|
|
|
import (
|
|
|
|
"errors"
|
|
|
|
"net/http"
|
2021-08-31 02:08:51 +00:00
|
|
|
|
|
|
|
"github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
|
2021-08-30 04:15:58 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
csrfTokenCookieName = "csrf_token"
|
|
|
|
csrfTokenHeaderName = "X-CSRF-Token"
|
|
|
|
)
|
|
|
|
|
|
|
|
func setCSRFMiddleware(h http.Handler) http.Handler {
|
|
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
|
|
|
2021-08-31 02:08:51 +00:00
|
|
|
csrfTok, err := apiutils.GetCookie(r, csrfTokenCookieName, "")
|
2021-08-30 04:15:58 +00:00
|
|
|
|
|
|
|
if err != nil {
|
2021-08-31 02:08:51 +00:00
|
|
|
apiutils.InternalServerError(rw, r, err)
|
2021-08-30 04:15:58 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
} else if csrfTok == "" {
|
|
|
|
http.SetCookie(rw, &http.Cookie{
|
|
|
|
Name: csrfTokenCookieName,
|
2021-08-31 02:08:51 +00:00
|
|
|
Value: apiutils.RandStr(32),
|
2021-08-30 04:15:58 +00:00
|
|
|
Secure: true,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
h.ServeHTTP(rw, r)
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func checkCSRFMiddleware(h http.Handler) http.Handler {
|
|
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
|
|
|
2021-08-31 02:08:51 +00:00
|
|
|
csrfTok, err := apiutils.GetCookie(r, csrfTokenCookieName, "")
|
2021-08-30 04:15:58 +00:00
|
|
|
|
|
|
|
if err != nil {
|
2021-08-31 02:08:51 +00:00
|
|
|
apiutils.InternalServerError(rw, r, err)
|
2021-08-30 04:15:58 +00:00
|
|
|
return
|
2021-09-02 23:02:20 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
givenCSRFTok := r.Header.Get(csrfTokenHeaderName)
|
|
|
|
if givenCSRFTok == "" {
|
|
|
|
givenCSRFTok = r.FormValue("csrfToken")
|
|
|
|
}
|
2021-08-30 04:15:58 +00:00
|
|
|
|
2021-09-02 23:02:20 +00:00
|
|
|
if csrfTok == "" || givenCSRFTok != csrfTok {
|
2021-08-31 02:08:51 +00:00
|
|
|
apiutils.BadRequest(rw, r, errors.New("invalid CSRF token"))
|
2021-08-30 04:15:58 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
h.ServeHTTP(rw, r)
|
|
|
|
})
|
|
|
|
}
|