mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-30 02:11:45 +08:00
Enable multiple inputs and outputs
- Modified the Session API so that the user must provide all input and output tensors when creating the session (Run() no longer takes any arguments). This should avoid allocations and fix the incorrect way I was using input and output names before. - Updated the test to use the new API. - Removed the onnx_example_application; it was only doing the same thing as the unit test anyway.
This commit is contained in:
@@ -1,7 +0,0 @@
|
|||||||
An Example Application Using the `onnxruntime` Go Wrapper
|
|
||||||
=========================================================
|
|
||||||
|
|
||||||
To run this application, navigate to this directory and compile it using
|
|
||||||
`go build`. Afterwards, run it using `./onnx_example_application` (or
|
|
||||||
`onnx_example_application.exe` on Windows).
|
|
||||||
|
|
||||||
Binary file not shown.
@@ -1,114 +0,0 @@
|
|||||||
// This application loads a test ONNX network and executes it on some fixed
|
|
||||||
// data. It serves as an example of how to use the onnxruntime wrapper library.
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/yalue/onnxruntime"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
)
|
|
||||||
|
|
||||||
// This type is read from JSON and used to determine the inputs and expected
|
|
||||||
// outputs for an ONNX network.
|
|
||||||
type testInputsInfo struct {
|
|
||||||
InputShape []int64 `json:"input_shape"`
|
|
||||||
FlattenedInput []float32 `json:"flattened_input"`
|
|
||||||
OutputShape []int64 `json:"output_shape"`
|
|
||||||
FlattenedOutput []float32 `json:"flattened_output"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loads JSON that contains the shapes and data used by the test ONNX network.
|
|
||||||
// Requires the path to the JSON file.
|
|
||||||
func loadInputsJSON(path string) (*testInputsInfo, error) {
|
|
||||||
f, e := os.Open(path)
|
|
||||||
if e != nil {
|
|
||||||
return nil, fmt.Errorf("Error opening %s: %w", path, e)
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
d := json.NewDecoder(f)
|
|
||||||
var toReturn testInputsInfo
|
|
||||||
e = d.Decode(&toReturn)
|
|
||||||
if e != nil {
|
|
||||||
return nil, fmt.Errorf("Error decoding %s: %w", path, e)
|
|
||||||
}
|
|
||||||
return &toReturn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func run() int {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.dll")
|
|
||||||
} else {
|
|
||||||
if runtime.GOARCH == "arm64" {
|
|
||||||
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime_arm64.so")
|
|
||||||
} else {
|
|
||||||
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.so")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
e := onnxruntime.InitializeEnvironment()
|
|
||||||
if e != nil {
|
|
||||||
fmt.Printf("Error initializing the onnxruntime environment: %s\n", e)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
fmt.Printf("The onnxruntime environment initialized OK.\n")
|
|
||||||
defer func() {
|
|
||||||
e := onnxruntime.DestroyEnvironment()
|
|
||||||
if e != nil {
|
|
||||||
fmt.Printf("Error destroying onnx environment: %s\n", e)
|
|
||||||
} else {
|
|
||||||
fmt.Printf("Environment cleaned up OK.\n")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Load the JSON with the test input and output data.
|
|
||||||
testInputs, e := loadInputsJSON(
|
|
||||||
"../test_data/example_network_results.json")
|
|
||||||
if e != nil {
|
|
||||||
fmt.Printf("Error reading example inputs from JSON: %s\n", e)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the session with the test onnx network
|
|
||||||
session, e := onnxruntime.NewSimpleSession[float32](
|
|
||||||
"../test_data/example_network.onnx")
|
|
||||||
if e != nil {
|
|
||||||
fmt.Printf("Error initializing the ONNX session: %s\n", e)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
defer session.Destroy()
|
|
||||||
|
|
||||||
// Create input and output tensors.
|
|
||||||
inputShape := onnxruntime.Shape(testInputs.InputShape)
|
|
||||||
outputShape := onnxruntime.Shape(testInputs.OutputShape)
|
|
||||||
inputTensor, e := onnxruntime.NewTensor(inputShape,
|
|
||||||
testInputs.FlattenedInput)
|
|
||||||
if e != nil {
|
|
||||||
fmt.Printf("Failed getting input tensor: %s\n", e)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
defer inputTensor.Destroy()
|
|
||||||
outputTensor, e := onnxruntime.NewEmptyTensor[float32](outputShape)
|
|
||||||
if e != nil {
|
|
||||||
fmt.Printf("Failed creating output tensor: %s\n", e)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
defer outputTensor.Destroy()
|
|
||||||
|
|
||||||
// Actually run the network.
|
|
||||||
e = session.SimpleRun(inputTensor, outputTensor)
|
|
||||||
if e != nil {
|
|
||||||
fmt.Printf("Failed running network: %s\n", e)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range outputTensor.GetData() {
|
|
||||||
fmt.Printf("Output value %d: expected %f, got %f\n", i,
|
|
||||||
outputTensor.GetData()[i], testInputs.FlattenedOutput[i])
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
os.Exit(run())
|
|
||||||
}
|
|
||||||
@@ -213,18 +213,45 @@ func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) {
|
|||||||
return &toReturn, nil
|
return &toReturn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// A simple wrapper around the OrtSession C struct. Requires the user to
|
// A wrapper around the OrtSession C struct. Requires the user to maintain all
|
||||||
// maintain all input and output tensors, and to use the same data type for
|
// input and output tensors, and to use the same data type for input and output
|
||||||
// input and output tensors.
|
// tensors.
|
||||||
type SimpleSession[T TensorData] struct {
|
type Session[T TensorData] struct {
|
||||||
ortSession *C.OrtSession
|
ortSession *C.OrtSession
|
||||||
|
// We convert the tensor names to C strings only once, and keep them around
|
||||||
|
// here for future calls to Run().
|
||||||
|
inputNames []*C.char
|
||||||
|
outputNames []*C.char
|
||||||
|
// We only actually keep around the OrtValue pointers from the tensors.
|
||||||
|
inputs []*C.OrtValue
|
||||||
|
outputs []*C.OrtValue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loads the ONNX network at the given path, and initializes a SimpleSession
|
// Loads the ONNX network at the given path, and initializes a Session
|
||||||
// instance. If this returns successfully, the caller must call Destroy() on
|
// instance. If this returns successfully, the caller must call Destroy() on
|
||||||
// the returned session when it is no longer needed.
|
// the returned session when it is no longer needed. We require the user to
|
||||||
func NewSimpleSession[T TensorData](onnxFilePath string) (*SimpleSession[T],
|
// provide the input and output tensors and names at this point, in order to
|
||||||
error) {
|
// not need to re-allocate them every time Run() is called. The user instead
|
||||||
|
// can just update or access the input/output tensor data after calling Run().
|
||||||
|
// The input and output tensors MUST outlive this session, and calling
|
||||||
|
// session.Destroy() will not destroy the input or output tensors.
|
||||||
|
func NewSession[T TensorData](onnxFilePath string, inputNames,
|
||||||
|
outputNames []string, inputs, outputs []*Tensor[T]) (*Session[T], error) {
|
||||||
|
if len(inputs) == 0 {
|
||||||
|
return nil, fmt.Errorf("No inputs were provided")
|
||||||
|
}
|
||||||
|
if len(outputs) == 0 {
|
||||||
|
return nil, fmt.Errorf("No outputs were provided")
|
||||||
|
}
|
||||||
|
if len(inputs) != len(inputNames) {
|
||||||
|
return nil, fmt.Errorf("Got %d input tensors, but %d input names",
|
||||||
|
len(inputs), len(inputNames))
|
||||||
|
}
|
||||||
|
if len(outputs) != len(outputNames) {
|
||||||
|
return nil, fmt.Errorf("Got %d output tensors, but %d output names",
|
||||||
|
len(outputs), len(outputNames))
|
||||||
|
}
|
||||||
|
|
||||||
// We load content this way in order to avoid a mess of wide-character
|
// We load content this way in order to avoid a mess of wide-character
|
||||||
// paths on Windows if we use CreateSession rather than
|
// paths on Windows if we use CreateSession rather than
|
||||||
// CreateSessionFromArray.
|
// CreateSessionFromArray.
|
||||||
@@ -233,7 +260,7 @@ func NewSimpleSession[T TensorData](onnxFilePath string) (*SimpleSession[T],
|
|||||||
return nil, fmt.Errorf("Error reading %s: %w", onnxFilePath, e)
|
return nil, fmt.Errorf("Error reading %s: %w", onnxFilePath, e)
|
||||||
}
|
}
|
||||||
var ortSession *C.OrtSession
|
var ortSession *C.OrtSession
|
||||||
status := C.CreateSimpleSession(unsafe.Pointer(&(fileContent[0])),
|
status := C.CreateSession(unsafe.Pointer(&(fileContent[0])),
|
||||||
C.size_t(len(fileContent)), ortEnv, &ortSession)
|
C.size_t(len(fileContent)), ortEnv, &ortSession)
|
||||||
if status != nil {
|
if status != nil {
|
||||||
return nil, fmt.Errorf("Error creating session from %s: %w",
|
return nil, fmt.Errorf("Error creating session from %s: %w",
|
||||||
@@ -242,29 +269,59 @@ func NewSimpleSession[T TensorData](onnxFilePath string) (*SimpleSession[T],
|
|||||||
// ONNXRuntime copies the file content unless a specific flag is provided
|
// ONNXRuntime copies the file content unless a specific flag is provided
|
||||||
// when creating the session (and we don't provide it!)
|
// when creating the session (and we don't provide it!)
|
||||||
fileContent = nil
|
fileContent = nil
|
||||||
return &SimpleSession[T]{
|
|
||||||
ortSession: ortSession,
|
// Collect the inputs and outputs, along with their names, into a format
|
||||||
|
// more convenient for passing to the Run() function in the C API.
|
||||||
|
cInputNames := make([]*C.char, len(inputNames))
|
||||||
|
cOutputNames := make([]*C.char, len(outputNames))
|
||||||
|
for i, v := range inputNames {
|
||||||
|
cInputNames[i] = C.CString(v)
|
||||||
|
}
|
||||||
|
for i, v := range outputNames {
|
||||||
|
cOutputNames[i] = C.CString(v)
|
||||||
|
}
|
||||||
|
inputOrtTensors := make([]*C.OrtValue, len(inputs))
|
||||||
|
outputOrtTensors := make([]*C.OrtValue, len(outputs))
|
||||||
|
for i, v := range inputs {
|
||||||
|
inputOrtTensors[i] = v.ortValue
|
||||||
|
}
|
||||||
|
for i, v := range outputs {
|
||||||
|
outputOrtTensors[i] = v.ortValue
|
||||||
|
}
|
||||||
|
return &Session[T]{
|
||||||
|
ortSession: ortSession,
|
||||||
|
inputNames: cInputNames,
|
||||||
|
outputNames: cOutputNames,
|
||||||
|
inputs: inputOrtTensors,
|
||||||
|
outputs: outputOrtTensors,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SimpleSession[_]) Destroy() error {
|
func (s *Session[_]) Destroy() error {
|
||||||
if s.ortSession != nil {
|
if s.ortSession != nil {
|
||||||
C.ReleaseOrtSession(s.ortSession)
|
C.ReleaseOrtSession(s.ortSession)
|
||||||
s.ortSession = nil
|
s.ortSession = nil
|
||||||
}
|
}
|
||||||
|
for i := range s.inputNames {
|
||||||
|
C.free(unsafe.Pointer(s.inputNames[i]))
|
||||||
|
}
|
||||||
|
s.inputNames = nil
|
||||||
|
for i := range s.outputNames {
|
||||||
|
C.free(unsafe.Pointer(s.outputNames[i]))
|
||||||
|
}
|
||||||
|
s.outputNames = nil
|
||||||
|
s.inputs = nil
|
||||||
|
s.outputs = nil
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function assumes the SimpleSession takes a single input tensor and
|
// Runs the session, updating the contents of the output tensors on success.
|
||||||
// produces a single output, both of which have the same type.
|
func (s *Session[T]) Run() error {
|
||||||
func (s *SimpleSession[T]) SimpleRun(input *Tensor[T],
|
status := C.RunOrtSession(s.ortSession, &s.inputs[0], &s.inputNames[0],
|
||||||
output *Tensor[T]) error {
|
C.int(len(s.inputs)), &s.outputs[0], &s.outputNames[0],
|
||||||
status := C.RunSimpleSession(s.ortSession, input.ortValue,
|
C.int(len(s.outputs)))
|
||||||
output.ortValue)
|
|
||||||
if status != nil {
|
if status != nil {
|
||||||
return fmt.Errorf("Error running network: %w", statusToError(status))
|
return fmt.Errorf("Error running network: %w", statusToError(status))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (next): Test SimpleRun
|
|
||||||
|
|||||||
@@ -160,13 +160,15 @@ func TestExampleNetwork(t *testing.T) {
|
|||||||
defer outputTensor.Destroy()
|
defer outputTensor.Destroy()
|
||||||
|
|
||||||
// Set up and run the session.
|
// Set up and run the session.
|
||||||
session, e := NewSimpleSession[float32]("test_data/example_network.onnx")
|
session, e := NewSession[float32]("test_data/example_network.onnx",
|
||||||
|
[]string{"1x4 Input Vector"}, []string{"1x2 Output Vector"},
|
||||||
|
[]*Tensor[float32]{inputTensor}, []*Tensor[float32]{outputTensor})
|
||||||
if e != nil {
|
if e != nil {
|
||||||
t.Logf("Failed creating simple session: %s\n", e)
|
t.Logf("Failed creating session: %s\n", e)
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
defer session.Destroy()
|
defer session.Destroy()
|
||||||
e = session.SimpleRun(inputTensor, outputTensor)
|
e = session.Run()
|
||||||
if e != nil {
|
if e != nil {
|
||||||
t.Logf("Failed to run the session: %s\n", e)
|
t.Logf("Failed to run the session: %s\n", e)
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ const char *GetErrorMessage(OrtStatus *status) {
|
|||||||
return ort_api->GetErrorMessage(status);
|
return ort_api->GetErrorMessage(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length,
|
OrtStatus *CreateSession(void *model_data, size_t model_data_length,
|
||||||
OrtEnv *env, OrtSession **out) {
|
OrtEnv *env, OrtSession **out) {
|
||||||
OrtStatus *status = NULL;
|
OrtStatus *status = NULL;
|
||||||
OrtSessionOptions *options = NULL;
|
OrtSessionOptions *options = NULL;
|
||||||
@@ -48,14 +48,13 @@ OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length,
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtStatus *RunSimpleSession(OrtSession *session, OrtValue *input,
|
OrtStatus *RunOrtSession(OrtSession *session,
|
||||||
OrtValue *output) {
|
OrtValue **inputs, char **input_names, int input_count,
|
||||||
// TODO (next): We actually *do* need to pass in these dang input names.
|
OrtValue **outputs, char **output_names, int output_count) {
|
||||||
const char *input_name[] = {"1x4 Input Vector"};
|
|
||||||
const char *output_name[] = {"1x2 Output Vector"};
|
|
||||||
OrtStatus *status = NULL;
|
OrtStatus *status = NULL;
|
||||||
status = ort_api->Run(session, NULL, input_name,
|
status = ort_api->Run(session, NULL, (const char* const*) input_names,
|
||||||
(const OrtValue* const*) &input, 1, output_name, 1, &output);
|
(const OrtValue* const*) inputs, input_count,
|
||||||
|
(const char* const*) output_names, output_count, outputs);
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,13 +44,15 @@ void ReleaseOrtMemoryInfo(OrtMemoryInfo *info);
|
|||||||
// Returns the message associated with the given ORT status.
|
// Returns the message associated with the given ORT status.
|
||||||
const char *GetErrorMessage(OrtStatus *status);
|
const char *GetErrorMessage(OrtStatus *status);
|
||||||
|
|
||||||
// Creates a "simple" session with a single input and single output.
|
// Creates an ORT session using the given model.
|
||||||
OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length,
|
OrtStatus *CreateSession(void *model_data, size_t model_data_length,
|
||||||
OrtEnv *env, OrtSession **out);
|
OrtEnv *env, OrtSession **out);
|
||||||
|
|
||||||
// Runs a session with single, user-allocated, input and output tensors.
|
// Runs an ORT session with the given input and output tensors, along with
|
||||||
OrtStatus *RunSimpleSession(OrtSession *session, OrtValue *input,
|
// their names. In our use case, outputs must NOT be NULL.
|
||||||
OrtValue *output);
|
OrtStatus *RunOrtSession(OrtSession *session,
|
||||||
|
OrtValue **inputs, char **input_names, int input_count,
|
||||||
|
OrtValue **outputs, char **output_names, int output_count);
|
||||||
|
|
||||||
// Wraps ort_api->ReleaseSession
|
// Wraps ort_api->ReleaseSession
|
||||||
void ReleaseOrtSession(OrtSession *session);
|
void ReleaseOrtSession(OrtSession *session);
|
||||||
|
|||||||
Reference in New Issue
Block a user