mirror of
https://github.com/gonum/gonum.git
synced 2025-10-24 07:34:11 +08:00

Improved function documentation Fixed dlarfb and dlarft and added full tests Added dgelq2 Working Dgels Fix many comments and tests Many PR comment responses Responded to more PR comments Many PR comments
425 lines
11 KiB
Go
425 lines
11 KiB
Go
// Copyright ©2015 The gonum Authors. All rights reserved.
|
||
// Use of this source code is governed by a BSD-style
|
||
// license that can be found in the LICENSE file.
|
||
|
||
package native
|
||
|
||
import (
|
||
"github.com/gonum/blas"
|
||
"github.com/gonum/blas/blas64"
|
||
"github.com/gonum/lapack"
|
||
)
|
||
|
||
// Dlarfb applies a block reflector to a matrix.
|
||
//
|
||
// In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows:
|
||
// c = h * c if side == Left and trans == NoTrans
|
||
// c = c * h if side == Right and trans == NoTrans
|
||
// c = h^T * c if side == Left and trans == Trans
|
||
// c = c * h^t if side == Right and trans == Trans
|
||
// h is a product of elementary reflectors. direct sets the direction of multiplication
|
||
// h = h_1 * h_2 * ... * h_k if direct == Forward
|
||
// h = h_k * h_k-1 * ... * h_1 if direct == Backward
|
||
// The combination of direct and store defines the orientation of the elementary
|
||
// reflectors. In all cases the ones on the diagonal are implicitly represented.
|
||
//
|
||
// If direct == lapack.Forward and store == lapack.ColumnWise
|
||
// V = ( 1 )
|
||
// ( v1 1 )
|
||
// ( v1 v2 1 )
|
||
// ( v1 v2 v3 )
|
||
// ( v1 v2 v3 )
|
||
// If direct == lapack.Forward and store == lapack.RowWise
|
||
// V = ( 1 v1 v1 v1 v1 )
|
||
// ( 1 v2 v2 v2 )
|
||
// ( 1 v3 v3 )
|
||
// If direct == lapack.Backward and store == lapack.ColumnWise
|
||
// V = ( v1 v2 v3 )
|
||
// ( v1 v2 v3 )
|
||
// ( 1 v2 v3 )
|
||
// ( 1 v3 )
|
||
// ( 1 )
|
||
// If direct == lapack.Backward and store == lapack.RowWise
|
||
// V = ( v1 v1 1 )
|
||
// ( v2 v2 v2 1 )
|
||
// ( v3 v3 v3 v3 1 )
|
||
// An elementary reflector can be explicitly constructed by extracting the
|
||
// corresponding elements of v, placing a 1 where the diagonal would be, and
|
||
// placing zeros in the remaining elements.
|
||
//
|
||
// t is a k×k matrix containing the block reflector, and this function will panic
|
||
// if t is not of sufficient size. See Dlarft for more information.
|
||
//
|
||
// Work is a temporary storage matrix with stride ldwork.
|
||
// Work must be of size at least n×k side == Left and m×k if side == Right, and
|
||
// this function will panic if this size is not met.
|
||
func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct,
|
||
store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int,
|
||
c []float64, ldc int, work []float64, ldwork int) {
|
||
|
||
checkMatrix(m, n, c, ldc)
|
||
if m == 0 || n == 0 {
|
||
return
|
||
}
|
||
if k < 0 {
|
||
panic("lapack: negative number of transforms")
|
||
}
|
||
if side != blas.Left && side != blas.Right {
|
||
panic(badSide)
|
||
}
|
||
if trans != blas.Trans && trans != blas.NoTrans {
|
||
panic(badTrans)
|
||
}
|
||
if direct != lapack.Forward && direct != lapack.Backward {
|
||
panic(badDirect)
|
||
}
|
||
if store != lapack.ColumnWise && store != lapack.RowWise {
|
||
panic(badStore)
|
||
}
|
||
|
||
rowsWork := n
|
||
if side == blas.Right {
|
||
rowsWork = m
|
||
}
|
||
checkMatrix(rowsWork, k, work, ldwork)
|
||
|
||
bi := blas64.Implementation()
|
||
|
||
transt := blas.Trans
|
||
if trans == blas.Trans {
|
||
transt = blas.NoTrans
|
||
}
|
||
// TODO(btracey): This follows the original Lapack code where the
|
||
// elements are copied into the columns of the working array. The
|
||
// loops should go in the other direction so the data is written
|
||
// into the rows of work so the copy is not strided. A bigger change
|
||
// would be to replace work with work^T, but benchmarks would be
|
||
// needed to see if the change is merited.
|
||
if store == lapack.ColumnWise {
|
||
if direct == lapack.Forward {
|
||
// V1 is the first k rows of C. V2 is the remaining rows.
|
||
if side == blas.Left {
|
||
// W = C^T V = C1^T V1 + C2^T V2 (stored in work).
|
||
|
||
// W = C1.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
|
||
}
|
||
// W = W * V1.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit,
|
||
n, k, 1,
|
||
v, ldv,
|
||
work, ldwork)
|
||
if m > k {
|
||
// W = W + C2^T V2.
|
||
bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
|
||
1, c[k*ldc:], ldc, v[k*ldv:], ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W = W * T^T or W * T.
|
||
bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= V * W^T.
|
||
if m > k {
|
||
// C2 -= V2 * W^T.
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
|
||
-1, v[k*ldv:], ldv, work, ldwork,
|
||
1, c[k*ldc:], ldc)
|
||
}
|
||
// W *= V1^T.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
// C1 -= W^T.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < n; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[j*ldc+i] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// Form C = C * H or C * H^T, where C = (C1 C2).
|
||
|
||
// W = C1.
|
||
for i := 0; i < k; i++ {
|
||
bi.Dcopy(m, c[i:], ldc, work[i:], ldwork)
|
||
}
|
||
// W *= V1.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
|
||
1, c[k:], ldc, v[k*ldv:], ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or T^T.
|
||
bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
|
||
-1, work, ldwork, v[k*ldv:], ldv,
|
||
1, c[k:], ldc)
|
||
}
|
||
// C -= W * V^T.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
// C -= W.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < m; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[i*ldc+j] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// V = (V1)
|
||
// = (V2) (last k rows)
|
||
// Where V2 is unit upper triangular.
|
||
if side == blas.Left {
|
||
// Form H * C or
|
||
// W = C^T V.
|
||
|
||
// W = C2^T.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
|
||
}
|
||
// W *= V2.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
|
||
1, v[(m-k)*ldv:], ldv,
|
||
work, ldwork)
|
||
if m > k {
|
||
// W += C1^T * V1.
|
||
bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
|
||
1, c, ldc, v, ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or T^T.
|
||
bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= V * W^T.
|
||
if m > k {
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
|
||
-1, v, ldv, work, ldwork,
|
||
1, c, ldc)
|
||
}
|
||
// W *= V2^T.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
|
||
1, v[(m-k)*ldv:], ldv,
|
||
work, ldwork)
|
||
// C2 -= W^T.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < n; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// Form C * H or C * H^T where C = (C1 C2).
|
||
// W = C * V.
|
||
|
||
// W = C2.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
|
||
}
|
||
|
||
// W = W * V2.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
|
||
1, v[(n-k)*ldv:], ldv,
|
||
work, ldwork)
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
|
||
1, c, ldc, v, ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or T^T.
|
||
bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= W * V^T.
|
||
if n > k {
|
||
// C1 -= W * V1^T.
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
|
||
-1, work, ldwork, v, ldv,
|
||
1, c, ldc)
|
||
}
|
||
// W *= V2^T.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
|
||
1, v[(n-k)*ldv:], ldv,
|
||
work, ldwork)
|
||
// C2 -= W.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < m; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[i*ldc+n-k+j] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// Store = Rowwise.
|
||
if direct == lapack.Forward {
|
||
// V = (V1 V2) where v1 is unit upper triangular.
|
||
if side == blas.Left {
|
||
// Form H * C or H^T * C where C = (C1; C2).
|
||
// W = C^T * V^T.
|
||
|
||
// W = C1^T.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
|
||
}
|
||
// W *= V1^T.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
if m > k {
|
||
bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
|
||
1, c[k*ldc:], ldc, v[k:], ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or T^T.
|
||
bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= V^T * W^T.
|
||
if m > k {
|
||
bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
|
||
-1, v[k:], ldv, work, ldwork,
|
||
1, c[k*ldc:], ldc)
|
||
}
|
||
// W *= V1.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
// C1 -= W^T.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < n; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[j*ldc+i] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// Form C * H or C * H^T where C = (C1 C2).
|
||
// W = C * V^T.
|
||
|
||
// W = C1.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(m, c[j:], ldc, work[j:], ldwork)
|
||
}
|
||
// W *= V1^T.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
|
||
1, c[k:], ldc, v[k:], ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or T^T.
|
||
bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= W * V.
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
|
||
-1, work, ldwork, v[k:], ldv,
|
||
1, c[k:], ldc)
|
||
}
|
||
// W *= V1.
|
||
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
|
||
1, v, ldv,
|
||
work, ldwork)
|
||
// C1 -= W.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < m; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[i*ldc+j] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// V = (V1 V2) where V2 is the last k columns and is lower unit triangular.
|
||
if side == blas.Left {
|
||
// Form H * C or H^T C where C = (C1 ; C2).
|
||
// W = C^T * V^T.
|
||
|
||
// W = C2^T.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
|
||
}
|
||
// W *= V2^T.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
|
||
1, v[m-k:], ldv,
|
||
work, ldwork)
|
||
if m > k {
|
||
bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
|
||
1, c, ldc, v, ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or T^T.
|
||
bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= V^T * W^T.
|
||
if m > k {
|
||
bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
|
||
-1, v, ldv, work, ldwork,
|
||
1, c, ldc)
|
||
}
|
||
// W *= V2.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k,
|
||
1, v[m-k:], ldv,
|
||
work, ldwork)
|
||
// C2 -= W^T.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < n; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
return
|
||
}
|
||
// Form C * H or C * H^T where C = (C1 C2).
|
||
// W = C * V^T.
|
||
// W = C2.
|
||
for j := 0; j < k; j++ {
|
||
bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
|
||
}
|
||
// W *= V2^T.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
|
||
1, v[n-k:], ldv,
|
||
work, ldwork)
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
|
||
1, c, ldc, v, ldv,
|
||
1, work, ldwork)
|
||
}
|
||
// W *= T or T^T.
|
||
bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
|
||
1, t, ldt,
|
||
work, ldwork)
|
||
// C -= W * V.
|
||
if n > k {
|
||
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
|
||
-1, work, ldwork, v, ldv,
|
||
1, c, ldc)
|
||
}
|
||
// W *= V2.
|
||
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
|
||
1, v[n-k:], ldv,
|
||
work, ldwork)
|
||
// C1 -= W.
|
||
// TODO(btracey): This should use blas.Axpy.
|
||
for i := 0; i < m; i++ {
|
||
for j := 0; j < k; j++ {
|
||
c[i*ldc+n-k+j] -= work[i*ldwork+j]
|
||
}
|
||
}
|
||
}
|