A fast and simple blog backend.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
mediocre-blog/src/http/middleware.go

199 lines
4.5 KiB

package http
import (
"bytes"
"net"
"net/http"
"path/filepath"
"sync"
"time"
lru "github.com/hashicorp/golang-lru"
"github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil"
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
)
type middleware func(http.Handler) http.Handler
func applyMiddlewares(h http.Handler, middlewares ...middleware) http.Handler {
for _, m := range middlewares {
h = m(h)
}
return h
}
func addResponseHeadersMiddleware(headers map[string]string) middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
for k, v := range headers {
rw.Header().Set(k, v)
}
h.ServeHTTP(rw, r)
})
}
}
func setLoggerMiddleware(logger *mlog.Logger) middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
type logCtxKey string
ctx := r.Context()
ctx = mctx.Annotate(ctx,
logCtxKey("url"), r.URL,
logCtxKey("method"), r.Method,
)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ctx = mctx.Annotate(ctx, logCtxKey("x_forwarded_for"), xff)
}
if xrip := r.Header.Get("X-Real-IP"); xrip != "" {
ctx = mctx.Annotate(ctx, logCtxKey("x_real_ip"), xrip)
}
if ip, _, _ := net.SplitHostPort(r.RemoteAddr); ip != "" {
ctx = mctx.Annotate(ctx, logCtxKey("remote_ip"), ip)
}
r = r.WithContext(ctx)
r = apiutil.SetRequestLogger(r, logger)
h.ServeHTTP(rw, r)
})
}
}
type wrappedResponseWriter struct {
http.ResponseWriter
http.Hijacker
statusCode int
}
func newWrappedResponseWriter(rw http.ResponseWriter) *wrappedResponseWriter {
h, _ := rw.(http.Hijacker)
return &wrappedResponseWriter{
ResponseWriter: rw,
Hijacker: h,
statusCode: 200,
}
}
func (rw *wrappedResponseWriter) WriteHeader(statusCode int) {
rw.statusCode = statusCode
rw.ResponseWriter.WriteHeader(statusCode)
}
func logReqMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
wrw := newWrappedResponseWriter(rw)
started := time.Now()
h.ServeHTTP(wrw, r)
took := time.Since(started)
type logCtxKey string
ctx := r.Context()
ctx = mctx.Annotate(ctx,
logCtxKey("took"), took.String(),
logCtxKey("response_code"), wrw.statusCode,
)
apiutil.GetRequestLogger(r).Info(ctx, "handled HTTP request")
})
}
func disallowGetMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// we allow websockets to be GETs because, well, they must be
if r.Method != "GET" || r.Header.Get("Upgrade") == "websocket" {
h.ServeHTTP(rw, r)
return
}
apiutil.GetRequestLogger(r).WarnString(r.Context(), "method not allowed")
rw.WriteHeader(405)
})
}
type cacheResponseWriter struct {
*wrappedResponseWriter
buf *bytes.Buffer
}
func newCacheResponseWriter(rw http.ResponseWriter) *cacheResponseWriter {
return &cacheResponseWriter{
wrappedResponseWriter: newWrappedResponseWriter(rw),
buf: new(bytes.Buffer),
}
}
func (rw *cacheResponseWriter) Write(b []byte) (int, error) {
if _, err := rw.buf.Write(b); err != nil {
panic(err)
}
return rw.wrappedResponseWriter.Write(b)
}
func cacheMiddleware(cache *lru.Cache) middleware {
type entry struct {
body []byte
createdAt time.Time
}
pool := sync.Pool{
New: func() interface{} { return new(bytes.Reader) },
}
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
id := r.URL.RequestURI()
if val, ok := cache.Get(id); ok {
entry := val.(entry)
reader := pool.Get().(*bytes.Reader)
defer pool.Put(reader)
reader.Reset(entry.body)
http.ServeContent(
rw, r, filepath.Base(r.URL.Path), entry.createdAt, reader,
)
return
}
cacheRW := newCacheResponseWriter(rw)
h.ServeHTTP(cacheRW, r)
if cacheRW.statusCode == 200 {
cache.Add(id, entry{
body: cacheRW.buf.Bytes(),
createdAt: time.Now(),
})
}
})
}
}
func purgeCacheOnOKMiddleware(cache *lru.Cache) middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
wrw := newWrappedResponseWriter(rw)
h.ServeHTTP(wrw, r)
if wrw.statusCode == 200 {
apiutil.GetRequestLogger(r).Info(r.Context(), "purging cache!")
cache.Purge()
}
})
}
}