// Copyright ©2019 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package vptree import ( "container/heap" "errors" "math" "sort" "golang.org/x/exp/rand" "gonum.org/v1/gonum/stat" ) // Comparable is the element interface for values stored in a vp-tree. type Comparable interface { // Distance returns the distance between the receiver and the // parameter. The returned distance must satisfy the properties // of distances in a metric space. // // - a.Distance(a) == 0 // - a.Distance(b) >= 0 // - a.Distance(b) == b.Distance(a) // - a.Distance(b) <= a.Distance(c)+c.Distance(b) // Distance(Comparable) float64 } // Point represents a point in a Euclidean k-d space that satisfies the Comparable // interface. type Point []float64 // Distance returns the Euclidean distance between c and the receiver. The concrete // type of c must be Point. func (p Point) Distance(c Comparable) float64 { q := c.(Point) var sum float64 for dim, c := range p { d := c - q[dim] sum += d * d } return math.Sqrt(sum) } // Node holds a single point value in a vantage point tree. type Node struct { Point Comparable Radius float64 Closer *Node Further *Node } // Tree implements a vantage point tree creation and nearest neighbor search. type Tree struct { Root *Node Count int } // New returns a vantage point tree constructed from the values in p. The effort // parameter specifies how much work should be put into optimizing the choice of // vantage point. If effort is one or less, random vantage points are chosen. // The order of elements in p will be altered after New returns. The src parameter // provides the source of randomness for vantage point selection. If src is nil // global rand package functions are used. Points in p must not be infinitely // distant. func New(p []Comparable, effort int, src rand.Source) (t *Tree, err error) { var intn func(int) int var shuf func(n int, swap func(i, j int)) if src == nil { intn = rand.Intn shuf = rand.Shuffle } else { rnd := rand.New(src) intn = rnd.Intn shuf = rnd.Shuffle } b := builder{work: make([]float64, len(p)), intn: intn, shuf: shuf} defer func() { switch r := recover(); r { case nil: case pointAtInfinity: t = nil err = pointAtInfinity default: panic(r) } }() t = &Tree{ Root: b.build(p, effort), Count: len(p), } return t, nil } var pointAtInfinity = errors.New("vptree: point at infinity") // builder performs vp-tree construction as described for the simple vp-tree // algorithm in http://pnylab.com/papers/vptree/vptree.pdf. type builder struct { work []float64 intn func(n int) int shuf func(n int, swap func(i, j int)) } func (b *builder) build(s []Comparable, effort int) *Node { if len(s) <= 1 { if len(s) == 0 { return nil } return &Node{Point: s[0]} } n := Node{Point: b.selectVantage(s, effort)} radius, closer, further := b.partition(n.Point, s) n.Radius = radius n.Closer = b.build(closer, effort) n.Further = b.build(further, effort) return &n } func (b *builder) selectVantage(s []Comparable, effort int) Comparable { if effort <= 1 { return s[b.intn(len(s))] } if effort > len(s) { effort = len(s) } var best Comparable var bestVar float64 b.work = b.work[:effort] choices := b.random(effort, s) for _, p := range choices { for i, q := range choices { d := p.Distance(q) if math.IsInf(d, 0) { panic(pointAtInfinity) } b.work[i] = d } variance := stat.Variance(b.work, nil) if variance > bestVar { best, bestVar = p, variance } } if best == nil { // This should never be reached. panic("vptree: could not find vantage point") } return best } func (b *builder) random(n int, s []Comparable) []Comparable { if n >= len(s) { return s } b.shuf(len(s), func(i, j int) { s[i], s[j] = s[j], s[i] }) return s[:n] } func (b *builder) partition(v Comparable, s []Comparable) (radius float64, closer, further []Comparable) { b.work = b.work[:len(s)] for i, p := range s { d := v.Distance(p) if math.IsInf(d, 0) { panic(pointAtInfinity) } b.work[i] = d } sort.Sort(byDist{dists: b.work, points: s}) // Note that this does not conform exactly to the description // in the paper which specifies d(p, s) < mu for L; in cases // where the median element has a lower indexed element with // the same distance from the vantage point, L will include a // d(p, s) == mu. // The additional work required to satisfy the algorithm is // not worth doing as it has no effect on the correctness or // performance of the algorithm. radius = b.work[len(b.work)/2] if len(b.work) > 1 { // Remove vantage if it is present. closer = s[1 : len(b.work)/2] } further = s[len(b.work)/2:] return radius, closer, further } type byDist struct { dists []float64 points []Comparable } func (c byDist) Len() int { return len(c.dists) } func (c byDist) Less(i, j int) bool { return c.dists[i] < c.dists[j] } func (c byDist) Swap(i, j int) { c.dists[i], c.dists[j] = c.dists[j], c.dists[i] c.points[i], c.points[j] = c.points[j], c.points[i] } // Len returns the number of elements in the tree. func (t *Tree) Len() int { return t.Count } var inf = math.Inf(1) // Nearest returns the nearest value to the query and the distance between them. func (t *Tree) Nearest(q Comparable) (Comparable, float64) { if t.Root == nil { return nil, inf } n, dist := t.Root.search(q, inf) if n == nil { return nil, inf } return n.Point, dist } func (n *Node) search(q Comparable, dist float64) (*Node, float64) { if n == nil { return nil, inf } d := q.Distance(n.Point) dist = math.Min(dist, d) bn := n if d < n.Radius { cn, cd := n.Closer.search(q, dist) if cd < dist { bn, dist = cn, cd } if d+dist >= n.Radius { fn, fd := n.Further.search(q, dist) if fd < dist { bn, dist = fn, fd } } } else { fn, fd := n.Further.search(q, dist) if fd < dist { bn, dist = fn, fd } if d-dist <= n.Radius { cn, cd := n.Closer.search(q, dist) if cd < dist { bn, dist = cn, cd } } } return bn, dist } // ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable // is used to mark the end of the heap, so clients should not store nil values except for // this purpose. type ComparableDist struct { Comparable Comparable Dist float64 } // Heap is a max heap sorted on Dist. type Heap []ComparableDist func (h *Heap) Max() ComparableDist { return (*h)[0] } func (h *Heap) Len() int { return len(*h) } func (h *Heap) Less(i, j int) bool { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist } func (h *Heap) Swap(i, j int) { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] } func (h *Heap) Push(x interface{}) { (*h) = append(*h, x.(ComparableDist)) } func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i } // NKeeper is a Keeper that retains the n best ComparableDists that have been passed to Keep. type NKeeper struct { Heap } // NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The // returned NKeeper is able to retain at most n values. func NewNKeeper(n int) *NKeeper { k := NKeeper{make(Heap, 1, n)} k.Heap[0].Dist = inf return &k } // Keep adds c to the heap if its distance is less than the maximum value of the heap. If adding // c would increase the size of the heap beyond the initial maximum length, the maximum value of // the heap is dropped. func (k *NKeeper) Keep(c ComparableDist) { if c.Dist <= k.Heap[0].Dist { // Favour later finds to displace sentinel. if len(k.Heap) == cap(k.Heap) { heap.Pop(k) } heap.Push(k, c) } } // DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the // query that it is called to Keep. type DistKeeper struct { Heap } // NewDistKeeper returns an DistKeeper with the maximum value of the heap set to d. func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} } // Keep adds c to the heap if its distance is less than or equal to the max value of the heap. func (k *DistKeeper) Keep(c ComparableDist) { if c.Dist <= k.Heap[0].Dist { heap.Push(k, c) } } // Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type. // vantage point search is guided by the distance stored in the max value of the heap. type Keeper interface { Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap. Max() ComparableDist // Max returns the maximum element of the Keeper. heap.Interface } // NearestSet finds the nearest values to the query accepted by the provided Keeper, k. // k must be able to return a ComparableDist specifying the maximum acceptable distance // when Max() is called, and retains the results of the search in min sorted order after // the call to NearestSet returns. // If a sentinel ComparableDist with a nil Comparable is used by the Keeper to mark the // maximum distance, NearestSet will remove it before returning. func (t *Tree) NearestSet(k Keeper, q Comparable) { if t.Root == nil { return } t.Root.searchSet(q, k) // Check whether we have retained a sentinel // and flag removal if we have. removeSentinel := k.Len() != 0 && k.Max().Comparable == nil sort.Sort(sort.Reverse(k)) // This abuses the interface to drop the max. // It is reasonable to do this because we know // that the maximum value will now be at element // zero, which is removed by the Pop method. if removeSentinel { k.Pop() } } func (n *Node) searchSet(q Comparable, k Keeper) { if n == nil { return } k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)}) d := q.Distance(n.Point) if d < n.Radius { n.Closer.searchSet(q, k) if d+k.Max().Dist >= n.Radius { n.Further.searchSet(q, k) } } else { n.Further.searchSet(q, k) if d-k.Max().Dist <= n.Radius { n.Closer.searchSet(q, k) } } } // Operation is a function that operates on a Comparable. The bounding volume and tree depth // of the point is also provided. If done is returned true, the Operation is indicating that no // further work needs to be done and so the Do function should traverse no further. type Operation func(Comparable, int) (done bool) // Do performs fn on all values stored in the tree. A boolean is returned indicating whether the // Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort // relationships, future tree operation behaviors are undefined. func (t *Tree) Do(fn Operation) bool { if t.Root == nil { return false } return t.Root.do(fn, 0) } func (n *Node) do(fn Operation, depth int) (done bool) { if n.Closer != nil { done = n.Closer.do(fn, depth+1) if done { return } } done = fn(n.Point, depth) if done { return } if n.Further != nil { done = n.Further.do(fn, depth+1) } return }