mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-29 01:42:27 +08:00
Check for valid shapes more thoroughly
- This change adds several more detailed error types that may be returned when attempting to work with shapes that are invalid for some reason. - Added the Shape.Validate() function, which may be called by users, but is internally called when creating new tensors.
This commit is contained in:
@@ -136,6 +136,56 @@ func TestCreateTensor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadTensorShapes(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer DestroyEnvironment()
|
||||
s := NewShape()
|
||||
_, e := NewEmptyTensor[float64](s)
|
||||
if e == nil {
|
||||
t.Logf("Didn't get an error when creating a tensor with an empty " +
|
||||
"shape.\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating a tensor with an empty shape: "+
|
||||
"%s\n", e)
|
||||
s = NewShape(10, 0, 10)
|
||||
_, e = NewEmptyTensor[uint16](s)
|
||||
if e == nil {
|
||||
t.Logf("Didn't get an error when creating a tensor with a shape " +
|
||||
"containing a 0 dimension.\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating a tensor with a 0 dimension: "+
|
||||
"%s\n", e)
|
||||
s = NewShape(10, 10, -10)
|
||||
_, e = NewEmptyTensor[int32](s)
|
||||
if e == nil {
|
||||
t.Logf("Didn't get an error when creating a tensor with a negative " +
|
||||
"dimension.\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating a tensor with a negative "+
|
||||
"dimension: %s\n", e)
|
||||
s = NewShape(10, -10, -10)
|
||||
_, e = NewEmptyTensor[uint64](s)
|
||||
if e == nil {
|
||||
t.Logf("Didn't get an error when creating a tensor with two " +
|
||||
"negative dimensions.\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating a tensor with two negative "+
|
||||
"dimensions: %s\n", e)
|
||||
s = NewShape(int64(1)<<62, 1, int64(1)<<62)
|
||||
_, e = NewEmptyTensor[float32](s)
|
||||
if e == nil {
|
||||
t.Logf("Didn't get an error when creating a tensor with an " +
|
||||
"overflowing shape.\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating a tensor with an overflowing "+
|
||||
"shape: %s\n", e)
|
||||
}
|
||||
|
||||
func TestCloneTensor(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
originalData := []float32{1, 2, 3, 4}
|
||||
|
||||
Reference in New Issue
Block a user