mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-29 09:52:33 +08:00
Check for valid shapes more thoroughly
- This change adds several more detailed error types that may be returned when attempting to work with shapes that are invalid for some reason. - Added the Shape.Validate() function, which may be called by users, but is internally called when creating new tensors.
This commit is contained in:
@@ -28,6 +28,23 @@ var ortMemoryInfo *C.OrtMemoryInfo
|
||||
var NotInitializedError error = fmt.Errorf("InitializeRuntime() has either " +
|
||||
"not yet been called, or did not return successfully")
|
||||
|
||||
var ZeroShapeLengthError error = fmt.Errorf("The shape has no dimensions")
|
||||
|
||||
var ShapeOverflowError error = fmt.Errorf("The shape's flattened size " +
|
||||
"overflows an int64")
|
||||
|
||||
// This type of error is returned when we attempt to validate a tensor that has
|
||||
// a negative or 0 dimension.
|
||||
type BadShapeDimensionError struct {
|
||||
DimensionIndex int
|
||||
DimensionSize int64
|
||||
}
|
||||
|
||||
func (e *BadShapeDimensionError) Error() string {
|
||||
return fmt.Sprintf("Dimension %d of the shape has invalid value %d",
|
||||
e.DimensionIndex, e.DimensionSize)
|
||||
}
|
||||
|
||||
// Does two things: converts the given OrtStatus to a Go error, and releases
|
||||
// the status. If the status is nil, this does nothing and returns nil.
|
||||
func statusToError(status *C.OrtStatus) error {
|
||||
@@ -121,7 +138,10 @@ func NewShape(dimensions ...int64) Shape {
|
||||
return Shape(dimensions)
|
||||
}
|
||||
|
||||
// Returns the total number of elements in a tensor with the given shape.
|
||||
// Returns the total number of elements in a tensor with the given shape. Note
|
||||
// that this may be an invalid value due to overflow or negative dimensions. If
|
||||
// a shape comes from an untrusted source, it may be a good practice to call
|
||||
// Validate() prior to trusting the FlattenedSize.
|
||||
func (s Shape) FlattenedSize() int64 {
|
||||
if len(s) == 0 {
|
||||
return 0
|
||||
@@ -133,6 +153,38 @@ func (s Shape) FlattenedSize() int64 {
|
||||
return toReturn
|
||||
}
|
||||
|
||||
// Returns a non-nil error if the shape has bad or zero dimensions. May return
|
||||
// a ZeroShapeLengthError, a ShapeOverflowError, or an BadShapeDimensionError.
|
||||
// In the future, this may return other types of errors if it others become
|
||||
// necessary.
|
||||
func (s Shape) Validate() error {
|
||||
if len(s) == 0 {
|
||||
return ZeroShapeLengthError
|
||||
}
|
||||
if s[0] <= 0 {
|
||||
return &BadShapeDimensionError{
|
||||
DimensionIndex: 0,
|
||||
DimensionSize: s[0],
|
||||
}
|
||||
}
|
||||
flattenedSize := int64(s[0])
|
||||
for i := 1; i < len(s); i++ {
|
||||
d := s[i]
|
||||
if d <= 0 {
|
||||
return &BadShapeDimensionError{
|
||||
DimensionIndex: i,
|
||||
DimensionSize: d,
|
||||
}
|
||||
}
|
||||
tmp := flattenedSize * d
|
||||
if tmp < flattenedSize {
|
||||
return ShapeOverflowError
|
||||
}
|
||||
flattenedSize = tmp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Makes and returns a deep copy of the Shape.
|
||||
func (s Shape) Clone() Shape {
|
||||
toReturn := make([]int64, len(s))
|
||||
@@ -205,10 +257,11 @@ func (t *Tensor[T]) Clone() (*Tensor[T], error) {
|
||||
// Creates a new empty tensor with the given shape. The shape provided to this
|
||||
// function is copied, and is no longer needed after this function returns.
|
||||
func NewEmptyTensor[T TensorData](s Shape) (*Tensor[T], error) {
|
||||
elementCount := s.FlattenedSize()
|
||||
if elementCount == 0 {
|
||||
return nil, fmt.Errorf("Got invalid shape containing 0 elements")
|
||||
e := s.Validate()
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("Invalid tensor shape: %w", e)
|
||||
}
|
||||
elementCount := s.FlattenedSize()
|
||||
data := make([]T, elementCount)
|
||||
return NewTensor(s, data)
|
||||
}
|
||||
@@ -221,7 +274,10 @@ func NewTensor[T TensorData](s Shape, data []T) (*Tensor[T], error) {
|
||||
if !IsInitialized() {
|
||||
return nil, NotInitializedError
|
||||
}
|
||||
|
||||
e := s.Validate()
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("Invalid tensor shape: %w", e)
|
||||
}
|
||||
elementCount := s.FlattenedSize()
|
||||
if elementCount > int64(len(data)) {
|
||||
return nil, fmt.Errorf("The tensor's shape (%s) requires %d "+
|
||||
|
||||
@@ -136,6 +136,56 @@ func TestCreateTensor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadTensorShapes(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer DestroyEnvironment()
|
||||
s := NewShape()
|
||||
_, e := NewEmptyTensor[float64](s)
|
||||
if e == nil {
|
||||
t.Logf("Didn't get an error when creating a tensor with an empty " +
|
||||
"shape.\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating a tensor with an empty shape: "+
|
||||
"%s\n", e)
|
||||
s = NewShape(10, 0, 10)
|
||||
_, e = NewEmptyTensor[uint16](s)
|
||||
if e == nil {
|
||||
t.Logf("Didn't get an error when creating a tensor with a shape " +
|
||||
"containing a 0 dimension.\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating a tensor with a 0 dimension: "+
|
||||
"%s\n", e)
|
||||
s = NewShape(10, 10, -10)
|
||||
_, e = NewEmptyTensor[int32](s)
|
||||
if e == nil {
|
||||
t.Logf("Didn't get an error when creating a tensor with a negative " +
|
||||
"dimension.\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating a tensor with a negative "+
|
||||
"dimension: %s\n", e)
|
||||
s = NewShape(10, -10, -10)
|
||||
_, e = NewEmptyTensor[uint64](s)
|
||||
if e == nil {
|
||||
t.Logf("Didn't get an error when creating a tensor with two " +
|
||||
"negative dimensions.\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating a tensor with two negative "+
|
||||
"dimensions: %s\n", e)
|
||||
s = NewShape(int64(1)<<62, 1, int64(1)<<62)
|
||||
_, e = NewEmptyTensor[float32](s)
|
||||
if e == nil {
|
||||
t.Logf("Didn't get an error when creating a tensor with an " +
|
||||
"overflowing shape.\n")
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Got expected error when creating a tensor with an overflowing "+
|
||||
"shape: %s\n", e)
|
||||
}
|
||||
|
||||
func TestCloneTensor(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
originalData := []float32{1, 2, 3, 4}
|
||||
|
||||
Reference in New Issue
Block a user