mirror of
https://github.com/gonum/gonum.git
synced 2025-09-27 03:26:04 +08:00
403 lines
11 KiB
Go
403 lines
11 KiB
Go
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package vptree
|
|
|
|
import (
|
|
"container/heap"
|
|
"errors"
|
|
"math"
|
|
"math/rand/v2"
|
|
"sort"
|
|
|
|
"gonum.org/v1/gonum/stat"
|
|
)
|
|
|
|
// Comparable is the element interface for values stored in a vp-tree.
|
|
type Comparable interface {
|
|
// Distance returns the distance between the receiver and the
|
|
// parameter. The returned distance must satisfy the properties
|
|
// of distances in a metric space.
|
|
//
|
|
// - a.Distance(a) == 0
|
|
// - a.Distance(b) >= 0
|
|
// - a.Distance(b) == b.Distance(a)
|
|
// - a.Distance(b) <= a.Distance(c)+c.Distance(b)
|
|
//
|
|
Distance(Comparable) float64
|
|
}
|
|
|
|
// Point represents a point in a Euclidean k-d space that satisfies the Comparable
|
|
// interface.
|
|
type Point []float64
|
|
|
|
// Distance returns the Euclidean distance between c and the receiver. The concrete
|
|
// type of c must be Point.
|
|
func (p Point) Distance(c Comparable) float64 {
|
|
q := c.(Point)
|
|
var sum float64
|
|
for dim, c := range p {
|
|
d := c - q[dim]
|
|
sum += d * d
|
|
}
|
|
return math.Sqrt(sum)
|
|
}
|
|
|
|
// Node holds a single point value in a vantage point tree.
|
|
type Node struct {
|
|
Point Comparable
|
|
Radius float64
|
|
Closer *Node
|
|
Further *Node
|
|
}
|
|
|
|
// Tree implements a vantage point tree creation and nearest neighbor search.
|
|
type Tree struct {
|
|
Root *Node
|
|
Count int
|
|
}
|
|
|
|
// New returns a vantage point tree constructed from the values in p. The effort
|
|
// parameter specifies how much work should be put into optimizing the choice of
|
|
// vantage point. If effort is one or less, random vantage points are chosen.
|
|
// The order of elements in p will be altered after New returns. The src parameter
|
|
// provides the source of randomness for vantage point selection. If src is nil
|
|
// global rand package functions are used. Points in p must not be infinitely
|
|
// distant.
|
|
func New(p []Comparable, effort int, src rand.Source) (t *Tree, err error) {
|
|
var intn func(int) int
|
|
var shuf func(n int, swap func(i, j int))
|
|
if src == nil {
|
|
intn = rand.IntN
|
|
shuf = rand.Shuffle
|
|
} else {
|
|
rnd := rand.New(src)
|
|
intn = rnd.IntN
|
|
shuf = rnd.Shuffle
|
|
}
|
|
b := builder{work: make([]float64, len(p)), intn: intn, shuf: shuf}
|
|
|
|
defer func() {
|
|
switch r := recover(); r {
|
|
case nil:
|
|
case pointAtInfinity:
|
|
t = nil
|
|
err = pointAtInfinity
|
|
default:
|
|
panic(r)
|
|
}
|
|
}()
|
|
|
|
t = &Tree{
|
|
Root: b.build(p, effort),
|
|
Count: len(p),
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
var pointAtInfinity = errors.New("vptree: point at infinity")
|
|
|
|
// builder performs vp-tree construction as described for the simple vp-tree
|
|
// algorithm in http://pnylab.com/papers/vptree/vptree.pdf.
|
|
type builder struct {
|
|
work []float64
|
|
intn func(n int) int
|
|
shuf func(n int, swap func(i, j int))
|
|
}
|
|
|
|
func (b *builder) build(s []Comparable, effort int) *Node {
|
|
if len(s) <= 1 {
|
|
if len(s) == 0 {
|
|
return nil
|
|
}
|
|
return &Node{Point: s[0]}
|
|
}
|
|
n := Node{Point: b.selectVantage(s, effort)}
|
|
radius, closer, further := b.partition(n.Point, s)
|
|
n.Radius = radius
|
|
n.Closer = b.build(closer, effort)
|
|
n.Further = b.build(further, effort)
|
|
return &n
|
|
}
|
|
|
|
func (b *builder) selectVantage(s []Comparable, effort int) Comparable {
|
|
if effort <= 1 {
|
|
return s[b.intn(len(s))]
|
|
}
|
|
if effort > len(s) {
|
|
effort = len(s)
|
|
}
|
|
var best Comparable
|
|
bestVar := -1.0
|
|
b.work = b.work[:effort]
|
|
choices := b.random(effort, s)
|
|
for _, p := range choices {
|
|
for i, q := range choices {
|
|
d := p.Distance(q)
|
|
if math.IsInf(d, 0) {
|
|
panic(pointAtInfinity)
|
|
}
|
|
b.work[i] = d
|
|
}
|
|
variance := stat.Variance(b.work, nil)
|
|
if variance > bestVar {
|
|
best, bestVar = p, variance
|
|
}
|
|
}
|
|
if best == nil {
|
|
// This should never be reached.
|
|
panic("vptree: could not find vantage point")
|
|
}
|
|
return best
|
|
}
|
|
|
|
func (b *builder) random(n int, s []Comparable) []Comparable {
|
|
if n >= len(s) {
|
|
n = len(s)
|
|
}
|
|
b.shuf(len(s), func(i, j int) { s[i], s[j] = s[j], s[i] })
|
|
return s[:n]
|
|
}
|
|
|
|
func (b *builder) partition(v Comparable, s []Comparable) (radius float64, closer, further []Comparable) {
|
|
b.work = b.work[:len(s)]
|
|
for i, p := range s {
|
|
d := v.Distance(p)
|
|
if math.IsInf(d, 0) {
|
|
panic(pointAtInfinity)
|
|
}
|
|
b.work[i] = d
|
|
}
|
|
sort.Sort(byDist{dists: b.work, points: s})
|
|
|
|
// Note that this does not conform exactly to the description
|
|
// in the paper which specifies d(p, s) < mu for L; in cases
|
|
// where the median element has a lower indexed element with
|
|
// the same distance from the vantage point, L will include a
|
|
// d(p, s) == mu.
|
|
// The additional work required to satisfy the algorithm is
|
|
// not worth doing as it has no effect on the correctness or
|
|
// performance of the algorithm.
|
|
radius = b.work[len(b.work)/2]
|
|
|
|
if len(b.work) > 1 {
|
|
// Remove vantage if it is present.
|
|
closer = s[1 : len(b.work)/2]
|
|
}
|
|
further = s[len(b.work)/2:]
|
|
return radius, closer, further
|
|
}
|
|
|
|
type byDist struct {
|
|
dists []float64
|
|
points []Comparable
|
|
}
|
|
|
|
func (c byDist) Len() int { return len(c.dists) }
|
|
func (c byDist) Less(i, j int) bool { return c.dists[i] < c.dists[j] }
|
|
func (c byDist) Swap(i, j int) {
|
|
c.dists[i], c.dists[j] = c.dists[j], c.dists[i]
|
|
c.points[i], c.points[j] = c.points[j], c.points[i]
|
|
}
|
|
|
|
// Len returns the number of elements in the tree.
|
|
func (t *Tree) Len() int { return t.Count }
|
|
|
|
var inf = math.Inf(1)
|
|
|
|
// Nearest returns the nearest value to the query and the distance between them.
|
|
func (t *Tree) Nearest(q Comparable) (Comparable, float64) {
|
|
if t.Root == nil {
|
|
return nil, inf
|
|
}
|
|
n, dist := t.Root.search(q, inf)
|
|
if n == nil {
|
|
return nil, inf
|
|
}
|
|
return n.Point, dist
|
|
}
|
|
|
|
func (n *Node) search(q Comparable, dist float64) (*Node, float64) {
|
|
if n == nil {
|
|
return nil, inf
|
|
}
|
|
|
|
d := q.Distance(n.Point)
|
|
dist = math.Min(dist, d)
|
|
|
|
bn := n
|
|
if d < n.Radius {
|
|
cn, cd := n.Closer.search(q, dist)
|
|
if cd < dist {
|
|
bn, dist = cn, cd
|
|
}
|
|
if d+dist >= n.Radius {
|
|
fn, fd := n.Further.search(q, dist)
|
|
if fd < dist {
|
|
bn, dist = fn, fd
|
|
}
|
|
}
|
|
} else {
|
|
fn, fd := n.Further.search(q, dist)
|
|
if fd < dist {
|
|
bn, dist = fn, fd
|
|
}
|
|
if d-dist <= n.Radius {
|
|
cn, cd := n.Closer.search(q, dist)
|
|
if cd < dist {
|
|
bn, dist = cn, cd
|
|
}
|
|
}
|
|
}
|
|
|
|
return bn, dist
|
|
}
|
|
|
|
// ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable
|
|
// is used to mark the end of the heap, so clients should not store nil values except for
|
|
// this purpose.
|
|
type ComparableDist struct {
|
|
Comparable Comparable
|
|
Dist float64
|
|
}
|
|
|
|
// Heap is a max heap sorted on Dist.
|
|
type Heap []ComparableDist
|
|
|
|
func (h *Heap) Max() ComparableDist { return (*h)[0] }
|
|
func (h *Heap) Len() int { return len(*h) }
|
|
func (h *Heap) Less(i, j int) bool { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist }
|
|
func (h *Heap) Swap(i, j int) { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] }
|
|
func (h *Heap) Push(x interface{}) { (*h) = append(*h, x.(ComparableDist)) }
|
|
func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i }
|
|
|
|
// NKeeper is a Keeper that retains the n best ComparableDists that have been passed to Keep.
|
|
type NKeeper struct {
|
|
Heap
|
|
}
|
|
|
|
// NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The
|
|
// returned NKeeper is able to retain at most n values.
|
|
func NewNKeeper(n int) *NKeeper {
|
|
k := NKeeper{make(Heap, 1, n)}
|
|
k.Heap[0].Dist = inf
|
|
return &k
|
|
}
|
|
|
|
// Keep adds c to the heap if its distance is less than the maximum value of the heap. If adding
|
|
// c would increase the size of the heap beyond the initial maximum length, the maximum value of
|
|
// the heap is dropped.
|
|
func (k *NKeeper) Keep(c ComparableDist) {
|
|
if c.Dist <= k.Heap[0].Dist { // Favour later finds to displace sentinel.
|
|
if len(k.Heap) == cap(k.Heap) {
|
|
heap.Pop(k)
|
|
}
|
|
heap.Push(k, c)
|
|
}
|
|
}
|
|
|
|
// DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the
|
|
// query that it is called to Keep.
|
|
type DistKeeper struct {
|
|
Heap
|
|
}
|
|
|
|
// NewDistKeeper returns an DistKeeper with the maximum value of the heap set to d.
|
|
func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} }
|
|
|
|
// Keep adds c to the heap if its distance is less than or equal to the max value of the heap.
|
|
func (k *DistKeeper) Keep(c ComparableDist) {
|
|
if c.Dist <= k.Heap[0].Dist {
|
|
heap.Push(k, c)
|
|
}
|
|
}
|
|
|
|
// Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type.
|
|
// vantage point search is guided by the distance stored in the max value of the heap.
|
|
type Keeper interface {
|
|
Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap.
|
|
Max() ComparableDist // Max returns the maximum element of the Keeper.
|
|
heap.Interface
|
|
}
|
|
|
|
// NearestSet finds the nearest values to the query accepted by the provided Keeper, k.
|
|
// k must be able to return a ComparableDist specifying the maximum acceptable distance
|
|
// when Max() is called, and retains the results of the search in min sorted order after
|
|
// the call to NearestSet returns.
|
|
// If a sentinel ComparableDist with a nil Comparable is used by the Keeper to mark the
|
|
// maximum distance, NearestSet will remove it before returning.
|
|
func (t *Tree) NearestSet(k Keeper, q Comparable) {
|
|
if t.Root == nil {
|
|
return
|
|
}
|
|
t.Root.searchSet(q, k)
|
|
|
|
// Check whether we have retained a sentinel
|
|
// and flag removal if we have.
|
|
removeSentinel := k.Len() != 0 && k.Max().Comparable == nil
|
|
|
|
sort.Sort(sort.Reverse(k))
|
|
|
|
// This abuses the interface to drop the max.
|
|
// It is reasonable to do this because we know
|
|
// that the maximum value will now be at element
|
|
// zero, which is removed by the Pop method.
|
|
if removeSentinel {
|
|
k.Pop()
|
|
}
|
|
}
|
|
|
|
func (n *Node) searchSet(q Comparable, k Keeper) {
|
|
if n == nil {
|
|
return
|
|
}
|
|
|
|
k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)})
|
|
|
|
d := q.Distance(n.Point)
|
|
if d < n.Radius {
|
|
n.Closer.searchSet(q, k)
|
|
if d+k.Max().Dist >= n.Radius {
|
|
n.Further.searchSet(q, k)
|
|
}
|
|
} else {
|
|
n.Further.searchSet(q, k)
|
|
if d-k.Max().Dist <= n.Radius {
|
|
n.Closer.searchSet(q, k)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Operation is a function that operates on a Comparable. The bounding volume and tree depth
|
|
// of the point is also provided. If done is returned true, the Operation is indicating that no
|
|
// further work needs to be done and so the Do function should traverse no further.
|
|
type Operation func(Comparable, int) (done bool)
|
|
|
|
// Do performs fn on all values stored in the tree. A boolean is returned indicating whether the
|
|
// Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort
|
|
// relationships, future tree operation behaviors are undefined.
|
|
func (t *Tree) Do(fn Operation) bool {
|
|
if t.Root == nil {
|
|
return false
|
|
}
|
|
return t.Root.do(fn, 0)
|
|
}
|
|
|
|
func (n *Node) do(fn Operation, depth int) (done bool) {
|
|
if n.Closer != nil {
|
|
done = n.Closer.do(fn, depth+1)
|
|
if done {
|
|
return
|
|
}
|
|
}
|
|
done = fn(n.Point, depth)
|
|
if done {
|
|
return
|
|
}
|
|
if n.Further != nil {
|
|
done = n.Further.do(fn, depth+1)
|
|
}
|
|
return
|
|
}
|