mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-11-01 19:22:39 +08:00
Add Tensor.ZeroContents() function
- Added a function that sets the underlying data slice to 0 for any tensor type. - Added a test for the new function. - Nobody was asking for this, but I plan to use it myself. - It's probably worth benchmarking whether this is better to do using C.memset or just a loop in go. Probably depends on tensor size and the overhead of the C call.
This commit is contained in:
@@ -124,7 +124,7 @@ func TestTensorTypes(t *testing.T) {
|
||||
|
||||
func TestCreateTensor(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer DestroyEnvironment()
|
||||
defer CleanupRuntime(t)
|
||||
s := NewShape(1, 2, 3)
|
||||
tensor1, e := NewEmptyTensor[uint8](s)
|
||||
if e != nil {
|
||||
@@ -177,7 +177,7 @@ func TestCreateTensor(t *testing.T) {
|
||||
|
||||
func TestBadTensorShapes(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer DestroyEnvironment()
|
||||
defer CleanupRuntime(t)
|
||||
s := NewShape()
|
||||
_, e := NewEmptyTensor[float64](s)
|
||||
if e == nil {
|
||||
@@ -227,6 +227,7 @@ func TestBadTensorShapes(t *testing.T) {
|
||||
|
||||
func TestCloneTensor(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
originalData := []float32{1, 2, 3, 4}
|
||||
originalTensor, e := NewTensor(NewShape(2, 2), originalData)
|
||||
if e != nil {
|
||||
@@ -258,6 +259,48 @@ func TestCloneTensor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestZeroTensorContents(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
a := newTestTensor[float64](t, NewShape(3, 4, 5))
|
||||
defer a.Destroy()
|
||||
data := a.GetData()
|
||||
for i := range data {
|
||||
data[i] = float64(i)
|
||||
}
|
||||
t.Logf("Before zeroing: a[%d] = %f\n", len(data)-1, data[len(data)-1])
|
||||
a.ZeroContents()
|
||||
for i, v := range data {
|
||||
if v != 0.0 {
|
||||
t.Logf("a[%d] = %f, expected it to be set to 0.\n", i, v)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Do the same basic test with a CustomDataTensor
|
||||
shape := NewShape(2, 3, 4, 5)
|
||||
customData := randomBytes(123, 2*shape.FlattenedSize())
|
||||
b, e := NewCustomDataTensor(shape, customData, TensorElementDataTypeUint16)
|
||||
if e != nil {
|
||||
t.Logf("Error creating custom data tensor: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
defer b.Destroy()
|
||||
for i := range customData {
|
||||
// This will wrap around, but doesn't matter. We just need arbitrary
|
||||
// nonzero data for the test.
|
||||
customData[i] = uint8(i)
|
||||
}
|
||||
t.Logf("Start of custom data before zeroing: % x\n", customData[0:10])
|
||||
b.ZeroContents()
|
||||
for i, v := range customData {
|
||||
if v != 0 {
|
||||
t.Logf("b[%d] = %d, expected it to be set to 0.\n", i, v)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExampleNetwork(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
|
||||
Reference in New Issue
Block a user