mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-30 10:16:19 +08:00
- Named the cleanup function something more consistent with the other cleanup functions. - Fixed a couple bugs where deferred stuff was executed after the environment was destroyed, causing a segfault.
115 lines
3.1 KiB
Go
115 lines
3.1 KiB
Go
// This application loads a test ONNX network and executes it on some fixed
|
|
// data. It serves as an example of how to use the onnxruntime wrapper library.
|
|
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/yalue/onnxruntime"
|
|
"os"
|
|
"runtime"
|
|
)
|
|
|
|
// This type is read from JSON and used to determine the inputs and expected
|
|
// outputs for an ONNX network.
|
|
type testInputsInfo struct {
|
|
InputShape []int64 `json:"input_shape"`
|
|
FlattenedInput []float32 `json:"flattened_input"`
|
|
OutputShape []int64 `json:"output_shape"`
|
|
FlattenedOutput []float32 `json:"flattened_output"`
|
|
}
|
|
|
|
// Loads JSON that contains the shapes and data used by the test ONNX network.
|
|
// Requires the path to the JSON file.
|
|
func loadInputsJSON(path string) (*testInputsInfo, error) {
|
|
f, e := os.Open(path)
|
|
if e != nil {
|
|
return nil, fmt.Errorf("Error opening %s: %w", path, e)
|
|
}
|
|
defer f.Close()
|
|
d := json.NewDecoder(f)
|
|
var toReturn testInputsInfo
|
|
e = d.Decode(&toReturn)
|
|
if e != nil {
|
|
return nil, fmt.Errorf("Error decoding %s: %w", path, e)
|
|
}
|
|
return &toReturn, nil
|
|
}
|
|
|
|
func run() int {
|
|
if runtime.GOOS == "windows" {
|
|
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.dll")
|
|
} else {
|
|
if runtime.GOARCH == "arm64" {
|
|
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime_arm64.so")
|
|
} else {
|
|
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.so")
|
|
}
|
|
}
|
|
e := onnxruntime.InitializeEnvironment()
|
|
if e != nil {
|
|
fmt.Printf("Error initializing the onnxruntime environment: %s\n", e)
|
|
return 1
|
|
}
|
|
fmt.Printf("The onnxruntime environment initialized OK.\n")
|
|
defer func() {
|
|
e := onnxruntime.DestroyEnvironment()
|
|
if e != nil {
|
|
fmt.Printf("Error destroying onnx environment: %s\n", e)
|
|
} else {
|
|
fmt.Printf("Environment cleaned up OK.\n")
|
|
}
|
|
}()
|
|
|
|
// Load the JSON with the test input and output data.
|
|
testInputs, e := loadInputsJSON(
|
|
"../test_data/example_network_results.json")
|
|
if e != nil {
|
|
fmt.Printf("Error reading example inputs from JSON: %s\n", e)
|
|
return 1
|
|
}
|
|
|
|
// Create the session with the test onnx network
|
|
session, e := onnxruntime.NewSimpleSession[float32](
|
|
"../test_data/example_network.onnx")
|
|
if e != nil {
|
|
fmt.Printf("Error initializing the ONNX session: %s\n", e)
|
|
return 1
|
|
}
|
|
defer session.Destroy()
|
|
|
|
// Create input and output tensors.
|
|
inputShape := onnxruntime.Shape(testInputs.InputShape)
|
|
outputShape := onnxruntime.Shape(testInputs.OutputShape)
|
|
inputTensor, e := onnxruntime.NewTensor(inputShape,
|
|
testInputs.FlattenedInput)
|
|
if e != nil {
|
|
fmt.Printf("Failed getting input tensor: %s\n", e)
|
|
return 1
|
|
}
|
|
defer inputTensor.Destroy()
|
|
outputTensor, e := onnxruntime.NewEmptyTensor[float32](outputShape)
|
|
if e != nil {
|
|
fmt.Printf("Failed creating output tensor: %s\n", e)
|
|
return 1
|
|
}
|
|
defer outputTensor.Destroy()
|
|
|
|
// Actually run the network.
|
|
e = session.SimpleRun(inputTensor, outputTensor)
|
|
if e != nil {
|
|
fmt.Printf("Failed running network: %s\n", e)
|
|
return 1
|
|
}
|
|
|
|
for i := range outputTensor.GetData() {
|
|
fmt.Printf("Output value %d: expected %f, got %f\n", i,
|
|
outputTensor.GetData()[i], testInputs.FlattenedOutput[i])
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func main() {
|
|
os.Exit(run())
|
|
}
|