Implement ratelimit on authentications

This commit is contained in:
Brian Picciano 2022-05-20 14:54:26 -06:00
parent ae1fa76efc
commit 1ffda21ae3
6 changed files with 84 additions and 23 deletions

View File

@ -16,4 +16,7 @@
httpAuthUsers = {
"foo" = "$2a$13$0JdWlUfHc.3XimEMpEu1cuu6RodhUvzD9l7iiAqa4YkM3mcFV5Pxi";
};
# Very low, should be increased for prod.
httpAuthRatelimit = "1s";
}

View File

@ -29,6 +29,8 @@
export MEDIOCRE_BLOG_LISTEN_PROTO="${config.listenProto}"
export MEDIOCRE_BLOG_LISTEN_ADDR="${config.listenAddr}"
export MEDIOCRE_BLOG_HTTP_AUTH_USERS='${builtins.toJSON config.httpAuthUsers}'
export MEDIOCRE_BLOG_HTTP_AUTH_RATELIMIT='${config.httpAuthRatelimit}'
'';
build = buildGoModule {

View File

@ -11,6 +11,7 @@ import (
"net"
"net/http"
"os"
"time"
"github.com/mediocregopher/blog.mediocregopher.com/srv/cfg"
"github.com/mediocregopher/blog.mediocregopher.com/srv/chat"
@ -52,6 +53,10 @@ type Params struct {
// and the values are the password hash which accompanies those users. The
// password hash must have been produced by NewPasswordHash.
AuthUsers map[string]string
// AuthRatelimit indicates how much time must pass between subsequent auth
// attempts.
AuthRatelimit time.Duration
}
// SetupCfg implement the cfg.Cfger interface.
@ -61,10 +66,20 @@ func (p *Params) SetupCfg(cfg *cfg.Cfg) {
httpAuthUsersStr := cfg.String("http-auth-users", "{}", "JSON object with usernames as values and password hashes (produced by the hash-password binary) as values. Denotes users which are able to edit server-side data")
httpAuthRatelimitStr := cfg.String("http-auth-ratelimit", "5s", "Minimum duration which must be waited between subsequent auth attempts")
cfg.OnInit(func(context.Context) error {
if err := json.Unmarshal([]byte(*httpAuthUsersStr), &p.AuthUsers); err != nil {
err := json.Unmarshal([]byte(*httpAuthUsersStr), &p.AuthUsers)
if err != nil {
return fmt.Errorf("unmarshaling -http-auth-users: %w", err)
}
if p.AuthRatelimit, err = time.ParseDuration(*httpAuthRatelimitStr); err != nil {
return fmt.Errorf("unmarshaling -http-auth-ratelimit: %w", err)
}
return nil
})
}
@ -73,6 +88,7 @@ func (p *Params) SetupCfg(cfg *cfg.Cfg) {
func (p *Params) Annotate(a mctx.Annotations) {
a["listenProto"] = p.ListenProto
a["listenAddr"] = p.ListenAddr
a["authRatelimit"] = p.AuthRatelimit
}
// API will listen on the port configured for it, and serve HTTP requests for
@ -86,6 +102,7 @@ type api struct {
srv *http.Server
redirectTpl *template.Template
auther Auther
}
// New initializes and returns a new API instance, including setting up all
@ -105,6 +122,7 @@ func New(params Params) (API, error) {
a := &api{
params: params,
auther: NewAuther(params.AuthUsers, params.AuthRatelimit),
}
a.redirectTpl = a.mustParseTpl("redirect.html")
@ -124,6 +142,7 @@ func New(params Params) (API, error) {
}
func (a *api) Shutdown(ctx context.Context) error {
defer a.auther.Close()
if err := a.srv.Shutdown(ctx); err != nil {
return err
}
@ -149,8 +168,6 @@ func (a *api) handler() http.Handler {
return h
}
auther := NewAuther(a.params.AuthUsers)
mux := http.NewServeMux()
{
@ -179,13 +196,13 @@ func (a *api) handler() http.Handler {
apiutil.MethodMux(map[string]http.Handler{
"GET": a.renderPostHandler(),
"EDIT": a.editPostHandler(),
"POST": authMiddleware(auther,
"POST": authMiddleware(a.auther,
formMiddleware(a.postPostHandler()),
),
"DELETE": authMiddleware(auther,
"DELETE": authMiddleware(a.auther,
formMiddleware(a.deletePostHandler()),
),
"PREVIEW": authMiddleware(auther,
"PREVIEW": authMiddleware(a.auther,
formMiddleware(a.previewPostHandler()),
),
}),
@ -194,10 +211,10 @@ func (a *api) handler() http.Handler {
mux.Handle("/assets/", http.StripPrefix("/assets",
apiutil.MethodMux(map[string]http.Handler{
"GET": a.getPostAssetHandler(),
"POST": authMiddleware(auther,
"POST": authMiddleware(a.auther,
formMiddleware(a.postPostAssetHandler()),
),
"DELETE": authMiddleware(auther,
"DELETE": authMiddleware(a.auther,
formMiddleware(a.deletePostAssetHandler()),
),
}),

View File

@ -1,7 +1,9 @@
package http
import (
"context"
"net/http"
"time"
"github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil"
"golang.org/x/crypto/bcrypt"
@ -19,21 +21,37 @@ func NewPasswordHash(plaintext string) string {
// Auther determines who can do what.
type Auther interface {
Allowed(username, password string) bool
Allowed(ctx context.Context, username, password string) bool
Close() error
}
type auther struct {
users map[string]string
ticker *time.Ticker
}
// NewAuther initializes and returns an Auther will which allow the given
// username and password hash combinations. Password hashes must have been
// created using NewPasswordHash.
func NewAuther(users map[string]string) Auther {
return &auther{users: users}
func NewAuther(users map[string]string, ratelimit time.Duration) Auther {
return &auther{
users: users,
ticker: time.NewTicker(ratelimit),
}
}
func (a *auther) Allowed(username, password string) bool {
func (a *auther) Close() error {
a.ticker.Stop()
return nil
}
func (a *auther) Allowed(ctx context.Context, username, password string) bool {
select {
case <-ctx.Done():
return false
case <-a.ticker.C:
}
hashedPassword, ok := a.users[username]
if !ok {
@ -64,7 +82,7 @@ func authMiddleware(auther Auther, h http.Handler) http.Handler {
return
}
if !auther.Allowed(username, password) {
if !auther.Allowed(r.Context(), username, password) {
respondUnauthorized(rw, r)
return
}

View File

@ -1,21 +1,24 @@
package http
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestAuther(t *testing.T) {
ctx := context.Background()
password := "foo"
hashedPassword := NewPasswordHash(password)
auther := NewAuther(map[string]string{
"FOO": hashedPassword,
})
}, 1*time.Millisecond)
assert.False(t, auther.Allowed("BAR", password))
assert.False(t, auther.Allowed("FOO", "bar"))
assert.True(t, auther.Allowed("FOO", password))
assert.False(t, auther.Allowed(ctx, "BAR", password))
assert.False(t, auther.Allowed(ctx, "FOO", "bar"))
assert.True(t, auther.Allowed(ctx, "FOO", password))
}

View File

@ -197,7 +197,7 @@ func (a *api) editPostHandler() http.Handler {
})
}
func postFromPostReq(r *http.Request) post.Post {
func postFromPostReq(r *http.Request) (post.Post, error) {
p := post.Post{
ID: r.PostFormValue("id"),
@ -207,18 +207,30 @@ func postFromPostReq(r *http.Request) post.Post {
Series: r.PostFormValue("series"),
}
p.Body = strings.TrimSpace(r.PostFormValue("body"))
// textareas encode newlines as CRLF for historical reasons
p.Body = strings.ReplaceAll(p.Body, "\r\n", "\n")
p.Body = strings.TrimSpace(r.PostFormValue("body"))
return p
if p.ID == "" ||
p.Title == "" ||
p.Description == "" ||
p.Body == "" ||
len(p.Tags) == 0 {
return post.Post{}, errors.New("ID, Title, Description, Tags, and Body are all required")
}
return p, nil
}
func (a *api) postPostHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
p := postFromPostReq(r)
p, err := postFromPostReq(r)
if err != nil {
apiutil.BadRequest(rw, r, err)
return
}
if err := a.params.PostStore.Set(p, time.Now()); err != nil {
apiutil.InternalServerError(
@ -267,8 +279,14 @@ func (a *api) previewPostHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
p, err := postFromPostReq(r)
if err != nil {
apiutil.BadRequest(rw, r, err)
return
}
storedPost := post.StoredPost{
Post: postFromPostReq(r),
Post: p,
PublishedAt: time.Now(),
}