Files
gonum/native/dlarfb.go
btracey ec100cf00f Working implementation of blocked QR
Improved function documentation

Fixed dlarfb and dlarft and added full tests

Added dgelq2

Working Dgels

Fix many comments and tests

Many PR comment responses

Responded to more PR comments

Many PR comments
2015-07-15 00:43:15 -07:00

425 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
"github.com/gonum/lapack"
)
// Dlarfb applies a block reflector to a matrix.
//
// In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows:
// c = h * c if side == Left and trans == NoTrans
// c = c * h if side == Right and trans == NoTrans
// c = h^T * c if side == Left and trans == Trans
// c = c * h^t if side == Right and trans == Trans
// h is a product of elementary reflectors. direct sets the direction of multiplication
// h = h_1 * h_2 * ... * h_k if direct == Forward
// h = h_k * h_k-1 * ... * h_1 if direct == Backward
// The combination of direct and store defines the orientation of the elementary
// reflectors. In all cases the ones on the diagonal are implicitly represented.
//
// If direct == lapack.Forward and store == lapack.ColumnWise
// V = ( 1 )
// ( v1 1 )
// ( v1 v2 1 )
// ( v1 v2 v3 )
// ( v1 v2 v3 )
// If direct == lapack.Forward and store == lapack.RowWise
// V = ( 1 v1 v1 v1 v1 )
// ( 1 v2 v2 v2 )
// ( 1 v3 v3 )
// If direct == lapack.Backward and store == lapack.ColumnWise
// V = ( v1 v2 v3 )
// ( v1 v2 v3 )
// ( 1 v2 v3 )
// ( 1 v3 )
// ( 1 )
// If direct == lapack.Backward and store == lapack.RowWise
// V = ( v1 v1 1 )
// ( v2 v2 v2 1 )
// ( v3 v3 v3 v3 1 )
// An elementary reflector can be explicitly constructed by extracting the
// corresponding elements of v, placing a 1 where the diagonal would be, and
// placing zeros in the remaining elements.
//
// t is a k×k matrix containing the block reflector, and this function will panic
// if t is not of sufficient size. See Dlarft for more information.
//
// Work is a temporary storage matrix with stride ldwork.
// Work must be of size at least n×k side == Left and m×k if side == Right, and
// this function will panic if this size is not met.
func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct,
store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int,
c []float64, ldc int, work []float64, ldwork int) {
checkMatrix(m, n, c, ldc)
if m == 0 || n == 0 {
return
}
if k < 0 {
panic("lapack: negative number of transforms")
}
if side != blas.Left && side != blas.Right {
panic(badSide)
}
if trans != blas.Trans && trans != blas.NoTrans {
panic(badTrans)
}
if direct != lapack.Forward && direct != lapack.Backward {
panic(badDirect)
}
if store != lapack.ColumnWise && store != lapack.RowWise {
panic(badStore)
}
rowsWork := n
if side == blas.Right {
rowsWork = m
}
checkMatrix(rowsWork, k, work, ldwork)
bi := blas64.Implementation()
transt := blas.Trans
if trans == blas.Trans {
transt = blas.NoTrans
}
// TODO(btracey): This follows the original Lapack code where the
// elements are copied into the columns of the working array. The
// loops should go in the other direction so the data is written
// into the rows of work so the copy is not strided. A bigger change
// would be to replace work with work^T, but benchmarks would be
// needed to see if the change is merited.
if store == lapack.ColumnWise {
if direct == lapack.Forward {
// V1 is the first k rows of C. V2 is the remaining rows.
if side == blas.Left {
// W = C^T V = C1^T V1 + C2^T V2 (stored in work).
// W = C1.
for j := 0; j < k; j++ {
bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
}
// W = W * V1.
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit,
n, k, 1,
v, ldv,
work, ldwork)
if m > k {
// W = W + C2^T V2.
bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
1, c[k*ldc:], ldc, v[k*ldv:], ldv,
1, work, ldwork)
}
// W = W * T^T or W * T.
bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
1, t, ldt,
work, ldwork)
// C -= V * W^T.
if m > k {
// C2 -= V2 * W^T.
bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
-1, v[k*ldv:], ldv, work, ldwork,
1, c[k*ldc:], ldc)
}
// W *= V1^T.
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
1, v, ldv,
work, ldwork)
// C1 -= W^T.
// TODO(btracey): This should use blas.Axpy.
for i := 0; i < n; i++ {
for j := 0; j < k; j++ {
c[j*ldc+i] -= work[i*ldwork+j]
}
}
return
}
// Form C = C * H or C * H^T, where C = (C1 C2).
// W = C1.
for i := 0; i < k; i++ {
bi.Dcopy(m, c[i:], ldc, work[i:], ldwork)
}
// W *= V1.
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
1, v, ldv,
work, ldwork)
if n > k {
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
1, c[k:], ldc, v[k*ldv:], ldv,
1, work, ldwork)
}
// W *= T or T^T.
bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
1, t, ldt,
work, ldwork)
if n > k {
bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
-1, work, ldwork, v[k*ldv:], ldv,
1, c[k:], ldc)
}
// C -= W * V^T.
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
1, v, ldv,
work, ldwork)
// C -= W.
// TODO(btracey): This should use blas.Axpy.
for i := 0; i < m; i++ {
for j := 0; j < k; j++ {
c[i*ldc+j] -= work[i*ldwork+j]
}
}
return
}
// V = (V1)
// = (V2) (last k rows)
// Where V2 is unit upper triangular.
if side == blas.Left {
// Form H * C or
// W = C^T V.
// W = C2^T.
for j := 0; j < k; j++ {
bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
}
// W *= V2.
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
1, v[(m-k)*ldv:], ldv,
work, ldwork)
if m > k {
// W += C1^T * V1.
bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
1, c, ldc, v, ldv,
1, work, ldwork)
}
// W *= T or T^T.
bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
1, t, ldt,
work, ldwork)
// C -= V * W^T.
if m > k {
bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
-1, v, ldv, work, ldwork,
1, c, ldc)
}
// W *= V2^T.
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
1, v[(m-k)*ldv:], ldv,
work, ldwork)
// C2 -= W^T.
// TODO(btracey): This should use blas.Axpy.
for i := 0; i < n; i++ {
for j := 0; j < k; j++ {
c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
}
}
return
}
// Form C * H or C * H^T where C = (C1 C2).
// W = C * V.
// W = C2.
for j := 0; j < k; j++ {
bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
}
// W = W * V2.
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
1, v[(n-k)*ldv:], ldv,
work, ldwork)
if n > k {
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
1, c, ldc, v, ldv,
1, work, ldwork)
}
// W *= T or T^T.
bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
1, t, ldt,
work, ldwork)
// C -= W * V^T.
if n > k {
// C1 -= W * V1^T.
bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
-1, work, ldwork, v, ldv,
1, c, ldc)
}
// W *= V2^T.
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
1, v[(n-k)*ldv:], ldv,
work, ldwork)
// C2 -= W.
// TODO(btracey): This should use blas.Axpy.
for i := 0; i < m; i++ {
for j := 0; j < k; j++ {
c[i*ldc+n-k+j] -= work[i*ldwork+j]
}
}
return
}
// Store = Rowwise.
if direct == lapack.Forward {
// V = (V1 V2) where v1 is unit upper triangular.
if side == blas.Left {
// Form H * C or H^T * C where C = (C1; C2).
// W = C^T * V^T.
// W = C1^T.
for j := 0; j < k; j++ {
bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
}
// W *= V1^T.
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
1, v, ldv,
work, ldwork)
if m > k {
bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
1, c[k*ldc:], ldc, v[k:], ldv,
1, work, ldwork)
}
// W *= T or T^T.
bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
1, t, ldt,
work, ldwork)
// C -= V^T * W^T.
if m > k {
bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
-1, v[k:], ldv, work, ldwork,
1, c[k*ldc:], ldc)
}
// W *= V1.
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
1, v, ldv,
work, ldwork)
// C1 -= W^T.
// TODO(btracey): This should use blas.Axpy.
for i := 0; i < n; i++ {
for j := 0; j < k; j++ {
c[j*ldc+i] -= work[i*ldwork+j]
}
}
return
}
// Form C * H or C * H^T where C = (C1 C2).
// W = C * V^T.
// W = C1.
for j := 0; j < k; j++ {
bi.Dcopy(m, c[j:], ldc, work[j:], ldwork)
}
// W *= V1^T.
bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
1, v, ldv,
work, ldwork)
if n > k {
bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
1, c[k:], ldc, v[k:], ldv,
1, work, ldwork)
}
// W *= T or T^T.
bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
1, t, ldt,
work, ldwork)
// C -= W * V.
if n > k {
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
-1, work, ldwork, v[k:], ldv,
1, c[k:], ldc)
}
// W *= V1.
bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
1, v, ldv,
work, ldwork)
// C1 -= W.
// TODO(btracey): This should use blas.Axpy.
for i := 0; i < m; i++ {
for j := 0; j < k; j++ {
c[i*ldc+j] -= work[i*ldwork+j]
}
}
return
}
// V = (V1 V2) where V2 is the last k columns and is lower unit triangular.
if side == blas.Left {
// Form H * C or H^T C where C = (C1 ; C2).
// W = C^T * V^T.
// W = C2^T.
for j := 0; j < k; j++ {
bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
}
// W *= V2^T.
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
1, v[m-k:], ldv,
work, ldwork)
if m > k {
bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
1, c, ldc, v, ldv,
1, work, ldwork)
}
// W *= T or T^T.
bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
1, t, ldt,
work, ldwork)
// C -= V^T * W^T.
if m > k {
bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
-1, v, ldv, work, ldwork,
1, c, ldc)
}
// W *= V2.
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k,
1, v[m-k:], ldv,
work, ldwork)
// C2 -= W^T.
// TODO(btracey): This should use blas.Axpy.
for i := 0; i < n; i++ {
for j := 0; j < k; j++ {
c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
}
}
return
}
// Form C * H or C * H^T where C = (C1 C2).
// W = C * V^T.
// W = C2.
for j := 0; j < k; j++ {
bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
}
// W *= V2^T.
bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
1, v[n-k:], ldv,
work, ldwork)
if n > k {
bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
1, c, ldc, v, ldv,
1, work, ldwork)
}
// W *= T or T^T.
bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
1, t, ldt,
work, ldwork)
// C -= W * V.
if n > k {
bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
-1, work, ldwork, v, ldv,
1, c, ldc)
}
// W *= V2.
bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
1, v[n-k:], ldv,
work, ldwork)
// C1 -= W.
// TODO(btracey): This should use blas.Axpy.
for i := 0; i < m; i++ {
for j := 0; j < k; j++ {
c[i*ldc+n-k+j] -= work[i*ldwork+j]
}
}
}