mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00
154 lines
3.4 KiB
Go
154 lines
3.4 KiB
Go
// Copyright ©2015 The Gonum Authors. All rights reserved.
|
|
// Use of this code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file
|
|
|
|
package sampleuv
|
|
|
|
import "golang.org/x/exp/rand"
|
|
|
|
// Weighted provides sampling without replacement from a collection of items with
|
|
// non-uniform probability.
|
|
type Weighted struct {
|
|
weights []float64
|
|
// heap is a weight heap.
|
|
//
|
|
// It keeps a heap-organised sum of remaining
|
|
// index weights that are available to be taken
|
|
// from.
|
|
//
|
|
// Each element holds the sum of weights for
|
|
// the corresponding index, plus the sum of
|
|
// its children's weights; the children of
|
|
// an element i can be found at positions
|
|
// 2*(i+1)-1 and 2*(i+1). The root of the
|
|
// weight heap is at element 0.
|
|
//
|
|
// See comments in container/heap for an
|
|
// explanation of the layout of a heap.
|
|
heap []float64
|
|
rnd *rand.Rand
|
|
}
|
|
|
|
// NewWeighted returns a Weighted for the weights w. If src is nil, rand.Rand is
|
|
// used as the random number generator.
|
|
//
|
|
// Note that sampling from weights with a high variance or overall low absolute
|
|
// value sum may result in problems with numerical stability.
|
|
func NewWeighted(w []float64, src rand.Source) Weighted {
|
|
s := Weighted{
|
|
weights: make([]float64, len(w)),
|
|
heap: make([]float64, len(w)),
|
|
}
|
|
if src != nil {
|
|
s.rnd = rand.New(src)
|
|
}
|
|
s.ReweightAll(w)
|
|
return s
|
|
}
|
|
|
|
// Len returns the number of items held by the Weighted, including items
|
|
// already taken.
|
|
func (s Weighted) Len() int { return len(s.weights) }
|
|
|
|
// Take returns an index from the Weighted with probability proportional
|
|
// to the weight of the item. The weight of the item is then set to zero.
|
|
// Take returns false if there are no items remaining.
|
|
func (s Weighted) Take() (idx int, ok bool) {
|
|
if s.heap[0] == 0 {
|
|
return -1, false
|
|
}
|
|
|
|
var r float64
|
|
if s.rnd == nil {
|
|
r = rand.Float64()
|
|
} else {
|
|
r = s.rnd.Float64()
|
|
}
|
|
|
|
r *= s.heap[0]
|
|
i := 0
|
|
for {
|
|
r -= s.weights[i]
|
|
if r < 0 {
|
|
break // Fall within item i.
|
|
}
|
|
|
|
li := i*2 + 1 // Move to left child.
|
|
// Left node should exist, because r is non-negative,
|
|
// but there could be floating point errors, so we
|
|
// check index explicitly.
|
|
if li >= len(s.heap) {
|
|
break
|
|
}
|
|
|
|
i = li
|
|
|
|
d := s.heap[i]
|
|
if r >= d {
|
|
// If there is enough r to pass left child try to
|
|
// move to the right child.
|
|
r -= d
|
|
ri := i + 1
|
|
|
|
if ri >= len(s.heap) {
|
|
break
|
|
}
|
|
|
|
i = ri
|
|
}
|
|
}
|
|
|
|
s.Reweight(i, 0)
|
|
|
|
return i, true
|
|
}
|
|
|
|
// Reweight sets the weight of item idx to w.
|
|
func (s Weighted) Reweight(idx int, w float64) {
|
|
s.weights[idx] = w
|
|
|
|
// We want to keep the heap state here consistent
|
|
// with the result of a reset call. So we sum
|
|
// weights in the same order, since floating point
|
|
// addition is not associative.
|
|
for {
|
|
w = s.weights[idx]
|
|
|
|
ri := idx*2 + 2
|
|
if ri < len(s.heap) {
|
|
w += s.heap[ri]
|
|
}
|
|
|
|
li := ri - 1
|
|
if li < len(s.heap) {
|
|
w += s.heap[li]
|
|
}
|
|
|
|
s.heap[idx] = w
|
|
|
|
if idx == 0 {
|
|
break
|
|
}
|
|
|
|
idx = (idx - 1) / 2
|
|
}
|
|
}
|
|
|
|
// ReweightAll sets the weight of all items in the Weighted. ReweightAll
|
|
// panics if len(w) != s.Len.
|
|
func (s Weighted) ReweightAll(w []float64) {
|
|
if len(w) != s.Len() {
|
|
panic("floats: length of the slices do not match")
|
|
}
|
|
copy(s.weights, w)
|
|
s.reset()
|
|
}
|
|
|
|
func (s Weighted) reset() {
|
|
copy(s.heap, s.weights)
|
|
for i := len(s.heap) - 1; i > 0; i-- {
|
|
// Sometimes 1-based counting makes sense.
|
|
s.heap[((i+1)>>1)-1] += s.heap[i]
|
|
}
|
|
}
|