Define an actual middleware type, use that to set up API routes

This commit is contained in:
Brian Picciano 2022-05-21 09:17:43 -06:00
parent 034342421b
commit 1de0ab3b72
4 changed files with 87 additions and 66 deletions

View File

@ -166,24 +166,6 @@ func (a *api) handler() http.Handler {
return a.requirePowMiddleware(h) return a.requirePowMiddleware(h)
} }
formMiddleware := func(h http.Handler) http.Handler {
wh := checkCSRFMiddleware(h)
wh = logReqMiddleware(wh)
wh = addResponseHeaders(map[string]string{
"Cache-Control": "no-store, max-age=0",
"Pragma": "no-cache",
"Expires": "0",
}, wh)
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
wh.ServeHTTP(rw, r)
} else {
h.ServeHTTP(rw, r)
}
})
}
mux := http.NewServeMux() mux := http.NewServeMux()
{ {
@ -215,17 +197,17 @@ func (a *api) handler() http.Handler {
mux.Handle("/posts/", http.StripPrefix("/posts", mux.Handle("/posts/", http.StripPrefix("/posts",
apiutil.MethodMux(map[string]http.Handler{ apiutil.MethodMux(map[string]http.Handler{
"GET": a.renderPostHandler(), "GET": a.renderPostHandler(),
"POST": authMiddleware(a.auther, a.postPostHandler()), "POST": a.postPostHandler(),
"DELETE": authMiddleware(a.auther, a.deletePostHandler()), "DELETE": a.deletePostHandler(),
"PREVIEW": authMiddleware(a.auther, a.previewPostHandler()), "PREVIEW": a.previewPostHandler(),
}), }),
)) ))
mux.Handle("/assets/", http.StripPrefix("/assets", mux.Handle("/assets/", http.StripPrefix("/assets",
apiutil.MethodMux(map[string]http.Handler{ apiutil.MethodMux(map[string]http.Handler{
"GET": a.getPostAssetHandler(), "GET": a.getPostAssetHandler(),
"POST": authMiddleware(a.auther, a.postPostAssetHandler()), "POST": a.postPostAssetHandler(),
"DELETE": authMiddleware(a.auther, a.deletePostAssetHandler()), "DELETE": a.deletePostAssetHandler(),
}), }),
)) ))
@ -234,10 +216,28 @@ func (a *api) handler() http.Handler {
mux.Handle("/feed.xml", a.renderFeedHandler()) mux.Handle("/feed.xml", a.renderFeedHandler())
mux.Handle("/", a.renderIndexHandler()) mux.Handle("/", a.renderIndexHandler())
var globalHandler http.Handler = mux globalHandler := http.Handler(mux)
globalHandler = formMiddleware(globalHandler)
globalHandler = setCSRFMiddleware(globalHandler) globalHandler = apiutil.MethodMux(map[string]http.Handler{
globalHandler = setLoggerMiddleware(a.params.Logger, globalHandler) "GET": applyMiddlewares(
globalHandler,
logReqMiddleware,
setCSRFMiddleware,
),
"*": applyMiddlewares(
globalHandler,
authMiddleware(a.auther),
checkCSRFMiddleware,
addResponseHeadersMiddleware(map[string]string{
"Cache-Control": "no-store, max-age=0",
"Pragma": "no-cache",
"Expires": "0",
}),
logReqMiddleware,
),
})
globalHandler = setLoggerMiddleware(a.params.Logger)(globalHandler)
return globalHandler return globalHandler
} }

View File

@ -117,6 +117,9 @@ func RandStr(numBytes int) string {
// //
// If no Handler is defined for a method then a 405 Method Not Allowed error is // If no Handler is defined for a method then a 405 Method Not Allowed error is
// returned. // returned.
//
// If the method "*" is defined then all methods not defined will be directed to
// that handler, and 405 Method Not Allowed is never returned.
func MethodMux(handlers map[string]http.Handler) http.Handler { func MethodMux(handlers map[string]http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@ -128,13 +131,16 @@ func MethodMux(handlers map[string]http.Handler) http.Handler {
method = formMethod method = formMethod
} }
handler, ok := handlers[method] if handler, ok := handlers[method]; ok {
handler.ServeHTTP(rw, r)
if !ok {
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return return
} }
handler.ServeHTTP(rw, r) if handler, ok := handlers["*"]; ok {
handler.ServeHTTP(rw, r)
return
}
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
}) })
} }

View File

@ -65,7 +65,7 @@ func (a *auther) Allowed(ctx context.Context, username, password string) bool {
return err == nil return err == nil
} }
func authMiddleware(auther Auther, h http.Handler) http.Handler { func authMiddleware(auther Auther) middleware {
respondUnauthorized := func(rw http.ResponseWriter, r *http.Request) { respondUnauthorized := func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("WWW-Authenticate", `Basic realm="NOPE"`) rw.Header().Set("WWW-Authenticate", `Basic realm="NOPE"`)
@ -73,20 +73,22 @@ func authMiddleware(auther Auther, h http.Handler) http.Handler {
apiutil.GetRequestLogger(r).WarnString(r.Context(), "unauthorized") apiutil.GetRequestLogger(r).WarnString(r.Context(), "unauthorized")
} }
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth() username, password, ok := r.BasicAuth()
if !ok { if !ok {
respondUnauthorized(rw, r) respondUnauthorized(rw, r)
return return
} }
if !auther.Allowed(r.Context(), username, password) { if !auther.Allowed(r.Context(), username, password) {
respondUnauthorized(rw, r) respondUnauthorized(rw, r)
return return
} }
h.ServeHTTP(rw, r) h.ServeHTTP(rw, r)
}) })
}
} }

View File

@ -10,33 +10,46 @@ import (
"github.com/mediocregopher/mediocre-go-lib/v2/mlog" "github.com/mediocregopher/mediocre-go-lib/v2/mlog"
) )
func addResponseHeaders(headers map[string]string, h http.Handler) http.Handler { type middleware func(http.Handler) http.Handler
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
for k, v := range headers { func applyMiddlewares(h http.Handler, middlewares ...middleware) http.Handler {
rw.Header().Set(k, v) for _, m := range middlewares {
} h = m(h)
h.ServeHTTP(rw, r) }
}) return h
} }
func setLoggerMiddleware(logger *mlog.Logger, h http.Handler) http.Handler { func addResponseHeadersMiddleware(headers map[string]string) middleware {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
for k, v := range headers {
rw.Header().Set(k, v)
}
h.ServeHTTP(rw, r)
})
}
}
type reqInfoKey string func setLoggerMiddleware(logger *mlog.Logger) middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ip, _, _ := net.SplitHostPort(r.RemoteAddr) type logCtxKey string
ctx := r.Context() ip, _, _ := net.SplitHostPort(r.RemoteAddr)
ctx = mctx.Annotate(ctx,
reqInfoKey("remote_ip"), ip,
reqInfoKey("url"), r.URL,
reqInfoKey("method"), r.Method,
)
r = r.WithContext(ctx) ctx := r.Context()
r = apiutil.SetRequestLogger(r, logger) ctx = mctx.Annotate(ctx,
h.ServeHTTP(rw, r) logCtxKey("remote_ip"), ip,
}) logCtxKey("url"), r.URL,
logCtxKey("method"), r.Method,
)
r = r.WithContext(ctx)
r = apiutil.SetRequestLogger(r, logger)
h.ServeHTTP(rw, r)
})
}
} }
type logResponseWriter struct { type logResponseWriter struct {