diff --git a/srv/src/go.mod b/srv/src/go.mod index dd6509d..f65034f 100644 --- a/srv/src/go.mod +++ b/srv/src/go.mod @@ -10,6 +10,7 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/feeds v1.1.1 // indirect github.com/gorilla/websocket v1.4.2 + github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/mattn/go-sqlite3 v1.14.8 github.com/mediocregopher/mediocre-go-lib/v2 v2.0.0-beta.0.0.20220506011745-cbeee71cb1ee github.com/mediocregopher/radix/v4 v4.0.0-beta.1.0.20210726230805-d62fa1b2e3cb diff --git a/srv/src/go.sum b/srv/src/go.sum index f5bd1fe..d8139cc 100644 --- a/srv/src/go.sum +++ b/srv/src/go.sum @@ -75,6 +75,8 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgf github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/huandu/xstrings v1.3.2/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= diff --git a/srv/src/http/api.go b/srv/src/http/api.go index 11d092d..ebc1de2 100644 --- a/srv/src/http/api.go +++ b/srv/src/http/api.go @@ -15,6 +15,7 @@ import ( "strings" "time" + lru "github.com/hashicorp/golang-lru" "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg" "github.com/mediocregopher/blog.mediocregopher.com/srv/chat" "github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil" @@ -162,8 +163,11 @@ func (a *api) Shutdown(ctx context.Context) error { func (a *api) handler() http.Handler { - requirePow := func(h http.Handler) http.Handler { - return a.requirePowMiddleware(h) + cache, err := lru.New(5000) + + // instantiating the lru cache can't realistically fail + if err != nil { + panic(err) } mux := http.NewServeMux() @@ -172,12 +176,12 @@ func (a *api) handler() http.Handler { apiMux := http.NewServeMux() apiMux.Handle("/pow/challenge", a.newPowChallengeHandler()) apiMux.Handle("/pow/check", - requirePow( + a.requirePowMiddleware( http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), ), ) - apiMux.Handle("/mailinglist/subscribe", requirePow(a.mailingListSubscribeHandler())) + apiMux.Handle("/mailinglist/subscribe", a.requirePowMiddleware(a.mailingListSubscribeHandler())) apiMux.Handle("/mailinglist/finalize", a.mailingListFinalizeHandler()) apiMux.Handle("/mailinglist/unsubscribe", a.mailingListUnsubscribeHandler()) @@ -222,10 +226,12 @@ func (a *api) handler() http.Handler { "GET": applyMiddlewares( globalHandler, logReqMiddleware, + cacheMiddleware(cache), setCSRFMiddleware, ), "*": applyMiddlewares( globalHandler, + purgeCacheOnOKMiddleware(cache), authMiddleware(a.auther), checkCSRFMiddleware, addResponseHeadersMiddleware(map[string]string{ diff --git a/srv/src/http/middleware.go b/srv/src/http/middleware.go index 02d156b..7296a35 100644 --- a/srv/src/http/middleware.go +++ b/srv/src/http/middleware.go @@ -1,10 +1,14 @@ package http import ( + "bytes" "net" "net/http" + "path/filepath" + "sync" "time" + lru "github.com/hashicorp/golang-lru" "github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil" "github.com/mediocregopher/mediocre-go-lib/v2/mctx" "github.com/mediocregopher/mediocre-go-lib/v2/mlog" @@ -52,33 +56,33 @@ func setLoggerMiddleware(logger *mlog.Logger) middleware { } } -type logResponseWriter struct { +type wrappedResponseWriter struct { http.ResponseWriter http.Hijacker statusCode int } -func newLogResponseWriter(rw http.ResponseWriter) *logResponseWriter { +func newWrappedResponseWriter(rw http.ResponseWriter) *wrappedResponseWriter { h, _ := rw.(http.Hijacker) - return &logResponseWriter{ + return &wrappedResponseWriter{ ResponseWriter: rw, Hijacker: h, statusCode: 200, } } -func (lrw *logResponseWriter) WriteHeader(statusCode int) { - lrw.statusCode = statusCode - lrw.ResponseWriter.WriteHeader(statusCode) +func (rw *wrappedResponseWriter) WriteHeader(statusCode int) { + rw.statusCode = statusCode + rw.ResponseWriter.WriteHeader(statusCode) } func logReqMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - lrw := newLogResponseWriter(rw) + wrw := newWrappedResponseWriter(rw) started := time.Now() - h.ServeHTTP(lrw, r) + h.ServeHTTP(wrw, r) took := time.Since(started) type logCtxKey string @@ -86,7 +90,7 @@ func logReqMiddleware(h http.Handler) http.Handler { ctx := r.Context() ctx = mctx.Annotate(ctx, logCtxKey("took"), took.String(), - logCtxKey("response_code"), lrw.statusCode, + logCtxKey("response_code"), wrw.statusCode, ) apiutil.GetRequestLogger(r).Info(ctx, "handled HTTP request") @@ -106,3 +110,80 @@ func disallowGetMiddleware(h http.Handler) http.Handler { rw.WriteHeader(405) }) } + +type cacheResponseWriter struct { + *wrappedResponseWriter + buf *bytes.Buffer +} + +func newCacheResponseWriter(rw http.ResponseWriter) *cacheResponseWriter { + return &cacheResponseWriter{ + wrappedResponseWriter: newWrappedResponseWriter(rw), + buf: new(bytes.Buffer), + } +} + +func (rw *cacheResponseWriter) Write(b []byte) (int, error) { + if _, err := rw.buf.Write(b); err != nil { + panic(err) + } + return rw.wrappedResponseWriter.Write(b) +} + +func cacheMiddleware(cache *lru.Cache) middleware { + + type entry struct { + body []byte + createdAt time.Time + } + + pool := sync.Pool{ + New: func() interface{} { return new(bytes.Reader) }, + } + + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + id := r.URL.RequestURI() + + if val, ok := cache.Get(id); ok { + + entry := val.(entry) + + reader := pool.Get().(*bytes.Reader) + defer pool.Put(reader) + + reader.Reset(entry.body) + + http.ServeContent( + rw, r, filepath.Base(r.URL.Path), entry.createdAt, reader, + ) + return + } + + cacheRW := newCacheResponseWriter(rw) + h.ServeHTTP(cacheRW, r) + + if cacheRW.statusCode == 200 { + cache.Add(id, entry{ + body: cacheRW.buf.Bytes(), + createdAt: time.Now(), + }) + } + }) + } +} + +func purgeCacheOnOKMiddleware(cache *lru.Cache) middleware { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + wrw := newWrappedResponseWriter(rw) + h.ServeHTTP(wrw, r) + + if wrw.statusCode == 200 { + cache.Purge() + } + }) + } +}