mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-11-02 19:44:00 +08:00
- Named the cleanup function something more consistent with the other cleanup functions. - Fixed a couple bugs where deferred stuff was executed after the environment was destroyed, causing a segfault.
271 lines
8.3 KiB
Go
271 lines
8.3 KiB
Go
// This library wraps the C "onnxruntime" library maintained at
|
|
// https://github.com/microsoft/onnxruntime. It seeks to provide as simple an
|
|
// interface as possible to load and run ONNX-format neural networks from
|
|
// Go code.
|
|
package onnxruntime
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"unsafe"
|
|
)
|
|
|
|
// #cgo CFLAGS: -I${SRCDIR}/onnxruntime/include
|
|
//
|
|
// #include "onnxruntime_wrapper.h"
|
|
import "C"
|
|
|
|
// This string should be the path to onnxruntime.so, or onnxruntime.dll.
|
|
var onnxSharedLibraryPath string
|
|
|
|
// For simplicity, this library maintains a single ORT environment internally.
|
|
var ortEnv *C.OrtEnv
|
|
|
|
// We also keep a single OrtMemoryInfo value around, since we only support CPU
|
|
// allocations for now.
|
|
var ortMemoryInfo *C.OrtMemoryInfo
|
|
|
|
// Does two things: converts the given OrtStatus to a Go error, and releases
|
|
// the status. If the status is nil, this does nothing and returns nil.
|
|
func statusToError(status *C.OrtStatus) error {
|
|
if status == nil {
|
|
return nil
|
|
}
|
|
msg := C.GetErrorMessage(status)
|
|
toReturn := C.GoString(msg)
|
|
C.ReleaseOrtStatus(status)
|
|
return fmt.Errorf("%s", toReturn)
|
|
}
|
|
|
|
// Use this function to set the path to the "onnxruntime.so" or
|
|
// "onnxruntime.dll" function. By default, it will be set to "onnxruntime.so"
|
|
// on non-Windows systems, and "onnxruntime.dll" on Windows. Users wishing to
|
|
// specify a particular location of this library must call this function prior
|
|
// to calling onnxruntime.InitializeEnvironment().
|
|
func SetSharedLibraryPath(path string) {
|
|
onnxSharedLibraryPath = path
|
|
}
|
|
|
|
// Call this function to initialize the internal onnxruntime environment. If
|
|
// this doesn't return an error, the caller will be responsible for calling
|
|
// DestroyEnvironment to free the onnxruntime state when no longer needed.
|
|
func InitializeEnvironment() error {
|
|
if ortEnv != nil {
|
|
return fmt.Errorf("The onnxruntime has already been initialized")
|
|
}
|
|
// Do the windows- or linux- specific initialization first.
|
|
e := platformInitializeEnvironment()
|
|
if e != nil {
|
|
return fmt.Errorf("Platform-specific initialization failed: %w", e)
|
|
}
|
|
|
|
name := C.CString("Golang onnxruntime environment")
|
|
defer C.free(unsafe.Pointer(name))
|
|
status := C.CreateOrtEnv(name, &ortEnv)
|
|
if status != nil {
|
|
return fmt.Errorf("Error creating ORT environment: %w",
|
|
statusToError(status))
|
|
}
|
|
|
|
status = C.CreateOrtMemoryInfo(&ortMemoryInfo)
|
|
if status != nil {
|
|
DestroyEnvironment()
|
|
return fmt.Errorf("Error creating ORT memory info: %w",
|
|
statusToError(status))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Call this function to cleanup the internal onnxruntime environment when it
|
|
// is no longer needed.
|
|
func DestroyEnvironment() error {
|
|
var e error
|
|
if ortMemoryInfo != nil {
|
|
C.ReleaseOrtMemoryInfo(ortMemoryInfo)
|
|
ortMemoryInfo = nil
|
|
}
|
|
if ortEnv != nil {
|
|
C.ReleaseOrtEnv(ortEnv)
|
|
ortEnv = nil
|
|
}
|
|
|
|
// platformCleanup primarily unloads the library, so we need to call it
|
|
// last, after any functions that make use of the ORT API.
|
|
e = platformCleanup()
|
|
if e != nil {
|
|
return fmt.Errorf("Platform-specific cleanup failed: %w", e)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// The Shape type holds the shape of the tensors used by the network input and
|
|
// outputs.
|
|
type Shape []int64
|
|
|
|
// Returns a Shape, with the given dimensions.
|
|
func NewShape(dimensions ...int64) Shape {
|
|
return Shape(dimensions)
|
|
}
|
|
|
|
// Returns the total number of elements in a tensor with the given shape.
|
|
func (s Shape) FlattenedSize() int64 {
|
|
if len(s) == 0 {
|
|
return 0
|
|
}
|
|
toReturn := int64(s[0])
|
|
for i := 1; i < len(s); i++ {
|
|
toReturn *= s[i]
|
|
}
|
|
return toReturn
|
|
}
|
|
|
|
// Makes and returns a deep copy of the Shape.
|
|
func (s Shape) Clone() Shape {
|
|
toReturn := make([]int64, len(s))
|
|
copy(toReturn, []int64(s))
|
|
return Shape(toReturn)
|
|
}
|
|
|
|
func (s Shape) String() string {
|
|
return fmt.Sprintf("%v", []int64(s))
|
|
}
|
|
|
|
type Tensor[T TensorData] struct {
|
|
// The shape of the tensor
|
|
shape Shape
|
|
// The go slice containing the flattened data that backs the ONNX tensor.
|
|
data []T
|
|
// The underlying ONNX value we use with the C API.
|
|
ortValue *C.OrtValue
|
|
}
|
|
|
|
// Cleans up and frees the memory associated with this tensor.
|
|
func (t *Tensor[_]) Destroy() {
|
|
C.ReleaseOrtValue(t.ortValue)
|
|
t.ortValue = nil
|
|
t.data = nil
|
|
t.shape = nil
|
|
}
|
|
|
|
// Returns the slice containing the tensor's underlying data. The contents of
|
|
// the slice can be read or written to get or set the tensor's contents.
|
|
func (t *Tensor[T]) GetData() []T {
|
|
return t.data
|
|
}
|
|
|
|
// Returns the shape of the tensor. The returned shape is only a copy;
|
|
// modifying this does *not* change the shape of the underlying tensor.
|
|
// (Modifying the tensor's shape can only be accomplished by Destroying and
|
|
// recreating the tensor with the same data.)
|
|
func (t *Tensor[_]) GetShape() Shape {
|
|
return t.shape.Clone()
|
|
}
|
|
|
|
// Makes a deep copy of the tensor, including its ONNXRuntime value. The Tensor
|
|
// returned by this function must be destroyed when no longer needed.
|
|
func (t *Tensor[T]) Clone() (*Tensor[T], error) {
|
|
// TODO: Implement Tensor.Clone()
|
|
return nil, fmt.Errorf("Not yet implemented")
|
|
}
|
|
|
|
// Creates a new empty tensor with the given shape. The shape provided to this
|
|
// function is copied, and is no longer needed after this function returns.
|
|
func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error) {
|
|
elementCount := s.FlattenedSize()
|
|
if elementCount == 0 {
|
|
return nil, fmt.Errorf("Got invalid shape containing 0 elements")
|
|
}
|
|
data := make([]T, elementCount)
|
|
return NewTensor(s, data)
|
|
}
|
|
|
|
// Creates a new tensor backed by an existing data slice. The shape provided to
|
|
// this function is copied, and is no longer needed after this function
|
|
// returns. If the data slice is longer than s.FlattenedSize(), then only the
|
|
// first portion of the data will be used.
|
|
func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) {
|
|
elementCount := s.FlattenedSize()
|
|
if elementCount > int64(len(data)) {
|
|
return nil, fmt.Errorf("The tensor's shape (%s) requires %d "+
|
|
"elements, but only %d were provided\n", s, elementCount,
|
|
len(data))
|
|
}
|
|
var ortValue *C.OrtValue
|
|
dataType := GetTensorElementDataType[T]()
|
|
dataSize := unsafe.Sizeof(data[0]) * uintptr(elementCount)
|
|
|
|
status := C.CreateOrtTensorWithShape(unsafe.Pointer(&data[0]),
|
|
C.size_t(dataSize), (*C.int64_t)(unsafe.Pointer(&s[0])),
|
|
C.int64_t(len(s)), ortMemoryInfo, dataType, &ortValue)
|
|
if status != nil {
|
|
return nil, fmt.Errorf("ORT API error creating tensor: %s",
|
|
statusToError(status))
|
|
}
|
|
|
|
toReturn := Tensor[T]{
|
|
data: data[0:elementCount],
|
|
shape: s.Clone(),
|
|
ortValue: ortValue,
|
|
}
|
|
// TODO (next): Set a finalizer on new Tensors.
|
|
// - Idea: use a "destroyable" interface
|
|
return &toReturn, nil
|
|
}
|
|
|
|
// A simple wrapper around the OrtSession C struct. Requires the user to
|
|
// maintain all input and output tensors, and to use the same data type for
|
|
// input and output tensors.
|
|
type SimpleSession[T TensorData] struct {
|
|
ortSession *C.OrtSession
|
|
}
|
|
|
|
// Loads the ONNX network at the given path, and initializes a SimpleSession
|
|
// instance. If this returns successfully, the caller must call Destroy() on
|
|
// the returned session when it is no longer needed.
|
|
func NewSimpleSession[T TensorData](onnxFilePath string) (*SimpleSession[T],
|
|
error) {
|
|
// We load content this way in order to avoid a mess of wide-character
|
|
// paths on Windows if we use CreateSession rather than
|
|
// CreateSessionFromArray.
|
|
fileContent, e := os.ReadFile(onnxFilePath)
|
|
if e != nil {
|
|
return nil, fmt.Errorf("Error reading %s: %w", onnxFilePath, e)
|
|
}
|
|
var ortSession *C.OrtSession
|
|
status := C.CreateSimpleSession(unsafe.Pointer(&(fileContent[0])),
|
|
C.size_t(len(fileContent)), ortEnv, &ortSession)
|
|
if status != nil {
|
|
return nil, fmt.Errorf("Error creating session from %s: %w",
|
|
onnxFilePath, statusToError(status))
|
|
}
|
|
// ONNXRuntime copies the file content unless a specific flag is provided
|
|
// when creating the session (and we don't provide it!)
|
|
fileContent = nil
|
|
return &SimpleSession[T]{
|
|
ortSession: ortSession,
|
|
}, nil
|
|
}
|
|
|
|
func (s *SimpleSession[_]) Destroy() error {
|
|
if s.ortSession != nil {
|
|
C.ReleaseOrtSession(s.ortSession)
|
|
s.ortSession = nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// This function assumes the SimpleSession takes a single input tensor and
|
|
// produces a single output, both of which have the same type.
|
|
func (s *SimpleSession[T]) SimpleRun(input *Tensor[T],
|
|
output *Tensor[T]) error {
|
|
status := C.RunSimpleSession(s.ortSession, input.ortValue,
|
|
output.ortValue)
|
|
if status != nil {
|
|
return fmt.Errorf("Error running network: %w", statusToError(status))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// TODO (next): Test SimpleRun
|