Can now load and run a network

- The library is now capable of loading the test session and running
   it successfully.

 - Updated the example application, which is really the same as the test
   and should probably be deleted.

 - TODO: The input names actually *do* need to match the names when the
   network was created. The hardcoded names in onnxruntime_wrapper.c
   will *not* do!
This commit is contained in:
yalue
2023-02-04 10:31:57 -05:00
parent ff910beb76
commit 137ec8f64f
5 changed files with 210 additions and 42 deletions

View File

@@ -3,12 +3,39 @@
package main
import (
"encoding/json"
"fmt"
"github.com/yalue/onnxruntime"
"os"
"runtime"
)
// This type is read from JSON and used to determine the inputs and expected
// outputs for an ONNX network.
type testInputsInfo struct {
InputShape []int64 `json:"input_shape"`
FlattenedInput []float32 `json:"flattened_input"`
OutputShape []int64 `json:"output_shape"`
FlattenedOutput []float32 `json:"flattened_output"`
}
// Loads JSON that contains the shapes and data used by the test ONNX network.
// Requires the path to the JSON file.
func loadInputsJSON(path string) (*testInputsInfo, error) {
f, e := os.Open(path)
if e != nil {
return nil, fmt.Errorf("Error opening %s: %w", path, e)
}
defer f.Close()
d := json.NewDecoder(f)
var toReturn testInputsInfo
e = d.Decode(&toReturn)
if e != nil {
return nil, fmt.Errorf("Error decoding %s: %w", path, e)
}
return &toReturn, nil
}
func run() int {
if runtime.GOOS == "windows" {
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.dll")
@@ -26,6 +53,52 @@ func run() int {
}
fmt.Printf("The onnxruntime environment initialized OK.\n")
// Load the JSON with the test input and output data.
testInputs, e := loadInputsJSON(
"../test_data/example_network_results.json")
if e != nil {
fmt.Printf("Error reading example inputs from JSON: %s\n", e)
return 1
}
// Create the session with the test onnx network
session, e := onnxruntime.NewSimpleSession[float32](
"../test_data/example_network.onnx")
if e != nil {
fmt.Printf("Error initializing the ONNX session: %s\n", e)
return 1
}
defer session.Destroy()
// Create input and output tensors.
inputShape := onnxruntime.Shape(testInputs.InputShape)
outputShape := onnxruntime.Shape(testInputs.OutputShape)
inputTensor, e := onnxruntime.NewTensor(inputShape,
testInputs.FlattenedInput)
if e != nil {
fmt.Printf("Failed getting input tensor: %s\n", e)
return 1
}
defer inputTensor.Destroy()
outputTensor, e := onnxruntime.NewEmptyTensor[float32](outputShape)
if e != nil {
fmt.Printf("Failed creating output tensor: %s\n", e)
return 1
}
defer outputTensor.Destroy()
// Actually run the network.
e = session.SimpleRun(inputTensor, outputTensor)
if e != nil {
fmt.Printf("Failed running network: %s\n", e)
return 1
}
for i := range outputTensor.GetData() {
fmt.Printf("Output value %d: expected %f, got %f\n", i,
outputTensor.GetData()[i], testInputs.FlattenedOutput[i])
}
// Ordinarily, it is probably fine to call this using defer, but we do it
// here just so we can print a status message after the cleanup completes.
e = onnxruntime.CleanupEnvironment()