mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-29 01:42:27 +08:00
Initial Sequence support
- This adds support for ONNX sequences, along with basic unit tests for the sequence types. - This also introduces a new test network generated via sklearn in an included script. - The test using sklearn_randomforest.onnx has not been written yet, but I wanted to commit after doing all the work so far. The existing tests all pass.
This commit is contained in:
@@ -691,7 +691,6 @@ func TestDynamicInputOutputAxes(t *testing.T) {
|
||||
t.Fatalf("Error loading %s: %s\n", netPath, e)
|
||||
}
|
||||
defer session.Destroy()
|
||||
rng := rand.New(rand.NewSource(1234))
|
||||
maxBatchSize := 99
|
||||
// The example network takes a dynamic batch size of vectors containing 10
|
||||
// elements each.
|
||||
@@ -708,10 +707,7 @@ func TestDynamicInputOutputAxes(t *testing.T) {
|
||||
}
|
||||
|
||||
// Populate the input with new random floats.
|
||||
inputData := input.GetData()
|
||||
for i := range inputData {
|
||||
inputData[i] = rng.Float32()
|
||||
}
|
||||
fillRandomFloats(input.GetData(), 1234)
|
||||
|
||||
// Run the session; make onnxruntime allocate the output tensor for us.
|
||||
outputs := []Value{nil}
|
||||
@@ -725,6 +721,7 @@ func TestDynamicInputOutputAxes(t *testing.T) {
|
||||
// The checkVectorSum function will destroy the input and output tensor
|
||||
// regardless of their correctness.
|
||||
checkVectorSum(input, outputs[0].(*Tensor[float32]), t)
|
||||
input.Destroy()
|
||||
t.Logf("Batch size %d seems OK!\n", i)
|
||||
}
|
||||
}
|
||||
@@ -962,6 +959,13 @@ func randomBytes(seed, n int64) []byte {
|
||||
return toReturn
|
||||
}
|
||||
|
||||
func fillRandomFloats(dst []float32, seed int64) {
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
for i := range dst {
|
||||
dst[i] = rng.Float32()
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomDataTensors(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
@@ -1091,6 +1095,128 @@ func TestFloat16Network(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a 10-element tensor randomly filled values using the given rng seed.
|
||||
func randomSmallTensor(seed int64, t testing.TB) *Tensor[float32] {
|
||||
toReturn, e := NewEmptyTensor[float32](NewShape(10))
|
||||
if e != nil {
|
||||
t.Fatalf("Error creating small tensor: %s\n", e)
|
||||
}
|
||||
fillRandomFloats(toReturn.GetData(), seed)
|
||||
return toReturn
|
||||
}
|
||||
|
||||
func TestONNXSequence(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
sequenceLength := int64(123)
|
||||
|
||||
values := make([]Value, sequenceLength)
|
||||
for i := range values {
|
||||
values[i] = randomSmallTensor(int64(i)+123, t)
|
||||
}
|
||||
defer func() {
|
||||
for _, v := range values {
|
||||
v.Destroy()
|
||||
}
|
||||
}()
|
||||
sequence, e := NewSequence(values)
|
||||
if e != nil {
|
||||
t.Fatalf("Error creating sequence: %s\n", e)
|
||||
}
|
||||
defer sequence.Destroy()
|
||||
if int64(sequence.GetValueCount()) != sequenceLength {
|
||||
t.Fatalf("Got %d values in sequence, expected %d\n",
|
||||
sequence.GetValueCount(), sequenceLength)
|
||||
}
|
||||
if sequence.GetONNXType() != ONNXTypeSequence {
|
||||
t.Fatalf("Got incorrect ONNX type for sequence: %s\n",
|
||||
sequence.GetONNXType())
|
||||
}
|
||||
// Not guaranteed by onnxruntime, but it _is_ something that I wrote in the
|
||||
// docs.
|
||||
if !sequence.GetShape().Equals(NewShape(sequenceLength)) {
|
||||
t.Fatalf("Sequence.GetShape() returned incorrect shape: %s\n",
|
||||
sequence.GetShape())
|
||||
}
|
||||
|
||||
_, e = sequence.GetValue(9999)
|
||||
if e == nil {
|
||||
t.Fatalf("Did not get an error when accessing out of a " +
|
||||
"sequence's bounds\n")
|
||||
}
|
||||
t.Logf("Got expected error when accessing out of bounds: %s\n", e)
|
||||
_, e = sequence.GetValue(-1)
|
||||
if e == nil {
|
||||
t.Fatalf("Did not get an error when accessing a negative index\n")
|
||||
}
|
||||
t.Logf("Got expected error when accessing a negative index: %s\n", e)
|
||||
|
||||
// We know from the C API docs that this needs to be destroyed
|
||||
selectedIndex := int64(44)
|
||||
selectedValue, e := sequence.GetValue(selectedIndex)
|
||||
if e != nil {
|
||||
t.Fatalf("Error getting sequence value at index %d: %s\n",
|
||||
selectedIndex, e)
|
||||
}
|
||||
defer selectedValue.Destroy()
|
||||
|
||||
if selectedValue.GetONNXType() != ONNXTypeTensor {
|
||||
t.Fatalf("Got incorrect ONNXType for value at index %d: "+
|
||||
"expected %s, got %s\n", selectedIndex, ONNXType(ONNXTypeTensor),
|
||||
selectedValue.GetONNXType())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadSequences(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
|
||||
// Sequences containing no elements or nil entries shouldn't be allowed
|
||||
_, e := NewSequence([]Value{})
|
||||
if e == nil {
|
||||
t.Fatalf("Didn't get expected error when creating an empty sequence\n")
|
||||
}
|
||||
t.Logf("Got expected error when creating an empty sequence: %s\n", e)
|
||||
_, e = NewSequence([]Value{nil})
|
||||
if e == nil {
|
||||
t.Fatalf("Didn't get expected error when creating sequence with a " +
|
||||
"nil entry.\n")
|
||||
}
|
||||
t.Logf("Got expected error when creating sequence with nil entry: %s\n", e)
|
||||
|
||||
// Sequences containing mixed data types shouldn't be allowed
|
||||
tensor := randomSmallTensor(1337, t)
|
||||
defer tensor.Destroy()
|
||||
innerSequence, e := NewSequence([]Value{tensor})
|
||||
if e != nil {
|
||||
t.Fatalf("Error creating 1-element sequence: %s\n", e)
|
||||
}
|
||||
defer innerSequence.Destroy()
|
||||
_, e = NewSequence([]Value{tensor, innerSequence})
|
||||
if e == nil {
|
||||
t.Fatalf("Didn't get expected error when attempting to create a "+
|
||||
"mixed sequence: %s\n", e)
|
||||
}
|
||||
t.Logf("Got expected error when attempting a mixed sequence: %s\n", e)
|
||||
|
||||
// Nested sequences also aren't allowed; the C API docs don't seem to
|
||||
// mention this either.
|
||||
_, e = NewSequence([]Value{innerSequence, innerSequence})
|
||||
if e == nil {
|
||||
t.Fatalf("Didn't get an error creating a sequence with nested " +
|
||||
"sequences.\n")
|
||||
}
|
||||
t.Logf("Got expected error when creating a sequence with nested "+
|
||||
"sequences: %s\n", e)
|
||||
}
|
||||
|
||||
func TestSklearnNetwork(t *testing.T) {
|
||||
// TODO: TestSklearnNetwork
|
||||
// - One of its outputs is a sequence of maps, make sure it works.
|
||||
// - Check the values printed by the python script for test inputs and
|
||||
// expected outputs.
|
||||
}
|
||||
|
||||
// See the comment in generate_network_big_compute.py for information about
|
||||
// the inputs and outputs used for testing or benchmarking session options.
|
||||
func prepareBenchmarkTensors(t testing.TB, seed int64) (*Tensor[float32],
|
||||
|
||||
Reference in New Issue
Block a user