Implement ratelimit on authentications

This commit is contained in:
Brian Picciano 2022-05-20 14:54:26 -06:00
parent ae1fa76efc
commit 1ffda21ae3
6 changed files with 84 additions and 23 deletions

View File

@ -16,4 +16,7 @@
httpAuthUsers = { httpAuthUsers = {
"foo" = "$2a$13$0JdWlUfHc.3XimEMpEu1cuu6RodhUvzD9l7iiAqa4YkM3mcFV5Pxi"; "foo" = "$2a$13$0JdWlUfHc.3XimEMpEu1cuu6RodhUvzD9l7iiAqa4YkM3mcFV5Pxi";
}; };
# Very low, should be increased for prod.
httpAuthRatelimit = "1s";
} }

View File

@ -29,6 +29,8 @@
export MEDIOCRE_BLOG_LISTEN_PROTO="${config.listenProto}" export MEDIOCRE_BLOG_LISTEN_PROTO="${config.listenProto}"
export MEDIOCRE_BLOG_LISTEN_ADDR="${config.listenAddr}" export MEDIOCRE_BLOG_LISTEN_ADDR="${config.listenAddr}"
export MEDIOCRE_BLOG_HTTP_AUTH_USERS='${builtins.toJSON config.httpAuthUsers}' export MEDIOCRE_BLOG_HTTP_AUTH_USERS='${builtins.toJSON config.httpAuthUsers}'
export MEDIOCRE_BLOG_HTTP_AUTH_RATELIMIT='${config.httpAuthRatelimit}'
''; '';
build = buildGoModule { build = buildGoModule {

View File

@ -11,6 +11,7 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"time"
"github.com/mediocregopher/blog.mediocregopher.com/srv/cfg" "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg"
"github.com/mediocregopher/blog.mediocregopher.com/srv/chat" "github.com/mediocregopher/blog.mediocregopher.com/srv/chat"
@ -52,6 +53,10 @@ type Params struct {
// and the values are the password hash which accompanies those users. The // and the values are the password hash which accompanies those users. The
// password hash must have been produced by NewPasswordHash. // password hash must have been produced by NewPasswordHash.
AuthUsers map[string]string AuthUsers map[string]string
// AuthRatelimit indicates how much time must pass between subsequent auth
// attempts.
AuthRatelimit time.Duration
} }
// SetupCfg implement the cfg.Cfger interface. // SetupCfg implement the cfg.Cfger interface.
@ -61,10 +66,20 @@ func (p *Params) SetupCfg(cfg *cfg.Cfg) {
httpAuthUsersStr := cfg.String("http-auth-users", "{}", "JSON object with usernames as values and password hashes (produced by the hash-password binary) as values. Denotes users which are able to edit server-side data") httpAuthUsersStr := cfg.String("http-auth-users", "{}", "JSON object with usernames as values and password hashes (produced by the hash-password binary) as values. Denotes users which are able to edit server-side data")
httpAuthRatelimitStr := cfg.String("http-auth-ratelimit", "5s", "Minimum duration which must be waited between subsequent auth attempts")
cfg.OnInit(func(context.Context) error { cfg.OnInit(func(context.Context) error {
if err := json.Unmarshal([]byte(*httpAuthUsersStr), &p.AuthUsers); err != nil {
err := json.Unmarshal([]byte(*httpAuthUsersStr), &p.AuthUsers)
if err != nil {
return fmt.Errorf("unmarshaling -http-auth-users: %w", err) return fmt.Errorf("unmarshaling -http-auth-users: %w", err)
} }
if p.AuthRatelimit, err = time.ParseDuration(*httpAuthRatelimitStr); err != nil {
return fmt.Errorf("unmarshaling -http-auth-ratelimit: %w", err)
}
return nil return nil
}) })
} }
@ -73,6 +88,7 @@ func (p *Params) SetupCfg(cfg *cfg.Cfg) {
func (p *Params) Annotate(a mctx.Annotations) { func (p *Params) Annotate(a mctx.Annotations) {
a["listenProto"] = p.ListenProto a["listenProto"] = p.ListenProto
a["listenAddr"] = p.ListenAddr a["listenAddr"] = p.ListenAddr
a["authRatelimit"] = p.AuthRatelimit
} }
// API will listen on the port configured for it, and serve HTTP requests for // API will listen on the port configured for it, and serve HTTP requests for
@ -86,6 +102,7 @@ type api struct {
srv *http.Server srv *http.Server
redirectTpl *template.Template redirectTpl *template.Template
auther Auther
} }
// New initializes and returns a new API instance, including setting up all // New initializes and returns a new API instance, including setting up all
@ -105,6 +122,7 @@ func New(params Params) (API, error) {
a := &api{ a := &api{
params: params, params: params,
auther: NewAuther(params.AuthUsers, params.AuthRatelimit),
} }
a.redirectTpl = a.mustParseTpl("redirect.html") a.redirectTpl = a.mustParseTpl("redirect.html")
@ -124,6 +142,7 @@ func New(params Params) (API, error) {
} }
func (a *api) Shutdown(ctx context.Context) error { func (a *api) Shutdown(ctx context.Context) error {
defer a.auther.Close()
if err := a.srv.Shutdown(ctx); err != nil { if err := a.srv.Shutdown(ctx); err != nil {
return err return err
} }
@ -149,8 +168,6 @@ func (a *api) handler() http.Handler {
return h return h
} }
auther := NewAuther(a.params.AuthUsers)
mux := http.NewServeMux() mux := http.NewServeMux()
{ {
@ -179,13 +196,13 @@ func (a *api) handler() http.Handler {
apiutil.MethodMux(map[string]http.Handler{ apiutil.MethodMux(map[string]http.Handler{
"GET": a.renderPostHandler(), "GET": a.renderPostHandler(),
"EDIT": a.editPostHandler(), "EDIT": a.editPostHandler(),
"POST": authMiddleware(auther, "POST": authMiddleware(a.auther,
formMiddleware(a.postPostHandler()), formMiddleware(a.postPostHandler()),
), ),
"DELETE": authMiddleware(auther, "DELETE": authMiddleware(a.auther,
formMiddleware(a.deletePostHandler()), formMiddleware(a.deletePostHandler()),
), ),
"PREVIEW": authMiddleware(auther, "PREVIEW": authMiddleware(a.auther,
formMiddleware(a.previewPostHandler()), formMiddleware(a.previewPostHandler()),
), ),
}), }),
@ -194,10 +211,10 @@ func (a *api) handler() http.Handler {
mux.Handle("/assets/", http.StripPrefix("/assets", mux.Handle("/assets/", http.StripPrefix("/assets",
apiutil.MethodMux(map[string]http.Handler{ apiutil.MethodMux(map[string]http.Handler{
"GET": a.getPostAssetHandler(), "GET": a.getPostAssetHandler(),
"POST": authMiddleware(auther, "POST": authMiddleware(a.auther,
formMiddleware(a.postPostAssetHandler()), formMiddleware(a.postPostAssetHandler()),
), ),
"DELETE": authMiddleware(auther, "DELETE": authMiddleware(a.auther,
formMiddleware(a.deletePostAssetHandler()), formMiddleware(a.deletePostAssetHandler()),
), ),
}), }),

View File

@ -1,7 +1,9 @@
package http package http
import ( import (
"context"
"net/http" "net/http"
"time"
"github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil" "github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -19,21 +21,37 @@ func NewPasswordHash(plaintext string) string {
// Auther determines who can do what. // Auther determines who can do what.
type Auther interface { type Auther interface {
Allowed(username, password string) bool Allowed(ctx context.Context, username, password string) bool
Close() error
} }
type auther struct { type auther struct {
users map[string]string users map[string]string
ticker *time.Ticker
} }
// NewAuther initializes and returns an Auther will which allow the given // NewAuther initializes and returns an Auther will which allow the given
// username and password hash combinations. Password hashes must have been // username and password hash combinations. Password hashes must have been
// created using NewPasswordHash. // created using NewPasswordHash.
func NewAuther(users map[string]string) Auther { func NewAuther(users map[string]string, ratelimit time.Duration) Auther {
return &auther{users: users} return &auther{
users: users,
ticker: time.NewTicker(ratelimit),
}
} }
func (a *auther) Allowed(username, password string) bool { func (a *auther) Close() error {
a.ticker.Stop()
return nil
}
func (a *auther) Allowed(ctx context.Context, username, password string) bool {
select {
case <-ctx.Done():
return false
case <-a.ticker.C:
}
hashedPassword, ok := a.users[username] hashedPassword, ok := a.users[username]
if !ok { if !ok {
@ -64,7 +82,7 @@ func authMiddleware(auther Auther, h http.Handler) http.Handler {
return return
} }
if !auther.Allowed(username, password) { if !auther.Allowed(r.Context(), username, password) {
respondUnauthorized(rw, r) respondUnauthorized(rw, r)
return return
} }

View File

@ -1,21 +1,24 @@
package http package http
import ( import (
"context"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestAuther(t *testing.T) { func TestAuther(t *testing.T) {
ctx := context.Background()
password := "foo" password := "foo"
hashedPassword := NewPasswordHash(password) hashedPassword := NewPasswordHash(password)
auther := NewAuther(map[string]string{ auther := NewAuther(map[string]string{
"FOO": hashedPassword, "FOO": hashedPassword,
}) }, 1*time.Millisecond)
assert.False(t, auther.Allowed("BAR", password)) assert.False(t, auther.Allowed(ctx, "BAR", password))
assert.False(t, auther.Allowed("FOO", "bar")) assert.False(t, auther.Allowed(ctx, "FOO", "bar"))
assert.True(t, auther.Allowed("FOO", password)) assert.True(t, auther.Allowed(ctx, "FOO", password))
} }

View File

@ -197,7 +197,7 @@ func (a *api) editPostHandler() http.Handler {
}) })
} }
func postFromPostReq(r *http.Request) post.Post { func postFromPostReq(r *http.Request) (post.Post, error) {
p := post.Post{ p := post.Post{
ID: r.PostFormValue("id"), ID: r.PostFormValue("id"),
@ -207,18 +207,30 @@ func postFromPostReq(r *http.Request) post.Post {
Series: r.PostFormValue("series"), Series: r.PostFormValue("series"),
} }
p.Body = strings.TrimSpace(r.PostFormValue("body"))
// textareas encode newlines as CRLF for historical reasons // textareas encode newlines as CRLF for historical reasons
p.Body = strings.ReplaceAll(p.Body, "\r\n", "\n") p.Body = strings.ReplaceAll(p.Body, "\r\n", "\n")
p.Body = strings.TrimSpace(r.PostFormValue("body"))
return p if p.ID == "" ||
p.Title == "" ||
p.Description == "" ||
p.Body == "" ||
len(p.Tags) == 0 {
return post.Post{}, errors.New("ID, Title, Description, Tags, and Body are all required")
}
return p, nil
} }
func (a *api) postPostHandler() http.Handler { func (a *api) postPostHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
p := postFromPostReq(r) p, err := postFromPostReq(r)
if err != nil {
apiutil.BadRequest(rw, r, err)
return
}
if err := a.params.PostStore.Set(p, time.Now()); err != nil { if err := a.params.PostStore.Set(p, time.Now()); err != nil {
apiutil.InternalServerError( apiutil.InternalServerError(
@ -267,8 +279,14 @@ func (a *api) previewPostHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
p, err := postFromPostReq(r)
if err != nil {
apiutil.BadRequest(rw, r, err)
return
}
storedPost := post.StoredPost{ storedPost := post.StoredPost{
Post: postFromPostReq(r), Post: p,
PublishedAt: time.Now(), PublishedAt: time.Now(),
} }