Simplify routes by moving formMiddleware to the global level

This commit is contained in:
Brian Picciano 2022-05-20 19:29:01 -06:00
parent 1181af0318
commit 034342421b
2 changed files with 29 additions and 27 deletions

View File

@ -167,15 +167,21 @@ func (a *api) handler() http.Handler {
} }
formMiddleware := func(h http.Handler) http.Handler { formMiddleware := func(h http.Handler) http.Handler {
h = checkCSRFMiddleware(h) wh := checkCSRFMiddleware(h)
h = disallowGetMiddleware(h) wh = logReqMiddleware(wh)
h = logReqMiddleware(h) wh = addResponseHeaders(map[string]string{
h = addResponseHeaders(map[string]string{
"Cache-Control": "no-store, max-age=0", "Cache-Control": "no-store, max-age=0",
"Pragma": "no-cache", "Pragma": "no-cache",
"Expires": "0", "Expires": "0",
}, h) }, wh)
return h
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
wh.ServeHTTP(rw, r)
} else {
h.ServeHTTP(rw, r)
}
})
} }
mux := http.NewServeMux() mux := http.NewServeMux()
@ -199,33 +205,27 @@ func (a *api) handler() http.Handler {
a.requirePowMiddleware, a.requirePowMiddleware,
))) )))
mux.Handle("/api/", http.StripPrefix("/api", formMiddleware(apiMux))) mux.Handle("/api/", http.StripPrefix("/api",
// disallowGetMiddleware is used rather than a MethodMux because it
// has an exception for websockets, which is needed for chat.
disallowGetMiddleware(apiMux),
))
} }
mux.Handle("/posts/", http.StripPrefix("/posts", mux.Handle("/posts/", http.StripPrefix("/posts",
apiutil.MethodMux(map[string]http.Handler{ apiutil.MethodMux(map[string]http.Handler{
"GET": a.renderPostHandler(), "GET": a.renderPostHandler(),
"POST": authMiddleware(a.auther, "POST": authMiddleware(a.auther, a.postPostHandler()),
formMiddleware(a.postPostHandler()), "DELETE": authMiddleware(a.auther, a.deletePostHandler()),
), "PREVIEW": authMiddleware(a.auther, a.previewPostHandler()),
"DELETE": authMiddleware(a.auther,
formMiddleware(a.deletePostHandler()),
),
"PREVIEW": authMiddleware(a.auther,
formMiddleware(a.previewPostHandler()),
),
}), }),
)) ))
mux.Handle("/assets/", http.StripPrefix("/assets", mux.Handle("/assets/", http.StripPrefix("/assets",
apiutil.MethodMux(map[string]http.Handler{ apiutil.MethodMux(map[string]http.Handler{
"GET": a.getPostAssetHandler(), "GET": a.getPostAssetHandler(),
"POST": authMiddleware(a.auther, "POST": authMiddleware(a.auther, a.postPostAssetHandler()),
formMiddleware(a.postPostAssetHandler()), "DELETE": authMiddleware(a.auther, a.deletePostAssetHandler()),
),
"DELETE": authMiddleware(a.auther,
formMiddleware(a.deletePostAssetHandler()),
),
}), }),
)) ))
@ -235,6 +235,7 @@ func (a *api) handler() http.Handler {
mux.Handle("/", a.renderIndexHandler()) mux.Handle("/", a.renderIndexHandler())
var globalHandler http.Handler = mux var globalHandler http.Handler = mux
globalHandler = formMiddleware(globalHandler)
globalHandler = setCSRFMiddleware(globalHandler) globalHandler = setCSRFMiddleware(globalHandler)
globalHandler = setLoggerMiddleware(a.params.Logger, globalHandler) globalHandler = setLoggerMiddleware(a.params.Logger, globalHandler)

View File

@ -121,10 +121,11 @@ func MethodMux(handlers map[string]http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
method := strings.ToUpper(r.FormValue("method")) method := strings.ToUpper(r.Method)
formMethod := strings.ToUpper(r.FormValue("method"))
if method == "" { if method == "POST" && formMethod != "" {
method = strings.ToUpper(r.Method) method = formMethod
} }
handler, ok := handlers[method] handler, ok := handlers[method]