commit 815e8b37ff5c1d045778d2e29b4c9765d19ed446 Author: swdee Date: Mon Apr 8 22:45:28 2024 +1200 imported initial code diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..311f778 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.idea +prototype +python diff --git a/cgo.go b/cgo.go new file mode 100644 index 0000000..f053e46 --- /dev/null +++ b/cgo.go @@ -0,0 +1,6 @@ +package rknnlite + +/* +#cgo LDFLAGS: -lrknnrt +*/ +import "C" diff --git a/example/data/.gitignore b/example/data/.gitignore new file mode 100644 index 0000000..335d0d1 --- /dev/null +++ b/example/data/.gitignore @@ -0,0 +1,2 @@ +*.jpg +*.rknn \ No newline at end of file diff --git a/example/data/README.md b/example/data/README.md new file mode 100644 index 0000000..de04172 --- /dev/null +++ b/example/data/README.md @@ -0,0 +1,5 @@ + +# Example Data Files + +Use the `download.sh` script to download the data files (models and images) +needed to run the example code. \ No newline at end of file diff --git a/example/data/download.sh b/example/data/download.sh new file mode 100755 index 0000000..f06074b --- /dev/null +++ b/example/data/download.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +# download mobilenet_v1 data files +wget https://github.com/airockchip/rknn-toolkit2/raw/v1.6.0/rknpu2/examples/rknn_mobilenet_demo/model/cat_224x224.jpg +wget https://github.com/airockchip/rknn-toolkit2/raw/v1.6.0/rknpu2/examples/rknn_mobilenet_demo/model/dog_224x224.jpg +wget -O mobilenet_v1-rk3588.rknn https://github.com/airockchip/rknn-toolkit2/raw/v1.6.0/rknpu2/examples/rknn_mobilenet_demo/model/RK3588/mobilenet_v1.rknn \ No newline at end of file diff --git a/example/mobilenet/mobilenet.go b/example/mobilenet/mobilenet.go new file mode 100644 index 0000000..010a16f --- /dev/null +++ b/example/mobilenet/mobilenet.go @@ -0,0 +1,119 @@ +package main + +import ( + "flag" + "fmt" + rknnlite "go-rknnlite" + "gocv.io/x/gocv" + "image" + "log" +) + +func main() { + // disable logging timestamps + log.SetFlags(0) + + // read in cli flags + modelFile := flag.String("m", "../data/mobilenet_v1-rk3588.rknn", "RKNN compiled model file") + imgFile := flag.String("i", "../data/cat_224x224.jpg", "Image file to run inference on") + flag.Parse() + + // load image + img := gocv.IMRead(*imgFile, gocv.IMReadColor) + + if img.Empty() { + log.Fatal("Error reading image from: ", *imgFile) + } + + // convert colorspace and resize image + rgbImg := gocv.NewMat() + gocv.CvtColor(img, &rgbImg, gocv.ColorBGRToRGB) + + cropImg := rgbImg.Clone() + gocv.Resize(rgbImg, &cropImg, image.Pt(224, 224), 0, 0, gocv.InterpolationArea) + + defer img.Close() + defer rgbImg.Close() + defer cropImg.Close() + + // create rknn runtime instance + rt, err := rknnlite.NewRuntime(*modelFile, rknnlite.NPUCoreAuto) + + if err != nil { + log.Fatal("Error initialing RKNN runtime: ", err) + } + + // optional querying of model file tensors and SDK version. not necessary + // for production inference code + optionalQueries(rt) + + // perform inference on image file + outputs, err := rt.Inference([]gocv.Mat{cropImg}) + + if err != nil { + log.Fatal("Runtime inferencing failed with error: ", err) + } + + // post process outputs and show top5 matches + log.Println(" --- Top5 ---") + + for _, next := range rknnlite.GetTop5(outputs) { + log.Printf("%3d: %8.6f\n", next.LabelIndex, next.Probability) + } + + // close runtime and release resources + err = rt.Close() + + if err != nil { + log.Fatal("Error closing RKNN runtime: ", err) + } + + log.Println("done") +} + +func optionalQueries(rt *rknnlite.Runtime) { + + // get SDK version + ver, err := rt.SDKVersion() + + if err != nil { + log.Fatal("Error initialing RKNN runtime: ", err) + } + + fmt.Printf("Driver Version: %s, API Version: %s\n", ver.DriverVersion, ver.APIVersion) + + // get model input and output numbers + num, err := rt.QueryModelIONumber() + + if err != nil { + log.Fatal("Error querying IO Numbers: ", err) + } + + log.Printf("Model Input Number: %d, Ouput Number: %d\n", num.NumberInput, num.NumberOutput) + + // query Input tensors + inputAttrs, err := rt.QueryInputTensors() + + if err != nil { + log.Fatal("Error querying Input Tensors: ", err) + } + + log.Println("Input tensors:") + + for _, attr := range inputAttrs { + log.Printf(" %s\n", attr.String()) + } + + // query Output tensors + outputAttrs, err := rt.QueryOutputTensors() + + if err != nil { + log.Fatal("Error querying Output Tensors: ", err) + } + + log.Println("Output tensors:") + + for _, attr := range outputAttrs { + log.Printf(" %s\n", attr.String()) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..2337517 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module go-rknnlite + +go 1.21 + +require gocv.io/x/gocv v0.36.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f5c721f --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +gocv.io/x/gocv v0.36.1 h1:6XkEaPOk7h/umjy+MXgSEtSeCIgcPJhccUjrJFhjdTY= +gocv.io/x/gocv v0.36.1/go.mod h1:lmS802zoQmnNvXETpmGriBqWrENPei2GxYx5KUxJsMA= diff --git a/inference.go b/inference.go new file mode 100644 index 0000000..40be341 --- /dev/null +++ b/inference.go @@ -0,0 +1,271 @@ +package rknnlite + +/* +#include "rknn_api.h" +#include +*/ +import "C" +import ( + "fmt" + "gocv.io/x/gocv" + "unsafe" +) + +// Input represents the C.rknn_input struct and defines the Input used for +// inference +type Input struct { + // Index is the input index + Index uint32 + // Buf is the gocv Mat input + Buf unsafe.Pointer + // Size is the number of bytes of Buf + Size uint32 + // Passthrough defines the mode, if True the buf data is passed directly to + // the input node of the rknn model without any conversion. If False the + // buf data is converted into an input consistent with the model according + // to the following type and fmt + PassThrough bool + // Type is the data type of Buf. This is a required parameter if Passthrough + // is False + Type TensorType + // Fmt is the data format of Buf. This is a required parameter if Passthrough + // is False + Fmt TensorFormat +} + +// Inference runs the model inference on the given inputs +func (r *Runtime) Inference(mats []gocv.Mat) ([]Output, error) { + + // convert the cv Mat's into RKNN inputs + inputs := make([]Input, len(mats)) + + for idx, mat := range mats { + + // make mat continuous + if !mat.IsContinuous() { + mat = mat.Clone() + } + + // cast to float32, as PassThrough below is set to false then RKNN + // we convert the input values to that of the tensor inputs in the model, + // eg: INT8 + data, err := mat.DataPtrUint8() + + if err != nil { + return nil, fmt.Errorf("error converting image to float32: %w", err) + } + + inputs[idx] = Input{ + Index: uint32(idx), + Type: TensorUint8, + Size: uint32(mat.Cols() * mat.Rows() * mat.Channels()), + Fmt: TensorNHWC, + Buf: unsafe.Pointer(&data[0]), + PassThrough: false, + } + } + + // set the Inputs + err := r.SetInputs(inputs) + + if err != nil { + return nil, fmt.Errorf("error setting inputs: %w", err) + } + + // run the model + err = r.RunModel() + + if err != nil { + return nil, fmt.Errorf("error running model: %w", err) + } + + // get Outputs + return r.GetOutputs(r.ioNum.NumberOutput) +} + +// setInputs wraps C.rknn_inputs_set +func (r *Runtime) SetInputs(inputs []Input) error { + + nInputs := C.uint32_t(len(inputs)) + // make a C array of inputs + cInputs := make([]C.rknn_input, len(inputs)) + + for i, input := range inputs { + cInputs[i].index = C.uint32_t(input.Index) + cInputs[i].buf = input.Buf + cInputs[i].size = C.uint32_t(input.Size) + cInputs[i].pass_through = C.uint8_t(0) + if input.PassThrough { + cInputs[i].pass_through = C.uint8_t(1) + } + cInputs[i]._type = C.rknn_tensor_type(input.Type) + cInputs[i].fmt = C.rknn_tensor_format(input.Fmt) + } + + ret := C.rknn_inputs_set(r.ctx, nInputs, &cInputs[0]) + + if ret != 0 { + return fmt.Errorf("C.rknn_inputs_set failed with code %d, error: %s", + int(ret), ErrorCodes(ret).String()) + } + + return nil +} + +// RunModel wraps C.rknn_run +func (r *Runtime) RunModel() error { + + ret := C.rknn_run(r.ctx, nil) + + if ret < 0 { + return fmt.Errorf("C.rknn_run failed with code %d, error: %s", + int(ret), ErrorCodes(ret).String()) + } + + return nil +} + +// Output wraps C.rknn_output +type Output struct { + WantFloat uint8 // want transfer output data to float + IsPrealloc uint8 // whether buf is pre-allocated + Index uint32 // the output index + Buf []float32 // the output buf + Size uint32 // the size of output buf +} + +// GetOutputs returns the Output results +func (r *Runtime) GetOutputs(nOutputs uint32) ([]Output, error) { + + outputs := make([]Output, nOutputs) + + // prepare the outputs array in C + cOutputs := make([]C.rknn_output, nOutputs) + // release cOutputs from memory + defer r.releaseOutputs(cOutputs) + + // set want float for all outputs + for idx := range cOutputs { + cOutputs[idx].want_float = 1 + } + + // call C function + ret := C.rknn_outputs_get(r.ctx, C.uint32_t(nOutputs), + (*C.rknn_output)(unsafe.Pointer(&cOutputs[0])), nil) + + if ret < 0 { + return nil, fmt.Errorf("C.rknn_outputs_get failed with code %d, error: %s", + int(ret), ErrorCodes(ret).String()) + } + + // convert C.rknn_output array back to Go Output array + for i, cOutput := range cOutputs { + // convert buffer to []float32 + buffer := (*[1 << 30]float32)(cOutputs[i].buf)[:cOutputs[i].size/4] + + outputs[i] = Output{ + WantFloat: uint8(cOutput.want_float), + IsPrealloc: uint8(cOutput.is_prealloc), + Index: uint32(cOutput.index), + Buf: buffer, + Size: uint32(cOutput.size), + } + } + + return outputs, nil +} + +// releaseOutputs releases the memory allocated for the outputs by the RKNN +// toolkit directly using C rknn_output structs +func (r *Runtime) releaseOutputs(cOutputs []C.rknn_output) error { + + // directly use the C array of rknn_output obtained from getOutputs or similar. + outputsPtr := (*C.rknn_output)(unsafe.Pointer(&cOutputs[0])) + + // call C.rknn_outputs_release with the context and the outputs pointer + ret := C.rknn_outputs_release(r.ctx, C.uint32_t(len(cOutputs)), outputsPtr) + + if ret != 0 { + return fmt.Errorf("C.rknn_outputs_release failed with code %d, error: %s", + ret, ErrorCodes(ret).String()) + } + + return nil +} + +type Probability struct { + LabelIndex int32 + Probability float32 +} + +// GetTop5 outputs the Top5 matches in the model, with left column as label +// index and right column the match probability. The results are returned +// in the Probability slice in descending order from top match. +func GetTop5(outputs []Output) []Probability { + + probs := make([]Probability, 5) + + for i := 0; i < len(outputs); i++ { + var MaxClass [5]int32 + var fMaxProb [5]float32 + + GetTop(outputs[i].Buf, fMaxProb[:], MaxClass[:], int32(len(outputs[i].Buf)), 5) + + for i := 0; i < 5; i++ { + probs[i] = Probability{ + LabelIndex: MaxClass[i], + Probability: fMaxProb[i], + } + } + } + + return probs +} + +const MAX_TOP_NUM = 20 + +// GetTop takes outputs and produces a top list of matches by probability +func GetTop(pfProb []float32, pfMaxProb []float32, pMaxClass []int32, + outputCount int32, topNum int32) int { + + if topNum > MAX_TOP_NUM { + return 0 + } + + // initialize pfMaxProb with default values, ie: 0 + for j := range pfMaxProb { + pfMaxProb[j] = 0 + } + // initialize pMaxClass with default values, ie: -1 + for j := range pMaxClass { + pMaxClass[j] = -1 + } + + for j := int32(0); j < topNum; j++ { + for i := int32(0); i < outputCount; i++ { + + // skip if the current class is already in the top list + skip := false + + for k := 0; k < len(pMaxClass); k++ { + if i == pMaxClass[k] { + skip = true + break + } + } + + if skip { + continue + } + + // if the current probability is greater than the j'th max + // probability, update pfMaxProb and pMaxClass + if pfProb[i] > pfMaxProb[j] && pfProb[i] > 0.000001 { + pfMaxProb[j] = pfProb[i] + pMaxClass[j] = i + } + } + } + + return 1 +} diff --git a/io.go b/io.go new file mode 100644 index 0000000..af63b26 --- /dev/null +++ b/io.go @@ -0,0 +1,38 @@ +package rknnlite + +/* +#include "rknn_api.h" +#include +*/ +import "C" +import ( + "fmt" + "unsafe" +) + +// QueryModelIONumber queries the number of Input and Output tensors of the model +func (r *Runtime) QueryModelIONumber() (ioNum IONumber, err error) { + + // prepare the structure to receive the Input/Output number + var cIONum C.rknn_input_output_num + + // call the C function + ret := C.rknn_query(r.ctx, C.RKNN_QUERY_IN_OUT_NUM, unsafe.Pointer(&cIONum), C.uint(C.sizeof_rknn_input_output_num)) + + if ret != C.RKNN_SUCC { + return IONumber{}, fmt.Errorf("rknn_query failed with return code %d", int(ret)) + } + + ioNum = IONumber{ + NumberInput: uint32(cIONum.n_input), + NumberOutput: uint32(cIONum.n_output), + } + + return ioNum, nil +} + +// IONumber represents the C.rknn_input_output_num struct +type IONumber struct { + NumberInput uint32 + NumberOutput uint32 +} diff --git a/runtime.go b/runtime.go new file mode 100644 index 0000000..28c9b14 --- /dev/null +++ b/runtime.go @@ -0,0 +1,216 @@ +package rknnlite + +/* +#include "rknn_api.h" +#include +*/ +import "C" +import ( + "fmt" + "os" + "unsafe" +) + +// CoreMask wraps C.rknn_core_mask +type CoreMask int + +// rknn_core_mask values used to target which cores on the NPU the model is run +// on. The rk3588 has three cores, auto will pick an idle core to run the model +// on, whilst the others specify the specific core or combined number of cores +// to run. For multi-core modes the following ops have better acceleration: Conv, +// DepthwiseConvolution, Add, Concat, Relu, Clip, Relu6, ThresholdedRelu, Prelu, +// and LeakyRelu. Other type of ops will fallback to Core0 to continue running +const ( + NPUCoreAuto CoreMask = C.RKNN_NPU_CORE_AUTO + NPUCore0 CoreMask = C.RKNN_NPU_CORE_0 + NPUCore1 CoreMask = C.RKNN_NPU_CORE_1 + NPUCore2 CoreMask = C.RKNN_NPU_CORE_2 + NPUCore01 CoreMask = C.RKNN_NPU_CORE_0_1 + NPUCore012 CoreMask = C.RKNN_NPU_CORE_0_1_2 +) + +// ErrorCodes +type ErrorCodes int + +// error code values returned by the C API +const ( + Success ErrorCodes = C.RKNN_SUCC + ErrFail ErrorCodes = C.RKNN_ERR_FAIL + ErrTimeout ErrorCodes = C.RKNN_ERR_TIMEOUT + ErrDeviceUnavailable ErrorCodes = C.RKNN_ERR_DEVICE_UNAVAILABLE + ErrMallocFail ErrorCodes = C.RKNN_ERR_MALLOC_FAIL + ErrParamInvalid ErrorCodes = C.RKNN_ERR_PARAM_INVALID + ErrModelInvalid ErrorCodes = C.RKNN_ERR_MODEL_INVALID + ErrCtxInvalid ErrorCodes = C.RKNN_ERR_CTX_INVALID + ErrInputInvalid ErrorCodes = C.RKNN_ERR_INPUT_INVALID + ErrOutputInvalid ErrorCodes = C.RKNN_ERR_OUTPUT_INVALID + ErrDeviceMismatch ErrorCodes = C.RKNN_ERR_DEVICE_UNMATCH + ErrPreCompiledModel ErrorCodes = C.RKNN_ERR_INCOMPATILE_PRE_COMPILE_MODEL + ErrOptimizationVersion ErrorCodes = C.RKNN_ERR_INCOMPATILE_OPTIMIZATION_LEVEL_VERSION + ErrPlatformMismatch ErrorCodes = C.RKNN_ERR_TARGET_PLATFORM_UNMATCH +) + +// String returns a readable description of the error code +func (e ErrorCodes) String() string { + switch e { + case Success: + return "execution successful" + case ErrFail: + return "execution failed" + case ErrTimeout: + return "execution timed out" + case ErrDeviceUnavailable: + return "device is unavailable" + case ErrMallocFail: + return "C memory allocation failed" + case ErrParamInvalid: + return "parameter is invalid" + case ErrModelInvalid: + return "model file is invalid" + case ErrCtxInvalid: + return "context is invalid" + case ErrInputInvalid: + return "input is invalid" + case ErrOutputInvalid: + return "output is invalid" + case ErrDeviceMismatch: + return "device mismatch, please update rknn sdk and npu driver/firmware" + case ErrPreCompiledModel: + return "the RKNN model uses pre_compile mode, but is not compatible with current driver" + case ErrOptimizationVersion: + return "the RKNN model optimization level is not compatible with current driver" + case ErrPlatformMismatch: + return "the RKNN model target platform is not compatible with the current platform" + default: + return fmt.Sprintf("unknown error code %d", e) + } +} + +// Runtime defines the RKNN run time instance +type Runtime struct { + // ctx is the C runtime context + ctx C.rknn_context + // ioNum caches the IONumber of Model Input/Output tensors + ioNum IONumber +} + +// NewRuntime returns a RKNN run time instance. Provide the full path and +// filename of the RKNN compiled model file to run. +func NewRuntime(modelFile string, core CoreMask) (*Runtime, error) { + + r := &Runtime{} + + err := r.init(modelFile) + + if err != nil { + return nil, err + } + + err = r.setCoreMask(core) + + if err != nil { + return nil, err + } + + // cache IONumber + r.ioNum, err = r.QueryModelIONumber() + + if err != nil { + return nil, err + } + + return r, nil +} + +// init wraps C.rknn_init which initializes the RKNN context with the given +// model. The modelFile is the full path and filename of the RKNN compiled +// model file to run. +func (r *Runtime) init(modelFile string) error { + + // check file exists in Go, before passing to C + info, err := os.Stat(modelFile) + + if err != nil { + return fmt.Errorf("model file does not exist at %s, error: %w", + modelFile, err) + } + + if info.IsDir() { + return fmt.Errorf("model file is a directory") + } + + // convert the Go string to a C string + cModelFile := C.CString(modelFile) + defer C.free(unsafe.Pointer(cModelFile)) + + // call the C function. + ret := C.rknn_init(&r.ctx, unsafe.Pointer(cModelFile), 0, 0, nil) + + if ret != C.RKNN_SUCC { + return fmt.Errorf("C.rknn_init call failed with code %d, error: %s", + ret, ErrorCodes(ret).String()) + } + + return nil +} + +// setCoreMark wraps C.rknn_set_core_mask and specifies the NPU core configuration +// to run the model on +func (r *Runtime) setCoreMask(mask CoreMask) error { + + ret := C.rknn_set_core_mask(r.ctx, C.rknn_core_mask(mask)) + + if ret != C.RKNN_SUCC { + return fmt.Errorf("C.rknn_set_core_mask failed with code %d, error: %s", + ret, ErrorCodes(ret).String()) + } + + return nil +} + +// Close wraps C.rknn_destroy which unloads the RKNN model from the runtime and +// destroys the context releasing all C resources +func (r *Runtime) Close() error { + + ret := C.rknn_destroy(r.ctx) + + if ret != C.RKNN_SUCC { + return fmt.Errorf("C.rknn_destroy failed with code %d, error: %s", + ret, ErrorCodes(ret).String()) + } + + return nil +} + +// SDKVersion represents the C.rknn_sdk_version struct +type SDKVersion struct { + DriverVersion string + APIVersion string +} + +// SDKVersion returns the RKNN API and Driver versions +func (r *Runtime) SDKVersion() (SDKVersion, error) { + + // prepare the structure to receive the SDK version info + var cSdkVer C.rknn_sdk_version + + // call the C function + ret := C.rknn_query( + r.ctx, + C.RKNN_QUERY_SDK_VERSION, + unsafe.Pointer(&cSdkVer), + C.uint(C.sizeof_rknn_sdk_version), + ) + + if ret != C.RKNN_SUCC { + return SDKVersion{}, fmt.Errorf("rknn_query failed with return code %d", int(ret)) + } + + // convert the C rknn_sdk_version to Go rknn_sdk_version + version := SDKVersion{ + DriverVersion: C.GoString(&(cSdkVer.drv_version[0])), + APIVersion: C.GoString(&(cSdkVer.api_version[0])), + } + + return version, nil +} diff --git a/tensor.go b/tensor.go new file mode 100644 index 0000000..0b1572d --- /dev/null +++ b/tensor.go @@ -0,0 +1,241 @@ +package rknnlite + +/* +#include "rknn_api.h" +#include +*/ +import "C" +import ( + "fmt" + "strings" + "unsafe" +) + +// TensorFormat wraps C.rknn_tensor_format +type TensorFormat int + +const ( + TensorNCHW TensorFormat = C.RKNN_TENSOR_NCHW + TensorNHWC TensorFormat = C.RKNN_TENSOR_NHWC + TensorNC1HWC2 TensorFormat = C.RKNN_TENSOR_NC1HWC2 + TensorUndefined TensorFormat = C.RKNN_TENSOR_UNDEFINED +) + +// TensorType wraps C.rknn_tensor_type +type TensorType int + +const ( + TensorFloat32 TensorType = C.RKNN_TENSOR_FLOAT32 + TensorFloat16 TensorType = C.RKNN_TENSOR_FLOAT16 + TensorInt8 TensorType = C.RKNN_TENSOR_INT8 + TensorUint8 TensorType = C.RKNN_TENSOR_UINT8 + TensorInt16 TensorType = C.RKNN_TENSOR_INT16 + TensorUint16 TensorType = C.RKNN_TENSOR_UINT16 + TensorInt32 TensorType = C.RKNN_TENSOR_INT32 + TensorUint32 TensorType = C.RKNN_TENSOR_UINT32 + TensorInt64 TensorType = C.RKNN_TENSOR_INT64 + TensorBool TensorType = C.RKNN_TENSOR_BOOL + TensorInt4 TensorType = C.RKNN_TENSOR_INT4 +) + +// TensorQntType wraps C.rknn_tensor_qnt_type +type TensorQntType int + +const ( + TensorQntNone TensorQntType = C.RKNN_TENSOR_QNT_NONE + TensorQntDFP TensorQntType = C.RKNN_TENSOR_QNT_DFP + TensorQntAffine TensorQntType = C.RKNN_TENSOR_QNT_AFFINE_ASYMMETRIC +) + +// AttrMaxDimensions are the maximum dimensions for an attribute in a tensor +type AttrMaxDimensions int + +// maximum field lengths of attributes in a tensor +const ( + AttrMaxDimension AttrMaxDimensions = C.RKNN_MAX_DIMS + AttrMaxChannels AttrMaxDimensions = C.RKNN_MAX_NUM_CHANNEL + AttrMaxNameLength AttrMaxDimensions = C.RKNN_MAX_NAME_LEN + AttrMaxDynShape AttrMaxDimensions = C.RKNN_MAX_DYNAMIC_SHAPE_NUM +) + +// TensorAttr represents the C.rknn_tensor_attr structure +type TensorAttr struct { + Index uint32 + NDims uint32 + Dims [AttrMaxDimension]uint32 + Name string + NElems uint32 + Size uint32 + Fmt TensorFormat + Type TensorType + QntType TensorQntType + FL int8 + ZP int32 + Scale float32 + WStride uint32 + SizeWithStride uint32 + PassThrough bool + HStride uint32 +} + +// convertTensorAttr converts a C.rknn_tensor_attr to a Go TensorAttr +func (r *Runtime) convertTensorAttr(cAttr *C.rknn_tensor_attr) TensorAttr { + + // convert C char array to Go string for Name field + nameBytes := C.GoBytes(unsafe.Pointer(&cAttr.name[0]), C.int(AttrMaxNameLength)) + goName := string(nameBytes) + + // find the first null byte to correctly end the string (if present) + nullIndex := strings.IndexByte(goName, 0) + + if nullIndex != -1 { + // Trim the string at the first null character + goName = goName[:nullIndex] + } + + return TensorAttr{ + Index: uint32(cAttr.index), + NDims: uint32(cAttr.n_dims), + Dims: *(*[AttrMaxDimension]uint32)(unsafe.Pointer(&cAttr.dims)), + Name: goName, + NElems: uint32(cAttr.n_elems), + Size: uint32(cAttr.size), + Fmt: TensorFormat(cAttr.fmt), + Type: TensorType(cAttr._type), + QntType: TensorQntType(cAttr.qnt_type), + FL: int8(cAttr.fl), + ZP: int32(cAttr.zp), + Scale: float32(cAttr.scale), + WStride: uint32(cAttr.w_stride), + SizeWithStride: uint32(cAttr.size_with_stride), + PassThrough: cAttr.pass_through != 0, + HStride: uint32(cAttr.h_stride), + } +} + +// QueryInputTensors gets the model Input Tensor attributes +func (r *Runtime) QueryInputTensors() ([]TensorAttr, error) { + + // allocate memory for input attributes in C + cInputAttrs := make([]C.rknn_tensor_attr, r.ioNum.NumberInput) + + for i := uint32(0); i < r.ioNum.NumberInput; i++ { + cInputAttrs[i].index = C.uint32_t(i) + + ret := C.rknn_query(r.ctx, C.RKNN_QUERY_INPUT_ATTR, + unsafe.Pointer(&cInputAttrs[i]), C.uint(unsafe.Sizeof(cInputAttrs[i]))) + + if ret != C.RKNN_SUCC { + return nil, fmt.Errorf("C.rknn_query RKNN_QUERY_INPUT_ATTR failed with code %d, error: %s", + int(ret), ErrorCodes(ret).String()) + } + } + + // convert the C.rknn_tensor_attr array to a TensorAttr + inputAttrs := make([]TensorAttr, r.ioNum.NumberInput) + + for i, cAttr := range cInputAttrs { + inputAttrs[i] = r.convertTensorAttr(&cAttr) + } + + return inputAttrs, nil +} + +// QueryOutputTensors gets the model Output Tensor attributes +func (r *Runtime) QueryOutputTensors() ([]TensorAttr, error) { + + // allocate memory for input attributes in C + cOutputAttrs := make([]C.rknn_tensor_attr, r.ioNum.NumberOutput) + + for i := uint32(0); i < r.ioNum.NumberOutput; i++ { + cOutputAttrs[i].index = C.uint32_t(i) + + ret := C.rknn_query(r.ctx, C.RKNN_QUERY_OUTPUT_ATTR, unsafe.Pointer(&cOutputAttrs[i]), C.uint(unsafe.Sizeof(cOutputAttrs[i]))) + + if ret != C.RKNN_SUCC { + return nil, fmt.Errorf("rknn_query RKNN_QUERY_OUTPUT_ATTR failed with code %d, error: %s", + int(ret), ErrorCodes(ret).String()) + } + } + + // convert the C rknn_tensor_attr array to a Go slice of RKNNTensorAttr + outputAttrs := make([]TensorAttr, r.ioNum.NumberOutput) + + for i, cAttr := range cOutputAttrs { + outputAttrs[i] = r.convertTensorAttr(&cAttr) + } + + return outputAttrs, nil + +} + +// String returns the TensorAttr's attributes formatted as a string +func (a TensorAttr) String() string { + + return fmt.Sprintf("index=%d, name=%s, n_dims=%d, "+ + "dims=[%d, %d, %d, %d], n_elems=%d, "+ + "size=%d, fmt=%s, type=%s, qnt_type=%s, zp=%d, scale=%f", + a.Index, a.Name, a.NDims, a.Dims[0], a.Dims[1], a.Dims[2], a.Dims[3], + a.NElems, a.Size, a.Fmt.String(), a.Type.String(), a.QntType.String(), a.ZP, a.Scale, + ) +} + +// String returns a readable description of the TensorType +func (t TensorType) String() string { + switch t { + case TensorFloat32: + return "FP32" + case TensorFloat16: + return "FP16" + case TensorInt8: + return "INT8" + case TensorUint8: + return "UINT8" + case TensorInt16: + return "INT16" + case TensorUint16: + return "UINT16" + case TensorInt32: + return "INT32" + case TensorUint32: + return "UINT32" + case TensorInt64: + return "INT64" + case TensorBool: + return "BOOL" + case TensorInt4: + return "INT4" + default: + return "UNKNOW" + } +} + +// String returns a readable description of the TensorQntType +func (t TensorQntType) String() string { + switch t { + case TensorQntNone: + return "NONE" + case TensorQntDFP: + return "DFP" + case TensorQntAffine: + return "AFFINE" + default: + return "UNKNOW" + } +} + +// String returns a readable description of the TensorFormat +func (t TensorFormat) String() string { + switch t { + case TensorNCHW: + return "NCHW" + case TensorNHWC: + return "NHWC" + case TensorNC1HWC2: + return "NC1HWC2" + case TensorUndefined: + return "UNDEFINED" + default: + return "UNKNOW" + } +}