Enable multiple inputs and outputs

- Modified the Session API so that the user must provide all input and
   output tensors when creating the session (Run() no longer takes any
   arguments).  This should avoid allocations and fix the incorrect way
   I was using input and output names before.

 - Updated the test to use the new API.

 - Removed the onnx_example_application; it was only doing the same
   thing as the unit test anyway.
This commit is contained in:
yalue
2023-02-04 13:51:45 -05:00
parent e0cd5f977c
commit bb5039f6ad
7 changed files with 96 additions and 157 deletions

View File

@@ -1,7 +0,0 @@
An Example Application Using the `onnxruntime` Go Wrapper
=========================================================
To run this application, navigate to this directory and compile it using
`go build`. Afterwards, run it using `./onnx_example_application` (or
`onnx_example_application.exe` on Windows).

View File

@@ -1,114 +0,0 @@
// This application loads a test ONNX network and executes it on some fixed
// data. It serves as an example of how to use the onnxruntime wrapper library.
package main
import (
"encoding/json"
"fmt"
"github.com/yalue/onnxruntime"
"os"
"runtime"
)
// This type is read from JSON and used to determine the inputs and expected
// outputs for an ONNX network.
type testInputsInfo struct {
InputShape []int64 `json:"input_shape"`
FlattenedInput []float32 `json:"flattened_input"`
OutputShape []int64 `json:"output_shape"`
FlattenedOutput []float32 `json:"flattened_output"`
}
// Loads JSON that contains the shapes and data used by the test ONNX network.
// Requires the path to the JSON file.
func loadInputsJSON(path string) (*testInputsInfo, error) {
f, e := os.Open(path)
if e != nil {
return nil, fmt.Errorf("Error opening %s: %w", path, e)
}
defer f.Close()
d := json.NewDecoder(f)
var toReturn testInputsInfo
e = d.Decode(&toReturn)
if e != nil {
return nil, fmt.Errorf("Error decoding %s: %w", path, e)
}
return &toReturn, nil
}
func run() int {
if runtime.GOOS == "windows" {
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.dll")
} else {
if runtime.GOARCH == "arm64" {
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime_arm64.so")
} else {
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.so")
}
}
e := onnxruntime.InitializeEnvironment()
if e != nil {
fmt.Printf("Error initializing the onnxruntime environment: %s\n", e)
return 1
}
fmt.Printf("The onnxruntime environment initialized OK.\n")
defer func() {
e := onnxruntime.DestroyEnvironment()
if e != nil {
fmt.Printf("Error destroying onnx environment: %s\n", e)
} else {
fmt.Printf("Environment cleaned up OK.\n")
}
}()
// Load the JSON with the test input and output data.
testInputs, e := loadInputsJSON(
"../test_data/example_network_results.json")
if e != nil {
fmt.Printf("Error reading example inputs from JSON: %s\n", e)
return 1
}
// Create the session with the test onnx network
session, e := onnxruntime.NewSimpleSession[float32](
"../test_data/example_network.onnx")
if e != nil {
fmt.Printf("Error initializing the ONNX session: %s\n", e)
return 1
}
defer session.Destroy()
// Create input and output tensors.
inputShape := onnxruntime.Shape(testInputs.InputShape)
outputShape := onnxruntime.Shape(testInputs.OutputShape)
inputTensor, e := onnxruntime.NewTensor(inputShape,
testInputs.FlattenedInput)
if e != nil {
fmt.Printf("Failed getting input tensor: %s\n", e)
return 1
}
defer inputTensor.Destroy()
outputTensor, e := onnxruntime.NewEmptyTensor[float32](outputShape)
if e != nil {
fmt.Printf("Failed creating output tensor: %s\n", e)
return 1
}
defer outputTensor.Destroy()
// Actually run the network.
e = session.SimpleRun(inputTensor, outputTensor)
if e != nil {
fmt.Printf("Failed running network: %s\n", e)
return 1
}
for i := range outputTensor.GetData() {
fmt.Printf("Output value %d: expected %f, got %f\n", i,
outputTensor.GetData()[i], testInputs.FlattenedOutput[i])
}
return 0
}
func main() {
os.Exit(run())
}

View File

@@ -213,18 +213,45 @@ func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) {
return &toReturn, nil
}
// A simple wrapper around the OrtSession C struct. Requires the user to
// maintain all input and output tensors, and to use the same data type for
// input and output tensors.
type SimpleSession[T TensorData] struct {
// A wrapper around the OrtSession C struct. Requires the user to maintain all
// input and output tensors, and to use the same data type for input and output
// tensors.
type Session[T TensorData] struct {
ortSession *C.OrtSession
// We convert the tensor names to C strings only once, and keep them around
// here for future calls to Run().
inputNames []*C.char
outputNames []*C.char
// We only actually keep around the OrtValue pointers from the tensors.
inputs []*C.OrtValue
outputs []*C.OrtValue
}
// Loads the ONNX network at the given path, and initializes a SimpleSession
// Loads the ONNX network at the given path, and initializes a Session
// instance. If this returns successfully, the caller must call Destroy() on
// the returned session when it is no longer needed.
func NewSimpleSession[T TensorData](onnxFilePath string) (*SimpleSession[T],
error) {
// the returned session when it is no longer needed. We require the user to
// provide the input and output tensors and names at this point, in order to
// not need to re-allocate them every time Run() is called. The user instead
// can just update or access the input/output tensor data after calling Run().
// The input and output tensors MUST outlive this session, and calling
// session.Destroy() will not destroy the input or output tensors.
func NewSession[T TensorData](onnxFilePath string, inputNames,
outputNames []string, inputs, outputs []*Tensor[T]) (*Session[T], error) {
if len(inputs) == 0 {
return nil, fmt.Errorf("No inputs were provided")
}
if len(outputs) == 0 {
return nil, fmt.Errorf("No outputs were provided")
}
if len(inputs) != len(inputNames) {
return nil, fmt.Errorf("Got %d input tensors, but %d input names",
len(inputs), len(inputNames))
}
if len(outputs) != len(outputNames) {
return nil, fmt.Errorf("Got %d output tensors, but %d output names",
len(outputs), len(outputNames))
}
// We load content this way in order to avoid a mess of wide-character
// paths on Windows if we use CreateSession rather than
// CreateSessionFromArray.
@@ -233,7 +260,7 @@ func NewSimpleSession[T TensorData](onnxFilePath string) (*SimpleSession[T],
return nil, fmt.Errorf("Error reading %s: %w", onnxFilePath, e)
}
var ortSession *C.OrtSession
status := C.CreateSimpleSession(unsafe.Pointer(&(fileContent[0])),
status := C.CreateSession(unsafe.Pointer(&(fileContent[0])),
C.size_t(len(fileContent)), ortEnv, &ortSession)
if status != nil {
return nil, fmt.Errorf("Error creating session from %s: %w",
@@ -242,29 +269,59 @@ func NewSimpleSession[T TensorData](onnxFilePath string) (*SimpleSession[T],
// ONNXRuntime copies the file content unless a specific flag is provided
// when creating the session (and we don't provide it!)
fileContent = nil
return &SimpleSession[T]{
// Collect the inputs and outputs, along with their names, into a format
// more convenient for passing to the Run() function in the C API.
cInputNames := make([]*C.char, len(inputNames))
cOutputNames := make([]*C.char, len(outputNames))
for i, v := range inputNames {
cInputNames[i] = C.CString(v)
}
for i, v := range outputNames {
cOutputNames[i] = C.CString(v)
}
inputOrtTensors := make([]*C.OrtValue, len(inputs))
outputOrtTensors := make([]*C.OrtValue, len(outputs))
for i, v := range inputs {
inputOrtTensors[i] = v.ortValue
}
for i, v := range outputs {
outputOrtTensors[i] = v.ortValue
}
return &Session[T]{
ortSession: ortSession,
inputNames: cInputNames,
outputNames: cOutputNames,
inputs: inputOrtTensors,
outputs: outputOrtTensors,
}, nil
}
func (s *SimpleSession[_]) Destroy() error {
func (s *Session[_]) Destroy() error {
if s.ortSession != nil {
C.ReleaseOrtSession(s.ortSession)
s.ortSession = nil
}
for i := range s.inputNames {
C.free(unsafe.Pointer(s.inputNames[i]))
}
s.inputNames = nil
for i := range s.outputNames {
C.free(unsafe.Pointer(s.outputNames[i]))
}
s.outputNames = nil
s.inputs = nil
s.outputs = nil
return nil
}
// This function assumes the SimpleSession takes a single input tensor and
// produces a single output, both of which have the same type.
func (s *SimpleSession[T]) SimpleRun(input *Tensor[T],
output *Tensor[T]) error {
status := C.RunSimpleSession(s.ortSession, input.ortValue,
output.ortValue)
// Runs the session, updating the contents of the output tensors on success.
func (s *Session[T]) Run() error {
status := C.RunOrtSession(s.ortSession, &s.inputs[0], &s.inputNames[0],
C.int(len(s.inputs)), &s.outputs[0], &s.outputNames[0],
C.int(len(s.outputs)))
if status != nil {
return fmt.Errorf("Error running network: %w", statusToError(status))
}
return nil
}
// TODO (next): Test SimpleRun

View File

@@ -160,13 +160,15 @@ func TestExampleNetwork(t *testing.T) {
defer outputTensor.Destroy()
// Set up and run the session.
session, e := NewSimpleSession[float32]("test_data/example_network.onnx")
session, e := NewSession[float32]("test_data/example_network.onnx",
[]string{"1x4 Input Vector"}, []string{"1x2 Output Vector"},
[]*Tensor[float32]{inputTensor}, []*Tensor[float32]{outputTensor})
if e != nil {
t.Logf("Failed creating simple session: %s\n", e)
t.Logf("Failed creating session: %s\n", e)
t.FailNow()
}
defer session.Destroy()
e = session.SimpleRun(inputTensor, outputTensor)
e = session.Run()
if e != nil {
t.Logf("Failed to run the session: %s\n", e)
t.FailNow()

View File

@@ -35,7 +35,7 @@ const char *GetErrorMessage(OrtStatus *status) {
return ort_api->GetErrorMessage(status);
}
OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length,
OrtStatus *CreateSession(void *model_data, size_t model_data_length,
OrtEnv *env, OrtSession **out) {
OrtStatus *status = NULL;
OrtSessionOptions *options = NULL;
@@ -48,14 +48,13 @@ OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length,
return status;
}
OrtStatus *RunSimpleSession(OrtSession *session, OrtValue *input,
OrtValue *output) {
// TODO (next): We actually *do* need to pass in these dang input names.
const char *input_name[] = {"1x4 Input Vector"};
const char *output_name[] = {"1x2 Output Vector"};
OrtStatus *RunOrtSession(OrtSession *session,
OrtValue **inputs, char **input_names, int input_count,
OrtValue **outputs, char **output_names, int output_count) {
OrtStatus *status = NULL;
status = ort_api->Run(session, NULL, input_name,
(const OrtValue* const*) &input, 1, output_name, 1, &output);
status = ort_api->Run(session, NULL, (const char* const*) input_names,
(const OrtValue* const*) inputs, input_count,
(const char* const*) output_names, output_count, outputs);
return status;
}

View File

@@ -44,13 +44,15 @@ void ReleaseOrtMemoryInfo(OrtMemoryInfo *info);
// Returns the message associated with the given ORT status.
const char *GetErrorMessage(OrtStatus *status);
// Creates a "simple" session with a single input and single output.
OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length,
// Creates an ORT session using the given model.
OrtStatus *CreateSession(void *model_data, size_t model_data_length,
OrtEnv *env, OrtSession **out);
// Runs a session with single, user-allocated, input and output tensors.
OrtStatus *RunSimpleSession(OrtSession *session, OrtValue *input,
OrtValue *output);
// Runs an ORT session with the given input and output tensors, along with
// their names. In our use case, outputs must NOT be NULL.
OrtStatus *RunOrtSession(OrtSession *session,
OrtValue **inputs, char **input_names, int input_count,
OrtValue **outputs, char **output_names, int output_count);
// Wraps ort_api->ReleaseSession
void ReleaseOrtSession(OrtSession *session);