mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 23:52:47 +08:00
mat: add minimal shadow detection for complex matrices
This commit is contained in:
@@ -166,3 +166,8 @@ func (m *CDense) Copy(a CMatrix) (r, c int) {
|
|||||||
}
|
}
|
||||||
return r, c
|
return r, c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RawCMatrix returns the underlying cblas128.General used by the receiver.
|
||||||
|
// Changes to elements in the receiver following the call will be reflected
|
||||||
|
// in returned cblas128.General.
|
||||||
|
func (m *CDense) RawCMatrix() cblas128.General { return m.mat }
|
||||||
|
@@ -8,6 +8,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"math/cmplx"
|
"math/cmplx"
|
||||||
|
|
||||||
|
"gonum.org/v1/gonum/blas/cblas128"
|
||||||
"gonum.org/v1/gonum/floats"
|
"gonum.org/v1/gonum/floats"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,6 +28,12 @@ type CMatrix interface {
|
|||||||
H() CMatrix
|
H() CMatrix
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A RawCMatrixer can return a cblas128.General representation of the receiver. Changes to the cblas128.General.Data
|
||||||
|
// slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not.
|
||||||
|
type RawCMatrixer interface {
|
||||||
|
RawCMatrix() cblas128.General
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
_ CMatrix = Conjugate{}
|
_ CMatrix = Conjugate{}
|
||||||
_ Unconjugator = Conjugate{}
|
_ Unconjugator = Conjugate{}
|
||||||
|
@@ -18,3 +18,14 @@ func offset(a, b []float64) int {
|
|||||||
// move. See https://golang.org/issue/12445.
|
// move. See https://golang.org/issue/12445.
|
||||||
return int(uintptr(unsafe.Pointer(&b[0]))-uintptr(unsafe.Pointer(&a[0]))) / int(unsafe.Sizeof(float64(0)))
|
return int(uintptr(unsafe.Pointer(&b[0]))-uintptr(unsafe.Pointer(&a[0]))) / int(unsafe.Sizeof(float64(0)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// offsetComplex returns the number of complex128 values b[0] is after a[0].
|
||||||
|
func offsetComplex(a, b []complex128) int {
|
||||||
|
if &a[0] == &b[0] {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
// This expression must be atomic with respect to GC moves.
|
||||||
|
// At this stage this is true, because the GC does not
|
||||||
|
// move. See https://golang.org/issue/12445.
|
||||||
|
return int(uintptr(unsafe.Pointer(&b[0]))-uintptr(unsafe.Pointer(&a[0]))) / int(unsafe.Sizeof(complex128(0)))
|
||||||
|
}
|
||||||
|
@@ -22,3 +22,18 @@ func offset(a, b []float64) int {
|
|||||||
// move. See https://golang.org/issue/12445.
|
// move. See https://golang.org/issue/12445.
|
||||||
return int(vb0.UnsafeAddr()-va0.UnsafeAddr()) / sizeOfFloat64
|
return int(vb0.UnsafeAddr()-va0.UnsafeAddr()) / sizeOfFloat64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var sizeOfComplex128 = int(reflect.TypeOf(complex128(0)).Size())
|
||||||
|
|
||||||
|
// offsetComplex returns the number of complex128 values b[0] is after a[0].
|
||||||
|
func offsetComplex(a, b []complex128) int {
|
||||||
|
va0 := reflect.ValueOf(a).Index(0)
|
||||||
|
vb0 := reflect.ValueOf(b).Index(0)
|
||||||
|
if va0.Addr() == vb0.Addr() {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
// This expression must be atomic with respect to GC moves.
|
||||||
|
// At this stage this is true, because the GC does not
|
||||||
|
// move. See https://golang.org/issue/12445.
|
||||||
|
return int(vb0.UnsafeAddr()-va0.UnsafeAddr()) / sizeOfComplex128
|
||||||
|
}
|
||||||
|
@@ -4,23 +4,7 @@
|
|||||||
|
|
||||||
package mat
|
package mat
|
||||||
|
|
||||||
import (
|
import "gonum.org/v1/gonum/blas/blas64"
|
||||||
"gonum.org/v1/gonum/blas/blas64"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// regionOverlap is the panic string used for the general case
|
|
||||||
// of a matrix region overlap between a source and destination.
|
|
||||||
regionOverlap = "mat: bad region: overlap"
|
|
||||||
|
|
||||||
// regionIdentity is the panic string used for the specific
|
|
||||||
// case of complete agreement between a source and a destination.
|
|
||||||
regionIdentity = "mat: bad region: identical"
|
|
||||||
|
|
||||||
// mismatchedStrides is the panic string used for overlapping
|
|
||||||
// data slices with differing strides.
|
|
||||||
mismatchedStrides = "mat: bad region: different strides"
|
|
||||||
)
|
|
||||||
|
|
||||||
// checkOverlap returns false if the receiver does not overlap data elements
|
// checkOverlap returns false if the receiver does not overlap data elements
|
||||||
// referenced by the parameter and panics otherwise.
|
// referenced by the parameter and panics otherwise.
|
||||||
@@ -212,38 +196,3 @@ func generalFromVector(a blas64.Vector, r, c int) blas64.General {
|
|||||||
Data: a.Data,
|
Data: a.Data,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// rectanglesOverlap returns whether the strided rectangles a and b overlap
|
|
||||||
// when b is offset by off elements after a but has at least one element before
|
|
||||||
// the end of a. off must be positive. a and b have aCols and bCols respectively.
|
|
||||||
//
|
|
||||||
// rectanglesOverlap works by shifting both matrices left such that the left
|
|
||||||
// column of a is at 0. The column indexes are flattened by obtaining the shifted
|
|
||||||
// relative left and right column positions modulo the common stride. This allows
|
|
||||||
// direct comparison of the column offsets when the matrix backing data slices
|
|
||||||
// are known to overlap.
|
|
||||||
func rectanglesOverlap(off, aCols, bCols, stride int) bool {
|
|
||||||
if stride == 1 {
|
|
||||||
// Unit stride means overlapping data
|
|
||||||
// slices must overlap as matrices.
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flatten the shifted matrix column positions
|
|
||||||
// so a starts at 0, modulo the common stride.
|
|
||||||
aTo := aCols
|
|
||||||
// The mod stride operations here make the from
|
|
||||||
// and to indexes comparable between a and b when
|
|
||||||
// the data slices of a and b overlap.
|
|
||||||
bFrom := off % stride
|
|
||||||
bTo := (bFrom + bCols) % stride
|
|
||||||
|
|
||||||
if bTo == 0 || bFrom < bTo {
|
|
||||||
// b matrix is not wrapped: compare for
|
|
||||||
// simple overlap.
|
|
||||||
return bFrom < aTo
|
|
||||||
}
|
|
||||||
|
|
||||||
// b strictly wraps and so must overlap with a.
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
54
mat/shadow_common.go
Normal file
54
mat/shadow_common.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
// Copyright ©2015 The Gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package mat
|
||||||
|
|
||||||
|
const (
|
||||||
|
// regionOverlap is the panic string used for the general case
|
||||||
|
// of a matrix region overlap between a source and destination.
|
||||||
|
regionOverlap = "mat: bad region: overlap"
|
||||||
|
|
||||||
|
// regionIdentity is the panic string used for the specific
|
||||||
|
// case of complete agreement between a source and a destination.
|
||||||
|
regionIdentity = "mat: bad region: identical"
|
||||||
|
|
||||||
|
// mismatchedStrides is the panic string used for overlapping
|
||||||
|
// data slices with differing strides.
|
||||||
|
mismatchedStrides = "mat: bad region: different strides"
|
||||||
|
)
|
||||||
|
|
||||||
|
// rectanglesOverlap returns whether the strided rectangles a and b overlap
|
||||||
|
// when b is offset by off elements after a but has at least one element before
|
||||||
|
// the end of a. off must be positive. a and b have aCols and bCols respectively.
|
||||||
|
//
|
||||||
|
// rectanglesOverlap works by shifting both matrices left such that the left
|
||||||
|
// column of a is at 0. The column indexes are flattened by obtaining the shifted
|
||||||
|
// relative left and right column positions modulo the common stride. This allows
|
||||||
|
// direct comparison of the column offsets when the matrix backing data slices
|
||||||
|
// are known to overlap.
|
||||||
|
func rectanglesOverlap(off, aCols, bCols, stride int) bool {
|
||||||
|
if stride == 1 {
|
||||||
|
// Unit stride means overlapping data
|
||||||
|
// slices must overlap as matrices.
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flatten the shifted matrix column positions
|
||||||
|
// so a starts at 0, modulo the common stride.
|
||||||
|
aTo := aCols
|
||||||
|
// The mod stride operations here make the from
|
||||||
|
// and to indexes comparable between a and b when
|
||||||
|
// the data slices of a and b overlap.
|
||||||
|
bFrom := off % stride
|
||||||
|
bTo := (bFrom + bCols) % stride
|
||||||
|
|
||||||
|
if bTo == 0 || bFrom < bTo {
|
||||||
|
// b matrix is not wrapped: compare for
|
||||||
|
// simple overlap.
|
||||||
|
return bFrom < aTo
|
||||||
|
}
|
||||||
|
|
||||||
|
// b strictly wraps and so must overlap with a.
|
||||||
|
return true
|
||||||
|
}
|
72
mat/shadow_complex.go
Normal file
72
mat/shadow_complex.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
// Copyright ©2015 The Gonum Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// TODO(kortschak): Generate this file from shadow.go when all complex type are available.
|
||||||
|
|
||||||
|
package mat
|
||||||
|
|
||||||
|
import "gonum.org/v1/gonum/blas/cblas128"
|
||||||
|
|
||||||
|
// checkOverlapComplex returns false if the receiver does not overlap data elements
|
||||||
|
// referenced by the parameter and panics otherwise.
|
||||||
|
//
|
||||||
|
// checkOverlapComplex methods return a boolean to allow the check call to be added to a
|
||||||
|
// boolean expression, making use of short-circuit operators.
|
||||||
|
func checkOverlapComplex(a, b cblas128.General) bool {
|
||||||
|
if cap(a.Data) == 0 || cap(b.Data) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
off := offsetComplex(a.Data[:1], b.Data[:1])
|
||||||
|
|
||||||
|
if off == 0 {
|
||||||
|
// At least one element overlaps.
|
||||||
|
if a.Cols == b.Cols && a.Rows == b.Rows && a.Stride == b.Stride {
|
||||||
|
panic(regionIdentity)
|
||||||
|
}
|
||||||
|
panic(regionOverlap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if off > 0 && len(a.Data) <= off {
|
||||||
|
// We know a is completely before b.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if off < 0 && len(b.Data) <= -off {
|
||||||
|
// We know a is completely after b.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.Stride != b.Stride && a.Stride != 1 && b.Stride != 1 {
|
||||||
|
// Too hard, so assume the worst; if either stride
|
||||||
|
// is one it will be caught in rectanglesOverlap.
|
||||||
|
panic(mismatchedStrides)
|
||||||
|
}
|
||||||
|
|
||||||
|
if off < 0 {
|
||||||
|
off = -off
|
||||||
|
a.Cols, b.Cols = b.Cols, a.Cols
|
||||||
|
}
|
||||||
|
if rectanglesOverlap(off, a.Cols, b.Cols, min(a.Stride, b.Stride)) {
|
||||||
|
panic(regionOverlap)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CDense) checkOverlapComplex(a cblas128.General) bool {
|
||||||
|
return checkOverlapComplex(m.RawCMatrix(), a)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CDense) checkOverlapMatrix(a CMatrix) bool {
|
||||||
|
if m == a {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var amat cblas128.General
|
||||||
|
switch ar := a.(type) {
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
case RawCMatrixer:
|
||||||
|
amat = ar.RawCMatrix()
|
||||||
|
}
|
||||||
|
return m.checkOverlapComplex(amat)
|
||||||
|
}
|
Reference in New Issue
Block a user