mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-29 09:52:33 +08:00
Can now load and run a network
- The library is now capable of loading the test session and running it successfully. - Updated the example application, which is really the same as the test and should probably be deleted. - TODO: The input names actually *do* need to match the names when the network was created. The hardcoded names in onnxruntime_wrapper.c will *not* do!
This commit is contained in:
@@ -2,6 +2,7 @@ package onnxruntime
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
@@ -10,10 +11,10 @@ import (
|
||||
// This type is read from JSON and used to determine the inputs and expected
|
||||
// outputs for an ONNX network.
|
||||
type testInputsInfo struct {
|
||||
InputShape []int `json:input_shape`
|
||||
FlattenedInput []float32 `json:flattened_input`
|
||||
OutputShape []int `json:output_shape`
|
||||
FlattenedOutput []float32 `json:flattened_output`
|
||||
InputShape []int64 `json:"input_shape"`
|
||||
FlattenedInput []float32 `json:"flattened_input"`
|
||||
OutputShape []int64 `json:"output_shape"`
|
||||
FlattenedOutput []float32 `json:"flattened_output"`
|
||||
}
|
||||
|
||||
// This must be called prior to running each test.
|
||||
@@ -52,6 +53,25 @@ func parseInputsJSON(path string, t *testing.T) *testInputsInfo {
|
||||
return &toReturn
|
||||
}
|
||||
|
||||
// Returns an error if any element between a and b don't match.
|
||||
func floatsEqual(a, b []float32) error {
|
||||
if len(a) != len(b) {
|
||||
return fmt.Errorf("Length mismatch: %d vs %d", len(a), len(b))
|
||||
}
|
||||
for i := range a {
|
||||
diff := a[i] - b[i]
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
// Arbitrarily chosen precision.
|
||||
if diff >= 0.00000001 {
|
||||
return fmt.Errorf("Data element %d doesn't match: %f vs %v",
|
||||
i, a[i], b[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTensorTypes(t *testing.T) {
|
||||
// It would be nice to compare this, but doing that would require exposing
|
||||
// the underlying C types in Go; the testing package doesn't support cgo.
|
||||
@@ -115,11 +135,42 @@ func TestCreateTensor(t *testing.T) {
|
||||
|
||||
func TestExampleNetwork(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
_ = parseInputsJSON("test_data/example_network_results.json", t)
|
||||
|
||||
// TODO: More tests here to run the network, once that's supported.
|
||||
// Create input and output tensors
|
||||
inputs := parseInputsJSON("test_data/example_network_results.json", t)
|
||||
inputTensor, e := NewTensor(Shape(inputs.InputShape),
|
||||
inputs.FlattenedInput)
|
||||
if e != nil {
|
||||
t.Logf("Failed creating input tensor: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
defer inputTensor.Destroy()
|
||||
outputTensor, e := NewEmptyTensor[float32](Shape(inputs.OutputShape))
|
||||
if e != nil {
|
||||
t.Logf("Failed creating output tensor: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
defer outputTensor.Destroy()
|
||||
|
||||
e := CleanupEnvironment()
|
||||
// Set up and run the session.
|
||||
session, e := NewSimpleSession[float32]("test_data/example_network.onnx")
|
||||
if e != nil {
|
||||
t.Logf("Failed creating simple session: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
defer session.Destroy()
|
||||
e = session.SimpleRun(inputTensor, outputTensor)
|
||||
if e != nil {
|
||||
t.Logf("Failed to run the session: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
e = floatsEqual(outputTensor.GetData(), inputs.FlattenedOutput)
|
||||
if e != nil {
|
||||
t.Logf("The neural network didn't produce the correct result: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
e = CleanupEnvironment()
|
||||
if e != nil {
|
||||
t.Logf("Failed cleaning up the environment: %s\n", e)
|
||||
t.FailNow()
|
||||
|
||||
Reference in New Issue
Block a user