mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 23:52:47 +08:00
spatial/kdtree: new package implementing k-d tree storage
This commit is contained in:
8
spatial/kdtree/doc.go
Normal file
8
spatial/kdtree/doc.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package kdtree implements a k-d tree.
|
||||||
|
//
|
||||||
|
// See https://en.wikipedia.org/wiki/K-d_tree for more details of k-d tree functionality.
|
||||||
|
package kdtree // import "gonum.org/v1/gonum/spatial/kdtree"
|
467
spatial/kdtree/kdtree.go
Normal file
467
spatial/kdtree/kdtree.go
Normal file
@@ -0,0 +1,467 @@
|
|||||||
|
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package kdtree
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/heap"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Interface is the set of methods required for construction of efficiently
|
||||||
|
// searchable k-d trees. A k-d tree may be constructed without using the
|
||||||
|
// Interface type, but it is likely to have reduced search performance.
|
||||||
|
type Interface interface {
|
||||||
|
// Index returns the ith element of the list of points.
|
||||||
|
Index(i int) Comparable
|
||||||
|
|
||||||
|
// Len returns the length of the list.
|
||||||
|
Len() int
|
||||||
|
|
||||||
|
// Pivot partitions the list based on the dimension specified.
|
||||||
|
Pivot(Dim) int
|
||||||
|
|
||||||
|
// Slice returns a slice of the list using zero-based half
|
||||||
|
// open indexing equivalent to built-in slice indexing.
|
||||||
|
Slice(start, end int) Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bounder returns a bounding volume containing the list of points. Bounds may return nil.
|
||||||
|
type Bounder interface {
|
||||||
|
Bounds() *Bounding
|
||||||
|
}
|
||||||
|
|
||||||
|
type bounder interface {
|
||||||
|
Interface
|
||||||
|
Bounder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dim is an index into a point's coordinates.
|
||||||
|
type Dim int
|
||||||
|
|
||||||
|
// Comparable is the element interface for values stored in a k-d tree.
|
||||||
|
type Comparable interface {
|
||||||
|
// Compare returns the signed distance of a from the plane passing through
|
||||||
|
// b and perpendicular to the dimension d.
|
||||||
|
//
|
||||||
|
// Given c = a.Compare(b, d):
|
||||||
|
// c = a_d - b_d
|
||||||
|
//
|
||||||
|
Compare(Comparable, Dim) float64
|
||||||
|
|
||||||
|
// Dims returns the number of dimensions described in the Comparable.
|
||||||
|
Dims() int
|
||||||
|
|
||||||
|
// Distance returns the squared Euclidean distance between the receiver and
|
||||||
|
// the parameter.
|
||||||
|
Distance(Comparable) float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extender is a Comparable that can increase a bounding volume to include the
|
||||||
|
// point represented by the Comparable.
|
||||||
|
type Extender interface {
|
||||||
|
Comparable
|
||||||
|
|
||||||
|
// Extend returns a bounding box that has been extended to include the
|
||||||
|
// receiver. Extend may return nil.
|
||||||
|
Extend(*Bounding) *Bounding
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bounding represents a volume bounding box.
|
||||||
|
type Bounding struct {
|
||||||
|
Min, Max Comparable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains returns whether c is within the volume of the Bounding. A nil Bounding
|
||||||
|
// returns true.
|
||||||
|
func (b *Bounding) Contains(c Comparable) bool {
|
||||||
|
if b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for d := Dim(0); d < Dim(c.Dims()); d++ {
|
||||||
|
if c.Compare(b.Min, d) < 0 || 0 < c.Compare(b.Max, d) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Node holds a single point value in a k-d tree.
|
||||||
|
type Node struct {
|
||||||
|
Point Comparable
|
||||||
|
Plane Dim
|
||||||
|
Left, Right *Node
|
||||||
|
*Bounding
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) String() string {
|
||||||
|
if n == nil {
|
||||||
|
return "<nil>"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.3f %d", n.Point, n.Plane)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tree implements a k-d tree creation and nearest neighbor search.
|
||||||
|
type Tree struct {
|
||||||
|
Root *Node
|
||||||
|
Count int
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a k-d tree constructed from the values in p. If p is a Bounder and
|
||||||
|
// bounding is true, bounds are determined for each node.
|
||||||
|
// The ordering of elements in p may be altered after New returns.
|
||||||
|
func New(p Interface, bounding bool) *Tree {
|
||||||
|
if p, ok := p.(bounder); ok && bounding {
|
||||||
|
return &Tree{
|
||||||
|
Root: buildBounded(p, 0, bounding),
|
||||||
|
Count: p.Len(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &Tree{
|
||||||
|
Root: build(p, 0),
|
||||||
|
Count: p.Len(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func build(p Interface, plane Dim) *Node {
|
||||||
|
if p.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
piv := p.Pivot(plane)
|
||||||
|
d := p.Index(piv)
|
||||||
|
np := (plane + 1) % Dim(d.Dims())
|
||||||
|
|
||||||
|
return &Node{
|
||||||
|
Point: d,
|
||||||
|
Plane: plane,
|
||||||
|
Left: build(p.Slice(0, piv), np),
|
||||||
|
Right: build(p.Slice(piv+1, p.Len()), np),
|
||||||
|
Bounding: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildBounded(p bounder, plane Dim, bounding bool) *Node {
|
||||||
|
if p.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
piv := p.Pivot(plane)
|
||||||
|
d := p.Index(piv)
|
||||||
|
np := (plane + 1) % Dim(d.Dims())
|
||||||
|
|
||||||
|
b := p.Bounds()
|
||||||
|
return &Node{
|
||||||
|
Point: d,
|
||||||
|
Plane: plane,
|
||||||
|
Left: buildBounded(p.Slice(0, piv).(bounder), np, bounding),
|
||||||
|
Right: buildBounded(p.Slice(piv+1, p.Len()).(bounder), np, bounding),
|
||||||
|
Bounding: b,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert adds a point to the tree, updating the bounding volumes if bounding is
|
||||||
|
// true, and the tree is empty or the tree already has bounding volumes stored,
|
||||||
|
// and c is an Extender. No rebalancing of the tree is performed.
|
||||||
|
func (t *Tree) Insert(c Comparable, bounding bool) {
|
||||||
|
t.Count++
|
||||||
|
if t.Root != nil {
|
||||||
|
bounding = t.Root.Bounding != nil
|
||||||
|
}
|
||||||
|
if c, ok := c.(Extender); ok && bounding {
|
||||||
|
t.Root = t.Root.insertBounded(c, 0, bounding)
|
||||||
|
return
|
||||||
|
} else if !ok && t.Root != nil {
|
||||||
|
// If we are not rebounding, mark the tree as non-bounded.
|
||||||
|
t.Root.Bounding = nil
|
||||||
|
}
|
||||||
|
t.Root = t.Root.insert(c, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) insert(c Comparable, d Dim) *Node {
|
||||||
|
if n == nil {
|
||||||
|
return &Node{
|
||||||
|
Point: c,
|
||||||
|
Plane: d,
|
||||||
|
Bounding: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d = (n.Plane + 1) % Dim(c.Dims())
|
||||||
|
if c.Compare(n.Point, n.Plane) <= 0 {
|
||||||
|
n.Left = n.Left.insert(c, d)
|
||||||
|
} else {
|
||||||
|
n.Right = n.Right.insert(c, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) insertBounded(c Extender, d Dim, bounding bool) *Node {
|
||||||
|
if n == nil {
|
||||||
|
var b *Bounding
|
||||||
|
if bounding {
|
||||||
|
b = c.Extend(b)
|
||||||
|
}
|
||||||
|
return &Node{
|
||||||
|
Point: c,
|
||||||
|
Plane: d,
|
||||||
|
Bounding: b,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if bounding {
|
||||||
|
n.Bounding = c.Extend(n.Bounding)
|
||||||
|
}
|
||||||
|
d = (n.Plane + 1) % Dim(c.Dims())
|
||||||
|
if c.Compare(n.Point, n.Plane) <= 0 {
|
||||||
|
n.Left = n.Left.insertBounded(c, d, bounding)
|
||||||
|
} else {
|
||||||
|
n.Right = n.Right.insertBounded(c, d, bounding)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of elements in the tree.
|
||||||
|
func (t *Tree) Len() int { return t.Count }
|
||||||
|
|
||||||
|
// Contains returns whether a Comparable is in the bounds of the tree. If no bounding has
|
||||||
|
// been constructed Contains returns true.
|
||||||
|
func (t *Tree) Contains(c Comparable) bool {
|
||||||
|
if t.Root.Bounding == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return t.Root.Contains(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
var inf = math.Inf(1)
|
||||||
|
|
||||||
|
// Nearest returns the nearest value to the query and the distance between them.
|
||||||
|
func (t *Tree) Nearest(q Comparable) (Comparable, float64) {
|
||||||
|
if t.Root == nil {
|
||||||
|
return nil, inf
|
||||||
|
}
|
||||||
|
n, dist := t.Root.search(q, inf)
|
||||||
|
if n == nil {
|
||||||
|
return nil, inf
|
||||||
|
}
|
||||||
|
return n.Point, dist
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) search(q Comparable, dist float64) (*Node, float64) {
|
||||||
|
if n == nil {
|
||||||
|
return nil, inf
|
||||||
|
}
|
||||||
|
|
||||||
|
c := q.Compare(n.Point, n.Plane)
|
||||||
|
dist = math.Min(dist, q.Distance(n.Point))
|
||||||
|
|
||||||
|
bn := n
|
||||||
|
if c <= 0 {
|
||||||
|
ln, ld := n.Left.search(q, dist)
|
||||||
|
if ld < dist {
|
||||||
|
bn, dist = ln, ld
|
||||||
|
}
|
||||||
|
if c*c < dist {
|
||||||
|
rn, rd := n.Right.search(q, dist)
|
||||||
|
if rd < dist {
|
||||||
|
bn, dist = rn, rd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bn, dist
|
||||||
|
}
|
||||||
|
rn, rd := n.Right.search(q, dist)
|
||||||
|
if rd < dist {
|
||||||
|
bn, dist = rn, rd
|
||||||
|
}
|
||||||
|
if c*c < dist {
|
||||||
|
ln, ld := n.Left.search(q, dist)
|
||||||
|
if ld < dist {
|
||||||
|
bn, dist = ln, ld
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bn, dist
|
||||||
|
}
|
||||||
|
|
||||||
|
// ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable
|
||||||
|
// is used to mark the end of the heap, so clients should not store nil values except for
|
||||||
|
// this purpose.
|
||||||
|
type ComparableDist struct {
|
||||||
|
Comparable Comparable
|
||||||
|
Dist float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heap is a max heap sorted on Dist.
|
||||||
|
type Heap []ComparableDist
|
||||||
|
|
||||||
|
func (h *Heap) Max() ComparableDist { return (*h)[0] }
|
||||||
|
func (h *Heap) Len() int { return len(*h) }
|
||||||
|
func (h *Heap) Less(i, j int) bool { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist }
|
||||||
|
func (h *Heap) Swap(i, j int) { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] }
|
||||||
|
func (h *Heap) Push(x interface{}) { (*h) = append(*h, x.(ComparableDist)) }
|
||||||
|
func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i }
|
||||||
|
|
||||||
|
// NKeeper is a Keeper that retains the n best ComparableDists that have been passed to Keep.
|
||||||
|
type NKeeper struct {
|
||||||
|
Heap
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The
|
||||||
|
// returned NKeeper is able to retain at most n values.
|
||||||
|
func NewNKeeper(n int) *NKeeper {
|
||||||
|
k := NKeeper{make(Heap, 1, n)}
|
||||||
|
k.Heap[0].Dist = inf
|
||||||
|
return &k
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep adds c to the heap if its distance is less than the maximum value of the heap. If adding
|
||||||
|
// c would increase the size of the heap beyond the initial maximum length, the maximum value of
|
||||||
|
// the heap is dropped.
|
||||||
|
func (k *NKeeper) Keep(c ComparableDist) {
|
||||||
|
if c.Dist <= k.Heap[0].Dist { // Favour later finds to displace sentinel.
|
||||||
|
if len(k.Heap) == cap(k.Heap) {
|
||||||
|
heap.Pop(k)
|
||||||
|
}
|
||||||
|
heap.Push(k, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the
|
||||||
|
// query that it is called to Keep.
|
||||||
|
type DistKeeper struct {
|
||||||
|
Heap
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDistKeeper returns an DistKeeper with the maximum value of the heap set to d.
|
||||||
|
func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} }
|
||||||
|
|
||||||
|
// Keep adds c to the heap if its distance is less than or equal to the max value of the heap.
|
||||||
|
func (k *DistKeeper) Keep(c ComparableDist) {
|
||||||
|
if c.Dist <= k.Heap[0].Dist {
|
||||||
|
heap.Push(k, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type.
|
||||||
|
// kd search is guided by the distance stored in the max value of the heap.
|
||||||
|
type Keeper interface {
|
||||||
|
Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap.
|
||||||
|
Max() ComparableDist // Max returns the maximum element of the Keeper.
|
||||||
|
heap.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
// NearestSet finds the nearest values to the query accepted by the provided Keeper, k.
|
||||||
|
// k must be able to return a ComparableDist specifying the maximum acceptable distance
|
||||||
|
// when Max() is called, and retains the results of the search in min sorted order after
|
||||||
|
// the call to NearestSet returns.
|
||||||
|
// If a sentinel ComparableDist with a nil Comparable is used by the Keeper to mark the
|
||||||
|
// maximum distance, NearestSet will remove it before returning.
|
||||||
|
func (t *Tree) NearestSet(k Keeper, q Comparable) {
|
||||||
|
if t.Root == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Root.searchSet(q, k)
|
||||||
|
|
||||||
|
// Check whether we have retained a sentinel
|
||||||
|
// and flag removal if we have.
|
||||||
|
removeSentinel := k.Len() != 0 && k.Max().Comparable == nil
|
||||||
|
|
||||||
|
sort.Sort(sort.Reverse(k))
|
||||||
|
|
||||||
|
// This abuses the interface to drop the max.
|
||||||
|
// It is reasonable to do this because we know
|
||||||
|
// that the maximum value will now be at element
|
||||||
|
// zero, which is removed by the Pop method.
|
||||||
|
if removeSentinel {
|
||||||
|
k.Pop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) searchSet(q Comparable, k Keeper) {
|
||||||
|
if n == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c := q.Compare(n.Point, n.Plane)
|
||||||
|
k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)})
|
||||||
|
if c <= 0 {
|
||||||
|
n.Left.searchSet(q, k)
|
||||||
|
if c*c <= k.Max().Dist {
|
||||||
|
n.Right.searchSet(q, k)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.Right.searchSet(q, k)
|
||||||
|
if c*c <= k.Max().Dist {
|
||||||
|
n.Left.searchSet(q, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operation is a function that operates on a Comparable. The bounding volume and tree depth
|
||||||
|
// of the point is also provided. If done is returned true, the Operation is indicating that no
|
||||||
|
// further work needs to be done and so the Do function should traverse no further.
|
||||||
|
type Operation func(Comparable, *Bounding, int) (done bool)
|
||||||
|
|
||||||
|
// Do performs fn on all values stored in the tree. A boolean is returned indicating whether the
|
||||||
|
// Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort
|
||||||
|
// relationships, future tree operation behaviors are undefined.
|
||||||
|
func (t *Tree) Do(fn Operation) bool {
|
||||||
|
if t.Root == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return t.Root.do(fn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) do(fn Operation, depth int) (done bool) {
|
||||||
|
if n.Left != nil {
|
||||||
|
done = n.Left.do(fn, depth+1)
|
||||||
|
if done {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
done = fn(n.Point, n.Bounding, depth)
|
||||||
|
if done {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n.Right != nil {
|
||||||
|
done = n.Right.do(fn, depth+1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoBounded performs fn on all values stored in the tree that are within the specified bound.
|
||||||
|
// If b is nil, the result is the same as a Do. A boolean is returned indicating whether the
|
||||||
|
// DoBounded traversal was interrupted by an Operation returning true. If fn alters stored
|
||||||
|
// values' sort relationships future tree operation behaviors are undefined.
|
||||||
|
func (t *Tree) DoBounded(b *Bounding, fn Operation) bool {
|
||||||
|
if t.Root == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if b == nil {
|
||||||
|
return t.Root.do(fn, 0)
|
||||||
|
}
|
||||||
|
return t.Root.doBounded(fn, b, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) doBounded(fn Operation, b *Bounding, depth int) (done bool) {
|
||||||
|
if n.Left != nil && b.Min.Compare(n.Point, n.Plane) < 0 {
|
||||||
|
done = n.Left.doBounded(fn, b, depth+1)
|
||||||
|
if done {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if b.Contains(n.Point) {
|
||||||
|
done = fn(n.Point, n.Bounding, depth)
|
||||||
|
if done {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if n.Right != nil && 0 < b.Max.Compare(n.Point, n.Plane) {
|
||||||
|
done = n.Right.doBounded(fn, b, depth+1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
77
spatial/kdtree/kdtree_simple_example_test.go
Normal file
77
spatial/kdtree/kdtree_simple_example_test.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package kdtree_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"gonum.org/v1/gonum/spatial/kdtree"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleTree() {
|
||||||
|
// Example data from https://en.wikipedia.org/wiki/K-d_tree
|
||||||
|
points := kdtree.Points{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}}
|
||||||
|
|
||||||
|
t := kdtree.New(points, false)
|
||||||
|
q := kdtree.Point{8, 7}
|
||||||
|
p, d := t.Nearest(q)
|
||||||
|
fmt.Printf("%v is closest point to %v, d=%f\n", p, q, math.Sqrt(d))
|
||||||
|
// Output:
|
||||||
|
// [9 6] is closest point to [8 7], d=1.414214
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleTree_bounds() {
|
||||||
|
// Example data from https://en.wikipedia.org/wiki/K-d_tree
|
||||||
|
points := kdtree.Points{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}}
|
||||||
|
|
||||||
|
t := kdtree.New(points, true)
|
||||||
|
fmt.Printf("Bounding box of points is %+v\n", t.Root.Bounding)
|
||||||
|
// Output:
|
||||||
|
// Bounding box of points is &{Min:[2 1] Max:[9 7]}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleTree_Do() {
|
||||||
|
// Example data from https://en.wikipedia.org/wiki/K-d_tree
|
||||||
|
points := kdtree.Points{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}}
|
||||||
|
|
||||||
|
// Print all points in the data set within 3 of (3, 5).
|
||||||
|
t := kdtree.New(points, false)
|
||||||
|
q := kdtree.Point{3, 5}
|
||||||
|
t.Do(func(c kdtree.Comparable, _ *kdtree.Bounding, _ int) (done bool) {
|
||||||
|
// Compare each distance and output points
|
||||||
|
// with a Euclidean distance less than 3.
|
||||||
|
// Distance returns the square of the
|
||||||
|
// Euclidean distance between points.
|
||||||
|
if q.Distance(c) <= 3*3 {
|
||||||
|
fmt.Println(c)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
// Unordered output:
|
||||||
|
// [2 3]
|
||||||
|
// [4 7]
|
||||||
|
// [5 4]
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleTree_DoBounded() {
|
||||||
|
// Example data from https://en.wikipedia.org/wiki/K-d_tree
|
||||||
|
points := kdtree.Points{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}}
|
||||||
|
|
||||||
|
// Find all points within the bounding box ((3, 3), (6, 8))
|
||||||
|
// and print them with their bounding boxes and tree depth.
|
||||||
|
t := kdtree.New(points, true) // Construct tree with bounding boxes.
|
||||||
|
b := &kdtree.Bounding{
|
||||||
|
Min: kdtree.Point{3, 3},
|
||||||
|
Max: kdtree.Point{6, 8},
|
||||||
|
}
|
||||||
|
t.DoBounded(b, func(c kdtree.Comparable, bound *kdtree.Bounding, depth int) (done bool) {
|
||||||
|
fmt.Printf("p=%v bound=%+v depth=%d\n", c, bound, depth)
|
||||||
|
return
|
||||||
|
})
|
||||||
|
// Output:
|
||||||
|
// p=[5 4] bound=&{Min:[2 3] Max:[5 7]} depth=1
|
||||||
|
// p=[4 7] bound=&{Min:[4 7] Max:[4 7]} depth=2
|
||||||
|
}
|
666
spatial/kdtree/kdtree_test.go
Normal file
666
spatial/kdtree/kdtree_test.go
Normal file
@@ -0,0 +1,666 @@
|
|||||||
|
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package kdtree
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/exp/rand"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
genDot = flag.Bool("dot", false, "generate dot code for failing trees")
|
||||||
|
dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format")
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572.
|
||||||
|
wpData = Points{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}}
|
||||||
|
nbWpData = nbPoints{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}}
|
||||||
|
wpBound = &Bounding{Point{2, 1}, Point{9, 7}}
|
||||||
|
)
|
||||||
|
|
||||||
|
var newTests = []struct {
|
||||||
|
data Interface
|
||||||
|
bounding bool
|
||||||
|
wantBounds *Bounding
|
||||||
|
}{
|
||||||
|
{data: wpData, bounding: false, wantBounds: nil},
|
||||||
|
{data: nbWpData, bounding: false, wantBounds: nil},
|
||||||
|
{data: wpData, bounding: true, wantBounds: wpBound},
|
||||||
|
{data: nbWpData, bounding: true, wantBounds: nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
for i, test := range newTests {
|
||||||
|
var tree *Tree
|
||||||
|
var panicked bool
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
panicked = true
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
tree = New(test.data, test.bounding)
|
||||||
|
}()
|
||||||
|
if panicked {
|
||||||
|
t.Errorf("unexpected panic for test %d", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tree.Root.isKDTree() {
|
||||||
|
t.Errorf("tree %d is not k-d tree", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch data := test.data.(type) {
|
||||||
|
case Points:
|
||||||
|
for _, p := range data {
|
||||||
|
if !tree.Contains(p) {
|
||||||
|
t.Errorf("failed to find point %.3f in test %d", p, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case nbPoints:
|
||||||
|
for _, p := range data {
|
||||||
|
if !tree.Contains(p) {
|
||||||
|
t.Errorf("failed to find point %.3f in test %d", p, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Fatalf("bad test: unknown data type: %T", test.data)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tree.Root.Bounding, test.wantBounds) {
|
||||||
|
t.Errorf("unexpected bounding box for test %d with data type %T: got:%v want:%v",
|
||||||
|
i, test.data, tree.Root.Bounding, test.wantBounds)
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.Failed() && *genDot && tree.Len() <= *dotLimit {
|
||||||
|
err := dotFile(tree, fmt.Sprintf("TestNew%T", test.data), "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write DOT file: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var insertTests = []struct {
|
||||||
|
data Interface
|
||||||
|
insert []Comparable
|
||||||
|
wantBounds *Bounding
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
data: wpData,
|
||||||
|
insert: []Comparable{Point{0, 0}, Point{10, 10}},
|
||||||
|
wantBounds: &Bounding{Point{0, 0}, Point{10, 10}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
data: nbWpData,
|
||||||
|
insert: []Comparable{nbPoint{0, 0}, nbPoint{10, 10}},
|
||||||
|
wantBounds: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInsert(t *testing.T) {
|
||||||
|
for i, test := range insertTests {
|
||||||
|
tree := New(test.data, true)
|
||||||
|
for _, v := range test.insert {
|
||||||
|
tree.Insert(v, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tree.Root.isKDTree() {
|
||||||
|
t.Errorf("tree %d is not k-d tree", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tree.Root.Bounding, test.wantBounds) {
|
||||||
|
t.Errorf("unexpected bounding box for test %d with data type %T: got:%v want:%v",
|
||||||
|
i, test.data, tree.Root.Bounding, test.wantBounds)
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.Failed() && *genDot && tree.Len() <= *dotLimit {
|
||||||
|
err := dotFile(tree, fmt.Sprintf("TestInsert%T", test.data), "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write DOT file: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type compFn func(float64) bool
|
||||||
|
|
||||||
|
func left(v float64) bool { return v <= 0 }
|
||||||
|
func right(v float64) bool { return !left(v) }
|
||||||
|
|
||||||
|
func (n *Node) isKDTree() bool {
|
||||||
|
if n == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
d := n.Point.Dims()
|
||||||
|
// Together these define the property of minimal orthogonal bounding.
|
||||||
|
if !(n.isContainedBy(n.Bounding) && n.Bounding.planesHaveCoincidentPointsIn(n, [2][]bool{make([]bool, d), make([]bool, d)})) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !n.Left.isPartitioned(n.Point, left, n.Plane) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !n.Right.isPartitioned(n.Point, right, n.Plane) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return n.Left.isKDTree() && n.Right.isKDTree()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) isPartitioned(pivot Comparable, fn compFn, plane Dim) bool {
|
||||||
|
if n == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if n.Left != nil && fn(pivot.Compare(n.Left.Point, plane)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if n.Right != nil && fn(pivot.Compare(n.Right.Point, plane)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return n.Left.isPartitioned(pivot, fn, plane) && n.Right.isPartitioned(pivot, fn, plane)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) isContainedBy(b *Bounding) bool {
|
||||||
|
if n == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if !b.Contains(n.Point) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return n.Left.isContainedBy(b) && n.Right.isContainedBy(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bounding) planesHaveCoincidentPointsIn(n *Node, tight [2][]bool) bool {
|
||||||
|
if b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if n == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
b.planesHaveCoincidentPointsIn(n.Left, tight)
|
||||||
|
b.planesHaveCoincidentPointsIn(n.Right, tight)
|
||||||
|
|
||||||
|
var ok = true
|
||||||
|
for i := range tight {
|
||||||
|
for d := 0; d < n.Point.Dims(); d++ {
|
||||||
|
if c := n.Point.Compare(b.Min, Dim(d)); c == 0 {
|
||||||
|
tight[i][d] = true
|
||||||
|
}
|
||||||
|
ok = ok && tight[i][d]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func nearest(q Point, p Points) (Point, float64) {
|
||||||
|
min := q.Distance(p[0])
|
||||||
|
var r int
|
||||||
|
for i := 1; i < p.Len(); i++ {
|
||||||
|
d := q.Distance(p[i])
|
||||||
|
if d < min {
|
||||||
|
min = d
|
||||||
|
r = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p[r], min
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNearestRandom(t *testing.T) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
|
||||||
|
const (
|
||||||
|
min = 0.0
|
||||||
|
max = 1000.0
|
||||||
|
|
||||||
|
dims = 4
|
||||||
|
setSize = 10000
|
||||||
|
)
|
||||||
|
|
||||||
|
var randData Points
|
||||||
|
for i := 0; i < setSize; i++ {
|
||||||
|
p := make(Point, dims)
|
||||||
|
for j := 0; j < dims; j++ {
|
||||||
|
p[j] = (max-min)*rnd.Float64() + min
|
||||||
|
}
|
||||||
|
randData = append(randData, p)
|
||||||
|
}
|
||||||
|
tree := New(randData, false)
|
||||||
|
|
||||||
|
for i := 0; i < setSize; i++ {
|
||||||
|
q := make(Point, dims)
|
||||||
|
for j := 0; j < dims; j++ {
|
||||||
|
q[j] = (max-min)*rnd.Float64() + min
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := tree.Nearest(q)
|
||||||
|
want, _ := nearest(q, randData)
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNearest(t *testing.T) {
|
||||||
|
tree := New(wpData, false)
|
||||||
|
for _, q := range append([]Point{
|
||||||
|
{4, 6},
|
||||||
|
{7, 5},
|
||||||
|
{8, 7},
|
||||||
|
{6, -5},
|
||||||
|
{1e5, 1e5},
|
||||||
|
{1e5, -1e5},
|
||||||
|
{-1e5, 1e5},
|
||||||
|
{-1e5, -1e5},
|
||||||
|
{1e5, 0},
|
||||||
|
{0, -1e5},
|
||||||
|
{0, 1e5},
|
||||||
|
{-1e5, 0},
|
||||||
|
}, wpData...) {
|
||||||
|
gotP, gotD := tree.Nearest(q)
|
||||||
|
wantP, wantD := nearest(q, wpData)
|
||||||
|
if !reflect.DeepEqual(gotP, wantP) {
|
||||||
|
t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP)
|
||||||
|
}
|
||||||
|
if gotD != wantD {
|
||||||
|
t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func nearestN(n int, q Point, p Points) []ComparableDist {
|
||||||
|
nk := NewNKeeper(n)
|
||||||
|
for i := 0; i < p.Len(); i++ {
|
||||||
|
nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])})
|
||||||
|
}
|
||||||
|
if len(nk.Heap) == 1 {
|
||||||
|
return nk.Heap
|
||||||
|
}
|
||||||
|
sort.Sort(nk)
|
||||||
|
for i, j := 0, len(nk.Heap)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
nk.Heap[i], nk.Heap[j] = nk.Heap[j], nk.Heap[i]
|
||||||
|
}
|
||||||
|
return nk.Heap
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNearestSetN(t *testing.T) {
|
||||||
|
data := append([]Point{
|
||||||
|
{4, 6},
|
||||||
|
{7, 5},
|
||||||
|
{8, 7},
|
||||||
|
{6, -5},
|
||||||
|
{1e5, 1e5},
|
||||||
|
{1e5, -1e5},
|
||||||
|
{-1e5, 1e5},
|
||||||
|
{-1e5, -1e5},
|
||||||
|
{1e5, 0},
|
||||||
|
{0, -1e5},
|
||||||
|
{0, 1e5},
|
||||||
|
{-1e5, 0}},
|
||||||
|
wpData[:len(wpData)-1]...)
|
||||||
|
|
||||||
|
tree := New(wpData, false)
|
||||||
|
for k := 1; k <= len(wpData); k++ {
|
||||||
|
for _, q := range data {
|
||||||
|
wantP := nearestN(k, q, wpData)
|
||||||
|
|
||||||
|
nk := NewNKeeper(k)
|
||||||
|
tree.NearestSet(nk, q)
|
||||||
|
|
||||||
|
var max float64
|
||||||
|
wantD := make(map[float64]map[string]struct{})
|
||||||
|
for _, p := range wantP {
|
||||||
|
if p.Dist > max {
|
||||||
|
max = p.Dist
|
||||||
|
}
|
||||||
|
d, ok := wantD[p.Dist]
|
||||||
|
if !ok {
|
||||||
|
d = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
d[fmt.Sprint(p.Comparable)] = struct{}{}
|
||||||
|
wantD[p.Dist] = d
|
||||||
|
}
|
||||||
|
gotD := make(map[float64]map[string]struct{})
|
||||||
|
for _, p := range nk.Heap {
|
||||||
|
if p.Dist > max {
|
||||||
|
t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max)
|
||||||
|
}
|
||||||
|
d, ok := gotD[p.Dist]
|
||||||
|
if !ok {
|
||||||
|
d = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
d[fmt.Sprint(p.Comparable)] = struct{}{}
|
||||||
|
gotD[p.Dist] = d
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the available number of slots does not fit all the coequal furthest points
|
||||||
|
// we will fail the check. So remove, but check them minimally here.
|
||||||
|
if !reflect.DeepEqual(wantD[max], gotD[max]) {
|
||||||
|
// The best we can do at this stage is confirm that there are an equal number of matches at this distance.
|
||||||
|
if len(gotD[max]) != len(wantD[max]) {
|
||||||
|
t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max]))
|
||||||
|
}
|
||||||
|
delete(wantD, max)
|
||||||
|
delete(gotD, max)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(gotD, wantD) {
|
||||||
|
t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var nearestSetDistTests = []Point{
|
||||||
|
{4, 6},
|
||||||
|
{7, 5},
|
||||||
|
{8, 7},
|
||||||
|
{6, -5},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNearestSetDist(t *testing.T) {
|
||||||
|
tree := New(wpData, false)
|
||||||
|
for i, q := range nearestSetDistTests {
|
||||||
|
for d := 1.0; d < 100; d += 0.1 {
|
||||||
|
dk := NewDistKeeper(d)
|
||||||
|
tree.NearestSet(dk, q)
|
||||||
|
|
||||||
|
hits := make(map[string]float64)
|
||||||
|
for _, p := range wpData {
|
||||||
|
hits[fmt.Sprint(p)] = p.Distance(q)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range dk.Heap {
|
||||||
|
var done bool
|
||||||
|
if p.Comparable == nil {
|
||||||
|
done = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(hits, fmt.Sprint(p.Comparable))
|
||||||
|
if done {
|
||||||
|
t.Error("expectedly finished heap iteration")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
dist := p.Comparable.Distance(q)
|
||||||
|
if dist > d {
|
||||||
|
t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for p, dist := range hits {
|
||||||
|
if dist <= d {
|
||||||
|
t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDo(t *testing.T) {
|
||||||
|
tree := New(wpData, false)
|
||||||
|
var got Points
|
||||||
|
fn := func(c Comparable, _ *Bounding, _ int) (done bool) {
|
||||||
|
got = append(got, c.(Point))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
killed := tree.Do(fn)
|
||||||
|
if !reflect.DeepEqual(got, wpData) {
|
||||||
|
t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, wpData)
|
||||||
|
}
|
||||||
|
if killed {
|
||||||
|
t.Error("tree iteration unexpectedly killed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var doBoundedTests = []struct {
|
||||||
|
bounds *Bounding
|
||||||
|
want Points
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
bounds: nil,
|
||||||
|
want: wpData,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
bounds: &Bounding{Point{0, 0}, Point{10, 10}},
|
||||||
|
want: wpData,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
bounds: &Bounding{Point{3, 4}, Point{10, 10}},
|
||||||
|
want: Points{Point{5, 4}, Point{4, 7}, Point{9, 6}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
bounds: &Bounding{Point{3, 3}, Point{10, 10}},
|
||||||
|
want: Points{Point{5, 4}, Point{4, 7}, Point{9, 6}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
bounds: &Bounding{Point{0, 0}, Point{6, 5}},
|
||||||
|
want: Points{Point{2, 3}, Point{5, 4}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
bounds: &Bounding{Point{5, 2}, Point{7, 4}},
|
||||||
|
want: Points{Point{5, 4}, Point{7, 2}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
bounds: &Bounding{Point{2, 2}, Point{7, 4}},
|
||||||
|
want: Points{Point{2, 3}, Point{5, 4}, Point{7, 2}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
bounds: &Bounding{Point{2, 3}, Point{9, 6}},
|
||||||
|
want: Points{Point{2, 3}, Point{5, 4}, Point{9, 6}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
bounds: &Bounding{Point{7, 2}, Point{7, 2}},
|
||||||
|
want: Points{Point{7, 2}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDoBounded(t *testing.T) {
|
||||||
|
for _, test := range doBoundedTests {
|
||||||
|
tree := New(wpData, false)
|
||||||
|
var got Points
|
||||||
|
fn := func(c Comparable, _ *Bounding, _ int) (done bool) {
|
||||||
|
got = append(got, c.(Point))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
killed := tree.DoBounded(test.bounds, fn)
|
||||||
|
if !reflect.DeepEqual(got, test.want) {
|
||||||
|
t.Errorf("unexpected result from bounded tree iteration: got:%v want:%v", got, test.want)
|
||||||
|
}
|
||||||
|
if killed {
|
||||||
|
t.Error("tree iteration unexpectedly killed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkNew(b *testing.B) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
p := make(Points, 1e5)
|
||||||
|
for i := range p {
|
||||||
|
p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = New(p, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkNewBounds(b *testing.B) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
p := make(Points, 1e5)
|
||||||
|
for i := range p {
|
||||||
|
p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = New(p, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkInsert(b *testing.B) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
t := &Tree{}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
t.Insert(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkInsertBounds(b *testing.B) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
t := &Tree{}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
t.Insert(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Benchmark(b *testing.B) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
data := make(Points, 1e2)
|
||||||
|
for i := range data {
|
||||||
|
data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
|
||||||
|
}
|
||||||
|
tree := New(data, true)
|
||||||
|
|
||||||
|
if !tree.Root.isKDTree() {
|
||||||
|
b.Fatal("tree is not k-d tree")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 1e3; i++ {
|
||||||
|
q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
|
||||||
|
gotP, gotD := tree.Nearest(q)
|
||||||
|
wantP, wantD := nearest(q, data)
|
||||||
|
if !reflect.DeepEqual(gotP, wantP) {
|
||||||
|
b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP)
|
||||||
|
}
|
||||||
|
if gotD != wantD {
|
||||||
|
b.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.Failed() && *genDot && tree.Len() <= *dotLimit {
|
||||||
|
err := dotFile(tree, "TestBenches", "")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to write DOT file: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var r Comparable
|
||||||
|
var d float64
|
||||||
|
queryBenchmarks := []struct {
|
||||||
|
name string
|
||||||
|
fn func(*testing.B)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Nearest", fn: func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()})
|
||||||
|
}
|
||||||
|
if r == nil {
|
||||||
|
b.Error("unexpected nil result")
|
||||||
|
}
|
||||||
|
if math.IsNaN(d) {
|
||||||
|
b.Error("unexpected NaN result")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NearestBrute", fn: func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data)
|
||||||
|
}
|
||||||
|
if r == nil {
|
||||||
|
b.Error("unexpected nil result")
|
||||||
|
}
|
||||||
|
if math.IsNaN(d) {
|
||||||
|
b.Error("unexpected NaN result")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NearestSetN10", fn: func(b *testing.B) {
|
||||||
|
nk := NewNKeeper(10)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()})
|
||||||
|
}
|
||||||
|
if nk.Len() != 10 {
|
||||||
|
b.Error("unexpected result length")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NearestBruteN10", fn: func(b *testing.B) {
|
||||||
|
var r []ComparableDist
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data)
|
||||||
|
}
|
||||||
|
if len(r) != 10 {
|
||||||
|
b.Error("unexpected result length", len(r))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, bench := range queryBenchmarks {
|
||||||
|
b.Run(bench.name, bench.fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dot(t *Tree, label string) string {
|
||||||
|
if t == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
s []string
|
||||||
|
follow func(*Node)
|
||||||
|
)
|
||||||
|
follow = func(n *Node) {
|
||||||
|
id := uintptr(unsafe.Pointer(n))
|
||||||
|
c := fmt.Sprintf("%d[label = \"<Left> |<Elem> %s/%.3f\\n%.3f|<Right>\"];",
|
||||||
|
id, n, n.Point.(Point)[n.Plane], *n.Bounding)
|
||||||
|
if n.Left != nil {
|
||||||
|
c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Left -> \"%d\":Elem;",
|
||||||
|
id, uintptr(unsafe.Pointer(n.Left)))
|
||||||
|
follow(n.Left)
|
||||||
|
}
|
||||||
|
if n.Right != nil {
|
||||||
|
c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Right -> \"%d\":Elem;",
|
||||||
|
id, uintptr(unsafe.Pointer(n.Right)))
|
||||||
|
follow(n.Right)
|
||||||
|
}
|
||||||
|
s = append(s, c)
|
||||||
|
}
|
||||||
|
if t.Root != nil {
|
||||||
|
follow(t.Root)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n",
|
||||||
|
label,
|
||||||
|
strings.Join(s, "\n\t"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dotFile(t *Tree, label, dotString string) (err error) {
|
||||||
|
if t == nil && dotString == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f, err := os.Create(label + ".dot")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
if dotString == "" {
|
||||||
|
fmt.Fprintf(f, dot(t, label))
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(f, dotString)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
163
spatial/kdtree/kdtree_user_type_example_test.go
Normal file
163
spatial/kdtree/kdtree_user_type_example_test.go
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package kdtree_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"gonum.org/v1/gonum/spatial/kdtree"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Example_accessiblePublicTransport() {
|
||||||
|
// Construct a k-d tree of train station locations
|
||||||
|
// to identify accessible public transport for the
|
||||||
|
// elderly.
|
||||||
|
t := kdtree.New(stations, false)
|
||||||
|
|
||||||
|
// Residence.
|
||||||
|
q := place{lat: 51.501476, lon: -0.140634}
|
||||||
|
|
||||||
|
var keep kdtree.Keeper
|
||||||
|
|
||||||
|
// Find all stations within 0.75 of the residence.
|
||||||
|
keep = kdtree.NewDistKeeper(0.75 * 0.75) // Distances are squared.
|
||||||
|
t.NearestSet(keep, q)
|
||||||
|
|
||||||
|
fmt.Println(`Stations within 750 m of 51.501476N 0.140634W.`)
|
||||||
|
for _, c := range keep.(*kdtree.DistKeeper).Heap {
|
||||||
|
p := c.Comparable.(place)
|
||||||
|
fmt.Printf("%s: %0.3f km\n", p.name, math.Sqrt(p.Distance(q)))
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
// Find the five closest stations to the residence.
|
||||||
|
keep = kdtree.NewNKeeper(5)
|
||||||
|
t.NearestSet(keep, q)
|
||||||
|
|
||||||
|
fmt.Println(`5 closest stations to 51.501476N 0.140634W.`)
|
||||||
|
for _, c := range keep.(*kdtree.NKeeper).Heap {
|
||||||
|
p := c.Comparable.(place)
|
||||||
|
fmt.Printf("%s: %0.3f km\n", p.name, math.Sqrt(p.Distance(q)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
//
|
||||||
|
// Stations within 750 m of 51.501476N 0.140634W.
|
||||||
|
// St. James's Park: 0.545 km
|
||||||
|
// Green Park: 0.600 km
|
||||||
|
// Victoria: 0.621 km
|
||||||
|
//
|
||||||
|
// 5 closest stations to 51.501476N 0.140634W.
|
||||||
|
// St. James's Park: 0.545 km
|
||||||
|
// Green Park: 0.600 km
|
||||||
|
// Victoria: 0.621 km
|
||||||
|
// Hyde Park Corner: 0.846 km
|
||||||
|
// Picadilly Circus: 1.027 km
|
||||||
|
}
|
||||||
|
|
||||||
|
// stations is a list of railways stations satisfying the
|
||||||
|
// kdtree.Interface.
|
||||||
|
var stations = places{
|
||||||
|
{name: "Bond Street", lat: 51.5142, lon: -0.1494},
|
||||||
|
{name: "Charing Cross", lat: 51.508, lon: -0.1247},
|
||||||
|
{name: "Covent Garden", lat: 51.5129, lon: -0.1243},
|
||||||
|
{name: "Embankment", lat: 51.5074, lon: -0.1223},
|
||||||
|
{name: "Green Park", lat: 51.5067, lon: -0.1428},
|
||||||
|
{name: "Hyde Park Corner", lat: 51.5027, lon: -0.1527},
|
||||||
|
{name: "Leicester Square", lat: 51.5113, lon: -0.1281},
|
||||||
|
{name: "Marble Arch", lat: 51.5136, lon: -0.1586},
|
||||||
|
{name: "Oxford Circus", lat: 51.515, lon: -0.1415},
|
||||||
|
{name: "Picadilly Circus", lat: 51.5098, lon: -0.1342},
|
||||||
|
{name: "Pimlico", lat: 51.4893, lon: -0.1334},
|
||||||
|
{name: "Sloane Square", lat: 51.4924, lon: -0.1565},
|
||||||
|
{name: "South Kensington", lat: 51.4941, lon: -0.1738},
|
||||||
|
{name: "St. James's Park", lat: 51.4994, lon: -0.1335},
|
||||||
|
{name: "Temple", lat: 51.5111, lon: -0.1141},
|
||||||
|
{name: "Tottenham Court Road", lat: 51.5165, lon: -0.131},
|
||||||
|
{name: "Vauxhall", lat: 51.4861, lon: -0.1253},
|
||||||
|
{name: "Victoria", lat: 51.4965, lon: -0.1447},
|
||||||
|
{name: "Waterloo", lat: 51.5036, lon: -0.1143},
|
||||||
|
{name: "Westminster", lat: 51.501, lon: -0.1254},
|
||||||
|
}
|
||||||
|
|
||||||
|
// place is a kdtree.Comparable implementations.
|
||||||
|
type place struct {
|
||||||
|
name string
|
||||||
|
lat, lon float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare satisfies the axis comparisons method of the kdtree.Comparable interface.
|
||||||
|
// The dimensions are:
|
||||||
|
// 0 = lat
|
||||||
|
// 1 = lon
|
||||||
|
func (p place) Compare(c kdtree.Comparable, d kdtree.Dim) float64 {
|
||||||
|
q := c.(place)
|
||||||
|
switch d {
|
||||||
|
case 0:
|
||||||
|
return p.lat - q.lat
|
||||||
|
case 1:
|
||||||
|
return p.lon - q.lon
|
||||||
|
default:
|
||||||
|
panic("illegal dimension")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dims returns the number of dimensions to be considered.
|
||||||
|
func (p place) Dims() int { return 2 }
|
||||||
|
|
||||||
|
// Distance returns the distance between the receiver and c.
|
||||||
|
func (p place) Distance(c kdtree.Comparable) float64 {
|
||||||
|
q := c.(place)
|
||||||
|
d := haversine(p.lat, p.lon, q.lat, q.lon)
|
||||||
|
return d * d
|
||||||
|
}
|
||||||
|
|
||||||
|
// haversine returns the distance between two geographic coordinates.
|
||||||
|
func haversine(lat1, lon1, lat2, lon2 float64) float64 {
|
||||||
|
const r = 6371 // km
|
||||||
|
sdLat := math.Sin(radians(lat2-lat1) / 2)
|
||||||
|
sdLon := math.Sin(radians(lon2-lon1) / 2)
|
||||||
|
a := sdLat*sdLat + math.Cos(radians(lat1))*math.Cos(radians(lat2))*sdLon*sdLon
|
||||||
|
d := 2 * r * math.Asin(math.Sqrt(a))
|
||||||
|
return d // km
|
||||||
|
}
|
||||||
|
|
||||||
|
func radians(d float64) float64 {
|
||||||
|
return d * math.Pi / 180
|
||||||
|
}
|
||||||
|
|
||||||
|
// places is a collection of the place type that satisfies kdtree.Interface.
|
||||||
|
type places []place
|
||||||
|
|
||||||
|
func (p places) Index(i int) kdtree.Comparable { return p[i] }
|
||||||
|
func (p places) Len() int { return len(p) }
|
||||||
|
func (p places) Pivot(d kdtree.Dim) int { return plane{places: p, Dim: d}.Pivot() }
|
||||||
|
func (p places) Slice(start, end int) kdtree.Interface { return p[start:end] }
|
||||||
|
|
||||||
|
// plane is required to help places.
|
||||||
|
type plane struct {
|
||||||
|
kdtree.Dim
|
||||||
|
places
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p plane) Less(i, j int) bool {
|
||||||
|
switch p.Dim {
|
||||||
|
case 0:
|
||||||
|
return p.places[i].lat < p.places[j].lat
|
||||||
|
case 1:
|
||||||
|
return p.places[i].lon < p.places[j].lon
|
||||||
|
default:
|
||||||
|
panic("illegal dimension")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (p plane) Pivot() int { return kdtree.Partition(p, kdtree.MedianOfMedians(p)) }
|
||||||
|
func (p plane) Slice(start, end int) kdtree.SortSlicer {
|
||||||
|
p.places = p.places[start:end]
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
func (p plane) Swap(i, j int) {
|
||||||
|
p.places[i], p.places[j] = p.places[j], p.places[i]
|
||||||
|
}
|
105
spatial/kdtree/medians.go
Normal file
105
spatial/kdtree/medians.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package kdtree
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"golang.org/x/exp/rand"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Partition partitions list such that all elements less than the value at pivot prior to the
|
||||||
|
// call are placed before that element and all elements greater than that value are placed after it.
|
||||||
|
// The final location of the element at pivot prior to the call is returned.
|
||||||
|
func Partition(list sort.Interface, pivot int) int {
|
||||||
|
var index, last int
|
||||||
|
if last = list.Len() - 1; last < 0 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
list.Swap(pivot, last)
|
||||||
|
for i := 0; i < last; i++ {
|
||||||
|
if !list.Less(last, i) {
|
||||||
|
list.Swap(index, i)
|
||||||
|
index++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
list.Swap(last, index)
|
||||||
|
return index
|
||||||
|
}
|
||||||
|
|
||||||
|
// SortSlicer satisfies the sort.Interface and is able to slice itself.
|
||||||
|
type SortSlicer interface {
|
||||||
|
sort.Interface
|
||||||
|
Slice(start, end int) SortSlicer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select partitions list such that all elements less than the kth largest element are
|
||||||
|
// placed placed before k in the resulting list and all elements greater than it are placed
|
||||||
|
// after the position k.
|
||||||
|
func Select(list SortSlicer, k int) int {
|
||||||
|
var (
|
||||||
|
start int
|
||||||
|
end = list.Len()
|
||||||
|
)
|
||||||
|
if k >= end {
|
||||||
|
if k == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
panic("kdtree: index out of range")
|
||||||
|
}
|
||||||
|
if start == end-1 {
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
if start == end {
|
||||||
|
panic("kdtree: internal inconsistency")
|
||||||
|
}
|
||||||
|
sub := list.Slice(start, end)
|
||||||
|
pivot := Partition(sub, rand.Intn(sub.Len()))
|
||||||
|
switch {
|
||||||
|
case pivot == k:
|
||||||
|
return k
|
||||||
|
case k < pivot:
|
||||||
|
end = pivot + start
|
||||||
|
default:
|
||||||
|
k -= pivot
|
||||||
|
start += pivot
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a > b {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
// MedianOfMedians returns the index to the median value of the medians of groups of 5 consecutive elements.
|
||||||
|
func MedianOfMedians(list SortSlicer) int {
|
||||||
|
n := list.Len() / 5
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
left := i * 5
|
||||||
|
sub := list.Slice(left, min(left+5, list.Len()-1))
|
||||||
|
Select(sub, 2)
|
||||||
|
list.Swap(i, left+2)
|
||||||
|
}
|
||||||
|
Select(list.Slice(0, min(n, list.Len()-1)), min(list.Len(), n/2))
|
||||||
|
return n / 2
|
||||||
|
}
|
||||||
|
|
||||||
|
// MedianOfRandoms returns the index to the median value of up to n randomly chosen elements in list.
|
||||||
|
func MedianOfRandoms(list SortSlicer, n int) int {
|
||||||
|
if l := list.Len(); n <= l {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
list.Swap(i, rand.Intn(n))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n = l
|
||||||
|
}
|
||||||
|
Select(list.Slice(0, n), n/2)
|
||||||
|
return n / 2
|
||||||
|
}
|
187
spatial/kdtree/medians_test.go
Normal file
187
spatial/kdtree/medians_test.go
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package kdtree
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/exp/rand"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ints []int
|
||||||
|
|
||||||
|
func (a ints) Len() int { return len(a) }
|
||||||
|
func (a ints) Less(i, j int) bool { return a[i] < a[j] }
|
||||||
|
func (a ints) Slice(s, e int) SortSlicer { return a[s:e] }
|
||||||
|
func (a ints) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||||
|
|
||||||
|
func TestPartition(t *testing.T) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
|
||||||
|
for p := 0; p < 100; p++ {
|
||||||
|
list := make(ints, 1e5)
|
||||||
|
for i := range list {
|
||||||
|
list[i] = rnd.Int()
|
||||||
|
}
|
||||||
|
pi := Partition(list, rnd.Intn(list.Len()))
|
||||||
|
for i := 0; i < pi; i++ {
|
||||||
|
if list[i] > list[pi] {
|
||||||
|
t.Errorf("unexpected partition sort order p[%d] > p[%d]: %d > %d", i, pi, list[i], list[pi])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := pi + 1; i < len(list); i++ {
|
||||||
|
if list[i] <= list[pi] {
|
||||||
|
t.Errorf("unexpected partition sort order p[%d] <= p[%d]: %d <= %d", i, pi, list[i], list[pi])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPartitionCollision(t *testing.T) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
|
||||||
|
for p := 0; p < 10; p++ {
|
||||||
|
list := make(ints, 10)
|
||||||
|
for i := range list {
|
||||||
|
list[i] = rnd.Intn(5)
|
||||||
|
}
|
||||||
|
pi := Partition(list, p)
|
||||||
|
for i := 0; i < pi; i++ {
|
||||||
|
if list[i] > list[pi] {
|
||||||
|
t.Errorf("unexpected partition sort order p[%d] > p[%d]: %d > %d", i, pi, list[i], list[pi])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := pi + 1; i < len(list); i++ {
|
||||||
|
if list[i] <= list[pi] {
|
||||||
|
t.Errorf("unexpected partition sort order p[%d] <= p[%d]: %d <= %d", i, pi, list[i], list[pi])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortSelection(list ints, k int) int {
|
||||||
|
sort.Sort(list)
|
||||||
|
return list[k]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelect(t *testing.T) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
|
||||||
|
for k := 0; k < 2121; k++ {
|
||||||
|
list := make(ints, 2121)
|
||||||
|
for i := range list {
|
||||||
|
list[i] = rnd.Intn(1000)
|
||||||
|
}
|
||||||
|
Select(list, k)
|
||||||
|
sorted := append(ints(nil), list...)
|
||||||
|
want := sortSelection(sorted, k)
|
||||||
|
if list[k] != want {
|
||||||
|
t.Errorf("unexpected result from Select(..., %d): got:%v want:%d", k, list[k], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMedianOfMedians(t *testing.T) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
|
||||||
|
list := make(ints, 1e4)
|
||||||
|
for i := range list {
|
||||||
|
list[i] = rnd.Int()
|
||||||
|
}
|
||||||
|
p := MedianOfMedians(list)
|
||||||
|
med := list[p]
|
||||||
|
sort.Sort(list)
|
||||||
|
var found bool
|
||||||
|
for _, v := range list[len(list)*3/10 : len(list)*7/10+1] {
|
||||||
|
if v == med {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("failed to find median")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMedianOfRandoms(t *testing.T) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
|
||||||
|
list := make(ints, 1e4)
|
||||||
|
for i := range list {
|
||||||
|
list[i] = rnd.Int()
|
||||||
|
}
|
||||||
|
p := MedianOfRandoms(list, randoms)
|
||||||
|
med := list[p]
|
||||||
|
sort.Sort(list)
|
||||||
|
var found bool
|
||||||
|
for _, v := range list[len(list)*3/10 : len(list)*7/10+1] {
|
||||||
|
if v == med {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("failed to find median")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var benchSink int
|
||||||
|
|
||||||
|
func BenchmarkMedianOfMedians(b *testing.B) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
b.StopTimer()
|
||||||
|
list := make(ints, 1e4)
|
||||||
|
for i := range list {
|
||||||
|
list[i] = rnd.Int()
|
||||||
|
}
|
||||||
|
b.StartTimer()
|
||||||
|
benchSink = MedianOfMedians(list)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPartitionMedianOfMedians(b *testing.B) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
b.StopTimer()
|
||||||
|
list := make(ints, 1e4)
|
||||||
|
for i := range list {
|
||||||
|
list[i] = rnd.Int()
|
||||||
|
}
|
||||||
|
b.StartTimer()
|
||||||
|
benchSink = Partition(list, MedianOfMedians(list))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMedianOfRandoms(b *testing.B) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
|
||||||
|
b.StopTimer()
|
||||||
|
list := make(ints, 1e4)
|
||||||
|
for i := range list {
|
||||||
|
list[i] = rnd.Int()
|
||||||
|
}
|
||||||
|
b.StartTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
benchSink = MedianOfRandoms(list, list.Len()/1e3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPartitionMedianOfRandoms(b *testing.B) {
|
||||||
|
rnd := rand.New(rand.NewSource(1))
|
||||||
|
|
||||||
|
b.StopTimer()
|
||||||
|
list := make(ints, 1e4)
|
||||||
|
for i := range list {
|
||||||
|
list[i] = rnd.Int()
|
||||||
|
}
|
||||||
|
b.StartTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
benchSink = Partition(list, MedianOfRandoms(list, list.Len()/1e3))
|
||||||
|
}
|
||||||
|
}
|
50
spatial/kdtree/nbpoints_test.go
Normal file
50
spatial/kdtree/nbpoints_test.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package kdtree
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ Interface = nbPoints{}
|
||||||
|
_ Comparable = nbPoint{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// nbRandoms is the maximum number of random values to sample for calculation of median of
|
||||||
|
// random elements.
|
||||||
|
var nbRandoms = 100
|
||||||
|
|
||||||
|
// nbPoint represents a point in a k-d space that satisfies the Comparable interface.
|
||||||
|
type nbPoint Point
|
||||||
|
|
||||||
|
func (p nbPoint) Compare(c Comparable, d Dim) float64 { q := c.(nbPoint); return p[d] - q[d] }
|
||||||
|
func (p nbPoint) Dims() int { return len(p) }
|
||||||
|
func (p nbPoint) Distance(c Comparable) float64 {
|
||||||
|
q := c.(nbPoint)
|
||||||
|
var sum float64
|
||||||
|
for dim, c := range p {
|
||||||
|
d := c - q[dim]
|
||||||
|
sum += d * d
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
// nbPoints is a collection of point values that satisfies the Interface.
|
||||||
|
type nbPoints []nbPoint
|
||||||
|
|
||||||
|
func (p nbPoints) Index(i int) Comparable { return p[i] }
|
||||||
|
func (p nbPoints) Len() int { return len(p) }
|
||||||
|
func (p nbPoints) Pivot(d Dim) int { return nbPlane{nbPoints: p, Dim: d}.Pivot() }
|
||||||
|
func (p nbPoints) Slice(start, end int) Interface { return p[start:end] }
|
||||||
|
|
||||||
|
// nbPlane is a wrapping type that allows a Points type be pivoted on a dimension.
|
||||||
|
type nbPlane struct {
|
||||||
|
Dim
|
||||||
|
nbPoints
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p nbPlane) Less(i, j int) bool { return p.nbPoints[i][p.Dim] < p.nbPoints[j][p.Dim] }
|
||||||
|
func (p nbPlane) Pivot() int { return Partition(p, MedianOfRandoms(p, nbRandoms)) }
|
||||||
|
func (p nbPlane) Slice(start, end int) SortSlicer { p.nbPoints = p.nbPoints[start:end]; return p }
|
||||||
|
func (p nbPlane) Swap(i, j int) {
|
||||||
|
p.nbPoints[i], p.nbPoints[j] = p.nbPoints[j], p.nbPoints[i]
|
||||||
|
}
|
88
spatial/kdtree/points.go
Normal file
88
spatial/kdtree/points.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package kdtree
|
||||||
|
|
||||||
|
import "math"
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ Interface = Points(nil)
|
||||||
|
_ Comparable = Point(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Point represents a point in a k-d space that satisfies the Comparable interface.
|
||||||
|
type Point []float64
|
||||||
|
|
||||||
|
// Compare returns the signed distance of p from the plane passing through c and
|
||||||
|
// perpendicular to the dimension d. The concrete type of c must be Point.
|
||||||
|
func (p Point) Compare(c Comparable, d Dim) float64 { q := c.(Point); return p[d] - q[d] }
|
||||||
|
|
||||||
|
// Dims returns the number of dimensions described by the receiver.
|
||||||
|
func (p Point) Dims() int { return len(p) }
|
||||||
|
|
||||||
|
// Distance returns the squared Euclidean distance between c and the receiver. The
|
||||||
|
// concrete type of c must be Point.
|
||||||
|
func (p Point) Distance(c Comparable) float64 {
|
||||||
|
q := c.(Point)
|
||||||
|
var sum float64
|
||||||
|
for dim, c := range p {
|
||||||
|
d := c - q[dim]
|
||||||
|
sum += d * d
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend returns a bounding box that has been extended to include the receiver.
|
||||||
|
func (p Point) Extend(b *Bounding) *Bounding {
|
||||||
|
if b == nil {
|
||||||
|
b = &Bounding{append(Point(nil), p...), append(Point(nil), p...)}
|
||||||
|
}
|
||||||
|
min := b.Min.(Point)
|
||||||
|
max := b.Max.(Point)
|
||||||
|
for d, v := range p {
|
||||||
|
min[d] = math.Min(min[d], v)
|
||||||
|
max[d] = math.Max(max[d], v)
|
||||||
|
}
|
||||||
|
*b = Bounding{Min: min, Max: max}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Points is a collection of point values that satisfies the Interface.
|
||||||
|
type Points []Point
|
||||||
|
|
||||||
|
func (p Points) Bounds() *Bounding {
|
||||||
|
if p.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
min := append(Point(nil), p[0]...)
|
||||||
|
max := append(Point(nil), p[0]...)
|
||||||
|
for _, e := range p[1:] {
|
||||||
|
for d, v := range e {
|
||||||
|
min[d] = math.Min(min[d], v)
|
||||||
|
max[d] = math.Max(max[d], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &Bounding{Min: min, Max: max}
|
||||||
|
}
|
||||||
|
func (p Points) Index(i int) Comparable { return p[i] }
|
||||||
|
func (p Points) Len() int { return len(p) }
|
||||||
|
func (p Points) Pivot(d Dim) int { return Plane{Points: p, Dim: d}.Pivot() }
|
||||||
|
func (p Points) Slice(start, end int) Interface { return p[start:end] }
|
||||||
|
|
||||||
|
// Plane is a wrapping type that allows a Points type be pivoted on a dimension.
|
||||||
|
// The Pivot method of Plane uses MedianOfRandoms sampling at most 100 elements
|
||||||
|
// to find a pivot element.
|
||||||
|
type Plane struct {
|
||||||
|
Dim
|
||||||
|
Points
|
||||||
|
}
|
||||||
|
|
||||||
|
// randoms is the maximum number of random values to sample for calculation of
|
||||||
|
// median of random elements.
|
||||||
|
const randoms = 100
|
||||||
|
|
||||||
|
func (p Plane) Less(i, j int) bool { return p.Points[i][p.Dim] < p.Points[j][p.Dim] }
|
||||||
|
func (p Plane) Pivot() int { return Partition(p, MedianOfRandoms(p, randoms)) }
|
||||||
|
func (p Plane) Slice(start, end int) SortSlicer { p.Points = p.Points[start:end]; return p }
|
||||||
|
func (p Plane) Swap(i, j int) { p.Points[i], p.Points[j] = p.Points[j], p.Points[i] }
|
Reference in New Issue
Block a user