Check for valid shapes more thoroughly

- This change adds several more detailed error types that may be
   returned when attempting to work with shapes that are invalid for
   some reason.

 - Added the Shape.Validate() function, which may be called by users,
   but is internally called when creating new tensors.
This commit is contained in:
yalue
2023-05-11 10:11:39 -04:00
parent 1acf4f2a2e
commit ca658dac00
2 changed files with 111 additions and 5 deletions

View File

@@ -28,6 +28,23 @@ var ortMemoryInfo *C.OrtMemoryInfo
var NotInitializedError error = fmt.Errorf("InitializeRuntime() has either " +
"not yet been called, or did not return successfully")
var ZeroShapeLengthError error = fmt.Errorf("The shape has no dimensions")
var ShapeOverflowError error = fmt.Errorf("The shape's flattened size " +
"overflows an int64")
// This type of error is returned when we attempt to validate a tensor that has
// a negative or 0 dimension.
type BadShapeDimensionError struct {
DimensionIndex int
DimensionSize int64
}
func (e *BadShapeDimensionError) Error() string {
return fmt.Sprintf("Dimension %d of the shape has invalid value %d",
e.DimensionIndex, e.DimensionSize)
}
// Does two things: converts the given OrtStatus to a Go error, and releases
// the status. If the status is nil, this does nothing and returns nil.
func statusToError(status *C.OrtStatus) error {
@@ -121,7 +138,10 @@ func NewShape(dimensions ...int64) Shape {
return Shape(dimensions)
}
// Returns the total number of elements in a tensor with the given shape.
// Returns the total number of elements in a tensor with the given shape. Note
// that this may be an invalid value due to overflow or negative dimensions. If
// a shape comes from an untrusted source, it may be a good practice to call
// Validate() prior to trusting the FlattenedSize.
func (s Shape) FlattenedSize() int64 {
if len(s) == 0 {
return 0
@@ -133,6 +153,38 @@ func (s Shape) FlattenedSize() int64 {
return toReturn
}
// Returns a non-nil error if the shape has bad or zero dimensions. May return
// a ZeroShapeLengthError, a ShapeOverflowError, or an BadShapeDimensionError.
// In the future, this may return other types of errors if it others become
// necessary.
func (s Shape) Validate() error {
if len(s) == 0 {
return ZeroShapeLengthError
}
if s[0] <= 0 {
return &BadShapeDimensionError{
DimensionIndex: 0,
DimensionSize: s[0],
}
}
flattenedSize := int64(s[0])
for i := 1; i < len(s); i++ {
d := s[i]
if d <= 0 {
return &BadShapeDimensionError{
DimensionIndex: i,
DimensionSize: d,
}
}
tmp := flattenedSize * d
if tmp < flattenedSize {
return ShapeOverflowError
}
flattenedSize = tmp
}
return nil
}
// Makes and returns a deep copy of the Shape.
func (s Shape) Clone() Shape {
toReturn := make([]int64, len(s))
@@ -205,10 +257,11 @@ func (t *Tensor[T]) Clone() (*Tensor[T], error) {
// Creates a new empty tensor with the given shape. The shape provided to this
// function is copied, and is no longer needed after this function returns.
func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error) {
elementCount := s.FlattenedSize()
if elementCount == 0 {
return nil, fmt.Errorf("Got invalid shape containing 0 elements")
e := s.Validate()
if e != nil {
return nil, fmt.Errorf("Invalid tensor shape: %w", e)
}
elementCount := s.FlattenedSize()
data := make([]T, elementCount)
return NewTensor(s, data)
}
@@ -221,7 +274,10 @@ func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) {
if !IsInitialized() {
return nil, NotInitializedError
}
e := s.Validate()
if e != nil {
return nil, fmt.Errorf("Invalid tensor shape: %w", e)
}
elementCount := s.FlattenedSize()
if elementCount > int64(len(data)) {
return nil, fmt.Errorf("The tensor's shape (%s) requires %d "+

View File

@@ -136,6 +136,56 @@ func TestCreateTensor(t *testing.T) {
}
}
func TestBadTensorShapes(t *testing.T) {
InitializeRuntime(t)
defer DestroyEnvironment()
s := NewShape()
_, e := NewEmptyTensor[float64](s)
if e == nil {
t.Logf("Didn't get an error when creating a tensor with an empty " +
"shape.\n")
t.FailNow()
}
t.Logf("Got expected error when creating a tensor with an empty shape: "+
"%s\n", e)
s = NewShape(10, 0, 10)
_, e = NewEmptyTensor[uint16](s)
if e == nil {
t.Logf("Didn't get an error when creating a tensor with a shape " +
"containing a 0 dimension.\n")
t.FailNow()
}
t.Logf("Got expected error when creating a tensor with a 0 dimension: "+
"%s\n", e)
s = NewShape(10, 10, -10)
_, e = NewEmptyTensor[int32](s)
if e == nil {
t.Logf("Didn't get an error when creating a tensor with a negative " +
"dimension.\n")
t.FailNow()
}
t.Logf("Got expected error when creating a tensor with a negative "+
"dimension: %s\n", e)
s = NewShape(10, -10, -10)
_, e = NewEmptyTensor[uint64](s)
if e == nil {
t.Logf("Didn't get an error when creating a tensor with two " +
"negative dimensions.\n")
t.FailNow()
}
t.Logf("Got expected error when creating a tensor with two negative "+
"dimensions: %s\n", e)
s = NewShape(int64(1)<<62, 1, int64(1)<<62)
_, e = NewEmptyTensor[float32](s)
if e == nil {
t.Logf("Didn't get an error when creating a tensor with an " +
"overflowing shape.\n")
t.FailNow()
}
t.Logf("Got expected error when creating a tensor with an overflowing "+
"shape: %s\n", e)
}
func TestCloneTensor(t *testing.T) {
InitializeRuntime(t)
originalData := []float32{1, 2, 3, 4}