You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
204 lines
4.7 KiB
204 lines
4.7 KiB
package http
|
|
|
|
import (
|
|
"bytes"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"path/filepath"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/mediocregopher/blog.mediocregopher.com/srv/cache"
|
|
"github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil"
|
|
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
|
|
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
|
|
)
|
|
|
|
type middleware func(http.Handler) http.Handler
|
|
|
|
func applyMiddlewares(h http.Handler, middlewares ...middleware) http.Handler {
|
|
for _, m := range middlewares {
|
|
h = m(h)
|
|
}
|
|
return h
|
|
}
|
|
|
|
func addResponseHeadersMiddleware(headers map[string]string) middleware {
|
|
return func(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
for k, v := range headers {
|
|
rw.Header().Set(k, v)
|
|
}
|
|
h.ServeHTTP(rw, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func setLoggerMiddleware(logger *mlog.Logger) middleware {
|
|
return func(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
|
|
type logCtxKey string
|
|
|
|
ctx := r.Context()
|
|
ctx = mctx.Annotate(ctx,
|
|
logCtxKey("url"), r.URL,
|
|
logCtxKey("method"), r.Method,
|
|
)
|
|
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
ctx = mctx.Annotate(ctx, logCtxKey("x_forwarded_for"), xff)
|
|
}
|
|
|
|
if xrip := r.Header.Get("X-Real-IP"); xrip != "" {
|
|
ctx = mctx.Annotate(ctx, logCtxKey("x_real_ip"), xrip)
|
|
}
|
|
|
|
if ip, _, _ := net.SplitHostPort(r.RemoteAddr); ip != "" {
|
|
ctx = mctx.Annotate(ctx, logCtxKey("remote_ip"), ip)
|
|
}
|
|
|
|
r = r.WithContext(ctx)
|
|
r = apiutil.SetRequestLogger(r, logger)
|
|
h.ServeHTTP(rw, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
type wrappedResponseWriter struct {
|
|
http.ResponseWriter
|
|
http.Hijacker
|
|
statusCode int
|
|
}
|
|
|
|
func newWrappedResponseWriter(rw http.ResponseWriter) *wrappedResponseWriter {
|
|
h, _ := rw.(http.Hijacker)
|
|
return &wrappedResponseWriter{
|
|
ResponseWriter: rw,
|
|
Hijacker: h,
|
|
statusCode: 200,
|
|
}
|
|
}
|
|
|
|
func (rw *wrappedResponseWriter) WriteHeader(statusCode int) {
|
|
rw.statusCode = statusCode
|
|
rw.ResponseWriter.WriteHeader(statusCode)
|
|
}
|
|
|
|
func logReqMiddleware(h http.Handler) http.Handler {
|
|
|
|
type logCtxKey string
|
|
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
|
|
wrw := newWrappedResponseWriter(rw)
|
|
|
|
started := time.Now()
|
|
h.ServeHTTP(wrw, r)
|
|
took := time.Since(started)
|
|
|
|
ctx := r.Context()
|
|
ctx = mctx.Annotate(ctx,
|
|
logCtxKey("took"), took.String(),
|
|
logCtxKey("response_code"), wrw.statusCode,
|
|
)
|
|
|
|
apiutil.GetRequestLogger(r).Info(ctx, "handled HTTP request")
|
|
})
|
|
}
|
|
|
|
func disallowGetMiddleware(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
|
|
// we allow websockets to be GETs because, well, they must be
|
|
if r.Method != "GET" || r.Header.Get("Upgrade") == "websocket" {
|
|
h.ServeHTTP(rw, r)
|
|
return
|
|
}
|
|
|
|
apiutil.GetRequestLogger(r).WarnString(r.Context(), "method not allowed")
|
|
rw.WriteHeader(405)
|
|
})
|
|
}
|
|
|
|
type cacheResponseWriter struct {
|
|
*wrappedResponseWriter
|
|
buf *bytes.Buffer
|
|
}
|
|
|
|
func newCacheResponseWriter(rw http.ResponseWriter) *cacheResponseWriter {
|
|
return &cacheResponseWriter{
|
|
wrappedResponseWriter: newWrappedResponseWriter(rw),
|
|
buf: new(bytes.Buffer),
|
|
}
|
|
}
|
|
|
|
func (rw *cacheResponseWriter) Write(b []byte) (int, error) {
|
|
if _, err := rw.buf.Write(b); err != nil {
|
|
panic(err)
|
|
}
|
|
return rw.wrappedResponseWriter.Write(b)
|
|
}
|
|
|
|
func cacheMiddleware(cache cache.Cache, publicURL *url.URL) middleware {
|
|
|
|
type entry struct {
|
|
body []byte
|
|
createdAt time.Time
|
|
}
|
|
|
|
pool := sync.Pool{
|
|
New: func() interface{} { return new(bytes.Reader) },
|
|
}
|
|
|
|
return func(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
|
|
// r.URL doesn't have Scheme or Host populated, better to add the
|
|
// public url to the key to make sure there's no possiblity of
|
|
// collision with other protocols using the cache.
|
|
id := publicURL.String() + "|" + r.URL.String()
|
|
|
|
if value := cache.Get(id); value != nil {
|
|
|
|
entry := value.(entry)
|
|
|
|
reader := pool.Get().(*bytes.Reader)
|
|
defer pool.Put(reader)
|
|
|
|
reader.Reset(entry.body)
|
|
|
|
http.ServeContent(
|
|
rw, r, filepath.Base(r.URL.Path), entry.createdAt, reader,
|
|
)
|
|
return
|
|
}
|
|
|
|
cacheRW := newCacheResponseWriter(rw)
|
|
h.ServeHTTP(cacheRW, r)
|
|
|
|
if cacheRW.statusCode == 200 {
|
|
cache.Set(id, entry{
|
|
body: cacheRW.buf.Bytes(),
|
|
createdAt: time.Now(),
|
|
})
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func purgeCacheOnOKMiddleware(cache cache.Cache) middleware {
|
|
return func(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
|
|
wrw := newWrappedResponseWriter(rw)
|
|
h.ServeHTTP(wrw, r)
|
|
|
|
if wrw.statusCode == 200 {
|
|
apiutil.GetRequestLogger(r).Info(r.Context(), "purging cache!")
|
|
cache.Purge()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|