mirror of
https://github.com/gonum/gonum.git
synced 2025-10-21 06:09:26 +08:00
Working implementation of blocked QR
Improved function documentation Fixed dlarfb and dlarft and added full tests Added dgelq2 Working Dgels Fix many comments and tests Many PR comment responses Responded to more PR comments Many PR comments
This commit is contained in:
50
lapack.go
50
lapack.go
@@ -10,23 +10,51 @@ const None = 'N'
|
|||||||
|
|
||||||
type Job byte
|
type Job byte
|
||||||
|
|
||||||
const (
|
// CompSV determines if the singular values are to be computed in compact form.
|
||||||
All (Job) = 'A'
|
|
||||||
Slim (Job) = 'S'
|
|
||||||
Overwrite (Job) = 'O'
|
|
||||||
)
|
|
||||||
|
|
||||||
type CompSV byte
|
type CompSV byte
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Compact (CompSV) = 'P'
|
Compact CompSV = 'P'
|
||||||
Explicit (CompSV) = 'I'
|
Explicit CompSV = 'I'
|
||||||
)
|
)
|
||||||
|
|
||||||
// Float64 defines the float64 interface for the Lapack function. This interface
|
// Complex128 defines the public complex128 LAPACK API supported by gonum/lapack.
|
||||||
// contains the functions needed in the gonum suite.
|
type Complex128 interface{}
|
||||||
|
|
||||||
|
// Float64 defines the public float64 LAPACK API supported by gonum/lapack.
|
||||||
type Float64 interface {
|
type Float64 interface {
|
||||||
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
|
Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Complex128 interface{}
|
// Direct specifies the direction of the multiplication for the Householder matrix.
|
||||||
|
type Direct byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
Forward Direct = 'F' // Reflectors are right-multiplied, H_1 * H_2 * ... * H_k
|
||||||
|
Backward Direct = 'B' // Reflectors are left-multiplied, H_k * ... * H_2 * H_1
|
||||||
|
)
|
||||||
|
|
||||||
|
// StoreV indicates the storage direction of elementary reflectors.
|
||||||
|
type StoreV byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
ColumnWise StoreV = 'C' // Reflector stored in a column of the matrix.
|
||||||
|
RowWise StoreV = 'R' // Reflector stored in a row of the matrix.
|
||||||
|
)
|
||||||
|
|
||||||
|
// MatrixNorm represents the kind of matrix norm to compute.
|
||||||
|
type MatrixNorm byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
MaxAbs MatrixNorm = 'M' // max(abs(A(i,j))) ('M')
|
||||||
|
MaxColumnSum MatrixNorm = 'O' // Maximum column sum (one norm) ('1', 'O')
|
||||||
|
MaxRowSum MatrixNorm = 'I' // Maximum row sum (infinity norm) ('I', 'i')
|
||||||
|
NormFrob MatrixNorm = 'F' // Frobenium norm (sqrt of sum of squares) ('F', 'f', E, 'e')
|
||||||
|
)
|
||||||
|
|
||||||
|
// MatrixType represents the kind of matrix represented in the data.
|
||||||
|
type MatrixType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
General MatrixType = 'G' // A dense matrix (like blas64.General).
|
||||||
|
)
|
||||||
|
44
native/dgelq2.go
Normal file
44
native/dgelq2.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import "github.com/gonum/blas"
|
||||||
|
|
||||||
|
// Dgelq2 computes the LQ factorization of the m×n matrix a.
|
||||||
|
//
|
||||||
|
// During Dgelq2, a is modified to contain the information to construct Q and L.
|
||||||
|
// The lower triangle of a contains the matrix L. The upper triangular elements
|
||||||
|
// (not including the diagonal) contain the elementary reflectors. Tau is modified
|
||||||
|
// to contain the reflector scales. Tau must have length of at least k = min(m,n)
|
||||||
|
// and this function will panic otherwise.
|
||||||
|
//
|
||||||
|
// See Dgeqr2 for a description of the elementary reflectors and orthonormal
|
||||||
|
// matrix Q. Q is constructed as a product of these elementary reflectors,
|
||||||
|
// Q = H_k ... H_2*H_1.
|
||||||
|
//
|
||||||
|
// Work is temporary storage of length at least m and this function will panic otherwise.
|
||||||
|
func (impl Implementation) Dgelq2(m, n int, a []float64, lda int, tau, work []float64) {
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
k := min(m, n)
|
||||||
|
if len(tau) < k {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
if len(work) < m {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
a[i*lda+i], tau[i] = impl.Dlarfg(n-i, a[i*lda+i], a[i*lda+min(i+1, n-1):], 1)
|
||||||
|
if i < m-1 {
|
||||||
|
aii := a[i*lda+i]
|
||||||
|
a[i*lda+i] = 1
|
||||||
|
impl.Dlarf(blas.Right, m-i-1, n-i,
|
||||||
|
a[i*lda+i:], 1,
|
||||||
|
tau[i],
|
||||||
|
a[(i+1)*lda+i:], lda,
|
||||||
|
work)
|
||||||
|
a[i*lda+i] = aii
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
84
native/dgelqf.go
Normal file
84
native/dgelqf.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dgelqf computes the LQ factorization of the m×n matrix a using a blocked
|
||||||
|
// algorithm. Please see the documentation for Dgelq2 for a description of the
|
||||||
|
// parameters at entry and exit.
|
||||||
|
//
|
||||||
|
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||||
|
// At minimum, lwork >= m, and this function will panic otherwise.
|
||||||
|
// Dgelqf is a blocked LQ factorization, but the block size is limited
|
||||||
|
// by the temporary space available. If lwork == -1, instead of performing Dgelqf,
|
||||||
|
// the optimal work length will be stored into work[0].
|
||||||
|
//
|
||||||
|
// tau must have length at least min(m,n), and this function will panic otherwise.
|
||||||
|
func (impl Implementation) Dgelqf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
|
||||||
|
nb := impl.Ilaenv(1, "DGELQF", " ", m, n, -1, -1)
|
||||||
|
lworkopt := m * max(nb, 1)
|
||||||
|
if lwork == -1 {
|
||||||
|
work[0] = float64(lworkopt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if len(work) < lwork {
|
||||||
|
panic(shortWork)
|
||||||
|
}
|
||||||
|
if lwork < m {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
k := min(m, n)
|
||||||
|
if len(tau) < k {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
if k == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Find the optimal blocking size based on the size of available memory
|
||||||
|
// and optimal machine parameters.
|
||||||
|
nbmin := 2
|
||||||
|
var nx int
|
||||||
|
iws := m
|
||||||
|
ldwork := nb
|
||||||
|
if nb > 1 && k > nb {
|
||||||
|
nx = max(0, impl.Ilaenv(3, "DGELQF", " ", m, n, -1, -1))
|
||||||
|
if nx < k {
|
||||||
|
iws = m * nb
|
||||||
|
if lwork < iws {
|
||||||
|
nb = lwork / m
|
||||||
|
nbmin = max(2, impl.Ilaenv(2, "DGELQF", " ", m, n, -1, -1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Computed blocked LQ factorization.
|
||||||
|
var i int
|
||||||
|
if nb >= nbmin && nb < k && nx < k {
|
||||||
|
for i = 0; i < k-nx; i += nb {
|
||||||
|
ib := min(k-i, nb)
|
||||||
|
impl.Dgelq2(ib, n-i, a[i*lda+i:], lda, tau[i:], work)
|
||||||
|
if i+ib < m {
|
||||||
|
impl.Dlarft(lapack.Forward, lapack.RowWise, n-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
tau[i:],
|
||||||
|
work, ldwork)
|
||||||
|
impl.Dlarfb(blas.Right, blas.NoTrans, lapack.Forward, lapack.RowWise,
|
||||||
|
m-i-ib, n-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
work, ldwork,
|
||||||
|
a[(i+ib)*lda+i:], lda,
|
||||||
|
work[ib*ldwork:], ldwork)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Perform unblocked LQ factorization on the remainder.
|
||||||
|
if i < k {
|
||||||
|
impl.Dgelq2(m-i, n-i, a[i*lda+i:], lda, tau[i:], work)
|
||||||
|
}
|
||||||
|
}
|
200
native/dgels.go
Normal file
200
native/dgels.go
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dgels finds a minimum-norm solution based on the matrices a and b using the
|
||||||
|
// QR or LQ factorization. Dgels returns false if the matrix
|
||||||
|
// A is singular, and true if this solution was successfully found.
|
||||||
|
//
|
||||||
|
// The minimization problem solved depends on the input parameters.
|
||||||
|
//
|
||||||
|
// 1. If m >= n and trans == blas.NoTrans, Dgels finds X such that || A*X - B||_2
|
||||||
|
// is minimized.
|
||||||
|
// 2. If m < n and trans == blas.NoTrans, Dgels finds the minimum norm solution of
|
||||||
|
// A * X = B.
|
||||||
|
// 3. If m >= n and trans == blas.Trans, Dgels finds the minimum norm solution of
|
||||||
|
// A^T * X = B.
|
||||||
|
// 4. If m < n and trans == blas.Trans, Dgels finds X such that || A*X - B||_2
|
||||||
|
// is minimized.
|
||||||
|
// Note that the least-squares solutions (cases 1 and 3) perform the minimization
|
||||||
|
// per column of B. This is not the same as finding the minimum-norm matrix.
|
||||||
|
//
|
||||||
|
// The matrix a is a general matrix of size m×n and is modified during this call.
|
||||||
|
// The input matrix b is of size max(m,n)×nrhs, and serves two purposes. On entry,
|
||||||
|
// the elements of b specify the input matrix B. B has size m×nrhs if
|
||||||
|
// trans == blas.NoTrans, and n×nrhs if trans == blas.Trans. On exit, the
|
||||||
|
// leading submatrix of b contains the solution vectors X. If trans == blas.NoTrans,
|
||||||
|
// this submatrix is of size n×nrhs, and of size m×nrhs otherwise.
|
||||||
|
//
|
||||||
|
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||||
|
// At minimum, lwork >= max(m,n) + max(m,n,nrhs), and this function will panic
|
||||||
|
// otherwise. A longer work will enable blocked algorithms to be called.
|
||||||
|
// In the special case that lwork == -1, work[0] will be set to the optimal working
|
||||||
|
// length.
|
||||||
|
func (impl Implementation) Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool {
|
||||||
|
notran := trans == blas.NoTrans
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
mn := min(m, n)
|
||||||
|
checkMatrix(mn, nrhs, b, ldb)
|
||||||
|
|
||||||
|
// Find optimal block size.
|
||||||
|
tpsd := true
|
||||||
|
if notran {
|
||||||
|
tpsd = false
|
||||||
|
}
|
||||||
|
var nb int
|
||||||
|
if m >= n {
|
||||||
|
nb = impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1)
|
||||||
|
if tpsd {
|
||||||
|
nb = max(nb, impl.Ilaenv(1, "DORMQR", "LN", m, nrhs, n, -1))
|
||||||
|
} else {
|
||||||
|
nb = max(nb, impl.Ilaenv(1, "DORMQR", "LT", m, nrhs, n, -1))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nb = impl.Ilaenv(1, "DGELQF", " ", m, n, -1, -1)
|
||||||
|
if tpsd {
|
||||||
|
nb = max(nb, impl.Ilaenv(1, "DORMLQ", "LT", n, nrhs, m, -1))
|
||||||
|
} else {
|
||||||
|
nb = max(nb, impl.Ilaenv(1, "DORMLQ", "LN", n, nrhs, m, -1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lwork == -1 {
|
||||||
|
work[0] = float64(max(1, mn+max(mn, nrhs)*nb))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(work) < lwork {
|
||||||
|
panic(shortWork)
|
||||||
|
}
|
||||||
|
if lwork < mn+max(mn, nrhs) {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
if m == 0 || n == 0 || nrhs == 0 {
|
||||||
|
impl.Dlaset(blas.All, max(m, n), nrhs, 0, 0, b, ldb)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scale the input matrices if they contain extreme values.
|
||||||
|
smlnum := dlamchS / dlamchP
|
||||||
|
bignum := 1 / smlnum
|
||||||
|
anrm := impl.Dlange(lapack.MaxAbs, m, n, a, lda, nil)
|
||||||
|
var iascl int
|
||||||
|
if anrm > 0 && anrm < smlnum {
|
||||||
|
impl.Dlascl(lapack.General, 0, 0, anrm, smlnum, m, n, a, lda)
|
||||||
|
iascl = 1
|
||||||
|
} else if anrm > bignum {
|
||||||
|
impl.Dlascl(lapack.General, 0, 0, anrm, bignum, m, n, a, lda)
|
||||||
|
} else if anrm == 0 {
|
||||||
|
// Matrix all zeros
|
||||||
|
impl.Dlaset(blas.All, max(m, n), nrhs, 0, 0, b, ldb)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
brow := m
|
||||||
|
if tpsd {
|
||||||
|
brow = n
|
||||||
|
}
|
||||||
|
bnrm := impl.Dlange(lapack.MaxAbs, brow, nrhs, b, ldb, nil)
|
||||||
|
ibscl := 0
|
||||||
|
if bnrm > 0 && bnrm < smlnum {
|
||||||
|
impl.Dlascl(lapack.General, 0, 0, bnrm, smlnum, brow, nrhs, b, ldb)
|
||||||
|
ibscl = 1
|
||||||
|
} else if bnrm > bignum {
|
||||||
|
impl.Dlascl(lapack.General, 0, 0, bnrm, bignum, brow, nrhs, b, ldb)
|
||||||
|
ibscl = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
// Solve the minimization problem using a QR or an LQ decomposition.
|
||||||
|
var scllen int
|
||||||
|
if m >= n {
|
||||||
|
impl.Dgeqrf(m, n, a, lda, work, work[mn:], lwork-mn)
|
||||||
|
if !tpsd {
|
||||||
|
impl.Dormqr(blas.Left, blas.Trans, m, nrhs, n,
|
||||||
|
a, lda,
|
||||||
|
work,
|
||||||
|
b, ldb,
|
||||||
|
work[mn:], lwork-mn)
|
||||||
|
ok := impl.Dtrtrs(blas.Upper, blas.NoTrans, blas.NonUnit, n, nrhs,
|
||||||
|
a, lda,
|
||||||
|
b, ldb)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
scllen = n
|
||||||
|
} else {
|
||||||
|
ok := impl.Dtrtrs(blas.Upper, blas.Trans, blas.NonUnit, n, nrhs,
|
||||||
|
a, lda,
|
||||||
|
b, ldb)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := n; i < m; i++ {
|
||||||
|
for j := 0; j < nrhs; j++ {
|
||||||
|
b[i*ldb+j] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl.Dormqr(blas.Left, blas.NoTrans, m, nrhs, n,
|
||||||
|
a, lda,
|
||||||
|
work,
|
||||||
|
b, ldb,
|
||||||
|
work[mn:], lwork-mn)
|
||||||
|
scllen = m
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
impl.Dgelqf(m, n, a, lda, work, work[mn:], lwork-mn)
|
||||||
|
if !tpsd {
|
||||||
|
ok := impl.Dtrtrs(blas.Lower, blas.NoTrans, blas.NonUnit,
|
||||||
|
m, nrhs,
|
||||||
|
a, lda,
|
||||||
|
b, ldb)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := m; i < n; i++ {
|
||||||
|
for j := 0; j < nrhs; j++ {
|
||||||
|
b[i*ldb+j] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl.Dormlq(blas.Left, blas.Trans, n, nrhs, m,
|
||||||
|
a, lda,
|
||||||
|
work,
|
||||||
|
b, ldb,
|
||||||
|
work[mn:], lwork-mn)
|
||||||
|
scllen = n
|
||||||
|
} else {
|
||||||
|
impl.Dormlq(blas.Left, blas.NoTrans, n, nrhs, m,
|
||||||
|
a, lda,
|
||||||
|
work,
|
||||||
|
b, ldb,
|
||||||
|
work[mn:], lwork-mn)
|
||||||
|
ok := impl.Dtrtrs(blas.Lower, blas.Trans, blas.NonUnit,
|
||||||
|
m, nrhs,
|
||||||
|
a, lda,
|
||||||
|
b, ldb)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust answer vector based on scaling.
|
||||||
|
if iascl == 1 {
|
||||||
|
impl.Dlascl(lapack.General, 0, 0, anrm, smlnum, scllen, nrhs, b, ldb)
|
||||||
|
}
|
||||||
|
if iascl == 2 {
|
||||||
|
impl.Dlascl(lapack.General, 0, 0, anrm, bignum, scllen, nrhs, b, ldb)
|
||||||
|
}
|
||||||
|
if ibscl == 1 {
|
||||||
|
impl.Dlascl(lapack.General, 0, 0, smlnum, bnrm, scllen, nrhs, b, ldb)
|
||||||
|
}
|
||||||
|
if ibscl == 2 {
|
||||||
|
impl.Dlascl(lapack.General, 0, 0, bignum, bnrm, scllen, nrhs, b, ldb)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
57
native/dgeqr2.go
Normal file
57
native/dgeqr2.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import "github.com/gonum/blas"
|
||||||
|
|
||||||
|
// Dgeqr2 computes a QR factorization of the m×n matrix a.
|
||||||
|
//
|
||||||
|
// In a QR factorization, Q is an m×m orthonormal matrix, and R is an
|
||||||
|
// upper triangular m×n matrix.
|
||||||
|
//
|
||||||
|
// During Dgeqr2, a is modified to contain the information to construct Q and R.
|
||||||
|
// The upper triangle of a contains the matrix R. The lower triangular elements
|
||||||
|
// (not including the diagonal) contain the elementary reflectors. Tau is modified
|
||||||
|
// to contain the reflector scales. Tau must have length at least k = min(m,n), and
|
||||||
|
// this function will panic otherwise.
|
||||||
|
//
|
||||||
|
// The ith elementary reflector can be explicitly constructed by first extracting
|
||||||
|
// the
|
||||||
|
// v[j] = 0 j < i
|
||||||
|
// v[j] = i j == i
|
||||||
|
// v[j] = a[i*lda+j] j > i
|
||||||
|
// and computing h_i = I - tau[i] * v * v^T.
|
||||||
|
//
|
||||||
|
// The orthonormal matrix Q can be constucted from a product of these elementary
|
||||||
|
// reflectors, Q = H_1*H_2 ... H_k, where k = min(m,n).
|
||||||
|
//
|
||||||
|
// Work is temporary storage of length at least n and this function will panic otherwise.
|
||||||
|
func (impl Implementation) Dgeqr2(m, n int, a []float64, lda int, tau, work []float64) {
|
||||||
|
// TODO(btracey): This is oriented such that columns of a are eliminated.
|
||||||
|
// This likely could be re-arranged to take better advantage of row-major
|
||||||
|
// storage.
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if len(work) < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
k := min(m, n)
|
||||||
|
if len(tau) < k {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
// Generate elementary reflector H(i).
|
||||||
|
a[i*lda+i], tau[i] = impl.Dlarfg(m-i, a[i*lda+i], a[min((i+1), m-1)*lda+i:], lda)
|
||||||
|
if i < n-1 {
|
||||||
|
aii := a[i*lda+i]
|
||||||
|
a[i*lda+i] = 1
|
||||||
|
impl.Dlarf(blas.Left, m-i, n-i-1,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
tau[i],
|
||||||
|
a[i*lda+i+1:], lda,
|
||||||
|
work)
|
||||||
|
a[i*lda+i] = aii
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
98
native/dgeqrf.go
Normal file
98
native/dgeqrf.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dgeqrf computes the QR factorization of the m×n matrix a using a blocked
|
||||||
|
// algorithm. Please see the documentation for Dgeqr2 for a description of the
|
||||||
|
// parameters at entry and exit.
|
||||||
|
//
|
||||||
|
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||||
|
// At minimum, lwork >= m and this function will panic otherwise.
|
||||||
|
// Dgeqrf is a blocked LQ factorization, but the block size is limited
|
||||||
|
// by the temporary space available. If lwork == -1, instead of performing Dgelqf,
|
||||||
|
// the optimal work length will be stored into work[0].
|
||||||
|
//
|
||||||
|
// tau must be at least len min(m,n), and this function will panic otherwise.
|
||||||
|
func (impl Implementation) Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int) {
|
||||||
|
// TODO(btracey): This algorithm is oriented for column-major storage.
|
||||||
|
// Consider modifying the algorithm to better suit row-major storage.
|
||||||
|
|
||||||
|
// nb is the optimal blocksize, i.e. the number of columns transformed at a time.
|
||||||
|
nb := impl.Ilaenv(1, "DGEQRF", " ", m, n, -1, -1)
|
||||||
|
lworkopt := n * max(nb, 1)
|
||||||
|
lworkopt = max(n, lworkopt)
|
||||||
|
if lwork == -1 {
|
||||||
|
work[0] = float64(lworkopt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if len(work) < lwork {
|
||||||
|
panic(shortWork)
|
||||||
|
}
|
||||||
|
if lwork < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
k := min(m, n)
|
||||||
|
if len(tau) < k {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
if k == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nbmin := 2 // Minimal number of blocks
|
||||||
|
var nx int // Use unblocked (unless changed in the next for loop)
|
||||||
|
iws := n
|
||||||
|
ldwork := nb
|
||||||
|
// Only consider blocked if the suggested number of blocks is > 1 and the
|
||||||
|
// number of columns is sufficiently large.
|
||||||
|
if nb > 1 && k > nb {
|
||||||
|
// nx is the crossover point. Above this value the blocked routine should be used.
|
||||||
|
nx = max(0, impl.Ilaenv(3, "DGEQRF", " ", m, n, -1, -1))
|
||||||
|
if k > nx {
|
||||||
|
iws = ldwork * n
|
||||||
|
if lwork < iws {
|
||||||
|
// Not enough workspace to use the optimal number of blocks. Instead,
|
||||||
|
// get the maximum allowable number of blocks.
|
||||||
|
nb = lwork / n
|
||||||
|
nbmin = max(2, impl.Ilaenv(2, "DGEQRF", " ", m, n, -1, -1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := range work {
|
||||||
|
work[i] = 0
|
||||||
|
}
|
||||||
|
// Compute QR using a blocked algorithm.
|
||||||
|
var i int
|
||||||
|
if nb >= nbmin && nb < k && nx < k {
|
||||||
|
for i = 0; i < k-nx; i += nb {
|
||||||
|
ib := min(k-i, nb)
|
||||||
|
// Compute the QR factorization of the current block.
|
||||||
|
impl.Dgeqr2(m-i, ib, a[i*lda+i:], lda, tau[i:], work)
|
||||||
|
if i+ib < n {
|
||||||
|
// Form the triangular factor of the block reflector and apply H^T
|
||||||
|
// In Dlarft, work becomes the T matrix.
|
||||||
|
impl.Dlarft(lapack.Forward, lapack.ColumnWise, m-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
tau[i:],
|
||||||
|
work, ldwork)
|
||||||
|
impl.Dlarfb(blas.Left, blas.Trans, lapack.Forward, lapack.ColumnWise,
|
||||||
|
m-i, n-i-ib, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
work, ldwork,
|
||||||
|
a[i*lda+i+ib:], lda,
|
||||||
|
work[ib*ldwork:], ldwork)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Call unblocked code on the remaining columns.
|
||||||
|
if i < k {
|
||||||
|
impl.Dgeqr2(m-i, n-i, a[i*lda+i:], lda, tau[i:], work)
|
||||||
|
}
|
||||||
|
}
|
76
native/dlange.go
Normal file
76
native/dlange.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dlange computes the matrix norm of the general m×n matrix a. The input norm
|
||||||
|
// specifies the norm computed.
|
||||||
|
// lapack.MaxAbs: the maximum absolute value of an element.
|
||||||
|
// lapack.MaxColumnSum: the maximum column sum of the absolute values of the entries.
|
||||||
|
// lapack.MaxRowSum: the maximum row sum of the absolute values of the entries.
|
||||||
|
// lapack.Frobenius: the square root of the sum of the squares of the entries.
|
||||||
|
// If norm == lapack.MaxColumnSum, work must be of length n, and this function will panic otherwise.
|
||||||
|
// There are no restrictions on work for the other matrix norms.
|
||||||
|
func (impl Implementation) Dlange(norm lapack.MatrixNorm, m, n int, a []float64, lda int, work []float64) float64 {
|
||||||
|
// TODO(btracey): These should probably be refactored to use BLAS calls.
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if m == 0 && n == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if norm == lapack.MaxAbs {
|
||||||
|
var value float64
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
value = math.Max(value, math.Abs(a[i*lda+j]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
if norm == lapack.MaxColumnSum {
|
||||||
|
if len(work) < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
work[i] = 0
|
||||||
|
}
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
work[j] += math.Abs(a[i*lda+j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var value float64
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
value = math.Max(value, work[i])
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
if norm == lapack.MaxRowSum {
|
||||||
|
var value float64
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
var sum float64
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
sum += math.Abs(a[i*lda+j])
|
||||||
|
}
|
||||||
|
value = math.Max(value, sum)
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
if norm == lapack.NormFrob {
|
||||||
|
var value float64
|
||||||
|
scale := 0.0
|
||||||
|
sum := 1.0
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
scale, sum = impl.Dlassq(n, a[i*lda:], 1, scale, sum)
|
||||||
|
}
|
||||||
|
value = scale * math.Sqrt(sum)
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
panic("lapack: bad matrix norm")
|
||||||
|
}
|
12
native/dlapy2.go
Normal file
12
native/dlapy2.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import "math"
|
||||||
|
|
||||||
|
// Dlapy2 is the LAPACK version of math.Hypot.
|
||||||
|
func (Implementation) Dlapy2(x, y float64) float64 {
|
||||||
|
return math.Hypot(x, y)
|
||||||
|
}
|
@@ -17,25 +17,37 @@ import (
|
|||||||
// h = 1 - tau * v * v^T
|
// h = 1 - tau * v * v^T
|
||||||
// and c is an m * n matrix.
|
// and c is an m * n matrix.
|
||||||
//
|
//
|
||||||
// Work is pre-allocated memory of size at least m if side == Left and at least
|
|
||||||
// n if side == Right. Dlarf will panic if this length requirement is not met.
|
// Work is temporary storage of length at least m if side == Left and at least
|
||||||
|
// n if side == Right. This function will panic if this length requirement is not met.
|
||||||
func (impl Implementation) Dlarf(side blas.Side, m, n int, v []float64, incv int, tau float64, c []float64, ldc int, work []float64) {
|
func (impl Implementation) Dlarf(side blas.Side, m, n int, v []float64, incv int, tau float64, c []float64, ldc int, work []float64) {
|
||||||
applyleft := side == blas.Left
|
applyleft := side == blas.Left
|
||||||
if (applyleft && len(work) < n) || (!applyleft && len(work) < m) {
|
if (applyleft && len(work) < n) || (!applyleft && len(work) < m) {
|
||||||
panic("dlarf: insufficient work length")
|
panic(badWork)
|
||||||
}
|
}
|
||||||
|
checkMatrix(m, n, c, ldc)
|
||||||
|
|
||||||
|
// v has length m if applyleft and n otherwise.
|
||||||
|
lenV := n
|
||||||
|
if applyleft {
|
||||||
|
lenV = m
|
||||||
|
}
|
||||||
|
|
||||||
|
checkVector(lenV, v, incv)
|
||||||
|
|
||||||
lastv := 0 // last non-zero element of v
|
lastv := 0 // last non-zero element of v
|
||||||
lastc := 0 // last non-zero row/column of c
|
lastc := 0 // last non-zero row/column of c
|
||||||
if tau != 0 {
|
if tau != 0 {
|
||||||
var i int
|
var i int
|
||||||
if applyleft {
|
if applyleft {
|
||||||
lastv = m
|
lastv = m - 1
|
||||||
} else {
|
} else {
|
||||||
lastv = n
|
lastv = n - 1
|
||||||
}
|
}
|
||||||
if incv > 0 {
|
if incv > 0 {
|
||||||
i = (lastv - 1) * incv
|
i = lastv * incv
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look for the last non-zero row in v.
|
// Look for the last non-zero row in v.
|
||||||
for lastv >= 0 && v[i] == 0 {
|
for lastv >= 0 && v[i] == 0 {
|
||||||
lastv--
|
lastv--
|
||||||
@@ -43,10 +55,10 @@ func (impl Implementation) Dlarf(side blas.Side, m, n int, v []float64, incv int
|
|||||||
}
|
}
|
||||||
if applyleft {
|
if applyleft {
|
||||||
// Scan for the last non-zero column in C[0:lastv, :]
|
// Scan for the last non-zero column in C[0:lastv, :]
|
||||||
lastc = impl.Iladlc(lastv, n, c, ldc)
|
lastc = impl.Iladlc(lastv+1, n, c, ldc)
|
||||||
} else {
|
} else {
|
||||||
// Scan for the last non-zero row in C[:, 0:lastv]
|
// Scan for the last non-zero row in C[:, 0:lastv]
|
||||||
lastc = impl.Iladlr(m, lastv, c, ldc)
|
lastc = impl.Iladlr(m, lastv+1, c, ldc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if lastv == -1 || lastc == -1 {
|
if lastv == -1 || lastc == -1 {
|
||||||
|
424
native/dlarfb.go
Normal file
424
native/dlarfb.go
Normal file
@@ -0,0 +1,424 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dlarfb applies a block reflector to a matrix.
|
||||||
|
//
|
||||||
|
// In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows:
|
||||||
|
// c = h * c if side == Left and trans == NoTrans
|
||||||
|
// c = c * h if side == Right and trans == NoTrans
|
||||||
|
// c = h^T * c if side == Left and trans == Trans
|
||||||
|
// c = c * h^t if side == Right and trans == Trans
|
||||||
|
// h is a product of elementary reflectors. direct sets the direction of multiplication
|
||||||
|
// h = h_1 * h_2 * ... * h_k if direct == Forward
|
||||||
|
// h = h_k * h_k-1 * ... * h_1 if direct == Backward
|
||||||
|
// The combination of direct and store defines the orientation of the elementary
|
||||||
|
// reflectors. In all cases the ones on the diagonal are implicitly represented.
|
||||||
|
//
|
||||||
|
// If direct == lapack.Forward and store == lapack.ColumnWise
|
||||||
|
// V = ( 1 )
|
||||||
|
// ( v1 1 )
|
||||||
|
// ( v1 v2 1 )
|
||||||
|
// ( v1 v2 v3 )
|
||||||
|
// ( v1 v2 v3 )
|
||||||
|
// If direct == lapack.Forward and store == lapack.RowWise
|
||||||
|
// V = ( 1 v1 v1 v1 v1 )
|
||||||
|
// ( 1 v2 v2 v2 )
|
||||||
|
// ( 1 v3 v3 )
|
||||||
|
// If direct == lapack.Backward and store == lapack.ColumnWise
|
||||||
|
// V = ( v1 v2 v3 )
|
||||||
|
// ( v1 v2 v3 )
|
||||||
|
// ( 1 v2 v3 )
|
||||||
|
// ( 1 v3 )
|
||||||
|
// ( 1 )
|
||||||
|
// If direct == lapack.Backward and store == lapack.RowWise
|
||||||
|
// V = ( v1 v1 1 )
|
||||||
|
// ( v2 v2 v2 1 )
|
||||||
|
// ( v3 v3 v3 v3 1 )
|
||||||
|
// An elementary reflector can be explicitly constructed by extracting the
|
||||||
|
// corresponding elements of v, placing a 1 where the diagonal would be, and
|
||||||
|
// placing zeros in the remaining elements.
|
||||||
|
//
|
||||||
|
// t is a k×k matrix containing the block reflector, and this function will panic
|
||||||
|
// if t is not of sufficient size. See Dlarft for more information.
|
||||||
|
//
|
||||||
|
// Work is a temporary storage matrix with stride ldwork.
|
||||||
|
// Work must be of size at least n×k side == Left and m×k if side == Right, and
|
||||||
|
// this function will panic if this size is not met.
|
||||||
|
func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct,
|
||||||
|
store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int,
|
||||||
|
c []float64, ldc int, work []float64, ldwork int) {
|
||||||
|
|
||||||
|
checkMatrix(m, n, c, ldc)
|
||||||
|
if m == 0 || n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if k < 0 {
|
||||||
|
panic("lapack: negative number of transforms")
|
||||||
|
}
|
||||||
|
if side != blas.Left && side != blas.Right {
|
||||||
|
panic(badSide)
|
||||||
|
}
|
||||||
|
if trans != blas.Trans && trans != blas.NoTrans {
|
||||||
|
panic(badTrans)
|
||||||
|
}
|
||||||
|
if direct != lapack.Forward && direct != lapack.Backward {
|
||||||
|
panic(badDirect)
|
||||||
|
}
|
||||||
|
if store != lapack.ColumnWise && store != lapack.RowWise {
|
||||||
|
panic(badStore)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsWork := n
|
||||||
|
if side == blas.Right {
|
||||||
|
rowsWork = m
|
||||||
|
}
|
||||||
|
checkMatrix(rowsWork, k, work, ldwork)
|
||||||
|
|
||||||
|
bi := blas64.Implementation()
|
||||||
|
|
||||||
|
transt := blas.Trans
|
||||||
|
if trans == blas.Trans {
|
||||||
|
transt = blas.NoTrans
|
||||||
|
}
|
||||||
|
// TODO(btracey): This follows the original Lapack code where the
|
||||||
|
// elements are copied into the columns of the working array. The
|
||||||
|
// loops should go in the other direction so the data is written
|
||||||
|
// into the rows of work so the copy is not strided. A bigger change
|
||||||
|
// would be to replace work with work^T, but benchmarks would be
|
||||||
|
// needed to see if the change is merited.
|
||||||
|
if store == lapack.ColumnWise {
|
||||||
|
if direct == lapack.Forward {
|
||||||
|
// V1 is the first k rows of C. V2 is the remaining rows.
|
||||||
|
if side == blas.Left {
|
||||||
|
// W = C^T V = C1^T V1 + C2^T V2 (stored in work).
|
||||||
|
|
||||||
|
// W = C1.
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
|
||||||
|
}
|
||||||
|
// W = W * V1.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit,
|
||||||
|
n, k, 1,
|
||||||
|
v, ldv,
|
||||||
|
work, ldwork)
|
||||||
|
if m > k {
|
||||||
|
// W = W + C2^T V2.
|
||||||
|
bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
|
||||||
|
1, c[k*ldc:], ldc, v[k*ldv:], ldv,
|
||||||
|
1, work, ldwork)
|
||||||
|
}
|
||||||
|
// W = W * T^T or W * T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
|
||||||
|
1, t, ldt,
|
||||||
|
work, ldwork)
|
||||||
|
// C -= V * W^T.
|
||||||
|
if m > k {
|
||||||
|
// C2 -= V2 * W^T.
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
|
||||||
|
-1, v[k*ldv:], ldv, work, ldwork,
|
||||||
|
1, c[k*ldc:], ldc)
|
||||||
|
}
|
||||||
|
// W *= V1^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
|
||||||
|
1, v, ldv,
|
||||||
|
work, ldwork)
|
||||||
|
// C1 -= W^T.
|
||||||
|
// TODO(btracey): This should use blas.Axpy.
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
c[j*ldc+i] -= work[i*ldwork+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Form C = C * H or C * H^T, where C = (C1 C2).
|
||||||
|
|
||||||
|
// W = C1.
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
bi.Dcopy(m, c[i:], ldc, work[i:], ldwork)
|
||||||
|
}
|
||||||
|
// W *= V1.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
|
||||||
|
1, v, ldv,
|
||||||
|
work, ldwork)
|
||||||
|
if n > k {
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
|
||||||
|
1, c[k:], ldc, v[k*ldv:], ldv,
|
||||||
|
1, work, ldwork)
|
||||||
|
}
|
||||||
|
// W *= T or T^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
|
||||||
|
1, t, ldt,
|
||||||
|
work, ldwork)
|
||||||
|
if n > k {
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
|
||||||
|
-1, work, ldwork, v[k*ldv:], ldv,
|
||||||
|
1, c[k:], ldc)
|
||||||
|
}
|
||||||
|
// C -= W * V^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
|
||||||
|
1, v, ldv,
|
||||||
|
work, ldwork)
|
||||||
|
// C -= W.
|
||||||
|
// TODO(btracey): This should use blas.Axpy.
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
c[i*ldc+j] -= work[i*ldwork+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// V = (V1)
|
||||||
|
// = (V2) (last k rows)
|
||||||
|
// Where V2 is unit upper triangular.
|
||||||
|
if side == blas.Left {
|
||||||
|
// Form H * C or
|
||||||
|
// W = C^T V.
|
||||||
|
|
||||||
|
// W = C2^T.
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
|
||||||
|
}
|
||||||
|
// W *= V2.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
|
||||||
|
1, v[(m-k)*ldv:], ldv,
|
||||||
|
work, ldwork)
|
||||||
|
if m > k {
|
||||||
|
// W += C1^T * V1.
|
||||||
|
bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
|
||||||
|
1, c, ldc, v, ldv,
|
||||||
|
1, work, ldwork)
|
||||||
|
}
|
||||||
|
// W *= T or T^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
|
||||||
|
1, t, ldt,
|
||||||
|
work, ldwork)
|
||||||
|
// C -= V * W^T.
|
||||||
|
if m > k {
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
|
||||||
|
-1, v, ldv, work, ldwork,
|
||||||
|
1, c, ldc)
|
||||||
|
}
|
||||||
|
// W *= V2^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
|
||||||
|
1, v[(m-k)*ldv:], ldv,
|
||||||
|
work, ldwork)
|
||||||
|
// C2 -= W^T.
|
||||||
|
// TODO(btracey): This should use blas.Axpy.
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Form C * H or C * H^T where C = (C1 C2).
|
||||||
|
// W = C * V.
|
||||||
|
|
||||||
|
// W = C2.
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
|
||||||
|
}
|
||||||
|
|
||||||
|
// W = W * V2.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
|
||||||
|
1, v[(n-k)*ldv:], ldv,
|
||||||
|
work, ldwork)
|
||||||
|
if n > k {
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
|
||||||
|
1, c, ldc, v, ldv,
|
||||||
|
1, work, ldwork)
|
||||||
|
}
|
||||||
|
// W *= T or T^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
|
||||||
|
1, t, ldt,
|
||||||
|
work, ldwork)
|
||||||
|
// C -= W * V^T.
|
||||||
|
if n > k {
|
||||||
|
// C1 -= W * V1^T.
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
|
||||||
|
-1, work, ldwork, v, ldv,
|
||||||
|
1, c, ldc)
|
||||||
|
}
|
||||||
|
// W *= V2^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
|
||||||
|
1, v[(n-k)*ldv:], ldv,
|
||||||
|
work, ldwork)
|
||||||
|
// C2 -= W.
|
||||||
|
// TODO(btracey): This should use blas.Axpy.
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
c[i*ldc+n-k+j] -= work[i*ldwork+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Store = Rowwise.
|
||||||
|
if direct == lapack.Forward {
|
||||||
|
// V = (V1 V2) where v1 is unit upper triangular.
|
||||||
|
if side == blas.Left {
|
||||||
|
// Form H * C or H^T * C where C = (C1; C2).
|
||||||
|
// W = C^T * V^T.
|
||||||
|
|
||||||
|
// W = C1^T.
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
|
||||||
|
}
|
||||||
|
// W *= V1^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
|
||||||
|
1, v, ldv,
|
||||||
|
work, ldwork)
|
||||||
|
if m > k {
|
||||||
|
bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
|
||||||
|
1, c[k*ldc:], ldc, v[k:], ldv,
|
||||||
|
1, work, ldwork)
|
||||||
|
}
|
||||||
|
// W *= T or T^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
|
||||||
|
1, t, ldt,
|
||||||
|
work, ldwork)
|
||||||
|
// C -= V^T * W^T.
|
||||||
|
if m > k {
|
||||||
|
bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
|
||||||
|
-1, v[k:], ldv, work, ldwork,
|
||||||
|
1, c[k*ldc:], ldc)
|
||||||
|
}
|
||||||
|
// W *= V1.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
|
||||||
|
1, v, ldv,
|
||||||
|
work, ldwork)
|
||||||
|
// C1 -= W^T.
|
||||||
|
// TODO(btracey): This should use blas.Axpy.
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
c[j*ldc+i] -= work[i*ldwork+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Form C * H or C * H^T where C = (C1 C2).
|
||||||
|
// W = C * V^T.
|
||||||
|
|
||||||
|
// W = C1.
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
bi.Dcopy(m, c[j:], ldc, work[j:], ldwork)
|
||||||
|
}
|
||||||
|
// W *= V1^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
|
||||||
|
1, v, ldv,
|
||||||
|
work, ldwork)
|
||||||
|
if n > k {
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
|
||||||
|
1, c[k:], ldc, v[k:], ldv,
|
||||||
|
1, work, ldwork)
|
||||||
|
}
|
||||||
|
// W *= T or T^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
|
||||||
|
1, t, ldt,
|
||||||
|
work, ldwork)
|
||||||
|
// C -= W * V.
|
||||||
|
if n > k {
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
|
||||||
|
-1, work, ldwork, v[k:], ldv,
|
||||||
|
1, c[k:], ldc)
|
||||||
|
}
|
||||||
|
// W *= V1.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
|
||||||
|
1, v, ldv,
|
||||||
|
work, ldwork)
|
||||||
|
// C1 -= W.
|
||||||
|
// TODO(btracey): This should use blas.Axpy.
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
c[i*ldc+j] -= work[i*ldwork+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// V = (V1 V2) where V2 is the last k columns and is lower unit triangular.
|
||||||
|
if side == blas.Left {
|
||||||
|
// Form H * C or H^T C where C = (C1 ; C2).
|
||||||
|
// W = C^T * V^T.
|
||||||
|
|
||||||
|
// W = C2^T.
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
|
||||||
|
}
|
||||||
|
// W *= V2^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
|
||||||
|
1, v[m-k:], ldv,
|
||||||
|
work, ldwork)
|
||||||
|
if m > k {
|
||||||
|
bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
|
||||||
|
1, c, ldc, v, ldv,
|
||||||
|
1, work, ldwork)
|
||||||
|
}
|
||||||
|
// W *= T or T^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
|
||||||
|
1, t, ldt,
|
||||||
|
work, ldwork)
|
||||||
|
// C -= V^T * W^T.
|
||||||
|
if m > k {
|
||||||
|
bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
|
||||||
|
-1, v, ldv, work, ldwork,
|
||||||
|
1, c, ldc)
|
||||||
|
}
|
||||||
|
// W *= V2.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k,
|
||||||
|
1, v[m-k:], ldv,
|
||||||
|
work, ldwork)
|
||||||
|
// C2 -= W^T.
|
||||||
|
// TODO(btracey): This should use blas.Axpy.
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Form C * H or C * H^T where C = (C1 C2).
|
||||||
|
// W = C * V^T.
|
||||||
|
// W = C2.
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
|
||||||
|
}
|
||||||
|
// W *= V2^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
|
||||||
|
1, v[n-k:], ldv,
|
||||||
|
work, ldwork)
|
||||||
|
if n > k {
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
|
||||||
|
1, c, ldc, v, ldv,
|
||||||
|
1, work, ldwork)
|
||||||
|
}
|
||||||
|
// W *= T or T^T.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
|
||||||
|
1, t, ldt,
|
||||||
|
work, ldwork)
|
||||||
|
// C -= W * V.
|
||||||
|
if n > k {
|
||||||
|
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
|
||||||
|
-1, work, ldwork, v, ldv,
|
||||||
|
1, c, ldc)
|
||||||
|
}
|
||||||
|
// W *= V2.
|
||||||
|
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
|
||||||
|
1, v[n-k:], ldv,
|
||||||
|
work, ldwork)
|
||||||
|
// C1 -= W.
|
||||||
|
// TODO(btracey): This should use blas.Axpy.
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
c[i*ldc+n-k+j] -= work[i*ldwork+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
60
native/dlarfg.go
Normal file
60
native/dlarfg.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dlarfg generates an elementary reflector for a Householder matrix. It creates
|
||||||
|
// a real elementary reflector of order n such that
|
||||||
|
// H * (alpha) = (beta)
|
||||||
|
// ( x) ( 0)
|
||||||
|
// H^T * H = I
|
||||||
|
// H is represented in the form
|
||||||
|
// H = 1 - tau * (1; v) * (1 v^T)
|
||||||
|
// where tau is a real scalar.
|
||||||
|
//
|
||||||
|
// On entry, x contains the vector x, on exit it contains v.
|
||||||
|
func (impl Implementation) Dlarfg(n int, alpha float64, x []float64, incX int) (beta, tau float64) {
|
||||||
|
if n < 0 {
|
||||||
|
panic(nLT0)
|
||||||
|
}
|
||||||
|
if n <= 1 {
|
||||||
|
return alpha, 0
|
||||||
|
}
|
||||||
|
checkVector(n-1, x, incX)
|
||||||
|
bi := blas64.Implementation()
|
||||||
|
xnorm := bi.Dnrm2(n-1, x, incX)
|
||||||
|
if xnorm == 0 {
|
||||||
|
return alpha, 0
|
||||||
|
}
|
||||||
|
beta = -math.Copysign(impl.Dlapy2(alpha, xnorm), alpha)
|
||||||
|
safmin := dlamchS / dlamchE
|
||||||
|
knt := 0
|
||||||
|
if math.Abs(beta) < safmin {
|
||||||
|
// xnorm and beta may be innacurate, scale x and recompute.
|
||||||
|
rsafmn := 1 / safmin
|
||||||
|
for {
|
||||||
|
knt++
|
||||||
|
bi.Dscal(n-1, rsafmn, x, incX)
|
||||||
|
beta *= rsafmn
|
||||||
|
alpha *= rsafmn
|
||||||
|
if math.Abs(beta) >= safmin {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xnorm = bi.Dnrm2(n-1, x, incX)
|
||||||
|
beta = -math.Copysign(impl.Dlapy2(alpha, xnorm), alpha)
|
||||||
|
}
|
||||||
|
tau = (beta - alpha) / beta
|
||||||
|
bi.Dscal(n-1, 1/(alpha-beta), x, incX)
|
||||||
|
for j := 0; j < knt; j++ {
|
||||||
|
beta *= safmin
|
||||||
|
}
|
||||||
|
return beta, tau
|
||||||
|
}
|
148
native/dlarft.go
Normal file
148
native/dlarft.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dlarft forms the triangular factor t of a block reflector, storing the answer
|
||||||
|
// in t.
|
||||||
|
// H = 1 - V * T * V^T if store == lapack.ColumnWise
|
||||||
|
// H = 1 - V^T * T * V if store == lapack.RowWise
|
||||||
|
// H is defined by a product of the elementary reflectors where
|
||||||
|
// H = H_1 * H_2 * ... * H_k if direct == lapack.Forward
|
||||||
|
// H = H_k * H_k-1 * ... * H_1 if direct == lapack.Backward
|
||||||
|
//
|
||||||
|
// t is a k×k triangular matrix. t is upper triangular if direct = lapack.Forward
|
||||||
|
// and lower triangular otherwise. This function will panic if t is not of
|
||||||
|
// sufficient size.
|
||||||
|
//
|
||||||
|
// store describes the storage of the elementary reflectors in v. Please see
|
||||||
|
// Dlarfb for a description of layout.
|
||||||
|
//
|
||||||
|
// tau contains the scalar factor of the elementary reflectors h.
|
||||||
|
func (Implementation) Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int,
|
||||||
|
v []float64, ldv int, tau []float64, t []float64, ldt int) {
|
||||||
|
if n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n < 0 || k < 0 {
|
||||||
|
panic(negDimension)
|
||||||
|
}
|
||||||
|
if direct != lapack.Forward && direct != lapack.Backward {
|
||||||
|
panic(badDirect)
|
||||||
|
}
|
||||||
|
if store != lapack.RowWise && store != lapack.ColumnWise {
|
||||||
|
panic(badStore)
|
||||||
|
}
|
||||||
|
if len(tau) < k {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
checkMatrix(k, k, t, ldt)
|
||||||
|
bi := blas64.Implementation()
|
||||||
|
// TODO(btracey): There are a number of minor obvious loop optimizations here.
|
||||||
|
// TODO(btracey): It may be possible to rearrange some of the code so that
|
||||||
|
// index of 1 is more common in the Dgemv.
|
||||||
|
if direct == lapack.Forward {
|
||||||
|
prevlastv := n - 1
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
prevlastv = max(i, prevlastv)
|
||||||
|
if tau[i] == 0 {
|
||||||
|
for j := 0; j <= i; j++ {
|
||||||
|
t[j*ldt+i] = 0
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var lastv int
|
||||||
|
if store == lapack.ColumnWise {
|
||||||
|
// skip trailing zeros
|
||||||
|
for lastv = n - 1; lastv >= i+1; lastv-- {
|
||||||
|
if v[lastv*ldv+i] != 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
t[j*ldt+i] = -tau[i] * v[i*ldv+j]
|
||||||
|
}
|
||||||
|
j := min(lastv, prevlastv)
|
||||||
|
bi.Dgemv(blas.Trans, j-i, i,
|
||||||
|
-tau[i], v[(i+1)*ldv:], ldv, v[(i+1)*ldv+i:], ldv,
|
||||||
|
1, t[i:], ldt)
|
||||||
|
} else {
|
||||||
|
for lastv = n - 1; lastv >= i+1; lastv-- {
|
||||||
|
if v[i*ldv+lastv] != 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
t[j*ldt+i] = -tau[i] * v[j*ldv+i]
|
||||||
|
}
|
||||||
|
j := min(lastv, prevlastv)
|
||||||
|
bi.Dgemv(blas.NoTrans, i, j-i,
|
||||||
|
-tau[i], v[i+1:], ldv, v[i*ldv+i+1:], 1,
|
||||||
|
1, t[i:], ldt)
|
||||||
|
}
|
||||||
|
bi.Dtrmv(blas.Upper, blas.NoTrans, blas.NonUnit, i, t, ldt, t[i:], ldt)
|
||||||
|
t[i*ldt+i] = tau[i]
|
||||||
|
if i > 1 {
|
||||||
|
prevlastv = max(prevlastv, lastv)
|
||||||
|
} else {
|
||||||
|
prevlastv = lastv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
prevlastv := 0
|
||||||
|
for i := k - 1; i >= 0; i-- {
|
||||||
|
if tau[i] == 0 {
|
||||||
|
for j := i; j < k; j++ {
|
||||||
|
t[j*ldt+i] = 0
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var lastv int
|
||||||
|
if i < k-1 {
|
||||||
|
if store == lapack.ColumnWise {
|
||||||
|
for lastv = 0; lastv < i; lastv++ {
|
||||||
|
if v[lastv*ldv+i] != 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j := i + 1; j < k; j++ {
|
||||||
|
t[j*ldt+i] = -tau[i] * v[(n-k+i)*ldv+j]
|
||||||
|
}
|
||||||
|
j := max(lastv, prevlastv)
|
||||||
|
bi.Dgemv(blas.Trans, n-k+i-j, k-i-1,
|
||||||
|
-tau[i], v[j*ldv+i+1:], ldv, v[j*ldv+i:], ldv,
|
||||||
|
1, t[(i+1)*ldt+i:], ldt)
|
||||||
|
} else {
|
||||||
|
for lastv := 0; lastv < i; lastv++ {
|
||||||
|
if v[i*ldv+lastv] != 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j := i + 1; j < k; j++ {
|
||||||
|
t[j*ldt+i] = -tau[i] * v[j*ldv+n-k+i]
|
||||||
|
}
|
||||||
|
j := max(lastv, prevlastv)
|
||||||
|
bi.Dgemv(blas.NoTrans, k-i-1, n-k+i-j,
|
||||||
|
-tau[i], v[(i+1)*ldv+j:], ldv, v[i*ldv+j:], 1,
|
||||||
|
1, t[(i+1)*ldt+i:], ldt)
|
||||||
|
}
|
||||||
|
bi.Dtrmv(blas.Lower, blas.NoTrans, blas.NonUnit, k-i-1,
|
||||||
|
t[(i+1)*ldt+i+1:], ldt,
|
||||||
|
t[(i+1)*ldt+i:], ldt)
|
||||||
|
if i > 0 {
|
||||||
|
prevlastv = min(prevlastv, lastv)
|
||||||
|
} else {
|
||||||
|
prevlastv = lastv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t[i*ldt+i] = tau[i]
|
||||||
|
}
|
||||||
|
}
|
72
native/dlascl.go
Normal file
72
native/dlascl.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dlascl multiplies a rectangular matrix by a scalar.
|
||||||
|
func (impl Implementation) Dlascl(kind lapack.MatrixType, kl, ku int, cfrom, cto float64, m, n int, a []float64, lda int) {
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if cfrom == 0 {
|
||||||
|
panic("dlascl: zero divisor")
|
||||||
|
}
|
||||||
|
if math.IsNaN(cfrom) || math.IsNaN(cto) {
|
||||||
|
panic("dlascl: NaN scale factor")
|
||||||
|
}
|
||||||
|
if n == 0 || m == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
smlnum := dlamchS
|
||||||
|
bignum := 1 / smlnum
|
||||||
|
cfromc := cfrom
|
||||||
|
ctoc := cto
|
||||||
|
cfrom1 := cfromc * smlnum
|
||||||
|
for {
|
||||||
|
var done bool
|
||||||
|
var mul, ctol float64
|
||||||
|
if cfrom1 == cfromc {
|
||||||
|
// cfromc is inf
|
||||||
|
mul = ctoc / cfromc
|
||||||
|
done = true
|
||||||
|
ctol = ctoc
|
||||||
|
} else {
|
||||||
|
ctol = ctoc / bignum
|
||||||
|
if ctol == ctoc {
|
||||||
|
// ctoc is either 0 or inf.
|
||||||
|
mul = ctoc
|
||||||
|
done = true
|
||||||
|
cfromc = 1
|
||||||
|
} else if math.Abs(cfrom1) > math.Abs(ctoc) && ctoc != 0 {
|
||||||
|
mul = smlnum
|
||||||
|
done = false
|
||||||
|
cfromc = cfrom1
|
||||||
|
} else if math.Abs(ctol) > math.Abs(cfromc) {
|
||||||
|
mul = bignum
|
||||||
|
done = false
|
||||||
|
ctoc = ctol
|
||||||
|
} else {
|
||||||
|
mul = ctoc / cfromc
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch kind {
|
||||||
|
default:
|
||||||
|
panic("lapack: not implemented")
|
||||||
|
case lapack.General:
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
a[i*lda+j] = a[i*lda+j] * mul
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
37
native/dlaset.go
Normal file
37
native/dlaset.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import "github.com/gonum/blas"
|
||||||
|
|
||||||
|
// Dlaset sets the off-diagonal elements of a to alpha, and the diagonal elements
|
||||||
|
// of a to beta. If uplo == blas.Upper, only the upper diagonal elements are set.
|
||||||
|
// If uplo == blas.Lower, only the lower diagonal elements are set. If uplo is
|
||||||
|
// otherwise, all of the elements of a are set.
|
||||||
|
func (impl Implementation) Dlaset(uplo blas.Uplo, m, n int, alpha, beta float64, a []float64, lda int) {
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
if uplo == blas.Upper {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := i + 1; j < n; j++ {
|
||||||
|
a[i*lda+j] = alpha
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if uplo == blas.Lower {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
a[i*lda+j] = alpha
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
a[i*lda+j] = alpha
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := 0; i < min(m, n); i++ {
|
||||||
|
a[i*lda+i] = beta
|
||||||
|
}
|
||||||
|
}
|
29
native/dlassq.go
Normal file
29
native/dlassq.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import "math"
|
||||||
|
|
||||||
|
// Dlassq updates a sum of squares in scaled form. The input parameters scale and
|
||||||
|
// sumsq represent the current scale and total sum of squares. These values are
|
||||||
|
// updated with the information in the first n elements of the vector specified
|
||||||
|
// by x and incX.
|
||||||
|
func (impl Implementation) Dlassq(n int, x []float64, incx int, scale float64, sumsq float64) (scl, smsq float64) {
|
||||||
|
if n <= 0 {
|
||||||
|
return scale, sumsq
|
||||||
|
}
|
||||||
|
for ix := 0; ix <= (n-1)*incx; ix += incx {
|
||||||
|
absxi := math.Abs(x[ix])
|
||||||
|
if absxi > 0 || math.IsNaN(absxi) {
|
||||||
|
if scale < absxi {
|
||||||
|
sumsq = 1 + sumsq*(scale/absxi)*(scale/absxi)
|
||||||
|
scale = absxi
|
||||||
|
} else {
|
||||||
|
sumsq += (absxi / scale) * (absxi / scale)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scale, sumsq
|
||||||
|
}
|
86
native/dorm2r.go
Normal file
86
native/dorm2r.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import "github.com/gonum/blas"
|
||||||
|
|
||||||
|
// Dorm2r multiplies a general matrix c by an orthogonal matrix from a QR factorization
|
||||||
|
// determined by Dgeqrf.
|
||||||
|
// C = Q * C if side == blas.Left and trans == blas.NoTrans
|
||||||
|
// C = Q^T * C if side == blas.Left and trans == blas.Trans
|
||||||
|
// C = C * Q if side == blas.Right and trans == blas.NoTrans
|
||||||
|
// C = C * Q^T if side == blas.Right and trans == blas.Trans
|
||||||
|
// If side == blas.Left, a is a matrix of size m×k, and if side == blas.Right
|
||||||
|
// a is of size n×k.
|
||||||
|
//
|
||||||
|
// Tau contains the householder factors and is of length at least k and this function
|
||||||
|
// will panic otherwise.
|
||||||
|
//
|
||||||
|
// Work is temporary storage of length at least n if side == blas.Left
|
||||||
|
// and at least m if side == blas.Right and this function will panic otherwise.
|
||||||
|
func (impl Implementation) Dorm2r(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64) {
|
||||||
|
if side != blas.Left && side != blas.Right {
|
||||||
|
panic(badSide)
|
||||||
|
}
|
||||||
|
if trans != blas.Trans && trans != blas.NoTrans {
|
||||||
|
panic(badTrans)
|
||||||
|
}
|
||||||
|
|
||||||
|
left := side == blas.Left
|
||||||
|
notran := trans == blas.NoTrans
|
||||||
|
if left {
|
||||||
|
// Q is m x m
|
||||||
|
checkMatrix(m, k, a, lda)
|
||||||
|
if len(work) < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Q is n x n
|
||||||
|
checkMatrix(n, k, a, lda)
|
||||||
|
if len(work) < m {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
checkMatrix(m, n, c, ldc)
|
||||||
|
if m == 0 || n == 0 || k == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(tau) < k {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
if left {
|
||||||
|
if notran {
|
||||||
|
for i := k - 1; i >= 0; i-- {
|
||||||
|
aii := a[i*lda+i]
|
||||||
|
a[i*lda+i] = 1
|
||||||
|
impl.Dlarf(side, m-i, n, a[i*lda+i:], lda, tau[i], c[i*ldc:], ldc, work)
|
||||||
|
a[i*lda+i] = aii
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
aii := a[i*lda+i]
|
||||||
|
a[i*lda+i] = 1
|
||||||
|
impl.Dlarf(side, m-i, n, a[i*lda+i:], lda, tau[i], c[i*ldc:], ldc, work)
|
||||||
|
a[i*lda+i] = aii
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if notran {
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
aii := a[i*lda+i]
|
||||||
|
a[i*lda+i] = 1
|
||||||
|
impl.Dlarf(side, m, n-i, a[i*lda+i:], lda, tau[i], c[i:], ldc, work)
|
||||||
|
a[i*lda+i] = aii
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := k - 1; i >= 0; i-- {
|
||||||
|
aii := a[i*lda+i]
|
||||||
|
a[i*lda+i] = 1
|
||||||
|
impl.Dlarf(side, m, n-i, a[i*lda+i:], lda, tau[i], c[i:], ldc, work)
|
||||||
|
a[i*lda+i] = aii
|
||||||
|
}
|
||||||
|
}
|
83
native/dorml2.go
Normal file
83
native/dorml2.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import "github.com/gonum/blas"
|
||||||
|
|
||||||
|
// Dorml2 multiplies a general matrix c by an orthogonal matrix from an LQ factorization
|
||||||
|
// determined by Dgelqf.
|
||||||
|
// C = Q * C if side == blas.Left and trans == blas.NoTrans
|
||||||
|
// C = Q^T * C if side == blas.Left and trans == blas.Trans
|
||||||
|
// C = C * Q if side == blas.Right and trans == blas.NoTrans
|
||||||
|
// C = C * Q^T if side == blas.Right and trans == blas.Trans
|
||||||
|
// If side == blas.Left, a is a matrix of side k×m, and if side == blas.Right
|
||||||
|
// a is of size k×n.
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// Tau contains the householder factors and is of length at least k and this function will
|
||||||
|
// panic otherwise.
|
||||||
|
//
|
||||||
|
// Work is temporary storage of length at least n if side == blas.Left
|
||||||
|
// and at least m if side == blas.Right and this function will panic otherwise.
|
||||||
|
func (impl Implementation) Dorml2(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64) {
|
||||||
|
if side != blas.Left && side != blas.Right {
|
||||||
|
panic(badSide)
|
||||||
|
}
|
||||||
|
if trans != blas.Trans && trans != blas.NoTrans {
|
||||||
|
panic(badTrans)
|
||||||
|
}
|
||||||
|
|
||||||
|
left := side == blas.Left
|
||||||
|
notran := trans == blas.NoTrans
|
||||||
|
if left {
|
||||||
|
checkMatrix(k, m, a, lda)
|
||||||
|
if len(work) < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
checkMatrix(k, n, a, lda)
|
||||||
|
if len(work) < m {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
checkMatrix(m, n, c, ldc)
|
||||||
|
if m == 0 || n == 0 || k == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case left && notran:
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
aii := a[i*lda+i]
|
||||||
|
a[i*lda+i] = 1
|
||||||
|
impl.Dlarf(side, m-i, n, a[i*lda+i:], 1, tau[i], c[i*ldc:], ldc, work)
|
||||||
|
a[i*lda+i] = aii
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case left && !notran:
|
||||||
|
for i := k - 1; i >= 0; i-- {
|
||||||
|
aii := a[i*lda+i]
|
||||||
|
a[i*lda+i] = 1
|
||||||
|
impl.Dlarf(side, m-i, n, a[i*lda+i:], 1, tau[i], c[i*ldc:], ldc, work)
|
||||||
|
a[i*lda+i] = aii
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case !left && notran:
|
||||||
|
for i := k - 1; i >= 0; i-- {
|
||||||
|
aii := a[i*lda+i]
|
||||||
|
a[i*lda+i] = 1
|
||||||
|
impl.Dlarf(side, m, n-i, a[i*lda+i:], 1, tau[i], c[i:], ldc, work)
|
||||||
|
a[i*lda+i] = aii
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case !left && !notran:
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
aii := a[i*lda+i]
|
||||||
|
a[i*lda+i] = 1
|
||||||
|
impl.Dlarf(side, m, n-i, a[i*lda+i:], 1, tau[i], c[i:], ldc, work)
|
||||||
|
a[i*lda+i] = aii
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
155
native/dormlq.go
Normal file
155
native/dormlq.go
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dormlq multiplies the matrix c by the othogonal matrix q defined by the
|
||||||
|
// slices a and tau. A and tau are as returned from Dgelqf.
|
||||||
|
// C = Q * C if side == blas.Left and trans == blas.NoTrans
|
||||||
|
// C = Q^T * C if side == blas.Left and trans == blas.Trans
|
||||||
|
// C = C * Q if side == blas.Right and trans == blas.NoTrans
|
||||||
|
// C = C * Q^T if side == blas.Right and trans == blas.Trans
|
||||||
|
// If side == blas.Left, a is a matrix of side k×m, and if side == blas.Right
|
||||||
|
// a is of size k×n. This uses a blocked algorithm.
|
||||||
|
//
|
||||||
|
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||||
|
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
|
||||||
|
// and this function will panic otherwise.
|
||||||
|
// Dormlq uses a block algorithm, but the block size is limited
|
||||||
|
// by the temporary space available. If lwork == -1, instead of performing Dormlq,
|
||||||
|
// the optimal work length will be stored into work[0].
|
||||||
|
//
|
||||||
|
// Tau contains the householder scales and must have length at least k, and
|
||||||
|
// this function will panic otherwise.
|
||||||
|
func (impl Implementation) Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
|
||||||
|
if side != blas.Left && side != blas.Right {
|
||||||
|
panic(badSide)
|
||||||
|
}
|
||||||
|
if trans != blas.Trans && trans != blas.NoTrans {
|
||||||
|
panic(badTrans)
|
||||||
|
}
|
||||||
|
left := side == blas.Left
|
||||||
|
notran := trans == blas.NoTrans
|
||||||
|
if left {
|
||||||
|
checkMatrix(k, m, a, lda)
|
||||||
|
} else {
|
||||||
|
checkMatrix(k, n, a, lda)
|
||||||
|
}
|
||||||
|
checkMatrix(m, n, c, ldc)
|
||||||
|
if len(tau) < k {
|
||||||
|
panic(badTau)
|
||||||
|
}
|
||||||
|
|
||||||
|
const nbmax = 64
|
||||||
|
nw := n
|
||||||
|
if !left {
|
||||||
|
nw = m
|
||||||
|
}
|
||||||
|
opts := string(side) + string(trans)
|
||||||
|
nb := min(nbmax, impl.Ilaenv(1, "DORMLQ", opts, m, n, k, -1))
|
||||||
|
lworkopt := max(1, nw) * nb
|
||||||
|
if lwork == -1 {
|
||||||
|
work[0] = float64(lworkopt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if left {
|
||||||
|
if lwork < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if lwork < m {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m == 0 || n == 0 || k == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nbmin := 2
|
||||||
|
|
||||||
|
ldwork := nb
|
||||||
|
if nb > 1 && nb < k {
|
||||||
|
iws := nw * nb
|
||||||
|
if lwork < iws {
|
||||||
|
nb = lwork / nw
|
||||||
|
nbmin = max(2, impl.Ilaenv(2, "DORMLQ", opts, m, n, k, -1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nb < nbmin || nb >= k {
|
||||||
|
// Call unblocked code
|
||||||
|
impl.Dorml2(side, trans, m, n, k, a, lda, tau, c, ldc, work)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ldt := nb
|
||||||
|
t := make([]float64, nb*ldt)
|
||||||
|
|
||||||
|
transt := blas.NoTrans
|
||||||
|
if notran {
|
||||||
|
transt = blas.Trans
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case left && notran:
|
||||||
|
for i := 0; i < k; i += nb {
|
||||||
|
ib := min(nb, k-i)
|
||||||
|
impl.Dlarft(lapack.Forward, lapack.RowWise, m-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
tau[i:],
|
||||||
|
t, ldt)
|
||||||
|
impl.Dlarfb(side, transt, lapack.Forward, lapack.RowWise, m-i, n, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
t, ldt,
|
||||||
|
c[i*ldc:], ldc,
|
||||||
|
work, ldwork)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case left && !notran:
|
||||||
|
for i := ((k - 1) / nb) * nb; i >= 0; i -= nb {
|
||||||
|
ib := min(nb, k-i)
|
||||||
|
impl.Dlarft(lapack.Forward, lapack.RowWise, m-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
tau[i:],
|
||||||
|
t, ldt)
|
||||||
|
impl.Dlarfb(side, transt, lapack.Forward, lapack.RowWise, m-i, n, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
t, ldt,
|
||||||
|
c[i*ldc:], ldc,
|
||||||
|
work, ldwork)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case !left && notran:
|
||||||
|
for i := ((k - 1) / nb) * nb; i >= 0; i -= nb {
|
||||||
|
ib := min(nb, k-i)
|
||||||
|
impl.Dlarft(lapack.Forward, lapack.RowWise, n-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
tau[i:],
|
||||||
|
t, ldt)
|
||||||
|
impl.Dlarfb(side, transt, lapack.Forward, lapack.RowWise, m, n-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
t, ldt,
|
||||||
|
c[i:], ldc,
|
||||||
|
work, ldwork)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case !left && !notran:
|
||||||
|
for i := 0; i < k; i += nb {
|
||||||
|
ib := min(nb, k-i)
|
||||||
|
impl.Dlarft(lapack.Forward, lapack.RowWise, n-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
tau[i:],
|
||||||
|
t, ldt)
|
||||||
|
impl.Dlarfb(side, transt, lapack.Forward, lapack.RowWise, m, n-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
t, ldt,
|
||||||
|
c[i:], ldc,
|
||||||
|
work, ldwork)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
139
native/dormqr.go
Normal file
139
native/dormqr.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dormqr multiplies the matrix c by the othogonal matrix q defined by the
|
||||||
|
// slices a and tau. A and tau are as returned from Dgeqrf.
|
||||||
|
// C = Q * C if side == blas.Left and trans == blas.NoTrans
|
||||||
|
// C = Q^T * C if side == blas.Left and trans == blas.Trans
|
||||||
|
// C = C * Q if side == blas.Right and trans == blas.NoTrans
|
||||||
|
// C = C * Q^T if side == blas.Right and trans == blas.Trans
|
||||||
|
// If side == blas.Left, a is a matrix of side k×m, and if side == blas.Right
|
||||||
|
// a is of size k×n. This uses a blocked algorithm.
|
||||||
|
//
|
||||||
|
// Work is temporary storage, and lwork specifies the usable memory length.
|
||||||
|
// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right,
|
||||||
|
// and this function will panic otherwise.
|
||||||
|
// Dormqr uses a block algorithm, but the block size is limited
|
||||||
|
// by the temporary space available. If lwork == -1, instead of performing Dormqr,
|
||||||
|
// the optimal work length will be stored into work[0].
|
||||||
|
//
|
||||||
|
// Tau contains the householder scales and must have length at least k, and
|
||||||
|
// this function will panic otherwise.
|
||||||
|
func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
|
||||||
|
left := side == blas.Left
|
||||||
|
notran := trans == blas.NoTrans
|
||||||
|
if left {
|
||||||
|
checkMatrix(m, k, a, lda)
|
||||||
|
} else {
|
||||||
|
checkMatrix(n, k, a, lda)
|
||||||
|
}
|
||||||
|
checkMatrix(m, n, c, ldc)
|
||||||
|
|
||||||
|
const nbmax = 64
|
||||||
|
nw := n
|
||||||
|
if side == blas.Right {
|
||||||
|
nw = m
|
||||||
|
}
|
||||||
|
opts := string(side) + string(trans)
|
||||||
|
nb := min(nbmax, impl.Ilaenv(1, "DORMQR", opts, m, n, k, -1))
|
||||||
|
lworkopt := max(1, nw) * nb
|
||||||
|
if lwork == -1 {
|
||||||
|
work[0] = float64(lworkopt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if left {
|
||||||
|
if lwork < n {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if lwork < m {
|
||||||
|
panic(badWork)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if m == 0 || n == 0 || k == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nbmin := 2
|
||||||
|
|
||||||
|
ldwork := nb
|
||||||
|
if nb > 1 && nb < k {
|
||||||
|
iws := nw * nb
|
||||||
|
if lwork < iws {
|
||||||
|
nb = lwork / nw
|
||||||
|
nbmin = max(2, impl.Ilaenv(2, "DORMQR", opts, m, n, k, -1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nb < nbmin || nb >= k {
|
||||||
|
// Call unblocked code
|
||||||
|
impl.Dorm2r(side, trans, m, n, k, a, lda, tau, c, ldc, work)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ldt := nb
|
||||||
|
t := make([]float64, nb*ldt)
|
||||||
|
switch {
|
||||||
|
case left && notran:
|
||||||
|
for i := ((k - 1) / nb) * nb; i >= 0; i -= nb {
|
||||||
|
ib := min(nb, k-i)
|
||||||
|
impl.Dlarft(lapack.Forward, lapack.ColumnWise, m-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
tau[i:],
|
||||||
|
t, ldt)
|
||||||
|
impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m-i, n, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
t, ldt,
|
||||||
|
c[i*ldc:], ldc,
|
||||||
|
work, ldwork)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case left && !notran:
|
||||||
|
for i := 0; i < k; i += nb {
|
||||||
|
ib := min(nb, k-i)
|
||||||
|
impl.Dlarft(lapack.Forward, lapack.ColumnWise, m-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
tau[i:],
|
||||||
|
t, ldt)
|
||||||
|
impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m-i, n, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
t, ldt,
|
||||||
|
c[i*ldc:], ldc,
|
||||||
|
work, ldwork)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case !left && notran:
|
||||||
|
for i := 0; i < k; i += nb {
|
||||||
|
ib := min(nb, k-i)
|
||||||
|
impl.Dlarft(lapack.Forward, lapack.ColumnWise, n-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
tau[i:],
|
||||||
|
t, ldt)
|
||||||
|
impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m, n-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
t, ldt,
|
||||||
|
c[i:], ldc,
|
||||||
|
work, ldwork)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case !left && !notran:
|
||||||
|
for i := ((k - 1) / nb) * nb; i >= 0; i -= nb {
|
||||||
|
ib := min(nb, k-i)
|
||||||
|
impl.Dlarft(lapack.Forward, lapack.ColumnWise, n-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
tau[i:],
|
||||||
|
t, ldt)
|
||||||
|
impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m, n-i, ib,
|
||||||
|
a[i*lda+i:], lda,
|
||||||
|
t, ldt,
|
||||||
|
c[i:], ldc,
|
||||||
|
work, ldwork)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
@@ -43,7 +43,9 @@ func (Implementation) Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool
|
|||||||
ajj = math.Sqrt(ajj)
|
ajj = math.Sqrt(ajj)
|
||||||
a[j*lda+j] = ajj
|
a[j*lda+j] = ajj
|
||||||
if j < n-1 {
|
if j < n-1 {
|
||||||
bi.Dgemv(blas.Trans, j, n-j-1, -1, a[j+1:], lda, a[j:], lda, 1, a[j*lda+j+1:], 1)
|
bi.Dgemv(blas.Trans, j, n-j-1,
|
||||||
|
-1, a[j+1:], lda, a[j:], lda,
|
||||||
|
1, a[j*lda+j+1:], 1)
|
||||||
bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1)
|
bi.Dscal(n-j-1, 1/ajj, a[j*lda+j+1:], 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -61,7 +63,9 @@ func (Implementation) Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool
|
|||||||
ajj = math.Sqrt(ajj)
|
ajj = math.Sqrt(ajj)
|
||||||
a[j*lda+j] = ajj
|
a[j*lda+j] = ajj
|
||||||
if j < n-1 {
|
if j < n-1 {
|
||||||
bi.Dgemv(blas.NoTrans, n-j-1, j, -1, a[(j+1)*lda:], lda, a[j*lda:], 1, 1, a[(j+1)*lda+j:], lda)
|
bi.Dgemv(blas.NoTrans, n-j-1, j,
|
||||||
|
-1, a[(j+1)*lda:], lda, a[j*lda:], 1,
|
||||||
|
1, a[(j+1)*lda+j:], lda)
|
||||||
bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda)
|
bi.Dscal(n-j-1, 1/ajj, a[(j+1)*lda+j:], lda)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -28,39 +28,47 @@ func (impl Implementation) Dpotrf(ul blas.Uplo, n int, a []float64, lda int) (ok
|
|||||||
if n == 0 {
|
if n == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
nb := blockSize()
|
nb := impl.Ilaenv(1, "DPOTRF", string(ul), n, -1, -1, -1)
|
||||||
if n <= nb {
|
if n <= nb {
|
||||||
return impl.Dpotf2(ul, n, a, lda)
|
return impl.Dpotf2(ul, n, a, lda)
|
||||||
}
|
}
|
||||||
if ul == blas.Upper {
|
if ul == blas.Upper {
|
||||||
for j := 0; j < n; j += nb {
|
for j := 0; j < n; j += nb {
|
||||||
jb := min(nb, n-j)
|
jb := min(nb, n-j)
|
||||||
bi.Dsyrk(blas.Upper, blas.Trans, jb, j, -1, a[j:], lda, 1, a[j*lda+j:], lda)
|
bi.Dsyrk(blas.Upper, blas.Trans, jb, j,
|
||||||
|
-1, a[j:], lda,
|
||||||
|
1, a[j*lda+j:], lda)
|
||||||
ok = impl.Dpotf2(blas.Upper, jb, a[j*lda+j:], lda)
|
ok = impl.Dpotf2(blas.Upper, jb, a[j*lda+j:], lda)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
if j+jb < n {
|
if j+jb < n {
|
||||||
bi.Dgemm(blas.Trans, blas.NoTrans, jb, n-j-jb, j, -1,
|
bi.Dgemm(blas.Trans, blas.NoTrans, jb, n-j-jb, j,
|
||||||
a[j:], lda, a[j+jb:], lda, 1, a[j*lda+j+jb:], lda)
|
-1, a[j:], lda, a[j+jb:], lda,
|
||||||
bi.Dtrsm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit, jb, n-j-jb, 1,
|
1, a[j*lda+j+jb:], lda)
|
||||||
a[j*lda+j:], lda, a[j*lda+j+jb:], lda)
|
bi.Dtrsm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit, jb, n-j-jb,
|
||||||
|
1, a[j*lda+j:], lda,
|
||||||
|
a[j*lda+j+jb:], lda)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
for j := 0; j < n; j += nb {
|
for j := 0; j < n; j += nb {
|
||||||
jb := min(nb, n-j)
|
jb := min(nb, n-j)
|
||||||
bi.Dsyrk(blas.Lower, blas.NoTrans, jb, j, -1, a[j*lda:], lda, 1, a[j*lda+j:], lda)
|
bi.Dsyrk(blas.Lower, blas.NoTrans, jb, j,
|
||||||
|
-1, a[j*lda:], lda,
|
||||||
|
1, a[j*lda+j:], lda)
|
||||||
ok := impl.Dpotf2(blas.Lower, jb, a[j*lda+j:], lda)
|
ok := impl.Dpotf2(blas.Lower, jb, a[j*lda+j:], lda)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
if j+jb < n {
|
if j+jb < n {
|
||||||
bi.Dgemm(blas.NoTrans, blas.Trans, n-j-jb, jb, j, -1,
|
bi.Dgemm(blas.NoTrans, blas.Trans, n-j-jb, jb, j,
|
||||||
a[(j+jb)*lda:], lda, a[j*lda:], lda, 1, a[(j+jb)*lda+j:], lda)
|
-1, a[(j+jb)*lda:], lda, a[j*lda:], lda,
|
||||||
bi.Dtrsm(blas.Right, blas.Lower, blas.Trans, blas.NonUnit, n-j-jb, jb, 1,
|
1, a[(j+jb)*lda+j:], lda)
|
||||||
a[j*lda+j:], lda, a[(j+jb)*lda+j:], lda)
|
bi.Dtrsm(blas.Right, blas.Lower, blas.Trans, blas.NonUnit, n-j-jb, jb,
|
||||||
|
1, a[j*lda+j:], lda,
|
||||||
|
a[(j+jb)*lda+j:], lda)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
31
native/dtrtrs.go
Normal file
31
native/dtrtrs.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dtrtrs solves a triangular system of the form a * x = b or a^T * x = b. Dtrtrs
|
||||||
|
// checks for singularity in a. If a is singular, false is returned and no solve
|
||||||
|
// is performed. True is returned otherwise.
|
||||||
|
func (impl Implementation) Dtrtrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, nrhs int, a []float64, lda int, b []float64, ldb int) (ok bool) {
|
||||||
|
nounit := diag == blas.NonUnit
|
||||||
|
if n == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Check for singularity.
|
||||||
|
if nounit {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if a[i*lda+i] == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bi := blas64.Implementation()
|
||||||
|
bi.Dtrsm(blas.Left, uplo, trans, diag, n, nrhs, 1, a, lda, b, ldb)
|
||||||
|
return true
|
||||||
|
}
|
@@ -4,20 +4,57 @@
|
|||||||
|
|
||||||
package native
|
package native
|
||||||
|
|
||||||
import "github.com/gonum/lapack"
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Implementation is the native Go implementation of LAPACK routines. It
|
||||||
|
// is built on top of calls to the return of blas64.Implementation(), so while
|
||||||
|
// this code is in pure Go, the underlying BLAS implementation may not be.
|
||||||
type Implementation struct{}
|
type Implementation struct{}
|
||||||
|
|
||||||
var _ lapack.Float64 = Implementation{}
|
var _ lapack.Float64 = Implementation{}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
badUplo = "lapack: illegal triangle"
|
badDirect = "lapack: bad direct"
|
||||||
nLT0 = "lapack: n < 0"
|
badLdA = "lapack: index of a out of range"
|
||||||
badLdA = "lapack: index of a out of range"
|
badSide = "lapack: bad side"
|
||||||
|
badStore = "lapack: bad store"
|
||||||
|
badTau = "lapack: tau has insufficient length"
|
||||||
|
badTrans = "lapack: bad trans"
|
||||||
|
badUplo = "lapack: illegal triangle"
|
||||||
|
badWork = "lapack: insufficient working memory"
|
||||||
|
badWorkStride = "lapack: insufficient working array stride"
|
||||||
|
negDimension = "lapack: negative matrix dimension"
|
||||||
|
nLT0 = "lapack: n < 0"
|
||||||
|
shortWork = "lapack: working array shorter than declared"
|
||||||
)
|
)
|
||||||
|
|
||||||
func blockSize() int {
|
// checkMatrix verifies the parameters of a matrix input.
|
||||||
return 64
|
func checkMatrix(m, n int, a []float64, lda int) {
|
||||||
|
if m < 0 {
|
||||||
|
panic("lapack: has negative number of rows")
|
||||||
|
}
|
||||||
|
if m < 0 {
|
||||||
|
panic("lapack: has negative number of columns")
|
||||||
|
}
|
||||||
|
if lda < n {
|
||||||
|
panic("lapack: stride less than number of columns")
|
||||||
|
}
|
||||||
|
if len(a) < (m-1)*lda+n {
|
||||||
|
panic("lapack: insufficient matrix slice length")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkVector(n int, v []float64, inc int) {
|
||||||
|
if n < 0 {
|
||||||
|
panic("lapack: negative matrix length")
|
||||||
|
}
|
||||||
|
if (inc > 0 && (n-1)*inc >= len(v)) || (inc < 0 && (1-n)*inc >= len(v)) {
|
||||||
|
panic("lapack: insufficient vector slice length")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func min(a, b int) int {
|
func min(a, b int) int {
|
||||||
@@ -33,3 +70,23 @@ func max(a, b int) int {
|
|||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dlamch is a function in fortran, but since go forces IEEE-754 these are all
|
||||||
|
// fixed values. Probably a way to get them as constants.
|
||||||
|
// TODO(btracey): Is there a better way to find the smallest number such that 1+E > 1
|
||||||
|
|
||||||
|
var dlamchE, dlamchS, dlamchP float64
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
onePlusEps := math.Nextafter(1, math.Inf(1))
|
||||||
|
eps := (math.Nextafter(1, math.Inf(1)) - 1) * 0.5
|
||||||
|
dlamchE = eps
|
||||||
|
sfmin := math.SmallestNonzeroFloat64
|
||||||
|
small := 1 / math.MaxFloat64
|
||||||
|
if small >= sfmin {
|
||||||
|
sfmin = small * onePlusEps
|
||||||
|
}
|
||||||
|
dlamchS = sfmin
|
||||||
|
radix := 2.0
|
||||||
|
dlamchP = radix * eps
|
||||||
|
}
|
||||||
|
@@ -7,9 +7,11 @@ package native
|
|||||||
// Iladlc scans a matrix for its last non-zero column. Returns -1 if the matrix
|
// Iladlc scans a matrix for its last non-zero column. Returns -1 if the matrix
|
||||||
// is all zeros.
|
// is all zeros.
|
||||||
func (Implementation) Iladlc(m, n int, a []float64, lda int) int {
|
func (Implementation) Iladlc(m, n int, a []float64, lda int) int {
|
||||||
if n == 0 {
|
if n == 0 || m == 0 {
|
||||||
return n - 1
|
return n - 1
|
||||||
}
|
}
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
|
||||||
// Test common case where corner is non-zero.
|
// Test common case where corner is non-zero.
|
||||||
if a[n-1] != 0 || a[(m-1)*lda+(n-1)] != 0 {
|
if a[n-1] != 0 || a[(m-1)*lda+(n-1)] != 0 {
|
||||||
return n - 1
|
return n - 1
|
||||||
|
@@ -10,6 +10,9 @@ func (Implementation) Iladlr(m, n int, a []float64, lda int) int {
|
|||||||
if m == 0 {
|
if m == 0 {
|
||||||
return m - 1
|
return m - 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkMatrix(m, n, a, lda)
|
||||||
|
|
||||||
// Check the common case where the corner is non-zero
|
// Check the common case where the corner is non-zero
|
||||||
if a[(m-1)*lda] != 0 || a[(m-1)*lda+n-1] != 0 {
|
if a[(m-1)*lda] != 0 || a[(m-1)*lda+n-1] != 0 {
|
||||||
return m - 1
|
return m - 1
|
||||||
|
375
native/ilaenv.go
Normal file
375
native/ilaenv.go
Normal file
@@ -0,0 +1,375 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package native
|
||||||
|
|
||||||
|
// Ilaenv returns algorithm tuning parameters for the algorithm given by the
|
||||||
|
// input string. ispec specifies the parameter to return.
|
||||||
|
// 1: The optimal block size
|
||||||
|
// 2: The minimum block size for which the algorithm should be used.
|
||||||
|
// 3: The crossover point below which an unblocked routine should be used.
|
||||||
|
// 4: The number of shifts.
|
||||||
|
// 5: The minumum column dimension for blocking to be used.
|
||||||
|
// 6: The crossover point for SVD (to use QR factorization or not).
|
||||||
|
// 7: The number of processors.
|
||||||
|
// 8: The crossover point for multishift in QR and QZ methods for nonsymmetric eigenvalue problems.
|
||||||
|
// 9: Maximum size of the subproblems in divide-and-conquer algorithms.
|
||||||
|
// 10: ieee NaN arithmetic can be trusted not to trap.
|
||||||
|
// 11: infinity arithmetic can be trusted not to trap.
|
||||||
|
func (Implementation) Ilaenv(ispec int, s string, opts string, n1, n2, n3, n4 int) int {
|
||||||
|
// TODO(btracey): Replace this with a constant lookup? A list of constants?
|
||||||
|
// TODO: What is the difference between 2 and 3?
|
||||||
|
sname := s[0] == 'S' || s[0] == 'D'
|
||||||
|
cname := s[0] == 'C' || s[0] == 'Z'
|
||||||
|
if !sname && !cname {
|
||||||
|
panic("lapack: bad name")
|
||||||
|
}
|
||||||
|
c2 := s[1:3]
|
||||||
|
c3 := s[3:6]
|
||||||
|
c4 := c3[1:3]
|
||||||
|
|
||||||
|
switch ispec {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad ispec")
|
||||||
|
case 1:
|
||||||
|
switch c2 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "GE":
|
||||||
|
switch c3 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "TRF":
|
||||||
|
if sname {
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
return 64
|
||||||
|
case "QRF", "RQF", "LQF", "QLF":
|
||||||
|
if sname {
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
return 32
|
||||||
|
case "HRD":
|
||||||
|
if sname {
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
return 32
|
||||||
|
case "BRD":
|
||||||
|
if sname {
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
return 32
|
||||||
|
case "TRI":
|
||||||
|
if sname {
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
case "PO":
|
||||||
|
switch c3 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "TRF":
|
||||||
|
if sname {
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
case "SY":
|
||||||
|
switch c3 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "TRF":
|
||||||
|
if sname {
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
return 64
|
||||||
|
case "TRD":
|
||||||
|
return 32
|
||||||
|
case "GST":
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
case "HE":
|
||||||
|
switch c3 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "TRF":
|
||||||
|
return 64
|
||||||
|
case "TRD":
|
||||||
|
return 32
|
||||||
|
case "GST":
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
case "OR":
|
||||||
|
switch c3[0] {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case 'G':
|
||||||
|
switch c3[1:] {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QR", "RQ", "LQ", "QL", "HR", "TR", "BR":
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
case 'M':
|
||||||
|
switch c3[1:] {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QR", "RQ", "LQ", "QL", "HR", "TR", "BR":
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "UN":
|
||||||
|
switch c3[0] {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case 'G':
|
||||||
|
switch c3[1:] {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QR", "RQ", "LQ", "QL", "HR", "TR", "BR":
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
case 'M':
|
||||||
|
switch c3[1:] {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QR", "RQ", "LQ", "QL", "HR", "TR", "BR":
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "GB":
|
||||||
|
switch c3 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "TRF":
|
||||||
|
if sname {
|
||||||
|
if n4 <= 64 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
if n4 <= 64 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
case "PB":
|
||||||
|
switch c3 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "TRF":
|
||||||
|
if sname {
|
||||||
|
if n4 <= 64 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
if n4 <= 64 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
case "TR":
|
||||||
|
switch c3 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "TRI":
|
||||||
|
if sname {
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
case "LA":
|
||||||
|
switch c3 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "UUM":
|
||||||
|
if sname {
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
return 64
|
||||||
|
}
|
||||||
|
case "ST":
|
||||||
|
if sname && c3 == "EBZ" {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
switch c2 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "GE":
|
||||||
|
switch c3 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QRF", "RQF", "LQF", "QLF":
|
||||||
|
if sname {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
return 2
|
||||||
|
case "HRD":
|
||||||
|
if sname {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
return 2
|
||||||
|
case "BRD":
|
||||||
|
if sname {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
return 2
|
||||||
|
case "TRI":
|
||||||
|
if sname {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
case "SY":
|
||||||
|
switch c3 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "TRF":
|
||||||
|
if sname {
|
||||||
|
return 8
|
||||||
|
}
|
||||||
|
return 8
|
||||||
|
case "TRD":
|
||||||
|
if sname {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
}
|
||||||
|
case "HE":
|
||||||
|
if c3 == "TRD" {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "OR":
|
||||||
|
if !sname {
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
}
|
||||||
|
switch c3[0] {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case 'G':
|
||||||
|
switch c4 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QR", "RQ", "LQ", "QL", "HR", "TR", "BR":
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
case 'M':
|
||||||
|
switch c4 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QR", "RQ", "LQ", "QL", "HR", "TR", "BR":
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "UN":
|
||||||
|
switch c3[0] {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case 'G':
|
||||||
|
switch c4 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QR", "RQ", "LQ", "QL", "HR", "TR", "BR":
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
case 'M':
|
||||||
|
switch c4 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QR", "RQ", "LQ", "QL", "HR", "TR", "BR":
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 3:
|
||||||
|
switch c2 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "GE":
|
||||||
|
switch c3 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QRF", "RQF", "LQF", "QLF":
|
||||||
|
if sname {
|
||||||
|
return 128
|
||||||
|
}
|
||||||
|
return 128
|
||||||
|
case "HRD":
|
||||||
|
if sname {
|
||||||
|
return 128
|
||||||
|
}
|
||||||
|
return 128
|
||||||
|
case "BRD":
|
||||||
|
if sname {
|
||||||
|
return 128
|
||||||
|
}
|
||||||
|
return 128
|
||||||
|
}
|
||||||
|
case "SY":
|
||||||
|
if sname && c3 == "TRD" {
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "HE":
|
||||||
|
if c3 == "TRD" {
|
||||||
|
return 32
|
||||||
|
}
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "OR":
|
||||||
|
switch c3[0] {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case 'G':
|
||||||
|
switch c4 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QR", "RQ", "LQ", "QL", "HR", "TR", "BR":
|
||||||
|
return 128
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "UN":
|
||||||
|
switch c3[0] {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case 'G':
|
||||||
|
switch c4 {
|
||||||
|
default:
|
||||||
|
panic("lapack: bad function name")
|
||||||
|
case "QR", "RQ", "LQ", "QL", "HR", "TR", "BR":
|
||||||
|
return 128
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 4:
|
||||||
|
// Used by xHSEQR
|
||||||
|
return 6
|
||||||
|
case 5:
|
||||||
|
// Not used
|
||||||
|
return 2
|
||||||
|
case 6:
|
||||||
|
// Used by xGELSS and xGESVD
|
||||||
|
return min(n1, n2) * 1e6
|
||||||
|
case 7:
|
||||||
|
// Not used
|
||||||
|
return 1
|
||||||
|
case 8:
|
||||||
|
// Used by xHSEQR
|
||||||
|
return 50
|
||||||
|
case 9:
|
||||||
|
// used by xGELSD and xGESDD
|
||||||
|
return 25
|
||||||
|
case 10:
|
||||||
|
// Go guarantees ieee
|
||||||
|
return 1
|
||||||
|
case 11:
|
||||||
|
// Go guarantees ieee
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
@@ -12,10 +12,62 @@ import (
|
|||||||
|
|
||||||
var impl = Implementation{}
|
var impl = Implementation{}
|
||||||
|
|
||||||
|
func TestDgelqf(t *testing.T) {
|
||||||
|
testlapack.DgelqfTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDgelq2(t *testing.T) {
|
||||||
|
testlapack.Dgelq2Test(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDgels(t *testing.T) {
|
||||||
|
testlapack.DgelsTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDgeqr2(t *testing.T) {
|
||||||
|
testlapack.Dgeqr2Test(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDgeqrf(t *testing.T) {
|
||||||
|
testlapack.DgeqrfTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDlange(t *testing.T) {
|
||||||
|
testlapack.DlangeTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDlarfb(t *testing.T) {
|
||||||
|
testlapack.DlarfbTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDlarf(t *testing.T) {
|
func TestDlarf(t *testing.T) {
|
||||||
testlapack.DlarfTest(t, impl)
|
testlapack.DlarfTest(t, impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDlarfg(t *testing.T) {
|
||||||
|
testlapack.DlarfgTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDlarft(t *testing.T) {
|
||||||
|
testlapack.DlarftTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDorml2(t *testing.T) {
|
||||||
|
testlapack.Dorml2Test(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDormlq(t *testing.T) {
|
||||||
|
testlapack.DormlqTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDormqr(t *testing.T) {
|
||||||
|
testlapack.DormqrTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDorm2r(t *testing.T) {
|
||||||
|
testlapack.Dorm2rTest(t, impl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDpotf2(t *testing.T) {
|
func TestDpotf2(t *testing.T) {
|
||||||
testlapack.Dpotf2Test(t, impl)
|
testlapack.Dpotf2Test(t, impl)
|
||||||
}
|
}
|
||||||
|
112
testlapack/dgelq2.go
Normal file
112
testlapack/dgelq2.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dgelq2er interface {
|
||||||
|
Dgelq2(m, n int, a []float64, lda int, tau, work []float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dgelq2Test(t *testing.T, impl Dgelq2er) {
|
||||||
|
for c, test := range []struct {
|
||||||
|
m, n, lda int
|
||||||
|
}{
|
||||||
|
{1, 1, 0},
|
||||||
|
{2, 2, 0},
|
||||||
|
{3, 2, 0},
|
||||||
|
{2, 3, 0},
|
||||||
|
{1, 12, 0},
|
||||||
|
{2, 6, 0},
|
||||||
|
{3, 4, 0},
|
||||||
|
{4, 3, 0},
|
||||||
|
{6, 2, 0},
|
||||||
|
{1, 12, 0},
|
||||||
|
{1, 1, 20},
|
||||||
|
{2, 2, 20},
|
||||||
|
{3, 2, 20},
|
||||||
|
{2, 3, 20},
|
||||||
|
{1, 12, 20},
|
||||||
|
{2, 6, 20},
|
||||||
|
{3, 4, 20},
|
||||||
|
{4, 3, 20},
|
||||||
|
{6, 2, 20},
|
||||||
|
{1, 12, 20},
|
||||||
|
} {
|
||||||
|
n := test.n
|
||||||
|
m := test.m
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = test.n
|
||||||
|
}
|
||||||
|
k := min(m, n)
|
||||||
|
tau := make([]float64, k)
|
||||||
|
for i := range tau {
|
||||||
|
tau[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
work := make([]float64, m)
|
||||||
|
for i := range work {
|
||||||
|
work[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
for i := 0; i < m*lda; i++ {
|
||||||
|
a[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
impl.Dgelq2(m, n, a, lda, tau, work)
|
||||||
|
|
||||||
|
Q := constructQ("LQ", m, n, a, lda, tau)
|
||||||
|
|
||||||
|
// Check that Q is orthonormal
|
||||||
|
for i := 0; i < Q.Rows; i++ {
|
||||||
|
nrm := blas64.Nrm2(Q.Cols, blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]})
|
||||||
|
if math.Abs(nrm-1) > 1e-14 {
|
||||||
|
t.Errorf("Q not normal. Norm is %v", nrm)
|
||||||
|
}
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
dot := blas64.Dot(Q.Rows,
|
||||||
|
blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]},
|
||||||
|
blas64.Vector{Inc: 1, Data: Q.Data[j*Q.Stride:]},
|
||||||
|
)
|
||||||
|
if math.Abs(dot) > 1e-14 {
|
||||||
|
t.Errorf("Q not orthogonal. Dot is %v", dot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
L := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: n,
|
||||||
|
Stride: n,
|
||||||
|
Data: make([]float64, m*n),
|
||||||
|
}
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j <= min(i, n-1); j++ {
|
||||||
|
L.Data[i*L.Stride+j] = a[i*lda+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ans := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: n,
|
||||||
|
Stride: lda,
|
||||||
|
Data: make([]float64, m*lda),
|
||||||
|
}
|
||||||
|
copy(ans.Data, aCopy)
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, Q, 0, ans)
|
||||||
|
if !floats.EqualApprox(aCopy, ans.Data, 1e-14) {
|
||||||
|
t.Errorf("Case %v, LQ mismatch. Want %v, got %v.", c, aCopy, ans.Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
94
testlapack/dgelqf.go
Normal file
94
testlapack/dgelqf.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dgelqfer interface {
|
||||||
|
Dgelq2er
|
||||||
|
Dgelqf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DgelqfTest(t *testing.T, impl Dgelqfer) {
|
||||||
|
for c, test := range []struct {
|
||||||
|
m, n, lda int
|
||||||
|
}{
|
||||||
|
{10, 5, 0},
|
||||||
|
{5, 10, 0},
|
||||||
|
{10, 10, 0},
|
||||||
|
{300, 5, 0},
|
||||||
|
{3, 500, 0},
|
||||||
|
{200, 200, 0},
|
||||||
|
{300, 200, 0},
|
||||||
|
{204, 300, 0},
|
||||||
|
{1, 3000, 0},
|
||||||
|
{3000, 1, 0},
|
||||||
|
{10, 5, 30},
|
||||||
|
{5, 10, 30},
|
||||||
|
{10, 10, 30},
|
||||||
|
{300, 5, 500},
|
||||||
|
{3, 500, 600},
|
||||||
|
{200, 200, 300},
|
||||||
|
{300, 200, 300},
|
||||||
|
{204, 300, 400},
|
||||||
|
{1, 3000, 4000},
|
||||||
|
{3000, 1, 4000},
|
||||||
|
} {
|
||||||
|
m := test.m
|
||||||
|
n := test.n
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = n
|
||||||
|
}
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
a[i*lda+j] = rand.Float64()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tau := make([]float64, n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
tau[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
ans := make([]float64, len(a))
|
||||||
|
copy(ans, a)
|
||||||
|
work := make([]float64, m)
|
||||||
|
for i := range work {
|
||||||
|
work[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
// Compute unblocked QR.
|
||||||
|
impl.Dgelq2(m, n, ans, lda, tau, work)
|
||||||
|
// Compute blocked QR with small work.
|
||||||
|
impl.Dgelqf(m, n, a, lda, tau, work, len(work))
|
||||||
|
if !floats.EqualApprox(ans, a, 1e-14) {
|
||||||
|
t.Errorf("Case %v, mismatch small work.", c)
|
||||||
|
}
|
||||||
|
// Try the full length of work.
|
||||||
|
impl.Dgelqf(m, n, a, lda, tau, work, -1)
|
||||||
|
lwork := int(work[0])
|
||||||
|
work = make([]float64, lwork)
|
||||||
|
copy(a, aCopy)
|
||||||
|
impl.Dgelqf(m, n, a, lda, tau, work, lwork)
|
||||||
|
if !floats.EqualApprox(ans, a, 1e-12) {
|
||||||
|
t.Errorf("Case %v, mismatch large work.", c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try a slightly smaller version of work to test blocking code.
|
||||||
|
work = work[1:]
|
||||||
|
lwork--
|
||||||
|
copy(a, aCopy)
|
||||||
|
impl.Dgelqf(m, n, a, lda, tau, work, lwork)
|
||||||
|
if !floats.EqualApprox(ans, a, 1e-12) {
|
||||||
|
t.Errorf("Case %v, mismatch large work.", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
181
testlapack/dgels.go
Normal file
181
testlapack/dgels.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dgelser interface {
|
||||||
|
Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func DgelsTest(t *testing.T, impl Dgelser) {
|
||||||
|
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
|
||||||
|
for _, test := range []struct {
|
||||||
|
m, n, nrhs, lda, ldb int
|
||||||
|
}{
|
||||||
|
{3, 4, 5, 0, 0},
|
||||||
|
{3, 5, 4, 0, 0},
|
||||||
|
{4, 3, 5, 0, 0},
|
||||||
|
{4, 5, 3, 0, 0},
|
||||||
|
{5, 3, 4, 0, 0},
|
||||||
|
{5, 4, 3, 0, 0},
|
||||||
|
{3, 4, 5, 10, 20},
|
||||||
|
{3, 5, 4, 10, 20},
|
||||||
|
{4, 3, 5, 10, 20},
|
||||||
|
{4, 5, 3, 10, 20},
|
||||||
|
{5, 3, 4, 10, 20},
|
||||||
|
{5, 4, 3, 10, 20},
|
||||||
|
{3, 4, 5, 20, 10},
|
||||||
|
{3, 5, 4, 20, 10},
|
||||||
|
{4, 3, 5, 20, 10},
|
||||||
|
{4, 5, 3, 20, 10},
|
||||||
|
{5, 3, 4, 20, 10},
|
||||||
|
{5, 4, 3, 20, 10},
|
||||||
|
{200, 300, 400, 0, 0},
|
||||||
|
{200, 400, 300, 0, 0},
|
||||||
|
{300, 200, 400, 0, 0},
|
||||||
|
{300, 400, 200, 0, 0},
|
||||||
|
{400, 200, 300, 0, 0},
|
||||||
|
{400, 300, 200, 0, 0},
|
||||||
|
{200, 300, 400, 500, 600},
|
||||||
|
{200, 400, 300, 500, 600},
|
||||||
|
{300, 200, 400, 500, 600},
|
||||||
|
{300, 400, 200, 500, 600},
|
||||||
|
{400, 200, 300, 500, 600},
|
||||||
|
{400, 300, 200, 500, 600},
|
||||||
|
{200, 300, 400, 600, 500},
|
||||||
|
{200, 400, 300, 600, 500},
|
||||||
|
{300, 200, 400, 600, 500},
|
||||||
|
{300, 400, 200, 600, 500},
|
||||||
|
{400, 200, 300, 600, 500},
|
||||||
|
{400, 300, 200, 600, 500},
|
||||||
|
} {
|
||||||
|
m := test.m
|
||||||
|
n := test.n
|
||||||
|
nrhs := test.nrhs
|
||||||
|
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = n
|
||||||
|
}
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
|
||||||
|
// Size of b is the same trans or no trans, because the number of rows
|
||||||
|
// has to be the max of (m,n).
|
||||||
|
mb := max(m, n)
|
||||||
|
nb := nrhs
|
||||||
|
ldb := test.ldb
|
||||||
|
if ldb == 0 {
|
||||||
|
ldb = nb
|
||||||
|
}
|
||||||
|
b := make([]float64, mb*ldb)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
bCopy := make([]float64, len(b))
|
||||||
|
copy(bCopy, b)
|
||||||
|
|
||||||
|
// Find optimal work length.
|
||||||
|
work := make([]float64, 1)
|
||||||
|
impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, -1)
|
||||||
|
|
||||||
|
// Perform linear solve
|
||||||
|
work = make([]float64, int(work[0]))
|
||||||
|
lwork := len(work)
|
||||||
|
for i := range work {
|
||||||
|
work[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, lwork)
|
||||||
|
|
||||||
|
// Check that the answer is correct by comparing to the normal equations.
|
||||||
|
aMat := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: n,
|
||||||
|
Stride: lda,
|
||||||
|
Data: make([]float64, len(aCopy)),
|
||||||
|
}
|
||||||
|
copy(aMat.Data, aCopy)
|
||||||
|
szAta := n
|
||||||
|
if trans == blas.Trans {
|
||||||
|
szAta = m
|
||||||
|
}
|
||||||
|
aTA := blas64.General{
|
||||||
|
Rows: szAta,
|
||||||
|
Cols: szAta,
|
||||||
|
Stride: szAta,
|
||||||
|
Data: make([]float64, szAta*szAta),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute A^T * A if notrans and A * A^T otherwise.
|
||||||
|
if trans == blas.NoTrans {
|
||||||
|
blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, aMat, 0, aTA)
|
||||||
|
} else {
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.Trans, 1, aMat, aMat, 0, aTA)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply by X.
|
||||||
|
X := blas64.General{
|
||||||
|
Rows: szAta,
|
||||||
|
Cols: nrhs,
|
||||||
|
Stride: ldb,
|
||||||
|
Data: b,
|
||||||
|
}
|
||||||
|
ans := blas64.General{
|
||||||
|
Rows: aTA.Rows,
|
||||||
|
Cols: X.Cols,
|
||||||
|
Stride: X.Cols,
|
||||||
|
Data: make([]float64, aTA.Rows*X.Cols),
|
||||||
|
}
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aTA, X, 0, ans)
|
||||||
|
|
||||||
|
B := blas64.General{
|
||||||
|
Rows: szAta,
|
||||||
|
Cols: nrhs,
|
||||||
|
Stride: ldb,
|
||||||
|
Data: make([]float64, len(bCopy)),
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(B.Data, bCopy)
|
||||||
|
var ans2 blas64.General
|
||||||
|
if trans == blas.NoTrans {
|
||||||
|
ans2 = blas64.General{
|
||||||
|
Rows: aMat.Cols,
|
||||||
|
Cols: B.Cols,
|
||||||
|
Stride: B.Cols,
|
||||||
|
Data: make([]float64, aMat.Cols*B.Cols),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ans2 = blas64.General{
|
||||||
|
Rows: aMat.Rows,
|
||||||
|
Cols: B.Cols,
|
||||||
|
Stride: B.Cols,
|
||||||
|
Data: make([]float64, aMat.Rows*B.Cols),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute A^T B if Trans or A * B otherwise
|
||||||
|
if trans == blas.NoTrans {
|
||||||
|
blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, B, 0, ans2)
|
||||||
|
} else {
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, B, 0, ans2)
|
||||||
|
}
|
||||||
|
if !floats.EqualApprox(ans.Data, ans2.Data, 1e-12) {
|
||||||
|
t.Errorf("Normal equations not satisfied")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
110
testlapack/dgeqr2.go
Normal file
110
testlapack/dgeqr2.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dgeqr2er interface {
|
||||||
|
Dgeqr2(m, n int, a []float64, lda int, tau []float64, work []float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dgeqr2Test(t *testing.T, impl Dgeqr2er) {
|
||||||
|
for c, test := range []struct {
|
||||||
|
m, n, lda int
|
||||||
|
}{
|
||||||
|
{1, 1, 0},
|
||||||
|
{2, 2, 0},
|
||||||
|
{3, 2, 0},
|
||||||
|
{2, 3, 0},
|
||||||
|
{1, 12, 0},
|
||||||
|
{2, 6, 0},
|
||||||
|
{3, 4, 0},
|
||||||
|
{4, 3, 0},
|
||||||
|
{6, 2, 0},
|
||||||
|
{12, 1, 0},
|
||||||
|
{1, 1, 20},
|
||||||
|
{2, 2, 20},
|
||||||
|
{3, 2, 20},
|
||||||
|
{2, 3, 20},
|
||||||
|
{1, 12, 20},
|
||||||
|
{2, 6, 20},
|
||||||
|
{3, 4, 20},
|
||||||
|
{4, 3, 20},
|
||||||
|
{6, 2, 20},
|
||||||
|
{12, 1, 20},
|
||||||
|
} {
|
||||||
|
n := test.n
|
||||||
|
m := test.m
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = test.n
|
||||||
|
}
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
k := min(m, n)
|
||||||
|
tau := make([]float64, k)
|
||||||
|
for i := range tau {
|
||||||
|
tau[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
work := make([]float64, n)
|
||||||
|
for i := range work {
|
||||||
|
work[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
copy(aCopy, a)
|
||||||
|
impl.Dgeqr2(m, n, a, lda, tau, work)
|
||||||
|
|
||||||
|
// Test that the QR factorization has completed successfully. Compute
|
||||||
|
// Q based on the vectors.
|
||||||
|
q := constructQ("QR", m, n, a, lda, tau)
|
||||||
|
|
||||||
|
// Check that q is orthonormal
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
nrm := blas64.Nrm2(m, blas64.Vector{1, q.Data[i*m:]})
|
||||||
|
if math.Abs(nrm-1) > 1e-14 {
|
||||||
|
t.Errorf("Case %v, q not normal", c)
|
||||||
|
}
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
dot := blas64.Dot(m, blas64.Vector{1, q.Data[i*m:]}, blas64.Vector{1, q.Data[j*m:]})
|
||||||
|
if math.Abs(dot) > 1e-14 {
|
||||||
|
t.Errorf("Case %v, q not orthogonal", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check that A = Q * R
|
||||||
|
r := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: n,
|
||||||
|
Stride: n,
|
||||||
|
Data: make([]float64, m*n),
|
||||||
|
}
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := i; j < n; j++ {
|
||||||
|
r.Data[i*n+j] = a[i*lda+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
atmp := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: n,
|
||||||
|
Stride: lda,
|
||||||
|
Data: make([]float64, m*lda),
|
||||||
|
}
|
||||||
|
copy(atmp.Data, a)
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, r, 0, atmp)
|
||||||
|
if !floats.EqualApprox(atmp.Data, aCopy, 1e-14) {
|
||||||
|
t.Errorf("Q*R != a")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
91
testlapack/dgeqrf.go
Normal file
91
testlapack/dgeqrf.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dgeqrfer interface {
|
||||||
|
Dgeqr2er
|
||||||
|
Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DgeqrfTest(t *testing.T, impl Dgeqrfer) {
|
||||||
|
for c, test := range []struct {
|
||||||
|
m, n, lda int
|
||||||
|
}{
|
||||||
|
{10, 5, 0},
|
||||||
|
{5, 10, 0},
|
||||||
|
{10, 10, 0},
|
||||||
|
{300, 5, 0},
|
||||||
|
{3, 500, 0},
|
||||||
|
{200, 200, 0},
|
||||||
|
{300, 200, 0},
|
||||||
|
{204, 300, 0},
|
||||||
|
{1, 3000, 0},
|
||||||
|
{3000, 1, 0},
|
||||||
|
{10, 5, 20},
|
||||||
|
{5, 10, 20},
|
||||||
|
{10, 10, 20},
|
||||||
|
{300, 5, 400},
|
||||||
|
{3, 500, 600},
|
||||||
|
{200, 200, 300},
|
||||||
|
{300, 200, 300},
|
||||||
|
{204, 300, 400},
|
||||||
|
{1, 3000, 4000},
|
||||||
|
{3000, 1, 4000},
|
||||||
|
} {
|
||||||
|
m := test.m
|
||||||
|
n := test.n
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = test.n
|
||||||
|
}
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
a[i*lda+j] = rand.Float64()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tau := make([]float64, n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
tau[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
ans := make([]float64, len(a))
|
||||||
|
copy(ans, a)
|
||||||
|
work := make([]float64, n)
|
||||||
|
// Compute unblocked QR.
|
||||||
|
impl.Dgeqr2(m, n, ans, lda, tau, work)
|
||||||
|
// Compute blocked QR with small work.
|
||||||
|
impl.Dgeqrf(m, n, a, lda, tau, work, len(work))
|
||||||
|
if !floats.EqualApprox(ans, a, 1e-14) {
|
||||||
|
t.Errorf("Case %v, mismatch small work.", c)
|
||||||
|
}
|
||||||
|
// Try the full length of work.
|
||||||
|
impl.Dgeqrf(m, n, a, lda, tau, work, -1)
|
||||||
|
lwork := int(work[0])
|
||||||
|
work = make([]float64, lwork)
|
||||||
|
copy(a, aCopy)
|
||||||
|
impl.Dgeqrf(m, n, a, lda, tau, work, lwork)
|
||||||
|
if !floats.EqualApprox(ans, a, 1e-12) {
|
||||||
|
t.Errorf("Case %v, mismatch large work.", c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try a slightly smaller version of work to test blocking.
|
||||||
|
work = work[1:]
|
||||||
|
lwork--
|
||||||
|
copy(a, aCopy)
|
||||||
|
impl.Dgeqrf(m, n, a, lda, tau, work, lwork)
|
||||||
|
if !floats.EqualApprox(ans, a, 1e-12) {
|
||||||
|
t.Errorf("Case %v, mismatch large work.", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
92
testlapack/dlange.go
Normal file
92
testlapack/dlange.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dlanger interface {
|
||||||
|
Dlange(norm lapack.MatrixNorm, m, n int, a []float64, lda int, work []float64) float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func DlangeTest(t *testing.T, impl Dlanger) {
|
||||||
|
for _, test := range []struct {
|
||||||
|
m, n, lda int
|
||||||
|
}{
|
||||||
|
{4, 3, 0},
|
||||||
|
{3, 4, 0},
|
||||||
|
{4, 3, 100},
|
||||||
|
{3, 4, 100},
|
||||||
|
} {
|
||||||
|
m := test.m
|
||||||
|
n := test.n
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = n
|
||||||
|
}
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = (rand.Float64() - 0.5)
|
||||||
|
}
|
||||||
|
work := make([]float64, n)
|
||||||
|
for i := range work {
|
||||||
|
work[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
|
||||||
|
// Test MaxAbs norm.
|
||||||
|
norm := impl.Dlange(lapack.MaxAbs, m, n, a, lda, work)
|
||||||
|
var ans float64
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
idx := blas64.Iamax(n, blas64.Vector{1, aCopy[i*lda:]})
|
||||||
|
ans = math.Max(ans, math.Abs(a[i*lda+idx]))
|
||||||
|
}
|
||||||
|
// Should be strictly equal because there is no floating point summation error.
|
||||||
|
if ans != norm {
|
||||||
|
t.Errorf("MaxAbs mismatch. Want %v, got %v.", ans, norm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MaxColumnSum norm.
|
||||||
|
norm = impl.Dlange(lapack.MaxColumnSum, m, n, a, lda, work)
|
||||||
|
ans = 0
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
sum := blas64.Asum(m, blas64.Vector{lda, aCopy[i:]})
|
||||||
|
ans = math.Max(ans, sum)
|
||||||
|
}
|
||||||
|
if math.Abs(norm-ans) > 1e-14 {
|
||||||
|
t.Errorf("MaxColumnSum mismatch. Want %v, got %v.", ans, norm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test MaxRowSum norm.
|
||||||
|
norm = impl.Dlange(lapack.MaxRowSum, m, n, a, lda, work)
|
||||||
|
ans = 0
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
sum := blas64.Asum(n, blas64.Vector{1, aCopy[i*lda:]})
|
||||||
|
ans = math.Max(ans, sum)
|
||||||
|
}
|
||||||
|
if math.Abs(norm-ans) > 1e-14 {
|
||||||
|
t.Errorf("MaxRowSum mismatch. Want %v, got %v.", ans, norm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Frobenius norm
|
||||||
|
norm = impl.Dlange(lapack.NormFrob, m, n, a, lda, work)
|
||||||
|
ans = 0
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
sum := blas64.Nrm2(n, blas64.Vector{1, aCopy[i*lda:]})
|
||||||
|
ans += sum * sum
|
||||||
|
}
|
||||||
|
ans = math.Sqrt(ans)
|
||||||
|
if math.Abs(norm-ans) > 1e-14 {
|
||||||
|
t.Errorf("NormFrob mismatch. Want %v, got %v.", ans, norm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -63,6 +63,19 @@ func DlarfTest(t *testing.T, impl Dlarfer) {
|
|||||||
|
|
||||||
tau: 2,
|
tau: 2,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
m: 2,
|
||||||
|
n: 3,
|
||||||
|
ldc: 3,
|
||||||
|
|
||||||
|
incv: 4,
|
||||||
|
lastv: 0,
|
||||||
|
|
||||||
|
lastr: 0,
|
||||||
|
lastc: 1,
|
||||||
|
|
||||||
|
tau: 2,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
m: 10,
|
m: 10,
|
||||||
n: 10,
|
n: 10,
|
||||||
@@ -93,7 +106,7 @@ func DlarfTest(t *testing.T, impl Dlarfer) {
|
|||||||
sz := max(test.m, test.n) // so v works for both right and left side.
|
sz := max(test.m, test.n) // so v works for both right and left side.
|
||||||
v := make([]float64, test.incv*sz+1)
|
v := make([]float64, test.incv*sz+1)
|
||||||
// Fill with nonzero entries up until lastv.
|
// Fill with nonzero entries up until lastv.
|
||||||
for i := 0; i < test.lastv; i++ {
|
for i := 0; i <= test.lastv; i++ {
|
||||||
v[i*test.incv] = rand.Float64()
|
v[i*test.incv] = rand.Float64()
|
||||||
}
|
}
|
||||||
// Construct h explicitly to compare.
|
// Construct h explicitly to compare.
|
||||||
@@ -132,7 +145,7 @@ func DlarfTest(t *testing.T, impl Dlarfer) {
|
|||||||
work := make([]float64, sz)
|
work := make([]float64, sz)
|
||||||
impl.Dlarf(blas.Right, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
|
impl.Dlarf(blas.Right, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
|
||||||
if !floats.EqualApprox(c, cMat.Data, 1e-14) {
|
if !floats.EqualApprox(c, cMat.Data, 1e-14) {
|
||||||
t.Errorf("Dlarf mismatch case %v. Want %v, got %v", i, cMat.Data, c)
|
t.Errorf("Dlarf mismatch right, case %v. Want %v, got %v", i, cMat.Data, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test on the left side.
|
// Test on the left side.
|
||||||
@@ -153,7 +166,7 @@ func DlarfTest(t *testing.T, impl Dlarfer) {
|
|||||||
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, cMat2, 0, cMat)
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, cMat2, 0, cMat)
|
||||||
impl.Dlarf(blas.Left, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
|
impl.Dlarf(blas.Left, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
|
||||||
if !floats.EqualApprox(c, cMat.Data, 1e-14) {
|
if !floats.EqualApprox(c, cMat.Data, 1e-14) {
|
||||||
t.Errorf("Dlarf mismatch case %v. Want %v, got %v", i, cMat.Data, c)
|
t.Errorf("Dlarf mismatch left, case %v. Want %v, got %v", i, cMat.Data, c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
160
testlapack/dlarfb.go
Normal file
160
testlapack/dlarfb.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dlarfber interface {
|
||||||
|
Dlarfter
|
||||||
|
Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct,
|
||||||
|
store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int,
|
||||||
|
c []float64, ldc int, work []float64, ldwork int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DlarfbTest(t *testing.T, impl Dlarfber) {
|
||||||
|
for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} {
|
||||||
|
for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
|
||||||
|
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||||
|
for _, trans := range []blas.Transpose{blas.Trans, blas.NoTrans} {
|
||||||
|
for cas, test := range []struct {
|
||||||
|
ma, na, cdim, lda, ldt, ldc int
|
||||||
|
}{
|
||||||
|
{6, 6, 6, 0, 0, 0},
|
||||||
|
{6, 8, 10, 0, 0, 0},
|
||||||
|
{6, 10, 8, 0, 0, 0},
|
||||||
|
{8, 6, 10, 0, 0, 0},
|
||||||
|
{8, 10, 6, 0, 0, 0},
|
||||||
|
{10, 6, 8, 0, 0, 0},
|
||||||
|
{10, 8, 6, 0, 0, 0},
|
||||||
|
{6, 6, 6, 12, 15, 30},
|
||||||
|
{6, 8, 10, 12, 15, 30},
|
||||||
|
{6, 10, 8, 12, 15, 30},
|
||||||
|
{8, 6, 10, 12, 15, 30},
|
||||||
|
{8, 10, 6, 12, 15, 30},
|
||||||
|
{10, 6, 8, 12, 15, 30},
|
||||||
|
{10, 8, 6, 12, 15, 30},
|
||||||
|
{6, 6, 6, 15, 12, 30},
|
||||||
|
{6, 8, 10, 15, 12, 30},
|
||||||
|
{6, 10, 8, 15, 12, 30},
|
||||||
|
{8, 6, 10, 15, 12, 30},
|
||||||
|
{8, 10, 6, 15, 12, 30},
|
||||||
|
{10, 6, 8, 15, 12, 30},
|
||||||
|
{10, 8, 6, 15, 12, 30},
|
||||||
|
} {
|
||||||
|
// Generate a matrix for QR
|
||||||
|
ma := test.ma
|
||||||
|
na := test.na
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = na
|
||||||
|
}
|
||||||
|
a := make([]float64, ma*lda)
|
||||||
|
for i := 0; i < ma; i++ {
|
||||||
|
for j := 0; j < lda; j++ {
|
||||||
|
a[i*lda+j] = rand.Float64()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
k := min(ma, na)
|
||||||
|
|
||||||
|
// H is always ma x ma
|
||||||
|
var m, n, rowsWork int
|
||||||
|
switch {
|
||||||
|
default:
|
||||||
|
panic("not implemented")
|
||||||
|
case side == blas.Left:
|
||||||
|
m = test.ma
|
||||||
|
n = test.cdim
|
||||||
|
rowsWork = n
|
||||||
|
case side == blas.Right:
|
||||||
|
m = test.cdim
|
||||||
|
n = test.ma
|
||||||
|
rowsWork = m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use dgeqr2 to find the v vectors
|
||||||
|
tau := make([]float64, na)
|
||||||
|
work := make([]float64, na)
|
||||||
|
impl.Dgeqr2(ma, k, a, lda, tau, work)
|
||||||
|
|
||||||
|
// Correct the v vectors based on the direct and store
|
||||||
|
vMatTmp := extractVMat(ma, na, a, lda, lapack.Forward, lapack.ColumnWise)
|
||||||
|
vMat := constructVMat(vMatTmp, store, direct)
|
||||||
|
v := vMat.Data
|
||||||
|
ldv := vMat.Stride
|
||||||
|
|
||||||
|
// Use dlarft to find the t vector
|
||||||
|
ldt := test.ldt
|
||||||
|
if ldt == 0 {
|
||||||
|
ldt = k
|
||||||
|
}
|
||||||
|
tm := make([]float64, k*ldt)
|
||||||
|
|
||||||
|
impl.Dlarft(direct, store, ma, k, v, ldv, tau, tm, ldt)
|
||||||
|
|
||||||
|
// Generate c matrix
|
||||||
|
ldc := test.ldc
|
||||||
|
if ldc == 0 {
|
||||||
|
ldc = n
|
||||||
|
}
|
||||||
|
c := make([]float64, m*ldc)
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < ldc; j++ {
|
||||||
|
c[i*ldc+j] = rand.Float64()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cCopy := make([]float64, len(c))
|
||||||
|
copy(cCopy, c)
|
||||||
|
|
||||||
|
ldwork := k
|
||||||
|
work = make([]float64, rowsWork*k)
|
||||||
|
|
||||||
|
// Call Dlarfb with this information
|
||||||
|
impl.Dlarfb(side, trans, direct, store, m, n, k, v, ldv, tm, ldt, c, ldc, work, ldwork)
|
||||||
|
|
||||||
|
h := constructH(tau, vMat, store, direct)
|
||||||
|
|
||||||
|
cMat := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: n,
|
||||||
|
Stride: ldc,
|
||||||
|
Data: make([]float64, m*ldc),
|
||||||
|
}
|
||||||
|
copy(cMat.Data, cCopy)
|
||||||
|
ans := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: n,
|
||||||
|
Stride: ldc,
|
||||||
|
Data: make([]float64, m*ldc),
|
||||||
|
}
|
||||||
|
copy(ans.Data, cMat.Data)
|
||||||
|
switch {
|
||||||
|
default:
|
||||||
|
panic("not implemented")
|
||||||
|
case side == blas.Left && trans == blas.NoTrans:
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, cMat, 0, ans)
|
||||||
|
case side == blas.Left && trans == blas.Trans:
|
||||||
|
blas64.Gemm(blas.Trans, blas.NoTrans, 1, h, cMat, 0, ans)
|
||||||
|
case side == blas.Right && trans == blas.NoTrans:
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat, h, 0, ans)
|
||||||
|
case side == blas.Right && trans == blas.Trans:
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMat, h, 0, ans)
|
||||||
|
}
|
||||||
|
if !floats.EqualApprox(ans.Data, c, 1e-14) {
|
||||||
|
t.Errorf("Cas %v mismatch. Want %v, got %v.", cas, ans.Data, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
133
testlapack/dlarfg.go
Normal file
133
testlapack/dlarfg.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dlarfger interface {
|
||||||
|
Dlarfg(n int, alpha float64, x []float64, incX int) (beta, tau float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DlarfgTest(t *testing.T, impl Dlarfger) {
|
||||||
|
for i, test := range []struct {
|
||||||
|
alpha float64
|
||||||
|
n int
|
||||||
|
x []float64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
alpha: 4,
|
||||||
|
n: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
alpha: -2,
|
||||||
|
n: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
alpha: 0,
|
||||||
|
n: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
alpha: 1,
|
||||||
|
n: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
alpha: 1,
|
||||||
|
n: 2,
|
||||||
|
x: []float64{4, 5, 6},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
n := test.n
|
||||||
|
incX := 1
|
||||||
|
var x []float64
|
||||||
|
if test.x == nil {
|
||||||
|
x = make([]float64, n-1)
|
||||||
|
for i := range x {
|
||||||
|
x[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
x = make([]float64, n-1)
|
||||||
|
copy(x, test.x)
|
||||||
|
}
|
||||||
|
xcopy := make([]float64, n-1)
|
||||||
|
copy(xcopy, x)
|
||||||
|
alpha := test.alpha
|
||||||
|
beta, tau := impl.Dlarfg(n, alpha, x, incX)
|
||||||
|
|
||||||
|
// Verify the returns and the values in v. Construct h and perform
|
||||||
|
// the explicit multiplication.
|
||||||
|
h := make([]float64, n*n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
h[i*n+i] = 1
|
||||||
|
}
|
||||||
|
hmat := blas64.General{
|
||||||
|
Rows: n,
|
||||||
|
Cols: n,
|
||||||
|
Stride: n,
|
||||||
|
Data: h,
|
||||||
|
}
|
||||||
|
v := make([]float64, n)
|
||||||
|
copy(v[1:], x)
|
||||||
|
v[0] = 1
|
||||||
|
vVec := blas64.Vector{
|
||||||
|
Inc: 1,
|
||||||
|
Data: v,
|
||||||
|
}
|
||||||
|
blas64.Ger(-tau, vVec, vVec, hmat)
|
||||||
|
eye := blas64.General{
|
||||||
|
Rows: n,
|
||||||
|
Cols: n,
|
||||||
|
Stride: n,
|
||||||
|
Data: make([]float64, n*n),
|
||||||
|
}
|
||||||
|
blas64.Gemm(blas.Trans, blas.NoTrans, 1, hmat, hmat, 0, eye)
|
||||||
|
iseye := true
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
if i == j {
|
||||||
|
if math.Abs(eye.Data[i*n+j]-1) > 1e-14 {
|
||||||
|
iseye = false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if math.Abs(eye.Data[i*n+j]) > 1e-14 {
|
||||||
|
iseye = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !iseye {
|
||||||
|
t.Errorf("H^T * H is not I %V", eye)
|
||||||
|
}
|
||||||
|
|
||||||
|
xVec := blas64.Vector{
|
||||||
|
Inc: 1,
|
||||||
|
Data: make([]float64, n),
|
||||||
|
}
|
||||||
|
xVec.Data[0] = test.alpha
|
||||||
|
copy(xVec.Data[1:], xcopy)
|
||||||
|
|
||||||
|
ans := make([]float64, n)
|
||||||
|
ansVec := blas64.Vector{
|
||||||
|
Inc: 1,
|
||||||
|
Data: ans,
|
||||||
|
}
|
||||||
|
blas64.Gemv(blas.NoTrans, 1, hmat, xVec, 0, ansVec)
|
||||||
|
if math.Abs(ans[0]-beta) > 1e-14 {
|
||||||
|
t.Errorf("Case %v, beta mismatch. Want %v, got %v", i, ans[0], beta)
|
||||||
|
}
|
||||||
|
for i := 1; i < n; i++ {
|
||||||
|
if math.Abs(ans[i]) > 1e-14 {
|
||||||
|
t.Errorf("Case %v, nonzero answer %v", i, ans)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
167
testlapack/dlarft.go
Normal file
167
testlapack/dlarft.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dlarfter interface {
|
||||||
|
Dgeqr2er
|
||||||
|
Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, v []float64, ldv int, tau []float64, t []float64, ldt int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DlarftTest(t *testing.T, impl Dlarfter) {
|
||||||
|
for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} {
|
||||||
|
for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
|
||||||
|
for _, test := range []struct {
|
||||||
|
m, n, ldv, ldt int
|
||||||
|
}{
|
||||||
|
{6, 6, 0, 0},
|
||||||
|
{8, 6, 0, 0},
|
||||||
|
{6, 8, 0, 0},
|
||||||
|
{6, 6, 10, 15},
|
||||||
|
{8, 6, 10, 15},
|
||||||
|
{6, 8, 10, 15},
|
||||||
|
{6, 6, 15, 10},
|
||||||
|
{8, 6, 15, 10},
|
||||||
|
{6, 8, 15, 10},
|
||||||
|
} {
|
||||||
|
// Generate a matrix
|
||||||
|
m := test.m
|
||||||
|
n := test.n
|
||||||
|
lda := n
|
||||||
|
if lda == 0 {
|
||||||
|
lda = n
|
||||||
|
}
|
||||||
|
|
||||||
|
a := make([]float64, m*lda)
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < lda; j++ {
|
||||||
|
a[i*lda+j] = rand.Float64()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Use dgeqr2 to find the v vectors
|
||||||
|
tau := make([]float64, n)
|
||||||
|
work := make([]float64, n)
|
||||||
|
impl.Dgeqr2(m, n, a, lda, tau, work)
|
||||||
|
|
||||||
|
// Construct H using these answers
|
||||||
|
vMatTmp := extractVMat(m, n, a, lda, lapack.Forward, lapack.ColumnWise)
|
||||||
|
vMat := constructVMat(vMatTmp, store, direct)
|
||||||
|
v := vMat.Data
|
||||||
|
ldv := vMat.Stride
|
||||||
|
|
||||||
|
h := constructH(tau, vMat, store, direct)
|
||||||
|
|
||||||
|
k := min(m, n)
|
||||||
|
ldt := test.ldt
|
||||||
|
if ldt == 0 {
|
||||||
|
ldt = k
|
||||||
|
}
|
||||||
|
// Find T from the actual function
|
||||||
|
tm := make([]float64, k*ldt)
|
||||||
|
for i := range tm {
|
||||||
|
tm[i] = 100 + rand.Float64()
|
||||||
|
}
|
||||||
|
// The v data has been put into a.
|
||||||
|
impl.Dlarft(direct, store, m, k, v, ldv, tau, tm, ldt)
|
||||||
|
|
||||||
|
tData := make([]float64, len(tm))
|
||||||
|
copy(tData, tm)
|
||||||
|
if direct == lapack.Forward {
|
||||||
|
// Zero out the lower traingular portion.
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
tData[i*ldt+j] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Zero out the upper traingular portion.
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
for j := i + 1; j < k; j++ {
|
||||||
|
tData[i*ldt+j] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
T := blas64.General{
|
||||||
|
Rows: k,
|
||||||
|
Cols: k,
|
||||||
|
Stride: ldt,
|
||||||
|
Data: tData,
|
||||||
|
}
|
||||||
|
|
||||||
|
vMatT := blas64.General{
|
||||||
|
Rows: vMat.Cols,
|
||||||
|
Cols: vMat.Rows,
|
||||||
|
Stride: vMat.Rows,
|
||||||
|
Data: make([]float64, vMat.Cols*vMat.Rows),
|
||||||
|
}
|
||||||
|
for i := 0; i < vMat.Rows; i++ {
|
||||||
|
for j := 0; j < vMat.Cols; j++ {
|
||||||
|
vMatT.Data[j*vMatT.Stride+i] = vMat.Data[i*vMat.Stride+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var comp blas64.General
|
||||||
|
if store == lapack.ColumnWise {
|
||||||
|
// H = I - V * T * V^T
|
||||||
|
tmp := blas64.General{
|
||||||
|
Rows: T.Rows,
|
||||||
|
Cols: vMatT.Cols,
|
||||||
|
Stride: vMatT.Cols,
|
||||||
|
Data: make([]float64, T.Rows*vMatT.Cols),
|
||||||
|
}
|
||||||
|
// T * V^T
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMatT, 0, tmp)
|
||||||
|
comp = blas64.General{
|
||||||
|
Rows: vMat.Rows,
|
||||||
|
Cols: tmp.Cols,
|
||||||
|
Stride: tmp.Cols,
|
||||||
|
Data: make([]float64, vMat.Rows*tmp.Cols),
|
||||||
|
}
|
||||||
|
// V * (T * V^T)
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMat, tmp, 0, comp)
|
||||||
|
} else {
|
||||||
|
// H = I - V^T * T * V
|
||||||
|
tmp := blas64.General{
|
||||||
|
Rows: T.Rows,
|
||||||
|
Cols: vMat.Cols,
|
||||||
|
Stride: vMat.Cols,
|
||||||
|
Data: make([]float64, T.Rows*vMat.Cols),
|
||||||
|
}
|
||||||
|
// T * V
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMat, 0, tmp)
|
||||||
|
comp = blas64.General{
|
||||||
|
Rows: vMatT.Rows,
|
||||||
|
Cols: tmp.Cols,
|
||||||
|
Stride: tmp.Cols,
|
||||||
|
Data: make([]float64, vMatT.Rows*tmp.Cols),
|
||||||
|
}
|
||||||
|
// V^T * (T * V)
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMatT, tmp, 0, comp)
|
||||||
|
}
|
||||||
|
// I - V^T * T * V
|
||||||
|
for i := 0; i < comp.Rows; i++ {
|
||||||
|
for j := 0; j < comp.Cols; j++ {
|
||||||
|
comp.Data[i*m+j] *= -1
|
||||||
|
if i == j {
|
||||||
|
comp.Data[i*m+j] += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !floats.EqualApprox(comp.Data, h.Data, 1e-14) {
|
||||||
|
t.Errorf("T does not construct proper H. Store = %v, Direct = %v.\nWant %v\ngot %v.", string(store), string(direct), h.Data, comp.Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
138
testlapack/dorm2r.go
Normal file
138
testlapack/dorm2r.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dorm2rer interface {
|
||||||
|
Dgeqrfer
|
||||||
|
Dorm2r(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dorm2rTest(t *testing.T, impl Dorm2rer) {
|
||||||
|
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||||
|
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
|
||||||
|
for _, test := range []struct {
|
||||||
|
common, adim, cdim, lda, ldc int
|
||||||
|
}{
|
||||||
|
{3, 4, 5, 0, 0},
|
||||||
|
{3, 5, 4, 0, 0},
|
||||||
|
{4, 3, 5, 0, 0},
|
||||||
|
{4, 5, 3, 0, 0},
|
||||||
|
{5, 3, 4, 0, 0},
|
||||||
|
{5, 4, 3, 0, 0},
|
||||||
|
{3, 4, 5, 6, 20},
|
||||||
|
{3, 5, 4, 6, 20},
|
||||||
|
{4, 3, 5, 6, 20},
|
||||||
|
{4, 5, 3, 6, 20},
|
||||||
|
{5, 3, 4, 6, 20},
|
||||||
|
{5, 4, 3, 6, 20},
|
||||||
|
{3, 4, 5, 20, 6},
|
||||||
|
{3, 5, 4, 20, 6},
|
||||||
|
{4, 3, 5, 20, 6},
|
||||||
|
{4, 5, 3, 20, 6},
|
||||||
|
{5, 3, 4, 20, 6},
|
||||||
|
{5, 4, 3, 20, 6},
|
||||||
|
} {
|
||||||
|
var ma, na, mc, nc int
|
||||||
|
if side == blas.Left {
|
||||||
|
ma = test.common
|
||||||
|
na = test.adim
|
||||||
|
mc = test.common
|
||||||
|
nc = test.cdim
|
||||||
|
} else {
|
||||||
|
ma = test.common
|
||||||
|
na = test.adim
|
||||||
|
mc = test.cdim
|
||||||
|
nc = test.common
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random matrix
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = na
|
||||||
|
}
|
||||||
|
a := make([]float64, ma*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
ldc := test.ldc
|
||||||
|
if ldc == 0 {
|
||||||
|
ldc = nc
|
||||||
|
}
|
||||||
|
// Compute random C matrix
|
||||||
|
c := make([]float64, mc*ldc)
|
||||||
|
for i := range c {
|
||||||
|
c[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute QR
|
||||||
|
k := min(ma, na)
|
||||||
|
tau := make([]float64, k)
|
||||||
|
work := make([]float64, 1)
|
||||||
|
impl.Dgeqrf(ma, na, a, lda, tau, work, -1)
|
||||||
|
work = make([]float64, int(work[0]))
|
||||||
|
impl.Dgeqrf(ma, na, a, lda, tau, work, len(work))
|
||||||
|
|
||||||
|
// Build Q from result
|
||||||
|
q := constructQ("QR", ma, na, a, lda, tau)
|
||||||
|
|
||||||
|
cMat := blas64.General{
|
||||||
|
Rows: mc,
|
||||||
|
Cols: nc,
|
||||||
|
Stride: ldc,
|
||||||
|
Data: make([]float64, len(c)),
|
||||||
|
}
|
||||||
|
copy(cMat.Data, c)
|
||||||
|
cMatCopy := blas64.General{
|
||||||
|
Rows: cMat.Rows,
|
||||||
|
Cols: cMat.Cols,
|
||||||
|
Stride: cMat.Stride,
|
||||||
|
Data: make([]float64, len(cMat.Data)),
|
||||||
|
}
|
||||||
|
copy(cMatCopy.Data, cMat.Data)
|
||||||
|
switch {
|
||||||
|
default:
|
||||||
|
panic("bad test")
|
||||||
|
case side == blas.Left && trans == blas.NoTrans:
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
|
||||||
|
case side == blas.Left && trans == blas.Trans:
|
||||||
|
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
|
||||||
|
case side == blas.Right && trans == blas.NoTrans:
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMatCopy, q, 0, cMat)
|
||||||
|
case side == blas.Right && trans == blas.Trans:
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMatCopy, q, 0, cMat)
|
||||||
|
}
|
||||||
|
// Do Dorm2r ard compare
|
||||||
|
if side == blas.Left {
|
||||||
|
work = make([]float64, nc)
|
||||||
|
} else {
|
||||||
|
work = make([]float64, mc)
|
||||||
|
}
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
tauCopy := make([]float64, len(tau))
|
||||||
|
copy(tauCopy, tau)
|
||||||
|
impl.Dorm2r(side, trans, mc, nc, k, a, lda, tau, c, ldc, work)
|
||||||
|
if !floats.Equal(a, aCopy) {
|
||||||
|
t.Errorf("a changed in call")
|
||||||
|
}
|
||||||
|
if !floats.Equal(tau, tauCopy) {
|
||||||
|
t.Errorf("tau changed in call")
|
||||||
|
}
|
||||||
|
if !floats.EqualApprox(cMat.Data, c, 1e-14) {
|
||||||
|
t.Errorf("Multiplication mismatch.\n Want %v \n got %v.", cMat.Data, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
137
testlapack/dorml2.go
Normal file
137
testlapack/dorml2.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dorml2er interface {
|
||||||
|
Dgelqfer
|
||||||
|
Dorml2(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dorml2Test(t *testing.T, impl Dorml2er) {
|
||||||
|
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||||
|
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
|
||||||
|
for _, test := range []struct {
|
||||||
|
common, adim, cdim, lda, ldc int
|
||||||
|
}{
|
||||||
|
{3, 4, 5, 0, 0},
|
||||||
|
{3, 5, 4, 0, 0},
|
||||||
|
{4, 3, 5, 0, 0},
|
||||||
|
{4, 5, 3, 0, 0},
|
||||||
|
{5, 3, 4, 0, 0},
|
||||||
|
{5, 4, 3, 0, 0},
|
||||||
|
{3, 4, 5, 6, 20},
|
||||||
|
{3, 5, 4, 6, 20},
|
||||||
|
{4, 3, 5, 6, 20},
|
||||||
|
{4, 5, 3, 6, 20},
|
||||||
|
{5, 3, 4, 6, 20},
|
||||||
|
{5, 4, 3, 6, 20},
|
||||||
|
{3, 4, 5, 20, 6},
|
||||||
|
{3, 5, 4, 20, 6},
|
||||||
|
{4, 3, 5, 20, 6},
|
||||||
|
{4, 5, 3, 20, 6},
|
||||||
|
{5, 3, 4, 20, 6},
|
||||||
|
{5, 4, 3, 20, 6},
|
||||||
|
} {
|
||||||
|
var ma, na, mc, nc int
|
||||||
|
if side == blas.Left {
|
||||||
|
ma = test.adim
|
||||||
|
na = test.common
|
||||||
|
mc = test.common
|
||||||
|
nc = test.cdim
|
||||||
|
} else {
|
||||||
|
ma = test.adim
|
||||||
|
na = test.common
|
||||||
|
mc = test.cdim
|
||||||
|
nc = test.common
|
||||||
|
}
|
||||||
|
// Generate a random matrix
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = na
|
||||||
|
}
|
||||||
|
a := make([]float64, ma*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
ldc := test.ldc
|
||||||
|
if ldc == 0 {
|
||||||
|
ldc = nc
|
||||||
|
}
|
||||||
|
// Compute random C matrix
|
||||||
|
c := make([]float64, mc*ldc)
|
||||||
|
for i := range c {
|
||||||
|
c[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute LQ
|
||||||
|
k := min(ma, na)
|
||||||
|
tau := make([]float64, k)
|
||||||
|
work := make([]float64, 1)
|
||||||
|
impl.Dgelqf(ma, na, a, lda, tau, work, -1)
|
||||||
|
work = make([]float64, int(work[0]))
|
||||||
|
impl.Dgelqf(ma, na, a, lda, tau, work, len(work))
|
||||||
|
|
||||||
|
// Build Q from result
|
||||||
|
q := constructQ("LQ", ma, na, a, lda, tau)
|
||||||
|
|
||||||
|
cMat := blas64.General{
|
||||||
|
Rows: mc,
|
||||||
|
Cols: nc,
|
||||||
|
Stride: ldc,
|
||||||
|
Data: make([]float64, len(c)),
|
||||||
|
}
|
||||||
|
copy(cMat.Data, c)
|
||||||
|
cMatCopy := blas64.General{
|
||||||
|
Rows: cMat.Rows,
|
||||||
|
Cols: cMat.Cols,
|
||||||
|
Stride: cMat.Stride,
|
||||||
|
Data: make([]float64, len(cMat.Data)),
|
||||||
|
}
|
||||||
|
copy(cMatCopy.Data, cMat.Data)
|
||||||
|
switch {
|
||||||
|
default:
|
||||||
|
panic("bad test")
|
||||||
|
case side == blas.Left && trans == blas.NoTrans:
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
|
||||||
|
case side == blas.Left && trans == blas.Trans:
|
||||||
|
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
|
||||||
|
case side == blas.Right && trans == blas.NoTrans:
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMatCopy, q, 0, cMat)
|
||||||
|
case side == blas.Right && trans == blas.Trans:
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMatCopy, q, 0, cMat)
|
||||||
|
}
|
||||||
|
// Do Dorm2r ard compare
|
||||||
|
if side == blas.Left {
|
||||||
|
work = make([]float64, nc)
|
||||||
|
} else {
|
||||||
|
work = make([]float64, mc)
|
||||||
|
}
|
||||||
|
aCopy := make([]float64, len(a))
|
||||||
|
copy(aCopy, a)
|
||||||
|
tauCopy := make([]float64, len(tau))
|
||||||
|
copy(tauCopy, tau)
|
||||||
|
impl.Dorml2(side, trans, mc, nc, k, a, lda, tau, c, ldc, work)
|
||||||
|
if !floats.Equal(a, aCopy) {
|
||||||
|
t.Errorf("a changed in call")
|
||||||
|
}
|
||||||
|
if !floats.Equal(tau, tauCopy) {
|
||||||
|
t.Errorf("tau changed in call")
|
||||||
|
}
|
||||||
|
if !floats.EqualApprox(cMat.Data, c, 1e-14) {
|
||||||
|
t.Errorf("Multiplication mismatch.\n Want %v \n got %v.", cMat.Data, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
149
testlapack/dormlq.go
Normal file
149
testlapack/dormlq.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dormlqer interface {
|
||||||
|
Dorml2er
|
||||||
|
Dormlq(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DormlqTest(t *testing.T, impl Dormlqer) {
|
||||||
|
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||||
|
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
|
||||||
|
for _, test := range []struct {
|
||||||
|
common, adim, cdim, lda, ldc int
|
||||||
|
}{
|
||||||
|
{6, 7, 8, 0, 0},
|
||||||
|
{6, 8, 7, 0, 0},
|
||||||
|
{7, 6, 8, 0, 0},
|
||||||
|
{7, 8, 6, 0, 0},
|
||||||
|
{8, 6, 7, 0, 0},
|
||||||
|
{8, 7, 6, 0, 0},
|
||||||
|
{100, 200, 300, 0, 0},
|
||||||
|
{100, 300, 200, 0, 0},
|
||||||
|
{200, 100, 300, 0, 0},
|
||||||
|
{200, 300, 100, 0, 0},
|
||||||
|
{300, 100, 200, 0, 0},
|
||||||
|
{300, 200, 100, 0, 0},
|
||||||
|
{100, 200, 300, 400, 500},
|
||||||
|
{100, 300, 200, 400, 500},
|
||||||
|
{200, 100, 300, 400, 500},
|
||||||
|
{200, 300, 100, 400, 500},
|
||||||
|
{300, 100, 200, 400, 500},
|
||||||
|
{300, 200, 100, 400, 500},
|
||||||
|
{100, 200, 300, 500, 400},
|
||||||
|
{100, 300, 200, 500, 400},
|
||||||
|
{200, 100, 300, 500, 400},
|
||||||
|
{200, 300, 100, 500, 400},
|
||||||
|
{300, 100, 200, 500, 400},
|
||||||
|
{300, 200, 100, 500, 400},
|
||||||
|
} {
|
||||||
|
var ma, na, mc, nc int
|
||||||
|
if side == blas.Left {
|
||||||
|
ma = test.adim
|
||||||
|
na = test.common
|
||||||
|
mc = test.common
|
||||||
|
nc = test.cdim
|
||||||
|
} else {
|
||||||
|
ma = test.adim
|
||||||
|
na = test.common
|
||||||
|
mc = test.cdim
|
||||||
|
nc = test.common
|
||||||
|
}
|
||||||
|
// Generate a random matrix
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = na
|
||||||
|
}
|
||||||
|
a := make([]float64, ma*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
// Compute random C matrix
|
||||||
|
ldc := test.ldc
|
||||||
|
if ldc == 0 {
|
||||||
|
ldc = nc
|
||||||
|
}
|
||||||
|
c := make([]float64, mc*ldc)
|
||||||
|
for i := range c {
|
||||||
|
c[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute LQ
|
||||||
|
k := min(ma, na)
|
||||||
|
tau := make([]float64, k)
|
||||||
|
work := make([]float64, 1)
|
||||||
|
impl.Dgelqf(ma, na, a, lda, tau, work, -1)
|
||||||
|
work = make([]float64, int(work[0]))
|
||||||
|
impl.Dgelqf(ma, na, a, lda, tau, work, len(work))
|
||||||
|
|
||||||
|
cCopy := make([]float64, len(c))
|
||||||
|
copy(cCopy, c)
|
||||||
|
ans := make([]float64, len(c))
|
||||||
|
copy(ans, cCopy)
|
||||||
|
|
||||||
|
if side == blas.Left {
|
||||||
|
work = make([]float64, nc)
|
||||||
|
} else {
|
||||||
|
work = make([]float64, mc)
|
||||||
|
}
|
||||||
|
impl.Dorml2(side, trans, mc, nc, k, a, lda, tau, ans, ldc, work)
|
||||||
|
|
||||||
|
// Make sure Dorml2 and Dormlq match with small work
|
||||||
|
for i := range work {
|
||||||
|
work[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
lwork := len(work)
|
||||||
|
copy(c, cCopy)
|
||||||
|
impl.Dormlq(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||||
|
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||||
|
t.Errorf("Dormqr and Dorm2r mismatch for small work")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try with the optimum amount of work
|
||||||
|
copy(c, cCopy)
|
||||||
|
impl.Dormlq(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, -1)
|
||||||
|
work = make([]float64, int(work[0]))
|
||||||
|
lwork = len(work)
|
||||||
|
for i := range work {
|
||||||
|
work[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
impl.Dormlq(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||||
|
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||||
|
t.Errorf("Dormqr and Dorm2r mismatch for full work")
|
||||||
|
fmt.Println("ccopy")
|
||||||
|
for i := 0; i < mc; i++ {
|
||||||
|
fmt.Println(cCopy[i*ldc : (i+1)*ldc])
|
||||||
|
}
|
||||||
|
fmt.Println("ans =")
|
||||||
|
for i := 0; i < mc; i++ {
|
||||||
|
fmt.Println(ans[i*ldc : (i+1)*ldc])
|
||||||
|
}
|
||||||
|
fmt.Println("c =")
|
||||||
|
for i := 0; i < mc; i++ {
|
||||||
|
fmt.Println(c[i*ldc : (i+1)*ldc])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try with less than the optimum amount of work
|
||||||
|
copy(c, cCopy)
|
||||||
|
work = work[1:]
|
||||||
|
lwork--
|
||||||
|
impl.Dormlq(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||||
|
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||||
|
t.Errorf("Dormqr and Dorm2r mismatch for medium work")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
150
testlapack/dormqr.go
Normal file
150
testlapack/dormqr.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/floats"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dormqrer interface {
|
||||||
|
Dorm2rer
|
||||||
|
Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DormqrTest(t *testing.T, impl Dormqrer) {
|
||||||
|
for _, side := range []blas.Side{blas.Left, blas.Right} {
|
||||||
|
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
|
||||||
|
for _, test := range []struct {
|
||||||
|
common, adim, cdim, lda, ldc int
|
||||||
|
}{
|
||||||
|
{6, 7, 8, 0, 0},
|
||||||
|
{6, 8, 7, 0, 0},
|
||||||
|
{7, 6, 8, 0, 0},
|
||||||
|
{7, 8, 6, 0, 0},
|
||||||
|
{8, 6, 7, 0, 0},
|
||||||
|
{8, 7, 6, 0, 0},
|
||||||
|
{100, 200, 300, 0, 0},
|
||||||
|
{100, 300, 200, 0, 0},
|
||||||
|
{200, 100, 300, 0, 0},
|
||||||
|
{200, 300, 100, 0, 0},
|
||||||
|
{300, 100, 200, 0, 0},
|
||||||
|
{300, 200, 100, 0, 0},
|
||||||
|
{100, 200, 300, 400, 500},
|
||||||
|
{100, 300, 200, 400, 500},
|
||||||
|
{200, 100, 300, 400, 500},
|
||||||
|
{200, 300, 100, 400, 500},
|
||||||
|
{300, 100, 200, 400, 500},
|
||||||
|
{300, 200, 100, 400, 500},
|
||||||
|
{100, 200, 300, 500, 400},
|
||||||
|
{100, 300, 200, 500, 400},
|
||||||
|
{200, 100, 300, 500, 400},
|
||||||
|
{200, 300, 100, 500, 400},
|
||||||
|
{300, 100, 200, 500, 400},
|
||||||
|
{300, 200, 100, 500, 400},
|
||||||
|
} {
|
||||||
|
var ma, na, mc, nc int
|
||||||
|
if side == blas.Left {
|
||||||
|
ma = test.common
|
||||||
|
na = test.adim
|
||||||
|
mc = test.common
|
||||||
|
nc = test.cdim
|
||||||
|
} else {
|
||||||
|
ma = test.common
|
||||||
|
na = test.adim
|
||||||
|
mc = test.cdim
|
||||||
|
nc = test.common
|
||||||
|
}
|
||||||
|
// Generate a random matrix
|
||||||
|
lda := test.lda
|
||||||
|
if lda == 0 {
|
||||||
|
lda = na
|
||||||
|
}
|
||||||
|
a := make([]float64, ma*lda)
|
||||||
|
for i := range a {
|
||||||
|
a[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
// Compute random C matrix
|
||||||
|
ldc := test.ldc
|
||||||
|
if ldc == 0 {
|
||||||
|
ldc = nc
|
||||||
|
}
|
||||||
|
c := make([]float64, mc*ldc)
|
||||||
|
for i := range c {
|
||||||
|
c[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute QR
|
||||||
|
k := min(ma, na)
|
||||||
|
tau := make([]float64, k)
|
||||||
|
work := make([]float64, 1)
|
||||||
|
impl.Dgeqrf(ma, na, a, lda, tau, work, -1)
|
||||||
|
work = make([]float64, int(work[0]))
|
||||||
|
impl.Dgeqrf(ma, na, a, lda, tau, work, len(work))
|
||||||
|
|
||||||
|
cCopy := make([]float64, len(c))
|
||||||
|
copy(cCopy, c)
|
||||||
|
ans := make([]float64, len(c))
|
||||||
|
copy(ans, cCopy)
|
||||||
|
|
||||||
|
if side == blas.Left {
|
||||||
|
work = make([]float64, nc)
|
||||||
|
} else {
|
||||||
|
work = make([]float64, mc)
|
||||||
|
}
|
||||||
|
impl.Dorm2r(side, trans, mc, nc, k, a, lda, tau, ans, ldc, work)
|
||||||
|
|
||||||
|
// Make sure Dorm2r and Dormqr match with small work
|
||||||
|
for i := range work {
|
||||||
|
work[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
lwork := len(work)
|
||||||
|
copy(c, cCopy)
|
||||||
|
impl.Dormqr(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||||
|
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||||
|
t.Errorf("Dormqr and Dorm2r mismatch for small work")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try with the optimum amount of work
|
||||||
|
copy(c, cCopy)
|
||||||
|
impl.Dormqr(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, -1)
|
||||||
|
work = make([]float64, int(work[0]))
|
||||||
|
lwork = len(work)
|
||||||
|
for i := range work {
|
||||||
|
work[i] = rand.Float64()
|
||||||
|
}
|
||||||
|
_ = lwork
|
||||||
|
impl.Dormqr(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||||
|
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||||
|
t.Errorf("Dormqr and Dorm2r mismatch for full work")
|
||||||
|
fmt.Println("ccopy")
|
||||||
|
for i := 0; i < mc; i++ {
|
||||||
|
fmt.Println(cCopy[i*ldc : (i+1)*ldc])
|
||||||
|
}
|
||||||
|
fmt.Println("ans =")
|
||||||
|
for i := 0; i < mc; i++ {
|
||||||
|
fmt.Println(ans[i*ldc : (i+1)*ldc])
|
||||||
|
}
|
||||||
|
fmt.Println("c =")
|
||||||
|
for i := 0; i < mc; i++ {
|
||||||
|
fmt.Println(c[i*ldc : (i+1)*ldc])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try with less than the optimum amount of work
|
||||||
|
copy(c, cCopy)
|
||||||
|
work = work[1:]
|
||||||
|
lwork--
|
||||||
|
impl.Dormqr(side, trans, mc, nc, k, a, lda, tau, c, ldc, work, lwork)
|
||||||
|
if !floats.EqualApprox(c, ans, 1e-12) {
|
||||||
|
t.Errorf("Dormqr and Dorm2r mismatch for medium work")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -4,9 +4,289 @@
|
|||||||
|
|
||||||
package testlapack
|
package testlapack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gonum/blas"
|
||||||
|
"github.com/gonum/blas/blas64"
|
||||||
|
"github.com/gonum/lapack"
|
||||||
|
)
|
||||||
|
|
||||||
func max(a, b int) int {
|
func max(a, b int) int {
|
||||||
if a > b {
|
if a > b {
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractVMat collects the single reflectors from a into a matrix.
|
||||||
|
func extractVMat(m, n int, a []float64, lda int, direct lapack.Direct, store lapack.StoreV) blas64.General {
|
||||||
|
k := min(m, n)
|
||||||
|
switch {
|
||||||
|
default:
|
||||||
|
panic("not implemented")
|
||||||
|
case direct == lapack.Forward && store == lapack.ColumnWise:
|
||||||
|
v := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: k,
|
||||||
|
Stride: k,
|
||||||
|
Data: make([]float64, m*k),
|
||||||
|
}
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
v.Data[j*v.Stride+i] = 0
|
||||||
|
}
|
||||||
|
v.Data[i*v.Stride+i] = 1
|
||||||
|
for j := i + 1; j < m; j++ {
|
||||||
|
v.Data[j*v.Stride+i] = a[j*lda+i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
case direct == lapack.Forward && store == lapack.RowWise:
|
||||||
|
v := blas64.General{
|
||||||
|
Rows: k,
|
||||||
|
Cols: n,
|
||||||
|
Stride: n,
|
||||||
|
Data: make([]float64, k*n),
|
||||||
|
}
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
v.Data[i*v.Stride+j] = 0
|
||||||
|
}
|
||||||
|
v.Data[i*v.Stride+i] = 1
|
||||||
|
for j := i + 1; j < n; j++ {
|
||||||
|
v.Data[i*v.Stride+j] = a[i*lda+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// constructVMat transforms the v matrix based on the storage.
|
||||||
|
func constructVMat(vMat blas64.General, store lapack.StoreV, direct lapack.Direct) blas64.General {
|
||||||
|
m := vMat.Rows
|
||||||
|
k := vMat.Cols
|
||||||
|
switch {
|
||||||
|
default:
|
||||||
|
panic("not implemented")
|
||||||
|
case store == lapack.ColumnWise && direct == lapack.Forward:
|
||||||
|
ldv := k
|
||||||
|
v := make([]float64, m*k)
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
if j > i {
|
||||||
|
v[i*ldv+j] = 0
|
||||||
|
} else if j == i {
|
||||||
|
v[i*ldv+i] = 1
|
||||||
|
} else {
|
||||||
|
v[i*ldv+j] = vMat.Data[i*vMat.Stride+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: k,
|
||||||
|
Stride: k,
|
||||||
|
Data: v,
|
||||||
|
}
|
||||||
|
case store == lapack.RowWise && direct == lapack.Forward:
|
||||||
|
ldv := m
|
||||||
|
v := make([]float64, m*k)
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
if j > i {
|
||||||
|
v[j*ldv+i] = 0
|
||||||
|
} else if j == i {
|
||||||
|
v[j*ldv+i] = 1
|
||||||
|
} else {
|
||||||
|
v[j*ldv+i] = vMat.Data[i*vMat.Stride+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return blas64.General{
|
||||||
|
Rows: k,
|
||||||
|
Cols: m,
|
||||||
|
Stride: m,
|
||||||
|
Data: v,
|
||||||
|
}
|
||||||
|
case store == lapack.ColumnWise && direct == lapack.Backward:
|
||||||
|
rowsv := m
|
||||||
|
ldv := k
|
||||||
|
v := make([]float64, m*k)
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
vrow := rowsv - i - 1
|
||||||
|
vcol := k - j - 1
|
||||||
|
if j > i {
|
||||||
|
v[vrow*ldv+vcol] = 0
|
||||||
|
} else if j == i {
|
||||||
|
v[vrow*ldv+vcol] = 1
|
||||||
|
} else {
|
||||||
|
v[vrow*ldv+vcol] = vMat.Data[i*vMat.Stride+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return blas64.General{
|
||||||
|
Rows: rowsv,
|
||||||
|
Cols: ldv,
|
||||||
|
Stride: ldv,
|
||||||
|
Data: v,
|
||||||
|
}
|
||||||
|
case store == lapack.RowWise && direct == lapack.Backward:
|
||||||
|
rowsv := k
|
||||||
|
ldv := m
|
||||||
|
v := make([]float64, m*k)
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
vcol := ldv - i - 1
|
||||||
|
vrow := k - j - 1
|
||||||
|
if j > i {
|
||||||
|
v[vrow*ldv+vcol] = 0
|
||||||
|
} else if j == i {
|
||||||
|
v[vrow*ldv+vcol] = 1
|
||||||
|
} else {
|
||||||
|
v[vrow*ldv+vcol] = vMat.Data[i*vMat.Stride+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return blas64.General{
|
||||||
|
Rows: rowsv,
|
||||||
|
Cols: ldv,
|
||||||
|
Stride: ldv,
|
||||||
|
Data: v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func constructH(tau []float64, v blas64.General, store lapack.StoreV, direct lapack.Direct) blas64.General {
|
||||||
|
m := v.Rows
|
||||||
|
k := v.Cols
|
||||||
|
if store == lapack.RowWise {
|
||||||
|
m, k = k, m
|
||||||
|
}
|
||||||
|
h := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: m,
|
||||||
|
Stride: m,
|
||||||
|
Data: make([]float64, m*m),
|
||||||
|
}
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
h.Data[i*m+i] = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
vecData := make([]float64, m)
|
||||||
|
if store == lapack.ColumnWise {
|
||||||
|
for j := 0; j < m; j++ {
|
||||||
|
vecData[j] = v.Data[j*v.Cols+i]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for j := 0; j < m; j++ {
|
||||||
|
vecData[j] = v.Data[i*v.Cols+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vec := blas64.Vector{
|
||||||
|
Inc: 1,
|
||||||
|
Data: vecData,
|
||||||
|
}
|
||||||
|
|
||||||
|
hi := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: m,
|
||||||
|
Stride: m,
|
||||||
|
Data: make([]float64, m*m),
|
||||||
|
}
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
hi.Data[i*m+i] = 1
|
||||||
|
}
|
||||||
|
// hi = I - tau * v * v^T
|
||||||
|
blas64.Ger(-tau[i], vec, vec, hi)
|
||||||
|
|
||||||
|
hcopy := blas64.General{
|
||||||
|
Rows: m,
|
||||||
|
Cols: m,
|
||||||
|
Stride: m,
|
||||||
|
Data: make([]float64, m*m),
|
||||||
|
}
|
||||||
|
copy(hcopy.Data, h.Data)
|
||||||
|
if direct == lapack.Forward {
|
||||||
|
// H = H * H_I in forward mode
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hcopy, hi, 0, h)
|
||||||
|
} else {
|
||||||
|
// H = H_I * H in backward mode
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hi, hcopy, 0, h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// constructQ constructs the Q matrix from the result of dgeqrf and dgeqr2
|
||||||
|
func constructQ(kind string, m, n int, a []float64, lda int, tau []float64) blas64.General {
|
||||||
|
k := min(m, n)
|
||||||
|
var sz int
|
||||||
|
switch kind {
|
||||||
|
case "QR":
|
||||||
|
sz = m
|
||||||
|
case "LQ":
|
||||||
|
sz = n
|
||||||
|
}
|
||||||
|
|
||||||
|
q := blas64.General{
|
||||||
|
Rows: sz,
|
||||||
|
Cols: sz,
|
||||||
|
Stride: sz,
|
||||||
|
Data: make([]float64, sz*sz),
|
||||||
|
}
|
||||||
|
for i := 0; i < sz; i++ {
|
||||||
|
q.Data[i*sz+i] = 1
|
||||||
|
}
|
||||||
|
qCopy := blas64.General{
|
||||||
|
Rows: q.Rows,
|
||||||
|
Cols: q.Cols,
|
||||||
|
Stride: q.Stride,
|
||||||
|
Data: make([]float64, len(q.Data)),
|
||||||
|
}
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
h := blas64.General{
|
||||||
|
Rows: sz,
|
||||||
|
Cols: sz,
|
||||||
|
Stride: sz,
|
||||||
|
Data: make([]float64, sz*sz),
|
||||||
|
}
|
||||||
|
for j := 0; j < sz; j++ {
|
||||||
|
h.Data[j*sz+j] = 1
|
||||||
|
}
|
||||||
|
vVec := blas64.Vector{
|
||||||
|
Inc: 1,
|
||||||
|
Data: make([]float64, sz),
|
||||||
|
}
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
vVec.Data[j] = 0
|
||||||
|
}
|
||||||
|
vVec.Data[i] = 1
|
||||||
|
switch kind {
|
||||||
|
case "QR":
|
||||||
|
for j := i + 1; j < sz; j++ {
|
||||||
|
vVec.Data[j] = a[lda*j+i]
|
||||||
|
}
|
||||||
|
case "LQ":
|
||||||
|
for j := i + 1; j < sz; j++ {
|
||||||
|
vVec.Data[j] = a[i*lda+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
blas64.Ger(-tau[i], vVec, vVec, h)
|
||||||
|
copy(qCopy.Data, q.Data)
|
||||||
|
// Mulitply q by the new h
|
||||||
|
switch kind {
|
||||||
|
case "QR":
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, h, 0, q)
|
||||||
|
case "LQ":
|
||||||
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qCopy, 0, q)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user