// Copyright ©2019 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package kdtree import ( "container/heap" "fmt" "math" "sort" ) // Interface is the set of methods required for construction of efficiently // searchable k-d trees. A k-d tree may be constructed without using the // Interface type, but it is likely to have reduced search performance. type Interface interface { // Index returns the ith element of the list of points. Index(i int) Comparable // Len returns the length of the list. Len() int // Pivot partitions the list based on the dimension specified. Pivot(Dim) int // Slice returns a slice of the list using zero-based half // open indexing equivalent to built-in slice indexing. Slice(start, end int) Interface } // Bounder returns a bounding volume containing the list of points. Bounds may return nil. type Bounder interface { Bounds() *Bounding } type bounder interface { Interface Bounder } // Dim is an index into a point's coordinates. type Dim int // Comparable is the element interface for values stored in a k-d tree. type Comparable interface { // Compare returns the signed distance of a from the plane passing through // b and perpendicular to the dimension d. // // Given c = a.Compare(b, d): // c = a_d - b_d // Compare(Comparable, Dim) float64 // Dims returns the number of dimensions described in the Comparable. Dims() int // Distance returns the squared Euclidean distance between the receiver and // the parameter. Distance(Comparable) float64 } // Extender is a Comparable that can increase a bounding volume to include the // point represented by the Comparable. type Extender interface { Comparable // Extend returns a bounding box that has been extended to include the // receiver. Extend may return nil. Extend(*Bounding) *Bounding } // Bounding represents a volume bounding box. type Bounding struct { Min, Max Comparable } // Contains returns whether c is within the volume of the Bounding. A nil Bounding // returns true. func (b *Bounding) Contains(c Comparable) bool { if b == nil { return true } for d := Dim(0); d < Dim(c.Dims()); d++ { if c.Compare(b.Min, d) < 0 || 0 < c.Compare(b.Max, d) { return false } } return true } // Node holds a single point value in a k-d tree. type Node struct { Point Comparable Plane Dim Left, Right *Node *Bounding } func (n *Node) String() string { if n == nil { return "" } return fmt.Sprintf("%.3f %d", n.Point, n.Plane) } // Tree implements a k-d tree creation and nearest neighbor search. type Tree struct { Root *Node Count int } // New returns a k-d tree constructed from the values in p. If p is a Bounder and // bounding is true, bounds are determined for each node. // The ordering of elements in p may be altered after New returns. func New(p Interface, bounding bool) *Tree { if p, ok := p.(bounder); ok && bounding { return &Tree{ Root: buildBounded(p, 0, bounding), Count: p.Len(), } } return &Tree{ Root: build(p, 0), Count: p.Len(), } } func build(p Interface, plane Dim) *Node { if p.Len() == 0 { return nil } piv := p.Pivot(plane) d := p.Index(piv) np := (plane + 1) % Dim(d.Dims()) return &Node{ Point: d, Plane: plane, Left: build(p.Slice(0, piv), np), Right: build(p.Slice(piv+1, p.Len()), np), Bounding: nil, } } func buildBounded(p bounder, plane Dim, bounding bool) *Node { if p.Len() == 0 { return nil } piv := p.Pivot(plane) d := p.Index(piv) np := (plane + 1) % Dim(d.Dims()) b := p.Bounds() return &Node{ Point: d, Plane: plane, Left: buildBounded(p.Slice(0, piv).(bounder), np, bounding), Right: buildBounded(p.Slice(piv+1, p.Len()).(bounder), np, bounding), Bounding: b, } } // Insert adds a point to the tree, updating the bounding volumes if bounding is // true, and the tree is empty or the tree already has bounding volumes stored, // and c is an Extender. No rebalancing of the tree is performed. func (t *Tree) Insert(c Comparable, bounding bool) { t.Count++ if t.Root != nil { bounding = t.Root.Bounding != nil } if c, ok := c.(Extender); ok && bounding { t.Root = t.Root.insertBounded(c, 0, bounding) return } else if !ok && t.Root != nil { // If we are not rebounding, mark the tree as non-bounded. t.Root.Bounding = nil } t.Root = t.Root.insert(c, 0) } func (n *Node) insert(c Comparable, d Dim) *Node { if n == nil { return &Node{ Point: c, Plane: d, Bounding: nil, } } d = (n.Plane + 1) % Dim(c.Dims()) if c.Compare(n.Point, n.Plane) <= 0 { n.Left = n.Left.insert(c, d) } else { n.Right = n.Right.insert(c, d) } return n } func (n *Node) insertBounded(c Extender, d Dim, bounding bool) *Node { if n == nil { var b *Bounding if bounding { b = c.Extend(b) } return &Node{ Point: c, Plane: d, Bounding: b, } } if bounding { n.Bounding = c.Extend(n.Bounding) } d = (n.Plane + 1) % Dim(c.Dims()) if c.Compare(n.Point, n.Plane) <= 0 { n.Left = n.Left.insertBounded(c, d, bounding) } else { n.Right = n.Right.insertBounded(c, d, bounding) } return n } // Len returns the number of elements in the tree. func (t *Tree) Len() int { return t.Count } // Contains returns whether a Comparable is in the bounds of the tree. If no bounding has // been constructed Contains returns true. func (t *Tree) Contains(c Comparable) bool { if t.Root.Bounding == nil { return true } return t.Root.Contains(c) } var inf = math.Inf(1) // Nearest returns the nearest value to the query and the distance between them. func (t *Tree) Nearest(q Comparable) (Comparable, float64) { if t.Root == nil { return nil, inf } n, dist := t.Root.search(q, inf) if n == nil { return nil, inf } return n.Point, dist } func (n *Node) search(q Comparable, dist float64) (*Node, float64) { if n == nil { return nil, inf } c := q.Compare(n.Point, n.Plane) dist = math.Min(dist, q.Distance(n.Point)) bn := n if c <= 0 { ln, ld := n.Left.search(q, dist) if ld < dist { bn, dist = ln, ld } if c*c < dist { rn, rd := n.Right.search(q, dist) if rd < dist { bn, dist = rn, rd } } return bn, dist } rn, rd := n.Right.search(q, dist) if rd < dist { bn, dist = rn, rd } if c*c < dist { ln, ld := n.Left.search(q, dist) if ld < dist { bn, dist = ln, ld } } return bn, dist } // ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable // is used to mark the end of the heap, so clients should not store nil values except for // this purpose. type ComparableDist struct { Comparable Comparable Dist float64 } // Heap is a max heap sorted on Dist. type Heap []ComparableDist func (h *Heap) Max() ComparableDist { return (*h)[0] } func (h *Heap) Len() int { return len(*h) } func (h *Heap) Less(i, j int) bool { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist } func (h *Heap) Swap(i, j int) { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] } func (h *Heap) Push(x interface{}) { (*h) = append(*h, x.(ComparableDist)) } func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i } // NKeeper is a Keeper that retains the n best ComparableDists that have been passed to Keep. type NKeeper struct { Heap } // NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The // returned NKeeper is able to retain at most n values. func NewNKeeper(n int) *NKeeper { k := NKeeper{make(Heap, 1, n)} k.Heap[0].Dist = inf return &k } // Keep adds c to the heap if its distance is less than the maximum value of the heap. If adding // c would increase the size of the heap beyond the initial maximum length, the maximum value of // the heap is dropped. func (k *NKeeper) Keep(c ComparableDist) { if c.Dist <= k.Heap[0].Dist { // Favour later finds to displace sentinel. if len(k.Heap) == cap(k.Heap) { heap.Pop(k) } heap.Push(k, c) } } // DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the // query that it is called to Keep. type DistKeeper struct { Heap } // NewDistKeeper returns an DistKeeper with the maximum value of the heap set to d. func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} } // Keep adds c to the heap if its distance is less than or equal to the max value of the heap. func (k *DistKeeper) Keep(c ComparableDist) { if c.Dist <= k.Heap[0].Dist { heap.Push(k, c) } } // Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type. // kd search is guided by the distance stored in the max value of the heap. type Keeper interface { Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap. Max() ComparableDist // Max returns the maximum element of the Keeper. heap.Interface } // NearestSet finds the nearest values to the query accepted by the provided Keeper, k. // k must be able to return a ComparableDist specifying the maximum acceptable distance // when Max() is called, and retains the results of the search in min sorted order after // the call to NearestSet returns. // If a sentinel ComparableDist with a nil Comparable is used by the Keeper to mark the // maximum distance, NearestSet will remove it before returning. func (t *Tree) NearestSet(k Keeper, q Comparable) { if t.Root == nil { return } t.Root.searchSet(q, k) // Check whether we have retained a sentinel // and flag removal if we have. removeSentinel := k.Len() != 0 && k.Max().Comparable == nil sort.Sort(sort.Reverse(k)) // This abuses the interface to drop the max. // It is reasonable to do this because we know // that the maximum value will now be at element // zero, which is removed by the Pop method. if removeSentinel { k.Pop() } } func (n *Node) searchSet(q Comparable, k Keeper) { if n == nil { return } c := q.Compare(n.Point, n.Plane) k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)}) if c <= 0 { n.Left.searchSet(q, k) if c*c <= k.Max().Dist { n.Right.searchSet(q, k) } return } n.Right.searchSet(q, k) if c*c <= k.Max().Dist { n.Left.searchSet(q, k) } } // Operation is a function that operates on a Comparable. The bounding volume and tree depth // of the point is also provided. If done is returned true, the Operation is indicating that no // further work needs to be done and so the Do function should traverse no further. type Operation func(Comparable, *Bounding, int) (done bool) // Do performs fn on all values stored in the tree. A boolean is returned indicating whether the // Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort // relationships, future tree operation behaviors are undefined. func (t *Tree) Do(fn Operation) bool { if t.Root == nil { return false } return t.Root.do(fn, 0) } func (n *Node) do(fn Operation, depth int) (done bool) { if n.Left != nil { done = n.Left.do(fn, depth+1) if done { return } } done = fn(n.Point, n.Bounding, depth) if done { return } if n.Right != nil { done = n.Right.do(fn, depth+1) } return } // DoBounded performs fn on all values stored in the tree that are within the specified bound. // If b is nil, the result is the same as a Do. A boolean is returned indicating whether the // DoBounded traversal was interrupted by an Operation returning true. If fn alters stored // values' sort relationships future tree operation behaviors are undefined. func (t *Tree) DoBounded(b *Bounding, fn Operation) bool { if t.Root == nil { return false } if b == nil { return t.Root.do(fn, 0) } return t.Root.doBounded(fn, b, 0) } func (n *Node) doBounded(fn Operation, b *Bounding, depth int) (done bool) { if n.Left != nil && b.Min.Compare(n.Point, n.Plane) < 0 { done = n.Left.doBounded(fn, b, depth+1) if done { return } } if b.Contains(n.Point) { done = fn(n.Point, n.Bounding, depth) if done { return } } if n.Right != nil && 0 < b.Max.Compare(n.Point, n.Plane) { done = n.Right.doBounded(fn, b, depth+1) } return }