mtest: implement RandElement
This commit is contained in:
parent
5adbae953b
commit
f9ec4d7bce
@ -6,6 +6,7 @@ import (
|
|||||||
crand "crypto/rand"
|
crand "crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,3 +28,33 @@ func RandHex(n int) string {
|
|||||||
b := RandBytes(hex.DecodedLen(n))
|
b := RandBytes(hex.DecodedLen(n))
|
||||||
return hex.EncodeToString(b)
|
return hex.EncodeToString(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RandElement returns a random element from the given slice.
|
||||||
|
//
|
||||||
|
// If a weighting function is given then that function is used to weight each
|
||||||
|
// element of the slice relative to the others, based on whatever metric and
|
||||||
|
// scale is desired. The weight function must be able to be called more than
|
||||||
|
// once on each element.
|
||||||
|
func RandElement(slice interface{}, weight func(i int) uint64) interface{} {
|
||||||
|
v := reflect.ValueOf(slice)
|
||||||
|
l := v.Len()
|
||||||
|
|
||||||
|
if weight == nil {
|
||||||
|
return v.Index(Rand.Intn(l)).Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
var totalWeight uint64
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
totalWeight += weight(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
target := Rand.Int63n(int64(totalWeight))
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
w := int64(weight(i))
|
||||||
|
target -= w
|
||||||
|
if target < 0 {
|
||||||
|
return v.Index(i).Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic("should never get here, perhaps the weighting function is inconsistent?")
|
||||||
|
}
|
||||||
|
@ -21,3 +21,39 @@ func TestRandHex(t *T) {
|
|||||||
// much
|
// much
|
||||||
assert.Len(t, RandHex(16), 16)
|
assert.Len(t, RandHex(16), 16)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRandElement(t *T) {
|
||||||
|
slice := []uint64{1, 2, 3} // values are also each value's weight
|
||||||
|
total := func() uint64 {
|
||||||
|
var t uint64
|
||||||
|
for i := range slice {
|
||||||
|
t += slice[i]
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}()
|
||||||
|
m := map[uint64]uint64{}
|
||||||
|
|
||||||
|
iterations := 100000
|
||||||
|
for i := 0; i < iterations; i++ {
|
||||||
|
el := RandElement(slice, func(i int) uint64 { return slice[i] }).(uint64)
|
||||||
|
m[el]++
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range slice {
|
||||||
|
t.Logf("%d -> %d (%f)", slice[i], m[slice[i]], float64(m[slice[i]])/float64(iterations))
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEl := func(i int) {
|
||||||
|
el, elF := slice[i], float64(slice[i])
|
||||||
|
gotRatio := float64(m[el]) / float64(iterations)
|
||||||
|
expRatio := elF / float64(total)
|
||||||
|
diff := (gotRatio - expRatio) / expRatio
|
||||||
|
if diff > 0.1 || diff < -0.1 {
|
||||||
|
t.Fatalf("ratio of element %d is off: got %f, expected %f (diff:%f)", el, gotRatio, expRatio, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range slice {
|
||||||
|
assertEl(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user