package toolkit import ( "context" "sync" "testing" "github.com/stretchr/testify/assert" ) // ExpectedBlockpoint represents the expectation that Blockpoint will be called // on a TestBlocker. It is possible to both wait for the Blockpoint call to // occur and to unblock it once it has occured. type ExpectedBlockpoint struct { waitCh chan struct{} unblockCh chan struct{} } // Wait will block until blockpoint has been hit and is itself blocking, or will // return the context error. func (eb ExpectedBlockpoint) Wait(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() case <-eb.waitCh: return nil } } // Unblock unblocks the Blockpoint call which is/was expected. If Unblock can be // called prior to Wait being called (and therefore prior to the Blockpoint // being hit). func (eb ExpectedBlockpoint) Unblock() { close(eb.unblockCh) } // On is a helper which will spawn a go-routine, call Wait on the // ExpectedBlockpoint, call the given callback, and then Unblock the // ExpectedBlockpoint. // // If Wait returns an error (due to context cancellation) then this fails the // test and returns without calling the callback. func (eb ExpectedBlockpoint) On(t *testing.T, ctx context.Context, cb func()) { go func() { defer eb.Unblock() if !assert.NoError(t, eb.Wait(ctx)) { return } cb() }() } // TestBlocker is used as an injected dependency into components, so that tests // can cause those components to block at specific execution points internally. // This is useful for testing race conditions between multiple components. // // A TestBlocker is initialized using `new`. A nil TestBlocker will never block. type TestBlocker struct { l sync.Mutex blockpointsByID map[string][]ExpectedBlockpoint blocksByID map[string]int } // NewTestBlocker initializes a TestBlocker and registers a Cleanup callback on // the T which will call AssertExpectations. func NewTestBlocker(t *testing.T) *TestBlocker { b := new(TestBlocker) t.Cleanup(func() { b.AssertExpectations(t) }) return b } // Blockpoint will block if and only if TestBlocker is non-nil and // ExpectBlockpoint has been called with the same ID previously. If the context // is canceled while blocking then this call will return. func (b *TestBlocker) Blockpoint(ctx context.Context, id string) { if b == nil { return } b.l.Lock() blockpoints := b.blockpointsByID[id] if len(blockpoints) == 0 { b.l.Unlock() return } blockpoint, blockpoints := blockpoints[0], blockpoints[1:] b.blockpointsByID[id] = blockpoints b.blocksByID[id]++ b.l.Unlock() close(blockpoint.waitCh) select { case <-ctx.Done(): case <-blockpoint.unblockCh: } } // ExpectBlockpoint will cause the TestBlocker to block upon the next call to // Blockpoint using the same id. The returned ExpectBlockpoint can be used to // wait until Blockpoint is called, as well as to unblock it. func (b *TestBlocker) ExpectBlockpoint(id string) ExpectedBlockpoint { b.l.Lock() defer b.l.Unlock() if b.blockpointsByID == nil { b.blockpointsByID = map[string][]ExpectedBlockpoint{} } if b.blocksByID == nil { b.blocksByID = map[string]int{} } blockpoint := ExpectedBlockpoint{ waitCh: make(chan struct{}), unblockCh: make(chan struct{}), } b.blockpointsByID[id] = append(b.blockpointsByID[id], blockpoint) return blockpoint } // AssertExpectations will Fail the test and return false if any calls to // ExpectBlockpoint have not had a corresponding Blockpoint call made. func (b *TestBlocker) AssertExpectations(t *testing.T) bool { b.l.Lock() defer b.l.Unlock() var failed bool for id, blockpoints := range b.blockpointsByID { if len(blockpoints) == 0 { continue } failed = true t.Errorf( "Blockpoint(%q) called %d times, expected %d more", id, b.blocksByID[id], len(blockpoints), ) } return !failed }