diff --git a/mat/basictypes_test.go b/mat/basictypes_test.go new file mode 100644 index 00000000..a881f0f2 --- /dev/null +++ b/mat/basictypes_test.go @@ -0,0 +1,114 @@ +// Copyright ©2015 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mat + +import "gonum.org/v1/gonum/blas/blas64" + +func asBasicMatrix(d *Dense) *basicMatrix { return (*basicMatrix)(d) } +func asBasicVector(d *VecDense) *basicVector { return (*basicVector)(d) } +func asBasicSymmetric(s *SymDense) *basicSymmetric { return (*basicSymmetric)(s) } +func asBasicTriangular(t *TriDense) *basicTriangular { return (*basicTriangular)(t) } +func asBasicBanded(b *BandDense) *basicBanded { return (*basicBanded)(b) } +func asBasicSymBanded(s *SymBandDense) *basicSymBanded { return (*basicSymBanded)(s) } +func asBasicTriBanded(t *TriBandDense) *basicTriBanded { return (*basicTriBanded)(t) } +func asBasicDiagonal(d *DiagDense) *basicDiagonal { return (*basicDiagonal)(d) } + +type basicMatrix Dense + +var _ Matrix = &basicMatrix{} + +func (m *basicMatrix) At(r, c int) float64 { return (*Dense)(m).At(r, c) } +func (m *basicMatrix) Dims() (r, c int) { return (*Dense)(m).Dims() } +func (m *basicMatrix) T() Matrix { return Transpose{m} } + +type basicVector VecDense + +var _ Vector = &basicVector{} + +func (v *basicVector) At(r, c int) float64 { return (*VecDense)(v).At(r, c) } +func (v *basicVector) Dims() (r, c int) { return (*VecDense)(v).Dims() } +func (v *basicVector) T() Matrix { return Transpose{v} } +func (v *basicVector) AtVec(i int) float64 { return (*VecDense)(v).AtVec(i) } +func (v *basicVector) Len() int { return (*VecDense)(v).Len() } + +type rawVector struct { + *basicVector +} + +func (v *rawVector) RawVector() blas64.Vector { + return v.basicVector.mat +} + +type basicSymmetric SymDense + +var _ Symmetric = &basicSymmetric{} + +func (m *basicSymmetric) At(r, c int) float64 { return (*SymDense)(m).At(r, c) } +func (m *basicSymmetric) Dims() (r, c int) { return (*SymDense)(m).Dims() } +func (m *basicSymmetric) T() Matrix { return m } +func (m *basicSymmetric) Symmetric() int { return (*SymDense)(m).Symmetric() } + +type basicTriangular TriDense + +var _ Triangular = &basicTriangular{} + +func (m *basicTriangular) At(r, c int) float64 { return (*TriDense)(m).At(r, c) } +func (m *basicTriangular) Dims() (r, c int) { return (*TriDense)(m).Dims() } +func (m *basicTriangular) T() Matrix { return Transpose{m} } +func (m *basicTriangular) Triangle() (int, TriKind) { return (*TriDense)(m).Triangle() } +func (m *basicTriangular) TTri() Triangular { return TransposeTri{m} } + +type basicBanded BandDense + +var _ Banded = &basicBanded{} + +func (m *basicBanded) At(r, c int) float64 { return (*BandDense)(m).At(r, c) } +func (m *basicBanded) Dims() (r, c int) { return (*BandDense)(m).Dims() } +func (m *basicBanded) T() Matrix { return Transpose{m} } +func (m *basicBanded) Bandwidth() (kl, ku int) { return (*BandDense)(m).Bandwidth() } +func (m *basicBanded) TBand() Banded { return TransposeBand{m} } + +type basicSymBanded SymBandDense + +var _ SymBanded = &basicSymBanded{} + +func (m *basicSymBanded) At(r, c int) float64 { return (*SymBandDense)(m).At(r, c) } +func (m *basicSymBanded) Dims() (r, c int) { return (*SymBandDense)(m).Dims() } +func (m *basicSymBanded) T() Matrix { return m } +func (m *basicSymBanded) Bandwidth() (kl, ku int) { return (*SymBandDense)(m).Bandwidth() } +func (m *basicSymBanded) TBand() Banded { return m } +func (m *basicSymBanded) Symmetric() int { return (*SymBandDense)(m).Symmetric() } +func (m *basicSymBanded) SymBand() (n, k int) { return (*SymBandDense)(m).SymBand() } + +type basicTriBanded TriBandDense + +var _ TriBanded = &basicTriBanded{} + +func (m *basicTriBanded) At(r, c int) float64 { return (*TriBandDense)(m).At(r, c) } +func (m *basicTriBanded) Dims() (r, c int) { return (*TriBandDense)(m).Dims() } +func (m *basicTriBanded) T() Matrix { return Transpose{m} } +func (m *basicTriBanded) Triangle() (int, TriKind) { return (*TriBandDense)(m).Triangle() } +func (m *basicTriBanded) TTri() Triangular { return TransposeTri{m} } +func (m *basicTriBanded) Bandwidth() (kl, ku int) { return (*TriBandDense)(m).Bandwidth() } +func (m *basicTriBanded) TBand() Banded { return TransposeBand{m} } +func (m *basicTriBanded) TriBand() (n, k int, kind TriKind) { return (*TriBandDense)(m).TriBand() } +func (m *basicTriBanded) TTriBand() TriBanded { return TransposeTriBand{m} } + +type basicDiagonal DiagDense + +var _ Diagonal = &basicDiagonal{} + +func (m *basicDiagonal) At(r, c int) float64 { return (*DiagDense)(m).At(r, c) } +func (m *basicDiagonal) Dims() (r, c int) { return (*DiagDense)(m).Dims() } +func (m *basicDiagonal) T() Matrix { return Transpose{m} } +func (m *basicDiagonal) Diag() int { return (*DiagDense)(m).Diag() } +func (m *basicDiagonal) Symmetric() int { return (*DiagDense)(m).Symmetric() } +func (m *basicDiagonal) SymBand() (n, k int) { return (*DiagDense)(m).SymBand() } +func (m *basicDiagonal) Bandwidth() (kl, ku int) { return (*DiagDense)(m).Bandwidth() } +func (m *basicDiagonal) TBand() Banded { return TransposeBand{m} } +func (m *basicDiagonal) Triangle() (int, TriKind) { return (*DiagDense)(m).Triangle() } +func (m *basicDiagonal) TTri() Triangular { return TransposeTri{m} } +func (m *basicDiagonal) TriBand() (n, k int, kind TriKind) { return (*DiagDense)(m).TriBand() } +func (m *basicDiagonal) TTriBand() TriBanded { return TransposeTriBand{m} } diff --git a/mat/dense_test.go b/mat/dense_test.go index 739156fe..bf9a983c 100644 --- a/mat/dense_test.go +++ b/mat/dense_test.go @@ -17,15 +17,6 @@ import ( "gonum.org/v1/gonum/floats" ) -func asBasicMatrix(d *Dense) Matrix { return (*basicMatrix)(d) } -func asBasicSymmetric(s *SymDense) Symmetric { return (*basicSymmetric)(s) } -func asBasicTriangular(t *TriDense) Triangular { return (*basicTriangular)(t) } -func asBasicBanded(b *BandDense) Banded { return (*basicBanded)(b) } -func asBasicSymBanded(s *SymBandDense) SymBanded { return (*basicSymBanded)(s) } -func asBasicTriBanded(t *TriBandDense) TriBanded { return (*basicTriBanded)(t) } -func asBasicDiagonal(d *DiagDense) Diagonal { return (*basicDiagonal)(d) } -func asBasicVector(d *VecDense) Vector { return (*basicVector)(d) } - func TestNewDense(t *testing.T) { t.Parallel() for i, test := range []struct { diff --git a/mat/matrix_test.go b/mat/matrix_test.go index 70c8cf85..3d0dee09 100644 --- a/mat/matrix_test.go +++ b/mat/matrix_test.go @@ -740,7 +740,7 @@ func TestMulVecToer(t *testing.T) { } x = dst case 2: - x = &rawVector{(*basicVector)(NewVecDense(n, random(n)))} + x = &rawVector{asBasicVector(NewVecDense(n, random(n)))} case 3: x = asBasicVector(NewVecDense(n, random(n))) default: diff --git a/mat/mul_test.go b/mat/mul_test.go index 78f7deab..5897409c 100644 --- a/mat/mul_test.go +++ b/mat/mul_test.go @@ -202,232 +202,6 @@ func testMul(t *testing.T, a, b Matrix, c *Dense, acomp, bcomp, ccomp matComp, c } } -type basicMatrix Dense - -var _ Matrix = &basicMatrix{} - -func (m *basicMatrix) At(r, c int) float64 { - return (*Dense)(m).At(r, c) -} - -func (m *basicMatrix) Dims() (r, c int) { - return (*Dense)(m).Dims() -} - -func (m *basicMatrix) T() Matrix { - return Transpose{m} -} - -type basicVector VecDense - -var _ Vector = &basicVector{} - -func (v *basicVector) At(r, c int) float64 { return (*VecDense)(v).At(r, c) } -func (v *basicVector) AtVec(i int) float64 { return (*VecDense)(v).AtVec(i) } -func (v *basicVector) Dims() (r, c int) { return (*VecDense)(v).Dims() } -func (v *basicVector) Len() int { return (*VecDense)(v).Len() } -func (v *basicVector) T() Matrix { return Transpose{v} } - -type rawVector struct { - *basicVector -} - -func (v *rawVector) RawVector() blas64.Vector { - return v.basicVector.mat -} - -type basicBanded BandDense - -var _ Banded = &basicBanded{} - -func (m *basicBanded) At(r, c int) float64 { - return (*BandDense)(m).At(r, c) -} - -func (m *basicBanded) Dims() (r, c int) { - return (*BandDense)(m).Dims() -} - -func (m *basicBanded) Bandwidth() (kl, ku int) { - return (*BandDense)(m).Bandwidth() -} - -func (m *basicBanded) T() Matrix { - return Transpose{m} -} - -func (m *basicBanded) TBand() Banded { - return TransposeBand{m} -} - -type basicSymmetric SymDense - -var _ Symmetric = &basicSymmetric{} - -func (m *basicSymmetric) At(r, c int) float64 { - return (*SymDense)(m).At(r, c) -} - -func (m *basicSymmetric) Dims() (r, c int) { - return (*SymDense)(m).Dims() -} - -func (m *basicSymmetric) T() Matrix { - return m -} - -func (m *basicSymmetric) Symmetric() int { - return (*SymDense)(m).Symmetric() -} - -type basicTriangular TriDense - -var _ Triangular = &basicTriangular{} - -func (m *basicTriangular) At(r, c int) float64 { - return (*TriDense)(m).At(r, c) -} - -func (m *basicTriangular) Dims() (r, c int) { - return (*TriDense)(m).Dims() -} - -func (m *basicTriangular) T() Matrix { - return Transpose{m} -} - -func (m *basicTriangular) Triangle() (int, TriKind) { - return (*TriDense)(m).Triangle() -} - -func (m *basicTriangular) TTri() Triangular { - return TransposeTri{m} -} - -type basicSymBanded SymBandDense - -var _ SymBanded = &basicSymBanded{} - -func (m *basicSymBanded) At(r, c int) float64 { - return (*SymBandDense)(m).At(r, c) -} - -func (m *basicSymBanded) Dims() (r, c int) { - return (*SymBandDense)(m).Dims() -} - -func (m *basicSymBanded) T() Matrix { - return m -} - -func (m *basicSymBanded) TBand() Banded { - return m -} - -func (m *basicSymBanded) Symmetric() int { - return (*SymBandDense)(m).Symmetric() -} - -func (m *basicSymBanded) SymBand() (n, k int) { - return (*SymBandDense)(m).SymBand() -} - -func (m *basicSymBanded) Bandwidth() (kl, ku int) { - return (*SymBandDense)(m).Bandwidth() -} - -type basicTriBanded TriBandDense - -var _ TriBanded = &basicTriBanded{} - -func (m *basicTriBanded) At(r, c int) float64 { - return (*TriBandDense)(m).At(r, c) -} - -func (m *basicTriBanded) Dims() (r, c int) { - return (*TriBandDense)(m).Dims() -} - -func (m *basicTriBanded) T() Matrix { - return Transpose{m} -} - -func (m *basicTriBanded) TTri() Triangular { - return TransposeTri{m} -} - -func (m *basicTriBanded) TBand() Banded { - return TransposeBand{m} -} - -func (m *basicTriBanded) TTriBand() TriBanded { - return TransposeTriBand{m} -} - -func (m *basicTriBanded) Bandwidth() (kl, ku int) { - return (*TriBandDense)(m).Bandwidth() -} - -func (m *basicTriBanded) Triangle() (int, TriKind) { - return (*TriBandDense)(m).Triangle() -} - -func (m *basicTriBanded) TriBand() (n, k int, kind TriKind) { - return (*TriBandDense)(m).TriBand() -} - -type basicDiagonal DiagDense - -var _ Diagonal = &basicDiagonal{} - -func (m *basicDiagonal) At(r, c int) float64 { - return (*DiagDense)(m).At(r, c) -} - -func (m *basicDiagonal) Dims() (r, c int) { - return (*DiagDense)(m).Dims() -} - -func (m *basicDiagonal) Diag() int { - return (*DiagDense)(m).Diag() -} - -func (m *basicDiagonal) T() Matrix { - return Transpose{m} -} - -func (m *basicDiagonal) TTri() Triangular { - return TransposeTri{m} -} - -func (m *basicDiagonal) TBand() Banded { - return TransposeBand{m} -} - -func (m *basicDiagonal) TTriBand() TriBanded { - return TransposeTriBand{m} -} - -func (m *basicDiagonal) Bandwidth() (kl, ku int) { - return (*DiagDense)(m).Bandwidth() -} - -func (m *basicDiagonal) Symmetric() int { - return (*DiagDense)(m).Symmetric() -} - -func (m *basicDiagonal) SymBand() (n, k int) { - return (*DiagDense)(m).SymBand() -} - -func (m *basicDiagonal) Triangle() (int, TriKind) { - return (*DiagDense)(m).Triangle() -} - -func (m *basicDiagonal) TriBand() (n, k int, kind TriKind) { - return (*DiagDense)(m).TriBand() -} - func denseEqual(a *Dense, acomp matComp) bool { ar2, ac2 := a.Dims() if ar2 != acomp.r {