Fix CSRF loading on static GET pages

This commit is contained in:
Brian Picciano 2022-05-24 17:27:03 -06:00
parent 88ebaeda8f
commit 159638084e
7 changed files with 44 additions and 3 deletions

View File

@ -163,6 +163,9 @@ func (a *api) Shutdown(ctx context.Context) error {
func (a *api) apiHandler() http.Handler { func (a *api) apiHandler() http.Handler {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/csrf", a.getCSRFTokenHandler())
mux.Handle("/pow/challenge", a.newPowChallengeHandler()) mux.Handle("/pow/challenge", a.newPowChallengeHandler())
mux.Handle("/pow/check", mux.Handle("/pow/check",
a.requirePowMiddleware( a.requirePowMiddleware(

View File

@ -57,3 +57,22 @@ func checkCSRFMiddleware(h http.Handler) http.Handler {
h.ServeHTTP(rw, r) h.ServeHTTP(rw, r)
}) })
} }
func (a *api) getCSRFTokenHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
csrfTok, err := apiutil.GetCookie(r, csrfTokenCookieName, "")
if err != nil {
apiutil.InternalServerError(rw, r, err)
return
}
apiutil.JSONResult(rw, r, struct {
CSRFToken string
}{
CSRFToken: csrfTok,
})
})
}

View File

@ -100,6 +100,7 @@ func (a *api) mustParseTpl(name string) *template.Template {
func (a *api) mustParseBasedTpl(name string) *template.Template { func (a *api) mustParseBasedTpl(name string) *template.Template {
tpl := a.mustParseTpl(name) tpl := a.mustParseTpl(name)
tpl = template.Must(tpl.New("load-csrf.html").Parse(mustReadTplFile("load-csrf.html")))
tpl = template.Must(tpl.New("base.html").Parse(mustReadTplFile("base.html"))) tpl = template.Must(tpl.New("base.html").Parse(mustReadTplFile("base.html")))
return tpl return tpl
} }
@ -111,8 +112,8 @@ type tplData struct {
func (t tplData) CSRFFormInput() template.HTML { func (t tplData) CSRFFormInput() template.HTML {
return template.HTML(fmt.Sprintf( return template.HTML(fmt.Sprintf(
`<input type="hidden" name="%s" value="%s" />`, `<input type="hidden" name="%s" class="csrfHiddenInput" />`,
csrfTokenFormName, t.CSRFToken, csrfTokenFormName,
)) ))
} }

View File

@ -46,6 +46,8 @@
</table> </table>
{{ template "load-csrf.html" . }}
{{ end }} {{ end }}
{{ template "base.html" . }} {{ template "base.html" . }}

View File

@ -99,6 +99,8 @@
</form> </form>
{{ template "load-csrf.html" . }}
{{ end }} {{ end }}
{{ template "base.html" . }} {{ template "base.html" . }}

View File

@ -0,0 +1,13 @@
<script async type="module" src="{{ StaticURL "api.js" }}"></script>
<script type="text/javascript">
(async () => {
const api = await import("{{ StaticURL "api.js" }}");
const res = await api.call("/api/csrf");
const els = document.getElementsByClassName("csrfHiddenInput");
for (let i = 0; i < els.length; i++) {
els[i].value = res.CSRFToken;
}
})();
</script>

View File

@ -20,7 +20,6 @@
{{ $csrfFormInput := .CSRFFormInput }} {{ $csrfFormInput := .CSRFFormInput }}
<p style="text-align: center;"> <p style="text-align: center;">
<a href="{{ BlogURL "posts/" }}?edit"> <a href="{{ BlogURL "posts/" }}?edit">
<button>New Post</button> <button>New Post</button>
@ -56,6 +55,8 @@
{{ template "posts-nextprev" . }} {{ template "posts-nextprev" . }}
{{ template "load-csrf.html" . }}
{{ end }} {{ end }}
{{ template "base.html" . }} {{ template "base.html" . }}