diff --git a/mdb/mpubsub/pubsub.go b/mdb/mpubsub/pubsub.go index d228a0d..31e23be 100644 --- a/mdb/mpubsub/pubsub.go +++ b/mdb/mpubsub/pubsub.go @@ -9,15 +9,20 @@ import ( "time" "cloud.google.com/go/pubsub" - "github.com/mediocregopher/mediocre-go-lib/m" - "github.com/mediocregopher/mediocre-go-lib/mcfg" + "github.com/mediocregopher/mediocre-go-lib/mctx" "github.com/mediocregopher/mediocre-go-lib/mdb" + "github.com/mediocregopher/mediocre-go-lib/merr" "github.com/mediocregopher/mediocre-go-lib/mlog" + "github.com/mediocregopher/mediocre-go-lib/mrun" oldctx "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) +// TODO this package still uses context.Context in the callback functions +// TODO Consume (and probably BatchConsume) don't properly handle the Client +// being closed. + func isErrAlreadyExists(err error) bool { if err == nil { return false @@ -37,24 +42,39 @@ type PubSub struct { log *mlog.Logger } -// Cfg configures and returns a PubSub instance which will be usable once -// StartRun is called on the passed in Cfg instance. -func Cfg(cfg *mcfg.Cfg) *PubSub { - cfg = cfg.Child("pubsub") - var ps PubSub - ps.gce = mdb.CfgGCE(cfg) - ps.log = m.Log(cfg, &ps) - cfg.Start.Then(func(ctx context.Context) error { +// MNew returns a PubSub instance which will be initialized and configured when +// the start event is triggered on ctx (see mrun.Start). The PubSub instance +// will have Close called on it when the stop event is triggered on ctx (see +// mrun.Stop). +// +// gce is optional and can be passed in if there's an existing gce object which +// should be used, otherwise a new one will be created with mdb.MGCE. +func MNew(ctx mctx.Context, gce *mdb.GCE) *PubSub { + if gce == nil { + gce = mdb.MGCE(ctx, "") + } + + ctx = mctx.ChildOf(ctx, "pubsub") + ps := &PubSub{ + gce: gce, + log: mlog.From(ctx), + } + ps.log.SetKV(ps) + + mrun.OnStart(ctx, func(innerCtx mctx.Context) error { ps.log.Info("connecting to pubsub") var err error - ps.Client, err = pubsub.NewClient(ctx, ps.gce.Project, ps.gce.ClientOptions()...) - return mlog.ErrWithKV(err, &ps) + ps.Client, err = pubsub.NewClient(innerCtx, ps.gce.Project, ps.gce.ClientOptions()...) + return merr.WithKV(err, ps.KV()) }) - return &ps + mrun.OnStop(ctx, func(mctx.Context) error { + return ps.Client.Close() + }) + return ps } // KV implements the mlog.KVer interface -func (ps *PubSub) KV() mlog.KV { +func (ps *PubSub) KV() map[string]interface{} { return ps.gce.KV() } @@ -67,44 +87,43 @@ type Topic struct { // Topic returns, after potentially creating, a topic of the given name func (ps *PubSub) Topic(ctx context.Context, name string, create bool) (*Topic, error) { - kv := mlog.KVerFunc(func() mlog.KV { - return ps.KV().Set("topicName", name) - }) + t := &Topic{ + ps: ps, + name: name, + } - var t *pubsub.Topic var err error if create { - t, err = ps.Client.CreateTopic(ctx, name) + t.topic, err = ps.Client.CreateTopic(ctx, name) if isErrAlreadyExists(err) { - t = ps.Client.Topic(name) + t.topic = ps.Client.Topic(name) } else if err != nil { - return nil, mlog.ErrWithKV(err, kv) + return nil, merr.WithKV(err, t.KV()) } } else { - t = ps.Client.Topic(name) - if exists, err := t.Exists(ctx); err != nil { - return nil, mlog.ErrWithKV(err, kv) + t.topic = ps.Client.Topic(name) + if exists, err := t.topic.Exists(ctx); err != nil { + return nil, merr.WithKV(err, t.KV()) } else if !exists { - return nil, mlog.ErrWithKV(errors.New("topic dne"), kv) + err := merr.New("topic dne") + return nil, merr.WithKV(err, t.KV()) } } - return &Topic{ - ps: ps, - topic: t, - name: name, - }, nil + return t, nil } // KV implements the mlog.KVer interface -func (t *Topic) KV() mlog.KV { - return t.ps.KV().Set("topicName", t.name) +func (t *Topic) KV() map[string]interface{} { + kv := t.ps.KV() + kv["topicName"] = t.name + return kv } // Publish publishes a message with the given data as its body to the Topic func (t *Topic) Publish(ctx context.Context, data []byte) error { _, err := t.topic.Publish(ctx, &Message{Data: data}).Get(ctx) if err != nil { - return mlog.ErrWithKV(err, t) + return merr.WithKV(err, t.KV()) } return nil } @@ -123,39 +142,38 @@ type Subscription struct { // for the Topic func (t *Topic) Subscription(ctx context.Context, name string, create bool) (*Subscription, error) { name = t.name + "_" + name - kv := mlog.KVerFunc(func() mlog.KV { - return t.KV().Set("subName", name) - }) + s := &Subscription{ + topic: t, + name: name, + } - var s *pubsub.Subscription var err error if create { - s, err = t.ps.CreateSubscription(ctx, name, pubsub.SubscriptionConfig{ + s.sub, err = t.ps.CreateSubscription(ctx, name, pubsub.SubscriptionConfig{ Topic: t.topic, }) if isErrAlreadyExists(err) { - s = t.ps.Subscription(name) + s.sub = t.ps.Subscription(name) } else if err != nil { - return nil, mlog.ErrWithKV(err, kv) + return nil, merr.WithKV(err, s.KV()) } } else { - s = t.ps.Subscription(name) - if exists, err := s.Exists(ctx); err != nil { - return nil, mlog.ErrWithKV(err, kv) + s.sub = t.ps.Subscription(name) + if exists, err := s.sub.Exists(ctx); err != nil { + return nil, merr.WithKV(err, s.KV()) } else if !exists { - return nil, mlog.ErrWithKV(errors.New("sub dne"), kv) + err := merr.New("sub dne") + return nil, merr.WithKV(err, s.KV()) } } - return &Subscription{ - topic: t, - sub: s, - name: name, - }, nil + return s, nil } // KV implements the mlog.KVer interface -func (s *Subscription) KV() mlog.KV { - return s.topic.KV().Set("subName", s.name) +func (s *Subscription) KV() map[string]interface{} { + kv := s.topic.KV() + kv["subName"] = s.name + return kv } // ConsumerFunc is a function which messages being consumed will be passed. The @@ -208,7 +226,7 @@ func (s *Subscription) Consume(ctx context.Context, fn ConsumerFunc, opts Consum ok, err := fn(context.Context(innerCtx), msg) if err != nil { - s.topic.ps.log.Warn("error consuming pubsub message", s, mlog.ErrKV(err)) + s.topic.ps.log.Warn("error consuming pubsub message", s, merr.KV(err)) } if ok { @@ -220,7 +238,8 @@ func (s *Subscription) Consume(ctx context.Context, fn ConsumerFunc, opts Consum if octx.Err() == context.Canceled || err == nil { return } else if err != nil { - s.topic.ps.log.Warn("error consuming from pubsub", s, mlog.ErrKV(err)) + s.topic.ps.log.Warn("error consuming from pubsub", s, merr.KV(err)) + time.Sleep(1 * time.Second) } } } @@ -312,7 +331,7 @@ func (s *Subscription) BatchConsume( } ret, err := fn(thisCtx, msgs) if err != nil { - s.topic.ps.log.Warn("error consuming pubsub batch messages", s, mlog.ErrKV(err)) + s.topic.ps.log.Warn("error consuming pubsub batch messages", s, merr.KV(err)) } for i := range thisGroup { thisGroup[i].retCh <- ret // retCh is buffered diff --git a/mdb/mpubsub/pubsub_test.go b/mdb/mpubsub/pubsub_test.go index 7867f7f..e516d26 100644 --- a/mdb/mpubsub/pubsub_test.go +++ b/mdb/mpubsub/pubsub_test.go @@ -5,169 +5,169 @@ import ( . "testing" "time" - "github.com/mediocregopher/mediocre-go-lib/mcfg" - "github.com/mediocregopher/mediocre-go-lib/mdb" "github.com/mediocregopher/mediocre-go-lib/mrand" + "github.com/mediocregopher/mediocre-go-lib/mtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -var testPS *PubSub - -func init() { - mdb.DefaultGCEProject = "test" - cfg := mcfg.New() - testPS = Cfg(cfg) - cfg.StartTestRun() -} - // this requires the pubsub emulator to be running func TestPubSub(t *T) { - topicName := "testTopic_" + mrand.Hex(8) - ctx := context.Background() + ctx := mtest.NewCtx() + mtest.SetEnv(ctx, "GCE_PROJECT", "test") + ps := MNew(ctx, nil) + mtest.Run(ctx, t, func() { + topicName := "testTopic_" + mrand.Hex(8) + ctx := context.Background() - // Topic shouldn't exist yet - _, err := testPS.Topic(ctx, topicName, false) - require.Error(t, err) + // Topic shouldn't exist yet + _, err := ps.Topic(ctx, topicName, false) + require.Error(t, err) - // ...so create it - topic, err := testPS.Topic(ctx, topicName, true) - require.NoError(t, err) + // ...so create it + topic, err := ps.Topic(ctx, topicName, true) + require.NoError(t, err) - // Create a subscription and consumer - sub, err := topic.Subscription(ctx, "testSub", true) - require.NoError(t, err) + // Create a subscription and consumer + sub, err := topic.Subscription(ctx, "testSub", true) + require.NoError(t, err) - msgCh := make(chan *Message) - go sub.Consume(ctx, func(ctx context.Context, m *Message) (bool, error) { - msgCh <- m - return true, nil - }, ConsumerOpts{}) - time.Sleep(1 * time.Second) // give consumer time to actually start + msgCh := make(chan *Message) + go sub.Consume(ctx, func(ctx context.Context, m *Message) (bool, error) { + msgCh <- m + return true, nil + }, ConsumerOpts{}) + time.Sleep(1 * time.Second) // give consumer time to actually start - // publish a message and make sure it gets consumed - assert.NoError(t, topic.Publish(ctx, []byte("foo"))) - msg := <-msgCh - assert.Equal(t, []byte("foo"), msg.Data) + // publish a message and make sure it gets consumed + assert.NoError(t, topic.Publish(ctx, []byte("foo"))) + msg := <-msgCh + assert.Equal(t, []byte("foo"), msg.Data) + }) } func TestBatchPubSub(t *T) { - ctx := context.Background() - topicName := "testBatchTopic_" + mrand.Hex(8) - topic, err := testPS.Topic(ctx, topicName, true) - require.NoError(t, err) + ctx := mtest.NewCtx() + mtest.SetEnv(ctx, "GCE_PROJECT", "test") + ps := MNew(ctx, nil) + mtest.Run(ctx, t, func() { - readBatch := func(ch chan []*Message) map[byte]int { - select { - case <-time.After(1 * time.Second): - assert.Fail(t, "waited too long to read batch") - return nil - case mm := <-ch: - ret := map[byte]int{} - for _, m := range mm { - ret[m.Data[0]]++ - } - return ret - } - } + topicName := "testBatchTopic_" + mrand.Hex(8) + topic, err := ps.Topic(ctx, topicName, true) + require.NoError(t, err) - // we use the same sub across the next two sections to ensure that cleanup - // also works - sub, err := topic.Subscription(ctx, "testSub", true) - require.NoError(t, err) - sub.batchTestTrigger = make(chan bool) - - { // no grouping - // Create a subscription and consumer - ctx, cancel := context.WithCancel(ctx) - ch := make(chan []*Message) - go func() { - sub.BatchConsume(ctx, - func(ctx context.Context, mm []*Message) (bool, error) { - ch <- mm - return true, nil - }, - nil, - ConsumerOpts{Concurrent: 5}, - ) - close(ch) - }() - time.Sleep(1 * time.Second) // give consumer time to actually start - - exp := map[byte]int{} - for i := byte(0); i <= 9; i++ { - require.NoError(t, topic.Publish(ctx, []byte{i})) - exp[i] = 1 - } - - time.Sleep(1 * time.Second) - sub.batchTestTrigger <- true - gotA := readBatch(ch) - assert.Len(t, gotA, 5) - - time.Sleep(1 * time.Second) - sub.batchTestTrigger <- true - gotB := readBatch(ch) - assert.Len(t, gotB, 5) - - for i, c := range gotB { - gotA[i] += c - } - assert.Equal(t, exp, gotA) - - time.Sleep(1 * time.Second) // give time to ack before cancelling - cancel() - <-ch - } - - { // with grouping - ctx, cancel := context.WithCancel(ctx) - ch := make(chan []*Message) - go func() { - sub.BatchConsume(ctx, - func(ctx context.Context, mm []*Message) (bool, error) { - ch <- mm - return true, nil - }, - func(a, b *Message) bool { return a.Data[0]%2 == b.Data[0]%2 }, - ConsumerOpts{Concurrent: 10}, - ) - close(ch) - }() - time.Sleep(1 * time.Second) // give consumer time to actually start - - exp := map[byte]int{} - for i := byte(0); i <= 9; i++ { - require.NoError(t, topic.Publish(ctx, []byte{i})) - exp[i] = 1 - } - - time.Sleep(1 * time.Second) - sub.batchTestTrigger <- true - gotA := readBatch(ch) - assert.Len(t, gotA, 5) - gotB := readBatch(ch) - assert.Len(t, gotB, 5) - - assertGotGrouped := func(got map[byte]int) { - prev := byte(255) - for i := range got { - if prev != 255 { - assert.Equal(t, prev%2, i%2) + readBatch := func(ch chan []*Message) map[byte]int { + select { + case <-time.After(1 * time.Second): + assert.Fail(t, "waited too long to read batch") + return nil + case mm := <-ch: + ret := map[byte]int{} + for _, m := range mm { + ret[m.Data[0]]++ } - prev = i + return ret } } - assertGotGrouped(gotA) - assertGotGrouped(gotB) - for i, c := range gotB { - gotA[i] += c - } - assert.Equal(t, exp, gotA) + // we use the same sub across the next two sections to ensure that cleanup + // also works + sub, err := topic.Subscription(ctx, "testSub", true) + require.NoError(t, err) + sub.batchTestTrigger = make(chan bool) - time.Sleep(1 * time.Second) // give time to ack before cancelling - cancel() - <-ch - } + { // no grouping + // Create a subscription and consumer + ctx, cancel := context.WithCancel(ctx) + ch := make(chan []*Message) + go func() { + sub.BatchConsume(ctx, + func(ctx context.Context, mm []*Message) (bool, error) { + ch <- mm + return true, nil + }, + nil, + ConsumerOpts{Concurrent: 5}, + ) + close(ch) + }() + time.Sleep(1 * time.Second) // give consumer time to actually start + + exp := map[byte]int{} + for i := byte(0); i <= 9; i++ { + require.NoError(t, topic.Publish(ctx, []byte{i})) + exp[i] = 1 + } + + time.Sleep(1 * time.Second) + sub.batchTestTrigger <- true + gotA := readBatch(ch) + assert.Len(t, gotA, 5) + + time.Sleep(1 * time.Second) + sub.batchTestTrigger <- true + gotB := readBatch(ch) + assert.Len(t, gotB, 5) + + for i, c := range gotB { + gotA[i] += c + } + assert.Equal(t, exp, gotA) + + time.Sleep(1 * time.Second) // give time to ack before cancelling + cancel() + <-ch + } + + { // with grouping + ctx, cancel := context.WithCancel(ctx) + ch := make(chan []*Message) + go func() { + sub.BatchConsume(ctx, + func(ctx context.Context, mm []*Message) (bool, error) { + ch <- mm + return true, nil + }, + func(a, b *Message) bool { return a.Data[0]%2 == b.Data[0]%2 }, + ConsumerOpts{Concurrent: 10}, + ) + close(ch) + }() + time.Sleep(1 * time.Second) // give consumer time to actually start + + exp := map[byte]int{} + for i := byte(0); i <= 9; i++ { + require.NoError(t, topic.Publish(ctx, []byte{i})) + exp[i] = 1 + } + + time.Sleep(1 * time.Second) + sub.batchTestTrigger <- true + gotA := readBatch(ch) + assert.Len(t, gotA, 5) + gotB := readBatch(ch) + assert.Len(t, gotB, 5) + + assertGotGrouped := func(got map[byte]int) { + prev := byte(255) + for i := range got { + if prev != 255 { + assert.Equal(t, prev%2, i%2) + } + prev = i + } + } + + assertGotGrouped(gotA) + assertGotGrouped(gotB) + for i, c := range gotB { + gotA[i] += c + } + assert.Equal(t, exp, gotA) + + time.Sleep(1 * time.Second) // give time to ack before cancelling + cancel() + <-ch + } + }) }