diff --git a/stat/combin/combin.go b/stat/combin/combin.go index c15540f6..71145ac0 100644 --- a/stat/combin/combin.go +++ b/stat/combin/combin.go @@ -6,8 +6,6 @@ package combin import ( "math" - - "gonum.org/v1/gonum/mat" ) const ( @@ -185,49 +183,34 @@ func nextCombination(s []int, n, k int) { } } -// Cartesian returns the cartesian product of the slices in data. The Cartesian +// Cartesian returns indices into the cartesian product for sets of the given lengths. The Cartesian // product of two sets is the set of all combinations of the items. For example, // given the input -// [][]float64{{1,2},{3,4},{5,6}} +// []int{2, 3, 1} // the returned matrix will be -// [ 1 3 5 ] -// [ 1 3 6 ] -// [ 1 4 5 ] -// [ 1 4 6 ] -// [ 2 3 5 ] -// [ 2 3 6 ] -// [ 2 4 5 ] -// [ 2 4 6 ] -// If dst is nil, a new matrix will be allocated and returned, otherwise the number -// of rows of dst must equal \prod_i len(data[i]), and the number of columns in -// dst must equal len(data). Cartesian also panics if len(data) = 0. -func Cartesian(dst *mat.Dense, data [][]float64) *mat.Dense { - if len(data) == 0 { - panic("combin: empty data input") +// [ 0 0 0 ] +// [ 0 1 0 ] +// [ 0 2 0 ] +// [ 1 0 0 ] +// [ 1 1 0 ] +// [ 1 2 0 ] +// Cartesian panics if any of the provided lengths are less than 1. +func Cartesian(lens []int) [][]int { + if len(lens) == 0 { + panic("combin: empty lengths") } - cols := len(data) rows := 1 - lens := make([]int, cols) - for i, d := range data { - v := len(d) - lens[i] = v + for _, v := range lens { + if v < 1 { + panic("combin: length less than zero") + } rows *= v } - if dst == nil { - dst = mat.NewDense(rows, cols, nil) - } - r, c := dst.Dims() - if r != rows || c != cols { - panic("combin: destination matrix size mismatch") - } - idxs := make([]int, cols) + out := make([][]int, rows) for i := 0; i < rows; i++ { - SubFor(idxs, i, lens) - for j := 0; j < len(data); j++ { - dst.Set(i, j, data[j][idxs[j]]) - } + out[i] = SubFor(nil, i, lens) } - return dst + return out } // IdxFor converts a multi-dimensional index into a linear index for a diff --git a/stat/combin/combin_test.go b/stat/combin/combin_test.go index ac41a3c1..ef884cb4 100644 --- a/stat/combin/combin_test.go +++ b/stat/combin/combin_test.go @@ -10,7 +10,6 @@ import ( "testing" "gonum.org/v1/gonum/floats" - "gonum.org/v1/gonum/mat" ) // intSosMatch returns true if the two slices of slices are equal. @@ -184,62 +183,36 @@ func TestCombinationGenerator(t *testing.T) { func TestCartesian(t *testing.T) { // First, test with a known return. - data := [][]float64{ - {1, 2}, - {3, 4}, - {5, 6}, + lens := []int{2, 3, 4} + want := [][]int{ + {0, 0, 0}, + {0, 0, 1}, + {0, 0, 2}, + {0, 0, 3}, + {0, 1, 0}, + {0, 1, 1}, + {0, 1, 2}, + {0, 1, 3}, + {0, 2, 0}, + {0, 2, 1}, + {0, 2, 2}, + {0, 2, 3}, + {1, 0, 0}, + {1, 0, 1}, + {1, 0, 2}, + {1, 0, 3}, + {1, 1, 0}, + {1, 1, 1}, + {1, 1, 2}, + {1, 1, 3}, + {1, 2, 0}, + {1, 2, 1}, + {1, 2, 2}, + {1, 2, 3}, } - want := mat.NewDense(8, 3, []float64{ - 1, 3, 5, - 1, 3, 6, - 1, 4, 5, - 1, 4, 6, - 2, 3, 5, - 2, 3, 6, - 2, 4, 5, - 2, 4, 6, - }) - got := Cartesian(nil, data) - if !mat.Equal(want, got) { - t.Errorf("cartesian data mismatch.\nwant:\n%v\ngot:\n%v", mat.Formatted(want), mat.Formatted(got)) - } - gotTo := mat.NewDense(8, 3, nil) - Cartesian(gotTo, data) - if !mat.Equal(want, got) { - t.Errorf("cartesian data mismatch with supplied.\nwant:\n%v\ngot:\n%v", mat.Formatted(want), mat.Formatted(gotTo)) - } - - // Test that Cartesian generates unique vectors. - for cas, data := range [][][]float64{ - {{1}, {2, 3}, {8, 9, 10}}, - {{1, 10}, {2, 3}, {8, 9, 10}}, - {{1, 10, 11}, {2, 3}, {8}}, - } { - cart := Cartesian(nil, data) - r, c := cart.Dims() - if c != len(data) { - t.Errorf("Case %v: wrong number of columns. Want %v, got %v", cas, len(data), c) - } - wantRows := 1 - for _, v := range data { - wantRows *= len(v) - } - if r != wantRows { - t.Errorf("Case %v: wrong number of rows. Want %v, got %v", cas, wantRows, r) - } - for i := 0; i < r; i++ { - for j := i + 1; j < r; j++ { - if floats.Equal(cart.RawRowView(i), cart.RawRowView(j)) { - t.Errorf("Cas %v: rows %d and %d are equal", cas, i, j) - } - } - } - - cartTo := mat.NewDense(r, c, nil) - Cartesian(cartTo, data) - if !mat.Equal(cart, cartTo) { - t.Errorf("cartesian data mismatch with supplied.\nwant:\n%v\ngot:\n%v", mat.Formatted(cart), mat.Formatted(cartTo)) - } + got := Cartesian(lens) + if !intSosMatch(want, got) { + t.Errorf("cartesian data mismatch.\nwant:\n%v\ngot:\n%v", want, got) } }