Can now load and run a network

- The library is now capable of loading the test session and running
   it successfully.

 - Updated the example application, which is really the same as the test
   and should probably be deleted.

 - TODO: The input names actually *do* need to match the names when the
   network was created. The hardcoded names in onnxruntime_wrapper.c
   will *not* do!
This commit is contained in:
yalue
2023-02-04 10:31:57 -05:00
parent ff910beb76
commit 137ec8f64f
5 changed files with 210 additions and 42 deletions

View File

@@ -3,12 +3,39 @@
package main package main
import ( import (
"encoding/json"
"fmt" "fmt"
"github.com/yalue/onnxruntime" "github.com/yalue/onnxruntime"
"os" "os"
"runtime" "runtime"
) )
// This type is read from JSON and used to determine the inputs and expected
// outputs for an ONNX network.
type testInputsInfo struct {
InputShape []int64 `json:"input_shape"`
FlattenedInput []float32 `json:"flattened_input"`
OutputShape []int64 `json:"output_shape"`
FlattenedOutput []float32 `json:"flattened_output"`
}
// Loads JSON that contains the shapes and data used by the test ONNX network.
// Requires the path to the JSON file.
func loadInputsJSON(path string) (*testInputsInfo, error) {
f, e := os.Open(path)
if e != nil {
return nil, fmt.Errorf("Error opening %s: %w", path, e)
}
defer f.Close()
d := json.NewDecoder(f)
var toReturn testInputsInfo
e = d.Decode(&toReturn)
if e != nil {
return nil, fmt.Errorf("Error decoding %s: %w", path, e)
}
return &toReturn, nil
}
func run() int { func run() int {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.dll") onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.dll")
@@ -26,6 +53,52 @@ func run() int {
} }
fmt.Printf("The onnxruntime environment initialized OK.\n") fmt.Printf("The onnxruntime environment initialized OK.\n")
// Load the JSON with the test input and output data.
testInputs, e := loadInputsJSON(
"../test_data/example_network_results.json")
if e != nil {
fmt.Printf("Error reading example inputs from JSON: %s\n", e)
return 1
}
// Create the session with the test onnx network
session, e := onnxruntime.NewSimpleSession[float32](
"../test_data/example_network.onnx")
if e != nil {
fmt.Printf("Error initializing the ONNX session: %s\n", e)
return 1
}
defer session.Destroy()
// Create input and output tensors.
inputShape := onnxruntime.Shape(testInputs.InputShape)
outputShape := onnxruntime.Shape(testInputs.OutputShape)
inputTensor, e := onnxruntime.NewTensor(inputShape,
testInputs.FlattenedInput)
if e != nil {
fmt.Printf("Failed getting input tensor: %s\n", e)
return 1
}
defer inputTensor.Destroy()
outputTensor, e := onnxruntime.NewEmptyTensor[float32](outputShape)
if e != nil {
fmt.Printf("Failed creating output tensor: %s\n", e)
return 1
}
defer outputTensor.Destroy()
// Actually run the network.
e = session.SimpleRun(inputTensor, outputTensor)
if e != nil {
fmt.Printf("Failed running network: %s\n", e)
return 1
}
for i := range outputTensor.GetData() {
fmt.Printf("Output value %d: expected %f, got %f\n", i,
outputTensor.GetData()[i], testInputs.FlattenedOutput[i])
}
// Ordinarily, it is probably fine to call this using defer, but we do it // Ordinarily, it is probably fine to call this using defer, but we do it
// here just so we can print a status message after the cleanup completes. // here just so we can print a status message after the cleanup completes.
e = onnxruntime.CleanupEnvironment() e = onnxruntime.CleanupEnvironment()

View File

@@ -6,6 +6,7 @@ package onnxruntime
import ( import (
"fmt" "fmt"
"os"
"unsafe" "unsafe"
) )
@@ -161,6 +162,13 @@ func (t *Tensor[_]) GetShape() Shape {
return t.shape.Clone() return t.shape.Clone()
} }
// Makes a deep copy of the tensor, including its ONNXRuntime value. The Tensor
// returned by this function must be destroyed when no longer needed.
func (t *Tensor[T]) Clone() (*Tensor[T], error) {
// TODO: Implement Tensor.Clone()
return nil, fmt.Errorf("Not yet implemented")
}
// Creates a new empty tensor with the given shape. The shape provided to this // Creates a new empty tensor with the given shape. The shape provided to this
// function is copied, and is no longer needed after this function returns. // function is copied, and is no longer needed after this function returns.
func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error) { func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error) {
@@ -187,8 +195,8 @@ func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) {
dataType := GetTensorElementDataType[T]() dataType := GetTensorElementDataType[T]()
dataSize := unsafe.Sizeof(data[0]) * uintptr(elementCount) dataSize := unsafe.Sizeof(data[0]) * uintptr(elementCount)
status := C.CreateOrtTensorWithShape(unsafe.Pointer(&(data[0])), status := C.CreateOrtTensorWithShape(unsafe.Pointer(&data[0]),
C.size_t(dataSize), (*C.int64_t)(unsafe.Pointer(&(s[0]))), C.size_t(dataSize), (*C.int64_t)(unsafe.Pointer(&s[0])),
C.int64_t(len(s)), ortMemoryInfo, dataType, &ortValue) C.int64_t(len(s)), ortMemoryInfo, dataType, &ortValue)
if status != nil { if status != nil {
return nil, fmt.Errorf("ORT API error creating tensor: %s", return nil, fmt.Errorf("ORT API error creating tensor: %s",
@@ -200,42 +208,63 @@ func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) {
shape: s.Clone(), shape: s.Clone(),
ortValue: ortValue, ortValue: ortValue,
} }
// TODO (next): Set a finalizer on new Tensors.
// - Idea: use a "destroyable" interface
return &toReturn, nil return &toReturn, nil
} }
// The Session type, generally speaking, wraps operations supported by the // A simple wrapper around the OrtSession C struct. Requires the user to
// OrtSession struct in the C API, such as actually executing the network. We // maintain all input and output tensors, and to use the same data type for
// define it as an interface here. // input and output tensors.
type Session interface { type SimpleSession[T TensorData] struct {
// Executes the neural network, using the given (flattened) input data. The ortSession *C.OrtSession
// size of the input data must match the input size specified when the
// Session was created.
Run(input []float32) error
// Returns the shape of the output tensor obtained by running the network.
// Returns an error if one occurs, including if the network hasn't been run
// yet (i.e. the output shape is still unknown).
OutputShape() (Shape, error)
// Returns a slice of results of the neural network's most recent
// execution. Note that the slice returned by this may change the next time
// the neural network is run (as calling Run() typically shouldn't result
// in re-allocating the result slice).
GetResults() ([]float32, error)
// Copies the results of the neural network's execution into the given
// slice.
CopyResults(dst []float32) error
// Destroys the Session, cleaning up resources. Must be called when the
// session is no longer needed.
Destroy() error
} }
// TODO (next): Keep implementing CreateSimpleSession. // Loads the ONNX network at the given path, and initializes a SimpleSession
// - Allocate input and output tensors. // instance. If this returns successfully, the caller must call Destroy() on
// - Implement and use the CreateTensorWithShape function in onnxruntime_wrapper.c. // the returned session when it is no longer needed.
// - When calling Run() it will likely be an error if the tensors are the wrong shape!! func NewSimpleSession[T TensorData](onnxFilePath string) (*SimpleSession[T],
error) {
// We load content this way in order to avoid a mess of wide-character
// paths on Windows if we use CreateSession rather than
// CreateSessionFromArray.
fileContent, e := os.ReadFile(onnxFilePath)
if e != nil {
return nil, fmt.Errorf("Error reading %s: %w", onnxFilePath, e)
}
var ortSession *C.OrtSession
status := C.CreateSimpleSession(unsafe.Pointer(&(fileContent[0])),
C.size_t(len(fileContent)), ortEnv, &ortSession)
if status != nil {
return nil, fmt.Errorf("Error creating session from %s: %w",
onnxFilePath, statusToError(status))
}
// ONNXRuntime copies the file content unless a specific flag is provided
// when creating the session (and we don't provide it!)
fileContent = nil
return &SimpleSession[T]{
ortSession: ortSession,
}, nil
}
// We'll also allocate the input data and set up the input tensor here. func (s *SimpleSession[_]) Destroy() error {
if s.ortSession != nil {
C.ReleaseOrtSession(s.ortSession)
s.ortSession = nil
}
return nil
}
// This function assumes the SimpleSession takes a single input tensor and
// produces a single output, both of which have the same type.
func (s *SimpleSession[T]) SimpleRun(input *Tensor[T],
output *Tensor[T]) error {
status := C.RunSimpleSession(s.ortSession, input.ortValue,
output.ortValue)
if status != nil {
return fmt.Errorf("Error running network: %w", statusToError(status))
}
return nil
}
// TODO (next): Test SimpleRun

View File

@@ -2,6 +2,7 @@ package onnxruntime
import ( import (
"encoding/json" "encoding/json"
"fmt"
"os" "os"
"runtime" "runtime"
"testing" "testing"
@@ -10,10 +11,10 @@ import (
// This type is read from JSON and used to determine the inputs and expected // This type is read from JSON and used to determine the inputs and expected
// outputs for an ONNX network. // outputs for an ONNX network.
type testInputsInfo struct { type testInputsInfo struct {
InputShape []int `json:input_shape` InputShape []int64 `json:"input_shape"`
FlattenedInput []float32 `json:flattened_input` FlattenedInput []float32 `json:"flattened_input"`
OutputShape []int `json:output_shape` OutputShape []int64 `json:"output_shape"`
FlattenedOutput []float32 `json:flattened_output` FlattenedOutput []float32 `json:"flattened_output"`
} }
// This must be called prior to running each test. // This must be called prior to running each test.
@@ -52,6 +53,25 @@ func parseInputsJSON(path string, t *testing.T) *testInputsInfo {
return &toReturn return &toReturn
} }
// Returns an error if any element between a and b don't match.
func floatsEqual(a, b []float32) error {
if len(a) != len(b) {
return fmt.Errorf("Length mismatch: %d vs %d", len(a), len(b))
}
for i := range a {
diff := a[i] - b[i]
if diff < 0 {
diff = -diff
// Arbitrarily chosen precision.
if diff >= 0.00000001 {
return fmt.Errorf("Data element %d doesn't match: %f vs %v",
i, a[i], b[i])
}
}
}
return nil
}
func TestTensorTypes(t *testing.T) { func TestTensorTypes(t *testing.T) {
// It would be nice to compare this, but doing that would require exposing // It would be nice to compare this, but doing that would require exposing
// the underlying C types in Go; the testing package doesn't support cgo. // the underlying C types in Go; the testing package doesn't support cgo.
@@ -115,11 +135,42 @@ func TestCreateTensor(t *testing.T) {
func TestExampleNetwork(t *testing.T) { func TestExampleNetwork(t *testing.T) {
InitializeRuntime(t) InitializeRuntime(t)
_ = parseInputsJSON("test_data/example_network_results.json", t)
// TODO: More tests here to run the network, once that's supported. // Create input and output tensors
inputs := parseInputsJSON("test_data/example_network_results.json", t)
inputTensor, e := NewTensor(Shape(inputs.InputShape),
inputs.FlattenedInput)
if e != nil {
t.Logf("Failed creating input tensor: %s\n", e)
t.FailNow()
}
defer inputTensor.Destroy()
outputTensor, e := NewEmptyTensor[float32](Shape(inputs.OutputShape))
if e != nil {
t.Logf("Failed creating output tensor: %s\n", e)
t.FailNow()
}
defer outputTensor.Destroy()
e := CleanupEnvironment() // Set up and run the session.
session, e := NewSimpleSession[float32]("test_data/example_network.onnx")
if e != nil {
t.Logf("Failed creating simple session: %s\n", e)
t.FailNow()
}
defer session.Destroy()
e = session.SimpleRun(inputTensor, outputTensor)
if e != nil {
t.Logf("Failed to run the session: %s\n", e)
t.FailNow()
}
e = floatsEqual(outputTensor.GetData(), inputs.FlattenedOutput)
if e != nil {
t.Logf("The neural network didn't produce the correct result: %s\n", e)
t.FailNow()
}
e = CleanupEnvironment()
if e != nil { if e != nil {
t.Logf("Failed cleaning up the environment: %s\n", e) t.Logf("Failed cleaning up the environment: %s\n", e)
t.FailNow() t.FailNow()

View File

@@ -48,6 +48,17 @@ OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length,
return status; return status;
} }
OrtStatus *RunSimpleSession(OrtSession *session, OrtValue *input,
OrtValue *output) {
// TODO (next): We actually *do* need to pass in these dang input names.
const char *input_name[] = {"1x4 Input Vector"};
const char *output_name[] = {"1x2 Output Vector"};
OrtStatus *status = NULL;
status = ort_api->Run(session, NULL, input_name,
(const OrtValue* const*) &input, 1, output_name, 1, &output);
return status;
}
void ReleaseOrtSession(OrtSession *session) { void ReleaseOrtSession(OrtSession *session) {
ort_api->ReleaseSession(session); ort_api->ReleaseSession(session);
} }

View File

@@ -48,6 +48,10 @@ const char *GetErrorMessage(OrtStatus *status);
OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length, OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length,
OrtEnv *env, OrtSession **out); OrtEnv *env, OrtSession **out);
// Runs a session with single, user-allocated, input and output tensors.
OrtStatus *RunSimpleSession(OrtSession *session, OrtValue *input,
OrtValue *output);
// Wraps ort_api->ReleaseSession // Wraps ort_api->ReleaseSession
void ReleaseOrtSession(OrtSession *session); void ReleaseOrtSession(OrtSession *session);