mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-05 07:06:51 +08:00
Initial Sequence support
- This adds support for ONNX sequences, along with basic unit tests for the sequence types. - This also introduces a new test network generated via sklearn in an included script. - The test using sklearn_randomforest.onnx has not been written yet, but I wanted to commit after doing all the work so far. The existing tests all pass.
This commit is contained in:
@@ -488,6 +488,131 @@ func (t TensorElementDataType) String() string {
|
|||||||
return fmt.Sprintf("Unknown tensor element data type: %d", int(t))
|
return fmt.Sprintf("Unknown tensor element data type: %d", int(t))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This wraps an ONNX_TYPE_SEQUENCE OrtValue. Individual elements must be
|
||||||
|
// accessed using GetValue. Satisfies the Value interface, though Tensor-
|
||||||
|
// related functions such as ZeroContents() may be no-ops.
|
||||||
|
type Sequence struct {
|
||||||
|
ortValue *C.OrtValue
|
||||||
|
// We'll stash the number of values in the sequence here, so we don't need
|
||||||
|
// to call C.GetValueCount more than once.
|
||||||
|
valueCount int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a new ONNX sequence with the given contents. The returned Sequence
|
||||||
|
// must be Destroyed by the caller when no longer needed. Destroying the
|
||||||
|
// Sequence created by this function does _not_ destroy the Values it contains,
|
||||||
|
// so the caller is still responsible for destroying them as well.
|
||||||
|
//
|
||||||
|
// The contents of a sequence are subject to additional constraints. I can't
|
||||||
|
// find mention of some of these in the C API docs, but they are enforced by
|
||||||
|
// the onnxruntime API. Notably: all elements of the sequence must have the
|
||||||
|
// same type, and all elements must be either maps or tensors. Finally, the
|
||||||
|
// sequence must contain at least one element, and none of the elements may be
|
||||||
|
// nil. There may be other constraints that I am unaware of, as well.
|
||||||
|
func NewSequence(contents []Value) (*Sequence, error) {
|
||||||
|
if !IsInitialized() {
|
||||||
|
return nil, NotInitializedError
|
||||||
|
}
|
||||||
|
length := int64(len(contents))
|
||||||
|
if length == 0 {
|
||||||
|
return nil, fmt.Errorf("Sequences must contain at least 1 element")
|
||||||
|
}
|
||||||
|
ortValues := make([]*C.OrtValue, length)
|
||||||
|
for i, v := range contents {
|
||||||
|
if v == nil {
|
||||||
|
// I don't actually know if NULL OrtValue pointers are allowed in
|
||||||
|
// sequences, but I'm assuming not.
|
||||||
|
return nil, fmt.Errorf("Sequences must not contain nil (index "+
|
||||||
|
"%d was nil)", i)
|
||||||
|
}
|
||||||
|
ortValues[i] = v.GetInternals().ortValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Avoid dereferencing ortValues[0] if the list is empty.
|
||||||
|
var valuesPtr **C.OrtValue
|
||||||
|
if length > 0 {
|
||||||
|
valuesPtr = &(ortValues[0])
|
||||||
|
}
|
||||||
|
var sequence *C.OrtValue
|
||||||
|
status := C.CreateOrtValue(valuesPtr, C.size_t(length),
|
||||||
|
C.ONNX_TYPE_SEQUENCE, &sequence)
|
||||||
|
if status != nil {
|
||||||
|
return nil, fmt.Errorf("Error creating ORT sequence: %s",
|
||||||
|
statusToError(status))
|
||||||
|
}
|
||||||
|
return &Sequence{
|
||||||
|
ortValue: sequence,
|
||||||
|
valueCount: length,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sequence) Destroy() error {
|
||||||
|
C.ReleaseOrtValue(s.ortValue)
|
||||||
|
s.ortValue = nil
|
||||||
|
s.valueCount = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of elements in the sequence.
|
||||||
|
func (s *Sequence) GetValueCount() int64 {
|
||||||
|
return s.valueCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// This returns a 1-dimensional Shape containing a single element: the number
|
||||||
|
// of elements the sequence. Typically, Sequence users should prefer
|
||||||
|
// GetValueCount() to this function, since this function only exists to
|
||||||
|
// maintain compatibility with the Value interface.
|
||||||
|
func (s *Sequence) GetShape() Shape {
|
||||||
|
return NewShape(s.valueCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sequence) GetONNXType() ONNXType {
|
||||||
|
return ONNXTypeSequence
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function is meaningless for a Sequence and shouldn't be used. The
|
||||||
|
// return value is always TENSOR_ELEMENT_DATA_TYPE_UNDEFINED for now, but this
|
||||||
|
// may change in the future. This function is only present for compatibility
|
||||||
|
// with the Value interface and should not be relied on for sequences.
|
||||||
|
func (s *Sequence) DataType() C.ONNXTensorElementDataType {
|
||||||
|
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function does nothing for a Sequence, and is only present for
|
||||||
|
// compatibility with the Value interface.
|
||||||
|
func (s *Sequence) ZeroContents() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sequence) GetInternals() *ValueInternalData {
|
||||||
|
return &ValueInternalData{
|
||||||
|
ortValue: s.ortValue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the value at the given index in the sequence. Will return an error
|
||||||
|
// if the index exceeds the number of elements in the sequence, or for any
|
||||||
|
// other reason.
|
||||||
|
//
|
||||||
|
// IMPORTANT: The Value returned by this function must be destroyed by the
|
||||||
|
// caller when it is no longer needed. This is because the GetValue C API for
|
||||||
|
// sequences internally allocates a new OrtValue. Put another way, this does
|
||||||
|
// _not_ return the same Value instance that was provided to NewSequence, even
|
||||||
|
// though, if it's a tensor, it should still refer to the same underlying data
|
||||||
|
// slice.
|
||||||
|
func (s *Sequence) GetValue(index int64) (Value, error) {
|
||||||
|
if (index < 0) || (index >= s.valueCount) {
|
||||||
|
return nil, fmt.Errorf("Invalid index (%d) into sequence of length %d",
|
||||||
|
index, s.valueCount)
|
||||||
|
}
|
||||||
|
var result *C.OrtValue
|
||||||
|
status := C.GetValue(s.ortValue, C.int(index), &result)
|
||||||
|
if status != nil {
|
||||||
|
return nil, fmt.Errorf("Error getting value of index %d: %s", index,
|
||||||
|
statusToError(status))
|
||||||
|
}
|
||||||
|
return createGoValueFromOrtValue(result)
|
||||||
|
}
|
||||||
|
|
||||||
// Wraps the ONNXType enum in C.
|
// Wraps the ONNXType enum in C.
|
||||||
type ONNXType int
|
type ONNXType int
|
||||||
|
|
||||||
@@ -1369,7 +1494,53 @@ func getShapeFromInfo(t *C.OrtTensorTypeAndShapeInfo) (Shape, error) {
|
|||||||
return shape, nil
|
return shape, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Make createTensorFromOrtValue work for non-Tensor OrtValues.
|
// Returns the ONNXType associated with a C OrtValue.
|
||||||
|
func getValueType(v *C.OrtValue) (ONNXType, error) {
|
||||||
|
var t C.enum_ONNXType
|
||||||
|
status := C.GetValueType(v, &t)
|
||||||
|
if status != nil {
|
||||||
|
return ONNXTypeUnknown, fmt.Errorf("Error looking up type for "+
|
||||||
|
"OrtValue: %s", statusToError(status))
|
||||||
|
}
|
||||||
|
return ONNXType(t), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the "count" associated with an OrtValue. Mostly useful for
|
||||||
|
// sequences. Should always return 2 for a map. Not sure what it returns for
|
||||||
|
// Tensors, but that shouldn't matter.
|
||||||
|
func getValueCount(v *C.OrtValue) (int64, error) {
|
||||||
|
var size C.size_t
|
||||||
|
status := C.GetValueCount(v, &size)
|
||||||
|
if status != nil {
|
||||||
|
return 0, fmt.Errorf("Error getting non tensor count for OrtValue: %s",
|
||||||
|
statusToError(status))
|
||||||
|
}
|
||||||
|
return int64(size), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Takes a OrtValue and returns an appropriate Go Value wrapping it. Does not
|
||||||
|
// copy the value. This function assumes that v will be destroyed later when
|
||||||
|
// the returned go Value is destroyed.
|
||||||
|
func createGoValueFromOrtValue(v *C.OrtValue) (Value, error) {
|
||||||
|
// TODO: How to handle v == nil? Is it even possible to get nil here?
|
||||||
|
valueType, e := getValueType(v)
|
||||||
|
if e != nil {
|
||||||
|
return nil, e
|
||||||
|
}
|
||||||
|
switch valueType {
|
||||||
|
case ONNXTypeTensor:
|
||||||
|
return createTensorFromOrtValue(v)
|
||||||
|
case ONNXTypeSequence:
|
||||||
|
return createSequenceFromOrtValue(v)
|
||||||
|
default:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("It is currently not supported to create a Go "+
|
||||||
|
"value from OrtValues with ONNXType = %s", valueType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must only be called if v is known to be of type ONNXTensor. Returns a Tensor
|
||||||
|
// wrapping v with the correct Go type.
|
||||||
func createTensorFromOrtValue(v *C.OrtValue) (Value, error) {
|
func createTensorFromOrtValue(v *C.OrtValue) (Value, error) {
|
||||||
var pInfo *C.OrtTensorTypeAndShapeInfo
|
var pInfo *C.OrtTensorTypeAndShapeInfo
|
||||||
status := C.GetTensorTypeAndShape(v, &pInfo)
|
status := C.GetTensorTypeAndShape(v, &pInfo)
|
||||||
@@ -1424,11 +1595,24 @@ func createTensorFromOrtValue(v *C.OrtValue) (Value, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Must only be called if v is already known to be an ONNXTypeSequence. Returns
|
||||||
|
// a Sequence go type wrapping v.
|
||||||
|
func createSequenceFromOrtValue(v *C.OrtValue) (Value, error) {
|
||||||
|
length, e := getValueCount(v)
|
||||||
|
if e != nil {
|
||||||
|
return nil, fmt.Errorf("Error determining sequence length: %w", e)
|
||||||
|
}
|
||||||
|
return &Sequence{
|
||||||
|
ortValue: v,
|
||||||
|
valueCount: length,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Runs the network on the given input and output tensors. The number of input
|
// Runs the network on the given input and output tensors. The number of input
|
||||||
// and output tensors must match the number (and order) of the input and output
|
// and output tensors must match the number (and order) of the input and output
|
||||||
// names specified to NewDynamicAdvancedSession.
|
// names specified to NewDynamicAdvancedSession. If a given output is nil, it
|
||||||
// If a given output is nil, it will be allocated and the slice will be modified
|
// will be allocated and the slice will be modified to include the new Value.
|
||||||
// to include the new tensor. The new tensor must be freed by calling Destroy on it.
|
// Any new Value allocated in this way must be freed by calling Destroy on it.
|
||||||
func (s *DynamicAdvancedSession) Run(inputs, outputs []Value) error {
|
func (s *DynamicAdvancedSession) Run(inputs, outputs []Value) error {
|
||||||
if len(inputs) != len(s.s.inputNames) {
|
if len(inputs) != len(s.s.inputNames) {
|
||||||
return fmt.Errorf("The session specified %d input names, but Run() "+
|
return fmt.Errorf("The session specified %d input names, but Run() "+
|
||||||
@@ -1446,25 +1630,30 @@ func (s *DynamicAdvancedSession) Run(inputs, outputs []Value) error {
|
|||||||
}
|
}
|
||||||
outputValues := make([]*C.OrtValue, len(outputs))
|
outputValues := make([]*C.OrtValue, len(outputs))
|
||||||
for i, v := range outputs {
|
for i, v := range outputs {
|
||||||
if v != nil {
|
if v == nil {
|
||||||
|
// Leave any output that needs to be allocated as nil.
|
||||||
|
continue
|
||||||
|
}
|
||||||
outputValues[i] = v.GetInternals().ortValue
|
outputValues[i] = v.GetInternals().ortValue
|
||||||
}
|
}
|
||||||
}
|
|
||||||
status := C.RunOrtSession(s.s.ortSession, &inputValues[0],
|
status := C.RunOrtSession(s.s.ortSession, &inputValues[0],
|
||||||
&s.s.inputNames[0], C.int(len(inputs)), &outputValues[0],
|
&s.s.inputNames[0], C.int(len(inputs)), &outputValues[0],
|
||||||
&s.s.outputNames[0], C.int(len(outputs)))
|
&s.s.outputNames[0], C.int(len(outputs)))
|
||||||
if status != nil {
|
if status != nil {
|
||||||
return fmt.Errorf("Error running network: %w", statusToError(status))
|
return fmt.Errorf("Error running network: %w", statusToError(status))
|
||||||
}
|
}
|
||||||
|
// Convert any automatically-allocated output to a go Value.
|
||||||
for i, v := range outputs {
|
for i, v := range outputs {
|
||||||
if v == nil {
|
if v != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
var err error
|
var err error
|
||||||
outputs[i], err = createTensorFromOrtValue(outputValues[i])
|
outputs[i], err = createTensorFromOrtValue(outputValues[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error creating tensor from ort: %w", err)
|
return fmt.Errorf("Error creating tensor from ort: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1480,22 +1669,65 @@ func (s *DynamicAdvancedSession) GetModelMetadata() (*ModelMetadata, error) {
|
|||||||
type InputOutputInfo struct {
|
type InputOutputInfo struct {
|
||||||
// The name of the input or output
|
// The name of the input or output
|
||||||
Name string
|
Name string
|
||||||
// The input or output's dimensions
|
// The higher-level "type" of the output; whether it's a tensor, sequence,
|
||||||
|
// map, etc.
|
||||||
|
OrtValueType ONNXType
|
||||||
|
// The input or output's dimensions, if it's a tensor. This should be
|
||||||
|
// ignored for non-tensor types.
|
||||||
Dimensions Shape
|
Dimensions Shape
|
||||||
// The input or output's data type
|
// The type of element in the input or output, if it's a tensor. This
|
||||||
|
// should be ignored for non-tensor types.
|
||||||
DataType TensorElementDataType
|
DataType TensorElementDataType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *InputOutputInfo) String() string {
|
func (n *InputOutputInfo) String() string {
|
||||||
return fmt.Sprintf("\"%s\": %s, %s", n.Name, n.Dimensions, n.DataType)
|
switch n.OrtValueType {
|
||||||
|
case ONNXTypeUnknown:
|
||||||
|
return fmt.Sprintf("Unknown ONNX type: %s", n.Name)
|
||||||
|
case ONNXTypeTensor:
|
||||||
|
return fmt.Sprintf("Tensor \"%s\": %s, %s", n.Name, n.Dimensions,
|
||||||
|
n.DataType)
|
||||||
|
case ONNXTypeSequence:
|
||||||
|
return fmt.Sprintf("Sequence \"%s\"", n.Name)
|
||||||
|
case ONNXTypeMap:
|
||||||
|
return fmt.Sprintf("Map \"%s\"", n.Name)
|
||||||
|
case ONNXTypeOpaque:
|
||||||
|
return fmt.Sprintf("Opaque \"%s\"", n.Name)
|
||||||
|
case ONNXTypeSparseTensor:
|
||||||
|
return fmt.Sprintf("Sparse tensor \"%s\": dense shape %s, %s",
|
||||||
|
n.Name, n.Dimensions, n.DataType)
|
||||||
|
case ONNXTypeOptional:
|
||||||
|
return fmt.Sprintf("Optional \"%s\"", n.Name)
|
||||||
|
default:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// We'll use the ONNXType String() output if we don't know the type.
|
||||||
|
return fmt.Sprintf("%s: \"%s\"", n.OrtValueType, n.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets o.Dimensions and o.DataType from the contents of t.
|
// Sets o.OrtValueType, o.DataType, and o.Dimensions from the contents of t.
|
||||||
func (o *InputOutputInfo) fillFromTypeInfo(t *C.OrtTypeInfo) error {
|
func (o *InputOutputInfo) fillFromTypeInfo(t *C.OrtTypeInfo) error {
|
||||||
|
var onnxType C.enum_ONNXType
|
||||||
|
status := C.GetONNXTypeFromTypeInfo(t, &onnxType)
|
||||||
|
if status != nil {
|
||||||
|
return fmt.Errorf("Error getting ONNX type: %s", statusToError(status))
|
||||||
|
}
|
||||||
|
o.OrtValueType = ONNXType(onnxType)
|
||||||
|
o.Dimensions = nil
|
||||||
|
o.DataType = TensorElementDataTypeUndefined
|
||||||
|
|
||||||
|
// We only fill in element type and dimensions if we're dealing with a
|
||||||
|
// tensor of some sort.
|
||||||
|
isTensorType := (o.OrtValueType == ONNXTypeTensor) ||
|
||||||
|
(o.OrtValueType == ONNXTypeSparseTensor)
|
||||||
|
if !isTensorType {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// OrtTensorTypeAndShapeInfo pointers should *not* be released if they're
|
// OrtTensorTypeAndShapeInfo pointers should *not* be released if they're
|
||||||
// obtained via CastTypeInfoToTensorInfo.
|
// obtained via CastTypeInfoToTensorInfo.
|
||||||
var typeAndShapeInfo *C.OrtTensorTypeAndShapeInfo
|
var typeAndShapeInfo *C.OrtTensorTypeAndShapeInfo
|
||||||
status := C.CastTypeInfoToTensorInfo(t, &typeAndShapeInfo)
|
status = C.CastTypeInfoToTensorInfo(t, &typeAndShapeInfo)
|
||||||
if status != nil {
|
if status != nil {
|
||||||
return fmt.Errorf("Error getting type and shape info: %w",
|
return fmt.Errorf("Error getting type and shape info: %w",
|
||||||
statusToError(status))
|
statusToError(status))
|
||||||
|
@@ -691,7 +691,6 @@ func TestDynamicInputOutputAxes(t *testing.T) {
|
|||||||
t.Fatalf("Error loading %s: %s\n", netPath, e)
|
t.Fatalf("Error loading %s: %s\n", netPath, e)
|
||||||
}
|
}
|
||||||
defer session.Destroy()
|
defer session.Destroy()
|
||||||
rng := rand.New(rand.NewSource(1234))
|
|
||||||
maxBatchSize := 99
|
maxBatchSize := 99
|
||||||
// The example network takes a dynamic batch size of vectors containing 10
|
// The example network takes a dynamic batch size of vectors containing 10
|
||||||
// elements each.
|
// elements each.
|
||||||
@@ -708,10 +707,7 @@ func TestDynamicInputOutputAxes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Populate the input with new random floats.
|
// Populate the input with new random floats.
|
||||||
inputData := input.GetData()
|
fillRandomFloats(input.GetData(), 1234)
|
||||||
for i := range inputData {
|
|
||||||
inputData[i] = rng.Float32()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run the session; make onnxruntime allocate the output tensor for us.
|
// Run the session; make onnxruntime allocate the output tensor for us.
|
||||||
outputs := []Value{nil}
|
outputs := []Value{nil}
|
||||||
@@ -725,6 +721,7 @@ func TestDynamicInputOutputAxes(t *testing.T) {
|
|||||||
// The checkVectorSum function will destroy the input and output tensor
|
// The checkVectorSum function will destroy the input and output tensor
|
||||||
// regardless of their correctness.
|
// regardless of their correctness.
|
||||||
checkVectorSum(input, outputs[0].(*Tensor[float32]), t)
|
checkVectorSum(input, outputs[0].(*Tensor[float32]), t)
|
||||||
|
input.Destroy()
|
||||||
t.Logf("Batch size %d seems OK!\n", i)
|
t.Logf("Batch size %d seems OK!\n", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -962,6 +959,13 @@ func randomBytes(seed, n int64) []byte {
|
|||||||
return toReturn
|
return toReturn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func fillRandomFloats(dst []float32, seed int64) {
|
||||||
|
rng := rand.New(rand.NewSource(seed))
|
||||||
|
for i := range dst {
|
||||||
|
dst[i] = rng.Float32()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCustomDataTensors(t *testing.T) {
|
func TestCustomDataTensors(t *testing.T) {
|
||||||
InitializeRuntime(t)
|
InitializeRuntime(t)
|
||||||
defer CleanupRuntime(t)
|
defer CleanupRuntime(t)
|
||||||
@@ -1091,6 +1095,128 @@ func TestFloat16Network(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a 10-element tensor randomly filled values using the given rng seed.
|
||||||
|
func randomSmallTensor(seed int64, t testing.TB) *Tensor[float32] {
|
||||||
|
toReturn, e := NewEmptyTensor[float32](NewShape(10))
|
||||||
|
if e != nil {
|
||||||
|
t.Fatalf("Error creating small tensor: %s\n", e)
|
||||||
|
}
|
||||||
|
fillRandomFloats(toReturn.GetData(), seed)
|
||||||
|
return toReturn
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestONNXSequence(t *testing.T) {
|
||||||
|
InitializeRuntime(t)
|
||||||
|
defer CleanupRuntime(t)
|
||||||
|
sequenceLength := int64(123)
|
||||||
|
|
||||||
|
values := make([]Value, sequenceLength)
|
||||||
|
for i := range values {
|
||||||
|
values[i] = randomSmallTensor(int64(i)+123, t)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
for _, v := range values {
|
||||||
|
v.Destroy()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
sequence, e := NewSequence(values)
|
||||||
|
if e != nil {
|
||||||
|
t.Fatalf("Error creating sequence: %s\n", e)
|
||||||
|
}
|
||||||
|
defer sequence.Destroy()
|
||||||
|
if int64(sequence.GetValueCount()) != sequenceLength {
|
||||||
|
t.Fatalf("Got %d values in sequence, expected %d\n",
|
||||||
|
sequence.GetValueCount(), sequenceLength)
|
||||||
|
}
|
||||||
|
if sequence.GetONNXType() != ONNXTypeSequence {
|
||||||
|
t.Fatalf("Got incorrect ONNX type for sequence: %s\n",
|
||||||
|
sequence.GetONNXType())
|
||||||
|
}
|
||||||
|
// Not guaranteed by onnxruntime, but it _is_ something that I wrote in the
|
||||||
|
// docs.
|
||||||
|
if !sequence.GetShape().Equals(NewShape(sequenceLength)) {
|
||||||
|
t.Fatalf("Sequence.GetShape() returned incorrect shape: %s\n",
|
||||||
|
sequence.GetShape())
|
||||||
|
}
|
||||||
|
|
||||||
|
_, e = sequence.GetValue(9999)
|
||||||
|
if e == nil {
|
||||||
|
t.Fatalf("Did not get an error when accessing out of a " +
|
||||||
|
"sequence's bounds\n")
|
||||||
|
}
|
||||||
|
t.Logf("Got expected error when accessing out of bounds: %s\n", e)
|
||||||
|
_, e = sequence.GetValue(-1)
|
||||||
|
if e == nil {
|
||||||
|
t.Fatalf("Did not get an error when accessing a negative index\n")
|
||||||
|
}
|
||||||
|
t.Logf("Got expected error when accessing a negative index: %s\n", e)
|
||||||
|
|
||||||
|
// We know from the C API docs that this needs to be destroyed
|
||||||
|
selectedIndex := int64(44)
|
||||||
|
selectedValue, e := sequence.GetValue(selectedIndex)
|
||||||
|
if e != nil {
|
||||||
|
t.Fatalf("Error getting sequence value at index %d: %s\n",
|
||||||
|
selectedIndex, e)
|
||||||
|
}
|
||||||
|
defer selectedValue.Destroy()
|
||||||
|
|
||||||
|
if selectedValue.GetONNXType() != ONNXTypeTensor {
|
||||||
|
t.Fatalf("Got incorrect ONNXType for value at index %d: "+
|
||||||
|
"expected %s, got %s\n", selectedIndex, ONNXType(ONNXTypeTensor),
|
||||||
|
selectedValue.GetONNXType())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBadSequences(t *testing.T) {
|
||||||
|
InitializeRuntime(t)
|
||||||
|
defer CleanupRuntime(t)
|
||||||
|
|
||||||
|
// Sequences containing no elements or nil entries shouldn't be allowed
|
||||||
|
_, e := NewSequence([]Value{})
|
||||||
|
if e == nil {
|
||||||
|
t.Fatalf("Didn't get expected error when creating an empty sequence\n")
|
||||||
|
}
|
||||||
|
t.Logf("Got expected error when creating an empty sequence: %s\n", e)
|
||||||
|
_, e = NewSequence([]Value{nil})
|
||||||
|
if e == nil {
|
||||||
|
t.Fatalf("Didn't get expected error when creating sequence with a " +
|
||||||
|
"nil entry.\n")
|
||||||
|
}
|
||||||
|
t.Logf("Got expected error when creating sequence with nil entry: %s\n", e)
|
||||||
|
|
||||||
|
// Sequences containing mixed data types shouldn't be allowed
|
||||||
|
tensor := randomSmallTensor(1337, t)
|
||||||
|
defer tensor.Destroy()
|
||||||
|
innerSequence, e := NewSequence([]Value{tensor})
|
||||||
|
if e != nil {
|
||||||
|
t.Fatalf("Error creating 1-element sequence: %s\n", e)
|
||||||
|
}
|
||||||
|
defer innerSequence.Destroy()
|
||||||
|
_, e = NewSequence([]Value{tensor, innerSequence})
|
||||||
|
if e == nil {
|
||||||
|
t.Fatalf("Didn't get expected error when attempting to create a "+
|
||||||
|
"mixed sequence: %s\n", e)
|
||||||
|
}
|
||||||
|
t.Logf("Got expected error when attempting a mixed sequence: %s\n", e)
|
||||||
|
|
||||||
|
// Nested sequences also aren't allowed; the C API docs don't seem to
|
||||||
|
// mention this either.
|
||||||
|
_, e = NewSequence([]Value{innerSequence, innerSequence})
|
||||||
|
if e == nil {
|
||||||
|
t.Fatalf("Didn't get an error creating a sequence with nested " +
|
||||||
|
"sequences.\n")
|
||||||
|
}
|
||||||
|
t.Logf("Got expected error when creating a sequence with nested "+
|
||||||
|
"sequences: %s\n", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSklearnNetwork(t *testing.T) {
|
||||||
|
// TODO: TestSklearnNetwork
|
||||||
|
// - One of its outputs is a sequence of maps, make sure it works.
|
||||||
|
// - Check the values printed by the python script for test inputs and
|
||||||
|
// expected outputs.
|
||||||
|
}
|
||||||
|
|
||||||
// See the comment in generate_network_big_compute.py for information about
|
// See the comment in generate_network_big_compute.py for information about
|
||||||
// the inputs and outputs used for testing or benchmarking session options.
|
// the inputs and outputs used for testing or benchmarking session options.
|
||||||
func prepareBenchmarkTensors(t testing.TB, seed int64) (*Tensor[float32],
|
func prepareBenchmarkTensors(t testing.TB, seed int64) (*Tensor[float32],
|
||||||
|
@@ -289,6 +289,10 @@ void ReleaseTypeInfo(OrtTypeInfo *o) {
|
|||||||
ort_api->ReleaseTypeInfo(o);
|
ort_api->ReleaseTypeInfo(o);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OrtStatus *GetONNXTypeFromTypeInfo(OrtTypeInfo *info, enum ONNXType *out) {
|
||||||
|
return ort_api->GetOnnxTypeFromTypeInfo(info, out);
|
||||||
|
}
|
||||||
|
|
||||||
OrtStatus *CastTypeInfoToTensorInfo(OrtTypeInfo *type_info,
|
OrtStatus *CastTypeInfoToTensorInfo(OrtTypeInfo *type_info,
|
||||||
OrtTensorTypeAndShapeInfo **out) {
|
OrtTensorTypeAndShapeInfo **out) {
|
||||||
return ort_api->CastTypeInfoToTensorInfo(type_info,
|
return ort_api->CastTypeInfoToTensorInfo(type_info,
|
||||||
@@ -359,6 +363,28 @@ OrtStatus *ModelMetadataGetVersion(OrtModelMetadata *m, int64_t *version) {
|
|||||||
return ort_api->ModelMetadataGetVersion(m, version);
|
return ort_api->ModelMetadataGetVersion(m, version);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OrtStatus *GetValue(OrtValue *container, int index, OrtValue **dst) {
|
||||||
|
OrtAllocator *allocator = NULL;
|
||||||
|
OrtStatus *status = NULL;
|
||||||
|
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
|
||||||
|
if (status) return status;
|
||||||
|
return ort_api->GetValue(container, index, allocator, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtStatus *GetValueType(OrtValue *v, enum ONNXType *out) {
|
||||||
|
return ort_api->GetValueType(v, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtStatus *GetValueCount(OrtValue *v, size_t *out) {
|
||||||
|
return ort_api->GetValueCount(v, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
|
||||||
|
enum ONNXType value_type, OrtValue **out) {
|
||||||
|
return ort_api->CreateValue((const OrtValue* const*) in, num_values,
|
||||||
|
value_type, out);
|
||||||
|
}
|
||||||
|
|
||||||
// TRAINING API WRAPPER
|
// TRAINING API WRAPPER
|
||||||
|
|
||||||
static const OrtTrainingApi *ort_training_api = NULL;
|
static const OrtTrainingApi *ort_training_api = NULL;
|
||||||
@@ -504,3 +530,4 @@ void ReleaseOrtTrainingSession(OrtTrainingSession *session) {
|
|||||||
void ReleaseCheckpointState(OrtCheckpointState *checkpoint) {
|
void ReleaseCheckpointState(OrtCheckpointState *checkpoint) {
|
||||||
ort_training_api->ReleaseCheckpointState(checkpoint);
|
ort_training_api->ReleaseCheckpointState(checkpoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -198,6 +198,9 @@ OrtStatus *SessionGetOutputTypeInfo(OrtSession *session, size_t i,
|
|||||||
OrtStatus *CastTypeInfoToTensorInfo(OrtTypeInfo *type_info,
|
OrtStatus *CastTypeInfoToTensorInfo(OrtTypeInfo *type_info,
|
||||||
OrtTensorTypeAndShapeInfo **out);
|
OrtTensorTypeAndShapeInfo **out);
|
||||||
|
|
||||||
|
// Wraps ort_api->GetOnnxTypeFromTypeInfo.
|
||||||
|
OrtStatus *GetONNXTypeFromTypeInfo(OrtTypeInfo *info, enum ONNXType *out);
|
||||||
|
|
||||||
// Wraps ort_api->FreeTypeInfo.
|
// Wraps ort_api->FreeTypeInfo.
|
||||||
void ReleaseTypeInfo(OrtTypeInfo *o);
|
void ReleaseTypeInfo(OrtTypeInfo *o);
|
||||||
|
|
||||||
@@ -232,6 +235,19 @@ OrtStatus *ModelMetadataGetCustomMetadataMapKeys(OrtModelMetadata *m,
|
|||||||
// Wraps ort_api->ModelMetadataGetVersion.
|
// Wraps ort_api->ModelMetadataGetVersion.
|
||||||
OrtStatus *ModelMetadataGetVersion(OrtModelMetadata *m, int64_t *version);
|
OrtStatus *ModelMetadataGetVersion(OrtModelMetadata *m, int64_t *version);
|
||||||
|
|
||||||
|
// Wraps ort_api->GetValue. Uses the default allocator.
|
||||||
|
OrtStatus *GetValue(OrtValue *container, int index, OrtValue **dst);
|
||||||
|
|
||||||
|
// Wraps ort_api->GetValueType.
|
||||||
|
OrtStatus *GetValueType(OrtValue *v, enum ONNXType *out);
|
||||||
|
|
||||||
|
// Wraps ort_api->GetValueCount.
|
||||||
|
OrtStatus *GetValueCount(OrtValue *v, size_t *out);
|
||||||
|
|
||||||
|
// Wraps ort_api->CreateValue to create a map or a sequence.
|
||||||
|
OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
|
||||||
|
enum ONNXType value_type, OrtValue **out);
|
||||||
|
|
||||||
// TRAINING API WRAPPER
|
// TRAINING API WRAPPER
|
||||||
|
|
||||||
void SetTrainingApi();
|
void SetTrainingApi();
|
||||||
@@ -241,11 +257,12 @@ int IsTrainingApiSupported();
|
|||||||
|
|
||||||
// Wraps ort_training_api->CreateSessionFromBuffer.
|
// Wraps ort_training_api->CreateSessionFromBuffer.
|
||||||
// Creates and ORT checkpoint from the checkpoint data.
|
// Creates and ORT checkpoint from the checkpoint data.
|
||||||
OrtStatus *CreateCheckpoint(void *checkpoint_data, size_t checkpoint_data_length, OrtCheckpointState **out);
|
OrtStatus *CreateCheckpoint(void *checkpoint_data,
|
||||||
|
size_t checkpoint_data_length, OrtCheckpointState **out);
|
||||||
|
|
||||||
// Wraps ort_training_api->CreateTrainingSessionFromBuffer.
|
// Wraps ort_training_api->CreateTrainingSessionFromBuffer. Creates an ORT
|
||||||
// Creates an ORT training session using the given models and checkpoint.
|
// training session using the given models and checkpoint. The given options
|
||||||
// The given options pointer may be NULL; if it is, then we'll use default options.
|
// pointer may be NULL; if it is, then we'll use default options.
|
||||||
OrtStatus *CreateTrainingSessionFromBuffer(OrtCheckpointState *checkpoint_state,
|
OrtStatus *CreateTrainingSessionFromBuffer(OrtCheckpointState *checkpoint_state,
|
||||||
void *training_model_data, size_t training_model_data_length,
|
void *training_model_data, size_t training_model_data_length,
|
||||||
void *eval_model_data, size_t eval_model_data_length,
|
void *eval_model_data, size_t eval_model_data_length,
|
||||||
|
45
test_data/generate_sklearn_network.py
Normal file
45
test_data/generate_sklearn_network.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# This script is a modified version of the example from
|
||||||
|
# https://pypi.org/project/skl2onnx/, which we use to produce
|
||||||
|
# sklearn_randomforest.onnx. sklearn makes heavy use of onnxruntime maps and
|
||||||
|
# sequences in its networks, so this is used for testing those data types.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
|
||||||
|
iris = load_iris()
|
||||||
|
inputs, outputs = iris.data, iris.target
|
||||||
|
inputs = inputs.astype(np.float32)
|
||||||
|
inputs_train, inputs_test, outputs_train, outputs_test = train_test_split(inputs, outputs)
|
||||||
|
classifier = RandomForestClassifier()
|
||||||
|
classifier.fit(inputs_train, outputs_train)
|
||||||
|
|
||||||
|
# Convert into ONNX format.
|
||||||
|
from skl2onnx import to_onnx
|
||||||
|
output_filename = "sklearn_randomforest.onnx"
|
||||||
|
onnx_content = to_onnx(classifier, inputs[:1])
|
||||||
|
with open(output_filename, "wb") as f:
|
||||||
|
f.write(onnx_content.SerializeToString())
|
||||||
|
|
||||||
|
# Compute the prediction with onnxruntime.
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
def float_formatter(f):
|
||||||
|
return f"{float(f):.06f}"
|
||||||
|
|
||||||
|
np.set_printoptions(formatter = {'float_kind': float_formatter})
|
||||||
|
session = ort.InferenceSession(output_filename)
|
||||||
|
print(f"Input names: {[n.name for n in session.get_inputs()]!s}")
|
||||||
|
print(f"Output names: {[o.name for o in session.get_outputs()]!s}")
|
||||||
|
example_inputs = inputs_test.astype(np.float32)[:6]
|
||||||
|
print(f"Inputs shape = {example_inputs.shape!s}")
|
||||||
|
onnx_predictions = session.run(["output_label", "output_probability"],
|
||||||
|
{"X": example_inputs})
|
||||||
|
labels = onnx_predictions[0]
|
||||||
|
probabilities = onnx_predictions[1]
|
||||||
|
|
||||||
|
print(f"Inputs to network: {example_inputs.astype(np.float32)}")
|
||||||
|
print(f"ONNX predicted labels: {labels!s}")
|
||||||
|
print(f"ONNX predicted probabilities: {probabilities!s}")
|
||||||
|
|
BIN
test_data/sklearn_randomforest.onnx
Normal file
BIN
test_data/sklearn_randomforest.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user