Add Tensor.ZeroContents() function

- Added a function that sets the underlying data slice to 0 for any
   tensor type.

 - Added a test for the new function.

 - Nobody was asking for this, but I plan to use it myself.

 - It's probably worth benchmarking whether this is better to do using
   C.memset or just a loop in go. Probably depends on tensor size and
   the overhead of the C call.
This commit is contained in:
yalue
2023-11-27 14:09:25 -05:00
parent b0799d6c42
commit c3635af6f5
2 changed files with 60 additions and 2 deletions

View File

@@ -124,7 +124,7 @@ func TestTensorTypes(t *testing.T) {
func TestCreateTensor(t *testing.T) {
InitializeRuntime(t)
defer DestroyEnvironment()
defer CleanupRuntime(t)
s := NewShape(1, 2, 3)
tensor1, e := NewEmptyTensor[uint8](s)
if e != nil {
@@ -177,7 +177,7 @@ func TestCreateTensor(t *testing.T) {
func TestBadTensorShapes(t *testing.T) {
InitializeRuntime(t)
defer DestroyEnvironment()
defer CleanupRuntime(t)
s := NewShape()
_, e := NewEmptyTensor[float64](s)
if e == nil {
@@ -227,6 +227,7 @@ func TestBadTensorShapes(t *testing.T) {
func TestCloneTensor(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
originalData := []float32{1, 2, 3, 4}
originalTensor, e := NewTensor(NewShape(2, 2), originalData)
if e != nil {
@@ -258,6 +259,48 @@ func TestCloneTensor(t *testing.T) {
}
}
func TestZeroTensorContents(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
a := newTestTensor[float64](t, NewShape(3, 4, 5))
defer a.Destroy()
data := a.GetData()
for i := range data {
data[i] = float64(i)
}
t.Logf("Before zeroing: a[%d] = %f\n", len(data)-1, data[len(data)-1])
a.ZeroContents()
for i, v := range data {
if v != 0.0 {
t.Logf("a[%d] = %f, expected it to be set to 0.\n", i, v)
t.FailNow()
}
}
// Do the same basic test with a CustomDataTensor
shape := NewShape(2, 3, 4, 5)
customData := randomBytes(123, 2*shape.FlattenedSize())
b, e := NewCustomDataTensor(shape, customData, TensorElementDataTypeUint16)
if e != nil {
t.Logf("Error creating custom data tensor: %s\n", e)
t.FailNow()
}
defer b.Destroy()
for i := range customData {
// This will wrap around, but doesn't matter. We just need arbitrary
// nonzero data for the test.
customData[i] = uint8(i)
}
t.Logf("Start of custom data before zeroing: % x\n", customData[0:10])
b.ZeroContents()
for i, v := range customData {
if v != 0 {
t.Logf("b[%d] = %d, expected it to be set to 0.\n", i, v)
t.FailNow()
}
}
}
func TestExampleNetwork(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)