Implement string tensor support

- This change adds support for tensors of
   ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING. Unfortunately, this can't be
   shoehorned into Tensor[T], since there's no obvious way to have go
   manage the backing buffer of memory. Instead this change adds a
   StringTensor type, that satisfies the OrtValue interface.

 - Added a new test .onnx file that takes a vector of strings and
   converts each string into uppercase and lowercase outputs.
This commit is contained in:
yalue
2025-12-04 12:02:03 -05:00
parent b832b6b775
commit 737d7d2d9e
6 changed files with 452 additions and 15 deletions

View File

@@ -1026,6 +1026,177 @@ func (t *CustomDataTensor) GetData() []byte {
return t.data
}
// This represents an onnxruntime tensor containing strings. This still
// satisfies the Value interface, but has several important differences with
// Tensor[T] instances for numerical values. Most notably,
// StringTensor.GetContents() returns a _copy_ of the tensor's contents, and
// modifying these strings will not modify the contents of the underlying
// tensor. Instead, users must use StringTensor.SetElement(...) or
// StringTensor.SetContents(...) to modify the contents of an existing string
// tensor.
type StringTensor struct {
shape Shape
ortValue *C.OrtValue
}
// Creates and returns a string tensor. The contents are _not_ initialized yet,
// and must be initialized using SetContents or, less efficiently, using
// SetElement to set each string individually. As with all Values,
// StringTensors must be freed using Destroy() when no longer needed.
func NewStringTensor(shape Shape) (*StringTensor, error) {
if !IsInitialized() {
return nil, NotInitializedError
}
e := shape.Validate()
if e != nil {
return nil, fmt.Errorf("Invalid string tensor shape: %w", e)
}
var ortValue *C.OrtValue
status := C.CreateTensorAsOrtValue((*C.int64_t)(&shape[0]),
C.int64_t(len(shape)), C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
&ortValue)
if status != nil {
return nil, statusToError(status)
}
return &StringTensor{
shape: shape.Clone(),
ortValue: ortValue,
}, nil
}
// This sets all of the strings in t to the contents from the flattened slice
// of strings. The length of the contents slice must match
// t.GetShape().FlattenedSize(), otherwise this will return an error.
func (t *StringTensor) SetContents(contents []string) error {
if int64(len(contents)) != t.shape.FlattenedSize() {
// I'm sure the C API detects this as well, but we can check for it
// here before needing to allocate a bunch of strings.
return fmt.Errorf("Was provided %d strings, but %d are required",
len(contents), t.shape.FlattenedSize())
}
if len(contents) == 0 {
return nil
}
cStrings := make([]*C.char, len(contents))
for i, s := range contents {
cStrings[i] = C.CString(s)
}
defer freeCStrings(cStrings)
status := C.FillStringTensor(t.ortValue, &(cStrings[0]),
C.size_t(len(contents)))
if status != nil {
return statusToError(status)
}
return nil
}
// Sets the string at the flattened index in t to s.
func (t *StringTensor) SetElement(index int64, s string) error {
cString := C.CString(s)
defer C.free(unsafe.Pointer(cString))
status := C.FillStringTensorElement(t.ortValue, cString, C.size_t(index))
if status != nil {
return statusToError(status)
}
return nil
}
// Returns all contents of the string tensor, in order by flattened index.
// If you need to get all contents in the tensor, this should be more efficient
// than using GetElement(...) to retrieve all strings individually. This
// copies the tensor's contents; modifying the returned slice will not modify
// the tensor.
func (t *StringTensor) GetContents() ([]string, error) {
var dataSize C.size_t
status := C.GetStringTensorDataLength(t.ortValue, &dataSize)
if status != nil {
return nil, statusToError(status)
}
if dataSize == 0 {
// A quick optimization where all the strings are empty, avoids needing
// to mess around with nil pointers, too.
return make([]string, t.shape.FlattenedSize()), nil
}
// I'm assuming nonzero data length implies a non-empty shape.
dataBuffer := make([]byte, dataSize)
offsets := make([]C.size_t, t.shape.FlattenedSize())
// Invoke the C API
status = C.GetStringTensorContent(t.ortValue,
unsafe.Pointer(&(dataBuffer[0])), dataSize, &(offsets[0]),
C.size_t(len(offsets)))
if status != nil {
return nil, statusToError(status)
}
// Extract the individual strings from the data buffer using their offsets.
toReturn := make([]string, len(offsets))
for i := range offsets {
if i == (len(offsets) - 1) {
// The last string doesn't have an end offset; it just goes to the
// end of the buffer.
toReturn[i] = string(dataBuffer[offsets[i]:])
break
}
toReturn[i] = string(dataBuffer[offsets[i]:offsets[i+1]])
}
return toReturn, nil
}
// Returns a single string from t, at the given flattened index.
func (t *StringTensor) GetElement(index int64) (string, error) {
var length C.size_t
status := C.GetStringTensorElementLength(t.ortValue, C.size_t(index),
&length)
if status != nil {
return "", statusToError(status)
}
if length == 0 {
return "", nil
}
data := make([]byte, length)
status = C.GetStringTensorElement(t.ortValue, length, C.size_t(index),
unsafe.Pointer(&(data[0])))
if status != nil {
return "", statusToError(status)
}
return string(data), nil
}
// Always returns C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
func (t *StringTensor) DataType() C.ONNXTensorElementDataType {
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
}
func (t *StringTensor) GetShape() Shape {
return t.shape.Clone()
}
func (t *StringTensor) Destroy() error {
C.ReleaseOrtValue(t.ortValue)
t.ortValue = nil
t.shape = nil
return nil
}
func (t *StringTensor) GetInternals() *ValueInternalData {
return &ValueInternalData{
ortValue: t.ortValue,
}
}
// ZeroContents() is unsupported for string tensors. To clear all strings, use
// err := t.SetContents(make([]string, t.GetShape().FlattenedSize())) instead.
func (t *StringTensor) ZeroContents() {
}
func (t *StringTensor) GetONNXType() ONNXType {
return ONNXTypeTensor
}
// Scalar is like a tensor but the underlying go slice is of length 1 and it
// has no dimension. It was introduced for use with the training API, but
// remains supported since it may be useful apart from the training API.
@@ -2387,41 +2558,58 @@ func createGoValueFromOrtValue(v *C.OrtValue) (Value, error) {
}
// Must only be called if v is known to be of type ONNXTensor. Returns a Tensor
// wrapping v with the correct Go type. This function always copies v's
// contents into a new Tensor backed by a Go-managed slice and releases v.
// wrapping v with the correct Go type. For non-string tensors this will copy
// the tensor's data into a go-managed slice. In case of errors, this will
// destroy the original value. In all non-error cases, the returned value must
// be destroyed by the caller.
//
// The issue is that GetTensorMutableData() becomes invalid after v is
// Released, so we can't release the original OrtValue if a reference to the
// data slice returned by GetData is still in use elsewhere in go code. We work
// around this by copying the data into a new OrtValue with a Go-backed buffer.
// (String tensors are easier, since they can't have mutable data buffers in
// the first place.)
func createTensorFromOrtValue(v *C.OrtValue) (Value, error) {
// Either in the case of error or otherwise, we'll release v. The issue is
// that GetTensorMutableData() becomes invalid after v is Released, so we
// can't release v if a reference to the slice returned by GetData is
// still referred to outside of the tensor. We work around this by copying
// the data into a new tensor and releasing the original.
defer C.ReleaseOrtValue(v)
var pInfo *C.OrtTensorTypeAndShapeInfo
status := C.GetTensorTypeAndShape(v, &pInfo)
if status != nil {
C.ReleaseOrtValue(v)
return nil, fmt.Errorf("Error getting type and shape: %w",
statusToError(status))
}
shape, e := getShapeFromInfo(pInfo)
if e != nil {
C.ReleaseOrtValue(v)
return nil, fmt.Errorf("Error getting shape from TypeAndShapeInfo: %w",
e)
}
var tensorElementType C.ONNXTensorElementDataType
status = C.GetTensorElementType(pInfo, (*uint32)(&tensorElementType))
if status != nil {
C.ReleaseOrtValue(v)
return nil, fmt.Errorf("Error getting tensor element type: %w",
statusToError(status))
}
C.ReleaseTensorTypeAndShapeInfo(pInfo)
if tensorElementType == C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING {
// String tensors. No go-managed data, so no copying needed.
return &StringTensor{
shape: shape,
ortValue: v,
}, nil
}
// For non-string tensors, we will always release the original OrtValue.
defer C.ReleaseOrtValue(v)
// Now we start the process of copying the data into a Go-backed OrtValue.
var tensorData unsafe.Pointer
status = C.GetTensorMutableData(v, &tensorData)
if status != nil {
return nil, fmt.Errorf("Error getting tensor mutable data: %w",
statusToError(status))
}
switch tensorType := TensorElementDataType(tensorElementType); tensorType {
case TensorElementDataTypeFloat:
return createTensorWithCData[float32](shape, tensorData)

View File

@@ -6,6 +6,7 @@ import (
"math/rand"
"os"
"runtime"
"strings"
"testing"
"time"
)
@@ -236,6 +237,124 @@ func TestBoolTensor(t *testing.T) {
}
}
func TestStringTensor(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
logContents := func(s *StringTensor, label string) {
contents, e := s.GetContents()
if e != nil {
t.Fatalf("Error getting contents for %s tensor: %s\n", label, e)
}
t.Logf("Contents of tensor %s:\n", label)
for i, v := range contents {
t.Logf(" index %d = %s\n", i, v)
}
}
// Start with simple string tensor manipulation and verifying error cases.
input, e := NewStringTensor(NewShape(2))
if e != nil {
t.Fatalf("Error creating input tensor: %s\n", e)
}
defer input.Destroy()
// I assume these will be blank.
logContents(input, "initial values")
e = input.SetContents([]string{"I", "eat", "popcorn!"})
if e == nil {
t.Errorf("Didn't get expected error when providing an incorrect " +
"number of strings to SetContents.\n")
}
t.Logf("Got expected error when providing an incorrect number of "+
"strings to SetContents: %s\n", e)
e = input.SetContents([]string{"what's", "up?"})
if e != nil {
t.Fatalf("Got error when setting initial tensor contents: %s\n", e)
}
logContents(input, "after SetContents")
tmpString, e := input.GetElement(1)
if e != nil {
t.Fatalf("Error getting tensor element: %s\n", e)
}
if tmpString != "up?" {
t.Errorf("Got incorrect value for tensor element 1. Got \"%s\".\n",
tmpString)
}
e = input.SetElement(45, "???")
if e == nil {
t.Fatalf("Didn't get expected error when setting an element at an " +
"invalid index")
}
t.Logf("Got expected error when setting an element at an invalid "+
"index: %s\n", e)
e = input.SetElement(1, "")
if e != nil {
t.Fatalf("Error setting element at index 1: %s\n", e)
}
logContents(input, "after SetElement")
// Next, we'll test executing a network, including allowing onnxruntime to
// auto-allocate a string tensor.
inputContents := []string{"Green Beans", "Something Else"}
e = input.SetContents(inputContents)
if e != nil {
t.Fatalf("Error setting final input contents: %s\n", e)
}
outputUppercase, e := NewStringTensor(NewShape(2))
if e != nil {
t.Fatalf("Error creating output tensor: %s\n", e)
}
defer outputUppercase.Destroy()
outputs := []Value{outputUppercase, nil}
filePath := "test_data/example_strings.onnx"
session, e := NewDynamicAdvancedSession(filePath, []string{"input"},
[]string{"output_upper", "output_lower"}, nil)
if e != nil {
t.Fatalf("Error creating session for %s: %s\n", filePath, e)
}
defer session.Destroy()
e = session.Run([]Value{input}, outputs)
if e != nil {
t.Fatalf("Error running %s: %s\n", filePath, e)
}
defer outputs[1].Destroy()
outputLowercase, ok := outputs[1].(*StringTensor)
if !ok {
t.Fatalf("Running %s didn't create a StringTensor output", filePath)
}
logContents(input, "input")
logContents(outputLowercase, "outputLowercase")
logContents(outputUppercase, "outputUppercase")
uppercaseStrings, e := outputUppercase.GetContents()
if e != nil {
t.Fatalf("Error getting uppercase contents: %s\n", e)
}
lowercaseStrings, e := outputLowercase.GetContents()
if e != nil {
t.Fatalf("Error getting lowercase contents: %s\n", e)
}
for i, original := range inputContents {
if strings.ToLower(original) != lowercaseStrings[i] {
t.Errorf("Didn't get expected lowercase version of %s, got %s\n",
original, lowercaseStrings[i])
}
if strings.ToUpper(original) != uppercaseStrings[i] {
t.Errorf("Didn't get expected uppercase version of %s, got %s\n",
original, uppercaseStrings[i])
}
}
}
func TestBadTensorShapes(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
@@ -2251,7 +2370,8 @@ func TestCancelWithRunOptions_AdvancedSession(t *testing.T) {
defer output.Destroy()
filePath := "test_data/example ż 大 김.onnx"
session, e := NewAdvancedSession(filePath, []string{"in"}, []string{"out"}, []Value{input}, []Value{output}, nil)
session, e := NewAdvancedSession(filePath, []string{"in"}, []string{"out"},
[]Value{input}, []Value{output}, nil)
if e != nil {
t.Fatalf("Failed creating session for %s: %s\n", filePath, e)
}

View File

@@ -366,6 +366,17 @@ OrtStatus *CreateOrtTensorWithShape(void *data, size_t data_size,
return status;
}
OrtStatus *CreateTensorAsOrtValue(int64_t *shape, int64_t shape_size,
ONNXTensorElementDataType dtype, OrtValue **out) {
OrtStatus *status = NULL;
OrtAllocator *allocator = NULL;
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
if (status) return status;
status = ort_api->CreateTensorAsOrtValue(allocator, shape, shape_size, dtype,
out);
return status;
}
OrtStatus *GetTensorTypeAndShape(const OrtValue *value, OrtTensorTypeAndShapeInfo **out) {
return ort_api->GetTensorTypeAndShape(value, out);
}
@@ -523,3 +534,33 @@ OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
return ort_api->CreateValue((const OrtValue* const*) in, num_values,
value_type, out);
}
OrtStatus *FillStringTensor(OrtValue *v, char **strings, size_t num_strings) {
return ort_api->FillStringTensor(v, (const char* const*) strings,
num_strings);
}
OrtStatus *GetStringTensorDataLength(OrtValue *v, size_t *length) {
return ort_api->GetStringTensorDataLength(v, length);
}
OrtStatus *GetStringTensorContent(OrtValue *v, void *data_buffer,
size_t data_size, size_t *offsets_buffer, size_t offsets_length) {
return ort_api->GetStringTensorContent(v, data_buffer, data_size,
offsets_buffer, offsets_length);
}
OrtStatus *FillStringTensorElement(OrtValue *v, char *s, size_t index) {
return ort_api->FillStringTensorElement(v, s, index);
}
OrtStatus *GetStringTensorElementLength(OrtValue *v, size_t index,
size_t *result) {
return ort_api->GetStringTensorElementLength(v, index, result);
}
OrtStatus *GetStringTensorElement(OrtValue *v, size_t buffer_length,
size_t index, void *buffer) {
return ort_api->GetStringTensorElement(v, buffer_length, index, buffer);
}

View File

@@ -242,17 +242,26 @@ OrtStatus *CreateOrtTensorWithShape(void *data, size_t data_size,
int64_t *shape, int64_t shape_size, OrtMemoryInfo *mem_info,
ONNXTensorElementDataType dtype, OrtValue **out);
// Creates an OrtValue managed by onnxruntime's default allocator rather than
// using Go-managed memory. Wraps ort_api->CreateTensorAsOrtValue.
OrtStatus *CreateTensorAsOrtValue(int64_t *shape, int64_t shape_size,
ONNXTensorElementDataType dtype, OrtValue **out);
// Wraps ort_api->GetTensorTypeAndShape
OrtStatus *GetTensorTypeAndShape(const OrtValue *value, OrtTensorTypeAndShapeInfo **out);
OrtStatus *GetTensorTypeAndShape(const OrtValue *value,
OrtTensorTypeAndShapeInfo **out);
// Wraps ort_api->GetDimensionsCount
OrtStatus *GetDimensionsCount(const OrtTensorTypeAndShapeInfo *info, size_t *out);
OrtStatus *GetDimensionsCount(const OrtTensorTypeAndShapeInfo *info,
size_t *out);
// Wraps ort_api->GetDimensions
OrtStatus *GetDimensions(const OrtTensorTypeAndShapeInfo *info, int64_t *dim_values, size_t dim_values_length);
OrtStatus *GetDimensions(const OrtTensorTypeAndShapeInfo *info,
int64_t *dim_values, size_t dim_values_length);
// Wraps ort_api->GetTensorElementType
OrtStatus *GetTensorElementType(const OrtTensorTypeAndShapeInfo *info, enum ONNXTensorElementDataType *out);
OrtStatus *GetTensorElementType(const OrtTensorTypeAndShapeInfo *info,
enum ONNXTensorElementDataType *out);
// Wraps ort_api->ReleaseTensorTypeAndShapeInfo
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *input);
@@ -335,6 +344,27 @@ OrtStatus *GetValueCount(OrtValue *v, size_t *out);
OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
enum ONNXType value_type, OrtValue **out);
// Wraps ort_api->FillStringTensor
OrtStatus *FillStringTensor(OrtValue *v, char **strings, size_t num_strings);
// Wraps ort_api->GetStringTensorDataLength
OrtStatus *GetStringTensorDataLength(OrtValue *v, size_t *length);
// Wraps ort_api->GetStringTensorContent
OrtStatus *GetStringTensorContent(OrtValue *v, void *data_buffer,
size_t data_size, size_t *offsets_buffer, size_t offsets_length);
// Wraps ort_api->FillStringTensorElement
OrtStatus *FillStringTensorElement(OrtValue *v, char *s, size_t index);
// Wraps ort_api->GetStringTensorElementLength
OrtStatus *GetStringTensorElementLength(OrtValue *v, size_t index,
size_t *result);
// Wraps ort_api->GetStringTensorElement
OrtStatus *GetStringTensorElement(OrtValue *v, size_t buffer_length,
size_t index, void *buffer);
#ifdef __cplusplus
} // extern "C"
#endif

Binary file not shown.

View File

@@ -0,0 +1,58 @@
# This script generates "example_strings.onnx". This example takes a 1xN tensor
# of N strings, and produces two 1xN outputs: one with the strings converted to
# lowercase and one with the strings converted to uppercase.
import numpy as np
import onnx
from onnx import helper, TensorProto
import onnxruntime as ort
def main():
# Describe the inputs and outputs
input_info = helper.make_tensor_value_info("input", TensorProto.STRING,
[None])
output_lower_info = helper.make_tensor_value_info("output_lower",
TensorProto.STRING, [None])
output_upper_info = helper.make_tensor_value_info("output_upper",
TensorProto.STRING, [None])
node_lower = helper.make_node(
"StringNormalizer",
inputs=["input"],
outputs=["output_lower"],
case_change_action="LOWER",
)
node_upper = helper.make_node(
"StringNormalizer",
inputs=["input"],
outputs=["output_upper"],
case_change_action="UPPER",
)
graph = helper.make_graph(
[node_lower, node_upper],
"strings_example_graph",
[input_info],
[output_lower_info, output_upper_info],
)
model = helper.make_model(graph,
producer_name="generate_strings_example.py")
onnx.checker.check_model(model)
filename = "example_strings.onnx"
onnx.save(model, filename)
print(f"Saved {filename} OK. Testing...")
session = ort.InferenceSession(filename)
inputs = np.array(["I", "eAt", "POTATOEs!!"])
output_lower, output_upper = session.run(None, {"input": inputs})
print("Inputs: " + str(inputs))
print("Lowercase: " + str(output_lower))
print("Upercase: " + str(output_upper))
if __name__ == "__main__":
main()