diff --git a/onnxruntime_go.go b/onnxruntime_go.go index 2e2acfe..fb6666a 100644 --- a/onnxruntime_go.go +++ b/onnxruntime_go.go @@ -28,6 +28,23 @@ var ortMemoryInfo *C.OrtMemoryInfo var NotInitializedError error = fmt.Errorf("InitializeRuntime() has either " + "not yet been called, or did not return successfully") +var ZeroShapeLengthError error = fmt.Errorf("The shape has no dimensions") + +var ShapeOverflowError error = fmt.Errorf("The shape's flattened size " + + "overflows an int64") + +// This type of error is returned when we attempt to validate a tensor that has +// a negative or 0 dimension. +type BadShapeDimensionError struct { + DimensionIndex int + DimensionSize int64 +} + +func (e *BadShapeDimensionError) Error() string { + return fmt.Sprintf("Dimension %d of the shape has invalid value %d", + e.DimensionIndex, e.DimensionSize) +} + // Does two things: converts the given OrtStatus to a Go error, and releases // the status. If the status is nil, this does nothing and returns nil. func statusToError(status *C.OrtStatus) error { @@ -121,7 +138,10 @@ func NewShape(dimensions ...int64) Shape { return Shape(dimensions) } -// Returns the total number of elements in a tensor with the given shape. +// Returns the total number of elements in a tensor with the given shape. Note +// that this may be an invalid value due to overflow or negative dimensions. If +// a shape comes from an untrusted source, it may be a good practice to call +// Validate() prior to trusting the FlattenedSize. func (s Shape) FlattenedSize() int64 { if len(s) == 0 { return 0 @@ -133,6 +153,38 @@ func (s Shape) FlattenedSize() int64 { return toReturn } +// Returns a non-nil error if the shape has bad or zero dimensions. May return +// a ZeroShapeLengthError, a ShapeOverflowError, or an BadShapeDimensionError. +// In the future, this may return other types of errors if it others become +// necessary. +func (s Shape) Validate() error { + if len(s) == 0 { + return ZeroShapeLengthError + } + if s[0] <= 0 { + return &BadShapeDimensionError{ + DimensionIndex: 0, + DimensionSize: s[0], + } + } + flattenedSize := int64(s[0]) + for i := 1; i < len(s); i++ { + d := s[i] + if d <= 0 { + return &BadShapeDimensionError{ + DimensionIndex: i, + DimensionSize: d, + } + } + tmp := flattenedSize * d + if tmp < flattenedSize { + return ShapeOverflowError + } + flattenedSize = tmp + } + return nil +} + // Makes and returns a deep copy of the Shape. func (s Shape) Clone() Shape { toReturn := make([]int64, len(s)) @@ -205,10 +257,11 @@ func (t *Tensor[T]) Clone() (*Tensor[T], error) { // Creates a new empty tensor with the given shape. The shape provided to this // function is copied, and is no longer needed after this function returns. func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error) { - elementCount := s.FlattenedSize() - if elementCount == 0 { - return nil, fmt.Errorf("Got invalid shape containing 0 elements") + e := s.Validate() + if e != nil { + return nil, fmt.Errorf("Invalid tensor shape: %w", e) } + elementCount := s.FlattenedSize() data := make([]T, elementCount) return NewTensor(s, data) } @@ -221,7 +274,10 @@ func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) { if !IsInitialized() { return nil, NotInitializedError } - + e := s.Validate() + if e != nil { + return nil, fmt.Errorf("Invalid tensor shape: %w", e) + } elementCount := s.FlattenedSize() if elementCount > int64(len(data)) { return nil, fmt.Errorf("The tensor's shape (%s) requires %d "+ diff --git a/onnxruntime_test.go b/onnxruntime_test.go index 474df99..804e1a9 100644 --- a/onnxruntime_test.go +++ b/onnxruntime_test.go @@ -136,6 +136,56 @@ func TestCreateTensor(t *testing.T) { } } +func TestBadTensorShapes(t *testing.T) { + InitializeRuntime(t) + defer DestroyEnvironment() + s := NewShape() + _, e := NewEmptyTensor[float64](s) + if e == nil { + t.Logf("Didn't get an error when creating a tensor with an empty " + + "shape.\n") + t.FailNow() + } + t.Logf("Got expected error when creating a tensor with an empty shape: "+ + "%s\n", e) + s = NewShape(10, 0, 10) + _, e = NewEmptyTensor[uint16](s) + if e == nil { + t.Logf("Didn't get an error when creating a tensor with a shape " + + "containing a 0 dimension.\n") + t.FailNow() + } + t.Logf("Got expected error when creating a tensor with a 0 dimension: "+ + "%s\n", e) + s = NewShape(10, 10, -10) + _, e = NewEmptyTensor[int32](s) + if e == nil { + t.Logf("Didn't get an error when creating a tensor with a negative " + + "dimension.\n") + t.FailNow() + } + t.Logf("Got expected error when creating a tensor with a negative "+ + "dimension: %s\n", e) + s = NewShape(10, -10, -10) + _, e = NewEmptyTensor[uint64](s) + if e == nil { + t.Logf("Didn't get an error when creating a tensor with two " + + "negative dimensions.\n") + t.FailNow() + } + t.Logf("Got expected error when creating a tensor with two negative "+ + "dimensions: %s\n", e) + s = NewShape(int64(1)<<62, 1, int64(1)<<62) + _, e = NewEmptyTensor[float32](s) + if e == nil { + t.Logf("Didn't get an error when creating a tensor with an " + + "overflowing shape.\n") + t.FailNow() + } + t.Logf("Got expected error when creating a tensor with an overflowing "+ + "shape: %s\n", e) +} + func TestCloneTensor(t *testing.T) { InitializeRuntime(t) originalData := []float32{1, 2, 3, 4}