Replace CSRF token checking with Referer checking

This commit is contained in:
Brian Picciano 2022-05-24 17:42:00 -06:00
parent 159638084e
commit 08811a6da7
7 changed files with 10 additions and 105 deletions

View File

@ -164,8 +164,6 @@ func (a *api) Shutdown(ctx context.Context) error {
func (a *api) apiHandler() http.Handler { func (a *api) apiHandler() http.Handler {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/csrf", a.getCSRFTokenHandler())
mux.Handle("/pow/challenge", a.newPowChallengeHandler()) mux.Handle("/pow/challenge", a.newPowChallengeHandler())
mux.Handle("/pow/check", mux.Handle("/pow/check",
a.requirePowMiddleware( a.requirePowMiddleware(
@ -250,11 +248,10 @@ func (a *api) handler() http.Handler {
h := apiutil.MethodMux(map[string]http.Handler{ h := apiutil.MethodMux(map[string]http.Handler{
"GET": applyMiddlewares( "GET": applyMiddlewares(
mux, mux,
setCSRFMiddleware,
), ),
"*": applyMiddlewares( "*": applyMiddlewares(
mux, mux,
checkCSRFMiddleware, a.checkCSRFMiddleware,
addResponseHeadersMiddleware(map[string]string{ addResponseHeadersMiddleware(map[string]string{
"Cache-Control": "no-store, max-age=0", "Cache-Control": "no-store, max-age=0",
"Pragma": "no-cache", "Pragma": "no-cache",

View File

@ -3,76 +3,26 @@ package http
import ( import (
"errors" "errors"
"net/http" "net/http"
"net/url"
"github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil" "github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil"
) )
const ( func (a *api) checkCSRFMiddleware(h http.Handler) http.Handler {
csrfTokenCookieName = "csrf_token"
csrfTokenHeaderName = "X-CSRF-Token"
csrfTokenFormName = "csrfToken"
)
func setCSRFMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
csrfTok, err := apiutil.GetCookie(r, csrfTokenCookieName, "") refererURL, err := url.Parse(r.Referer())
if err != nil { if err != nil {
apiutil.InternalServerError(rw, r, err) apiutil.BadRequest(rw, r, errors.New("invalid Referer"))
return
} else if csrfTok == "" {
http.SetCookie(rw, &http.Cookie{
Name: csrfTokenCookieName,
Value: apiutil.RandStr(32),
Secure: true,
})
}
h.ServeHTTP(rw, r)
})
}
func checkCSRFMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
csrfTok, err := apiutil.GetCookie(r, csrfTokenCookieName, "")
if err != nil {
apiutil.InternalServerError(rw, r, err)
return return
} }
givenCSRFTok := r.Header.Get(csrfTokenHeaderName) if refererURL.Scheme != a.params.PublicURL.Scheme ||
if givenCSRFTok == "" { refererURL.Host != a.params.PublicURL.Host {
givenCSRFTok = r.FormValue(csrfTokenFormName) apiutil.BadRequest(rw, r, errors.New("invalid Referer"))
}
if csrfTok == "" || givenCSRFTok != csrfTok {
apiutil.BadRequest(rw, r, errors.New("invalid CSRF token"))
return return
} }
h.ServeHTTP(rw, r) h.ServeHTTP(rw, r)
}) })
} }
func (a *api) getCSRFTokenHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
csrfTok, err := apiutil.GetCookie(r, csrfTokenCookieName, "")
if err != nil {
apiutil.InternalServerError(rw, r, err)
return
}
apiutil.JSONResult(rw, r, struct {
CSRFToken string
}{
CSRFToken: csrfTok,
})
})
}

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"html/template" "html/template"
"io/fs" "io/fs"
"log"
"net/http" "net/http"
"path/filepath" "path/filepath"
"strings" "strings"
@ -100,21 +99,12 @@ func (a *api) mustParseTpl(name string) *template.Template {
func (a *api) mustParseBasedTpl(name string) *template.Template { func (a *api) mustParseBasedTpl(name string) *template.Template {
tpl := a.mustParseTpl(name) tpl := a.mustParseTpl(name)
tpl = template.Must(tpl.New("load-csrf.html").Parse(mustReadTplFile("load-csrf.html")))
tpl = template.Must(tpl.New("base.html").Parse(mustReadTplFile("base.html"))) tpl = template.Must(tpl.New("base.html").Parse(mustReadTplFile("base.html")))
return tpl return tpl
} }
type tplData struct { type tplData struct {
Payload interface{} Payload interface{}
CSRFToken string
}
func (t tplData) CSRFFormInput() template.HTML {
return template.HTML(fmt.Sprintf(
`<input type="hidden" name="%s" class="csrfHiddenInput" />`,
csrfTokenFormName,
))
} }
// executeTemplate expects to be the final action in an http.Handler // executeTemplate expects to be the final action in an http.Handler
@ -123,11 +113,8 @@ func executeTemplate(
tpl *template.Template, payload interface{}, tpl *template.Template, payload interface{},
) { ) {
csrfToken, _ := apiutil.GetCookie(r, csrfTokenCookieName, "")
tplData := tplData{ tplData := tplData{
Payload: payload, Payload: payload,
CSRFToken: csrfToken,
} }
if err := tpl.Execute(rw, tplData); err != nil { if err := tpl.Execute(rw, tplData); err != nil {
@ -141,7 +128,6 @@ func executeTemplate(
func (a *api) executeRedirectTpl( func (a *api) executeRedirectTpl(
rw http.ResponseWriter, r *http.Request, url string, rw http.ResponseWriter, r *http.Request, url string,
) { ) {
log.Printf("here url:%q", url)
executeTemplate(rw, r, a.redirectTpl, struct { executeTemplate(rw, r, a.redirectTpl, struct {
URL string URL string
}{ }{

View File

@ -1,7 +1,5 @@
{{ define "body" }} {{ define "body" }}
{{ $csrfFormInput := .CSRFFormInput }}
<h2>Upload Asset</h2> <h2>Upload Asset</h2>
<p> <p>
@ -10,7 +8,6 @@
</p> </p>
<form action="{{ BlogURL "assets/" }}" method="POST" enctype="multipart/form-data"> <form action="{{ BlogURL "assets/" }}" method="POST" enctype="multipart/form-data">
{{ $csrfFormInput }}
<div class="row"> <div class="row">
<div class="four columns"> <div class="four columns">
<input type="text" placeholder="Unique ID" name="id" /> <input type="text" placeholder="Unique ID" name="id" />
@ -37,7 +34,6 @@
method="POST" method="POST"
style="margin-bottom: 0;" style="margin-bottom: 0;"
> >
{{ $csrfFormInput }}
<input type="submit" value="Delete" /> <input type="submit" value="Delete" />
</form> </form>
</td> </td>
@ -46,8 +42,6 @@
</table> </table>
{{ template "load-csrf.html" . }}
{{ end }} {{ end }}
{{ template "base.html" . }} {{ template "base.html" . }}

View File

@ -2,8 +2,6 @@
<form method="POST" action="{{ BlogURL "posts/" }}"> <form method="POST" action="{{ BlogURL "posts/" }}">
{{ .CSRFFormInput }}
<div class="row"> <div class="row">
<div class="columns six"> <div class="columns six">
@ -99,8 +97,6 @@
</form> </form>
{{ template "load-csrf.html" . }}
{{ end }} {{ end }}
{{ template "base.html" . }} {{ template "base.html" . }}

View File

@ -1,13 +0,0 @@
<script async type="module" src="{{ StaticURL "api.js" }}"></script>
<script type="text/javascript">
(async () => {
const api = await import("{{ StaticURL "api.js" }}");
const res = await api.call("/api/csrf");
const els = document.getElementsByClassName("csrfHiddenInput");
for (let i = 0; i < els.length; i++) {
els[i].value = res.CSRFToken;
}
})();
</script>

View File

@ -18,8 +18,6 @@
{{ define "body" }} {{ define "body" }}
{{ $csrfFormInput := .CSRFFormInput }}
<p style="text-align: center;"> <p style="text-align: center;">
<a href="{{ BlogURL "posts/" }}?edit"> <a href="{{ BlogURL "posts/" }}?edit">
<button>New Post</button> <button>New Post</button>
@ -44,7 +42,6 @@
action="{{ PostURL .ID }}?method=delete" action="{{ PostURL .ID }}?method=delete"
method="POST" method="POST"
> >
{{ $csrfFormInput }}
<input type="submit" value="Delete" /> <input type="submit" value="Delete" />
</form> </form>
</td> </td>
@ -55,8 +52,6 @@
{{ template "posts-nextprev" . }} {{ template "posts-nextprev" . }}
{{ template "load-csrf.html" . }}
{{ end }} {{ end }}
{{ template "base.html" . }} {{ template "base.html" . }}