From 069ee93de17579230ef749d5804df7a0ac350ac5 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Sun, 1 Aug 2021 17:54:53 -0600 Subject: [PATCH] implemented PoW backend --- srv/cmd/mediocre-blog/api.go | 24 +++ srv/cmd/mediocre-blog/main.go | 38 ++++- srv/cmd/mediocre-blog/pow.go | 22 +++ srv/go.mod | 2 +- srv/pow/pow.go | 288 ++++++++++++++++++++++++++++++++++ srv/pow/pow_test.go | 120 ++++++++++++++ srv/pow/store.go | 92 +++++++++++ srv/pow/store_test.go | 52 ++++++ 8 files changed, 635 insertions(+), 3 deletions(-) create mode 100644 srv/cmd/mediocre-blog/api.go create mode 100644 srv/cmd/mediocre-blog/pow.go create mode 100644 srv/pow/pow.go create mode 100644 srv/pow/pow_test.go create mode 100644 srv/pow/store.go create mode 100644 srv/pow/store_test.go diff --git a/srv/cmd/mediocre-blog/api.go b/srv/cmd/mediocre-blog/api.go new file mode 100644 index 0000000..b4f90d6 --- /dev/null +++ b/srv/cmd/mediocre-blog/api.go @@ -0,0 +1,24 @@ +package main + +import ( + "encoding/json" + "log" + "net/http" +) + +func internalServerError(rw http.ResponseWriter, r *http.Request, err error) { + http.Error(rw, "internal server error", 500) + log.Printf("%s %s: internal server error: %v", r.Method, r.URL, err) +} + +func jsonResult(rw http.ResponseWriter, r *http.Request, v interface{}) { + b, err := json.Marshal(v) + if err != nil { + internalServerError(rw, r, err) + return + } + b = append(b, '\n') + + rw.Header().Set("Content-Type", "application/json") + rw.Write(b) +} diff --git a/srv/cmd/mediocre-blog/main.go b/srv/cmd/mediocre-blog/main.go index a1b20be..2952999 100644 --- a/srv/cmd/mediocre-blog/main.go +++ b/srv/cmd/mediocre-blog/main.go @@ -4,20 +4,54 @@ import ( "flag" "log" "net/http" + "strconv" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/pow" + "github.com/tilinna/clock" ) func main() { staticDir := flag.String("static-dir", "", "Directory from which static files are served") - //redisAddr := flag.String("redis-addr", "127.0.0.1:6379", "Address which redis is listening on") listenAddr := flag.String("listen-addr", ":4000", "Address to listen for HTTP requests on") + powTargetStr := flag.String("pow-target", "0x000FFFF", "Proof-of-work target, lower is more difficult") + powSecret := flag.String("pow-secret", "", "Secret used to sign proof-of-work challenge seeds") + + // parse config + flag.Parse() - if *staticDir == "" { + switch { + case *staticDir == "": log.Fatal("-static-dir is required") + case *powSecret == "": + log.Fatal("-pow-secret is required") + } + + powTargetUint, err := strconv.ParseUint(*powTargetStr, 0, 32) + if err != nil { + log.Fatalf("parsing -pow-target: %v", err) } + powTarget := uint32(powTargetUint) + + // initialization + + clock := clock.Realtime() + + powStore := pow.NewMemoryStore(clock) + defer powStore.Close() + + mgr := pow.NewManager(pow.ManagerParams{ + Clock: clock, + Store: powStore, + Secret: []byte(*powSecret), + Target: powTarget, + }) mux := http.NewServeMux() mux.Handle("/", http.FileServer(http.Dir(*staticDir))) + mux.Handle("/api/pow/challenge", newPowChallengeHandler(mgr)) + + // run log.Printf("listening on %q", *listenAddr) log.Fatal(http.ListenAndServe(*listenAddr, mux)) diff --git a/srv/cmd/mediocre-blog/pow.go b/srv/cmd/mediocre-blog/pow.go new file mode 100644 index 0000000..22b82f3 --- /dev/null +++ b/srv/cmd/mediocre-blog/pow.go @@ -0,0 +1,22 @@ +package main + +import ( + "encoding/hex" + "net/http" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/pow" +) + +func newPowChallengeHandler(mgr pow.Manager) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + challenge := mgr.NewChallenge() + + jsonResult(rw, r, struct { + Seed string `json:"seed"` + Target uint32 `json:"target"` + }{ + Seed: hex.EncodeToString(challenge.Seed), + Target: challenge.Target, + }) + }) +} diff --git a/srv/go.mod b/srv/go.mod index 67f424e..f14c154 100644 --- a/srv/go.mod +++ b/srv/go.mod @@ -1,4 +1,4 @@ -module blog.mediocregopher.com/srv +module github.com/mediocregopher/blog.mediocregopher.com/srv go 1.16 diff --git a/srv/pow/pow.go b/srv/pow/pow.go new file mode 100644 index 0000000..3de1450 --- /dev/null +++ b/srv/pow/pow.go @@ -0,0 +1,288 @@ +// Package pow creates proof-of-work challenges and validates their solutions. +package pow + +import ( + "bytes" + "crypto/hmac" + "crypto/md5" + "crypto/rand" + "crypto/sha512" + "encoding/binary" + "errors" + "fmt" + "hash" + "time" + + "github.com/tilinna/clock" +) + +type challengeParams struct { + Target uint32 + ExpiresAt int64 + Random []byte +} + +func (c challengeParams) MarshalBinary() ([]byte, error) { + buf := new(bytes.Buffer) + + var err error + write := func(v interface{}) { + if err != nil { + return + } + err = binary.Write(buf, binary.BigEndian, v) + } + + write(c.Target) + write(c.ExpiresAt) + + if err != nil { + return nil, err + } + + if _, err := buf.Write(c.Random); err != nil { + panic(err) + } + + return buf.Bytes(), nil +} + +func (c *challengeParams) UnmarshalBinary(b []byte) error { + buf := bytes.NewBuffer(b) + + var err error + read := func(into interface{}) { + if err != nil { + return + } + err = binary.Read(buf, binary.BigEndian, into) + } + + read(&c.Target) + read(&c.ExpiresAt) + + if buf.Len() > 0 { + c.Random = buf.Bytes() // whatever is left + } + + return err +} + +// The seed takes the form: +// +// (version)+(signature of challengeParams)+(challengeParams) +// +// Version is currently always 0. +func newSeed(c challengeParams, secret []byte) ([]byte, error) { + buf := new(bytes.Buffer) + buf.WriteByte(0) // version + + cb, err := c.MarshalBinary() + if err != nil { + return nil, err + } + + h := hmac.New(md5.New, secret) + h.Write(cb) + buf.Write(h.Sum(nil)) + + buf.Write(cb) + + return buf.Bytes(), nil +} + +var errMalformedSeed = errors.New("malformed seed") + +func challengeParamsFromSeed(seed, secret []byte) (challengeParams, error) { + h := hmac.New(md5.New, secret) + hSize := h.Size() + + if len(seed) < hSize+1 || seed[0] != 0 { + return challengeParams{}, errMalformedSeed + } + seed = seed[1:] + + sig, cb := seed[:hSize], seed[hSize:] + + // check signature + h.Write(cb) + if !hmac.Equal(sig, h.Sum(nil)) { + return challengeParams{}, errMalformedSeed + } + + var c challengeParams + if err := c.UnmarshalBinary(cb); err != nil { + return challengeParams{}, fmt.Errorf("unmarshaling challenge parameters: %w", err) + } + + return c, nil +} + +// Challenge is a set of fields presented to a client, with which they must +// generate a solution. +// +// Generating a solution is done by: +// +// - Collect up to len(Seed) random bytes. These will be the potential +// solution. +// +// - Calculate the sha512 of the concatenation of Seed and PotentialSolution. +// +// - Parse the first 4 bytes of the sha512 result as a big-endian uint32. +// +// - If the resulting number is _less_ than Target, the solution has been +// found. Otherwise go back to step 1 and try again. +// +type Challenge struct { + Seed []byte + Target uint32 +} + +// Errors which may be produced by a Manager. +var ( + ErrInvalidSolution = errors.New("invalid solution") + ErrExpiredSolution = errors.New("expired solution") +) + +// Manager is used to both produce proof-of-work challenges and check their +// solutions. +type Manager interface { + NewChallenge() Challenge + + // Will produce ErrInvalidSolution if the solution is invalid, or + // ErrExpiredSolution if the solution has expired. + CheckSolution(seed, solution []byte) error +} + +// ManagerParams are used to initialize a new Manager instance. All fields are +// required unless otherwise noted. +type ManagerParams struct { + Clock clock.Clock + Store Store + + // Secret is used to sign each Challenge's Seed, it should _not_ be shared + // with clients. + Secret []byte + + // The Target which Challenges should hit. Lower is more difficult. + // + // Defaults to 0x00FFFFFF + Target uint32 + + // ChallengeTimeout indicates how long before Challenges are considered + // expired and cannot be solved. + // + // Defaults to 1 minute. + ChallengeTimeout time.Duration +} + +func (p ManagerParams) withDefaults() ManagerParams { + if p.Target == 0 { + p.Target = 0x00FFFFFF + } + if p.ChallengeTimeout == 0 { + p.ChallengeTimeout = 1 * time.Minute + } + return p +} + +type manager struct { + params ManagerParams +} + +// NewManager initializes and returns a Manager instance using the given +// parameters. +func NewManager(params ManagerParams) Manager { + return &manager{ + params: params, + } +} + +func (m *manager) NewChallenge() Challenge { + target := m.params.Target + + c := challengeParams{ + Target: target, + ExpiresAt: m.params.Clock.Now().Add(m.params.ChallengeTimeout).Unix(), + Random: make([]byte, 8), + } + + if _, err := rand.Read(c.Random); err != nil { + panic(err) + } + + seed, err := newSeed(c, m.params.Secret) + if err != nil { + panic(err) + } + + return Challenge{ + Seed: seed, + Target: target, + } +} + +// SolutionChecker can be used to check possible Challenge solutions. It will +// cache certain values internally to save on allocations when used in a loop +// (e.g. when generating a solution). +// +// SolutionChecker is not thread-safe. +type SolutionChecker struct { + h hash.Hash // sha512 + sum []byte +} + +// Check returns true if the given bytes are a solution to the given Challenge. +func (s SolutionChecker) Check(challenge Challenge, solution []byte) bool { + if s.h == nil { + s.h = sha512.New() + } + s.h.Reset() + + s.h.Write(challenge.Seed) + s.h.Write(solution) + s.sum = s.h.Sum(s.sum[:0]) + + i := binary.BigEndian.Uint32(s.sum[:4]) + return i < challenge.Target +} + +func (m *manager) CheckSolution(seed, solution []byte) error { + c, err := challengeParamsFromSeed(seed, m.params.Secret) + if err != nil { + return fmt.Errorf("parsing challenge parameters from seed: %w", err) + + } else if c.ExpiresAt <= m.params.Clock.Now().Unix() { + return ErrExpiredSolution + } + + ok := (SolutionChecker{}).Check( + Challenge{Seed: seed, Target: c.Target}, solution, + ) + + if !ok { + return ErrInvalidSolution + } + + expiresAt := time.Unix(c.ExpiresAt, 0) + if err := m.params.Store.MarkSolved(seed, expiresAt.Add(1*time.Minute)); err != nil { + return fmt.Errorf("marking solution as solved: %w", err) + } + + return nil +} + +// Solve returns a solution for the given Challenge. This may take a while. +func Solve(challenge Challenge) []byte { + + chk := SolutionChecker{} + b := make([]byte, len(challenge.Seed)) + + for { + if _, err := rand.Read(b); err != nil { + panic(err) + } else if chk.Check(challenge, b) { + return b + } + } +} diff --git a/srv/pow/pow_test.go b/srv/pow/pow_test.go new file mode 100644 index 0000000..4bc4141 --- /dev/null +++ b/srv/pow/pow_test.go @@ -0,0 +1,120 @@ +package pow + +import ( + "encoding/hex" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/tilinna/clock" +) + +func TestChallengeParams(t *testing.T) { + tests := []challengeParams{ + {}, + { + Target: 1, + ExpiresAt: 3, + }, + { + Target: 2, + ExpiresAt: -10, + Random: []byte{0, 1, 2}, + }, + { + Random: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + } + + t.Run("marshal_unmarshal", func(t *testing.T) { + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + b, err := test.MarshalBinary() + assert.NoError(t, err) + + var c2 challengeParams + assert.NoError(t, c2.UnmarshalBinary(b)) + assert.Equal(t, test, c2) + + b2, err := c2.MarshalBinary() + assert.NoError(t, err) + assert.Equal(t, b, b2) + }) + } + }) + + secret := []byte("shhh") + + t.Run("to_from_seed", func(t *testing.T) { + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + seed, err := newSeed(test, secret) + assert.NoError(t, err) + + // generating seed should be deterministic + seed2, err := newSeed(test, secret) + assert.NoError(t, err) + assert.Equal(t, seed, seed2) + + c, err := challengeParamsFromSeed(seed, secret) + assert.NoError(t, err) + assert.Equal(t, test, c) + }) + } + }) + + t.Run("malformed_seed", func(t *testing.T) { + tests := []string{ + "", + "01", + "0000", + "00374a1ad84d6b7a93e68042c1f850cbb100000000000000000000000000000102030405060708A0", // changed one byte from a good seed + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + seed, err := hex.DecodeString(test) + if err != nil { + panic(err) + } + + _, err = challengeParamsFromSeed(seed, secret) + assert.ErrorIs(t, errMalformedSeed, err) + }) + } + }) +} + +func TestManager(t *testing.T) { + clock := clock.NewMock(time.Now().Truncate(time.Hour)) + + store := NewMemoryStore(clock) + defer store.Close() + + mgr := NewManager(ManagerParams{ + Clock: clock, + Store: store, + Secret: []byte("shhhh"), + Target: 0x00FFFFFF, + ChallengeTimeout: 1 * time.Second, + }) + + { + c := mgr.NewChallenge() + solution := Solve(c) + assert.NoError(t, mgr.CheckSolution(c.Seed, solution)) + + // doing again should fail, the seed should already be marked as solved + assert.ErrorIs(t, mgr.CheckSolution(c.Seed, solution), ErrSeedSolved) + } + + { + c := mgr.NewChallenge() + solution := Solve(c) + clock.Add(2 * time.Second) + assert.ErrorIs(t, mgr.CheckSolution(c.Seed, solution), ErrExpiredSolution) + } + +} diff --git a/srv/pow/store.go b/srv/pow/store.go new file mode 100644 index 0000000..0b5e7d0 --- /dev/null +++ b/srv/pow/store.go @@ -0,0 +1,92 @@ +package pow + +import ( + "errors" + "sync" + "time" + + "github.com/tilinna/clock" +) + +// ErrSeedSolved is used to indicate a seed has already been solved. +var ErrSeedSolved = errors.New("seed already solved") + +// Store is used to track information related to proof-of-work challenges and +// solutions. +type Store interface { + + // MarkSolved will return ErrSeedSolved if the seed was already marked. The + // seed will be cleared from the Store once expiresAt is reached. + MarkSolved(seed []byte, expiresAt time.Time) error + + Close() error +} + +type inMemStore struct { + clock clock.Clock + + m map[string]time.Time + l sync.Mutex + closeCh chan struct{} + spinLoopCh chan struct{} // only used by tests +} + +const inMemStoreGCPeriod = 5 * time.Second + +// NewMemoryStore initializes and returns an in-memory Store implementation. +func NewMemoryStore(clock clock.Clock) Store { + s := &inMemStore{ + clock: clock, + m: map[string]time.Time{}, + closeCh: make(chan struct{}), + spinLoopCh: make(chan struct{}, 1), + } + go s.spin(s.clock.NewTicker(inMemStoreGCPeriod)) + return s +} + +func (s *inMemStore) spin(ticker *clock.Ticker) { + defer ticker.Stop() + + for { + select { + case <-ticker.C: + now := s.clock.Now() + + s.l.Lock() + for seed, expiresAt := range s.m { + if !now.Before(expiresAt) { + delete(s.m, seed) + } + } + s.l.Unlock() + + case <-s.closeCh: + return + } + + select { + case s.spinLoopCh <- struct{}{}: + default: + } + } +} + +func (s *inMemStore) MarkSolved(seed []byte, expiresAt time.Time) error { + seedStr := string(seed) + + s.l.Lock() + defer s.l.Unlock() + + if _, ok := s.m[seedStr]; ok { + return ErrSeedSolved + } + + s.m[seedStr] = expiresAt + return nil +} + +func (s *inMemStore) Close() error { + close(s.closeCh) + return nil +} diff --git a/srv/pow/store_test.go b/srv/pow/store_test.go new file mode 100644 index 0000000..324a40c --- /dev/null +++ b/srv/pow/store_test.go @@ -0,0 +1,52 @@ +package pow + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/tilinna/clock" +) + +func TestStore(t *testing.T) { + clock := clock.NewMock(time.Now().Truncate(time.Hour)) + now := clock.Now() + + s := NewMemoryStore(clock) + defer s.Close() + + seed := []byte{0} + + // mark solved should work + err := s.MarkSolved(seed, now.Add(time.Second)) + assert.NoError(t, err) + + // mark again, should not work + err = s.MarkSolved(seed, now.Add(time.Hour)) + assert.ErrorIs(t, err, ErrSeedSolved) + + // marking a different seed should still work + seed2 := []byte{1} + err = s.MarkSolved(seed2, now.Add(inMemStoreGCPeriod*2)) + assert.NoError(t, err) + err = s.MarkSolved(seed2, now.Add(time.Hour)) + assert.ErrorIs(t, err, ErrSeedSolved) + + now = clock.Add(inMemStoreGCPeriod) + <-s.(*inMemStore).spinLoopCh + + // first one should be markable again, second shouldnt + err = s.MarkSolved(seed, now.Add(time.Second)) + assert.NoError(t, err) + err = s.MarkSolved(seed2, now.Add(time.Hour)) + assert.ErrorIs(t, err, ErrSeedSolved) + + now = clock.Add(inMemStoreGCPeriod) + <-s.(*inMemStore).spinLoopCh + + // now both should be expired + err = s.MarkSolved(seed, now.Add(time.Second)) + assert.NoError(t, err) + err = s.MarkSolved(seed2, now.Add(time.Second)) + assert.NoError(t, err) +}