mirror of
https://github.com/gonum/gonum.git
synced 2025-10-08 16:40:06 +08:00
spatial/vptree: new package for vantage point tree NN search
This commit is contained in:
10
spatial/vptree/doc.go
Normal file
10
spatial/vptree/doc.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package vptree implements a vantage point tree. Vantage point
|
||||
// trees provide an efficient search for nearest neighbors in a
|
||||
// metric space.
|
||||
//
|
||||
// See http://pnylab.com/papers/vptree/vptree.pdf for details of vp-trees.
|
||||
package vptree // import "gonum.org/v1/gonum/spatial/vptree"
|
374
spatial/vptree/vptree.go
Normal file
374
spatial/vptree/vptree.go
Normal file
@@ -0,0 +1,374 @@
|
||||
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package vptree
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
|
||||
"gonum.org/v1/gonum/stat"
|
||||
)
|
||||
|
||||
// Comparable is the element interface for values stored in a vp-tree.
|
||||
type Comparable interface {
|
||||
// Distance returns the distance between the receiver and the
|
||||
// parameter. The returned distance must satisfy the properties
|
||||
// of distances in a metric space.
|
||||
//
|
||||
// - a.Distance(a) == 0
|
||||
// - a.Distance(b) >= 0
|
||||
// - a.Distance(b) == b.Distance(a)
|
||||
// - a.Distance(b) <= a.Distance(c)+c.Distance(b)
|
||||
//
|
||||
Distance(Comparable) float64
|
||||
}
|
||||
|
||||
// Point represents a point in a Euclidean k-d space that satisfies the Comparable
|
||||
// interface.
|
||||
type Point []float64
|
||||
|
||||
// Distance returns the Euclidean distance between c and the receiver. The concrete
|
||||
// type of c must be Point.
|
||||
func (p Point) Distance(c Comparable) float64 {
|
||||
q := c.(Point)
|
||||
var sum float64
|
||||
for dim, c := range p {
|
||||
d := c - q[dim]
|
||||
sum += d * d
|
||||
}
|
||||
return math.Sqrt(sum)
|
||||
}
|
||||
|
||||
// Node holds a single point value in a vantage point tree.
|
||||
type Node struct {
|
||||
Point Comparable
|
||||
Radius float64
|
||||
Closer *Node
|
||||
Further *Node
|
||||
}
|
||||
|
||||
// Tree implements a vantage point tree creation and nearest neighbor search.
|
||||
type Tree struct {
|
||||
Root *Node
|
||||
Count int
|
||||
}
|
||||
|
||||
// New returns a vantage point tree constructed from the values in p. The effort
|
||||
// parameter specifies how much work should be put into optimizing the choice of
|
||||
// vantage point. If effort is one or less, random vantage points are chosen.
|
||||
// The order of elements in p will be altered after New returns. The src parameter
|
||||
// provides the source of randomness for vantage point selection. If src is nil
|
||||
// global rand package functions are used.
|
||||
func New(p []Comparable, effort int, src rand.Source) *Tree {
|
||||
var intn func(int) int
|
||||
var shuf func(n int, swap func(i, j int))
|
||||
if src == nil {
|
||||
intn = rand.Intn
|
||||
shuf = rand.Shuffle
|
||||
} else {
|
||||
rnd := rand.New(src)
|
||||
intn = rnd.Intn
|
||||
shuf = rnd.Shuffle
|
||||
}
|
||||
b := builder{work: make([]float64, len(p)), intn: intn, shuf: shuf}
|
||||
return &Tree{
|
||||
Root: b.build(p, effort),
|
||||
Count: len(p),
|
||||
}
|
||||
}
|
||||
|
||||
// builder performs vp-tree construction as described for the simple vp-tree
|
||||
// algorithm in http://pnylab.com/papers/vptree/vptree.pdf.
|
||||
type builder struct {
|
||||
work []float64
|
||||
intn func(n int) int
|
||||
shuf func(n int, swap func(i, j int))
|
||||
}
|
||||
|
||||
func (b *builder) build(s []Comparable, effort int) *Node {
|
||||
if len(s) <= 1 {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &Node{Point: s[0]}
|
||||
}
|
||||
n := Node{Point: b.selectVantage(s, effort)}
|
||||
radius, closer, further := b.partition(n.Point, s)
|
||||
n.Radius = radius
|
||||
n.Closer = b.build(closer, effort)
|
||||
n.Further = b.build(further, effort)
|
||||
return &n
|
||||
}
|
||||
|
||||
func (b *builder) selectVantage(s []Comparable, effort int) Comparable {
|
||||
if effort <= 1 {
|
||||
return s[b.intn(len(s))]
|
||||
}
|
||||
if effort > len(s) {
|
||||
effort = len(s)
|
||||
}
|
||||
var best Comparable
|
||||
var bestVar float64
|
||||
b.work = b.work[:effort]
|
||||
choices := b.random(effort, s)
|
||||
for _, p := range choices {
|
||||
for i, q := range choices {
|
||||
b.work[i] = p.Distance(q)
|
||||
}
|
||||
variance := stat.Variance(b.work, nil)
|
||||
if variance > bestVar {
|
||||
best, bestVar = p, variance
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func (b *builder) random(n int, s []Comparable) []Comparable {
|
||||
if n >= len(s) {
|
||||
return s
|
||||
}
|
||||
b.shuf(len(s), func(i, j int) { s[i], s[j] = s[j], s[i] })
|
||||
return s[:n]
|
||||
}
|
||||
|
||||
func (b *builder) partition(v Comparable, s []Comparable) (radius float64, closer, further []Comparable) {
|
||||
b.work = b.work[:len(s)]
|
||||
for i, p := range s {
|
||||
b.work[i] = v.Distance(p)
|
||||
}
|
||||
sort.Sort(byDist{dists: b.work, points: s})
|
||||
|
||||
// Note that this does not conform exactly to the description
|
||||
// in the paper which specifies d(p, s) < mu for L; in cases
|
||||
// where the median element has a lower indexed element with
|
||||
// the same distance from the vantage point, L will include a
|
||||
// d(p, s) == mu.
|
||||
// The additional work required to satisfy the algorithm is
|
||||
// not worth doing as it has no effect on the correctness or
|
||||
// performance of the algorithm.
|
||||
radius = b.work[len(b.work)/2]
|
||||
|
||||
if len(b.work) > 1 {
|
||||
// Remove vantage if it is present.
|
||||
closer = s[1 : len(b.work)/2]
|
||||
}
|
||||
further = s[len(b.work)/2:]
|
||||
return radius, closer, further
|
||||
}
|
||||
|
||||
type byDist struct {
|
||||
dists []float64
|
||||
points []Comparable
|
||||
}
|
||||
|
||||
func (c byDist) Len() int { return len(c.dists) }
|
||||
func (c byDist) Less(i, j int) bool { return c.dists[i] < c.dists[j] }
|
||||
func (c byDist) Swap(i, j int) {
|
||||
c.dists[i], c.dists[j] = c.dists[j], c.dists[i]
|
||||
c.points[i], c.points[j] = c.points[j], c.points[i]
|
||||
}
|
||||
|
||||
// Len returns the number of elements in the tree.
|
||||
func (t *Tree) Len() int { return t.Count }
|
||||
|
||||
var inf = math.Inf(1)
|
||||
|
||||
// Nearest returns the nearest value to the query and the distance between them.
|
||||
func (t *Tree) Nearest(q Comparable) (Comparable, float64) {
|
||||
if t.Root == nil {
|
||||
return nil, inf
|
||||
}
|
||||
n, dist := t.Root.search(q, inf)
|
||||
if n == nil {
|
||||
return nil, inf
|
||||
}
|
||||
return n.Point, dist
|
||||
}
|
||||
|
||||
func (n *Node) search(q Comparable, dist float64) (*Node, float64) {
|
||||
if n == nil {
|
||||
return nil, inf
|
||||
}
|
||||
|
||||
d := q.Distance(n.Point)
|
||||
dist = math.Min(dist, d)
|
||||
|
||||
bn := n
|
||||
if d < n.Radius {
|
||||
cn, cd := n.Closer.search(q, dist)
|
||||
if cd < dist {
|
||||
bn, dist = cn, cd
|
||||
}
|
||||
if d+dist >= n.Radius {
|
||||
fn, fd := n.Further.search(q, dist)
|
||||
if fd < dist {
|
||||
bn, dist = fn, fd
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fn, fd := n.Further.search(q, dist)
|
||||
if fd < dist {
|
||||
bn, dist = fn, fd
|
||||
}
|
||||
if d-dist <= n.Radius {
|
||||
cn, cd := n.Closer.search(q, dist)
|
||||
if cd < dist {
|
||||
bn, dist = cn, cd
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bn, dist
|
||||
}
|
||||
|
||||
// ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable
|
||||
// is used to mark the end of the heap, so clients should not store nil values except for
|
||||
// this purpose.
|
||||
type ComparableDist struct {
|
||||
Comparable Comparable
|
||||
Dist float64
|
||||
}
|
||||
|
||||
// Heap is a max heap sorted on Dist.
|
||||
type Heap []ComparableDist
|
||||
|
||||
func (h *Heap) Max() ComparableDist { return (*h)[0] }
|
||||
func (h *Heap) Len() int { return len(*h) }
|
||||
func (h *Heap) Less(i, j int) bool { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist }
|
||||
func (h *Heap) Swap(i, j int) { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] }
|
||||
func (h *Heap) Push(x interface{}) { (*h) = append(*h, x.(ComparableDist)) }
|
||||
func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i }
|
||||
|
||||
// NKeeper is a Keeper that retains the n best ComparableDists that have been passed to Keep.
|
||||
type NKeeper struct {
|
||||
Heap
|
||||
}
|
||||
|
||||
// NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The
|
||||
// returned NKeeper is able to retain at most n values.
|
||||
func NewNKeeper(n int) *NKeeper {
|
||||
k := NKeeper{make(Heap, 1, n)}
|
||||
k.Heap[0].Dist = inf
|
||||
return &k
|
||||
}
|
||||
|
||||
// Keep adds c to the heap if its distance is less than the maximum value of the heap. If adding
|
||||
// c would increase the size of the heap beyond the initial maximum length, the maximum value of
|
||||
// the heap is dropped.
|
||||
func (k *NKeeper) Keep(c ComparableDist) {
|
||||
if c.Dist <= k.Heap[0].Dist { // Favour later finds to displace sentinel.
|
||||
if len(k.Heap) == cap(k.Heap) {
|
||||
heap.Pop(k)
|
||||
}
|
||||
heap.Push(k, c)
|
||||
}
|
||||
}
|
||||
|
||||
// DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the
|
||||
// query that it is called to Keep.
|
||||
type DistKeeper struct {
|
||||
Heap
|
||||
}
|
||||
|
||||
// NewDistKeeper returns an DistKeeper with the maximum value of the heap set to d.
|
||||
func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} }
|
||||
|
||||
// Keep adds c to the heap if its distance is less than or equal to the max value of the heap.
|
||||
func (k *DistKeeper) Keep(c ComparableDist) {
|
||||
if c.Dist <= k.Heap[0].Dist {
|
||||
heap.Push(k, c)
|
||||
}
|
||||
}
|
||||
|
||||
// Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type.
|
||||
// vantage point search is guided by the distance stored in the max value of the heap.
|
||||
type Keeper interface {
|
||||
Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap.
|
||||
Max() ComparableDist // Max returns the maximum element of the Keeper.
|
||||
heap.Interface
|
||||
}
|
||||
|
||||
// NearestSet finds the nearest values to the query accepted by the provided Keeper, k.
|
||||
// k must be able to return a ComparableDist specifying the maximum acceptable distance
|
||||
// when Max() is called, and retains the results of the search in min sorted order after
|
||||
// the call to NearestSet returns.
|
||||
// If a sentinel ComparableDist with a nil Comparable is used by the Keeper to mark the
|
||||
// maximum distance, NearestSet will remove it before returning.
|
||||
func (t *Tree) NearestSet(k Keeper, q Comparable) {
|
||||
if t.Root == nil {
|
||||
return
|
||||
}
|
||||
t.Root.searchSet(q, k)
|
||||
|
||||
// Check whether we have retained a sentinel
|
||||
// and flag removal if we have.
|
||||
removeSentinel := k.Len() != 0 && k.Max().Comparable == nil
|
||||
|
||||
sort.Sort(sort.Reverse(k))
|
||||
|
||||
// This abuses the interface to drop the max.
|
||||
// It is reasonable to do this because we know
|
||||
// that the maximum value will now be at element
|
||||
// zero, which is removed by the Pop method.
|
||||
if removeSentinel {
|
||||
k.Pop()
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Node) searchSet(q Comparable, k Keeper) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
|
||||
k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)})
|
||||
|
||||
d := q.Distance(n.Point)
|
||||
if d < n.Radius {
|
||||
n.Closer.searchSet(q, k)
|
||||
if d+k.Max().Dist >= n.Radius {
|
||||
n.Further.searchSet(q, k)
|
||||
}
|
||||
} else {
|
||||
n.Further.searchSet(q, k)
|
||||
if d-k.Max().Dist <= n.Radius {
|
||||
n.Closer.searchSet(q, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Operation is a function that operates on a Comparable. The bounding volume and tree depth
|
||||
// of the point is also provided. If done is returned true, the Operation is indicating that no
|
||||
// further work needs to be done and so the Do function should traverse no further.
|
||||
type Operation func(Comparable, int) (done bool)
|
||||
|
||||
// Do performs fn on all values stored in the tree. A boolean is returned indicating whether the
|
||||
// Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort
|
||||
// relationships, future tree operation behaviors are undefined.
|
||||
func (t *Tree) Do(fn Operation) bool {
|
||||
if t.Root == nil {
|
||||
return false
|
||||
}
|
||||
return t.Root.do(fn, 0)
|
||||
}
|
||||
|
||||
func (n *Node) do(fn Operation, depth int) (done bool) {
|
||||
if n.Closer != nil {
|
||||
done = n.Closer.do(fn, depth+1)
|
||||
if done {
|
||||
return
|
||||
}
|
||||
}
|
||||
done = fn(n.Point, depth)
|
||||
if done {
|
||||
return
|
||||
}
|
||||
if n.Further != nil {
|
||||
done = n.Further.do(fn, depth+1)
|
||||
}
|
||||
return
|
||||
}
|
60
spatial/vptree/vptree_simple_example_test.go
Normal file
60
spatial/vptree/vptree_simple_example_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package vptree_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gonum.org/v1/gonum/spatial/vptree"
|
||||
)
|
||||
|
||||
func ExampleTree() {
|
||||
// Example data from https://en.wikipedia.org/wiki/K-d_tree
|
||||
points := []vptree.Comparable{
|
||||
vptree.Point{2, 3},
|
||||
vptree.Point{5, 4},
|
||||
vptree.Point{9, 6},
|
||||
vptree.Point{4, 7},
|
||||
vptree.Point{8, 1},
|
||||
vptree.Point{7, 2},
|
||||
}
|
||||
|
||||
t := vptree.New(points, 3, nil)
|
||||
q := vptree.Point{8, 7}
|
||||
p, d := t.Nearest(q)
|
||||
fmt.Printf("%v is closest point to %v, d=%f\n", p, q, d)
|
||||
// Output:
|
||||
// [9 6] is closest point to [8 7], d=1.414214
|
||||
}
|
||||
|
||||
func ExampleTree_Do() {
|
||||
// Example data from https://en.wikipedia.org/wiki/K-d_tree
|
||||
points := []vptree.Comparable{
|
||||
vptree.Point{2, 3},
|
||||
vptree.Point{5, 4},
|
||||
vptree.Point{9, 6},
|
||||
vptree.Point{4, 7},
|
||||
vptree.Point{8, 1},
|
||||
vptree.Point{7, 2},
|
||||
}
|
||||
|
||||
// Print all points in the data set within 3 of (3, 5).
|
||||
t := vptree.New(points, 0, nil)
|
||||
q := vptree.Point{3, 5}
|
||||
t.Do(func(c vptree.Comparable, _ int) (done bool) {
|
||||
// Compare each distance and output points
|
||||
// with a Euclidean distance less than or
|
||||
// equal to 3. Distance returns the
|
||||
// Euclidean distance between points.
|
||||
if q.Distance(c) <= 3 {
|
||||
fmt.Println(c)
|
||||
}
|
||||
return
|
||||
})
|
||||
// Unordered output:
|
||||
// [2 3]
|
||||
// [4 7]
|
||||
// [5 4]
|
||||
}
|
530
spatial/vptree/vptree_test.go
Normal file
530
spatial/vptree/vptree_test.go
Normal file
@@ -0,0 +1,530 @@
|
||||
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package vptree
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
)
|
||||
|
||||
var (
|
||||
genDot = flag.Bool("dot", false, "generate dot code for failing trees")
|
||||
dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format")
|
||||
)
|
||||
|
||||
var (
|
||||
// Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572.
|
||||
wpData = []Comparable{
|
||||
Point{2, 3},
|
||||
Point{5, 4},
|
||||
Point{9, 6},
|
||||
Point{4, 7},
|
||||
Point{8, 1},
|
||||
Point{7, 2},
|
||||
}
|
||||
)
|
||||
|
||||
var newTests = []struct {
|
||||
data []Comparable
|
||||
effort int
|
||||
}{
|
||||
{data: wpData, effort: 0},
|
||||
{data: wpData, effort: 1},
|
||||
{data: wpData, effort: 2},
|
||||
{data: wpData, effort: 4},
|
||||
{data: wpData, effort: 8},
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
for i, test := range newTests {
|
||||
var tree *Tree
|
||||
var panicked bool
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked = true
|
||||
}
|
||||
}()
|
||||
tree = New(test.data, test.effort, rand.NewSource(1))
|
||||
}()
|
||||
if panicked {
|
||||
t.Errorf("unexpected panic for test %d", i)
|
||||
continue
|
||||
}
|
||||
|
||||
if !tree.Root.isVPTree() {
|
||||
t.Errorf("tree %d is not vp-tree", i)
|
||||
}
|
||||
|
||||
if t.Failed() && *genDot && tree.Len() <= *dotLimit {
|
||||
err := dotFile(tree, fmt.Sprintf("TestNew%d", i), "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write DOT file: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type compFn func(v, radius float64) bool
|
||||
|
||||
func closer(v, radius float64) bool { return v <= radius }
|
||||
func further(v, radius float64) bool { return v >= radius }
|
||||
|
||||
func (n *Node) isVPTree() bool {
|
||||
if n == nil {
|
||||
return true
|
||||
}
|
||||
if !n.Closer.isPartitioned(n.Point, closer, n.Radius) {
|
||||
return false
|
||||
}
|
||||
if !n.Further.isPartitioned(n.Point, further, n.Radius) {
|
||||
return false
|
||||
}
|
||||
return n.Closer.isVPTree() && n.Further.isVPTree()
|
||||
}
|
||||
|
||||
func (n *Node) isPartitioned(vp Comparable, fn compFn, radius float64) bool {
|
||||
if n == nil {
|
||||
return true
|
||||
}
|
||||
if n.Closer != nil && !fn(vp.Distance(n.Closer.Point), radius) {
|
||||
return false
|
||||
}
|
||||
if n.Further != nil && !fn(vp.Distance(n.Further.Point), radius) {
|
||||
return false
|
||||
}
|
||||
return n.Closer.isPartitioned(vp, fn, radius) && n.Further.isPartitioned(vp, fn, radius)
|
||||
}
|
||||
|
||||
func nearest(q Comparable, p []Comparable) (Comparable, float64) {
|
||||
min := q.Distance(p[0])
|
||||
var r int
|
||||
for i := 1; i < len(p); i++ {
|
||||
d := q.Distance(p[i])
|
||||
if d < min {
|
||||
min = d
|
||||
r = i
|
||||
}
|
||||
}
|
||||
return p[r], min
|
||||
}
|
||||
|
||||
func TestNearestRandom(t *testing.T) {
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
|
||||
const (
|
||||
min = 0.0
|
||||
max = 1000.0
|
||||
|
||||
dims = 4
|
||||
setSize = 10000
|
||||
)
|
||||
|
||||
var randData []Comparable
|
||||
for i := 0; i < setSize; i++ {
|
||||
p := make(Point, dims)
|
||||
for j := 0; j < dims; j++ {
|
||||
p[j] = (max-min)*rnd.Float64() + min
|
||||
}
|
||||
randData = append(randData, p)
|
||||
}
|
||||
tree := New(randData, 10, rand.NewSource(1))
|
||||
|
||||
for i := 0; i < setSize; i++ {
|
||||
q := make(Point, dims)
|
||||
for j := 0; j < dims; j++ {
|
||||
q[j] = (max-min)*rnd.Float64() + min
|
||||
}
|
||||
|
||||
got, _ := tree.Nearest(q)
|
||||
want, _ := nearest(q, randData)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNearest(t *testing.T) {
|
||||
tree := New(wpData, 3, rand.NewSource(1))
|
||||
for _, q := range append([]Comparable{
|
||||
Point{4, 6},
|
||||
// Point{7, 5}, // Omitted because it is ambiguously finds [9 6] or [5 4].
|
||||
Point{8, 7},
|
||||
Point{6, -5},
|
||||
Point{1e5, 1e5},
|
||||
Point{1e5, -1e5},
|
||||
Point{-1e5, 1e5},
|
||||
Point{-1e5, -1e5},
|
||||
Point{1e5, 0},
|
||||
Point{0, -1e5},
|
||||
Point{0, 1e5},
|
||||
Point{-1e5, 0},
|
||||
}, wpData...) {
|
||||
gotP, gotD := tree.Nearest(q)
|
||||
wantP, wantD := nearest(q, wpData)
|
||||
if !reflect.DeepEqual(gotP, wantP) {
|
||||
t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP)
|
||||
}
|
||||
if gotD != wantD {
|
||||
t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func nearestN(n int, q Comparable, p []Comparable) []ComparableDist {
|
||||
nk := NewNKeeper(n)
|
||||
for i := 0; i < len(p); i++ {
|
||||
nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])})
|
||||
}
|
||||
if len(nk.Heap) == 1 {
|
||||
return nk.Heap
|
||||
}
|
||||
sort.Sort(nk)
|
||||
for i, j := 0, len(nk.Heap)-1; i < j; i, j = i+1, j-1 {
|
||||
nk.Heap[i], nk.Heap[j] = nk.Heap[j], nk.Heap[i]
|
||||
}
|
||||
return nk.Heap
|
||||
}
|
||||
|
||||
func TestNearestSetN(t *testing.T) {
|
||||
data := append([]Comparable{
|
||||
Point{4, 6},
|
||||
Point{7, 5}, // OK here because we collect N.
|
||||
Point{8, 7},
|
||||
Point{6, -5},
|
||||
Point{1e5, 1e5},
|
||||
Point{1e5, -1e5},
|
||||
Point{-1e5, 1e5},
|
||||
Point{-1e5, -1e5},
|
||||
Point{1e5, 0},
|
||||
Point{0, -1e5},
|
||||
Point{0, 1e5},
|
||||
Point{-1e5, 0}},
|
||||
wpData[:len(wpData)-1]...)
|
||||
|
||||
tree := New(wpData, 3, rand.NewSource(1))
|
||||
for k := 1; k <= len(wpData); k++ {
|
||||
for _, q := range data {
|
||||
wantP := nearestN(k, q, wpData)
|
||||
|
||||
nk := NewNKeeper(k)
|
||||
tree.NearestSet(nk, q)
|
||||
|
||||
var max float64
|
||||
wantD := make(map[float64]map[string]struct{})
|
||||
for _, p := range wantP {
|
||||
if p.Dist > max {
|
||||
max = p.Dist
|
||||
}
|
||||
d, ok := wantD[p.Dist]
|
||||
if !ok {
|
||||
d = make(map[string]struct{})
|
||||
}
|
||||
d[fmt.Sprint(p.Comparable)] = struct{}{}
|
||||
wantD[p.Dist] = d
|
||||
}
|
||||
gotD := make(map[float64]map[string]struct{})
|
||||
for _, p := range nk.Heap {
|
||||
if p.Dist > max {
|
||||
t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max)
|
||||
}
|
||||
d, ok := gotD[p.Dist]
|
||||
if !ok {
|
||||
d = make(map[string]struct{})
|
||||
}
|
||||
d[fmt.Sprint(p.Comparable)] = struct{}{}
|
||||
gotD[p.Dist] = d
|
||||
}
|
||||
|
||||
// If the available number of slots does not fit all the coequal furthest points
|
||||
// we will fail the check. So remove, but check them minimally here.
|
||||
if !reflect.DeepEqual(wantD[max], gotD[max]) {
|
||||
// The best we can do at this stage is confirm that there are an equal number of matches at this distance.
|
||||
if len(gotD[max]) != len(wantD[max]) {
|
||||
t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max]))
|
||||
}
|
||||
delete(wantD, max)
|
||||
delete(gotD, max)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotD, wantD) {
|
||||
t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var nearestSetDistTests = []Point{
|
||||
{4, 6},
|
||||
{7, 5},
|
||||
{8, 7},
|
||||
{6, -5},
|
||||
}
|
||||
|
||||
func TestNearestSetDist(t *testing.T) {
|
||||
tree := New(wpData, 3, rand.NewSource(1))
|
||||
for i, q := range nearestSetDistTests {
|
||||
for d := 1.0; d < 100; d += 0.1 {
|
||||
dk := NewDistKeeper(d)
|
||||
tree.NearestSet(dk, q)
|
||||
|
||||
hits := make(map[string]float64)
|
||||
for _, p := range wpData {
|
||||
hits[fmt.Sprint(p)] = p.Distance(q)
|
||||
}
|
||||
|
||||
for _, p := range dk.Heap {
|
||||
var done bool
|
||||
if p.Comparable == nil {
|
||||
done = true
|
||||
continue
|
||||
}
|
||||
delete(hits, fmt.Sprint(p.Comparable))
|
||||
if done {
|
||||
t.Error("expectedly finished heap iteration")
|
||||
break
|
||||
}
|
||||
dist := p.Comparable.Distance(q)
|
||||
if dist > d {
|
||||
t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d)
|
||||
}
|
||||
}
|
||||
|
||||
for p, dist := range hits {
|
||||
if dist <= d {
|
||||
t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDo(t *testing.T) {
|
||||
tree := New(wpData, 3, rand.NewSource(1))
|
||||
var got []Point
|
||||
fn := func(c Comparable, _ int) (done bool) {
|
||||
got = append(got, c.(Point))
|
||||
return
|
||||
}
|
||||
killed := tree.Do(fn)
|
||||
|
||||
want := make([]Point, len(wpData))
|
||||
for i, p := range wpData {
|
||||
want[i] = p.(Point)
|
||||
}
|
||||
sort.Sort(lexical(got))
|
||||
sort.Sort(lexical(want))
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, want)
|
||||
}
|
||||
if killed {
|
||||
t.Error("tree iteration unexpectedly killed")
|
||||
}
|
||||
}
|
||||
|
||||
type lexical []Point
|
||||
|
||||
func (c lexical) Len() int { return len(c) }
|
||||
func (c lexical) Less(i, j int) bool {
|
||||
a, b := c[i], c[j]
|
||||
l := len(a)
|
||||
if len(b) < l {
|
||||
l = len(b)
|
||||
}
|
||||
for k, v := range a[:l] {
|
||||
if v < b[k] {
|
||||
return true
|
||||
}
|
||||
if v > b[k] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(a) < len(b)
|
||||
}
|
||||
func (c lexical) Swap(i, j int) { c[i], c[j] = c[j], c[i] }
|
||||
|
||||
func BenchmarkNew(b *testing.B) {
|
||||
for _, effort := range []int{0, 10, 100} {
|
||||
b.Run(fmt.Sprintf("New:%d", effort), func(b *testing.B) {
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
p := make([]Comparable, 1e5)
|
||||
for i := range p {
|
||||
p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = New(p, effort, rand.NewSource(1))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark(b *testing.B) {
|
||||
var r Comparable
|
||||
var d float64
|
||||
queryBenchmarks := []struct {
|
||||
name string
|
||||
fn func(data []Comparable, tree *Tree, rnd *rand.Rand) func(*testing.B)
|
||||
}{
|
||||
{
|
||||
name: "NearestBrute", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data)
|
||||
}
|
||||
if r == nil {
|
||||
b.Error("unexpected nil result")
|
||||
}
|
||||
if math.IsNaN(d) {
|
||||
b.Error("unexpected NaN result")
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NearestBruteN10", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
var r []ComparableDist
|
||||
for i := 0; i < b.N; i++ {
|
||||
r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data)
|
||||
}
|
||||
if len(r) != 10 {
|
||||
b.Error("unexpected result length", len(r))
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Nearest", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()})
|
||||
}
|
||||
if r == nil {
|
||||
b.Error("unexpected nil result")
|
||||
}
|
||||
if math.IsNaN(d) {
|
||||
b.Error("unexpected NaN result")
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NearestSetN10", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
nk := NewNKeeper(10)
|
||||
for i := 0; i < b.N; i++ {
|
||||
tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()})
|
||||
if nk.Len() != 10 {
|
||||
b.Error("unexpected result length")
|
||||
}
|
||||
nk.Heap = nk.Heap[:1]
|
||||
nk.Heap[0] = ComparableDist{Dist: inf}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, effort := range []int{0, 3, 10, 30, 100, 300} {
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
data := make([]Comparable, 1e5)
|
||||
for i := range data {
|
||||
data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
|
||||
}
|
||||
tree := New(data, effort, rand.NewSource(1))
|
||||
|
||||
if !tree.Root.isVPTree() {
|
||||
b.Fatal("tree is not vantage point tree")
|
||||
}
|
||||
|
||||
for i := 0; i < 1e3; i++ {
|
||||
q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
|
||||
gotP, gotD := tree.Nearest(q)
|
||||
wantP, wantD := nearest(q, data)
|
||||
if !reflect.DeepEqual(gotP, wantP) {
|
||||
b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP)
|
||||
}
|
||||
if gotD != wantD {
|
||||
b.Errorf("unexpected distance for query %.3f: got:%v want:%v", q, gotD, wantD)
|
||||
}
|
||||
}
|
||||
|
||||
if b.Failed() && *genDot && tree.Len() <= *dotLimit {
|
||||
err := dotFile(tree, "TestBenches", "")
|
||||
if err != nil {
|
||||
b.Fatalf("failed to write DOT file: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, bench := range queryBenchmarks {
|
||||
if strings.Contains(bench.name, "Brute") && effort != 0 {
|
||||
continue
|
||||
}
|
||||
b.Run(fmt.Sprintf("%s:%d", bench.name, effort), bench.fn(data, tree, rnd))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dot(t *Tree, label string) string {
|
||||
if t == nil {
|
||||
return ""
|
||||
}
|
||||
var (
|
||||
s []string
|
||||
follow func(*Node)
|
||||
)
|
||||
follow = func(n *Node) {
|
||||
id := uintptr(unsafe.Pointer(n))
|
||||
c := fmt.Sprintf("%d[label = \"<Closer> |<Elem> %.3f/%.3f|<Further>\"];",
|
||||
id, n.Point, n.Radius)
|
||||
if n.Closer != nil {
|
||||
c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Closer -> \"%d\":Elem [label=%.3f];",
|
||||
id, uintptr(unsafe.Pointer(n.Closer)), n.Point.Distance(n.Closer.Point))
|
||||
follow(n.Closer)
|
||||
}
|
||||
if n.Further != nil {
|
||||
c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Further -> \"%d\":Elem [label=%.3f];",
|
||||
id, uintptr(unsafe.Pointer(n.Further)), n.Point.Distance(n.Further.Point))
|
||||
follow(n.Further)
|
||||
}
|
||||
s = append(s, c)
|
||||
}
|
||||
if t.Root != nil {
|
||||
follow(t.Root)
|
||||
}
|
||||
return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n",
|
||||
label,
|
||||
strings.Join(s, "\n\t"),
|
||||
)
|
||||
}
|
||||
|
||||
func dotFile(t *Tree, label, dotString string) (err error) {
|
||||
if t == nil && dotString == "" {
|
||||
return
|
||||
}
|
||||
f, err := os.Create(label + ".dot")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
if dotString == "" {
|
||||
fmt.Fprintf(f, dot(t, label))
|
||||
} else {
|
||||
fmt.Fprintf(f, dotString)
|
||||
}
|
||||
return
|
||||
}
|
109
spatial/vptree/vptree_user_type_example_test.go
Normal file
109
spatial/vptree/vptree_user_type_example_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright ©2019 The Gonum Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package vptree_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"gonum.org/v1/gonum/spatial/vptree"
|
||||
)
|
||||
|
||||
func Example_accessiblePublicTransport() {
|
||||
// Construct a vp tree of train station locations
|
||||
// to identify accessible public transport for the
|
||||
// elderly.
|
||||
t := vptree.New(stations, 5, nil)
|
||||
|
||||
// Residence.
|
||||
q := place{lat: 51.501476, lon: -0.140634}
|
||||
|
||||
var keep vptree.Keeper
|
||||
|
||||
// Find all stations within 0.75 of the residence.
|
||||
keep = vptree.NewDistKeeper(0.75)
|
||||
t.NearestSet(keep, q)
|
||||
|
||||
fmt.Println(`Stations within 750 m of 51.501476N 0.140634W.`)
|
||||
for _, c := range keep.(*vptree.DistKeeper).Heap {
|
||||
p := c.Comparable.(place)
|
||||
fmt.Printf("%s: %0.3f km\n", p.name, p.Distance(q))
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Find the five closest stations to the residence.
|
||||
keep = vptree.NewNKeeper(5)
|
||||
t.NearestSet(keep, q)
|
||||
|
||||
fmt.Println(`5 closest stations to 51.501476N 0.140634W.`)
|
||||
for _, c := range keep.(*vptree.NKeeper).Heap {
|
||||
p := c.Comparable.(place)
|
||||
fmt.Printf("%s: %0.3f km\n", p.name, p.Distance(q))
|
||||
}
|
||||
|
||||
// Output:
|
||||
//
|
||||
// Stations within 750 m of 51.501476N 0.140634W.
|
||||
// St. James's Park: 0.545 km
|
||||
// Green Park: 0.600 km
|
||||
// Victoria: 0.621 km
|
||||
//
|
||||
// 5 closest stations to 51.501476N 0.140634W.
|
||||
// St. James's Park: 0.545 km
|
||||
// Green Park: 0.600 km
|
||||
// Victoria: 0.621 km
|
||||
// Hyde Park Corner: 0.846 km
|
||||
// Picadilly Circus: 1.027 km
|
||||
}
|
||||
|
||||
// stations is a list of railways stations.
|
||||
var stations = []vptree.Comparable{
|
||||
place{name: "Bond Street", lat: 51.5142, lon: -0.1494},
|
||||
place{name: "Charing Cross", lat: 51.508, lon: -0.1247},
|
||||
place{name: "Covent Garden", lat: 51.5129, lon: -0.1243},
|
||||
place{name: "Embankment", lat: 51.5074, lon: -0.1223},
|
||||
place{name: "Green Park", lat: 51.5067, lon: -0.1428},
|
||||
place{name: "Hyde Park Corner", lat: 51.5027, lon: -0.1527},
|
||||
place{name: "Leicester Square", lat: 51.5113, lon: -0.1281},
|
||||
place{name: "Marble Arch", lat: 51.5136, lon: -0.1586},
|
||||
place{name: "Oxford Circus", lat: 51.515, lon: -0.1415},
|
||||
place{name: "Picadilly Circus", lat: 51.5098, lon: -0.1342},
|
||||
place{name: "Pimlico", lat: 51.4893, lon: -0.1334},
|
||||
place{name: "Sloane Square", lat: 51.4924, lon: -0.1565},
|
||||
place{name: "South Kensington", lat: 51.4941, lon: -0.1738},
|
||||
place{name: "St. James's Park", lat: 51.4994, lon: -0.1335},
|
||||
place{name: "Temple", lat: 51.5111, lon: -0.1141},
|
||||
place{name: "Tottenham Court Road", lat: 51.5165, lon: -0.131},
|
||||
place{name: "Vauxhall", lat: 51.4861, lon: -0.1253},
|
||||
place{name: "Victoria", lat: 51.4965, lon: -0.1447},
|
||||
place{name: "Waterloo", lat: 51.5036, lon: -0.1143},
|
||||
place{name: "Westminster", lat: 51.501, lon: -0.1254},
|
||||
}
|
||||
|
||||
// place is a vptree.Comparable implementations.
|
||||
type place struct {
|
||||
name string
|
||||
lat, lon float64
|
||||
}
|
||||
|
||||
// Distance returns the distance between the receiver and c.
|
||||
func (p place) Distance(c vptree.Comparable) float64 {
|
||||
q := c.(place)
|
||||
return haversine(p.lat, p.lon, q.lat, q.lon)
|
||||
}
|
||||
|
||||
// haversine returns the distance between two geographic coordinates.
|
||||
func haversine(lat1, lon1, lat2, lon2 float64) float64 {
|
||||
const r = 6371 // km
|
||||
sdLat := math.Sin(radians(lat2-lat1) / 2)
|
||||
sdLon := math.Sin(radians(lon2-lon1) / 2)
|
||||
a := sdLat*sdLat + math.Cos(radians(lat1))*math.Cos(radians(lat2))*sdLon*sdLon
|
||||
d := 2 * r * math.Asin(math.Sqrt(a))
|
||||
return d // km
|
||||
}
|
||||
|
||||
func radians(d float64) float64 {
|
||||
return d * math.Pi / 180
|
||||
}
|
Reference in New Issue
Block a user