From 047f2c2add5e169ea594d11fbbcbbba2ffb44aa9 Mon Sep 17 00:00:00 2001 From: Dan Kortschak Date: Thu, 16 May 2019 07:34:28 +0930 Subject: [PATCH] spatial/vptree: new package for vantage point tree NN search --- spatial/vptree/doc.go | 10 + spatial/vptree/vptree.go | 374 ++++++++++++ spatial/vptree/vptree_simple_example_test.go | 60 ++ spatial/vptree/vptree_test.go | 530 ++++++++++++++++++ .../vptree/vptree_user_type_example_test.go | 109 ++++ 5 files changed, 1083 insertions(+) create mode 100644 spatial/vptree/doc.go create mode 100644 spatial/vptree/vptree.go create mode 100644 spatial/vptree/vptree_simple_example_test.go create mode 100644 spatial/vptree/vptree_test.go create mode 100644 spatial/vptree/vptree_user_type_example_test.go diff --git a/spatial/vptree/doc.go b/spatial/vptree/doc.go new file mode 100644 index 00000000..e9ccb7a3 --- /dev/null +++ b/spatial/vptree/doc.go @@ -0,0 +1,10 @@ +// Copyright ©2019 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package vptree implements a vantage point tree. Vantage point +// trees provide an efficient search for nearest neighbors in a +// metric space. +// +// See http://pnylab.com/papers/vptree/vptree.pdf for details of vp-trees. +package vptree // import "gonum.org/v1/gonum/spatial/vptree" diff --git a/spatial/vptree/vptree.go b/spatial/vptree/vptree.go new file mode 100644 index 00000000..afd05721 --- /dev/null +++ b/spatial/vptree/vptree.go @@ -0,0 +1,374 @@ +// Copyright ©2019 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package vptree + +import ( + "container/heap" + "math" + "sort" + + "golang.org/x/exp/rand" + + "gonum.org/v1/gonum/stat" +) + +// Comparable is the element interface for values stored in a vp-tree. +type Comparable interface { + // Distance returns the distance between the receiver and the + // parameter. The returned distance must satisfy the properties + // of distances in a metric space. + // + // - a.Distance(a) == 0 + // - a.Distance(b) >= 0 + // - a.Distance(b) == b.Distance(a) + // - a.Distance(b) <= a.Distance(c)+c.Distance(b) + // + Distance(Comparable) float64 +} + +// Point represents a point in a Euclidean k-d space that satisfies the Comparable +// interface. +type Point []float64 + +// Distance returns the Euclidean distance between c and the receiver. The concrete +// type of c must be Point. +func (p Point) Distance(c Comparable) float64 { + q := c.(Point) + var sum float64 + for dim, c := range p { + d := c - q[dim] + sum += d * d + } + return math.Sqrt(sum) +} + +// Node holds a single point value in a vantage point tree. +type Node struct { + Point Comparable + Radius float64 + Closer *Node + Further *Node +} + +// Tree implements a vantage point tree creation and nearest neighbor search. +type Tree struct { + Root *Node + Count int +} + +// New returns a vantage point tree constructed from the values in p. The effort +// parameter specifies how much work should be put into optimizing the choice of +// vantage point. If effort is one or less, random vantage points are chosen. +// The order of elements in p will be altered after New returns. The src parameter +// provides the source of randomness for vantage point selection. If src is nil +// global rand package functions are used. +func New(p []Comparable, effort int, src rand.Source) *Tree { + var intn func(int) int + var shuf func(n int, swap func(i, j int)) + if src == nil { + intn = rand.Intn + shuf = rand.Shuffle + } else { + rnd := rand.New(src) + intn = rnd.Intn + shuf = rnd.Shuffle + } + b := builder{work: make([]float64, len(p)), intn: intn, shuf: shuf} + return &Tree{ + Root: b.build(p, effort), + Count: len(p), + } +} + +// builder performs vp-tree construction as described for the simple vp-tree +// algorithm in http://pnylab.com/papers/vptree/vptree.pdf. +type builder struct { + work []float64 + intn func(n int) int + shuf func(n int, swap func(i, j int)) +} + +func (b *builder) build(s []Comparable, effort int) *Node { + if len(s) <= 1 { + if len(s) == 0 { + return nil + } + return &Node{Point: s[0]} + } + n := Node{Point: b.selectVantage(s, effort)} + radius, closer, further := b.partition(n.Point, s) + n.Radius = radius + n.Closer = b.build(closer, effort) + n.Further = b.build(further, effort) + return &n +} + +func (b *builder) selectVantage(s []Comparable, effort int) Comparable { + if effort <= 1 { + return s[b.intn(len(s))] + } + if effort > len(s) { + effort = len(s) + } + var best Comparable + var bestVar float64 + b.work = b.work[:effort] + choices := b.random(effort, s) + for _, p := range choices { + for i, q := range choices { + b.work[i] = p.Distance(q) + } + variance := stat.Variance(b.work, nil) + if variance > bestVar { + best, bestVar = p, variance + } + } + return best +} + +func (b *builder) random(n int, s []Comparable) []Comparable { + if n >= len(s) { + return s + } + b.shuf(len(s), func(i, j int) { s[i], s[j] = s[j], s[i] }) + return s[:n] +} + +func (b *builder) partition(v Comparable, s []Comparable) (radius float64, closer, further []Comparable) { + b.work = b.work[:len(s)] + for i, p := range s { + b.work[i] = v.Distance(p) + } + sort.Sort(byDist{dists: b.work, points: s}) + + // Note that this does not conform exactly to the description + // in the paper which specifies d(p, s) < mu for L; in cases + // where the median element has a lower indexed element with + // the same distance from the vantage point, L will include a + // d(p, s) == mu. + // The additional work required to satisfy the algorithm is + // not worth doing as it has no effect on the correctness or + // performance of the algorithm. + radius = b.work[len(b.work)/2] + + if len(b.work) > 1 { + // Remove vantage if it is present. + closer = s[1 : len(b.work)/2] + } + further = s[len(b.work)/2:] + return radius, closer, further +} + +type byDist struct { + dists []float64 + points []Comparable +} + +func (c byDist) Len() int { return len(c.dists) } +func (c byDist) Less(i, j int) bool { return c.dists[i] < c.dists[j] } +func (c byDist) Swap(i, j int) { + c.dists[i], c.dists[j] = c.dists[j], c.dists[i] + c.points[i], c.points[j] = c.points[j], c.points[i] +} + +// Len returns the number of elements in the tree. +func (t *Tree) Len() int { return t.Count } + +var inf = math.Inf(1) + +// Nearest returns the nearest value to the query and the distance between them. +func (t *Tree) Nearest(q Comparable) (Comparable, float64) { + if t.Root == nil { + return nil, inf + } + n, dist := t.Root.search(q, inf) + if n == nil { + return nil, inf + } + return n.Point, dist +} + +func (n *Node) search(q Comparable, dist float64) (*Node, float64) { + if n == nil { + return nil, inf + } + + d := q.Distance(n.Point) + dist = math.Min(dist, d) + + bn := n + if d < n.Radius { + cn, cd := n.Closer.search(q, dist) + if cd < dist { + bn, dist = cn, cd + } + if d+dist >= n.Radius { + fn, fd := n.Further.search(q, dist) + if fd < dist { + bn, dist = fn, fd + } + } + } else { + fn, fd := n.Further.search(q, dist) + if fd < dist { + bn, dist = fn, fd + } + if d-dist <= n.Radius { + cn, cd := n.Closer.search(q, dist) + if cd < dist { + bn, dist = cn, cd + } + } + } + + return bn, dist +} + +// ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable +// is used to mark the end of the heap, so clients should not store nil values except for +// this purpose. +type ComparableDist struct { + Comparable Comparable + Dist float64 +} + +// Heap is a max heap sorted on Dist. +type Heap []ComparableDist + +func (h *Heap) Max() ComparableDist { return (*h)[0] } +func (h *Heap) Len() int { return len(*h) } +func (h *Heap) Less(i, j int) bool { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist } +func (h *Heap) Swap(i, j int) { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] } +func (h *Heap) Push(x interface{}) { (*h) = append(*h, x.(ComparableDist)) } +func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i } + +// NKeeper is a Keeper that retains the n best ComparableDists that have been passed to Keep. +type NKeeper struct { + Heap +} + +// NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The +// returned NKeeper is able to retain at most n values. +func NewNKeeper(n int) *NKeeper { + k := NKeeper{make(Heap, 1, n)} + k.Heap[0].Dist = inf + return &k +} + +// Keep adds c to the heap if its distance is less than the maximum value of the heap. If adding +// c would increase the size of the heap beyond the initial maximum length, the maximum value of +// the heap is dropped. +func (k *NKeeper) Keep(c ComparableDist) { + if c.Dist <= k.Heap[0].Dist { // Favour later finds to displace sentinel. + if len(k.Heap) == cap(k.Heap) { + heap.Pop(k) + } + heap.Push(k, c) + } +} + +// DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the +// query that it is called to Keep. +type DistKeeper struct { + Heap +} + +// NewDistKeeper returns an DistKeeper with the maximum value of the heap set to d. +func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} } + +// Keep adds c to the heap if its distance is less than or equal to the max value of the heap. +func (k *DistKeeper) Keep(c ComparableDist) { + if c.Dist <= k.Heap[0].Dist { + heap.Push(k, c) + } +} + +// Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type. +// vantage point search is guided by the distance stored in the max value of the heap. +type Keeper interface { + Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap. + Max() ComparableDist // Max returns the maximum element of the Keeper. + heap.Interface +} + +// NearestSet finds the nearest values to the query accepted by the provided Keeper, k. +// k must be able to return a ComparableDist specifying the maximum acceptable distance +// when Max() is called, and retains the results of the search in min sorted order after +// the call to NearestSet returns. +// If a sentinel ComparableDist with a nil Comparable is used by the Keeper to mark the +// maximum distance, NearestSet will remove it before returning. +func (t *Tree) NearestSet(k Keeper, q Comparable) { + if t.Root == nil { + return + } + t.Root.searchSet(q, k) + + // Check whether we have retained a sentinel + // and flag removal if we have. + removeSentinel := k.Len() != 0 && k.Max().Comparable == nil + + sort.Sort(sort.Reverse(k)) + + // This abuses the interface to drop the max. + // It is reasonable to do this because we know + // that the maximum value will now be at element + // zero, which is removed by the Pop method. + if removeSentinel { + k.Pop() + } +} + +func (n *Node) searchSet(q Comparable, k Keeper) { + if n == nil { + return + } + + k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)}) + + d := q.Distance(n.Point) + if d < n.Radius { + n.Closer.searchSet(q, k) + if d+k.Max().Dist >= n.Radius { + n.Further.searchSet(q, k) + } + } else { + n.Further.searchSet(q, k) + if d-k.Max().Dist <= n.Radius { + n.Closer.searchSet(q, k) + } + } +} + +// Operation is a function that operates on a Comparable. The bounding volume and tree depth +// of the point is also provided. If done is returned true, the Operation is indicating that no +// further work needs to be done and so the Do function should traverse no further. +type Operation func(Comparable, int) (done bool) + +// Do performs fn on all values stored in the tree. A boolean is returned indicating whether the +// Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort +// relationships, future tree operation behaviors are undefined. +func (t *Tree) Do(fn Operation) bool { + if t.Root == nil { + return false + } + return t.Root.do(fn, 0) +} + +func (n *Node) do(fn Operation, depth int) (done bool) { + if n.Closer != nil { + done = n.Closer.do(fn, depth+1) + if done { + return + } + } + done = fn(n.Point, depth) + if done { + return + } + if n.Further != nil { + done = n.Further.do(fn, depth+1) + } + return +} diff --git a/spatial/vptree/vptree_simple_example_test.go b/spatial/vptree/vptree_simple_example_test.go new file mode 100644 index 00000000..1b296fc4 --- /dev/null +++ b/spatial/vptree/vptree_simple_example_test.go @@ -0,0 +1,60 @@ +// Copyright ©2019 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package vptree_test + +import ( + "fmt" + + "gonum.org/v1/gonum/spatial/vptree" +) + +func ExampleTree() { + // Example data from https://en.wikipedia.org/wiki/K-d_tree + points := []vptree.Comparable{ + vptree.Point{2, 3}, + vptree.Point{5, 4}, + vptree.Point{9, 6}, + vptree.Point{4, 7}, + vptree.Point{8, 1}, + vptree.Point{7, 2}, + } + + t := vptree.New(points, 3, nil) + q := vptree.Point{8, 7} + p, d := t.Nearest(q) + fmt.Printf("%v is closest point to %v, d=%f\n", p, q, d) + // Output: + // [9 6] is closest point to [8 7], d=1.414214 +} + +func ExampleTree_Do() { + // Example data from https://en.wikipedia.org/wiki/K-d_tree + points := []vptree.Comparable{ + vptree.Point{2, 3}, + vptree.Point{5, 4}, + vptree.Point{9, 6}, + vptree.Point{4, 7}, + vptree.Point{8, 1}, + vptree.Point{7, 2}, + } + + // Print all points in the data set within 3 of (3, 5). + t := vptree.New(points, 0, nil) + q := vptree.Point{3, 5} + t.Do(func(c vptree.Comparable, _ int) (done bool) { + // Compare each distance and output points + // with a Euclidean distance less than or + // equal to 3. Distance returns the + // Euclidean distance between points. + if q.Distance(c) <= 3 { + fmt.Println(c) + } + return + }) + // Unordered output: + // [2 3] + // [4 7] + // [5 4] +} diff --git a/spatial/vptree/vptree_test.go b/spatial/vptree/vptree_test.go new file mode 100644 index 00000000..e70b5e81 --- /dev/null +++ b/spatial/vptree/vptree_test.go @@ -0,0 +1,530 @@ +// Copyright ©2019 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package vptree + +import ( + "flag" + "fmt" + "math" + "os" + "reflect" + "sort" + "strings" + "testing" + "unsafe" + + "golang.org/x/exp/rand" +) + +var ( + genDot = flag.Bool("dot", false, "generate dot code for failing trees") + dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format") +) + +var ( + // Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572. + wpData = []Comparable{ + Point{2, 3}, + Point{5, 4}, + Point{9, 6}, + Point{4, 7}, + Point{8, 1}, + Point{7, 2}, + } +) + +var newTests = []struct { + data []Comparable + effort int +}{ + {data: wpData, effort: 0}, + {data: wpData, effort: 1}, + {data: wpData, effort: 2}, + {data: wpData, effort: 4}, + {data: wpData, effort: 8}, +} + +func TestNew(t *testing.T) { + for i, test := range newTests { + var tree *Tree + var panicked bool + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + tree = New(test.data, test.effort, rand.NewSource(1)) + }() + if panicked { + t.Errorf("unexpected panic for test %d", i) + continue + } + + if !tree.Root.isVPTree() { + t.Errorf("tree %d is not vp-tree", i) + } + + if t.Failed() && *genDot && tree.Len() <= *dotLimit { + err := dotFile(tree, fmt.Sprintf("TestNew%d", i), "") + if err != nil { + t.Fatalf("failed to write DOT file: %v", err) + } + } + } +} + +type compFn func(v, radius float64) bool + +func closer(v, radius float64) bool { return v <= radius } +func further(v, radius float64) bool { return v >= radius } + +func (n *Node) isVPTree() bool { + if n == nil { + return true + } + if !n.Closer.isPartitioned(n.Point, closer, n.Radius) { + return false + } + if !n.Further.isPartitioned(n.Point, further, n.Radius) { + return false + } + return n.Closer.isVPTree() && n.Further.isVPTree() +} + +func (n *Node) isPartitioned(vp Comparable, fn compFn, radius float64) bool { + if n == nil { + return true + } + if n.Closer != nil && !fn(vp.Distance(n.Closer.Point), radius) { + return false + } + if n.Further != nil && !fn(vp.Distance(n.Further.Point), radius) { + return false + } + return n.Closer.isPartitioned(vp, fn, radius) && n.Further.isPartitioned(vp, fn, radius) +} + +func nearest(q Comparable, p []Comparable) (Comparable, float64) { + min := q.Distance(p[0]) + var r int + for i := 1; i < len(p); i++ { + d := q.Distance(p[i]) + if d < min { + min = d + r = i + } + } + return p[r], min +} + +func TestNearestRandom(t *testing.T) { + rnd := rand.New(rand.NewSource(1)) + + const ( + min = 0.0 + max = 1000.0 + + dims = 4 + setSize = 10000 + ) + + var randData []Comparable + for i := 0; i < setSize; i++ { + p := make(Point, dims) + for j := 0; j < dims; j++ { + p[j] = (max-min)*rnd.Float64() + min + } + randData = append(randData, p) + } + tree := New(randData, 10, rand.NewSource(1)) + + for i := 0; i < setSize; i++ { + q := make(Point, dims) + for j := 0; j < dims; j++ { + q[j] = (max-min)*rnd.Float64() + min + } + + got, _ := tree.Nearest(q) + want, _ := nearest(q, randData) + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want) + } + } +} + +func TestNearest(t *testing.T) { + tree := New(wpData, 3, rand.NewSource(1)) + for _, q := range append([]Comparable{ + Point{4, 6}, + // Point{7, 5}, // Omitted because it is ambiguously finds [9 6] or [5 4]. + Point{8, 7}, + Point{6, -5}, + Point{1e5, 1e5}, + Point{1e5, -1e5}, + Point{-1e5, 1e5}, + Point{-1e5, -1e5}, + Point{1e5, 0}, + Point{0, -1e5}, + Point{0, 1e5}, + Point{-1e5, 0}, + }, wpData...) { + gotP, gotD := tree.Nearest(q) + wantP, wantD := nearest(q, wpData) + if !reflect.DeepEqual(gotP, wantP) { + t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) + } + if gotD != wantD { + t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD) + } + } +} + +func nearestN(n int, q Comparable, p []Comparable) []ComparableDist { + nk := NewNKeeper(n) + for i := 0; i < len(p); i++ { + nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])}) + } + if len(nk.Heap) == 1 { + return nk.Heap + } + sort.Sort(nk) + for i, j := 0, len(nk.Heap)-1; i < j; i, j = i+1, j-1 { + nk.Heap[i], nk.Heap[j] = nk.Heap[j], nk.Heap[i] + } + return nk.Heap +} + +func TestNearestSetN(t *testing.T) { + data := append([]Comparable{ + Point{4, 6}, + Point{7, 5}, // OK here because we collect N. + Point{8, 7}, + Point{6, -5}, + Point{1e5, 1e5}, + Point{1e5, -1e5}, + Point{-1e5, 1e5}, + Point{-1e5, -1e5}, + Point{1e5, 0}, + Point{0, -1e5}, + Point{0, 1e5}, + Point{-1e5, 0}}, + wpData[:len(wpData)-1]...) + + tree := New(wpData, 3, rand.NewSource(1)) + for k := 1; k <= len(wpData); k++ { + for _, q := range data { + wantP := nearestN(k, q, wpData) + + nk := NewNKeeper(k) + tree.NearestSet(nk, q) + + var max float64 + wantD := make(map[float64]map[string]struct{}) + for _, p := range wantP { + if p.Dist > max { + max = p.Dist + } + d, ok := wantD[p.Dist] + if !ok { + d = make(map[string]struct{}) + } + d[fmt.Sprint(p.Comparable)] = struct{}{} + wantD[p.Dist] = d + } + gotD := make(map[float64]map[string]struct{}) + for _, p := range nk.Heap { + if p.Dist > max { + t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max) + } + d, ok := gotD[p.Dist] + if !ok { + d = make(map[string]struct{}) + } + d[fmt.Sprint(p.Comparable)] = struct{}{} + gotD[p.Dist] = d + } + + // If the available number of slots does not fit all the coequal furthest points + // we will fail the check. So remove, but check them minimally here. + if !reflect.DeepEqual(wantD[max], gotD[max]) { + // The best we can do at this stage is confirm that there are an equal number of matches at this distance. + if len(gotD[max]) != len(wantD[max]) { + t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max])) + } + delete(wantD, max) + delete(gotD, max) + } + + if !reflect.DeepEqual(gotD, wantD) { + t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD) + } + } + } +} + +var nearestSetDistTests = []Point{ + {4, 6}, + {7, 5}, + {8, 7}, + {6, -5}, +} + +func TestNearestSetDist(t *testing.T) { + tree := New(wpData, 3, rand.NewSource(1)) + for i, q := range nearestSetDistTests { + for d := 1.0; d < 100; d += 0.1 { + dk := NewDistKeeper(d) + tree.NearestSet(dk, q) + + hits := make(map[string]float64) + for _, p := range wpData { + hits[fmt.Sprint(p)] = p.Distance(q) + } + + for _, p := range dk.Heap { + var done bool + if p.Comparable == nil { + done = true + continue + } + delete(hits, fmt.Sprint(p.Comparable)) + if done { + t.Error("expectedly finished heap iteration") + break + } + dist := p.Comparable.Distance(q) + if dist > d { + t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d) + } + } + + for p, dist := range hits { + if dist <= d { + t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d) + } + } + } + } +} + +func TestDo(t *testing.T) { + tree := New(wpData, 3, rand.NewSource(1)) + var got []Point + fn := func(c Comparable, _ int) (done bool) { + got = append(got, c.(Point)) + return + } + killed := tree.Do(fn) + + want := make([]Point, len(wpData)) + for i, p := range wpData { + want[i] = p.(Point) + } + sort.Sort(lexical(got)) + sort.Sort(lexical(want)) + + if !reflect.DeepEqual(got, want) { + t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, want) + } + if killed { + t.Error("tree iteration unexpectedly killed") + } +} + +type lexical []Point + +func (c lexical) Len() int { return len(c) } +func (c lexical) Less(i, j int) bool { + a, b := c[i], c[j] + l := len(a) + if len(b) < l { + l = len(b) + } + for k, v := range a[:l] { + if v < b[k] { + return true + } + if v > b[k] { + return false + } + } + return len(a) < len(b) +} +func (c lexical) Swap(i, j int) { c[i], c[j] = c[j], c[i] } + +func BenchmarkNew(b *testing.B) { + for _, effort := range []int{0, 10, 100} { + b.Run(fmt.Sprintf("New:%d", effort), func(b *testing.B) { + rnd := rand.New(rand.NewSource(1)) + p := make([]Comparable, 1e5) + for i := range p { + p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = New(p, effort, rand.NewSource(1)) + } + }) + } +} + +func Benchmark(b *testing.B) { + var r Comparable + var d float64 + queryBenchmarks := []struct { + name string + fn func(data []Comparable, tree *Tree, rnd *rand.Rand) func(*testing.B) + }{ + { + name: "NearestBrute", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) { + return func(b *testing.B) { + for i := 0; i < b.N; i++ { + r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) + } + if r == nil { + b.Error("unexpected nil result") + } + if math.IsNaN(d) { + b.Error("unexpected NaN result") + } + } + }, + }, + { + name: "NearestBruteN10", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) { + return func(b *testing.B) { + var r []ComparableDist + for i := 0; i < b.N; i++ { + r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) + } + if len(r) != 10 { + b.Error("unexpected result length", len(r)) + } + } + }, + }, + { + name: "Nearest", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) { + return func(b *testing.B) { + for i := 0; i < b.N; i++ { + r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) + } + if r == nil { + b.Error("unexpected nil result") + } + if math.IsNaN(d) { + b.Error("unexpected NaN result") + } + } + }, + }, + { + name: "NearestSetN10", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) { + return func(b *testing.B) { + nk := NewNKeeper(10) + for i := 0; i < b.N; i++ { + tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) + if nk.Len() != 10 { + b.Error("unexpected result length") + } + nk.Heap = nk.Heap[:1] + nk.Heap[0] = ComparableDist{Dist: inf} + } + } + }, + }, + } + + for _, effort := range []int{0, 3, 10, 30, 100, 300} { + rnd := rand.New(rand.NewSource(1)) + data := make([]Comparable, 1e5) + for i := range data { + data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} + } + tree := New(data, effort, rand.NewSource(1)) + + if !tree.Root.isVPTree() { + b.Fatal("tree is not vantage point tree") + } + + for i := 0; i < 1e3; i++ { + q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} + gotP, gotD := tree.Nearest(q) + wantP, wantD := nearest(q, data) + if !reflect.DeepEqual(gotP, wantP) { + b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) + } + if gotD != wantD { + b.Errorf("unexpected distance for query %.3f: got:%v want:%v", q, gotD, wantD) + } + } + + if b.Failed() && *genDot && tree.Len() <= *dotLimit { + err := dotFile(tree, "TestBenches", "") + if err != nil { + b.Fatalf("failed to write DOT file: %v", err) + } + return + } + + for _, bench := range queryBenchmarks { + if strings.Contains(bench.name, "Brute") && effort != 0 { + continue + } + b.Run(fmt.Sprintf("%s:%d", bench.name, effort), bench.fn(data, tree, rnd)) + } + } +} + +func dot(t *Tree, label string) string { + if t == nil { + return "" + } + var ( + s []string + follow func(*Node) + ) + follow = func(n *Node) { + id := uintptr(unsafe.Pointer(n)) + c := fmt.Sprintf("%d[label = \" | %.3f/%.3f|\"];", + id, n.Point, n.Radius) + if n.Closer != nil { + c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Closer -> \"%d\":Elem [label=%.3f];", + id, uintptr(unsafe.Pointer(n.Closer)), n.Point.Distance(n.Closer.Point)) + follow(n.Closer) + } + if n.Further != nil { + c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Further -> \"%d\":Elem [label=%.3f];", + id, uintptr(unsafe.Pointer(n.Further)), n.Point.Distance(n.Further.Point)) + follow(n.Further) + } + s = append(s, c) + } + if t.Root != nil { + follow(t.Root) + } + return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n", + label, + strings.Join(s, "\n\t"), + ) +} + +func dotFile(t *Tree, label, dotString string) (err error) { + if t == nil && dotString == "" { + return + } + f, err := os.Create(label + ".dot") + if err != nil { + return + } + defer f.Close() + if dotString == "" { + fmt.Fprintf(f, dot(t, label)) + } else { + fmt.Fprintf(f, dotString) + } + return +} diff --git a/spatial/vptree/vptree_user_type_example_test.go b/spatial/vptree/vptree_user_type_example_test.go new file mode 100644 index 00000000..9c9e70b6 --- /dev/null +++ b/spatial/vptree/vptree_user_type_example_test.go @@ -0,0 +1,109 @@ +// Copyright ©2019 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package vptree_test + +import ( + "fmt" + "math" + + "gonum.org/v1/gonum/spatial/vptree" +) + +func Example_accessiblePublicTransport() { + // Construct a vp tree of train station locations + // to identify accessible public transport for the + // elderly. + t := vptree.New(stations, 5, nil) + + // Residence. + q := place{lat: 51.501476, lon: -0.140634} + + var keep vptree.Keeper + + // Find all stations within 0.75 of the residence. + keep = vptree.NewDistKeeper(0.75) + t.NearestSet(keep, q) + + fmt.Println(`Stations within 750 m of 51.501476N 0.140634W.`) + for _, c := range keep.(*vptree.DistKeeper).Heap { + p := c.Comparable.(place) + fmt.Printf("%s: %0.3f km\n", p.name, p.Distance(q)) + } + fmt.Println() + + // Find the five closest stations to the residence. + keep = vptree.NewNKeeper(5) + t.NearestSet(keep, q) + + fmt.Println(`5 closest stations to 51.501476N 0.140634W.`) + for _, c := range keep.(*vptree.NKeeper).Heap { + p := c.Comparable.(place) + fmt.Printf("%s: %0.3f km\n", p.name, p.Distance(q)) + } + + // Output: + // + // Stations within 750 m of 51.501476N 0.140634W. + // St. James's Park: 0.545 km + // Green Park: 0.600 km + // Victoria: 0.621 km + // + // 5 closest stations to 51.501476N 0.140634W. + // St. James's Park: 0.545 km + // Green Park: 0.600 km + // Victoria: 0.621 km + // Hyde Park Corner: 0.846 km + // Picadilly Circus: 1.027 km +} + +// stations is a list of railways stations. +var stations = []vptree.Comparable{ + place{name: "Bond Street", lat: 51.5142, lon: -0.1494}, + place{name: "Charing Cross", lat: 51.508, lon: -0.1247}, + place{name: "Covent Garden", lat: 51.5129, lon: -0.1243}, + place{name: "Embankment", lat: 51.5074, lon: -0.1223}, + place{name: "Green Park", lat: 51.5067, lon: -0.1428}, + place{name: "Hyde Park Corner", lat: 51.5027, lon: -0.1527}, + place{name: "Leicester Square", lat: 51.5113, lon: -0.1281}, + place{name: "Marble Arch", lat: 51.5136, lon: -0.1586}, + place{name: "Oxford Circus", lat: 51.515, lon: -0.1415}, + place{name: "Picadilly Circus", lat: 51.5098, lon: -0.1342}, + place{name: "Pimlico", lat: 51.4893, lon: -0.1334}, + place{name: "Sloane Square", lat: 51.4924, lon: -0.1565}, + place{name: "South Kensington", lat: 51.4941, lon: -0.1738}, + place{name: "St. James's Park", lat: 51.4994, lon: -0.1335}, + place{name: "Temple", lat: 51.5111, lon: -0.1141}, + place{name: "Tottenham Court Road", lat: 51.5165, lon: -0.131}, + place{name: "Vauxhall", lat: 51.4861, lon: -0.1253}, + place{name: "Victoria", lat: 51.4965, lon: -0.1447}, + place{name: "Waterloo", lat: 51.5036, lon: -0.1143}, + place{name: "Westminster", lat: 51.501, lon: -0.1254}, +} + +// place is a vptree.Comparable implementations. +type place struct { + name string + lat, lon float64 +} + +// Distance returns the distance between the receiver and c. +func (p place) Distance(c vptree.Comparable) float64 { + q := c.(place) + return haversine(p.lat, p.lon, q.lat, q.lon) +} + +// haversine returns the distance between two geographic coordinates. +func haversine(lat1, lon1, lat2, lon2 float64) float64 { + const r = 6371 // km + sdLat := math.Sin(radians(lat2-lat1) / 2) + sdLon := math.Sin(radians(lon2-lon1) / 2) + a := sdLat*sdLat + math.Cos(radians(lat1))*math.Cos(radians(lat2))*sdLon*sdLon + d := 2 * r * math.Asin(math.Sqrt(a)) + return d // km +} + +func radians(d float64) float64 { + return d * math.Pi / 180 +}