From f48364e31d40fb9c3b6de7b7d20223edd6d49779 Mon Sep 17 00:00:00 2001 From: Tom Payne Date: Thu, 13 Jun 2024 22:38:53 +0200 Subject: [PATCH] interp: increase speed of findSegment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Using slices.BinarySearch instead of sort.Search increases the speed of findSegment by a factor of two and overall performance by about 30%. goos: linux goarch: amd64 pkg: gonum.org/v1/gonum/interp cpu: AMD Ryzen 7 5800 8-Core Processor │ old.bench │ new.bench │ │ sec/op │ sec/op vs base │ FindSegment-16 104.60n ± 1% 50.78n ± 1% -51.45% (p=0.000 n=10) NewPiecewiseLinear-16 114.5n ± 5% 112.2n ± 2% ~ (p=0.109 n=10) PiecewiseLinearPredict-16 116.00n ± 1% 84.44n ± 2% -27.21% (p=0.000 n=10) PiecewiseConstantPredict-16 87.95n ± 2% 63.93n ± 1% -27.31% (p=0.000 n=10) geomean 105.2n 74.47n -29.18% --- interp/interp.go | 11 +++++++---- interp/interp_test.go | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/interp/interp.go b/interp/interp.go index 5e75b810..be89ec97 100644 --- a/interp/interp.go +++ b/interp/interp.go @@ -4,7 +4,7 @@ package interp -import "sort" +import "slices" const ( differentLengths = "interp: input slices have different lengths" @@ -156,10 +156,13 @@ func (pc PiecewiseConstant) Predict(x float64) float64 { } // findSegment returns 0 <= i < len(xs) such that xs[i] <= x < xs[i + 1], where xs[len(xs)] -// is assumed to be +Inf. If no such i is found, it returns -1. It assumes that len(xs) >= 2 -// without checking. +// is assumed to be +Inf. If no such i is found, it returns -1. func findSegment(xs []float64, x float64) int { - return sort.Search(len(xs), func(i int) bool { return xs[i] > x }) - 1 + i, found := slices.BinarySearch(xs, x) + if !found { + return i - 1 + } + return i } // calculateSlopes calculates slopes (ys[i+1] - ys[i]) / (xs[i+1] - xs[i]). diff --git a/interp/interp_test.go b/interp/interp_test.go index 9ec506e7..c7b069a5 100644 --- a/interp/interp_test.go +++ b/interp/interp_test.go @@ -50,6 +50,28 @@ func TestFindSegment(t *testing.T) { } } +func TestFindSegmentEdgeCases(t *testing.T) { + t.Parallel() + + cases := []struct { + xs []float64 + x float64 + want int + }{ + {xs: nil, x: 0, want: -1}, + {xs: []float64{0}, x: -1, want: -1}, + {xs: []float64{0}, x: 0, want: 0}, + {xs: []float64{0}, x: 1, want: 0}, + } + + for _, test := range cases { + if got := findSegment(test.xs, test.x); got != test.want { + t.Errorf("unexpected value of findSegment(%v, %f): got %d want: %d", + test.xs, test.x, got, test.want) + } + } +} + func BenchmarkFindSegment(b *testing.B) { xs := []float64{0, 1.5, 3, 4.5, 6, 7.5, 9, 12, 13.5, 16.5} for i := 0; i < b.N; i++ {