diff --git a/mtest/mtest.go b/mtest/mtest.go index 326b66f..03f39fa 100644 --- a/mtest/mtest.go +++ b/mtest/mtest.go @@ -6,6 +6,7 @@ import ( crand "crypto/rand" "encoding/hex" "math/rand" + "reflect" "time" ) @@ -27,3 +28,33 @@ func RandHex(n int) string { b := RandBytes(hex.DecodedLen(n)) return hex.EncodeToString(b) } + +// RandElement returns a random element from the given slice. +// +// If a weighting function is given then that function is used to weight each +// element of the slice relative to the others, based on whatever metric and +// scale is desired. The weight function must be able to be called more than +// once on each element. +func RandElement(slice interface{}, weight func(i int) uint64) interface{} { + v := reflect.ValueOf(slice) + l := v.Len() + + if weight == nil { + return v.Index(Rand.Intn(l)).Interface() + } + + var totalWeight uint64 + for i := 0; i < l; i++ { + totalWeight += weight(i) + } + + target := Rand.Int63n(int64(totalWeight)) + for i := 0; i < l; i++ { + w := int64(weight(i)) + target -= w + if target < 0 { + return v.Index(i).Interface() + } + } + panic("should never get here, perhaps the weighting function is inconsistent?") +} diff --git a/mtest/mtest_test.go b/mtest/mtest_test.go index 3f283fd..ae28cb1 100644 --- a/mtest/mtest_test.go +++ b/mtest/mtest_test.go @@ -21,3 +21,39 @@ func TestRandHex(t *T) { // much assert.Len(t, RandHex(16), 16) } + +func TestRandElement(t *T) { + slice := []uint64{1, 2, 3} // values are also each value's weight + total := func() uint64 { + var t uint64 + for i := range slice { + t += slice[i] + } + return t + }() + m := map[uint64]uint64{} + + iterations := 100000 + for i := 0; i < iterations; i++ { + el := RandElement(slice, func(i int) uint64 { return slice[i] }).(uint64) + m[el]++ + } + + for i := range slice { + t.Logf("%d -> %d (%f)", slice[i], m[slice[i]], float64(m[slice[i]])/float64(iterations)) + } + + assertEl := func(i int) { + el, elF := slice[i], float64(slice[i]) + gotRatio := float64(m[el]) / float64(iterations) + expRatio := elF / float64(total) + diff := (gotRatio - expRatio) / expRatio + if diff > 0.1 || diff < -0.1 { + t.Fatalf("ratio of element %d is off: got %f, expected %f (diff:%f)", el, gotRatio, expRatio, diff) + } + } + + for i := range slice { + assertEl(i) + } +}