diff --git a/srv/src/http/api.go b/srv/src/http/api.go index bcd0150..11d092d 100644 --- a/srv/src/http/api.go +++ b/srv/src/http/api.go @@ -166,24 +166,6 @@ func (a *api) handler() http.Handler { return a.requirePowMiddleware(h) } - formMiddleware := func(h http.Handler) http.Handler { - wh := checkCSRFMiddleware(h) - wh = logReqMiddleware(wh) - wh = addResponseHeaders(map[string]string{ - "Cache-Control": "no-store, max-age=0", - "Pragma": "no-cache", - "Expires": "0", - }, wh) - - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if r.Method != "GET" { - wh.ServeHTTP(rw, r) - } else { - h.ServeHTTP(rw, r) - } - }) - } - mux := http.NewServeMux() { @@ -215,17 +197,17 @@ func (a *api) handler() http.Handler { mux.Handle("/posts/", http.StripPrefix("/posts", apiutil.MethodMux(map[string]http.Handler{ "GET": a.renderPostHandler(), - "POST": authMiddleware(a.auther, a.postPostHandler()), - "DELETE": authMiddleware(a.auther, a.deletePostHandler()), - "PREVIEW": authMiddleware(a.auther, a.previewPostHandler()), + "POST": a.postPostHandler(), + "DELETE": a.deletePostHandler(), + "PREVIEW": a.previewPostHandler(), }), )) mux.Handle("/assets/", http.StripPrefix("/assets", apiutil.MethodMux(map[string]http.Handler{ "GET": a.getPostAssetHandler(), - "POST": authMiddleware(a.auther, a.postPostAssetHandler()), - "DELETE": authMiddleware(a.auther, a.deletePostAssetHandler()), + "POST": a.postPostAssetHandler(), + "DELETE": a.deletePostAssetHandler(), }), )) @@ -234,10 +216,28 @@ func (a *api) handler() http.Handler { mux.Handle("/feed.xml", a.renderFeedHandler()) mux.Handle("/", a.renderIndexHandler()) - var globalHandler http.Handler = mux - globalHandler = formMiddleware(globalHandler) - globalHandler = setCSRFMiddleware(globalHandler) - globalHandler = setLoggerMiddleware(a.params.Logger, globalHandler) + globalHandler := http.Handler(mux) + + globalHandler = apiutil.MethodMux(map[string]http.Handler{ + "GET": applyMiddlewares( + globalHandler, + logReqMiddleware, + setCSRFMiddleware, + ), + "*": applyMiddlewares( + globalHandler, + authMiddleware(a.auther), + checkCSRFMiddleware, + addResponseHeadersMiddleware(map[string]string{ + "Cache-Control": "no-store, max-age=0", + "Pragma": "no-cache", + "Expires": "0", + }), + logReqMiddleware, + ), + }) + + globalHandler = setLoggerMiddleware(a.params.Logger)(globalHandler) return globalHandler } diff --git a/srv/src/http/apiutil/apiutil.go b/srv/src/http/apiutil/apiutil.go index aa62299..fed6fb5 100644 --- a/srv/src/http/apiutil/apiutil.go +++ b/srv/src/http/apiutil/apiutil.go @@ -117,6 +117,9 @@ func RandStr(numBytes int) string { // // If no Handler is defined for a method then a 405 Method Not Allowed error is // returned. +// +// If the method "*" is defined then all methods not defined will be directed to +// that handler, and 405 Method Not Allowed is never returned. func MethodMux(handlers map[string]http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -128,13 +131,16 @@ func MethodMux(handlers map[string]http.Handler) http.Handler { method = formMethod } - handler, ok := handlers[method] + if handler, ok := handlers[method]; ok { + handler.ServeHTTP(rw, r) + return + } - if !ok { - http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed) + if handler, ok := handlers["*"]; ok { + handler.ServeHTTP(rw, r) return } - handler.ServeHTTP(rw, r) + http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed) }) } diff --git a/srv/src/http/auth.go b/srv/src/http/auth.go index 9527cc8..3ad026a 100644 --- a/srv/src/http/auth.go +++ b/srv/src/http/auth.go @@ -65,7 +65,7 @@ func (a *auther) Allowed(ctx context.Context, username, password string) bool { return err == nil } -func authMiddleware(auther Auther, h http.Handler) http.Handler { +func authMiddleware(auther Auther) middleware { respondUnauthorized := func(rw http.ResponseWriter, r *http.Request) { rw.Header().Set("WWW-Authenticate", `Basic realm="NOPE"`) @@ -73,20 +73,22 @@ func authMiddleware(auther Auther, h http.Handler) http.Handler { apiutil.GetRequestLogger(r).WarnString(r.Context(), "unauthorized") } - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - username, password, ok := r.BasicAuth() + username, password, ok := r.BasicAuth() - if !ok { - respondUnauthorized(rw, r) - return - } + if !ok { + respondUnauthorized(rw, r) + return + } - if !auther.Allowed(r.Context(), username, password) { - respondUnauthorized(rw, r) - return - } + if !auther.Allowed(r.Context(), username, password) { + respondUnauthorized(rw, r) + return + } - h.ServeHTTP(rw, r) - }) + h.ServeHTTP(rw, r) + }) + } } diff --git a/srv/src/http/middleware.go b/srv/src/http/middleware.go index 8299a71..02d156b 100644 --- a/srv/src/http/middleware.go +++ b/srv/src/http/middleware.go @@ -10,33 +10,46 @@ import ( "github.com/mediocregopher/mediocre-go-lib/v2/mlog" ) -func addResponseHeaders(headers map[string]string, h http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - for k, v := range headers { - rw.Header().Set(k, v) - } - h.ServeHTTP(rw, r) - }) +type middleware func(http.Handler) http.Handler + +func applyMiddlewares(h http.Handler, middlewares ...middleware) http.Handler { + for _, m := range middlewares { + h = m(h) + } + return h } -func setLoggerMiddleware(logger *mlog.Logger, h http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { +func addResponseHeadersMiddleware(headers map[string]string) middleware { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + for k, v := range headers { + rw.Header().Set(k, v) + } + h.ServeHTTP(rw, r) + }) + } +} - type reqInfoKey string +func setLoggerMiddleware(logger *mlog.Logger) middleware { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - ip, _, _ := net.SplitHostPort(r.RemoteAddr) + type logCtxKey string - ctx := r.Context() - ctx = mctx.Annotate(ctx, - reqInfoKey("remote_ip"), ip, - reqInfoKey("url"), r.URL, - reqInfoKey("method"), r.Method, - ) + ip, _, _ := net.SplitHostPort(r.RemoteAddr) - r = r.WithContext(ctx) - r = apiutil.SetRequestLogger(r, logger) - h.ServeHTTP(rw, r) - }) + ctx := r.Context() + ctx = mctx.Annotate(ctx, + logCtxKey("remote_ip"), ip, + logCtxKey("url"), r.URL, + logCtxKey("method"), r.Method, + ) + + r = r.WithContext(ctx) + r = apiutil.SetRequestLogger(r, logger) + h.ServeHTTP(rw, r) + }) + } } type logResponseWriter struct {