totp-proxy: refactor to use new Component logic

This commit is contained in:
Brian Picciano 2019-06-17 18:18:50 -04:00
parent e91ef01857
commit 7a5ac9caa0
3 changed files with 32 additions and 14 deletions

View File

@ -21,6 +21,7 @@ import (
"github.com/mediocregopher/mediocre-go-lib/mcfg" "github.com/mediocregopher/mediocre-go-lib/mcfg"
"github.com/mediocregopher/mediocre-go-lib/mcrypto" "github.com/mediocregopher/mediocre-go-lib/mcrypto"
"github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/mctx"
"github.com/mediocregopher/mediocre-go-lib/merr"
"github.com/mediocregopher/mediocre-go-lib/mhttp" "github.com/mediocregopher/mediocre-go-lib/mhttp"
"github.com/mediocregopher/mediocre-go-lib/mlog" "github.com/mediocregopher/mediocre-go-lib/mlog"
"github.com/mediocregopher/mediocre-go-lib/mrand" "github.com/mediocregopher/mediocre-go-lib/mrand"
@ -30,30 +31,39 @@ import (
) )
func main() { func main() {
ctx := m.ServiceContext() cmp := m.RootServiceComponent()
ctx, cookieName := mcfg.WithString(ctx, "cookie-name", "_totp_proxy", "String to use as the name for cookies") cookieName := mcfg.String(cmp, "cookie-name",
ctx, cookieTimeout := mcfg.WithDuration(ctx, "cookie-timeout", mtime.Duration{1 * time.Hour}, "Timeout for cookies") mcfg.ParamDefault("_totp_proxy"),
mcfg.ParamUsage("String to use as the name for cookies"))
cookieTimeout := mcfg.Duration(cmp, "cookie-timeout",
mcfg.ParamDefault(mtime.Duration{1 * time.Hour}),
mcfg.ParamUsage("Timeout for cookies"))
var userSecrets map[string]string var userSecrets map[string]string
ctx = mcfg.WithRequiredJSON(ctx, "users", &userSecrets, "JSON object which maps usernames to their TOTP secret strings") mcfg.JSON(cmp, "users", &userSecrets,
mcfg.ParamRequired(),
mcfg.ParamUsage("JSON object which maps usernames to their TOTP secret strings"))
var secret mcrypto.Secret var secret mcrypto.Secret
ctx, secretStr := mcfg.WithString(ctx, "secret", "", "String used to sign authentication tokens. If one isn't given a new one will be generated on each startup, invalidating all previous tokens.") secretStr := mcfg.String(cmp, "secret",
ctx = mrun.WithStartHook(ctx, func(context.Context) error { mcfg.ParamUsage("String used to sign authentication tokens. If one isn't given a new one will be generated on each startup, invalidating all previous tokens."))
mrun.InitHook(cmp, func(context.Context) error {
if *secretStr == "" { if *secretStr == "" {
*secretStr = mrand.Hex(32) *secretStr = mrand.Hex(32)
} }
mlog.Info("generating secret", ctx) mlog.From(cmp).Info("generating secret")
secret = mcrypto.NewSecret([]byte(*secretStr)) secret = mcrypto.NewSecret([]byte(*secretStr))
return nil return nil
}) })
proxyHandler := new(struct{ http.Handler }) proxyHandler := new(struct{ http.Handler })
ctx, proxyURL := mcfg.WithRequiredString(ctx, "dst-url", "URL to proxy requests to. Only the scheme and host should be set.") proxyURL := mcfg.String(cmp, "dst-url",
ctx = mrun.WithStartHook(ctx, func(context.Context) error { mcfg.ParamRequired(),
mcfg.ParamUsage("URL to proxy requests to. Only the scheme and host should be set."))
mrun.InitHook(cmp, func(context.Context) error {
u, err := url.Parse(*proxyURL) u, err := url.Parse(*proxyURL)
if err != nil { if err != nil {
return err return merr.Wrap(err, cmp.Context())
} }
proxyHandler.Handler = mhttp.ReverseProxy(u) proxyHandler.Handler = mhttp.ReverseProxy(u)
return nil return nil
@ -64,11 +74,13 @@ func main() {
ctx := r.Context() ctx := r.Context()
unauthorized := func() { unauthorized := func() {
mlog.From(cmp).Debug("connection is unauthorized")
w.Header().Add("WWW-Authenticate", "Basic") w.Header().Add("WWW-Authenticate", "Basic")
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
} }
authorized := func() { authorized := func() {
mlog.From(cmp).Debug("connection is authorized, rewriting cookies")
sig := mcrypto.SignString(secret, "") sig := mcrypto.SignString(secret, "")
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: *cookieName, Name: *cookieName,
@ -79,7 +91,7 @@ func main() {
} }
if cookie, _ := r.Cookie(*cookieName); cookie != nil { if cookie, _ := r.Cookie(*cookieName); cookie != nil {
mlog.Debug("authenticating with cookie", mlog.From(cmp).Debug("authenticating with cookie",
mctx.Annotate(ctx, "cookie", cookie.String())) mctx.Annotate(ctx, "cookie", cookie.String()))
var sig mcrypto.Signature var sig mcrypto.Signature
if err := sig.UnmarshalText([]byte(cookie.Value)); err == nil { if err := sig.UnmarshalText([]byte(cookie.Value)); err == nil {
@ -92,7 +104,7 @@ func main() {
} }
if user, pass, ok := r.BasicAuth(); ok && pass != "" { if user, pass, ok := r.BasicAuth(); ok && pass != "" {
mlog.Debug("authenticating with user", mlog.From(cmp).Debug("authenticating with user",
mctx.Annotate(ctx, "user", user)) mctx.Annotate(ctx, "user", user))
if userSecret, ok := userSecrets[user]; ok { if userSecret, ok := userSecrets[user]; ok {
if totp.Validate(pass, userSecret) { if totp.Validate(pass, userSecret) {
@ -105,6 +117,6 @@ func main() {
unauthorized() unauthorized()
}) })
ctx, _ = mhttp.WithListeningServer(ctx, authHandler) mhttp.InstListeningServer(cmp, authHandler)
m.StartWaitStop(ctx) m.Exec(cmp)
} }

1
m/m.go
View File

@ -76,6 +76,7 @@ func RootComponent() *mcmp.Component {
mctx.Annotated("log-level", *logLevelStr)) mctx.Annotated("log-level", *logLevelStr))
} }
logger.SetMaxLevel(logLevel) logger.SetMaxLevel(logLevel)
mlog.SetLogger(cmp, logger)
return nil return nil
}) })

View File

@ -9,6 +9,11 @@ type cmpKey int
// SetLogger sets the given logger onto the Component. The logger can later be // SetLogger sets the given logger onto the Component. The logger can later be
// retrieved from the Component, or any of its children, using From. // retrieved from the Component, or any of its children, using From.
//
// NOTE that if a Logger is set onto a Component and then changed, even though
// the Logger is a pointer and so is changed within the Component, SetLogger
// should still be called. This is due to some caching that From does for
// performance.
func SetLogger(cmp *mcmp.Component, l *Logger) { func SetLogger(cmp *mcmp.Component, l *Logger) {
cmp.SetValue(cmpKey(0), l) cmp.SetValue(cmpKey(0), l)