diff --git a/interp/interp.go b/interp/interp.go index 5e75b810..be89ec97 100644 --- a/interp/interp.go +++ b/interp/interp.go @@ -4,7 +4,7 @@ package interp -import "sort" +import "slices" const ( differentLengths = "interp: input slices have different lengths" @@ -156,10 +156,13 @@ func (pc PiecewiseConstant) Predict(x float64) float64 { } // findSegment returns 0 <= i < len(xs) such that xs[i] <= x < xs[i + 1], where xs[len(xs)] -// is assumed to be +Inf. If no such i is found, it returns -1. It assumes that len(xs) >= 2 -// without checking. +// is assumed to be +Inf. If no such i is found, it returns -1. func findSegment(xs []float64, x float64) int { - return sort.Search(len(xs), func(i int) bool { return xs[i] > x }) - 1 + i, found := slices.BinarySearch(xs, x) + if !found { + return i - 1 + } + return i } // calculateSlopes calculates slopes (ys[i+1] - ys[i]) / (xs[i+1] - xs[i]). diff --git a/interp/interp_test.go b/interp/interp_test.go index 9ec506e7..c7b069a5 100644 --- a/interp/interp_test.go +++ b/interp/interp_test.go @@ -50,6 +50,28 @@ func TestFindSegment(t *testing.T) { } } +func TestFindSegmentEdgeCases(t *testing.T) { + t.Parallel() + + cases := []struct { + xs []float64 + x float64 + want int + }{ + {xs: nil, x: 0, want: -1}, + {xs: []float64{0}, x: -1, want: -1}, + {xs: []float64{0}, x: 0, want: 0}, + {xs: []float64{0}, x: 1, want: 0}, + } + + for _, test := range cases { + if got := findSegment(test.xs, test.x); got != test.want { + t.Errorf("unexpected value of findSegment(%v, %f): got %d want: %d", + test.xs, test.x, got, test.want) + } + } +} + func BenchmarkFindSegment(b *testing.B) { xs := []float64{0, 1.5, 3, 4.5, 6, 7.5, 9, 12, 13.5, 16.5} for i := 0; i < b.N; i++ {