Can now load and run a network

- The library is now capable of loading the test session and running
   it successfully.

 - Updated the example application, which is really the same as the test
   and should probably be deleted.

 - TODO: The input names actually *do* need to match the names when the
   network was created. The hardcoded names in onnxruntime_wrapper.c
   will *not* do!
This commit is contained in:
yalue
2023-02-04 10:31:57 -05:00
parent ff910beb76
commit 137ec8f64f
5 changed files with 210 additions and 42 deletions

View File

@@ -2,6 +2,7 @@ package onnxruntime
import (
"encoding/json"
"fmt"
"os"
"runtime"
"testing"
@@ -10,10 +11,10 @@ import (
// This type is read from JSON and used to determine the inputs and expected
// outputs for an ONNX network.
type testInputsInfo struct {
InputShape []int `json:input_shape`
FlattenedInput []float32 `json:flattened_input`
OutputShape []int `json:output_shape`
FlattenedOutput []float32 `json:flattened_output`
InputShape []int64 `json:"input_shape"`
FlattenedInput []float32 `json:"flattened_input"`
OutputShape []int64 `json:"output_shape"`
FlattenedOutput []float32 `json:"flattened_output"`
}
// This must be called prior to running each test.
@@ -52,6 +53,25 @@ func parseInputsJSON(path string, t *testing.T) *testInputsInfo {
return &toReturn
}
// Returns an error if any element between a and b don't match.
func floatsEqual(a, b []float32) error {
if len(a) != len(b) {
return fmt.Errorf("Length mismatch: %d vs %d", len(a), len(b))
}
for i := range a {
diff := a[i] - b[i]
if diff < 0 {
diff = -diff
// Arbitrarily chosen precision.
if diff >= 0.00000001 {
return fmt.Errorf("Data element %d doesn't match: %f vs %v",
i, a[i], b[i])
}
}
}
return nil
}
func TestTensorTypes(t *testing.T) {
// It would be nice to compare this, but doing that would require exposing
// the underlying C types in Go; the testing package doesn't support cgo.
@@ -115,11 +135,42 @@ func TestCreateTensor(t *testing.T) {
func TestExampleNetwork(t *testing.T) {
InitializeRuntime(t)
_ = parseInputsJSON("test_data/example_network_results.json", t)
// TODO: More tests here to run the network, once that's supported.
// Create input and output tensors
inputs := parseInputsJSON("test_data/example_network_results.json", t)
inputTensor, e := NewTensor(Shape(inputs.InputShape),
inputs.FlattenedInput)
if e != nil {
t.Logf("Failed creating input tensor: %s\n", e)
t.FailNow()
}
defer inputTensor.Destroy()
outputTensor, e := NewEmptyTensor[float32](Shape(inputs.OutputShape))
if e != nil {
t.Logf("Failed creating output tensor: %s\n", e)
t.FailNow()
}
defer outputTensor.Destroy()
e := CleanupEnvironment()
// Set up and run the session.
session, e := NewSimpleSession[float32]("test_data/example_network.onnx")
if e != nil {
t.Logf("Failed creating simple session: %s\n", e)
t.FailNow()
}
defer session.Destroy()
e = session.SimpleRun(inputTensor, outputTensor)
if e != nil {
t.Logf("Failed to run the session: %s\n", e)
t.FailNow()
}
e = floatsEqual(outputTensor.GetData(), inputs.FlattenedOutput)
if e != nil {
t.Logf("The neural network didn't produce the correct result: %s\n", e)
t.FailNow()
}
e = CleanupEnvironment()
if e != nil {
t.Logf("Failed cleaning up the environment: %s\n", e)
t.FailNow()