diff --git a/onnx_example_application/onnx_example_application.go b/onnx_example_application/onnx_example_application.go index f66db51..63cbfa8 100644 --- a/onnx_example_application/onnx_example_application.go +++ b/onnx_example_application/onnx_example_application.go @@ -3,12 +3,39 @@ package main import ( + "encoding/json" "fmt" "github.com/yalue/onnxruntime" "os" "runtime" ) +// This type is read from JSON and used to determine the inputs and expected +// outputs for an ONNX network. +type testInputsInfo struct { + InputShape []int64 `json:"input_shape"` + FlattenedInput []float32 `json:"flattened_input"` + OutputShape []int64 `json:"output_shape"` + FlattenedOutput []float32 `json:"flattened_output"` +} + +// Loads JSON that contains the shapes and data used by the test ONNX network. +// Requires the path to the JSON file. +func loadInputsJSON(path string) (*testInputsInfo, error) { + f, e := os.Open(path) + if e != nil { + return nil, fmt.Errorf("Error opening %s: %w", path, e) + } + defer f.Close() + d := json.NewDecoder(f) + var toReturn testInputsInfo + e = d.Decode(&toReturn) + if e != nil { + return nil, fmt.Errorf("Error decoding %s: %w", path, e) + } + return &toReturn, nil +} + func run() int { if runtime.GOOS == "windows" { onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.dll") @@ -26,6 +53,52 @@ func run() int { } fmt.Printf("The onnxruntime environment initialized OK.\n") + // Load the JSON with the test input and output data. + testInputs, e := loadInputsJSON( + "../test_data/example_network_results.json") + if e != nil { + fmt.Printf("Error reading example inputs from JSON: %s\n", e) + return 1 + } + + // Create the session with the test onnx network + session, e := onnxruntime.NewSimpleSession[float32]( + "../test_data/example_network.onnx") + if e != nil { + fmt.Printf("Error initializing the ONNX session: %s\n", e) + return 1 + } + defer session.Destroy() + + // Create input and output tensors. + inputShape := onnxruntime.Shape(testInputs.InputShape) + outputShape := onnxruntime.Shape(testInputs.OutputShape) + inputTensor, e := onnxruntime.NewTensor(inputShape, + testInputs.FlattenedInput) + if e != nil { + fmt.Printf("Failed getting input tensor: %s\n", e) + return 1 + } + defer inputTensor.Destroy() + outputTensor, e := onnxruntime.NewEmptyTensor[float32](outputShape) + if e != nil { + fmt.Printf("Failed creating output tensor: %s\n", e) + return 1 + } + defer outputTensor.Destroy() + + // Actually run the network. + e = session.SimpleRun(inputTensor, outputTensor) + if e != nil { + fmt.Printf("Failed running network: %s\n", e) + return 1 + } + + for i := range outputTensor.GetData() { + fmt.Printf("Output value %d: expected %f, got %f\n", i, + outputTensor.GetData()[i], testInputs.FlattenedOutput[i]) + } + // Ordinarily, it is probably fine to call this using defer, but we do it // here just so we can print a status message after the cleanup completes. e = onnxruntime.CleanupEnvironment() diff --git a/onnxruntime.go b/onnxruntime.go index 435ac09..4130c29 100644 --- a/onnxruntime.go +++ b/onnxruntime.go @@ -6,6 +6,7 @@ package onnxruntime import ( "fmt" + "os" "unsafe" ) @@ -161,6 +162,13 @@ func (t *Tensor[_]) GetShape() Shape { return t.shape.Clone() } +// Makes a deep copy of the tensor, including its ONNXRuntime value. The Tensor +// returned by this function must be destroyed when no longer needed. +func (t *Tensor[T]) Clone() (*Tensor[T], error) { + // TODO: Implement Tensor.Clone() + return nil, fmt.Errorf("Not yet implemented") +} + // Creates a new empty tensor with the given shape. The shape provided to this // function is copied, and is no longer needed after this function returns. func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error) { @@ -187,8 +195,8 @@ func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) { dataType := GetTensorElementDataType[T]() dataSize := unsafe.Sizeof(data[0]) * uintptr(elementCount) - status := C.CreateOrtTensorWithShape(unsafe.Pointer(&(data[0])), - C.size_t(dataSize), (*C.int64_t)(unsafe.Pointer(&(s[0]))), + status := C.CreateOrtTensorWithShape(unsafe.Pointer(&data[0]), + C.size_t(dataSize), (*C.int64_t)(unsafe.Pointer(&s[0])), C.int64_t(len(s)), ortMemoryInfo, dataType, &ortValue) if status != nil { return nil, fmt.Errorf("ORT API error creating tensor: %s", @@ -200,42 +208,63 @@ func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) { shape: s.Clone(), ortValue: ortValue, } + // TODO (next): Set a finalizer on new Tensors. + // - Idea: use a "destroyable" interface return &toReturn, nil - } -// The Session type, generally speaking, wraps operations supported by the -// OrtSession struct in the C API, such as actually executing the network. We -// define it as an interface here. -type Session interface { - // Executes the neural network, using the given (flattened) input data. The - // size of the input data must match the input size specified when the - // Session was created. - Run(input []float32) error - - // Returns the shape of the output tensor obtained by running the network. - // Returns an error if one occurs, including if the network hasn't been run - // yet (i.e. the output shape is still unknown). - OutputShape() (Shape, error) - - // Returns a slice of results of the neural network's most recent - // execution. Note that the slice returned by this may change the next time - // the neural network is run (as calling Run() typically shouldn't result - // in re-allocating the result slice). - GetResults() ([]float32, error) - - // Copies the results of the neural network's execution into the given - // slice. - CopyResults(dst []float32) error - - // Destroys the Session, cleaning up resources. Must be called when the - // session is no longer needed. - Destroy() error +// A simple wrapper around the OrtSession C struct. Requires the user to +// maintain all input and output tensors, and to use the same data type for +// input and output tensors. +type SimpleSession[T TensorData] struct { + ortSession *C.OrtSession } -// TODO (next): Keep implementing CreateSimpleSession. -// - Allocate input and output tensors. -// - Implement and use the CreateTensorWithShape function in onnxruntime_wrapper.c. -// - When calling Run() it will likely be an error if the tensors are the wrong shape!! +// Loads the ONNX network at the given path, and initializes a SimpleSession +// instance. If this returns successfully, the caller must call Destroy() on +// the returned session when it is no longer needed. +func NewSimpleSession[T TensorData](onnxFilePath string) (*SimpleSession[T], + error) { + // We load content this way in order to avoid a mess of wide-character + // paths on Windows if we use CreateSession rather than + // CreateSessionFromArray. + fileContent, e := os.ReadFile(onnxFilePath) + if e != nil { + return nil, fmt.Errorf("Error reading %s: %w", onnxFilePath, e) + } + var ortSession *C.OrtSession + status := C.CreateSimpleSession(unsafe.Pointer(&(fileContent[0])), + C.size_t(len(fileContent)), ortEnv, &ortSession) + if status != nil { + return nil, fmt.Errorf("Error creating session from %s: %w", + onnxFilePath, statusToError(status)) + } + // ONNXRuntime copies the file content unless a specific flag is provided + // when creating the session (and we don't provide it!) + fileContent = nil + return &SimpleSession[T]{ + ortSession: ortSession, + }, nil +} -// We'll also allocate the input data and set up the input tensor here. +func (s *SimpleSession[_]) Destroy() error { + if s.ortSession != nil { + C.ReleaseOrtSession(s.ortSession) + s.ortSession = nil + } + return nil +} + +// This function assumes the SimpleSession takes a single input tensor and +// produces a single output, both of which have the same type. +func (s *SimpleSession[T]) SimpleRun(input *Tensor[T], + output *Tensor[T]) error { + status := C.RunSimpleSession(s.ortSession, input.ortValue, + output.ortValue) + if status != nil { + return fmt.Errorf("Error running network: %w", statusToError(status)) + } + return nil +} + +// TODO (next): Test SimpleRun diff --git a/onnxruntime_test.go b/onnxruntime_test.go index 18ad530..714c4ba 100644 --- a/onnxruntime_test.go +++ b/onnxruntime_test.go @@ -2,6 +2,7 @@ package onnxruntime import ( "encoding/json" + "fmt" "os" "runtime" "testing" @@ -10,10 +11,10 @@ import ( // This type is read from JSON and used to determine the inputs and expected // outputs for an ONNX network. type testInputsInfo struct { - InputShape []int `json:input_shape` - FlattenedInput []float32 `json:flattened_input` - OutputShape []int `json:output_shape` - FlattenedOutput []float32 `json:flattened_output` + InputShape []int64 `json:"input_shape"` + FlattenedInput []float32 `json:"flattened_input"` + OutputShape []int64 `json:"output_shape"` + FlattenedOutput []float32 `json:"flattened_output"` } // This must be called prior to running each test. @@ -52,6 +53,25 @@ func parseInputsJSON(path string, t *testing.T) *testInputsInfo { return &toReturn } +// Returns an error if any element between a and b don't match. +func floatsEqual(a, b []float32) error { + if len(a) != len(b) { + return fmt.Errorf("Length mismatch: %d vs %d", len(a), len(b)) + } + for i := range a { + diff := a[i] - b[i] + if diff < 0 { + diff = -diff + // Arbitrarily chosen precision. + if diff >= 0.00000001 { + return fmt.Errorf("Data element %d doesn't match: %f vs %v", + i, a[i], b[i]) + } + } + } + return nil +} + func TestTensorTypes(t *testing.T) { // It would be nice to compare this, but doing that would require exposing // the underlying C types in Go; the testing package doesn't support cgo. @@ -115,11 +135,42 @@ func TestCreateTensor(t *testing.T) { func TestExampleNetwork(t *testing.T) { InitializeRuntime(t) - _ = parseInputsJSON("test_data/example_network_results.json", t) - // TODO: More tests here to run the network, once that's supported. + // Create input and output tensors + inputs := parseInputsJSON("test_data/example_network_results.json", t) + inputTensor, e := NewTensor(Shape(inputs.InputShape), + inputs.FlattenedInput) + if e != nil { + t.Logf("Failed creating input tensor: %s\n", e) + t.FailNow() + } + defer inputTensor.Destroy() + outputTensor, e := NewEmptyTensor[float32](Shape(inputs.OutputShape)) + if e != nil { + t.Logf("Failed creating output tensor: %s\n", e) + t.FailNow() + } + defer outputTensor.Destroy() - e := CleanupEnvironment() + // Set up and run the session. + session, e := NewSimpleSession[float32]("test_data/example_network.onnx") + if e != nil { + t.Logf("Failed creating simple session: %s\n", e) + t.FailNow() + } + defer session.Destroy() + e = session.SimpleRun(inputTensor, outputTensor) + if e != nil { + t.Logf("Failed to run the session: %s\n", e) + t.FailNow() + } + e = floatsEqual(outputTensor.GetData(), inputs.FlattenedOutput) + if e != nil { + t.Logf("The neural network didn't produce the correct result: %s\n", e) + t.FailNow() + } + + e = CleanupEnvironment() if e != nil { t.Logf("Failed cleaning up the environment: %s\n", e) t.FailNow() diff --git a/onnxruntime_wrapper.c b/onnxruntime_wrapper.c index 168b4a9..5918946 100644 --- a/onnxruntime_wrapper.c +++ b/onnxruntime_wrapper.c @@ -48,6 +48,17 @@ OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length, return status; } +OrtStatus *RunSimpleSession(OrtSession *session, OrtValue *input, + OrtValue *output) { + // TODO (next): We actually *do* need to pass in these dang input names. + const char *input_name[] = {"1x4 Input Vector"}; + const char *output_name[] = {"1x2 Output Vector"}; + OrtStatus *status = NULL; + status = ort_api->Run(session, NULL, input_name, + (const OrtValue* const*) &input, 1, output_name, 1, &output); + return status; +} + void ReleaseOrtSession(OrtSession *session) { ort_api->ReleaseSession(session); } diff --git a/onnxruntime_wrapper.h b/onnxruntime_wrapper.h index 49b9bfd..4a20765 100644 --- a/onnxruntime_wrapper.h +++ b/onnxruntime_wrapper.h @@ -48,6 +48,10 @@ const char *GetErrorMessage(OrtStatus *status); OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length, OrtEnv *env, OrtSession **out); +// Runs a session with single, user-allocated, input and output tensors. +OrtStatus *RunSimpleSession(OrtSession *session, OrtValue *input, + OrtValue *output); + // Wraps ort_api->ReleaseSession void ReleaseOrtSession(OrtSession *session);