148 lines
3.8 KiB
Go
148 lines
3.8 KiB
Go
package toolkit
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
// ExpectedBlockpoint represents the expectation that Blockpoint will be called
|
|
// on a TestBlocker. It is possible to both wait for the Blockpoint call to
|
|
// occur and to unblock it once it has occured.
|
|
type ExpectedBlockpoint struct {
|
|
waitCh chan struct{}
|
|
unblockCh chan struct{}
|
|
}
|
|
|
|
// Wait will block until blockpoint has been hit and is itself blocking, or will
|
|
// return the context error.
|
|
func (eb ExpectedBlockpoint) Wait(ctx context.Context) error {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-eb.waitCh:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Unblock unblocks the Blockpoint call which is/was expected. If Unblock can be
|
|
// called prior to Wait being called (and therefore prior to the Blockpoint
|
|
// being hit).
|
|
func (eb ExpectedBlockpoint) Unblock() {
|
|
close(eb.unblockCh)
|
|
}
|
|
|
|
// Then is a helper which will spawn a go-routine, call Wait on the
|
|
// ExpectedBlockpoint, call the given callback, and then Unblock the
|
|
// ExpectedBlockpoint.
|
|
//
|
|
// If Wait returns an error (due to context cancellation) then this fails the
|
|
// test and returns without calling the callback.
|
|
func (eb ExpectedBlockpoint) Then(t *testing.T, ctx context.Context, cb func()) {
|
|
go func() {
|
|
defer eb.Unblock()
|
|
if !assert.NoError(t, eb.Wait(ctx)) {
|
|
return
|
|
}
|
|
cb()
|
|
}()
|
|
}
|
|
|
|
// TestBlocker is used as an injected dependency into components, so that tests
|
|
// can cause those components to block at specific execution points internally.
|
|
// This is useful for testing race conditions between multiple components.
|
|
//
|
|
// A TestBlocker is initialized using `new`. A nil TestBlocker will never block.
|
|
type TestBlocker struct {
|
|
l sync.Mutex
|
|
blockpointsByID map[string][]ExpectedBlockpoint
|
|
blocksByID map[string]int
|
|
}
|
|
|
|
// NewTestBlocker initializes a TestBlocker and registers a Cleanup callback on
|
|
// the T which will call AssertExpectations.
|
|
func NewTestBlocker(t *testing.T) *TestBlocker {
|
|
b := new(TestBlocker)
|
|
t.Cleanup(func() { b.AssertExpectations(t) })
|
|
return b
|
|
}
|
|
|
|
// Blockpoint will block if and only if TestBlocker is non-nil and Expect has
|
|
// been called with the same ID previously. If the context is canceled while
|
|
// blocking then this call will return.
|
|
func (b *TestBlocker) Blockpoint(ctx context.Context, id string) {
|
|
if b == nil {
|
|
return
|
|
}
|
|
|
|
b.l.Lock()
|
|
|
|
blockpoints := b.blockpointsByID[id]
|
|
if len(blockpoints) == 0 {
|
|
b.l.Unlock()
|
|
return
|
|
}
|
|
|
|
blockpoint, blockpoints := blockpoints[0], blockpoints[1:]
|
|
b.blockpointsByID[id] = blockpoints
|
|
b.blocksByID[id]++
|
|
|
|
b.l.Unlock()
|
|
|
|
close(blockpoint.waitCh)
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
case <-blockpoint.unblockCh:
|
|
}
|
|
}
|
|
|
|
// Expect will cause the TestBlocker to block upon the next call to Blockpoint
|
|
// using the same id. The returned ExpectedBlockpoint can be used to wait until
|
|
// Blockpoint is called, as well as to unblock it.
|
|
func (b *TestBlocker) Expect(id string) ExpectedBlockpoint {
|
|
b.l.Lock()
|
|
defer b.l.Unlock()
|
|
|
|
if b.blockpointsByID == nil {
|
|
b.blockpointsByID = map[string][]ExpectedBlockpoint{}
|
|
}
|
|
|
|
if b.blocksByID == nil {
|
|
b.blocksByID = map[string]int{}
|
|
}
|
|
|
|
blockpoint := ExpectedBlockpoint{
|
|
waitCh: make(chan struct{}),
|
|
unblockCh: make(chan struct{}),
|
|
}
|
|
|
|
b.blockpointsByID[id] = append(b.blockpointsByID[id], blockpoint)
|
|
|
|
return blockpoint
|
|
}
|
|
|
|
// AssertExpectations will Fail the test and return false if any calls to Expect
|
|
// have not had a corresponding Blockpoint call made.
|
|
func (b *TestBlocker) AssertExpectations(t *testing.T) bool {
|
|
b.l.Lock()
|
|
defer b.l.Unlock()
|
|
|
|
var failed bool
|
|
for id, blockpoints := range b.blockpointsByID {
|
|
if len(blockpoints) == 0 {
|
|
continue
|
|
}
|
|
|
|
failed = true
|
|
t.Errorf(
|
|
"Blockpoint(%q) called %d times, expected %d more",
|
|
id, b.blocksByID[id], len(blockpoints),
|
|
)
|
|
}
|
|
|
|
return !failed
|
|
}
|