mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-29 01:42:27 +08:00
- Named the cleanup function something more consistent with the other cleanup functions. - Fixed a couple bugs where deferred stuff was executed after the environment was destroyed, causing a segfault.
180 lines
4.7 KiB
Go
180 lines
4.7 KiB
Go
package onnxruntime
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"runtime"
|
|
"testing"
|
|
)
|
|
|
|
// This type is read from JSON and used to determine the inputs and expected
|
|
// outputs for an ONNX network.
|
|
type testInputsInfo struct {
|
|
InputShape []int64 `json:"input_shape"`
|
|
FlattenedInput []float32 `json:"flattened_input"`
|
|
OutputShape []int64 `json:"output_shape"`
|
|
FlattenedOutput []float32 `json:"flattened_output"`
|
|
}
|
|
|
|
// This must be called prior to running each test.
|
|
func InitializeRuntime(t *testing.T) {
|
|
if runtime.GOOS == "windows" {
|
|
SetSharedLibraryPath("test_data/onnxruntime.dll")
|
|
} else {
|
|
if runtime.GOARCH == "arm64" {
|
|
SetSharedLibraryPath("test_data/onnxruntime_arm64.so")
|
|
} else {
|
|
SetSharedLibraryPath("test_data/onnxruntime.so")
|
|
}
|
|
}
|
|
e := InitializeEnvironment()
|
|
if e != nil {
|
|
t.Logf("Failed setting up onnxruntime environment: %s\n", e)
|
|
t.FailNow()
|
|
}
|
|
}
|
|
|
|
// Used to obtain the shape
|
|
func parseInputsJSON(path string, t *testing.T) *testInputsInfo {
|
|
toReturn := testInputsInfo{}
|
|
f, e := os.Open(path)
|
|
if e != nil {
|
|
t.Logf("Failed opening %s: %s\n", path, e)
|
|
t.FailNow()
|
|
}
|
|
defer f.Close()
|
|
d := json.NewDecoder(f)
|
|
e = d.Decode(&toReturn)
|
|
if e != nil {
|
|
t.Logf("Failed decoding %s: %s\n", path, e)
|
|
t.FailNow()
|
|
}
|
|
return &toReturn
|
|
}
|
|
|
|
// Returns an error if any element between a and b don't match.
|
|
func floatsEqual(a, b []float32) error {
|
|
if len(a) != len(b) {
|
|
return fmt.Errorf("Length mismatch: %d vs %d", len(a), len(b))
|
|
}
|
|
for i := range a {
|
|
diff := a[i] - b[i]
|
|
if diff < 0 {
|
|
diff = -diff
|
|
// Arbitrarily chosen precision.
|
|
if diff >= 0.00000001 {
|
|
return fmt.Errorf("Data element %d doesn't match: %f vs %v",
|
|
i, a[i], b[i])
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestTensorTypes(t *testing.T) {
|
|
// It would be nice to compare this, but doing that would require exposing
|
|
// the underlying C types in Go; the testing package doesn't support cgo.
|
|
type myFloat float64
|
|
dataType := GetTensorElementDataType[myFloat]()
|
|
t.Logf("Got data type for float64-based double: %d\n", dataType)
|
|
}
|
|
|
|
func TestCreateTensor(t *testing.T) {
|
|
InitializeRuntime(t)
|
|
defer DestroyEnvironment()
|
|
s := NewShape(1, 2, 3)
|
|
tensor1, e := NewEmptyTensor[uint8](s)
|
|
if e != nil {
|
|
t.Logf("Failed creating %s uint8 tensor: %s\n", s, e)
|
|
t.FailNow()
|
|
}
|
|
defer tensor1.Destroy()
|
|
if len(tensor1.GetData()) != 6 {
|
|
t.Logf("Incorrect data length for tensor1: %d\n",
|
|
len(tensor1.GetData()))
|
|
}
|
|
// Make sure that the underlying tensor created a copy of the shape we
|
|
// passed to NewEmptyTensor.
|
|
s[1] = 3
|
|
if tensor1.GetShape()[1] == s[1] {
|
|
t.Logf("Modifying the original shape incorrectly changed the " +
|
|
"tensor's shape.\n")
|
|
t.FailNow()
|
|
}
|
|
|
|
// Try making a tensor with a different data type.
|
|
s = NewShape(2, 5)
|
|
data := []float32{1.0}
|
|
_, e = NewTensor(s, data)
|
|
if e == nil {
|
|
t.Logf("Didn't get error when creating a tensor with too little " +
|
|
"data.\n")
|
|
t.FailNow()
|
|
}
|
|
t.Logf("Got expected error when creating a tensor without enough data: "+
|
|
"%s\n", e)
|
|
|
|
// It shouldn't be an error to create a tensor with too *much* underlying
|
|
// data; we'll just use the first portion of it.
|
|
data = []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}
|
|
tensor2, e := NewTensor(s, data)
|
|
if e != nil {
|
|
t.Logf("Error creating tensor with data: %s\n", e)
|
|
t.FailNow()
|
|
}
|
|
defer tensor2.Destroy()
|
|
// Make sure the tensor's internal slice only refers to the part we care
|
|
// about, and not the entire slice.
|
|
if len(tensor2.GetData()) != 10 {
|
|
t.Logf("New tensor data contains %d elements, when it should "+
|
|
"contain 10.\n", len(tensor2.GetData()))
|
|
t.FailNow()
|
|
}
|
|
}
|
|
|
|
func TestExampleNetwork(t *testing.T) {
|
|
InitializeRuntime(t)
|
|
defer func() {
|
|
e := DestroyEnvironment()
|
|
if e != nil {
|
|
t.Logf("Error cleaning up environment: %s\n", e)
|
|
t.FailNow()
|
|
}
|
|
}()
|
|
|
|
// Create input and output tensors
|
|
inputs := parseInputsJSON("test_data/example_network_results.json", t)
|
|
inputTensor, e := NewTensor(Shape(inputs.InputShape),
|
|
inputs.FlattenedInput)
|
|
if e != nil {
|
|
t.Logf("Failed creating input tensor: %s\n", e)
|
|
t.FailNow()
|
|
}
|
|
defer inputTensor.Destroy()
|
|
outputTensor, e := NewEmptyTensor[float32](Shape(inputs.OutputShape))
|
|
if e != nil {
|
|
t.Logf("Failed creating output tensor: %s\n", e)
|
|
t.FailNow()
|
|
}
|
|
defer outputTensor.Destroy()
|
|
|
|
// Set up and run the session.
|
|
session, e := NewSimpleSession[float32]("test_data/example_network.onnx")
|
|
if e != nil {
|
|
t.Logf("Failed creating simple session: %s\n", e)
|
|
t.FailNow()
|
|
}
|
|
defer session.Destroy()
|
|
e = session.SimpleRun(inputTensor, outputTensor)
|
|
if e != nil {
|
|
t.Logf("Failed to run the session: %s\n", e)
|
|
t.FailNow()
|
|
}
|
|
e = floatsEqual(outputTensor.GetData(), inputs.FlattenedOutput)
|
|
if e != nil {
|
|
t.Logf("The neural network didn't produce the correct result: %s\n", e)
|
|
t.FailNow()
|
|
}
|
|
}
|