mirror of
https://github.com/gonum/gonum.git
synced 2025-09-27 03:26:04 +08:00
interp: increase speed of findSegment
Using slices.BinarySearch instead of sort.Search increases the speed of findSegment by a factor of two and overall performance by about 30%. goos: linux goarch: amd64 pkg: gonum.org/v1/gonum/interp cpu: AMD Ryzen 7 5800 8-Core Processor │ old.bench │ new.bench │ │ sec/op │ sec/op vs base │ FindSegment-16 104.60n ± 1% 50.78n ± 1% -51.45% (p=0.000 n=10) NewPiecewiseLinear-16 114.5n ± 5% 112.2n ± 2% ~ (p=0.109 n=10) PiecewiseLinearPredict-16 116.00n ± 1% 84.44n ± 2% -27.21% (p=0.000 n=10) PiecewiseConstantPredict-16 87.95n ± 2% 63.93n ± 1% -27.31% (p=0.000 n=10) geomean 105.2n 74.47n -29.18%
This commit is contained in:
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
package interp
|
package interp
|
||||||
|
|
||||||
import "sort"
|
import "slices"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
differentLengths = "interp: input slices have different lengths"
|
differentLengths = "interp: input slices have different lengths"
|
||||||
@@ -156,10 +156,13 @@ func (pc PiecewiseConstant) Predict(x float64) float64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// findSegment returns 0 <= i < len(xs) such that xs[i] <= x < xs[i + 1], where xs[len(xs)]
|
// findSegment returns 0 <= i < len(xs) such that xs[i] <= x < xs[i + 1], where xs[len(xs)]
|
||||||
// is assumed to be +Inf. If no such i is found, it returns -1. It assumes that len(xs) >= 2
|
// is assumed to be +Inf. If no such i is found, it returns -1.
|
||||||
// without checking.
|
|
||||||
func findSegment(xs []float64, x float64) int {
|
func findSegment(xs []float64, x float64) int {
|
||||||
return sort.Search(len(xs), func(i int) bool { return xs[i] > x }) - 1
|
i, found := slices.BinarySearch(xs, x)
|
||||||
|
if !found {
|
||||||
|
return i - 1
|
||||||
|
}
|
||||||
|
return i
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculateSlopes calculates slopes (ys[i+1] - ys[i]) / (xs[i+1] - xs[i]).
|
// calculateSlopes calculates slopes (ys[i+1] - ys[i]) / (xs[i+1] - xs[i]).
|
||||||
|
@@ -50,6 +50,28 @@ func TestFindSegment(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFindSegmentEdgeCases(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
xs []float64
|
||||||
|
x float64
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{xs: nil, x: 0, want: -1},
|
||||||
|
{xs: []float64{0}, x: -1, want: -1},
|
||||||
|
{xs: []float64{0}, x: 0, want: 0},
|
||||||
|
{xs: []float64{0}, x: 1, want: 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range cases {
|
||||||
|
if got := findSegment(test.xs, test.x); got != test.want {
|
||||||
|
t.Errorf("unexpected value of findSegment(%v, %f): got %d want: %d",
|
||||||
|
test.xs, test.x, got, test.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkFindSegment(b *testing.B) {
|
func BenchmarkFindSegment(b *testing.B) {
|
||||||
xs := []float64{0, 1.5, 3, 4.5, 6, 7.5, 9, 12, 13.5, 16.5}
|
xs := []float64{0, 1.5, 3, 4.5, 6, 7.5, 9, 12, 13.5, 16.5}
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
Reference in New Issue
Block a user