Initial Sequence support

- This adds support for ONNX sequences, along with basic unit tests for
   the sequence types.

 - This also introduces a new test network generated via sklearn in an
   included script.

 - The test using sklearn_randomforest.onnx has not been written yet,
   but I wanted to commit after doing all the work so far. The existing
   tests all pass.
This commit is contained in:
Nathan
2024-07-22 21:13:10 -04:00
parent 0334e9617f
commit 4aa8513549
6 changed files with 474 additions and 27 deletions

View File

@@ -488,6 +488,131 @@ func (t TensorElementDataType) String() string {
return fmt.Sprintf("Unknown tensor element data type: %d", int(t)) return fmt.Sprintf("Unknown tensor element data type: %d", int(t))
} }
// This wraps an ONNX_TYPE_SEQUENCE OrtValue. Individual elements must be
// accessed using GetValue. Satisfies the Value interface, though Tensor-
// related functions such as ZeroContents() may be no-ops.
type Sequence struct {
ortValue *C.OrtValue
// We'll stash the number of values in the sequence here, so we don't need
// to call C.GetValueCount more than once.
valueCount int64
}
// Creates a new ONNX sequence with the given contents. The returned Sequence
// must be Destroyed by the caller when no longer needed. Destroying the
// Sequence created by this function does _not_ destroy the Values it contains,
// so the caller is still responsible for destroying them as well.
//
// The contents of a sequence are subject to additional constraints. I can't
// find mention of some of these in the C API docs, but they are enforced by
// the onnxruntime API. Notably: all elements of the sequence must have the
// same type, and all elements must be either maps or tensors. Finally, the
// sequence must contain at least one element, and none of the elements may be
// nil. There may be other constraints that I am unaware of, as well.
func NewSequence(contents []Value) (*Sequence, error) {
if !IsInitialized() {
return nil, NotInitializedError
}
length := int64(len(contents))
if length == 0 {
return nil, fmt.Errorf("Sequences must contain at least 1 element")
}
ortValues := make([]*C.OrtValue, length)
for i, v := range contents {
if v == nil {
// I don't actually know if NULL OrtValue pointers are allowed in
// sequences, but I'm assuming not.
return nil, fmt.Errorf("Sequences must not contain nil (index "+
"%d was nil)", i)
}
ortValues[i] = v.GetInternals().ortValue
}
// Avoid dereferencing ortValues[0] if the list is empty.
var valuesPtr **C.OrtValue
if length > 0 {
valuesPtr = &(ortValues[0])
}
var sequence *C.OrtValue
status := C.CreateOrtValue(valuesPtr, C.size_t(length),
C.ONNX_TYPE_SEQUENCE, &sequence)
if status != nil {
return nil, fmt.Errorf("Error creating ORT sequence: %s",
statusToError(status))
}
return &Sequence{
ortValue: sequence,
valueCount: length,
}, nil
}
func (s *Sequence) Destroy() error {
C.ReleaseOrtValue(s.ortValue)
s.ortValue = nil
s.valueCount = 0
return nil
}
// Returns the number of elements in the sequence.
func (s *Sequence) GetValueCount() int64 {
return s.valueCount
}
// This returns a 1-dimensional Shape containing a single element: the number
// of elements the sequence. Typically, Sequence users should prefer
// GetValueCount() to this function, since this function only exists to
// maintain compatibility with the Value interface.
func (s *Sequence) GetShape() Shape {
return NewShape(s.valueCount)
}
func (s *Sequence) GetONNXType() ONNXType {
return ONNXTypeSequence
}
// This function is meaningless for a Sequence and shouldn't be used. The
// return value is always TENSOR_ELEMENT_DATA_TYPE_UNDEFINED for now, but this
// may change in the future. This function is only present for compatibility
// with the Value interface and should not be relied on for sequences.
func (s *Sequence) DataType() C.ONNXTensorElementDataType {
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
}
// This function does nothing for a Sequence, and is only present for
// compatibility with the Value interface.
func (s *Sequence) ZeroContents() {
}
func (s *Sequence) GetInternals() *ValueInternalData {
return &ValueInternalData{
ortValue: s.ortValue,
}
}
// Returns the value at the given index in the sequence. Will return an error
// if the index exceeds the number of elements in the sequence, or for any
// other reason.
//
// IMPORTANT: The Value returned by this function must be destroyed by the
// caller when it is no longer needed. This is because the GetValue C API for
// sequences internally allocates a new OrtValue. Put another way, this does
// _not_ return the same Value instance that was provided to NewSequence, even
// though, if it's a tensor, it should still refer to the same underlying data
// slice.
func (s *Sequence) GetValue(index int64) (Value, error) {
if (index < 0) || (index >= s.valueCount) {
return nil, fmt.Errorf("Invalid index (%d) into sequence of length %d",
index, s.valueCount)
}
var result *C.OrtValue
status := C.GetValue(s.ortValue, C.int(index), &result)
if status != nil {
return nil, fmt.Errorf("Error getting value of index %d: %s", index,
statusToError(status))
}
return createGoValueFromOrtValue(result)
}
// Wraps the ONNXType enum in C. // Wraps the ONNXType enum in C.
type ONNXType int type ONNXType int
@@ -1369,7 +1494,53 @@ func getShapeFromInfo(t *C.OrtTensorTypeAndShapeInfo) (Shape, error) {
return shape, nil return shape, nil
} }
// TODO: Make createTensorFromOrtValue work for non-Tensor OrtValues. // Returns the ONNXType associated with a C OrtValue.
func getValueType(v *C.OrtValue) (ONNXType, error) {
var t C.enum_ONNXType
status := C.GetValueType(v, &t)
if status != nil {
return ONNXTypeUnknown, fmt.Errorf("Error looking up type for "+
"OrtValue: %s", statusToError(status))
}
return ONNXType(t), nil
}
// Returns the "count" associated with an OrtValue. Mostly useful for
// sequences. Should always return 2 for a map. Not sure what it returns for
// Tensors, but that shouldn't matter.
func getValueCount(v *C.OrtValue) (int64, error) {
var size C.size_t
status := C.GetValueCount(v, &size)
if status != nil {
return 0, fmt.Errorf("Error getting non tensor count for OrtValue: %s",
statusToError(status))
}
return int64(size), nil
}
// Takes a OrtValue and returns an appropriate Go Value wrapping it. Does not
// copy the value. This function assumes that v will be destroyed later when
// the returned go Value is destroyed.
func createGoValueFromOrtValue(v *C.OrtValue) (Value, error) {
// TODO: How to handle v == nil? Is it even possible to get nil here?
valueType, e := getValueType(v)
if e != nil {
return nil, e
}
switch valueType {
case ONNXTypeTensor:
return createTensorFromOrtValue(v)
case ONNXTypeSequence:
return createSequenceFromOrtValue(v)
default:
break
}
return nil, fmt.Errorf("It is currently not supported to create a Go "+
"value from OrtValues with ONNXType = %s", valueType)
}
// Must only be called if v is known to be of type ONNXTensor. Returns a Tensor
// wrapping v with the correct Go type.
func createTensorFromOrtValue(v *C.OrtValue) (Value, error) { func createTensorFromOrtValue(v *C.OrtValue) (Value, error) {
var pInfo *C.OrtTensorTypeAndShapeInfo var pInfo *C.OrtTensorTypeAndShapeInfo
status := C.GetTensorTypeAndShape(v, &pInfo) status := C.GetTensorTypeAndShape(v, &pInfo)
@@ -1424,11 +1595,24 @@ func createTensorFromOrtValue(v *C.OrtValue) (Value, error) {
} }
} }
// Must only be called if v is already known to be an ONNXTypeSequence. Returns
// a Sequence go type wrapping v.
func createSequenceFromOrtValue(v *C.OrtValue) (Value, error) {
length, e := getValueCount(v)
if e != nil {
return nil, fmt.Errorf("Error determining sequence length: %w", e)
}
return &Sequence{
ortValue: v,
valueCount: length,
}, nil
}
// Runs the network on the given input and output tensors. The number of input // Runs the network on the given input and output tensors. The number of input
// and output tensors must match the number (and order) of the input and output // and output tensors must match the number (and order) of the input and output
// names specified to NewDynamicAdvancedSession. // names specified to NewDynamicAdvancedSession. If a given output is nil, it
// If a given output is nil, it will be allocated and the slice will be modified // will be allocated and the slice will be modified to include the new Value.
// to include the new tensor. The new tensor must be freed by calling Destroy on it. // Any new Value allocated in this way must be freed by calling Destroy on it.
func (s *DynamicAdvancedSession) Run(inputs, outputs []Value) error { func (s *DynamicAdvancedSession) Run(inputs, outputs []Value) error {
if len(inputs) != len(s.s.inputNames) { if len(inputs) != len(s.s.inputNames) {
return fmt.Errorf("The session specified %d input names, but Run() "+ return fmt.Errorf("The session specified %d input names, but Run() "+
@@ -1446,25 +1630,30 @@ func (s *DynamicAdvancedSession) Run(inputs, outputs []Value) error {
} }
outputValues := make([]*C.OrtValue, len(outputs)) outputValues := make([]*C.OrtValue, len(outputs))
for i, v := range outputs { for i, v := range outputs {
if v != nil { if v == nil {
// Leave any output that needs to be allocated as nil.
continue
}
outputValues[i] = v.GetInternals().ortValue outputValues[i] = v.GetInternals().ortValue
} }
}
status := C.RunOrtSession(s.s.ortSession, &inputValues[0], status := C.RunOrtSession(s.s.ortSession, &inputValues[0],
&s.s.inputNames[0], C.int(len(inputs)), &outputValues[0], &s.s.inputNames[0], C.int(len(inputs)), &outputValues[0],
&s.s.outputNames[0], C.int(len(outputs))) &s.s.outputNames[0], C.int(len(outputs)))
if status != nil { if status != nil {
return fmt.Errorf("Error running network: %w", statusToError(status)) return fmt.Errorf("Error running network: %w", statusToError(status))
} }
// Convert any automatically-allocated output to a go Value.
for i, v := range outputs { for i, v := range outputs {
if v == nil { if v != nil {
continue
}
var err error var err error
outputs[i], err = createTensorFromOrtValue(outputValues[i]) outputs[i], err = createTensorFromOrtValue(outputValues[i])
if err != nil { if err != nil {
return fmt.Errorf("Error creating tensor from ort: %w", err) return fmt.Errorf("Error creating tensor from ort: %w", err)
} }
} }
}
return nil return nil
} }
@@ -1480,22 +1669,65 @@ func (s *DynamicAdvancedSession) GetModelMetadata() (*ModelMetadata, error) {
type InputOutputInfo struct { type InputOutputInfo struct {
// The name of the input or output // The name of the input or output
Name string Name string
// The input or output's dimensions // The higher-level "type" of the output; whether it's a tensor, sequence,
// map, etc.
OrtValueType ONNXType
// The input or output's dimensions, if it's a tensor. This should be
// ignored for non-tensor types.
Dimensions Shape Dimensions Shape
// The input or output's data type // The type of element in the input or output, if it's a tensor. This
// should be ignored for non-tensor types.
DataType TensorElementDataType DataType TensorElementDataType
} }
func (n *InputOutputInfo) String() string { func (n *InputOutputInfo) String() string {
return fmt.Sprintf("\"%s\": %s, %s", n.Name, n.Dimensions, n.DataType) switch n.OrtValueType {
case ONNXTypeUnknown:
return fmt.Sprintf("Unknown ONNX type: %s", n.Name)
case ONNXTypeTensor:
return fmt.Sprintf("Tensor \"%s\": %s, %s", n.Name, n.Dimensions,
n.DataType)
case ONNXTypeSequence:
return fmt.Sprintf("Sequence \"%s\"", n.Name)
case ONNXTypeMap:
return fmt.Sprintf("Map \"%s\"", n.Name)
case ONNXTypeOpaque:
return fmt.Sprintf("Opaque \"%s\"", n.Name)
case ONNXTypeSparseTensor:
return fmt.Sprintf("Sparse tensor \"%s\": dense shape %s, %s",
n.Name, n.Dimensions, n.DataType)
case ONNXTypeOptional:
return fmt.Sprintf("Optional \"%s\"", n.Name)
default:
break
}
// We'll use the ONNXType String() output if we don't know the type.
return fmt.Sprintf("%s: \"%s\"", n.OrtValueType, n.Name)
} }
// Sets o.Dimensions and o.DataType from the contents of t. // Sets o.OrtValueType, o.DataType, and o.Dimensions from the contents of t.
func (o *InputOutputInfo) fillFromTypeInfo(t *C.OrtTypeInfo) error { func (o *InputOutputInfo) fillFromTypeInfo(t *C.OrtTypeInfo) error {
var onnxType C.enum_ONNXType
status := C.GetONNXTypeFromTypeInfo(t, &onnxType)
if status != nil {
return fmt.Errorf("Error getting ONNX type: %s", statusToError(status))
}
o.OrtValueType = ONNXType(onnxType)
o.Dimensions = nil
o.DataType = TensorElementDataTypeUndefined
// We only fill in element type and dimensions if we're dealing with a
// tensor of some sort.
isTensorType := (o.OrtValueType == ONNXTypeTensor) ||
(o.OrtValueType == ONNXTypeSparseTensor)
if !isTensorType {
return nil
}
// OrtTensorTypeAndShapeInfo pointers should *not* be released if they're // OrtTensorTypeAndShapeInfo pointers should *not* be released if they're
// obtained via CastTypeInfoToTensorInfo. // obtained via CastTypeInfoToTensorInfo.
var typeAndShapeInfo *C.OrtTensorTypeAndShapeInfo var typeAndShapeInfo *C.OrtTensorTypeAndShapeInfo
status := C.CastTypeInfoToTensorInfo(t, &typeAndShapeInfo) status = C.CastTypeInfoToTensorInfo(t, &typeAndShapeInfo)
if status != nil { if status != nil {
return fmt.Errorf("Error getting type and shape info: %w", return fmt.Errorf("Error getting type and shape info: %w",
statusToError(status)) statusToError(status))

View File

@@ -691,7 +691,6 @@ func TestDynamicInputOutputAxes(t *testing.T) {
t.Fatalf("Error loading %s: %s\n", netPath, e) t.Fatalf("Error loading %s: %s\n", netPath, e)
} }
defer session.Destroy() defer session.Destroy()
rng := rand.New(rand.NewSource(1234))
maxBatchSize := 99 maxBatchSize := 99
// The example network takes a dynamic batch size of vectors containing 10 // The example network takes a dynamic batch size of vectors containing 10
// elements each. // elements each.
@@ -708,10 +707,7 @@ func TestDynamicInputOutputAxes(t *testing.T) {
} }
// Populate the input with new random floats. // Populate the input with new random floats.
inputData := input.GetData() fillRandomFloats(input.GetData(), 1234)
for i := range inputData {
inputData[i] = rng.Float32()
}
// Run the session; make onnxruntime allocate the output tensor for us. // Run the session; make onnxruntime allocate the output tensor for us.
outputs := []Value{nil} outputs := []Value{nil}
@@ -725,6 +721,7 @@ func TestDynamicInputOutputAxes(t *testing.T) {
// The checkVectorSum function will destroy the input and output tensor // The checkVectorSum function will destroy the input and output tensor
// regardless of their correctness. // regardless of their correctness.
checkVectorSum(input, outputs[0].(*Tensor[float32]), t) checkVectorSum(input, outputs[0].(*Tensor[float32]), t)
input.Destroy()
t.Logf("Batch size %d seems OK!\n", i) t.Logf("Batch size %d seems OK!\n", i)
} }
} }
@@ -962,6 +959,13 @@ func randomBytes(seed, n int64) []byte {
return toReturn return toReturn
} }
func fillRandomFloats(dst []float32, seed int64) {
rng := rand.New(rand.NewSource(seed))
for i := range dst {
dst[i] = rng.Float32()
}
}
func TestCustomDataTensors(t *testing.T) { func TestCustomDataTensors(t *testing.T) {
InitializeRuntime(t) InitializeRuntime(t)
defer CleanupRuntime(t) defer CleanupRuntime(t)
@@ -1091,6 +1095,128 @@ func TestFloat16Network(t *testing.T) {
} }
} }
// Returns a 10-element tensor randomly filled values using the given rng seed.
func randomSmallTensor(seed int64, t testing.TB) *Tensor[float32] {
toReturn, e := NewEmptyTensor[float32](NewShape(10))
if e != nil {
t.Fatalf("Error creating small tensor: %s\n", e)
}
fillRandomFloats(toReturn.GetData(), seed)
return toReturn
}
func TestONNXSequence(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
sequenceLength := int64(123)
values := make([]Value, sequenceLength)
for i := range values {
values[i] = randomSmallTensor(int64(i)+123, t)
}
defer func() {
for _, v := range values {
v.Destroy()
}
}()
sequence, e := NewSequence(values)
if e != nil {
t.Fatalf("Error creating sequence: %s\n", e)
}
defer sequence.Destroy()
if int64(sequence.GetValueCount()) != sequenceLength {
t.Fatalf("Got %d values in sequence, expected %d\n",
sequence.GetValueCount(), sequenceLength)
}
if sequence.GetONNXType() != ONNXTypeSequence {
t.Fatalf("Got incorrect ONNX type for sequence: %s\n",
sequence.GetONNXType())
}
// Not guaranteed by onnxruntime, but it _is_ something that I wrote in the
// docs.
if !sequence.GetShape().Equals(NewShape(sequenceLength)) {
t.Fatalf("Sequence.GetShape() returned incorrect shape: %s\n",
sequence.GetShape())
}
_, e = sequence.GetValue(9999)
if e == nil {
t.Fatalf("Did not get an error when accessing out of a " +
"sequence's bounds\n")
}
t.Logf("Got expected error when accessing out of bounds: %s\n", e)
_, e = sequence.GetValue(-1)
if e == nil {
t.Fatalf("Did not get an error when accessing a negative index\n")
}
t.Logf("Got expected error when accessing a negative index: %s\n", e)
// We know from the C API docs that this needs to be destroyed
selectedIndex := int64(44)
selectedValue, e := sequence.GetValue(selectedIndex)
if e != nil {
t.Fatalf("Error getting sequence value at index %d: %s\n",
selectedIndex, e)
}
defer selectedValue.Destroy()
if selectedValue.GetONNXType() != ONNXTypeTensor {
t.Fatalf("Got incorrect ONNXType for value at index %d: "+
"expected %s, got %s\n", selectedIndex, ONNXType(ONNXTypeTensor),
selectedValue.GetONNXType())
}
}
func TestBadSequences(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
// Sequences containing no elements or nil entries shouldn't be allowed
_, e := NewSequence([]Value{})
if e == nil {
t.Fatalf("Didn't get expected error when creating an empty sequence\n")
}
t.Logf("Got expected error when creating an empty sequence: %s\n", e)
_, e = NewSequence([]Value{nil})
if e == nil {
t.Fatalf("Didn't get expected error when creating sequence with a " +
"nil entry.\n")
}
t.Logf("Got expected error when creating sequence with nil entry: %s\n", e)
// Sequences containing mixed data types shouldn't be allowed
tensor := randomSmallTensor(1337, t)
defer tensor.Destroy()
innerSequence, e := NewSequence([]Value{tensor})
if e != nil {
t.Fatalf("Error creating 1-element sequence: %s\n", e)
}
defer innerSequence.Destroy()
_, e = NewSequence([]Value{tensor, innerSequence})
if e == nil {
t.Fatalf("Didn't get expected error when attempting to create a "+
"mixed sequence: %s\n", e)
}
t.Logf("Got expected error when attempting a mixed sequence: %s\n", e)
// Nested sequences also aren't allowed; the C API docs don't seem to
// mention this either.
_, e = NewSequence([]Value{innerSequence, innerSequence})
if e == nil {
t.Fatalf("Didn't get an error creating a sequence with nested " +
"sequences.\n")
}
t.Logf("Got expected error when creating a sequence with nested "+
"sequences: %s\n", e)
}
func TestSklearnNetwork(t *testing.T) {
// TODO: TestSklearnNetwork
// - One of its outputs is a sequence of maps, make sure it works.
// - Check the values printed by the python script for test inputs and
// expected outputs.
}
// See the comment in generate_network_big_compute.py for information about // See the comment in generate_network_big_compute.py for information about
// the inputs and outputs used for testing or benchmarking session options. // the inputs and outputs used for testing or benchmarking session options.
func prepareBenchmarkTensors(t testing.TB, seed int64) (*Tensor[float32], func prepareBenchmarkTensors(t testing.TB, seed int64) (*Tensor[float32],

View File

@@ -289,6 +289,10 @@ void ReleaseTypeInfo(OrtTypeInfo *o) {
ort_api->ReleaseTypeInfo(o); ort_api->ReleaseTypeInfo(o);
} }
OrtStatus *GetONNXTypeFromTypeInfo(OrtTypeInfo *info, enum ONNXType *out) {
return ort_api->GetOnnxTypeFromTypeInfo(info, out);
}
OrtStatus *CastTypeInfoToTensorInfo(OrtTypeInfo *type_info, OrtStatus *CastTypeInfoToTensorInfo(OrtTypeInfo *type_info,
OrtTensorTypeAndShapeInfo **out) { OrtTensorTypeAndShapeInfo **out) {
return ort_api->CastTypeInfoToTensorInfo(type_info, return ort_api->CastTypeInfoToTensorInfo(type_info,
@@ -359,6 +363,28 @@ OrtStatus *ModelMetadataGetVersion(OrtModelMetadata *m, int64_t *version) {
return ort_api->ModelMetadataGetVersion(m, version); return ort_api->ModelMetadataGetVersion(m, version);
} }
OrtStatus *GetValue(OrtValue *container, int index, OrtValue **dst) {
OrtAllocator *allocator = NULL;
OrtStatus *status = NULL;
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
if (status) return status;
return ort_api->GetValue(container, index, allocator, dst);
}
OrtStatus *GetValueType(OrtValue *v, enum ONNXType *out) {
return ort_api->GetValueType(v, out);
}
OrtStatus *GetValueCount(OrtValue *v, size_t *out) {
return ort_api->GetValueCount(v, out);
}
OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
enum ONNXType value_type, OrtValue **out) {
return ort_api->CreateValue((const OrtValue* const*) in, num_values,
value_type, out);
}
// TRAINING API WRAPPER // TRAINING API WRAPPER
static const OrtTrainingApi *ort_training_api = NULL; static const OrtTrainingApi *ort_training_api = NULL;
@@ -504,3 +530,4 @@ void ReleaseOrtTrainingSession(OrtTrainingSession *session) {
void ReleaseCheckpointState(OrtCheckpointState *checkpoint) { void ReleaseCheckpointState(OrtCheckpointState *checkpoint) {
ort_training_api->ReleaseCheckpointState(checkpoint); ort_training_api->ReleaseCheckpointState(checkpoint);
} }

View File

@@ -198,6 +198,9 @@ OrtStatus *SessionGetOutputTypeInfo(OrtSession *session, size_t i,
OrtStatus *CastTypeInfoToTensorInfo(OrtTypeInfo *type_info, OrtStatus *CastTypeInfoToTensorInfo(OrtTypeInfo *type_info,
OrtTensorTypeAndShapeInfo **out); OrtTensorTypeAndShapeInfo **out);
// Wraps ort_api->GetOnnxTypeFromTypeInfo.
OrtStatus *GetONNXTypeFromTypeInfo(OrtTypeInfo *info, enum ONNXType *out);
// Wraps ort_api->FreeTypeInfo. // Wraps ort_api->FreeTypeInfo.
void ReleaseTypeInfo(OrtTypeInfo *o); void ReleaseTypeInfo(OrtTypeInfo *o);
@@ -232,6 +235,19 @@ OrtStatus *ModelMetadataGetCustomMetadataMapKeys(OrtModelMetadata *m,
// Wraps ort_api->ModelMetadataGetVersion. // Wraps ort_api->ModelMetadataGetVersion.
OrtStatus *ModelMetadataGetVersion(OrtModelMetadata *m, int64_t *version); OrtStatus *ModelMetadataGetVersion(OrtModelMetadata *m, int64_t *version);
// Wraps ort_api->GetValue. Uses the default allocator.
OrtStatus *GetValue(OrtValue *container, int index, OrtValue **dst);
// Wraps ort_api->GetValueType.
OrtStatus *GetValueType(OrtValue *v, enum ONNXType *out);
// Wraps ort_api->GetValueCount.
OrtStatus *GetValueCount(OrtValue *v, size_t *out);
// Wraps ort_api->CreateValue to create a map or a sequence.
OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
enum ONNXType value_type, OrtValue **out);
// TRAINING API WRAPPER // TRAINING API WRAPPER
void SetTrainingApi(); void SetTrainingApi();
@@ -241,11 +257,12 @@ int IsTrainingApiSupported();
// Wraps ort_training_api->CreateSessionFromBuffer. // Wraps ort_training_api->CreateSessionFromBuffer.
// Creates and ORT checkpoint from the checkpoint data. // Creates and ORT checkpoint from the checkpoint data.
OrtStatus *CreateCheckpoint(void *checkpoint_data, size_t checkpoint_data_length, OrtCheckpointState **out); OrtStatus *CreateCheckpoint(void *checkpoint_data,
size_t checkpoint_data_length, OrtCheckpointState **out);
// Wraps ort_training_api->CreateTrainingSessionFromBuffer. // Wraps ort_training_api->CreateTrainingSessionFromBuffer. Creates an ORT
// Creates an ORT training session using the given models and checkpoint. // training session using the given models and checkpoint. The given options
// The given options pointer may be NULL; if it is, then we'll use default options. // pointer may be NULL; if it is, then we'll use default options.
OrtStatus *CreateTrainingSessionFromBuffer(OrtCheckpointState *checkpoint_state, OrtStatus *CreateTrainingSessionFromBuffer(OrtCheckpointState *checkpoint_state,
void *training_model_data, size_t training_model_data_length, void *training_model_data, size_t training_model_data_length,
void *eval_model_data, size_t eval_model_data_length, void *eval_model_data, size_t eval_model_data_length,

View File

@@ -0,0 +1,45 @@
# This script is a modified version of the example from
# https://pypi.org/project/skl2onnx/, which we use to produce
# sklearn_randomforest.onnx. sklearn makes heavy use of onnxruntime maps and
# sequences in its networks, so this is used for testing those data types.
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
iris = load_iris()
inputs, outputs = iris.data, iris.target
inputs = inputs.astype(np.float32)
inputs_train, inputs_test, outputs_train, outputs_test = train_test_split(inputs, outputs)
classifier = RandomForestClassifier()
classifier.fit(inputs_train, outputs_train)
# Convert into ONNX format.
from skl2onnx import to_onnx
output_filename = "sklearn_randomforest.onnx"
onnx_content = to_onnx(classifier, inputs[:1])
with open(output_filename, "wb") as f:
f.write(onnx_content.SerializeToString())
# Compute the prediction with onnxruntime.
import onnxruntime as ort
def float_formatter(f):
return f"{float(f):.06f}"
np.set_printoptions(formatter = {'float_kind': float_formatter})
session = ort.InferenceSession(output_filename)
print(f"Input names: {[n.name for n in session.get_inputs()]!s}")
print(f"Output names: {[o.name for o in session.get_outputs()]!s}")
example_inputs = inputs_test.astype(np.float32)[:6]
print(f"Inputs shape = {example_inputs.shape!s}")
onnx_predictions = session.run(["output_label", "output_probability"],
{"X": example_inputs})
labels = onnx_predictions[0]
probabilities = onnx_predictions[1]
print(f"Inputs to network: {example_inputs.astype(np.float32)}")
print(f"ONNX predicted labels: {labels!s}")
print(f"ONNX predicted probabilities: {probabilities!s}")

Binary file not shown.