mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-30 02:11:45 +08:00
Add CustomDataTensor type
- This change introduces the CustomDataTensor type, which implements the ArbitraryTensor interface but, unlike the typed Tensor[T], is backed by an arbitrary slice of user-provided bytes. The user is responsible for providing the type of data the tensor is supposed to contain, as well as responsible for ensuring the data slice is in the correct format for the specified shape. - Added some test cases for the new CustomDataTensor type, which most notably will enable users to use float16 tensors (provided they converted the float16 array into bytes on their own).
This commit is contained in:
@@ -3,6 +3,7 @@ package onnxruntime_go
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"os"
|
||||
"runtime"
|
||||
@@ -600,6 +601,156 @@ func TestWrongInputs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func randomBytes(seed, n int64) []byte {
|
||||
toReturn := make([]byte, n)
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
rng.Read(toReturn)
|
||||
return toReturn
|
||||
}
|
||||
|
||||
func TestCustomDataTensors(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
shape := NewShape(2, 3, 4, 5)
|
||||
tensorData := randomBytes(123, 2*shape.FlattenedSize())
|
||||
// This could have been created using a Tensor[uint16], but we'll make sure
|
||||
// it works this way, too.
|
||||
v, e := NewCustomDataTensor(shape, tensorData, TensorElementDataTypeUint16)
|
||||
if e != nil {
|
||||
t.Logf("Error creating uint16 CustomDataTensor: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
shape[0] = 6
|
||||
if v.GetShape().Equals(shape) {
|
||||
t.Logf("CustomDataTensor didn't correctly create a Clone of its shape")
|
||||
t.FailNow()
|
||||
}
|
||||
e = v.Destroy()
|
||||
if e != nil {
|
||||
t.Logf("Error destroying CustomDataTensor: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
tensorData = randomBytes(1234, 2*shape.FlattenedSize())
|
||||
v, e = NewCustomDataTensor(shape, tensorData, TensorElementDataTypeFloat16)
|
||||
if e != nil {
|
||||
t.Logf("Error creating float16 tensor: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
e = v.Destroy()
|
||||
if e != nil {
|
||||
t.Logf("Error destroying float16 tensor: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
// Make sure we don't fail if providing more data than necessary
|
||||
shape[0] = 1
|
||||
v, e = NewCustomDataTensor(shape, tensorData,
|
||||
TensorElementDataTypeBFloat16)
|
||||
if e != nil {
|
||||
t.Logf("Got error when creating a tensor with more data than "+
|
||||
"necessary: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
v.Destroy()
|
||||
|
||||
// Make sure we fail when using a bad shape
|
||||
shape = NewShape(0, -1, -2)
|
||||
v, e = NewCustomDataTensor(shape, tensorData, TensorElementDataTypeFloat16)
|
||||
if e == nil {
|
||||
v.Destroy()
|
||||
t.Logf("Didn't get error when creating custom tensor with an " +
|
||||
"invalid shape\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error creating tensor with invalid shape: %s\n", e)
|
||||
shape = NewShape(1, 2, 3, 4, 5)
|
||||
tensorData = []byte{1, 2, 3, 4}
|
||||
v, e = NewCustomDataTensor(shape, tensorData, TensorElementDataTypeUint8)
|
||||
if e == nil {
|
||||
v.Destroy()
|
||||
t.Logf("Didn't get error when creating custom tensor with too " +
|
||||
"little data\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating custom data tensor with "+
|
||||
"too little data: %s\n", e)
|
||||
|
||||
// Make sure we fail when using a bad type
|
||||
tensorData = []byte{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
badType := TensorElementDataType(0xffffff)
|
||||
v, e = NewCustomDataTensor(NewShape(2), tensorData, badType)
|
||||
if e == nil {
|
||||
v.Destroy()
|
||||
t.Logf("Didn't get error when creating custom tensor with bad type\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating custom data tensor with bad "+
|
||||
"type: %s\n", e)
|
||||
}
|
||||
|
||||
// Converts a slice of floats to their representation as bfloat16 bytes.
|
||||
func floatsToBfloat16(f []float32) []byte {
|
||||
toReturn := make([]byte, 2*len(f))
|
||||
// bfloat16 is just a truncated version of a float32
|
||||
for i := range f {
|
||||
bf16Bits := uint16(math.Float32bits(f[i]) >> 16)
|
||||
toReturn[i*2] = uint8(bf16Bits)
|
||||
toReturn[i*2+1] = uint8(bf16Bits >> 8)
|
||||
}
|
||||
return toReturn
|
||||
}
|
||||
|
||||
func TestFloat16Network(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
|
||||
// The network takes a 1x2x2x2 float16 input
|
||||
inputData := []byte{
|
||||
// 0.0, 1.0, 2.0, 3.0
|
||||
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42,
|
||||
// 4.0, 5.0, 6.0, 7.0
|
||||
0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
||||
}
|
||||
// The network produces a 1x2x2x2 bfloat16 output: the input multiplied
|
||||
// by 3
|
||||
expectedOutput := floatsToBfloat16([]float32{0, 3, 6, 9, 12, 15, 18, 21})
|
||||
outputData := make([]byte, len(expectedOutput))
|
||||
inputTensor, e := NewCustomDataTensor(NewShape(1, 2, 2, 2), inputData,
|
||||
TensorElementDataTypeFloat16)
|
||||
if e != nil {
|
||||
t.Logf("Error creating input tensor: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
defer inputTensor.Destroy()
|
||||
outputTensor, e := NewCustomDataTensor(NewShape(1, 2, 2, 2), outputData,
|
||||
TensorElementDataTypeBFloat16)
|
||||
if e != nil {
|
||||
t.Logf("Error creating output tensor: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
defer outputTensor.Destroy()
|
||||
|
||||
session, e := NewAdvancedSession("test_data/example_float16.onnx",
|
||||
[]string{"InputA"}, []string{"OutputA"},
|
||||
[]ArbitraryTensor{inputTensor}, []ArbitraryTensor{outputTensor}, nil)
|
||||
if e != nil {
|
||||
t.Logf("Error creating session: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
defer session.Destroy()
|
||||
e = session.Run()
|
||||
if e != nil {
|
||||
t.Logf("Error running session: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
for i := range outputData {
|
||||
if outputData[i] != expectedOutput[i] {
|
||||
t.Logf("Incorrect output byte at index %d: 0x%02x (expected "+
|
||||
"0x%02x)\n", i, outputData[i], expectedOutput[i])
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// See the comment in generate_network_big_compute.py for information about
|
||||
// the inputs and outputs used for testing or benchmarking session options.
|
||||
func prepareBenchmarkTensors(t testing.TB, seed int64) (*Tensor[float32],
|
||||
@@ -633,8 +784,8 @@ func testBigSessionWithOptions(t *testing.T, options *SessionOptions) {
|
||||
defer input.Destroy()
|
||||
defer output.Destroy()
|
||||
session, e := NewAdvancedSession("test_data/example_big_compute.onnx",
|
||||
[]string{"Input"}, []string{"Output"}, []ArbitraryTensor{input},
|
||||
[]ArbitraryTensor{output}, options)
|
||||
[]string{"Input"}, []string{"Output"}, []ArbitraryTensor{input},
|
||||
[]ArbitraryTensor{output}, options)
|
||||
if e != nil {
|
||||
t.Logf("Error creating session: %s\n", e)
|
||||
t.FailNow()
|
||||
|
||||
Reference in New Issue
Block a user