// This library wraps the C "onnxruntime" library maintained at // https://github.com/microsoft/onnxruntime. It seeks to provide as simple an // interface as possible to load and run ONNX-format neural networks from // Go code. package onnxruntime_go import ( "fmt" "strings" "unsafe" ) // #cgo CFLAGS: -O2 -g // // #include "onnxruntime_wrapper.h" import "C" // This string should be the path to onnxruntime.so, or onnxruntime.dll. var onnxSharedLibraryPath string // For simplicity, this library maintains a single ORT environment internally. var ortEnv *C.OrtEnv // We also keep a single OrtMemoryInfo value around, since we only support CPU // allocations for now. var ortMemoryInfo *C.OrtMemoryInfo var NotInitializedError error = fmt.Errorf("InitializeRuntime() has either " + "not yet been called, or did not return successfully") var ZeroShapeLengthError error = fmt.Errorf("The shape has no dimensions") var ShapeOverflowError error = fmt.Errorf("The shape's flattened size " + "overflows an int64") // This type of error is returned when we attempt to validate a tensor that has // a negative or 0 dimension. type BadShapeDimensionError struct { DimensionIndex int DimensionSize int64 } func (e *BadShapeDimensionError) Error() string { return fmt.Sprintf("Dimension %d of the shape has invalid value %d", e.DimensionIndex, e.DimensionSize) } // GetVersion return version of the Onnxruntime library for logging. func GetVersion() string { return C.GoString(C.GetVersion()) } // Does two things: converts the given OrtStatus to a Go error, and releases // the status. If the status is nil, this does nothing and returns nil. func statusToError(status *C.OrtStatus) error { if status == nil { return nil } msg := C.GetErrorMessage(status) goMsg := C.GoString(msg) C.ReleaseOrtStatus(status) return fmt.Errorf("%s", strings.TrimSpace(goMsg)) } // Use this function to set the path to the "onnxruntime.so" or // "onnxruntime.dll" function. By default, it will be set to "onnxruntime.so" // on non-Windows systems, and "onnxruntime.dll" on Windows. Users wishing to // specify a particular location of this library must call this function prior // to calling onnxruntime.InitializeEnvironment(). func SetSharedLibraryPath(path string) { onnxSharedLibraryPath = path } // Returns false if the onnxruntime package is not initialized. Called // internally by several functions, to avoid segfaulting if // InitializeEnvironment hasn't been called yet. func IsInitialized() bool { return ortEnv != nil } // Call this function, optionally with one or more EnvironmentOption, to // initialize the internal onnxruntime environment. If this doesn't return an // error, the caller will be responsible for calling DestroyEnvironment to free // the onnxruntime state when no longer needed. func InitializeEnvironment(opts ...EnvironmentOption) error { if IsInitialized() { return fmt.Errorf("The onnxruntime has already been initialized") } // Do the windows- or linux- specific initialization first. e := platformInitializeEnvironment() if e != nil { return fmt.Errorf("Platform-specific initialization failed: %w", e) } name := C.CString("Golang onnxruntime environment") defer C.free(unsafe.Pointer(name)) status := C.CreateOrtEnv(name, &ortEnv) if status != nil { return fmt.Errorf("Error creating ORT environment: %w", statusToError(status)) } for _, opt := range opts { if status := opt(ortEnv); status != nil { DestroyEnvironment() return fmt.Errorf("Error applying ORT environment option: %w", statusToError(status)) } } status = C.CreateOrtMemoryInfo(&ortMemoryInfo) if status != nil { DestroyEnvironment() return fmt.Errorf("Error creating ORT memory info: %w", statusToError(status)) } return nil } // Call this function to cleanup the internal onnxruntime environment when it // is no longer needed. func DestroyEnvironment() error { var e error if !IsInitialized() { return NotInitializedError } if ortMemoryInfo != nil { C.ReleaseOrtMemoryInfo(ortMemoryInfo) ortMemoryInfo = nil } if ortEnv != nil { C.ReleaseOrtEnv(ortEnv) ortEnv = nil } // platformCleanup primarily unloads the library, so we need to call it // last, after any functions that make use of the ORT API. e = platformCleanup() if e != nil { return fmt.Errorf("Platform-specific cleanup failed: %w", e) } return nil } // EnvironmentOption is a functional option that can be provided during // initialization of an ORT Environment. type EnvironmentOption func(*C.OrtEnv) *C.OrtStatus // Wraps the OrtLoggingLevel enum type LoggingLevel int const ( LoggingLevelVerbose = C.ORT_LOGGING_LEVEL_VERBOSE LoggingLevelInfo = C.ORT_LOGGING_LEVEL_INFO LoggingLevelWarning = C.ORT_LOGGING_LEVEL_WARNING LoggingLevelError = C.ORT_LOGGING_LEVEL_ERROR LoggingLevelFatal = C.ORT_LOGGING_LEVEL_FATAL ) func (l LoggingLevel) String() string { switch l { case LoggingLevelVerbose: return "ORT_LOGGING_LEVEL_VERBOSE" case LoggingLevelInfo: return "ORT_LOGGING_LEVEL_INFO" case LoggingLevelWarning: return "ORT_LOGGING_LEVEL_WARNING" case LoggingLevelError: return "ORT_LOGGING_LEVEL_ERROR" case LoggingLevelFatal: return "ORT_LOGGING_LEVEL_FATAL" } return fmt.Sprintf("", int(l)) } // WithLogLevelVerbose is an EnvironmentOption that will set the ORT // Environment logging to emit verbose informational messages (least // severe) along with all messages of greater severity. func WithLogLevelVerbose() EnvironmentOption { return func(e *C.OrtEnv) *C.OrtStatus { return C.UpdateEnvWithCustomLogLevel(e, C.OrtLoggingLevel(LoggingLevelVerbose)) } } // WithLogLevelInfo is an EnvironmentOption that will set the ORT Environment // logging to emit informational messages along with all messages of greater // severity. func WithLogLevelInfo() EnvironmentOption { return func(e *C.OrtEnv) *C.OrtStatus { return C.UpdateEnvWithCustomLogLevel(e, C.OrtLoggingLevel(LoggingLevelInfo)) } } // WithLogLevelWarning is an EnvironmentOption that will set the ORT // Environment logging to emit warning messages along with all messages of // greater severity. func WithLogLevelWarning() EnvironmentOption { return func(e *C.OrtEnv) *C.OrtStatus { return C.UpdateEnvWithCustomLogLevel(e, C.OrtLoggingLevel(LoggingLevelWarning)) } } // WithLogLevelError is an EnvironmentOption that will set the ORT // Environment logging to emit error messages along with all messages of // greater severity. This is the default logging level. func WithLogLevelError() EnvironmentOption { return func(e *C.OrtEnv) *C.OrtStatus { return C.UpdateEnvWithCustomLogLevel(e, C.OrtLoggingLevel(LoggingLevelError)) } } // WithLogLevelFatal is an EnvironmentOption that will set the ORT // Environment logging to emit only fatal error messages (most severe). func WithLogLevelFatal() EnvironmentOption { return func(e *C.OrtEnv) *C.OrtStatus { return C.UpdateEnvWithCustomLogLevel(e, C.OrtLoggingLevel(LoggingLevelFatal)) } } // Disables telemetry events for the onnxruntime environment. Must be called // after initializing the environment using InitializeEnvironment(). It is // unclear from the onnxruntime docs whether this will cause an error or // silently return if telemetry is already disabled. func DisableTelemetry() error { if !IsInitialized() { return NotInitializedError } status := C.DisableTelemetry(ortEnv) if status != nil { return fmt.Errorf("Error disabling onnxruntime telemetry: %w", statusToError(status)) } return nil } // Enables telemetry events for the onnxruntime environment. Must be called // after initializing the environment using InitializeEnvironment(). It is // unclear from the onnxruntime docs whether this will cause an error or // silently return if telemetry is already enabled. func EnableTelemetry() error { if !IsInitialized() { return NotInitializedError } status := C.EnableTelemetry(ortEnv) if status != nil { return fmt.Errorf("Error enabling onnxruntime telemetry: %w", statusToError(status)) } return nil } // Sets the environmnent-wide log severity level. The argument must be one of // LoggingLevelVerbose, LoggingLevelInfo, LoggingLevelWarning, // LoggingLevelError, or LoggingLevelFatal. Must only be used after the // environment has been initialized. func SetEnvironmentLogLevel(level LoggingLevel) error { if !IsInitialized() { return NotInitializedError } status := C.UpdateEnvWithCustomLogLevel(ortEnv, C.OrtLoggingLevel(level)) if status != nil { return statusToError(status) } return nil } // The Shape type holds the shape of the tensors used by the network input and // outputs. type Shape []int64 // Returns a Shape, with the given dimensions. func NewShape(dimensions ...int64) Shape { return Shape(dimensions) } // Returns the total number of elements in a tensor with the given shape. Note // that this may be an invalid value due to overflow or negative dimensions. If // a shape comes from an untrusted source, it may be a good practice to call // Validate() prior to trusting the FlattenedSize. func (s Shape) FlattenedSize() int64 { if len(s) == 0 { return 0 } toReturn := s[0] for i := 1; i < len(s); i++ { toReturn *= s[i] } return toReturn } // Returns a non-nil error if the shape has bad or zero dimensions. May return // a ZeroShapeLengthError, a ShapeOverflowError, or a BadShapeDimensionError. // In the future, this may return other types of errors if it others become // necessary. func (s Shape) Validate() error { if len(s) == 0 { return ZeroShapeLengthError } hasZeroDim := false // If any dimension is negative, return an error. Also keep track of // whether any dimension is 0. for i, v := range s { if v < 0 { return &BadShapeDimensionError{ DimensionIndex: i, DimensionSize: v, } } if v == 0 { hasZeroDim = true } } // We don't need to check for overflow if one or more dimension was 0. if hasZeroDim { return nil } // All dimensions are positive and nonzero, so make sure that multiplying // them together won't overflow an int64. flattenedSize := s[0] for i := 1; i < len(s); i++ { tmp := flattenedSize * s[i] if tmp < flattenedSize { return ShapeOverflowError } flattenedSize = tmp } return nil } // Makes and returns a deep copy of the Shape. func (s Shape) Clone() Shape { toReturn := make([]int64, len(s)) copy(toReturn, []int64(s)) return Shape(toReturn) } func (s Shape) String() string { return fmt.Sprintf("%v", []int64(s)) } // Returns true if both shapes match in every dimension. func (s Shape) Equals(other Shape) bool { if len(s) != len(other) { return false } for i := 0; i < len(s); i++ { if s[i] != other[i] { return false } } return true } // This wraps internal implementation details to avoid exposing them to users // via the Value interface. type ValueInternalData struct { ortValue *C.OrtValue } // An interface for managing tensors or other onnxruntime values where we don't // necessarily need to access the underlying data slice. All typed tensors will // support this interface regardless of the underlying data type. type Value interface { DataType() C.ONNXTensorElementDataType GetShape() Shape Destroy() error GetInternals() *ValueInternalData ZeroContents() GetONNXType() ONNXType } // Used to manage all input and output data for onnxruntime networks. A Tensor // always has an associated type and refers to data contained in an underlying // Go slice. New tensors should be created using the NewTensor or // NewEmptyTensor functions, and must be destroyed using the Destroy function // when no longer needed. type Tensor[T TensorData] struct { // The shape of the tensor shape Shape // The go slice containing the flattened data that backs the ONNX tensor. data []T // The number of bytes taken by the data slice. dataSize uintptr // The underlying ONNX value we use with the C API. ortValue *C.OrtValue } // Cleans up and frees the memory associated with this tensor. func (t *Tensor[_]) Destroy() error { C.ReleaseOrtValue(t.ortValue) t.ortValue = nil t.data = nil t.dataSize = 0 t.shape = nil return nil } // Returns the slice containing the tensor's underlying data. The contents of // the slice can be read or written to get or set the tensor's contents. func (t *Tensor[T]) GetData() []T { return t.data } // Returns the value from the ONNXTensorElementDataType C enum corresponding to // the type of data held by this tensor. // // NOTE: This function was added prior to the introduction of the // Go TensorElementDataType int wrapping the C enum, so it still returns the // CGo type. func (t *Tensor[T]) DataType() C.ONNXTensorElementDataType { return GetTensorElementDataType[T]() } // Always returns ONNXTypeTensor for any Tensor[T] even if the underlying // tensor is invalid for some reason. func (t *Tensor[_]) GetONNXType() ONNXType { return ONNXTypeTensor } // Returns the shape of the tensor. The returned shape is only a copy; // modifying this does *not* change the shape of the underlying tensor. // (Modifying the tensor's shape can only be accomplished by Destroying and // recreating the tensor with the same data.) func (t *Tensor[_]) GetShape() Shape { return t.shape.Clone() } func (t *Tensor[_]) GetInternals() *ValueInternalData { return &ValueInternalData{ ortValue: t.ortValue, } } // Sets every element in the tensor's underlying data slice to 0. func (t *Tensor[T]) ZeroContents() { C.memset(unsafe.Pointer(&t.data[0]), 0, C.size_t(t.dataSize)) } // Makes a deep copy of the tensor, including its ONNXRuntime value. The Tensor // returned by this function must be destroyed when no longer needed. The // returned tensor will also no longer refer to the same underlying data; use // GetData() to obtain the new underlying slice. func (t *Tensor[T]) Clone() (*Tensor[T], error) { toReturn, e := NewEmptyTensor[T](t.shape) if e != nil { return nil, fmt.Errorf("Error allocating tensor clone: %w", e) } copy(toReturn.GetData(), t.data) return toReturn, nil } // Creates a new empty tensor with the given shape. The shape provided to this // function is copied, and is no longer needed after this function returns. func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error) { e := s.Validate() if e != nil { return nil, fmt.Errorf("Invalid tensor shape: %w", e) } elementCount := s.FlattenedSize() data := make([]T, elementCount) return NewTensor(s, data) } // Creates a new tensor backed by an existing data slice. The shape provided to // this function is copied, and is no longer needed after this function // returns. If the data slice is longer than s.FlattenedSize(), then only the // first portion of the data will be used. func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) { if !IsInitialized() { return nil, NotInitializedError } e := s.Validate() if e != nil { return nil, fmt.Errorf("Invalid tensor shape: %w", e) } elementCount := s.FlattenedSize() if elementCount > int64(len(data)) { return nil, fmt.Errorf("The tensor's shape (%s) requires %d "+ "elements, but only %d were provided", s, elementCount, len(data)) } var ortValue *C.OrtValue dataType := GetTensorElementDataType[T]() var dataPtr unsafe.Pointer var dataSize uintptr if elementCount != 0 { // Only do operations that require accessing the data slice if we know // the Tensor's shape doesn't contain any 0 dimensions dataSize = unsafe.Sizeof(data[0]) * uintptr(elementCount) dataPtr = unsafe.Pointer(&data[0]) } status := C.CreateOrtTensorWithShape(dataPtr, C.size_t(dataSize), (*C.int64_t)(unsafe.Pointer(&s[0])), C.int64_t(len(s)), ortMemoryInfo, dataType, &ortValue) if status != nil { return nil, fmt.Errorf("ORT API error creating tensor: %s", statusToError(status)) } toReturn := Tensor[T]{ data: data[0:elementCount], dataSize: dataSize, shape: s.Clone(), ortValue: ortValue, } return &toReturn, nil } // Wraps the ONNXTEnsorElementDataType enum in C. type TensorElementDataType int const ( TensorElementDataTypeUndefined = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED TensorElementDataTypeFloat = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT TensorElementDataTypeUint8 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 TensorElementDataTypeInt8 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 TensorElementDataTypeUint16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 TensorElementDataTypeInt16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 TensorElementDataTypeInt32 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 TensorElementDataTypeInt64 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 TensorElementDataTypeString = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING TensorElementDataTypeBool = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL TensorElementDataTypeFloat16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 TensorElementDataTypeDouble = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE TensorElementDataTypeUint32 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 TensorElementDataTypeUint64 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 // Not supported by onnxruntime (as of onnxruntime version 1.22.0) TensorElementDataTypeComplex64 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 // Not supported by onnxruntime (as of onnxruntime version 1.22.0) TensorElementDataTypeComplex128 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 // Non-IEEE floating-point format based on IEEE754 single-precision TensorElementDataTypeBFloat16 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // 8-bit float types, introduced in onnx 1.14. See // https://onnx.ai/onnx/technical/float8.html TensorElementDataTypeFloat8E4M3FN = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN TensorElementDataTypeFloat8E4M3FNUZ = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ TensorElementDataTypeFloat8E5M2 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 TensorElementDataTypeFloat8E5M2FNUZ = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ // Int4 types were introduced in ONNX 1.16. See // https://onnx.ai/onnx/technical/int4.html TensorElementDataTypeUint4 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4 TensorElementDataTypeInt4 = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 ) func (t TensorElementDataType) String() string { switch t { case TensorElementDataTypeUndefined: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED" case TensorElementDataTypeFloat: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT" case TensorElementDataTypeUint8: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8" case TensorElementDataTypeInt8: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8" case TensorElementDataTypeUint16: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16" case TensorElementDataTypeInt16: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16" case TensorElementDataTypeInt32: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32" case TensorElementDataTypeInt64: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64" case TensorElementDataTypeString: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING" case TensorElementDataTypeBool: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL" case TensorElementDataTypeFloat16: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16" case TensorElementDataTypeDouble: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE" case TensorElementDataTypeUint32: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32" case TensorElementDataTypeUint64: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64" case TensorElementDataTypeComplex64: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64" case TensorElementDataTypeComplex128: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128" case TensorElementDataTypeBFloat16: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16" case TensorElementDataTypeFloat8E4M3FN: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN" case TensorElementDataTypeFloat8E4M3FNUZ: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ" case TensorElementDataTypeFloat8E5M2: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2" case TensorElementDataTypeFloat8E5M2FNUZ: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ" case TensorElementDataTypeUint4: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4" case TensorElementDataTypeInt4: return "ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4" } return fmt.Sprintf("Unknown tensor element data type: %d", int(t)) } // Wraps the GraphOptimizationLevel enum in C. type GraphOptimizationLevel int const ( GraphOptimizationLevelDisableAll = C.ORT_DISABLE_ALL GraphOptimizationLevelEnableBasic = C.ORT_ENABLE_BASIC GraphOptimizationLevelEnableExtended = C.ORT_ENABLE_EXTENDED GraphOptimizationLevelEnableAll = C.ORT_ENABLE_ALL ) func (l GraphOptimizationLevel) String() string { switch l { case GraphOptimizationLevelDisableAll: return "ORT_DISABLE_ALL" case GraphOptimizationLevelEnableBasic: return "ORT_ENABLE_BASIC" case GraphOptimizationLevelEnableExtended: return "ORT_ENABLE_EXTENDED" case GraphOptimizationLevelEnableAll: return "ORT_ENABLE_ALL" } return fmt.Sprintf("", int(l)) } // This wraps an ONNX_TYPE_SEQUENCE OrtValue. Satisfies the Value interface, // though Tensor-related functions such as ZeroContents() may be no-ops. type Sequence struct { ortValue *C.OrtValue // We'll stash the values in the sequence here, so we don't need to look // them up, and so that users don't need to remember to free them. contents []Value } // Returns the value at the given index in the sequence or map. (In a map, // index 0 is for keys, and 1 is for values.) Used internally when initializing // a go Sequence or Map object. func getSequenceOrMapValue(sequenceOrMap *C.OrtValue, index int64) (Value, error) { var result *C.OrtValue status := C.GetValue(sequenceOrMap, C.int(index), &result) if status != nil { return nil, fmt.Errorf("Error getting value of index %d: %s", index, statusToError(status)) } return createGoValueFromOrtValue(result) } // Creates a new ONNX sequence with the given contents. The returned Sequence // must be Destroyed by the caller when no longer needed. Destroying the // Sequence created by this function does _not_ destroy the Values it was // created with, so the caller is still responsible for destroying them // as well. // // The contents of a sequence are subject to additional constraints. I can't // find mention of some of these in the C API docs, but they are enforced by // the onnxruntime API. Notably: all elements of the sequence must have the // same type, and all elements must be either maps or tensors. Finally, the // sequence must contain at least one element, and none of the elements may be // nil. There may be other constraints that I am unaware of, as well. func NewSequence(contents []Value) (*Sequence, error) { if !IsInitialized() { return nil, NotInitializedError } length := int64(len(contents)) if length == 0 { return nil, fmt.Errorf("Sequences must contain at least 1 element") } ortValues := make([]*C.OrtValue, length) for i, v := range contents { if v == nil { return nil, fmt.Errorf("Sequences must not contain nil (index "+ "%d was nil)", i) } ortValues[i] = v.GetInternals().ortValue } var sequence *C.OrtValue status := C.CreateOrtValue(&(ortValues[0]), C.size_t(length), C.ONNX_TYPE_SEQUENCE, &sequence) if status != nil { return nil, fmt.Errorf("Error creating ORT sequence: %s", statusToError(status)) } // Finally, we want to get each OrtValue from the sequence itself, but we // already have a function to do this in the case of onnxruntime-allocated // sequences. toReturn, e := createSequenceFromOrtValue(sequence) if e != nil { // createSequenceFromOrtValue destroys the sequence on error. return nil, fmt.Errorf("Error creating go Sequence from sequence "+ "OrtValue: %w", e) } return toReturn, nil } // Returns the list of values in the sequence. Each of these values should // _not_ be Destroy()'ed by the caller, they will be automatically destroyed // upon calling Destroy() on the sequence. If this sequence was created via // NewSequence, these are not the same Values that the sequence was created // with, though if they are tensors they should still refer to the same // underlying data. func (s *Sequence) GetValues() ([]Value, error) { return s.contents, nil } func (s *Sequence) Destroy() error { C.ReleaseOrtValue(s.ortValue) var e error for _, v := range s.contents { if v != nil { // Just return the last error if any of these returns an error. e2 := v.Destroy() if e2 != nil { e = e2 } } } s.ortValue = nil s.contents = nil return e } // This returns a 1-dimensional Shape containing a single element: the number // of elements the sequence. Typically, Sequence users should prefer calling // len(s.GetValues()) over this function. This function only exists to maintain // compatibility with the Value interface. func (s *Sequence) GetShape() Shape { return NewShape(int64(len(s.contents))) } // Always returns ONNXTypeSequence func (s *Sequence) GetONNXType() ONNXType { return ONNXTypeSequence } // This function is meaningless for a Sequence and shouldn't be used. The // return value is always TENSOR_ELEMENT_DATA_TYPE_UNDEFINED for now, but this // may change in the future. This function is only present for compatibility // with the Value interface and should not be relied on for sequences. func (s *Sequence) DataType() C.ONNXTensorElementDataType { return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED } // This function does nothing for a Sequence, and is only present for // compatibility with the Value interface. func (s *Sequence) ZeroContents() { } func (s *Sequence) GetInternals() *ValueInternalData { return &ValueInternalData{ ortValue: s.ortValue, } } // This wraps an ONNX_TYPE_MAP OrtValue. Satisfies the Value interface, // though Tensor-related functions such as ZeroContents() may be no-ops. type Map struct { ortValue *C.OrtValue // An onnxruntime map is really just two tensors, keys and values, that // must be the same length. These Values will be cleaned up when calling // Map.Destroy. keys Value values Value } // Creates a new ONNX map that maps the given keys tensor to the given values // tensor. Destroying the Map created by this function does _not_ destroy these // keys and values tensors; the caller is still responsible for destroying // them. // // Internally, creating a Map requires two tensors of the same length, and // with constraints on type. For example, keys are not allowed to be floats // (at least currently). (At the time of writing, this has only been confirmed // to work with int64 keys.) There may be many other constraints enforced by // the underlying C API. func NewMap(keys, values Value) (*Map, error) { if !IsInitialized() { return nil, NotInitializedError } newMapArgs := []*C.OrtValue{ keys.GetInternals().ortValue, values.GetInternals().ortValue, } var result *C.OrtValue status := C.CreateOrtValue(&(newMapArgs[0]), 2, C.ONNX_TYPE_MAP, &result) if status != nil { return nil, fmt.Errorf("Error creating ORT map: %s", statusToError(status)) } // We need to obtain internal references to the keys and values allocated // by onnxruntime. createMapFromOrtValue does this for us. toReturn, e := createMapFromOrtValue(result) if e != nil { // createMapFromOrtValue already destroys the OrtValue on error. return nil, fmt.Errorf("Error creating Map instance from map "+ "OrtValue: %w", e) } return toReturn, nil } // Wraps the creation of an ONNX map from a Go map. K is the key type, and V is // the value type. Be aware that constraints on these types exist based on // what ONNX supports. See the comment on NewMap. func NewMapFromGoMap[K, V TensorData](m map[K]V) (*Map, error) { if !IsInitialized() { return nil, NotInitializedError } keysSlice := make([]K, len(m)) valuesSlice := make([]V, len(m)) i := 0 for k, v := range m { keysSlice[i] = k valuesSlice[i] = v i++ } tensorShape := NewShape(int64(len(m))) keysTensor, e := NewTensor(tensorShape, keysSlice) if e != nil { return nil, fmt.Errorf("Error creating keys tensor for map: %w", e) } defer keysTensor.Destroy() valuesTensor, e := NewTensor(tensorShape, valuesSlice) if e != nil { return nil, fmt.Errorf("Error creating values tensor for map: %w", e) } defer valuesTensor.Destroy() toReturn, e := NewMap(keysTensor, valuesTensor) if e != nil { return nil, fmt.Errorf("Error creating map from key and value "+ "tensors: %w", e) } return toReturn, nil } // Returns two Tensors containing the keys and values, respectively. These // tensors should _not_ be Destroyed by users; they will be automatically // cleaned up when m.Destroy() is called. These are _not_ the same Value // instances that were passed to NewMap, and these should not be modified by // users. func (m *Map) GetKeysAndValues() (Value, Value, error) { return m.keys, m.values, nil } func (m *Map) Destroy() error { C.ReleaseOrtValue(m.ortValue) // Just return the last error if either of these returns an error. var e error e2 := m.keys.Destroy() if e2 != nil { e = e2 } e2 = m.values.Destroy() if e2 != nil { e = e2 } m.ortValue = nil m.keys = nil m.values = nil return e } // Always returns ONNXTypeMap func (m *Map) GetONNXType() ONNXType { return ONNXTypeMap } // Returns the shape of the map's keys Tensor. Essentially, this can be used // to determine the number of key/value pairs in the map. func (m *Map) GetShape() Shape { return m.keys.GetShape() } func (m *Map) GetInternals() *ValueInternalData { return &ValueInternalData{ ortValue: m.ortValue, } } // As with Sequence.ZeroContents(), this is a no-op (at least for now), and is // only present for compatibility with the Value interface. func (m *Map) ZeroContents() { } // As with a Sequence, this always returns the undefined data type and is only // present for compatibility with the Value interface. func (m *Map) DataType() C.ONNXTensorElementDataType { return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED } // Wraps the ONNXType enum in C. type ONNXType int const ( ONNXTypeUnknown = C.ONNX_TYPE_UNKNOWN ONNXTypeTensor = C.ONNX_TYPE_TENSOR ONNXTypeSequence = C.ONNX_TYPE_SEQUENCE ONNXTypeMap = C.ONNX_TYPE_MAP ONNXTypeOpaque = C.ONNX_TYPE_OPAQUE ONNXTypeSparseTensor = C.ONNX_TYPE_SPARSETENSOR ONNXTypeOptional = C.ONNX_TYPE_OPTIONAL ) func (t ONNXType) String() string { switch t { case ONNXTypeUnknown: return "ONNX_TYPE_UNKNOWN" case ONNXTypeTensor: return "ONNX_TYPE_TENSOR" case ONNXTypeSequence: return "ONNX_TYPE_SEQUENCE" case ONNXTypeMap: return "ONNX_TYPE_MAP" case ONNXTypeOpaque: return "ONNX_TYPE_OPAQUE" case ONNXTypeSparseTensor: return "ONNX_TYPE_SPARSE_TENSOR" case ONNXTypeOptional: return "ONNX_TYPE_OPTIONAL" } return fmt.Sprintf("Unknown ONNX type: %d", int(t)) } // This satisfies the Value interface, but is intended to allow users to // provide tensors of types that may not be supported by the generic typed // Tensor[T] struct. Instead, CustomDataTensors are backed by a slice of bytes, // using a user-provided shape and type from the ONNXTensorElementDataType // enum. type CustomDataTensor struct { data []byte dataType C.ONNXTensorElementDataType shape Shape ortValue *C.OrtValue } // Creates and returns a new CustomDataTensor using the given bytes as the // underlying data slice. Apart from ensuring that the provided data slice is // non-empty, this function mostly delegates validation of the provided data to // the C onnxruntime library. For example, it is the caller's responsibility to // ensure that the provided dataType and data slice are valid and correctly // sized for the specified shape. If this returns successfully, the caller must // call the returned tensor's Destroy() function to free it when no longer in // use. func NewCustomDataTensor(s Shape, data []byte, dataType TensorElementDataType) (*CustomDataTensor, error) { if !IsInitialized() { return nil, NotInitializedError } e := s.Validate() if e != nil { return nil, fmt.Errorf("Invalid tensor shape: %w", e) } if len(data) == 0 { return nil, fmt.Errorf("A CustomDataTensor requires at least one " + "byte of data") } dt := C.ONNXTensorElementDataType(dataType) var ortValue *C.OrtValue status := C.CreateOrtTensorWithShape(unsafe.Pointer(&data[0]), C.size_t(len(data)), (*C.int64_t)(unsafe.Pointer(&s[0])), C.int64_t(len(s)), ortMemoryInfo, dt, &ortValue) if status != nil { return nil, fmt.Errorf("ORT API error creating tensor: %s", statusToError(status)) } toReturn := CustomDataTensor{ data: data, dataType: dt, shape: s.Clone(), ortValue: ortValue, } return &toReturn, nil } func (t *CustomDataTensor) Destroy() error { C.ReleaseOrtValue(t.ortValue) t.ortValue = nil t.data = nil t.shape = nil t.dataType = C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED return nil } func (t *CustomDataTensor) DataType() C.ONNXTensorElementDataType { return t.dataType } func (t *CustomDataTensor) GetShape() Shape { return t.shape.Clone() } func (t *CustomDataTensor) GetInternals() *ValueInternalData { return &ValueInternalData{ ortValue: t.ortValue, } } // Always returns ONNXTypeTensor, even if the CustomDataTensor is invalid for // some reason. func (t *CustomDataTensor) GetONNXType() ONNXType { return ONNXTypeTensor } // Sets all bytes in the data slice to 0. func (t *CustomDataTensor) ZeroContents() { if len(t.data) == 0 { return } C.memset(unsafe.Pointer(&t.data[0]), 0, C.size_t(len(t.data))) } // Returns the same slice that was passed to NewCustomDataTensor. func (t *CustomDataTensor) GetData() []byte { return t.data } // Scalar is like a tensor but the underlying go slice is of length 1 and it // has no dimension. It was introduced for use with the training API, but // remains supported since it may be useful apart from the training API. type Scalar[T TensorData] struct { data []T dataSize uintptr ortValue *C.OrtValue } // Always returns nil for Scalars. func (s *Scalar[T]) GetShape() Shape { return nil } func (s *Scalar[T]) ZeroContents() { C.memset(unsafe.Pointer(&s.data[0]), 0, C.size_t(s.dataSize)) } func (s *Scalar[T]) Destroy() error { C.ReleaseOrtValue(s.ortValue) s.ortValue = nil s.data = nil s.dataSize = 0 return nil } // GetData returns the undelying data for the scalar. If you want to set the // scalar's data, use Set. func (t *Scalar[T]) GetData() T { return t.data[0] } // Changes the underlying value of the scalar to the new value. func (t *Scalar[T]) Set(value T) { t.data = []T{value} } func (t *Scalar[T]) DataType() C.ONNXTensorElementDataType { return GetTensorElementDataType[T]() } func (t *Scalar[_]) GetInternals() *ValueInternalData { return &ValueInternalData{ ortValue: t.ortValue, } } func (t *Scalar[_]) GetONNXType() ONNXType { return ONNXTypeTensor } // NewEmptyScalar creates a new scalar of type T. func NewEmptyScalar[T TensorData]() (*Scalar[T], error) { var data T return NewScalar(data) } // NewScalar creates a new scalar of type T backed by a value of type T. // Note that, differently from tensors, this is not a []T but just a value T. func NewScalar[T TensorData](data T) (*Scalar[T], error) { if !IsInitialized() { return nil, NotInitializedError } dataSlice := []T{data} var ortValue *C.OrtValue dataType := GetTensorElementDataType[T]() dataSize := unsafe.Sizeof(dataSlice[0]) * uintptr(1) status := C.CreateOrtTensorWithShape(unsafe.Pointer(&dataSlice[0]), C.size_t(dataSize), nil, C.int64_t(0), ortMemoryInfo, dataType, &ortValue) if status != nil { return nil, statusToError(status) } toReturn := Scalar[T]{ data: dataSlice, dataSize: dataSize, ortValue: ortValue, } return &toReturn, nil } // Holds options required when enabling the CUDA backend for a session. This // struct wraps C onnxruntime types; users must create instances of this using // the NewCUDAProviderOptions() function. So, to enable CUDA for a session, // follow these steps: // // 1. Call NewSessionOptions() to create a SessionOptions struct. // 2. Call NewCUDAProviderOptions() to obtain a CUDAProviderOptions struct. // 3. Call the CUDAProviderOptions struct's Update(...) function to pass a // list of settings to CUDA. (See the comment on the Update() function.) // 4. Pass the CUDA options struct pointer to the // SessionOptions.AppendExecutionProviderCUDA(...) function. // 5. Call the Destroy() function on the CUDA provider options. // 6. Call NewAdvancedSession(...), passing the SessionOptions struct to it. // 7. Call the Destroy() function on the SessionOptions struct. // // Admittedly, this is a bit of a mess, but that's how it's handled by the C // API internally. (The onnxruntime python API hides a bunch of this complexity // using getter and setter functions, for which Go does not have a terse // equivalent.) type CUDAProviderOptions struct { o *C.OrtCUDAProviderOptionsV2 } // Used when setting key-value pair options with certain obnoxious C APIs. // The entries in each of the returned slices must be freed when they're // no longer needed. func mapToCStrings(options map[string]string) ([]*C.char, []*C.char) { keys := make([]*C.char, 0, len(options)) values := make([]*C.char, 0, len(options)) for k, v := range options { keys = append(keys, C.CString(k)) values = append(values, C.CString(v)) } return keys, values } // Calls free on each entry in the array of C strings. func freeCStrings(s []*C.char) { for i := range s { C.free(unsafe.Pointer(s[i])) s[i] = nil } } // Wraps the call to the UpdateCUDAProviderOptions in the onnxruntime C API. // Requires a map of string keys to values for configuring the CUDA backend. // For example, set the key "device_id" to "1" to use GPU 1 rather than 0. // // The onnxruntime headers refer users to // https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options // for a full list of available keys and values. func (o *CUDAProviderOptions) Update(options map[string]string) error { if len(options) == 0 { return nil } keys, values := mapToCStrings(options) defer freeCStrings(keys) defer freeCStrings(values) status := C.UpdateCUDAProviderOptions(o.o, &(keys[0]), &(values[0]), C.int(len(options))) if status != nil { return statusToError(status) } return nil } // Must be called when the CUDAProviderOptions struct is no longer needed; // frees internal C-allocated state. Note that the CUDAProviderOptions struct // can be destroyed as soon as options.AppendExecutionProviderCUDA has been // called. func (o *CUDAProviderOptions) Destroy() error { if o.o == nil { return fmt.Errorf("The CUDAProviderOptions are not initialized") } C.ReleaseCUDAProviderOptions(o.o) o.o = nil return nil } // Initializes and returns a CUDAProviderOptions struct, used when enabling // CUDA in a SessionOptions instance. (i.e., a CUDAProviderOptions must be // configured, then passed to SessionOptions.AppendExecutionProviderCUDA.) // The caller must call the Destroy() function on the returned struct when it's // no longer needed. func NewCUDAProviderOptions() (*CUDAProviderOptions, error) { if !IsInitialized() { return nil, NotInitializedError } var o *C.OrtCUDAProviderOptionsV2 status := C.CreateCUDAProviderOptions(&o) if status != nil { return nil, statusToError(status) } return &CUDAProviderOptions{ o: o, }, nil } // Like the CUDAProviderOptions struct, but used for configuring TensorRT // options. Instances of this struct must be initialized using // NewTensorRTProviderOptions() and cleaned up by calling their Destroy() // function when they are no longer needed. type TensorRTProviderOptions struct { o *C.OrtTensorRTProviderOptionsV2 } // Wraps the call to the UpdateTensorRTProviderOptions in the C API. Requires // a map of string keys to values. // // The onnxruntime headers refer users to // https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc // for the list of available keys and values. func (o *TensorRTProviderOptions) Update(options map[string]string) error { if len(options) == 0 { return nil } keys, values := mapToCStrings(options) defer freeCStrings(keys) defer freeCStrings(values) status := C.UpdateTensorRTProviderOptions(o.o, &(keys[0]), &(values[0]), C.int(len(options))) if status != nil { return statusToError(status) } return nil } // Must be called when the TensorRTProviderOptions are no longer needed, in // order to free internal state. The struct is not needed as soon as you have // passed it to the AppendExecutionProviderTensorRT function. func (o *TensorRTProviderOptions) Destroy() error { if o.o == nil { return fmt.Errorf("The TensorRTProviderOptions are not initialized") } C.ReleaseTensorRTProviderOptions(o.o) o.o = nil return nil } // Initializes and returns a TensorRTProviderOptions struct, used when enabling // the TensorRT backend. The caller must call the Destroy() function on the // returned struct when it's no longer needed. func NewTensorRTProviderOptions() (*TensorRTProviderOptions, error) { if !IsInitialized() { return nil, NotInitializedError } var o *C.OrtTensorRTProviderOptionsV2 status := C.CreateTensorRTProviderOptions(&o) if status != nil { return nil, statusToError(status) } return &TensorRTProviderOptions{ o: o, }, nil } // Wraps the ExecutionMode enum in C. type ExecutionMode int const ( ExecutionModeSequential = C.ORT_SEQUENTIAL ExecutionModeParallel = C.ORT_PARALLEL ) func (m ExecutionMode) String() string { switch m { case ExecutionModeSequential: return "ORT_SEQUENTIAL" case ExecutionModeParallel: return "ORT_PARALLEL" } return fmt.Sprintf("Invalid/unknown execution mode: %d", int(m)) } // Used to set options when creating an ONNXRuntime session. There is currently // not a way to change options after the session is created, apart from // destroying the session and creating a new one. This struct opaquely wraps a // C OrtSessionOptions struct, which users must modify via function calls. (The // OrtSessionOptions struct is opaque in the C API, too.) // // Users must instantiate this struct using the NewSessionOptions function. // Instances must be destroyed by calling the Destroy() method after the // options are no longer needed (after NewAdvancedSession(...) has returned). type SessionOptions struct { o *C.OrtSessionOptions } func (o *SessionOptions) Destroy() error { if o.o == nil { return fmt.Errorf("The SessionOptions are not initialized") } C.ReleaseSessionOptions(o.o) o.o = nil return nil } // Sets the session's execution mode. The newMode must be // ExecutionModeSequential or ExecutionModeParallel. func (o *SessionOptions) SetExecutionMode(newMode ExecutionMode) error { status := C.SetSessionExecutionMode(o.o, C.int(newMode)) if status != nil { return statusToError(status) } return nil } // Sets the optimization level to apply when loading a graph. Refer to // the C API documentation for SetSessionGraphOptimizationLevel. func (o *SessionOptions) SetGraphOptimizationLevel( level GraphOptimizationLevel) error { status := C.SetSessionGraphOptimizationLevel(o.o, C.int(level)) if status != nil { return statusToError(status) } return nil } // Sets the sessions log severity level. Must be one of LoggingLevelVerbose, // LoggingLevelInfo, LoggingLevelWarning, LoggingLevelError, or // LoggingLevelFatal. func (o *SessionOptions) SetLogSeverityLevel(level LoggingLevel) error { status := C.SetSessionLogSeverityLevel(o.o, C.int(level)) if status != nil { return statusToError(status) } return nil } // Returns true, nil if the SessionOptions has a configuration entry with the // given key. Returns false if the key isn't defined. Returns an error if // onnxruntime indicates an error, though it isn't clear from the docs what // may cause an error to occur here. See also GetSessionConfigEntry and // AddSessionConfigEntry. func (o *SessionOptions) HasSessionConfigEntry(key string) (bool, error) { cKey := C.CString(key) defer C.free(unsafe.Pointer(cKey)) var out C.int status := C.HasSessionConfigEntry(o.o, cKey, &out) if status != nil { return false, statusToError(status) } return out != 0, nil } // Returns the session config entry corresponding to the given key, or an error // if one occurs. Returns an error if the key doesn't exist, so it may be // cheaper to check for the key with HasSessionConfigEntry first. See also // AddSessionConfigEntry. func (o *SessionOptions) GetSessionConfigEntry(key string) (string, error) { var neededSize C.size_t cKey := C.CString(key) defer C.free(unsafe.Pointer(cKey)) // First check for the size of the key. (This includes the null terminator) status := C.GetSessionConfigEntry(o.o, cKey, nil, &neededSize) if status != nil { return "", fmt.Errorf("Error determining size of key %s: %w", key, statusToError(status)) } // Should catch the case if the entry was an empty string. if neededSize <= 1 { return "", nil } // We'll allocate the buffer to hold the string in Go to keep the C simpler resultBuffer := make([]C.char, neededSize) status = C.GetSessionConfigEntry(o.o, cKey, &(resultBuffer[0]), &neededSize) if status != nil { return "", fmt.Errorf("Error getting contents of key %s: %w", key, statusToError(status)) } toReturn := C.GoString(&(resultBuffer[0])) return toReturn, nil } // Sets a session configuration key to the given value. See the // onnxruntime_session_options_config_keys.h file in the onnxruntime sources // for documentation on valid keys and values. If the key was already set, this // will overwrite its old setting with the given value. func (o *SessionOptions) AddSessionConfigEntry(key, value string) error { cKey := C.CString(key) defer C.free(unsafe.Pointer(cKey)) cValue := C.CString(value) defer C.free(unsafe.Pointer(cValue)) status := C.AddSessionConfigEntry(o.o, cKey, cValue) if status != nil { return statusToError(status) } return nil } // Sets the number of threads used to parallelize execution within onnxruntime // graph nodes. A value of 0 uses the default number of threads. func (o *SessionOptions) SetIntraOpNumThreads(n int) error { if n < 0 { return fmt.Errorf("Number of threads must be at least 0, got %d", n) } status := C.SetIntraOpNumThreads(o.o, C.int(n)) if status != nil { return statusToError(status) } return nil } // Sets the number of threads used to parallelize execution across separate // onnxruntime graph nodes. A value of 0 uses the default number of threads. func (o *SessionOptions) SetInterOpNumThreads(n int) error { if n < 0 { return fmt.Errorf("Number of threads must be at least 0, got %d", n) } status := C.SetInterOpNumThreads(o.o, C.int(n)) if status != nil { return statusToError(status) } return nil } // Enable/Disable the usage of the memory arena on CPU. // Arena may pre-allocate memory for future usage. func (o *SessionOptions) SetCpuMemArena(isEnabled bool) error { n := 0 if isEnabled { n = 1 } status := C.SetCpuMemArena(o.o, C.int(n)) if status != nil { return statusToError(status) } return nil } // Enable/Disable the memory pattern optimization. // If this is enabled memory is preallocated if all shapes are known. func (o *SessionOptions) SetMemPattern(isEnabled bool) error { n := 0 if isEnabled { n = 1 } status := C.SetMemPattern(o.o, C.int(n)) if status != nil { return statusToError(status) } return nil } // Takes a pointer to an initialized CUDAProviderOptions instance, and applies // them to the session options. This is what you'll need to call if you want // the session to use CUDA. Returns an error if your device (or onnxruntime // library) does not support CUDA. The CUDAProviderOptions struct can be // destroyed after this. func (o *SessionOptions) AppendExecutionProviderCUDA( cudaOptions *CUDAProviderOptions) error { status := C.AppendExecutionProviderCUDAV2(o.o, cudaOptions.o) if status != nil { return statusToError(status) } return nil } // Takes an initialized TensorRTProviderOptions instance, and applies them to // the session options. You'll need to call this if you want the session to use // TensorRT. Returns an error if your device (or onnxruntime library version) // does not support TensorRT. The TensorRTProviderOptions can be destroyed // after this. func (o *SessionOptions) AppendExecutionProviderTensorRT( tensorRTOptions *TensorRTProviderOptions) error { status := C.AppendExecutionProviderTensorRTV2(o.o, tensorRTOptions.o) if status != nil { return statusToError(status) } return nil } // Enables the CoreML backend for the given session options on supported // platforms. // The meanings of the flag bits are currently defined in the // coreml_provider_factory.h file which is provided in the include/ directory of // the onnxruntime releases for Apple platforms. // AppendExecutionProviderCoreML is now deprecated. Please use AppendExecutionProviderCoreMLV2 instead. // See: https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html func (o *SessionOptions) AppendExecutionProviderCoreML(flags uint32) error { status := C.AppendExecutionProviderCoreML(o.o, C.uint32_t(flags)) if status != nil { return statusToError(status) } return nil } // AppendExecutionProviderCoreMLV2 is the new API for adding CoreML provider to ONNX Runtime. // This is the recommended way to add CoreML provider as of ONNX Runtime 1.20.0. // // For CoreML options, see: // https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html func (o *SessionOptions) AppendExecutionProviderCoreMLV2(options map[string]string) error { // Handle case with no options if len(options) == 0 { status := C.AppendExecutionProviderCoreMLV2(o.o, nil, nil, 0) if status != nil { return statusToError(status) } return nil } // Convert map to arrays of keys and values keys, values := mapToCStrings(options) defer freeCStrings(keys) defer freeCStrings(values) // Call C function status := C.AppendExecutionProviderCoreMLV2(o.o, &keys[0], &values[0], C.size_t(len(options))) if status != nil { return statusToError(status) } return nil } // Enables the DirectML backend for the given session options on supported // platforms. See the notes on device_id in coreml_provider_factory.h in the // onnxruntime source code, but a device ID of 0 should correspond to the // default device, "which is typically the primary display GPU" according to // the docs. func (o *SessionOptions) AppendExecutionProviderDirectML(deviceID int) error { status := C.AppendExecutionProviderDirectML(o.o, C.int(deviceID)) if status != nil { return statusToError(status) } return nil } // Enables the OpenVINO backend for the given session options on supported // platforms. See // https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#summary-of-options // for a list of supported keys and values that can be passed in the options // map. func (o *SessionOptions) AppendExecutionProviderOpenVINO( options map[string]string) error { // There's probably a more concise way to do this, but we don't want to // do "&(keys[0])" if keys is an empty slice, so we'll declare the null // ptrs ahead of time and only set them if we know the slices aren't empty. var keysPtr, valuesPtr **C.char if len(options) != 0 { keys, values := mapToCStrings(options) defer freeCStrings(keys) defer freeCStrings(values) keysPtr = &(keys[0]) valuesPtr = &(values[0]) } status := C.AppendExecutionProviderOpenVINOV2(o.o, keysPtr, valuesPtr, C.int(len(options))) if status != nil { return statusToError(status) } return nil } // Initializes and returns a SessionOptions struct, used when setting options // in new AdvancedSession instances. The caller must call the Destroy() // function on the returned struct when it's no longer needed. func NewSessionOptions() (*SessionOptions, error) { if !IsInitialized() { return nil, NotInitializedError } var o *C.OrtSessionOptions status := C.CreateSessionOptions(&o) if status != nil { return nil, statusToError(status) } return &SessionOptions{ o: o, }, nil } // A wrapper around the OrtModelMetadata C struct. Must be freed by calling // Destroy() on it when it's no longer needed. type ModelMetadata struct { m *C.OrtModelMetadata } // Frees internal state required by the model metadata. Users are responsible // for calling this on any ModelMetadata instance after it's no longer needed. func (m *ModelMetadata) Destroy() error { if m.m != nil { C.ReleaseModelMetadata(m.m) m.m = nil } return nil } // Takes a C string allocated using the default ORT allocator, converts it to // a Go string, and frees the C copy. Returns an error if one occurs. Returns // an empty string with no error if s is nil. Obviously, s is invalid after // this returns. func convertORTString(s *C.char) (string, error) { if s == nil { return "", nil } // Unfortunately, onnxruntime wants to use custom allocators to allocate // data such as strings, which are rather obtuse to customize. Therefore, // our C code always specifies the default ORT allocator when possible. We // move any strings ORT allocates into Go strings so we can free the C // versions as soon as possible. toReturn := C.GoString(s) status := C.FreeWithDefaultORTAllocator(unsafe.Pointer(s)) if status != nil { return toReturn, statusToError(status) } return toReturn, nil } // Returns the producer name associated with the model metadata, or an error if // the name can't be obtained. func (m *ModelMetadata) GetProducerName() (string, error) { var cName *C.char status := C.ModelMetadataGetProducerName(m.m, &cName) if status != nil { return "", statusToError(status) } return convertORTString(cName) } // Returns the graph name associated with the model metadata, or an error if // the name can't be obtained. func (m *ModelMetadata) GetGraphName() (string, error) { var cName *C.char status := C.ModelMetadataGetGraphName(m.m, &cName) if status != nil { return "", statusToError(status) } return convertORTString(cName) } // Returns the domain associated with the model metadata, or an error if the // domain can't be obtained. func (m *ModelMetadata) GetDomain() (string, error) { var cDomain *C.char status := C.ModelMetadataGetDomain(m.m, &cDomain) if status != nil { return "", statusToError(status) } return convertORTString(cDomain) } // Returns the description associated with the model metadata, or an error if // the description can't be obtained. func (m *ModelMetadata) GetDescription() (string, error) { var cDescription *C.char status := C.ModelMetadataGetDescription(m.m, &cDescription) if status != nil { return "", statusToError(status) } return convertORTString(cDescription) } // Returns the version number in the model metadata, or an error if one occurs. func (m *ModelMetadata) GetVersion() (int64, error) { var version C.int64_t status := C.ModelMetadataGetVersion(m.m, &version) if status != nil { return 0, statusToError(status) } return int64(version), nil } // Looks up and returns the string associated with the given key in the custom // metadata map. Returns a blank string and 'false' if the key isn't in the // map. (A key that's in the map but set to a blank string will // return "" and true instead.) // // NOTE: It is unclear from the onnxruntime documentation for this function // whether an error will be returned if the key isn't present. At the time of // writing (1.17.1) the docs only state that no value is returned, not whether // an error occurs. func (m *ModelMetadata) LookupCustomMetadataMap(key string) (string, bool, error) { var cValue *C.char cKey := C.CString(key) defer C.free(unsafe.Pointer(cKey)) status := C.ModelMetadataLookupCustomMetadataMap(m.m, cKey, &cValue) if status != nil { return "", false, statusToError(status) } if cValue == nil { return "", false, nil } value, e := convertORTString(cValue) return value, true, e } // Returns a list of keys that are present in the custom metadata map. Returns // an empty slice or nil if no keys are in the map. // // NOTE: It is unclear from the docs whether an empty custom metadata map will // cause the underlying C function to return an error along with a NULL list, // or whether it will only return a NULL list with no error. func (m *ModelMetadata) GetCustomMetadataMapKeys() ([]string, error) { var keyCount C.int64_t var cKeys **C.char status := C.ModelMetadataGetCustomMetadataMapKeys(m.m, &cKeys, &keyCount) if status != nil { return nil, statusToError(status) } if cKeys == nil { // We got no keys in the map and no error return return nil, nil } if keyCount == 0 { // We have a non-NULL but empty list of C pointers, so we'll still // free it here. status := C.FreeWithDefaultORTAllocator(unsafe.Pointer(cKeys)) if status != nil { return nil, statusToError(status) } return nil, nil } // The slice allows us to index into the array of C-string pointers. cKeySlice := unsafe.Slice(cKeys, int64(keyCount)) toReturn := make([]string, len(cKeySlice)) var e error for i, s := range cKeySlice { // We won't check for errors until after the loop, because we want to // continue trying to free all of the strings regardless of whether an // error occurs for one of them. toReturn[i], e = convertORTString(s) cKeySlice[i] = nil } // At this point, we've done our best to convert and free all of the ORT- // allocated C strings, but we still need to free the array itself, which // we attempt regardless of whether an error occurred during the string // processing. status = C.FreeWithDefaultORTAllocator(unsafe.Pointer(cKeys)) cKeySlice = nil cKeys = nil if e != nil { return nil, fmt.Errorf("Error copying one or more C strings to Go: %w", e) } if status != nil { return nil, fmt.Errorf("Error freeing array of C strings: %w", statusToError(status)) } return toReturn, nil } // A wrapper around the OrtSession C struct. Requires the user to maintain all // input and output tensors, and to use the same data type for input and output // tensors. Created using NewAdvancedSession(...) or // NewAdvancedSessionWithONNXData(...). The caller is responsible for calling // the Destroy() function on each session when it is no longer needed. type AdvancedSession struct { ortSession *C.OrtSession // We convert the tensor names to C strings only once, and keep them around // here for future calls to Run(). inputNames []*C.char outputNames []*C.char // We only need the OrtValue pointers from the tensors when working with // the C API. Also, these fields aren't used with a DynamicAdvancedSession. inputs []*C.OrtValue outputs []*C.OrtValue } func createCSession(onnxData []byte, options *SessionOptions) (*C.OrtSession, error) { if !IsInitialized() { return nil, NotInitializedError } if len(onnxData) == 0 { return nil, fmt.Errorf("Missing ONNX data") } var ortSession *C.OrtSession var ortSessionOptions *C.OrtSessionOptions if options != nil { ortSessionOptions = options.o } status := C.CreateSession(unsafe.Pointer(&(onnxData[0])), C.size_t(len(onnxData)), ortEnv, &ortSession, ortSessionOptions) if status != nil { return nil, statusToError(status) } return ortSession, nil } // Basically identical to createCSession, except uses a file path rather than // a buffer of .onnx content. func createCSessionFromFile(path string, options *SessionOptions) (*C.OrtSession, error) { if !IsInitialized() { return nil, NotInitializedError } cPath, e := createOrtCharString(path) if e != nil { return nil, fmt.Errorf("Unable to convert path to C path: %w", e) } var ortSession *C.OrtSession var ortSessionOptions *C.OrtSessionOptions if options != nil { ortSessionOptions = options.o } status := C.CreateSessionFromFile(cPath, ortEnv, &ortSession, ortSessionOptions) if status != nil { return nil, statusToError(status) } return ortSession, nil } // Initializes an AdvancedSession object without creating the session; // essentially converting input and output names. Set the dynamicInputs // argument to true if this will be used for a DynamicAdvancedSession; it will // skip checks on the inputs and outputs []Values. func newAdvancedSessionInternal(inputNames, outputNames []string, inputs, outputs []Value, dynamicInputs bool) (*AdvancedSession, error) { if !IsInitialized() { return nil, NotInitializedError } if !dynamicInputs { if len(inputs) == 0 { return nil, fmt.Errorf("No inputs were provided") } if len(outputs) == 0 { return nil, fmt.Errorf("No outputs were provided") } if len(inputs) != len(inputNames) { return nil, fmt.Errorf("Got %d inputs, but %d input names", len(inputs), len(inputNames)) } if len(outputs) != len(outputNames) { return nil, fmt.Errorf("Got %d outputs, but %d output names", len(outputs), len(outputNames)) } } // Collect the inputs and outputs, along with their names, into a format // more convenient for passing to the Run() function in the C API. cInputNames := make([]*C.char, len(inputNames)) cOutputNames := make([]*C.char, len(outputNames)) for i, v := range inputNames { cInputNames[i] = C.CString(v) } for i, v := range outputNames { cOutputNames[i] = C.CString(v) } var inputOrtValues, outputOrtValues []*C.OrtValue if !dynamicInputs { inputOrtValues = make([]*C.OrtValue, len(inputs)) outputOrtValues = make([]*C.OrtValue, len(outputs)) for i, v := range inputs { inputOrtValues[i] = v.GetInternals().ortValue } for i, v := range outputs { outputOrtValues[i] = v.GetInternals().ortValue } } return &AdvancedSession{ ortSession: nil, inputNames: cInputNames, outputNames: cOutputNames, inputs: inputOrtValues, outputs: outputOrtValues, }, nil } // The same as NewAdvancedSession, but takes a slice of bytes containing the // .onnx network rather than a file path. func NewAdvancedSessionWithONNXData(onnxData []byte, inputNames, outputNames []string, inputs, outputs []Value, options *SessionOptions) (*AdvancedSession, error) { toReturn, e := newAdvancedSessionInternal(inputNames, outputNames, inputs, outputs, false) if e != nil { return nil, e } toReturn.ortSession, e = createCSession(onnxData, options) if e != nil { toReturn.Destroy() return nil, fmt.Errorf("Error creating C session: %w", e) } return toReturn, nil } // Loads the ONNX network at the given path, and initializes an AdvancedSession // instance. If this returns successfully, the caller must call Destroy() on // the returned session when it is no longer needed. We require the user to // provide the input and output tensors and names at this point, in order to // not need to re-allocate them every time Run() is called. The user instead // can just update or access the input/output tensor data after calling Run(). // The input and output tensors MUST outlive this session, and calling // session.Destroy() will not destroy the input or output tensors. If the // provided SessionOptions pointer is nil, then the new session will use // default options. func NewAdvancedSession(onnxFilePath string, inputNames, outputNames []string, inputs, outputs []Value, options *SessionOptions) (*AdvancedSession, error) { toReturn, e := newAdvancedSessionInternal(inputNames, outputNames, inputs, outputs, false) if e != nil { return nil, e } toReturn.ortSession, e = createCSessionFromFile(onnxFilePath, options) if e != nil { toReturn.Destroy() return nil, fmt.Errorf("Error creating C session from file: %w", e) } return toReturn, nil } func (s *AdvancedSession) Destroy() error { // Including the check that ortSession is not nil allows the Destroy() // function to be used on AdvancedSessions that are partially initialized, // which we do internally. if s.ortSession != nil { C.ReleaseOrtSession(s.ortSession) s.ortSession = nil } for i := range s.inputNames { C.free(unsafe.Pointer(s.inputNames[i])) } s.inputNames = nil for i := range s.outputNames { C.free(unsafe.Pointer(s.outputNames[i])) } s.outputNames = nil s.inputs = nil s.outputs = nil return nil } // Runs the session, updating the contents of the output tensors on success. func (s *AdvancedSession) Run() error { status := C.RunOrtSession(s.ortSession, &s.inputs[0], &s.inputNames[0], C.int(len(s.inputs)), &s.outputs[0], &s.outputNames[0], C.int(len(s.outputs))) if status != nil { return fmt.Errorf("Error running network: %w", statusToError(status)) } return nil } // Wraps the OrtIoBinding instance. Must be created using // DynamicAdvancedSession's CreateIoBinding method and Destroy'ed when no // longer needed. (Only DynamicAdvancedSession is supported for this, since // a regular AdvancedSession requires specifying input and output tensors at // session creation time.) type IoBinding struct { o *C.OrtIoBinding } // Creates and returns an IoBinding instance associated with the session. The // I/O binding can be used to avoid unecessary copies to or from device memory, // for sessions on different devices. The returned IoBinding must be freed // using Destroy() when it is no longer needed. func (s *DynamicAdvancedSession) CreateIoBinding() (*IoBinding, error) { var o *C.OrtIoBinding status := C.CreateIoBinding(s.s.ortSession, &o) if status != nil { return nil, statusToError(status) } return &IoBinding{ o: o, }, nil } // Must be called to free resources associated with the IoBinding once it's // no longer needed. func (b *IoBinding) Destroy() error { if b.o != nil { C.ReleaseIoBinding(b.o) b.o = nil } return nil } // Binds a value to the named input, to be used when RunWithBinding is called. func (b *IoBinding) BindInput(name string, value Value) error { cName := C.CString(name) defer C.free(unsafe.Pointer(cName)) ortValue := value.GetInternals().ortValue status := C.BindInput(b.o, cName, ortValue) if status != nil { return statusToError(status) } return nil } // Binds a value to the named output, to be used when RunWithBinding is called. func (b *IoBinding) BindOutput(name string, value Value) error { cName := C.CString(name) defer C.free(unsafe.Pointer(cName)) ortValue := value.GetInternals().ortValue status := C.BindOutput(b.o, cName, ortValue) if status != nil { return statusToError(status) } return nil } // Returns a list of bound output names, which will be returned in the same // order that outputs will be returned when GetBoundOutputValues is called. func (b *IoBinding) GetBoundOutputNames() ([]string, error) { var resultBuffer *C.char var resultCount C.size_t var resultSizes *C.size_t status := C.GetBoundOutputNames(b.o, &resultBuffer, &resultSizes, &resultCount) if status != nil { return nil, statusToError(status) } // The doc says no allocation occurs if the number of results is 0. if resultCount == 0 { return []string{}, nil } // Start by getting go slice views of the C buffers to make it easier to // work with. sizesSlice := unsafe.Slice(resultSizes, resultCount) totalSize := uint64(0) // There is likely a more efficient way, but it's much easier to just get // a go slice of the entire buffer. NOTE: The strings in this buffer are // *not* null terminated! for _, s := range sizesSlice { totalSize += uint64(s) } charsSlice := unsafe.Slice((*byte)(unsafe.Pointer(resultBuffer)), totalSize) // We'll take advantage of the byte-slice-to-string conversion to copy the // data into go-managed memory. toReturn := make([]string, resultCount) prevEndOffset := uint64(0) for i, stringLength := range sizesSlice { toReturn[i] = string(charsSlice[prevEndOffset:stringLength]) prevEndOffset += uint64(stringLength) } // Finally, free the onnxruntime-allocated memory buffers. Always attempt // to free both buffers, even if one returns an error. status1 := C.FreeWithDefaultORTAllocator(unsafe.Pointer(resultSizes)) status2 := C.FreeWithDefaultORTAllocator(unsafe.Pointer(resultBuffer)) if status1 != nil { return toReturn, fmt.Errorf("Error freeing list of string lengths: %w", statusToError(status1)) } if status2 != nil { return toReturn, fmt.Errorf("Error freeing buffer of output names: %w", statusToError(status2)) } return toReturn, nil } // Returns a list of Values containing results of a model run using // RunWithBinding. The returned slice contains the same number of values as the // number of names returned by GetOutputNames, and/or in the same order as they // were bound using IoBinding.BindOutput. IMPORTANT: Each Value returned by // this function must be freed by the caller; they are _copies_. // // Note: Using this function will cause Tensor contents to be copied from // C-managed to Go-managed memory to avoid leaks (this is similar to behavior // when DynamicAdvancedSession.Run is allowed to automatically allocate // output tensors). Note that this may be expensive for larger tensors. func (b *IoBinding) GetBoundOutputValues() ([]Value, error) { var valuesBuffer **C.OrtValue var numValues C.size_t var e error status := C.GetBoundOutputValues(b.o, &valuesBuffer, &numValues) if status != nil { return nil, statusToError(status) } valuesSlice := unsafe.Slice(valuesBuffer, numValues) toReturn := make([]Value, numValues) for i := range valuesSlice { toReturn[i], e = createGoValueFromOrtValue(valuesSlice[i]) if e != nil { // Upon error, we have a lot to clean up: // All OrtValues from C that haven't been converted... for _, v := range valuesSlice { if v != nil { C.ReleaseOrtValue(v) } } // the valuesBuffer... C.FreeWithDefaultORTAllocator(unsafe.Pointer(valuesBuffer)) // ... and any Go values that we've already converted for _, v := range toReturn { if v != nil { v.Destroy() } } // Finally, we can actually return the error. Hopefully this // doesn't happen often! return nil, fmt.Errorf("Error converting output index %d to "+ "a Go-managed Value: %w", i, e) } // createGoValueFromOrtValue will automatically destroy the OrtValue, // or it will be destroyed when the corresponding Go Value is // Destroy()'d. Either way, we don't want to double-free it if we // attempt to clean up, so make it nil to indicate we're done with it. valuesSlice[i] = nil } // If we're here, everything was successfully converted to a Go value, // which implies the C-managed OrtValues have already been released or // will be released when the corresponding Go Value is destroyed. status = C.FreeWithDefaultORTAllocator(unsafe.Pointer(valuesBuffer)) if status != nil { // The API is cleaner if the caller doesn't need to worry about // cleaning up if an error is returned. for _, v := range toReturn { v.Destroy() } return nil, fmt.Errorf("Error destroying C-managed buffer to hold "+ "OrtValue pointers: %w", statusToError(status)) } return toReturn, nil } // Clears any previously set inputs. Can't cause errors in the ORT C API. func (b *IoBinding) ClearBoundInputs() { if b.o == nil { return } C.ClearBoundInputs(b.o) } // Clears any previously set outputs. Can't cause errors in the ORT C API. func (b *IoBinding) ClearBoundOutputs() { if b.o == nil { return } C.ClearBoundOutputs(b.o) } // Creates and returns a ModelMetadata instance for this session's model. The // returned metadata must be freed using its Destroy() function when no longer // needed. func (s *AdvancedSession) GetModelMetadata() (*ModelMetadata, error) { var m *C.OrtModelMetadata status := C.SessionGetModelMetadata(s.ortSession, &m) if status != nil { return nil, statusToError(status) } return &ModelMetadata{ m: m, }, nil } // This type of session does not require specifying input and output tensors // ahead of time, but allows users to pass the list of input and output tensors // when calling Run(). As with AdvancedSession, users must still call Destroy() // on an DynamicAdvancedSession that is no longer needed. type DynamicAdvancedSession struct { // We may have further performance optimizations to this in the future, but // for now it's just a regular AdvancedSession. s *AdvancedSession } // Like NewAdvancedSessionWithONNXData, but does not require specifying input // and output tensors. func NewDynamicAdvancedSessionWithONNXData(onnxData []byte, inputNames, outputNames []string, options *SessionOptions) (*DynamicAdvancedSession, error) { s, e := newAdvancedSessionInternal(inputNames, outputNames, nil, nil, true) if e != nil { return nil, fmt.Errorf("Error creating internal AdvancedSession: %w", e) } s.ortSession, e = createCSession(onnxData, options) if e != nil { s.Destroy() return nil, fmt.Errorf("Error creating C session: %w", e) } return &DynamicAdvancedSession{ s: s, }, nil } // Like NewAdvancedSession, but does not require specifying input and output // tensors. Input and output names can be nil or empty, but _only if_ this // session will only be used via RunWithBinding, which manages names // separately. func NewDynamicAdvancedSession(onnxFilePath string, inputNames, outputNames []string, options *SessionOptions) (*DynamicAdvancedSession, error) { s, e := newAdvancedSessionInternal(inputNames, outputNames, nil, nil, true) if e != nil { return nil, fmt.Errorf("Error creating internal AdvancedSession: %w", e) } s.ortSession, e = createCSessionFromFile(onnxFilePath, options) if e != nil { s.Destroy() return nil, fmt.Errorf("Error creating C session from file: %w", e) } return &DynamicAdvancedSession{ s: s, }, nil } func (s *DynamicAdvancedSession) Destroy() error { return s.s.Destroy() } func createTensorWithCData[T TensorData](shape Shape, data unsafe.Pointer) (*Tensor[T], error) { totalCount := shape.FlattenedSize() actualData := unsafe.Slice((*T)(data), totalCount) dataCopy := make([]T, totalCount) copy(dataCopy, actualData) return NewTensor[T](shape, dataCopy) } // Returns the Shape described by a TensorTypeAndShapeInfo instance. func getShapeFromInfo(t *C.OrtTensorTypeAndShapeInfo) (Shape, error) { var dimCount C.size_t status := C.GetDimensionsCount(t, &dimCount) if status != nil { return nil, fmt.Errorf("Error getting dimension count: %w", statusToError(status)) } if dimCount == 0 { // Apparently some models will report shapes with zero dimensions return Shape{}, nil } shape := make(Shape, dimCount) status = C.GetDimensions(t, (*C.int64_t)(&shape[0]), dimCount) if status != nil { return nil, fmt.Errorf("Error getting shape dimensions: %w", statusToError(status)) } return shape, nil } // Returns the ONNXType associated with a C OrtValue. func getValueType(v *C.OrtValue) (ONNXType, error) { var t C.enum_ONNXType status := C.GetValueType(v, &t) if status != nil { return ONNXTypeUnknown, fmt.Errorf("Error looking up type for "+ "OrtValue: %s", statusToError(status)) } return ONNXType(t), nil } // Returns the "count" associated with an OrtValue. Mostly useful for // sequences. Should always return 2 for a map. Not sure what it returns for // Tensors, but that shouldn't matter. func getValueCount(v *C.OrtValue) (int64, error) { var size C.size_t status := C.GetValueCount(v, &size) if status != nil { return 0, fmt.Errorf("Error getting non tensor count for OrtValue: %s", statusToError(status)) } return int64(size), nil } // Takes an OrtValue and returns an appropriate Go value wrapping it, or at // least an equivalent go value in case v is a Tensor. The Value v should // _not_ be released after calling this function; it will either be released // internally or released when the returned Value is Destroy()'d. (Callers must // destroy the returned value.) // // If this function fails, v will be released. func createGoValueFromOrtValue(v *C.OrtValue) (Value, error) { if v == nil { return nil, fmt.Errorf("Internal error: got nil argument to " + "createGoValueFromOrtValue") } valueType, e := getValueType(v) if e != nil { C.ReleaseOrtValue(v) return nil, e } switch valueType { case ONNXTypeTensor: return createTensorFromOrtValue(v) case ONNXTypeSequence: return createSequenceFromOrtValue(v) case ONNXTypeMap: return createMapFromOrtValue(v) default: break } C.ReleaseOrtValue(v) return nil, fmt.Errorf("It is currently not supported to create a Go "+ "value from OrtValues with ONNXType = %s", valueType) } // Must only be called if v is known to be of type ONNXTensor. Returns a Tensor // wrapping v with the correct Go type. This function always copies v's // contents into a new Tensor backed by a Go-managed slice and releases v. func createTensorFromOrtValue(v *C.OrtValue) (Value, error) { // Either in the case of error or otherwise, we'll release v. The issue is // that GetTensorMutableData() becomes invalid after v is Released, so we // can't release v if a reference to the slice returned by GetData is // still referred to outside of the tensor. We work around this by copying // the data into a new tensor and releasing the original. defer C.ReleaseOrtValue(v) var pInfo *C.OrtTensorTypeAndShapeInfo status := C.GetTensorTypeAndShape(v, &pInfo) if status != nil { return nil, fmt.Errorf("Error getting type and shape: %w", statusToError(status)) } shape, e := getShapeFromInfo(pInfo) if e != nil { return nil, fmt.Errorf("Error getting shape from TypeAndShapeInfo: %w", e) } var tensorElementType C.ONNXTensorElementDataType status = C.GetTensorElementType(pInfo, (*uint32)(&tensorElementType)) if status != nil { return nil, fmt.Errorf("Error getting tensor element type: %w", statusToError(status)) } C.ReleaseTensorTypeAndShapeInfo(pInfo) var tensorData unsafe.Pointer status = C.GetTensorMutableData(v, &tensorData) if status != nil { return nil, fmt.Errorf("Error getting tensor mutable data: %w", statusToError(status)) } switch tensorType := TensorElementDataType(tensorElementType); tensorType { case TensorElementDataTypeFloat: return createTensorWithCData[float32](shape, tensorData) case TensorElementDataTypeUint8: return createTensorWithCData[uint8](shape, tensorData) case TensorElementDataTypeInt8: return createTensorWithCData[int8](shape, tensorData) case TensorElementDataTypeUint16: return createTensorWithCData[uint16](shape, tensorData) case TensorElementDataTypeInt16: return createTensorWithCData[int16](shape, tensorData) case TensorElementDataTypeInt32: return createTensorWithCData[int32](shape, tensorData) case TensorElementDataTypeInt64: return createTensorWithCData[int64](shape, tensorData) case TensorElementDataTypeDouble: return createTensorWithCData[float64](shape, tensorData) case TensorElementDataTypeUint32: return createTensorWithCData[uint32](shape, tensorData) case TensorElementDataTypeUint64: return createTensorWithCData[uint64](shape, tensorData) case TensorElementDataTypeBool: return createTensorWithCData[bool](shape, tensorData) default: totalSize := shape.FlattenedSize() actualData := unsafe.Slice((*byte)(tensorData), totalSize) dataCopy := make([]byte, totalSize) copy(dataCopy, actualData) return NewCustomDataTensor(shape, dataCopy, tensorType) } } // Must only be called if v is already known to be an ONNXTypeSequence. Returns // a Sequence go type wrapping v. Releases v if an error occurs; otherwise v // will be released when the returned Sequence is destroyed. func createSequenceFromOrtValue(v *C.OrtValue) (*Sequence, error) { length, e := getValueCount(v) if e != nil { C.ReleaseOrtValue(v) return nil, fmt.Errorf("Error determining sequence length: %w", e) } // Retrieve all of the sequence's contents as Go values, too. internalValues := make([]Value, length) for i := range internalValues { internalValues[i], e = getSequenceOrMapValue(v, int64(i)) if e != nil { // Clean up whatever values we already created. for j := 0; j < i; j++ { internalValues[i].Destroy() } C.ReleaseOrtValue(v) return nil, fmt.Errorf("Error retrieving sequence contents at "+ "index %d: %w", i, e) } } return &Sequence{ ortValue: v, contents: internalValues, }, nil } // Must only be called if v is already known to be an ONNXTypeMap. Returns a // Map go type wrapping v. Releases v if an error occurs, otherwise v will be // released when the returned Map is destroyed. func createMapFromOrtValue(v *C.OrtValue) (*Map, error) { // Obtain the keys and values as tensors from the Map instance. keys, e := getSequenceOrMapValue(v, 0) if e != nil { C.ReleaseOrtValue(v) return nil, fmt.Errorf("Error getting keys tensor from map: %w", e) } values, e := getSequenceOrMapValue(v, 1) if e != nil { keys.Destroy() C.ReleaseOrtValue(v) return nil, fmt.Errorf("Error getting values tensor from map: %w", e) } return &Map{ ortValue: v, keys: keys, values: values, }, nil } // Runs the network on the given input and output tensors. The number of input // and output tensors must match the number (and order) of the input and output // names specified to NewDynamicAdvancedSession. If a given output is nil, it // will be allocated and the slice will be modified to include the new Value. // Any new Value allocated in this way must be freed by calling Destroy on it. func (s *DynamicAdvancedSession) Run(inputs, outputs []Value) error { if len(inputs) != len(s.s.inputNames) { return fmt.Errorf("The session specified %d input names, but Run() "+ "was called with %d input tensors", len(s.s.inputNames), len(inputs)) } if len(outputs) != len(s.s.outputNames) { return fmt.Errorf("The session specified %d output names, but Run() "+ "was called with %d output tensors", len(s.s.outputNames), len(outputs)) } inputValues := make([]*C.OrtValue, len(inputs)) for i, v := range inputs { inputValues[i] = v.GetInternals().ortValue } outputValues := make([]*C.OrtValue, len(outputs)) for i, v := range outputs { if v == nil { // Leave any output that needs to be allocated as nil. continue } outputValues[i] = v.GetInternals().ortValue } status := C.RunOrtSession(s.s.ortSession, &inputValues[0], &s.s.inputNames[0], C.int(len(inputs)), &outputValues[0], &s.s.outputNames[0], C.int(len(outputs))) if status != nil { return fmt.Errorf("Error running network: %w", statusToError(status)) } // Convert any automatically-allocated output to a go Value. for i, v := range outputs { if v != nil { continue } var err error outputs[i], err = createGoValueFromOrtValue(outputValues[i]) if err != nil { return fmt.Errorf("Error creating tensor from ort: %w", err) } } return nil } // Runs the session using the given IoBinding instance. The IoBinding must // have been created from this session's CreateIoBinding() function. func (s *DynamicAdvancedSession) RunWithBinding(b *IoBinding) error { status := C.RunSessionWithBinding(s.s.ortSession, b.o) if status != nil { return statusToError(status) } return nil } // Creates and returns a ModelMetadata instance for this session's model. The // returned metadata must be freed using its Destroy() function when no longer // needed. func (s *DynamicAdvancedSession) GetModelMetadata() (*ModelMetadata, error) { return s.s.GetModelMetadata() } // Holds information about the name, shape, and type of an input or output to a // ONNX network. type InputOutputInfo struct { // The name of the input or output Name string // The higher-level "type" of the output; whether it's a tensor, sequence, // map, etc. OrtValueType ONNXType // The input or output's dimensions, if it's a tensor. This should be // ignored for non-tensor types. Dimensions Shape // The type of element in the input or output, if it's a tensor. This // should be ignored for non-tensor types. DataType TensorElementDataType } func (n *InputOutputInfo) String() string { switch n.OrtValueType { case ONNXTypeUnknown: return fmt.Sprintf("Unknown ONNX type: %s", n.Name) case ONNXTypeTensor: return fmt.Sprintf("Tensor \"%s\": %s, %s", n.Name, n.Dimensions, n.DataType) case ONNXTypeSequence: return fmt.Sprintf("Sequence \"%s\"", n.Name) case ONNXTypeMap: return fmt.Sprintf("Map \"%s\"", n.Name) case ONNXTypeOpaque: return fmt.Sprintf("Opaque \"%s\"", n.Name) case ONNXTypeSparseTensor: return fmt.Sprintf("Sparse tensor \"%s\": dense shape %s, %s", n.Name, n.Dimensions, n.DataType) case ONNXTypeOptional: return fmt.Sprintf("Optional \"%s\"", n.Name) default: break } // We'll use the ONNXType String() output if we don't know the type. return fmt.Sprintf("%s: \"%s\"", n.OrtValueType, n.Name) } // Sets o.OrtValueType, o.DataType, and o.Dimensions from the contents of t. func (o *InputOutputInfo) fillFromTypeInfo(t *C.OrtTypeInfo) error { var onnxType C.enum_ONNXType status := C.GetONNXTypeFromTypeInfo(t, &onnxType) if status != nil { return fmt.Errorf("Error getting ONNX type: %s", statusToError(status)) } o.OrtValueType = ONNXType(onnxType) o.Dimensions = nil o.DataType = TensorElementDataTypeUndefined // We only fill in element type and dimensions if we're dealing with a // tensor of some sort. isTensorType := (o.OrtValueType == ONNXTypeTensor) || (o.OrtValueType == ONNXTypeSparseTensor) if !isTensorType { return nil } // OrtTensorTypeAndShapeInfo pointers should *not* be released if they're // obtained via CastTypeInfoToTensorInfo. var typeAndShapeInfo *C.OrtTensorTypeAndShapeInfo status = C.CastTypeInfoToTensorInfo(t, &typeAndShapeInfo) if status != nil { return fmt.Errorf("Error getting type and shape info: %w", statusToError(status)) } if typeAndShapeInfo == nil { return fmt.Errorf("Didn't get type and shape info for an OrtTypeInfo" + "(it may not be a tensor type?)") } var e error o.Dimensions, e = getShapeFromInfo(typeAndShapeInfo) if e != nil { return fmt.Errorf("Error getting shape from typeAndShapeInfo: %w", e) } var tensorElementType C.ONNXTensorElementDataType status = C.GetTensorElementType(typeAndShapeInfo, (*uint32)(&tensorElementType)) if status != nil { return fmt.Errorf("Error getting data type from typeAndShapeInfo: %w", statusToError(status)) } o.DataType = TensorElementDataType(tensorElementType) return nil } // Fills dst with information about the session's i'th input. func getSessionInputInfo(s *C.OrtSession, i int, dst *InputOutputInfo) error { var cName *C.char var e error status := C.SessionGetInputName(s, C.size_t(i), &cName) if status != nil { return fmt.Errorf("Error getting name: %w", statusToError(status)) } dst.Name, e = convertORTString(cName) if e != nil { return fmt.Errorf("Error converting C name to Go string: %w", e) } // Session inputs are reported as OrtTypeInfo structs, though usually we // want a tensor-specific OrtTensorTypeAndShapeInfo struct, which we can // get from the type info. var typeInfo *C.OrtTypeInfo status = C.SessionGetInputTypeInfo(s, C.size_t(i), &typeInfo) if status != nil { return fmt.Errorf("Error getting type info: %w", statusToError(status)) } defer C.ReleaseTypeInfo(typeInfo) e = dst.fillFromTypeInfo(typeInfo) if e != nil { return e } return nil } // Fills dst with information about the session's i'th output. func getSessionOutputInfo(s *C.OrtSession, i int, dst *InputOutputInfo) error { // This is basically identical to getSessionInputInfo. var cName *C.char var e error status := C.SessionGetOutputName(s, C.size_t(i), &cName) if status != nil { return fmt.Errorf("Error getting name: %w", statusToError(status)) } dst.Name, e = convertORTString(cName) if e != nil { return fmt.Errorf("Error converting C name to Go string: %w", e) } var typeInfo *C.OrtTypeInfo status = C.SessionGetOutputTypeInfo(s, C.size_t(i), &typeInfo) if status != nil { return fmt.Errorf("Error getting type info: %w", statusToError(status)) } defer C.ReleaseTypeInfo(typeInfo) e = dst.fillFromTypeInfo(typeInfo) if e != nil { return e } return nil } // Takes an initialized OrtSession and returns slices of info for each input // and output, respectively. Used internally by GetInputOutputInfo, etc. func getInputOutputInfoFromCSession(s *C.OrtSession) ([]InputOutputInfo, []InputOutputInfo, error) { var e error // Allocate the structs to hold the results. var inputCount, outputCount C.size_t status := C.SessionGetInputCount(s, &inputCount) if status != nil { return nil, nil, statusToError(status) } inputs := make([]InputOutputInfo, inputCount) status = C.SessionGetOutputCount(s, &outputCount) if status != nil { return nil, nil, statusToError(status) } outputs := make([]InputOutputInfo, outputCount) // Get the results for each input and output. for i := 0; i < int(inputCount); i++ { e = getSessionInputInfo(s, i, &(inputs[i])) if e != nil { return nil, nil, fmt.Errorf("Error getting information about "+ "input %d: %w", i, e) } } for i := 0; i < int(outputCount); i++ { e = getSessionOutputInfo(s, i, &(outputs[i])) if e != nil { return nil, nil, fmt.Errorf("Error getting information about "+ "output %d: %w", i, e) } } return inputs, outputs, nil } // Takes a path to a .onnx file, and returns a list of inputs and a list of // outputs, respectively. Will open, read, and close the .onnx file to get the // information. InitializeEnvironment() must have been called prior to using // this function. Warning: this function requires loading the .onnx file into a // temporary onnxruntime session, which may be an expensive operation. // // For now, this may fail if the network has any non-tensor inputs or inputs // that don't have a concrete shape and type. In the future, a new API may be // added to support cases requiring more advanced usage of the C.OrtTypeInfo // struct. func GetInputOutputInfo(path string) ([]InputOutputInfo, []InputOutputInfo, error) { return GetInputOutputInfoWithOptions(path, nil) } // Identical in behavior to GetInputOutputInfo, but addtionally takes // session options to handle models that require options to load. func GetInputOutputInfoWithOptions(path string, options *SessionOptions) ([]InputOutputInfo, []InputOutputInfo, error) { s, e := createCSessionFromFile(path, options) if e != nil { return nil, nil, fmt.Errorf("Error loading temporary session: %w", e) } defer C.ReleaseOrtSession(s) return getInputOutputInfoFromCSession(s) } // Identical in behavior to GetInputOutputInfo, but takes a slice of bytes // containing the .onnx network rather than a file path. func GetInputOutputInfoWithONNXData(data []byte) ([]InputOutputInfo, []InputOutputInfo, error) { var e error s, e := createCSession(data, nil) if e != nil { return nil, nil, fmt.Errorf("Error creating temporary session: %w", e) } defer C.ReleaseOrtSession(s) return getInputOutputInfoFromCSession(s) } func getModelMetadataFromCSession(s *C.OrtSession) (*ModelMetadata, error) { var m *C.OrtModelMetadata status := C.SessionGetModelMetadata(s, &m) if status != nil { return nil, statusToError(status) } return &ModelMetadata{ m: m, }, nil } // Takes a path to a .onnx file and returns the ModelMetadata associated with // it. The returned metadata must be freed using its Destroy() function when // it's no longer needed. InitializeEnvironment() must be called before using // this function. // // Warning: This function loads the onnx content into a temporary onnxruntime // session, so it may be computationally expensive. func GetModelMetadata(path string) (*ModelMetadata, error) { return GetModelMetadataWithOptions(path, nil) } // Identical in behavior to GetModelMetadata, but addtionally takes // session options to handle models that require options to load. func GetModelMetadataWithOptions(path string, options *SessionOptions) (*ModelMetadata, error) { s, e := createCSessionFromFile(path, options) if e != nil { return nil, fmt.Errorf("Error loading %s: %w", path, e) } defer C.ReleaseOrtSession(s) return getModelMetadataFromCSession(s) } // Identical in behavior to GetModelMetadata, but takes a slice of bytes // containing the .onnx network rather than a file path. func GetModelMetadataWithONNXData(data []byte) (*ModelMetadata, error) { if !IsInitialized() { return nil, NotInitializedError } // Create the temporary ORT session from which we'll get the metadata. s, e := createCSession(data, nil) if e != nil { return nil, fmt.Errorf("Error creating temporary session: %w", e) } defer C.ReleaseOrtSession(s) return getModelMetadataFromCSession(s) }