mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-12-24 13:38:00 +08:00
Implement string tensor support
- This change adds support for tensors of ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING. Unfortunately, this can't be shoehorned into Tensor[T], since there's no obvious way to have go manage the backing buffer of memory. Instead this change adds a StringTensor type, that satisfies the OrtValue interface. - Added a new test .onnx file that takes a vector of strings and converts each string into uppercase and lowercase outputs.
This commit is contained in:
@@ -1026,6 +1026,177 @@ func (t *CustomDataTensor) GetData() []byte {
|
||||
return t.data
|
||||
}
|
||||
|
||||
// This represents an onnxruntime tensor containing strings. This still
|
||||
// satisfies the Value interface, but has several important differences with
|
||||
// Tensor[T] instances for numerical values. Most notably,
|
||||
// StringTensor.GetContents() returns a _copy_ of the tensor's contents, and
|
||||
// modifying these strings will not modify the contents of the underlying
|
||||
// tensor. Instead, users must use StringTensor.SetElement(...) or
|
||||
// StringTensor.SetContents(...) to modify the contents of an existing string
|
||||
// tensor.
|
||||
type StringTensor struct {
|
||||
shape Shape
|
||||
ortValue *C.OrtValue
|
||||
}
|
||||
|
||||
// Creates and returns a string tensor. The contents are _not_ initialized yet,
|
||||
// and must be initialized using SetContents or, less efficiently, using
|
||||
// SetElement to set each string individually. As with all Values,
|
||||
// StringTensors must be freed using Destroy() when no longer needed.
|
||||
func NewStringTensor(shape Shape) (*StringTensor, error) {
|
||||
if !IsInitialized() {
|
||||
return nil, NotInitializedError
|
||||
}
|
||||
e := shape.Validate()
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("Invalid string tensor shape: %w", e)
|
||||
}
|
||||
|
||||
var ortValue *C.OrtValue
|
||||
status := C.CreateTensorAsOrtValue((*C.int64_t)(&shape[0]),
|
||||
C.int64_t(len(shape)), C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
|
||||
&ortValue)
|
||||
if status != nil {
|
||||
return nil, statusToError(status)
|
||||
}
|
||||
|
||||
return &StringTensor{
|
||||
shape: shape.Clone(),
|
||||
ortValue: ortValue,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// This sets all of the strings in t to the contents from the flattened slice
|
||||
// of strings. The length of the contents slice must match
|
||||
// t.GetShape().FlattenedSize(), otherwise this will return an error.
|
||||
func (t *StringTensor) SetContents(contents []string) error {
|
||||
if int64(len(contents)) != t.shape.FlattenedSize() {
|
||||
// I'm sure the C API detects this as well, but we can check for it
|
||||
// here before needing to allocate a bunch of strings.
|
||||
return fmt.Errorf("Was provided %d strings, but %d are required",
|
||||
len(contents), t.shape.FlattenedSize())
|
||||
}
|
||||
if len(contents) == 0 {
|
||||
return nil
|
||||
}
|
||||
cStrings := make([]*C.char, len(contents))
|
||||
for i, s := range contents {
|
||||
cStrings[i] = C.CString(s)
|
||||
}
|
||||
defer freeCStrings(cStrings)
|
||||
status := C.FillStringTensor(t.ortValue, &(cStrings[0]),
|
||||
C.size_t(len(contents)))
|
||||
if status != nil {
|
||||
return statusToError(status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sets the string at the flattened index in t to s.
|
||||
func (t *StringTensor) SetElement(index int64, s string) error {
|
||||
cString := C.CString(s)
|
||||
defer C.free(unsafe.Pointer(cString))
|
||||
status := C.FillStringTensorElement(t.ortValue, cString, C.size_t(index))
|
||||
if status != nil {
|
||||
return statusToError(status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns all contents of the string tensor, in order by flattened index.
|
||||
// If you need to get all contents in the tensor, this should be more efficient
|
||||
// than using GetElement(...) to retrieve all strings individually. This
|
||||
// copies the tensor's contents; modifying the returned slice will not modify
|
||||
// the tensor.
|
||||
func (t *StringTensor) GetContents() ([]string, error) {
|
||||
var dataSize C.size_t
|
||||
status := C.GetStringTensorDataLength(t.ortValue, &dataSize)
|
||||
if status != nil {
|
||||
return nil, statusToError(status)
|
||||
}
|
||||
if dataSize == 0 {
|
||||
// A quick optimization where all the strings are empty, avoids needing
|
||||
// to mess around with nil pointers, too.
|
||||
return make([]string, t.shape.FlattenedSize()), nil
|
||||
}
|
||||
// I'm assuming nonzero data length implies a non-empty shape.
|
||||
dataBuffer := make([]byte, dataSize)
|
||||
offsets := make([]C.size_t, t.shape.FlattenedSize())
|
||||
|
||||
// Invoke the C API
|
||||
status = C.GetStringTensorContent(t.ortValue,
|
||||
unsafe.Pointer(&(dataBuffer[0])), dataSize, &(offsets[0]),
|
||||
C.size_t(len(offsets)))
|
||||
if status != nil {
|
||||
return nil, statusToError(status)
|
||||
}
|
||||
|
||||
// Extract the individual strings from the data buffer using their offsets.
|
||||
toReturn := make([]string, len(offsets))
|
||||
for i := range offsets {
|
||||
if i == (len(offsets) - 1) {
|
||||
// The last string doesn't have an end offset; it just goes to the
|
||||
// end of the buffer.
|
||||
toReturn[i] = string(dataBuffer[offsets[i]:])
|
||||
break
|
||||
}
|
||||
toReturn[i] = string(dataBuffer[offsets[i]:offsets[i+1]])
|
||||
}
|
||||
|
||||
return toReturn, nil
|
||||
}
|
||||
|
||||
// Returns a single string from t, at the given flattened index.
|
||||
func (t *StringTensor) GetElement(index int64) (string, error) {
|
||||
var length C.size_t
|
||||
status := C.GetStringTensorElementLength(t.ortValue, C.size_t(index),
|
||||
&length)
|
||||
if status != nil {
|
||||
return "", statusToError(status)
|
||||
}
|
||||
if length == 0 {
|
||||
return "", nil
|
||||
}
|
||||
data := make([]byte, length)
|
||||
status = C.GetStringTensorElement(t.ortValue, length, C.size_t(index),
|
||||
unsafe.Pointer(&(data[0])))
|
||||
if status != nil {
|
||||
return "", statusToError(status)
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// Always returns C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
|
||||
func (t *StringTensor) DataType() C.ONNXTensorElementDataType {
|
||||
return C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
|
||||
}
|
||||
|
||||
func (t *StringTensor) GetShape() Shape {
|
||||
return t.shape.Clone()
|
||||
}
|
||||
|
||||
func (t *StringTensor) Destroy() error {
|
||||
C.ReleaseOrtValue(t.ortValue)
|
||||
t.ortValue = nil
|
||||
t.shape = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *StringTensor) GetInternals() *ValueInternalData {
|
||||
return &ValueInternalData{
|
||||
ortValue: t.ortValue,
|
||||
}
|
||||
}
|
||||
|
||||
// ZeroContents() is unsupported for string tensors. To clear all strings, use
|
||||
// err := t.SetContents(make([]string, t.GetShape().FlattenedSize())) instead.
|
||||
func (t *StringTensor) ZeroContents() {
|
||||
}
|
||||
|
||||
func (t *StringTensor) GetONNXType() ONNXType {
|
||||
return ONNXTypeTensor
|
||||
}
|
||||
|
||||
// Scalar is like a tensor but the underlying go slice is of length 1 and it
|
||||
// has no dimension. It was introduced for use with the training API, but
|
||||
// remains supported since it may be useful apart from the training API.
|
||||
@@ -2387,41 +2558,58 @@ func createGoValueFromOrtValue(v *C.OrtValue) (Value, error) {
|
||||
}
|
||||
|
||||
// Must only be called if v is known to be of type ONNXTensor. Returns a Tensor
|
||||
// wrapping v with the correct Go type. This function always copies v's
|
||||
// contents into a new Tensor backed by a Go-managed slice and releases v.
|
||||
// wrapping v with the correct Go type. For non-string tensors this will copy
|
||||
// the tensor's data into a go-managed slice. In case of errors, this will
|
||||
// destroy the original value. In all non-error cases, the returned value must
|
||||
// be destroyed by the caller.
|
||||
//
|
||||
// The issue is that GetTensorMutableData() becomes invalid after v is
|
||||
// Released, so we can't release the original OrtValue if a reference to the
|
||||
// data slice returned by GetData is still in use elsewhere in go code. We work
|
||||
// around this by copying the data into a new OrtValue with a Go-backed buffer.
|
||||
// (String tensors are easier, since they can't have mutable data buffers in
|
||||
// the first place.)
|
||||
func createTensorFromOrtValue(v *C.OrtValue) (Value, error) {
|
||||
// Either in the case of error or otherwise, we'll release v. The issue is
|
||||
// that GetTensorMutableData() becomes invalid after v is Released, so we
|
||||
// can't release v if a reference to the slice returned by GetData is
|
||||
// still referred to outside of the tensor. We work around this by copying
|
||||
// the data into a new tensor and releasing the original.
|
||||
defer C.ReleaseOrtValue(v)
|
||||
|
||||
var pInfo *C.OrtTensorTypeAndShapeInfo
|
||||
status := C.GetTensorTypeAndShape(v, &pInfo)
|
||||
if status != nil {
|
||||
C.ReleaseOrtValue(v)
|
||||
return nil, fmt.Errorf("Error getting type and shape: %w",
|
||||
statusToError(status))
|
||||
}
|
||||
shape, e := getShapeFromInfo(pInfo)
|
||||
if e != nil {
|
||||
C.ReleaseOrtValue(v)
|
||||
return nil, fmt.Errorf("Error getting shape from TypeAndShapeInfo: %w",
|
||||
e)
|
||||
}
|
||||
var tensorElementType C.ONNXTensorElementDataType
|
||||
status = C.GetTensorElementType(pInfo, (*uint32)(&tensorElementType))
|
||||
if status != nil {
|
||||
C.ReleaseOrtValue(v)
|
||||
return nil, fmt.Errorf("Error getting tensor element type: %w",
|
||||
statusToError(status))
|
||||
}
|
||||
C.ReleaseTensorTypeAndShapeInfo(pInfo)
|
||||
|
||||
if tensorElementType == C.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING {
|
||||
// String tensors. No go-managed data, so no copying needed.
|
||||
return &StringTensor{
|
||||
shape: shape,
|
||||
ortValue: v,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// For non-string tensors, we will always release the original OrtValue.
|
||||
defer C.ReleaseOrtValue(v)
|
||||
|
||||
// Now we start the process of copying the data into a Go-backed OrtValue.
|
||||
var tensorData unsafe.Pointer
|
||||
status = C.GetTensorMutableData(v, &tensorData)
|
||||
if status != nil {
|
||||
return nil, fmt.Errorf("Error getting tensor mutable data: %w",
|
||||
statusToError(status))
|
||||
}
|
||||
|
||||
switch tensorType := TensorElementDataType(tensorElementType); tensorType {
|
||||
case TensorElementDataTypeFloat:
|
||||
return createTensorWithCData[float32](shape, tensorData)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -236,6 +237,124 @@ func TestBoolTensor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringTensor(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
|
||||
logContents := func(s *StringTensor, label string) {
|
||||
contents, e := s.GetContents()
|
||||
if e != nil {
|
||||
t.Fatalf("Error getting contents for %s tensor: %s\n", label, e)
|
||||
}
|
||||
t.Logf("Contents of tensor %s:\n", label)
|
||||
for i, v := range contents {
|
||||
t.Logf(" index %d = %s\n", i, v)
|
||||
}
|
||||
}
|
||||
|
||||
// Start with simple string tensor manipulation and verifying error cases.
|
||||
input, e := NewStringTensor(NewShape(2))
|
||||
if e != nil {
|
||||
t.Fatalf("Error creating input tensor: %s\n", e)
|
||||
}
|
||||
defer input.Destroy()
|
||||
// I assume these will be blank.
|
||||
logContents(input, "initial values")
|
||||
|
||||
e = input.SetContents([]string{"I", "eat", "popcorn!"})
|
||||
if e == nil {
|
||||
t.Errorf("Didn't get expected error when providing an incorrect " +
|
||||
"number of strings to SetContents.\n")
|
||||
}
|
||||
t.Logf("Got expected error when providing an incorrect number of "+
|
||||
"strings to SetContents: %s\n", e)
|
||||
e = input.SetContents([]string{"what's", "up?"})
|
||||
if e != nil {
|
||||
t.Fatalf("Got error when setting initial tensor contents: %s\n", e)
|
||||
}
|
||||
logContents(input, "after SetContents")
|
||||
|
||||
tmpString, e := input.GetElement(1)
|
||||
if e != nil {
|
||||
t.Fatalf("Error getting tensor element: %s\n", e)
|
||||
}
|
||||
if tmpString != "up?" {
|
||||
t.Errorf("Got incorrect value for tensor element 1. Got \"%s\".\n",
|
||||
tmpString)
|
||||
}
|
||||
|
||||
e = input.SetElement(45, "???")
|
||||
if e == nil {
|
||||
t.Fatalf("Didn't get expected error when setting an element at an " +
|
||||
"invalid index")
|
||||
}
|
||||
t.Logf("Got expected error when setting an element at an invalid "+
|
||||
"index: %s\n", e)
|
||||
|
||||
e = input.SetElement(1, "")
|
||||
if e != nil {
|
||||
t.Fatalf("Error setting element at index 1: %s\n", e)
|
||||
}
|
||||
logContents(input, "after SetElement")
|
||||
|
||||
// Next, we'll test executing a network, including allowing onnxruntime to
|
||||
// auto-allocate a string tensor.
|
||||
inputContents := []string{"Green Beans", "Something Else"}
|
||||
e = input.SetContents(inputContents)
|
||||
if e != nil {
|
||||
t.Fatalf("Error setting final input contents: %s\n", e)
|
||||
}
|
||||
outputUppercase, e := NewStringTensor(NewShape(2))
|
||||
if e != nil {
|
||||
t.Fatalf("Error creating output tensor: %s\n", e)
|
||||
}
|
||||
defer outputUppercase.Destroy()
|
||||
outputs := []Value{outputUppercase, nil}
|
||||
|
||||
filePath := "test_data/example_strings.onnx"
|
||||
session, e := NewDynamicAdvancedSession(filePath, []string{"input"},
|
||||
[]string{"output_upper", "output_lower"}, nil)
|
||||
if e != nil {
|
||||
t.Fatalf("Error creating session for %s: %s\n", filePath, e)
|
||||
}
|
||||
defer session.Destroy()
|
||||
|
||||
e = session.Run([]Value{input}, outputs)
|
||||
if e != nil {
|
||||
t.Fatalf("Error running %s: %s\n", filePath, e)
|
||||
}
|
||||
defer outputs[1].Destroy()
|
||||
|
||||
outputLowercase, ok := outputs[1].(*StringTensor)
|
||||
if !ok {
|
||||
t.Fatalf("Running %s didn't create a StringTensor output", filePath)
|
||||
}
|
||||
|
||||
logContents(input, "input")
|
||||
logContents(outputLowercase, "outputLowercase")
|
||||
logContents(outputUppercase, "outputUppercase")
|
||||
|
||||
uppercaseStrings, e := outputUppercase.GetContents()
|
||||
if e != nil {
|
||||
t.Fatalf("Error getting uppercase contents: %s\n", e)
|
||||
}
|
||||
lowercaseStrings, e := outputLowercase.GetContents()
|
||||
if e != nil {
|
||||
t.Fatalf("Error getting lowercase contents: %s\n", e)
|
||||
}
|
||||
|
||||
for i, original := range inputContents {
|
||||
if strings.ToLower(original) != lowercaseStrings[i] {
|
||||
t.Errorf("Didn't get expected lowercase version of %s, got %s\n",
|
||||
original, lowercaseStrings[i])
|
||||
}
|
||||
if strings.ToUpper(original) != uppercaseStrings[i] {
|
||||
t.Errorf("Didn't get expected uppercase version of %s, got %s\n",
|
||||
original, uppercaseStrings[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadTensorShapes(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
@@ -2251,7 +2370,8 @@ func TestCancelWithRunOptions_AdvancedSession(t *testing.T) {
|
||||
defer output.Destroy()
|
||||
|
||||
filePath := "test_data/example ż 大 김.onnx"
|
||||
session, e := NewAdvancedSession(filePath, []string{"in"}, []string{"out"}, []Value{input}, []Value{output}, nil)
|
||||
session, e := NewAdvancedSession(filePath, []string{"in"}, []string{"out"},
|
||||
[]Value{input}, []Value{output}, nil)
|
||||
if e != nil {
|
||||
t.Fatalf("Failed creating session for %s: %s\n", filePath, e)
|
||||
}
|
||||
|
||||
@@ -366,6 +366,17 @@ OrtStatus *CreateOrtTensorWithShape(void *data, size_t data_size,
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtStatus *CreateTensorAsOrtValue(int64_t *shape, int64_t shape_size,
|
||||
ONNXTensorElementDataType dtype, OrtValue **out) {
|
||||
OrtStatus *status = NULL;
|
||||
OrtAllocator *allocator = NULL;
|
||||
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
|
||||
if (status) return status;
|
||||
status = ort_api->CreateTensorAsOrtValue(allocator, shape, shape_size, dtype,
|
||||
out);
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtStatus *GetTensorTypeAndShape(const OrtValue *value, OrtTensorTypeAndShapeInfo **out) {
|
||||
return ort_api->GetTensorTypeAndShape(value, out);
|
||||
}
|
||||
@@ -523,3 +534,33 @@ OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
|
||||
return ort_api->CreateValue((const OrtValue* const*) in, num_values,
|
||||
value_type, out);
|
||||
}
|
||||
|
||||
OrtStatus *FillStringTensor(OrtValue *v, char **strings, size_t num_strings) {
|
||||
return ort_api->FillStringTensor(v, (const char* const*) strings,
|
||||
num_strings);
|
||||
}
|
||||
|
||||
OrtStatus *GetStringTensorDataLength(OrtValue *v, size_t *length) {
|
||||
return ort_api->GetStringTensorDataLength(v, length);
|
||||
}
|
||||
|
||||
OrtStatus *GetStringTensorContent(OrtValue *v, void *data_buffer,
|
||||
size_t data_size, size_t *offsets_buffer, size_t offsets_length) {
|
||||
return ort_api->GetStringTensorContent(v, data_buffer, data_size,
|
||||
offsets_buffer, offsets_length);
|
||||
}
|
||||
|
||||
OrtStatus *FillStringTensorElement(OrtValue *v, char *s, size_t index) {
|
||||
return ort_api->FillStringTensorElement(v, s, index);
|
||||
}
|
||||
|
||||
OrtStatus *GetStringTensorElementLength(OrtValue *v, size_t index,
|
||||
size_t *result) {
|
||||
return ort_api->GetStringTensorElementLength(v, index, result);
|
||||
}
|
||||
|
||||
OrtStatus *GetStringTensorElement(OrtValue *v, size_t buffer_length,
|
||||
size_t index, void *buffer) {
|
||||
return ort_api->GetStringTensorElement(v, buffer_length, index, buffer);
|
||||
}
|
||||
|
||||
|
||||
@@ -242,17 +242,26 @@ OrtStatus *CreateOrtTensorWithShape(void *data, size_t data_size,
|
||||
int64_t *shape, int64_t shape_size, OrtMemoryInfo *mem_info,
|
||||
ONNXTensorElementDataType dtype, OrtValue **out);
|
||||
|
||||
// Creates an OrtValue managed by onnxruntime's default allocator rather than
|
||||
// using Go-managed memory. Wraps ort_api->CreateTensorAsOrtValue.
|
||||
OrtStatus *CreateTensorAsOrtValue(int64_t *shape, int64_t shape_size,
|
||||
ONNXTensorElementDataType dtype, OrtValue **out);
|
||||
|
||||
// Wraps ort_api->GetTensorTypeAndShape
|
||||
OrtStatus *GetTensorTypeAndShape(const OrtValue *value, OrtTensorTypeAndShapeInfo **out);
|
||||
OrtStatus *GetTensorTypeAndShape(const OrtValue *value,
|
||||
OrtTensorTypeAndShapeInfo **out);
|
||||
|
||||
// Wraps ort_api->GetDimensionsCount
|
||||
OrtStatus *GetDimensionsCount(const OrtTensorTypeAndShapeInfo *info, size_t *out);
|
||||
OrtStatus *GetDimensionsCount(const OrtTensorTypeAndShapeInfo *info,
|
||||
size_t *out);
|
||||
|
||||
// Wraps ort_api->GetDimensions
|
||||
OrtStatus *GetDimensions(const OrtTensorTypeAndShapeInfo *info, int64_t *dim_values, size_t dim_values_length);
|
||||
OrtStatus *GetDimensions(const OrtTensorTypeAndShapeInfo *info,
|
||||
int64_t *dim_values, size_t dim_values_length);
|
||||
|
||||
// Wraps ort_api->GetTensorElementType
|
||||
OrtStatus *GetTensorElementType(const OrtTensorTypeAndShapeInfo *info, enum ONNXTensorElementDataType *out);
|
||||
OrtStatus *GetTensorElementType(const OrtTensorTypeAndShapeInfo *info,
|
||||
enum ONNXTensorElementDataType *out);
|
||||
|
||||
// Wraps ort_api->ReleaseTensorTypeAndShapeInfo
|
||||
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *input);
|
||||
@@ -335,6 +344,27 @@ OrtStatus *GetValueCount(OrtValue *v, size_t *out);
|
||||
OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
|
||||
enum ONNXType value_type, OrtValue **out);
|
||||
|
||||
// Wraps ort_api->FillStringTensor
|
||||
OrtStatus *FillStringTensor(OrtValue *v, char **strings, size_t num_strings);
|
||||
|
||||
// Wraps ort_api->GetStringTensorDataLength
|
||||
OrtStatus *GetStringTensorDataLength(OrtValue *v, size_t *length);
|
||||
|
||||
// Wraps ort_api->GetStringTensorContent
|
||||
OrtStatus *GetStringTensorContent(OrtValue *v, void *data_buffer,
|
||||
size_t data_size, size_t *offsets_buffer, size_t offsets_length);
|
||||
|
||||
// Wraps ort_api->FillStringTensorElement
|
||||
OrtStatus *FillStringTensorElement(OrtValue *v, char *s, size_t index);
|
||||
|
||||
// Wraps ort_api->GetStringTensorElementLength
|
||||
OrtStatus *GetStringTensorElementLength(OrtValue *v, size_t index,
|
||||
size_t *result);
|
||||
|
||||
// Wraps ort_api->GetStringTensorElement
|
||||
OrtStatus *GetStringTensorElement(OrtValue *v, size_t buffer_length,
|
||||
size_t index, void *buffer);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
BIN
test_data/example_strings.onnx
Normal file
BIN
test_data/example_strings.onnx
Normal file
Binary file not shown.
58
test_data/generate_strings_example.py
Normal file
58
test_data/generate_strings_example.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# This script generates "example_strings.onnx". This example takes a 1xN tensor
|
||||
# of N strings, and produces two 1xN outputs: one with the strings converted to
|
||||
# lowercase and one with the strings converted to uppercase.
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx import helper, TensorProto
|
||||
import onnxruntime as ort
|
||||
|
||||
def main():
|
||||
# Describe the inputs and outputs
|
||||
input_info = helper.make_tensor_value_info("input", TensorProto.STRING,
|
||||
[None])
|
||||
output_lower_info = helper.make_tensor_value_info("output_lower",
|
||||
TensorProto.STRING, [None])
|
||||
output_upper_info = helper.make_tensor_value_info("output_upper",
|
||||
TensorProto.STRING, [None])
|
||||
|
||||
node_lower = helper.make_node(
|
||||
"StringNormalizer",
|
||||
inputs=["input"],
|
||||
outputs=["output_lower"],
|
||||
case_change_action="LOWER",
|
||||
)
|
||||
|
||||
node_upper = helper.make_node(
|
||||
"StringNormalizer",
|
||||
inputs=["input"],
|
||||
outputs=["output_upper"],
|
||||
case_change_action="UPPER",
|
||||
)
|
||||
|
||||
graph = helper.make_graph(
|
||||
[node_lower, node_upper],
|
||||
"strings_example_graph",
|
||||
[input_info],
|
||||
[output_lower_info, output_upper_info],
|
||||
)
|
||||
|
||||
model = helper.make_model(graph,
|
||||
producer_name="generate_strings_example.py")
|
||||
onnx.checker.check_model(model)
|
||||
filename = "example_strings.onnx"
|
||||
onnx.save(model, filename)
|
||||
print(f"Saved {filename} OK. Testing...")
|
||||
|
||||
session = ort.InferenceSession(filename)
|
||||
|
||||
inputs = np.array(["I", "eAt", "POTATOEs!!"])
|
||||
|
||||
output_lower, output_upper = session.run(None, {"input": inputs})
|
||||
|
||||
print("Inputs: " + str(inputs))
|
||||
print("Lowercase: " + str(output_lower))
|
||||
print("Upercase: " + str(output_upper))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user