mirror of
				https://github.com/yalue/onnxruntime_go.git
				synced 2025-10-31 10:46:24 +08:00 
			
		
		
		
	Can now load and run a network
- The library is now capable of loading the test session and running it successfully. - Updated the example application, which is really the same as the test and should probably be deleted. - TODO: The input names actually *do* need to match the names when the network was created. The hardcoded names in onnxruntime_wrapper.c will *not* do!
This commit is contained in:
		| @@ -3,12 +3,39 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/yalue/onnxruntime" | 	"github.com/yalue/onnxruntime" | ||||||
| 	"os" | 	"os" | ||||||
| 	"runtime" | 	"runtime" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // This type is read from JSON and used to determine the inputs and expected | ||||||
|  | // outputs for an ONNX network. | ||||||
|  | type testInputsInfo struct { | ||||||
|  | 	InputShape      []int64   `json:"input_shape"` | ||||||
|  | 	FlattenedInput  []float32 `json:"flattened_input"` | ||||||
|  | 	OutputShape     []int64   `json:"output_shape"` | ||||||
|  | 	FlattenedOutput []float32 `json:"flattened_output"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Loads JSON that contains the shapes and data used by the test ONNX network. | ||||||
|  | // Requires the path to the JSON file. | ||||||
|  | func loadInputsJSON(path string) (*testInputsInfo, error) { | ||||||
|  | 	f, e := os.Open(path) | ||||||
|  | 	if e != nil { | ||||||
|  | 		return nil, fmt.Errorf("Error opening %s: %w", path, e) | ||||||
|  | 	} | ||||||
|  | 	defer f.Close() | ||||||
|  | 	d := json.NewDecoder(f) | ||||||
|  | 	var toReturn testInputsInfo | ||||||
|  | 	e = d.Decode(&toReturn) | ||||||
|  | 	if e != nil { | ||||||
|  | 		return nil, fmt.Errorf("Error decoding %s: %w", path, e) | ||||||
|  | 	} | ||||||
|  | 	return &toReturn, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func run() int { | func run() int { | ||||||
| 	if runtime.GOOS == "windows" { | 	if runtime.GOOS == "windows" { | ||||||
| 		onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.dll") | 		onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.dll") | ||||||
| @@ -26,6 +53,52 @@ func run() int { | |||||||
| 	} | 	} | ||||||
| 	fmt.Printf("The onnxruntime environment initialized OK.\n") | 	fmt.Printf("The onnxruntime environment initialized OK.\n") | ||||||
|  |  | ||||||
|  | 	// Load the JSON with the test input and output data. | ||||||
|  | 	testInputs, e := loadInputsJSON( | ||||||
|  | 		"../test_data/example_network_results.json") | ||||||
|  | 	if e != nil { | ||||||
|  | 		fmt.Printf("Error reading example inputs from JSON: %s\n", e) | ||||||
|  | 		return 1 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create the session with the test onnx network | ||||||
|  | 	session, e := onnxruntime.NewSimpleSession[float32]( | ||||||
|  | 		"../test_data/example_network.onnx") | ||||||
|  | 	if e != nil { | ||||||
|  | 		fmt.Printf("Error initializing the ONNX session: %s\n", e) | ||||||
|  | 		return 1 | ||||||
|  | 	} | ||||||
|  | 	defer session.Destroy() | ||||||
|  |  | ||||||
|  | 	// Create input and output tensors. | ||||||
|  | 	inputShape := onnxruntime.Shape(testInputs.InputShape) | ||||||
|  | 	outputShape := onnxruntime.Shape(testInputs.OutputShape) | ||||||
|  | 	inputTensor, e := onnxruntime.NewTensor(inputShape, | ||||||
|  | 		testInputs.FlattenedInput) | ||||||
|  | 	if e != nil { | ||||||
|  | 		fmt.Printf("Failed getting input tensor: %s\n", e) | ||||||
|  | 		return 1 | ||||||
|  | 	} | ||||||
|  | 	defer inputTensor.Destroy() | ||||||
|  | 	outputTensor, e := onnxruntime.NewEmptyTensor[float32](outputShape) | ||||||
|  | 	if e != nil { | ||||||
|  | 		fmt.Printf("Failed creating output tensor: %s\n", e) | ||||||
|  | 		return 1 | ||||||
|  | 	} | ||||||
|  | 	defer outputTensor.Destroy() | ||||||
|  |  | ||||||
|  | 	// Actually run the network. | ||||||
|  | 	e = session.SimpleRun(inputTensor, outputTensor) | ||||||
|  | 	if e != nil { | ||||||
|  | 		fmt.Printf("Failed running network: %s\n", e) | ||||||
|  | 		return 1 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i := range outputTensor.GetData() { | ||||||
|  | 		fmt.Printf("Output value %d: expected %f, got %f\n", i, | ||||||
|  | 			outputTensor.GetData()[i], testInputs.FlattenedOutput[i]) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Ordinarily, it is probably fine to call this using defer, but we do it | 	// Ordinarily, it is probably fine to call this using defer, but we do it | ||||||
| 	// here just so we can print a status message after the cleanup completes. | 	// here just so we can print a status message after the cleanup completes. | ||||||
| 	e = onnxruntime.CleanupEnvironment() | 	e = onnxruntime.CleanupEnvironment() | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ package onnxruntime | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"os" | ||||||
| 	"unsafe" | 	"unsafe" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -161,6 +162,13 @@ func (t *Tensor[_]) GetShape() Shape { | |||||||
| 	return t.shape.Clone() | 	return t.shape.Clone() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Makes a deep copy of the tensor, including its ONNXRuntime value. The Tensor | ||||||
|  | // returned by this function must be destroyed when no longer needed. | ||||||
|  | func (t *Tensor[T]) Clone() (*Tensor[T], error) { | ||||||
|  | 	// TODO: Implement Tensor.Clone() | ||||||
|  | 	return nil, fmt.Errorf("Not yet implemented") | ||||||
|  | } | ||||||
|  |  | ||||||
| // Creates a new empty tensor with the given shape. The shape provided to this | // Creates a new empty tensor with the given shape. The shape provided to this | ||||||
| // function is copied, and is no longer needed after this function returns. | // function is copied, and is no longer needed after this function returns. | ||||||
| func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error) { | func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error) { | ||||||
| @@ -187,8 +195,8 @@ func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) { | |||||||
| 	dataType := GetTensorElementDataType[T]() | 	dataType := GetTensorElementDataType[T]() | ||||||
| 	dataSize := unsafe.Sizeof(data[0]) * uintptr(elementCount) | 	dataSize := unsafe.Sizeof(data[0]) * uintptr(elementCount) | ||||||
|  |  | ||||||
| 	status := C.CreateOrtTensorWithShape(unsafe.Pointer(&(data[0])), | 	status := C.CreateOrtTensorWithShape(unsafe.Pointer(&data[0]), | ||||||
| 		C.size_t(dataSize), (*C.int64_t)(unsafe.Pointer(&(s[0]))), | 		C.size_t(dataSize), (*C.int64_t)(unsafe.Pointer(&s[0])), | ||||||
| 		C.int64_t(len(s)), ortMemoryInfo, dataType, &ortValue) | 		C.int64_t(len(s)), ortMemoryInfo, dataType, &ortValue) | ||||||
| 	if status != nil { | 	if status != nil { | ||||||
| 		return nil, fmt.Errorf("ORT API error creating tensor: %s", | 		return nil, fmt.Errorf("ORT API error creating tensor: %s", | ||||||
| @@ -200,42 +208,63 @@ func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) { | |||||||
| 		shape:    s.Clone(), | 		shape:    s.Clone(), | ||||||
| 		ortValue: ortValue, | 		ortValue: ortValue, | ||||||
| 	} | 	} | ||||||
|  | 	// TODO (next): Set a finalizer on new Tensors. | ||||||
|  | 	// - Idea: use a "destroyable" interface | ||||||
| 	return &toReturn, nil | 	return &toReturn, nil | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // The Session type, generally speaking, wraps operations supported by the | // A simple wrapper around the OrtSession C struct. Requires the user to | ||||||
| // OrtSession struct in the C API, such as actually executing the network. We | // maintain all input and output tensors, and to use the same data type for | ||||||
| // define it as an interface here. | // input and output tensors. | ||||||
| type Session interface { | type SimpleSession[T TensorData] struct { | ||||||
| 	// Executes the neural network, using the given (flattened) input data. The | 	ortSession *C.OrtSession | ||||||
| 	// size of the input data must match the input size specified when the |  | ||||||
| 	// Session was created. |  | ||||||
| 	Run(input []float32) error |  | ||||||
|  |  | ||||||
| 	// Returns the shape of the output tensor obtained by running the network. |  | ||||||
| 	// Returns an error if one occurs, including if the network hasn't been run |  | ||||||
| 	// yet (i.e. the output shape is still unknown). |  | ||||||
| 	OutputShape() (Shape, error) |  | ||||||
|  |  | ||||||
| 	// Returns a slice of results of the neural network's most recent |  | ||||||
| 	// execution. Note that the slice returned by this may change the next time |  | ||||||
| 	// the neural network is run (as calling Run() typically shouldn't result |  | ||||||
| 	// in re-allocating the result slice). |  | ||||||
| 	GetResults() ([]float32, error) |  | ||||||
|  |  | ||||||
| 	// Copies the results of the neural network's execution into the given |  | ||||||
| 	// slice. |  | ||||||
| 	CopyResults(dst []float32) error |  | ||||||
|  |  | ||||||
| 	// Destroys the Session, cleaning up resources. Must be called when the |  | ||||||
| 	// session is no longer needed. |  | ||||||
| 	Destroy() error |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // TODO (next): Keep implementing CreateSimpleSession. | // Loads the ONNX network at the given path, and initializes a SimpleSession | ||||||
| // - Allocate input and output tensors. | // instance. If this returns successfully, the caller must call Destroy() on | ||||||
| // - Implement and use the CreateTensorWithShape function in onnxruntime_wrapper.c. | // the returned session when it is no longer needed. | ||||||
| // - When calling Run() it will likely be an error if the tensors are the wrong shape!! | func NewSimpleSession[T TensorData](onnxFilePath string) (*SimpleSession[T], | ||||||
|  | 	error) { | ||||||
|  | 	// We load content this way in order to avoid a mess of wide-character | ||||||
|  | 	// paths on Windows if we use CreateSession rather than | ||||||
|  | 	// CreateSessionFromArray. | ||||||
|  | 	fileContent, e := os.ReadFile(onnxFilePath) | ||||||
|  | 	if e != nil { | ||||||
|  | 		return nil, fmt.Errorf("Error reading %s: %w", onnxFilePath, e) | ||||||
|  | 	} | ||||||
|  | 	var ortSession *C.OrtSession | ||||||
|  | 	status := C.CreateSimpleSession(unsafe.Pointer(&(fileContent[0])), | ||||||
|  | 		C.size_t(len(fileContent)), ortEnv, &ortSession) | ||||||
|  | 	if status != nil { | ||||||
|  | 		return nil, fmt.Errorf("Error creating session from %s: %w", | ||||||
|  | 			onnxFilePath, statusToError(status)) | ||||||
|  | 	} | ||||||
|  | 	// ONNXRuntime copies the file content unless a specific flag is provided | ||||||
|  | 	// when creating the session (and we don't provide it!) | ||||||
|  | 	fileContent = nil | ||||||
|  | 	return &SimpleSession[T]{ | ||||||
|  | 		ortSession: ortSession, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| // We'll also allocate the input data and set up the input tensor here. | func (s *SimpleSession[_]) Destroy() error { | ||||||
|  | 	if s.ortSession != nil { | ||||||
|  | 		C.ReleaseOrtSession(s.ortSession) | ||||||
|  | 		s.ortSession = nil | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // This function assumes the SimpleSession takes a single input tensor and | ||||||
|  | // produces a single output, both of which have the same type. | ||||||
|  | func (s *SimpleSession[T]) SimpleRun(input *Tensor[T], | ||||||
|  | 	output *Tensor[T]) error { | ||||||
|  | 	status := C.RunSimpleSession(s.ortSession, input.ortValue, | ||||||
|  | 		output.ortValue) | ||||||
|  | 	if status != nil { | ||||||
|  | 		return fmt.Errorf("Error running network: %w", statusToError(status)) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TODO (next): Test SimpleRun | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package onnxruntime | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
| 	"os" | 	"os" | ||||||
| 	"runtime" | 	"runtime" | ||||||
| 	"testing" | 	"testing" | ||||||
| @@ -10,10 +11,10 @@ import ( | |||||||
| // This type is read from JSON and used to determine the inputs and expected | // This type is read from JSON and used to determine the inputs and expected | ||||||
| // outputs for an ONNX network. | // outputs for an ONNX network. | ||||||
| type testInputsInfo struct { | type testInputsInfo struct { | ||||||
| 	InputShape      []int     `json:input_shape` | 	InputShape      []int64   `json:"input_shape"` | ||||||
| 	FlattenedInput  []float32 `json:flattened_input` | 	FlattenedInput  []float32 `json:"flattened_input"` | ||||||
| 	OutputShape     []int     `json:output_shape` | 	OutputShape     []int64   `json:"output_shape"` | ||||||
| 	FlattenedOutput []float32 `json:flattened_output` | 	FlattenedOutput []float32 `json:"flattened_output"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // This must be called prior to running each test. | // This must be called prior to running each test. | ||||||
| @@ -52,6 +53,25 @@ func parseInputsJSON(path string, t *testing.T) *testInputsInfo { | |||||||
| 	return &toReturn | 	return &toReturn | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Returns an error if any element between a and b don't match. | ||||||
|  | func floatsEqual(a, b []float32) error { | ||||||
|  | 	if len(a) != len(b) { | ||||||
|  | 		return fmt.Errorf("Length mismatch: %d vs %d", len(a), len(b)) | ||||||
|  | 	} | ||||||
|  | 	for i := range a { | ||||||
|  | 		diff := a[i] - b[i] | ||||||
|  | 		if diff < 0 { | ||||||
|  | 			diff = -diff | ||||||
|  | 			// Arbitrarily chosen precision. | ||||||
|  | 			if diff >= 0.00000001 { | ||||||
|  | 				return fmt.Errorf("Data element %d doesn't match: %f vs %v", | ||||||
|  | 					i, a[i], b[i]) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestTensorTypes(t *testing.T) { | func TestTensorTypes(t *testing.T) { | ||||||
| 	// It would be nice to compare this, but doing that would require exposing | 	// It would be nice to compare this, but doing that would require exposing | ||||||
| 	// the underlying C types in Go; the testing package doesn't support cgo. | 	// the underlying C types in Go; the testing package doesn't support cgo. | ||||||
| @@ -115,11 +135,42 @@ func TestCreateTensor(t *testing.T) { | |||||||
|  |  | ||||||
| func TestExampleNetwork(t *testing.T) { | func TestExampleNetwork(t *testing.T) { | ||||||
| 	InitializeRuntime(t) | 	InitializeRuntime(t) | ||||||
| 	_ = parseInputsJSON("test_data/example_network_results.json", t) |  | ||||||
|  |  | ||||||
| 	// TODO: More tests here to run the network, once that's supported. | 	// Create input and output tensors | ||||||
|  | 	inputs := parseInputsJSON("test_data/example_network_results.json", t) | ||||||
|  | 	inputTensor, e := NewTensor(Shape(inputs.InputShape), | ||||||
|  | 		inputs.FlattenedInput) | ||||||
|  | 	if e != nil { | ||||||
|  | 		t.Logf("Failed creating input tensor: %s\n", e) | ||||||
|  | 		t.FailNow() | ||||||
|  | 	} | ||||||
|  | 	defer inputTensor.Destroy() | ||||||
|  | 	outputTensor, e := NewEmptyTensor[float32](Shape(inputs.OutputShape)) | ||||||
|  | 	if e != nil { | ||||||
|  | 		t.Logf("Failed creating output tensor: %s\n", e) | ||||||
|  | 		t.FailNow() | ||||||
|  | 	} | ||||||
|  | 	defer outputTensor.Destroy() | ||||||
|  |  | ||||||
| 	e := CleanupEnvironment() | 	// Set up and run the session. | ||||||
|  | 	session, e := NewSimpleSession[float32]("test_data/example_network.onnx") | ||||||
|  | 	if e != nil { | ||||||
|  | 		t.Logf("Failed creating simple session: %s\n", e) | ||||||
|  | 		t.FailNow() | ||||||
|  | 	} | ||||||
|  | 	defer session.Destroy() | ||||||
|  | 	e = session.SimpleRun(inputTensor, outputTensor) | ||||||
|  | 	if e != nil { | ||||||
|  | 		t.Logf("Failed to run the session: %s\n", e) | ||||||
|  | 		t.FailNow() | ||||||
|  | 	} | ||||||
|  | 	e = floatsEqual(outputTensor.GetData(), inputs.FlattenedOutput) | ||||||
|  | 	if e != nil { | ||||||
|  | 		t.Logf("The neural network didn't produce the correct result: %s\n", e) | ||||||
|  | 		t.FailNow() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	e = CleanupEnvironment() | ||||||
| 	if e != nil { | 	if e != nil { | ||||||
| 		t.Logf("Failed cleaning up the environment: %s\n", e) | 		t.Logf("Failed cleaning up the environment: %s\n", e) | ||||||
| 		t.FailNow() | 		t.FailNow() | ||||||
|   | |||||||
| @@ -48,6 +48,17 @@ OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length, | |||||||
|   return status; |   return status; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | OrtStatus *RunSimpleSession(OrtSession *session, OrtValue *input, | ||||||
|  |   OrtValue *output) { | ||||||
|  |   // TODO (next): We actually *do* need to pass in these dang input names. | ||||||
|  |   const char *input_name[] = {"1x4 Input Vector"}; | ||||||
|  |   const char *output_name[] = {"1x2 Output Vector"}; | ||||||
|  |   OrtStatus *status = NULL; | ||||||
|  |   status = ort_api->Run(session, NULL, input_name, | ||||||
|  |     (const OrtValue* const*) &input, 1, output_name, 1, &output); | ||||||
|  |   return status; | ||||||
|  | } | ||||||
|  |  | ||||||
| void ReleaseOrtSession(OrtSession *session) { | void ReleaseOrtSession(OrtSession *session) { | ||||||
|   ort_api->ReleaseSession(session); |   ort_api->ReleaseSession(session); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -48,6 +48,10 @@ const char *GetErrorMessage(OrtStatus *status); | |||||||
| OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length, | OrtStatus *CreateSimpleSession(void *model_data, size_t model_data_length, | ||||||
|   OrtEnv *env, OrtSession **out); |   OrtEnv *env, OrtSession **out); | ||||||
|  |  | ||||||
|  | // Runs a session with single, user-allocated, input and output tensors. | ||||||
|  | OrtStatus *RunSimpleSession(OrtSession *session, OrtValue *input, | ||||||
|  |   OrtValue *output); | ||||||
|  |  | ||||||
| // Wraps ort_api->ReleaseSession | // Wraps ort_api->ReleaseSession | ||||||
| void ReleaseOrtSession(OrtSession *session); | void ReleaseOrtSession(OrtSession *session); | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 yalue
					yalue