mirror of
https://github.com/swdee/go-rknnlite.git
synced 2025-09-26 19:31:12 +08:00
imported initial code
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
.idea
|
||||
prototype
|
||||
python
|
6
cgo.go
Normal file
6
cgo.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package rknnlite
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -lrknnrt
|
||||
*/
|
||||
import "C"
|
2
example/data/.gitignore
vendored
Normal file
2
example/data/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*.jpg
|
||||
*.rknn
|
5
example/data/README.md
Normal file
5
example/data/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
|
||||
# Example Data Files
|
||||
|
||||
Use the `download.sh` script to download the data files (models and images)
|
||||
needed to run the example code.
|
6
example/data/download.sh
Executable file
6
example/data/download.sh
Executable file
@@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
# download mobilenet_v1 data files
|
||||
wget https://github.com/airockchip/rknn-toolkit2/raw/v1.6.0/rknpu2/examples/rknn_mobilenet_demo/model/cat_224x224.jpg
|
||||
wget https://github.com/airockchip/rknn-toolkit2/raw/v1.6.0/rknpu2/examples/rknn_mobilenet_demo/model/dog_224x224.jpg
|
||||
wget -O mobilenet_v1-rk3588.rknn https://github.com/airockchip/rknn-toolkit2/raw/v1.6.0/rknpu2/examples/rknn_mobilenet_demo/model/RK3588/mobilenet_v1.rknn
|
119
example/mobilenet/mobilenet.go
Normal file
119
example/mobilenet/mobilenet.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
rknnlite "go-rknnlite"
|
||||
"gocv.io/x/gocv"
|
||||
"image"
|
||||
"log"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// disable logging timestamps
|
||||
log.SetFlags(0)
|
||||
|
||||
// read in cli flags
|
||||
modelFile := flag.String("m", "../data/mobilenet_v1-rk3588.rknn", "RKNN compiled model file")
|
||||
imgFile := flag.String("i", "../data/cat_224x224.jpg", "Image file to run inference on")
|
||||
flag.Parse()
|
||||
|
||||
// load image
|
||||
img := gocv.IMRead(*imgFile, gocv.IMReadColor)
|
||||
|
||||
if img.Empty() {
|
||||
log.Fatal("Error reading image from: ", *imgFile)
|
||||
}
|
||||
|
||||
// convert colorspace and resize image
|
||||
rgbImg := gocv.NewMat()
|
||||
gocv.CvtColor(img, &rgbImg, gocv.ColorBGRToRGB)
|
||||
|
||||
cropImg := rgbImg.Clone()
|
||||
gocv.Resize(rgbImg, &cropImg, image.Pt(224, 224), 0, 0, gocv.InterpolationArea)
|
||||
|
||||
defer img.Close()
|
||||
defer rgbImg.Close()
|
||||
defer cropImg.Close()
|
||||
|
||||
// create rknn runtime instance
|
||||
rt, err := rknnlite.NewRuntime(*modelFile, rknnlite.NPUCoreAuto)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal("Error initialing RKNN runtime: ", err)
|
||||
}
|
||||
|
||||
// optional querying of model file tensors and SDK version. not necessary
|
||||
// for production inference code
|
||||
optionalQueries(rt)
|
||||
|
||||
// perform inference on image file
|
||||
outputs, err := rt.Inference([]gocv.Mat{cropImg})
|
||||
|
||||
if err != nil {
|
||||
log.Fatal("Runtime inferencing failed with error: ", err)
|
||||
}
|
||||
|
||||
// post process outputs and show top5 matches
|
||||
log.Println(" --- Top5 ---")
|
||||
|
||||
for _, next := range rknnlite.GetTop5(outputs) {
|
||||
log.Printf("%3d: %8.6f\n", next.LabelIndex, next.Probability)
|
||||
}
|
||||
|
||||
// close runtime and release resources
|
||||
err = rt.Close()
|
||||
|
||||
if err != nil {
|
||||
log.Fatal("Error closing RKNN runtime: ", err)
|
||||
}
|
||||
|
||||
log.Println("done")
|
||||
}
|
||||
|
||||
func optionalQueries(rt *rknnlite.Runtime) {
|
||||
|
||||
// get SDK version
|
||||
ver, err := rt.SDKVersion()
|
||||
|
||||
if err != nil {
|
||||
log.Fatal("Error initialing RKNN runtime: ", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Driver Version: %s, API Version: %s\n", ver.DriverVersion, ver.APIVersion)
|
||||
|
||||
// get model input and output numbers
|
||||
num, err := rt.QueryModelIONumber()
|
||||
|
||||
if err != nil {
|
||||
log.Fatal("Error querying IO Numbers: ", err)
|
||||
}
|
||||
|
||||
log.Printf("Model Input Number: %d, Ouput Number: %d\n", num.NumberInput, num.NumberOutput)
|
||||
|
||||
// query Input tensors
|
||||
inputAttrs, err := rt.QueryInputTensors()
|
||||
|
||||
if err != nil {
|
||||
log.Fatal("Error querying Input Tensors: ", err)
|
||||
}
|
||||
|
||||
log.Println("Input tensors:")
|
||||
|
||||
for _, attr := range inputAttrs {
|
||||
log.Printf(" %s\n", attr.String())
|
||||
}
|
||||
|
||||
// query Output tensors
|
||||
outputAttrs, err := rt.QueryOutputTensors()
|
||||
|
||||
if err != nil {
|
||||
log.Fatal("Error querying Output Tensors: ", err)
|
||||
}
|
||||
|
||||
log.Println("Output tensors:")
|
||||
|
||||
for _, attr := range outputAttrs {
|
||||
log.Printf(" %s\n", attr.String())
|
||||
}
|
||||
}
|
5
go.mod
Normal file
5
go.mod
Normal file
@@ -0,0 +1,5 @@
|
||||
module go-rknnlite
|
||||
|
||||
go 1.21
|
||||
|
||||
require gocv.io/x/gocv v0.36.1
|
2
go.sum
Normal file
2
go.sum
Normal file
@@ -0,0 +1,2 @@
|
||||
gocv.io/x/gocv v0.36.1 h1:6XkEaPOk7h/umjy+MXgSEtSeCIgcPJhccUjrJFhjdTY=
|
||||
gocv.io/x/gocv v0.36.1/go.mod h1:lmS802zoQmnNvXETpmGriBqWrENPei2GxYx5KUxJsMA=
|
271
inference.go
Normal file
271
inference.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package rknnlite
|
||||
|
||||
/*
|
||||
#include "rknn_api.h"
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"gocv.io/x/gocv"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Input represents the C.rknn_input struct and defines the Input used for
|
||||
// inference
|
||||
type Input struct {
|
||||
// Index is the input index
|
||||
Index uint32
|
||||
// Buf is the gocv Mat input
|
||||
Buf unsafe.Pointer
|
||||
// Size is the number of bytes of Buf
|
||||
Size uint32
|
||||
// Passthrough defines the mode, if True the buf data is passed directly to
|
||||
// the input node of the rknn model without any conversion. If False the
|
||||
// buf data is converted into an input consistent with the model according
|
||||
// to the following type and fmt
|
||||
PassThrough bool
|
||||
// Type is the data type of Buf. This is a required parameter if Passthrough
|
||||
// is False
|
||||
Type TensorType
|
||||
// Fmt is the data format of Buf. This is a required parameter if Passthrough
|
||||
// is False
|
||||
Fmt TensorFormat
|
||||
}
|
||||
|
||||
// Inference runs the model inference on the given inputs
|
||||
func (r *Runtime) Inference(mats []gocv.Mat) ([]Output, error) {
|
||||
|
||||
// convert the cv Mat's into RKNN inputs
|
||||
inputs := make([]Input, len(mats))
|
||||
|
||||
for idx, mat := range mats {
|
||||
|
||||
// make mat continuous
|
||||
if !mat.IsContinuous() {
|
||||
mat = mat.Clone()
|
||||
}
|
||||
|
||||
// cast to float32, as PassThrough below is set to false then RKNN
|
||||
// we convert the input values to that of the tensor inputs in the model,
|
||||
// eg: INT8
|
||||
data, err := mat.DataPtrUint8()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error converting image to float32: %w", err)
|
||||
}
|
||||
|
||||
inputs[idx] = Input{
|
||||
Index: uint32(idx),
|
||||
Type: TensorUint8,
|
||||
Size: uint32(mat.Cols() * mat.Rows() * mat.Channels()),
|
||||
Fmt: TensorNHWC,
|
||||
Buf: unsafe.Pointer(&data[0]),
|
||||
PassThrough: false,
|
||||
}
|
||||
}
|
||||
|
||||
// set the Inputs
|
||||
err := r.SetInputs(inputs)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error setting inputs: %w", err)
|
||||
}
|
||||
|
||||
// run the model
|
||||
err = r.RunModel()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error running model: %w", err)
|
||||
}
|
||||
|
||||
// get Outputs
|
||||
return r.GetOutputs(r.ioNum.NumberOutput)
|
||||
}
|
||||
|
||||
// setInputs wraps C.rknn_inputs_set
|
||||
func (r *Runtime) SetInputs(inputs []Input) error {
|
||||
|
||||
nInputs := C.uint32_t(len(inputs))
|
||||
// make a C array of inputs
|
||||
cInputs := make([]C.rknn_input, len(inputs))
|
||||
|
||||
for i, input := range inputs {
|
||||
cInputs[i].index = C.uint32_t(input.Index)
|
||||
cInputs[i].buf = input.Buf
|
||||
cInputs[i].size = C.uint32_t(input.Size)
|
||||
cInputs[i].pass_through = C.uint8_t(0)
|
||||
if input.PassThrough {
|
||||
cInputs[i].pass_through = C.uint8_t(1)
|
||||
}
|
||||
cInputs[i]._type = C.rknn_tensor_type(input.Type)
|
||||
cInputs[i].fmt = C.rknn_tensor_format(input.Fmt)
|
||||
}
|
||||
|
||||
ret := C.rknn_inputs_set(r.ctx, nInputs, &cInputs[0])
|
||||
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("C.rknn_inputs_set failed with code %d, error: %s",
|
||||
int(ret), ErrorCodes(ret).String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunModel wraps C.rknn_run
|
||||
func (r *Runtime) RunModel() error {
|
||||
|
||||
ret := C.rknn_run(r.ctx, nil)
|
||||
|
||||
if ret < 0 {
|
||||
return fmt.Errorf("C.rknn_run failed with code %d, error: %s",
|
||||
int(ret), ErrorCodes(ret).String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Output wraps C.rknn_output
|
||||
type Output struct {
|
||||
WantFloat uint8 // want transfer output data to float
|
||||
IsPrealloc uint8 // whether buf is pre-allocated
|
||||
Index uint32 // the output index
|
||||
Buf []float32 // the output buf
|
||||
Size uint32 // the size of output buf
|
||||
}
|
||||
|
||||
// GetOutputs returns the Output results
|
||||
func (r *Runtime) GetOutputs(nOutputs uint32) ([]Output, error) {
|
||||
|
||||
outputs := make([]Output, nOutputs)
|
||||
|
||||
// prepare the outputs array in C
|
||||
cOutputs := make([]C.rknn_output, nOutputs)
|
||||
// release cOutputs from memory
|
||||
defer r.releaseOutputs(cOutputs)
|
||||
|
||||
// set want float for all outputs
|
||||
for idx := range cOutputs {
|
||||
cOutputs[idx].want_float = 1
|
||||
}
|
||||
|
||||
// call C function
|
||||
ret := C.rknn_outputs_get(r.ctx, C.uint32_t(nOutputs),
|
||||
(*C.rknn_output)(unsafe.Pointer(&cOutputs[0])), nil)
|
||||
|
||||
if ret < 0 {
|
||||
return nil, fmt.Errorf("C.rknn_outputs_get failed with code %d, error: %s",
|
||||
int(ret), ErrorCodes(ret).String())
|
||||
}
|
||||
|
||||
// convert C.rknn_output array back to Go Output array
|
||||
for i, cOutput := range cOutputs {
|
||||
// convert buffer to []float32
|
||||
buffer := (*[1 << 30]float32)(cOutputs[i].buf)[:cOutputs[i].size/4]
|
||||
|
||||
outputs[i] = Output{
|
||||
WantFloat: uint8(cOutput.want_float),
|
||||
IsPrealloc: uint8(cOutput.is_prealloc),
|
||||
Index: uint32(cOutput.index),
|
||||
Buf: buffer,
|
||||
Size: uint32(cOutput.size),
|
||||
}
|
||||
}
|
||||
|
||||
return outputs, nil
|
||||
}
|
||||
|
||||
// releaseOutputs releases the memory allocated for the outputs by the RKNN
|
||||
// toolkit directly using C rknn_output structs
|
||||
func (r *Runtime) releaseOutputs(cOutputs []C.rknn_output) error {
|
||||
|
||||
// directly use the C array of rknn_output obtained from getOutputs or similar.
|
||||
outputsPtr := (*C.rknn_output)(unsafe.Pointer(&cOutputs[0]))
|
||||
|
||||
// call C.rknn_outputs_release with the context and the outputs pointer
|
||||
ret := C.rknn_outputs_release(r.ctx, C.uint32_t(len(cOutputs)), outputsPtr)
|
||||
|
||||
if ret != 0 {
|
||||
return fmt.Errorf("C.rknn_outputs_release failed with code %d, error: %s",
|
||||
ret, ErrorCodes(ret).String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type Probability struct {
|
||||
LabelIndex int32
|
||||
Probability float32
|
||||
}
|
||||
|
||||
// GetTop5 outputs the Top5 matches in the model, with left column as label
|
||||
// index and right column the match probability. The results are returned
|
||||
// in the Probability slice in descending order from top match.
|
||||
func GetTop5(outputs []Output) []Probability {
|
||||
|
||||
probs := make([]Probability, 5)
|
||||
|
||||
for i := 0; i < len(outputs); i++ {
|
||||
var MaxClass [5]int32
|
||||
var fMaxProb [5]float32
|
||||
|
||||
GetTop(outputs[i].Buf, fMaxProb[:], MaxClass[:], int32(len(outputs[i].Buf)), 5)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
probs[i] = Probability{
|
||||
LabelIndex: MaxClass[i],
|
||||
Probability: fMaxProb[i],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return probs
|
||||
}
|
||||
|
||||
const MAX_TOP_NUM = 20
|
||||
|
||||
// GetTop takes outputs and produces a top list of matches by probability
|
||||
func GetTop(pfProb []float32, pfMaxProb []float32, pMaxClass []int32,
|
||||
outputCount int32, topNum int32) int {
|
||||
|
||||
if topNum > MAX_TOP_NUM {
|
||||
return 0
|
||||
}
|
||||
|
||||
// initialize pfMaxProb with default values, ie: 0
|
||||
for j := range pfMaxProb {
|
||||
pfMaxProb[j] = 0
|
||||
}
|
||||
// initialize pMaxClass with default values, ie: -1
|
||||
for j := range pMaxClass {
|
||||
pMaxClass[j] = -1
|
||||
}
|
||||
|
||||
for j := int32(0); j < topNum; j++ {
|
||||
for i := int32(0); i < outputCount; i++ {
|
||||
|
||||
// skip if the current class is already in the top list
|
||||
skip := false
|
||||
|
||||
for k := 0; k < len(pMaxClass); k++ {
|
||||
if i == pMaxClass[k] {
|
||||
skip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if skip {
|
||||
continue
|
||||
}
|
||||
|
||||
// if the current probability is greater than the j'th max
|
||||
// probability, update pfMaxProb and pMaxClass
|
||||
if pfProb[i] > pfMaxProb[j] && pfProb[i] > 0.000001 {
|
||||
pfMaxProb[j] = pfProb[i]
|
||||
pMaxClass[j] = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
38
io.go
Normal file
38
io.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package rknnlite
|
||||
|
||||
/*
|
||||
#include "rknn_api.h"
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// QueryModelIONumber queries the number of Input and Output tensors of the model
|
||||
func (r *Runtime) QueryModelIONumber() (ioNum IONumber, err error) {
|
||||
|
||||
// prepare the structure to receive the Input/Output number
|
||||
var cIONum C.rknn_input_output_num
|
||||
|
||||
// call the C function
|
||||
ret := C.rknn_query(r.ctx, C.RKNN_QUERY_IN_OUT_NUM, unsafe.Pointer(&cIONum), C.uint(C.sizeof_rknn_input_output_num))
|
||||
|
||||
if ret != C.RKNN_SUCC {
|
||||
return IONumber{}, fmt.Errorf("rknn_query failed with return code %d", int(ret))
|
||||
}
|
||||
|
||||
ioNum = IONumber{
|
||||
NumberInput: uint32(cIONum.n_input),
|
||||
NumberOutput: uint32(cIONum.n_output),
|
||||
}
|
||||
|
||||
return ioNum, nil
|
||||
}
|
||||
|
||||
// IONumber represents the C.rknn_input_output_num struct
|
||||
type IONumber struct {
|
||||
NumberInput uint32
|
||||
NumberOutput uint32
|
||||
}
|
216
runtime.go
Normal file
216
runtime.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package rknnlite
|
||||
|
||||
/*
|
||||
#include "rknn_api.h"
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// CoreMask wraps C.rknn_core_mask
|
||||
type CoreMask int
|
||||
|
||||
// rknn_core_mask values used to target which cores on the NPU the model is run
|
||||
// on. The rk3588 has three cores, auto will pick an idle core to run the model
|
||||
// on, whilst the others specify the specific core or combined number of cores
|
||||
// to run. For multi-core modes the following ops have better acceleration: Conv,
|
||||
// DepthwiseConvolution, Add, Concat, Relu, Clip, Relu6, ThresholdedRelu, Prelu,
|
||||
// and LeakyRelu. Other type of ops will fallback to Core0 to continue running
|
||||
const (
|
||||
NPUCoreAuto CoreMask = C.RKNN_NPU_CORE_AUTO
|
||||
NPUCore0 CoreMask = C.RKNN_NPU_CORE_0
|
||||
NPUCore1 CoreMask = C.RKNN_NPU_CORE_1
|
||||
NPUCore2 CoreMask = C.RKNN_NPU_CORE_2
|
||||
NPUCore01 CoreMask = C.RKNN_NPU_CORE_0_1
|
||||
NPUCore012 CoreMask = C.RKNN_NPU_CORE_0_1_2
|
||||
)
|
||||
|
||||
// ErrorCodes
|
||||
type ErrorCodes int
|
||||
|
||||
// error code values returned by the C API
|
||||
const (
|
||||
Success ErrorCodes = C.RKNN_SUCC
|
||||
ErrFail ErrorCodes = C.RKNN_ERR_FAIL
|
||||
ErrTimeout ErrorCodes = C.RKNN_ERR_TIMEOUT
|
||||
ErrDeviceUnavailable ErrorCodes = C.RKNN_ERR_DEVICE_UNAVAILABLE
|
||||
ErrMallocFail ErrorCodes = C.RKNN_ERR_MALLOC_FAIL
|
||||
ErrParamInvalid ErrorCodes = C.RKNN_ERR_PARAM_INVALID
|
||||
ErrModelInvalid ErrorCodes = C.RKNN_ERR_MODEL_INVALID
|
||||
ErrCtxInvalid ErrorCodes = C.RKNN_ERR_CTX_INVALID
|
||||
ErrInputInvalid ErrorCodes = C.RKNN_ERR_INPUT_INVALID
|
||||
ErrOutputInvalid ErrorCodes = C.RKNN_ERR_OUTPUT_INVALID
|
||||
ErrDeviceMismatch ErrorCodes = C.RKNN_ERR_DEVICE_UNMATCH
|
||||
ErrPreCompiledModel ErrorCodes = C.RKNN_ERR_INCOMPATILE_PRE_COMPILE_MODEL
|
||||
ErrOptimizationVersion ErrorCodes = C.RKNN_ERR_INCOMPATILE_OPTIMIZATION_LEVEL_VERSION
|
||||
ErrPlatformMismatch ErrorCodes = C.RKNN_ERR_TARGET_PLATFORM_UNMATCH
|
||||
)
|
||||
|
||||
// String returns a readable description of the error code
|
||||
func (e ErrorCodes) String() string {
|
||||
switch e {
|
||||
case Success:
|
||||
return "execution successful"
|
||||
case ErrFail:
|
||||
return "execution failed"
|
||||
case ErrTimeout:
|
||||
return "execution timed out"
|
||||
case ErrDeviceUnavailable:
|
||||
return "device is unavailable"
|
||||
case ErrMallocFail:
|
||||
return "C memory allocation failed"
|
||||
case ErrParamInvalid:
|
||||
return "parameter is invalid"
|
||||
case ErrModelInvalid:
|
||||
return "model file is invalid"
|
||||
case ErrCtxInvalid:
|
||||
return "context is invalid"
|
||||
case ErrInputInvalid:
|
||||
return "input is invalid"
|
||||
case ErrOutputInvalid:
|
||||
return "output is invalid"
|
||||
case ErrDeviceMismatch:
|
||||
return "device mismatch, please update rknn sdk and npu driver/firmware"
|
||||
case ErrPreCompiledModel:
|
||||
return "the RKNN model uses pre_compile mode, but is not compatible with current driver"
|
||||
case ErrOptimizationVersion:
|
||||
return "the RKNN model optimization level is not compatible with current driver"
|
||||
case ErrPlatformMismatch:
|
||||
return "the RKNN model target platform is not compatible with the current platform"
|
||||
default:
|
||||
return fmt.Sprintf("unknown error code %d", e)
|
||||
}
|
||||
}
|
||||
|
||||
// Runtime defines the RKNN run time instance
|
||||
type Runtime struct {
|
||||
// ctx is the C runtime context
|
||||
ctx C.rknn_context
|
||||
// ioNum caches the IONumber of Model Input/Output tensors
|
||||
ioNum IONumber
|
||||
}
|
||||
|
||||
// NewRuntime returns a RKNN run time instance. Provide the full path and
|
||||
// filename of the RKNN compiled model file to run.
|
||||
func NewRuntime(modelFile string, core CoreMask) (*Runtime, error) {
|
||||
|
||||
r := &Runtime{}
|
||||
|
||||
err := r.init(modelFile)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = r.setCoreMask(core)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// cache IONumber
|
||||
r.ioNum, err = r.QueryModelIONumber()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// init wraps C.rknn_init which initializes the RKNN context with the given
|
||||
// model. The modelFile is the full path and filename of the RKNN compiled
|
||||
// model file to run.
|
||||
func (r *Runtime) init(modelFile string) error {
|
||||
|
||||
// check file exists in Go, before passing to C
|
||||
info, err := os.Stat(modelFile)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("model file does not exist at %s, error: %w",
|
||||
modelFile, err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return fmt.Errorf("model file is a directory")
|
||||
}
|
||||
|
||||
// convert the Go string to a C string
|
||||
cModelFile := C.CString(modelFile)
|
||||
defer C.free(unsafe.Pointer(cModelFile))
|
||||
|
||||
// call the C function.
|
||||
ret := C.rknn_init(&r.ctx, unsafe.Pointer(cModelFile), 0, 0, nil)
|
||||
|
||||
if ret != C.RKNN_SUCC {
|
||||
return fmt.Errorf("C.rknn_init call failed with code %d, error: %s",
|
||||
ret, ErrorCodes(ret).String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setCoreMark wraps C.rknn_set_core_mask and specifies the NPU core configuration
|
||||
// to run the model on
|
||||
func (r *Runtime) setCoreMask(mask CoreMask) error {
|
||||
|
||||
ret := C.rknn_set_core_mask(r.ctx, C.rknn_core_mask(mask))
|
||||
|
||||
if ret != C.RKNN_SUCC {
|
||||
return fmt.Errorf("C.rknn_set_core_mask failed with code %d, error: %s",
|
||||
ret, ErrorCodes(ret).String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close wraps C.rknn_destroy which unloads the RKNN model from the runtime and
|
||||
// destroys the context releasing all C resources
|
||||
func (r *Runtime) Close() error {
|
||||
|
||||
ret := C.rknn_destroy(r.ctx)
|
||||
|
||||
if ret != C.RKNN_SUCC {
|
||||
return fmt.Errorf("C.rknn_destroy failed with code %d, error: %s",
|
||||
ret, ErrorCodes(ret).String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SDKVersion represents the C.rknn_sdk_version struct
|
||||
type SDKVersion struct {
|
||||
DriverVersion string
|
||||
APIVersion string
|
||||
}
|
||||
|
||||
// SDKVersion returns the RKNN API and Driver versions
|
||||
func (r *Runtime) SDKVersion() (SDKVersion, error) {
|
||||
|
||||
// prepare the structure to receive the SDK version info
|
||||
var cSdkVer C.rknn_sdk_version
|
||||
|
||||
// call the C function
|
||||
ret := C.rknn_query(
|
||||
r.ctx,
|
||||
C.RKNN_QUERY_SDK_VERSION,
|
||||
unsafe.Pointer(&cSdkVer),
|
||||
C.uint(C.sizeof_rknn_sdk_version),
|
||||
)
|
||||
|
||||
if ret != C.RKNN_SUCC {
|
||||
return SDKVersion{}, fmt.Errorf("rknn_query failed with return code %d", int(ret))
|
||||
}
|
||||
|
||||
// convert the C rknn_sdk_version to Go rknn_sdk_version
|
||||
version := SDKVersion{
|
||||
DriverVersion: C.GoString(&(cSdkVer.drv_version[0])),
|
||||
APIVersion: C.GoString(&(cSdkVer.api_version[0])),
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
241
tensor.go
Normal file
241
tensor.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package rknnlite
|
||||
|
||||
/*
|
||||
#include "rknn_api.h"
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// TensorFormat wraps C.rknn_tensor_format
|
||||
type TensorFormat int
|
||||
|
||||
const (
|
||||
TensorNCHW TensorFormat = C.RKNN_TENSOR_NCHW
|
||||
TensorNHWC TensorFormat = C.RKNN_TENSOR_NHWC
|
||||
TensorNC1HWC2 TensorFormat = C.RKNN_TENSOR_NC1HWC2
|
||||
TensorUndefined TensorFormat = C.RKNN_TENSOR_UNDEFINED
|
||||
)
|
||||
|
||||
// TensorType wraps C.rknn_tensor_type
|
||||
type TensorType int
|
||||
|
||||
const (
|
||||
TensorFloat32 TensorType = C.RKNN_TENSOR_FLOAT32
|
||||
TensorFloat16 TensorType = C.RKNN_TENSOR_FLOAT16
|
||||
TensorInt8 TensorType = C.RKNN_TENSOR_INT8
|
||||
TensorUint8 TensorType = C.RKNN_TENSOR_UINT8
|
||||
TensorInt16 TensorType = C.RKNN_TENSOR_INT16
|
||||
TensorUint16 TensorType = C.RKNN_TENSOR_UINT16
|
||||
TensorInt32 TensorType = C.RKNN_TENSOR_INT32
|
||||
TensorUint32 TensorType = C.RKNN_TENSOR_UINT32
|
||||
TensorInt64 TensorType = C.RKNN_TENSOR_INT64
|
||||
TensorBool TensorType = C.RKNN_TENSOR_BOOL
|
||||
TensorInt4 TensorType = C.RKNN_TENSOR_INT4
|
||||
)
|
||||
|
||||
// TensorQntType wraps C.rknn_tensor_qnt_type
|
||||
type TensorQntType int
|
||||
|
||||
const (
|
||||
TensorQntNone TensorQntType = C.RKNN_TENSOR_QNT_NONE
|
||||
TensorQntDFP TensorQntType = C.RKNN_TENSOR_QNT_DFP
|
||||
TensorQntAffine TensorQntType = C.RKNN_TENSOR_QNT_AFFINE_ASYMMETRIC
|
||||
)
|
||||
|
||||
// AttrMaxDimensions are the maximum dimensions for an attribute in a tensor
|
||||
type AttrMaxDimensions int
|
||||
|
||||
// maximum field lengths of attributes in a tensor
|
||||
const (
|
||||
AttrMaxDimension AttrMaxDimensions = C.RKNN_MAX_DIMS
|
||||
AttrMaxChannels AttrMaxDimensions = C.RKNN_MAX_NUM_CHANNEL
|
||||
AttrMaxNameLength AttrMaxDimensions = C.RKNN_MAX_NAME_LEN
|
||||
AttrMaxDynShape AttrMaxDimensions = C.RKNN_MAX_DYNAMIC_SHAPE_NUM
|
||||
)
|
||||
|
||||
// TensorAttr represents the C.rknn_tensor_attr structure
|
||||
type TensorAttr struct {
|
||||
Index uint32
|
||||
NDims uint32
|
||||
Dims [AttrMaxDimension]uint32
|
||||
Name string
|
||||
NElems uint32
|
||||
Size uint32
|
||||
Fmt TensorFormat
|
||||
Type TensorType
|
||||
QntType TensorQntType
|
||||
FL int8
|
||||
ZP int32
|
||||
Scale float32
|
||||
WStride uint32
|
||||
SizeWithStride uint32
|
||||
PassThrough bool
|
||||
HStride uint32
|
||||
}
|
||||
|
||||
// convertTensorAttr converts a C.rknn_tensor_attr to a Go TensorAttr
|
||||
func (r *Runtime) convertTensorAttr(cAttr *C.rknn_tensor_attr) TensorAttr {
|
||||
|
||||
// convert C char array to Go string for Name field
|
||||
nameBytes := C.GoBytes(unsafe.Pointer(&cAttr.name[0]), C.int(AttrMaxNameLength))
|
||||
goName := string(nameBytes)
|
||||
|
||||
// find the first null byte to correctly end the string (if present)
|
||||
nullIndex := strings.IndexByte(goName, 0)
|
||||
|
||||
if nullIndex != -1 {
|
||||
// Trim the string at the first null character
|
||||
goName = goName[:nullIndex]
|
||||
}
|
||||
|
||||
return TensorAttr{
|
||||
Index: uint32(cAttr.index),
|
||||
NDims: uint32(cAttr.n_dims),
|
||||
Dims: *(*[AttrMaxDimension]uint32)(unsafe.Pointer(&cAttr.dims)),
|
||||
Name: goName,
|
||||
NElems: uint32(cAttr.n_elems),
|
||||
Size: uint32(cAttr.size),
|
||||
Fmt: TensorFormat(cAttr.fmt),
|
||||
Type: TensorType(cAttr._type),
|
||||
QntType: TensorQntType(cAttr.qnt_type),
|
||||
FL: int8(cAttr.fl),
|
||||
ZP: int32(cAttr.zp),
|
||||
Scale: float32(cAttr.scale),
|
||||
WStride: uint32(cAttr.w_stride),
|
||||
SizeWithStride: uint32(cAttr.size_with_stride),
|
||||
PassThrough: cAttr.pass_through != 0,
|
||||
HStride: uint32(cAttr.h_stride),
|
||||
}
|
||||
}
|
||||
|
||||
// QueryInputTensors gets the model Input Tensor attributes
|
||||
func (r *Runtime) QueryInputTensors() ([]TensorAttr, error) {
|
||||
|
||||
// allocate memory for input attributes in C
|
||||
cInputAttrs := make([]C.rknn_tensor_attr, r.ioNum.NumberInput)
|
||||
|
||||
for i := uint32(0); i < r.ioNum.NumberInput; i++ {
|
||||
cInputAttrs[i].index = C.uint32_t(i)
|
||||
|
||||
ret := C.rknn_query(r.ctx, C.RKNN_QUERY_INPUT_ATTR,
|
||||
unsafe.Pointer(&cInputAttrs[i]), C.uint(unsafe.Sizeof(cInputAttrs[i])))
|
||||
|
||||
if ret != C.RKNN_SUCC {
|
||||
return nil, fmt.Errorf("C.rknn_query RKNN_QUERY_INPUT_ATTR failed with code %d, error: %s",
|
||||
int(ret), ErrorCodes(ret).String())
|
||||
}
|
||||
}
|
||||
|
||||
// convert the C.rknn_tensor_attr array to a TensorAttr
|
||||
inputAttrs := make([]TensorAttr, r.ioNum.NumberInput)
|
||||
|
||||
for i, cAttr := range cInputAttrs {
|
||||
inputAttrs[i] = r.convertTensorAttr(&cAttr)
|
||||
}
|
||||
|
||||
return inputAttrs, nil
|
||||
}
|
||||
|
||||
// QueryOutputTensors gets the model Output Tensor attributes
|
||||
func (r *Runtime) QueryOutputTensors() ([]TensorAttr, error) {
|
||||
|
||||
// allocate memory for input attributes in C
|
||||
cOutputAttrs := make([]C.rknn_tensor_attr, r.ioNum.NumberOutput)
|
||||
|
||||
for i := uint32(0); i < r.ioNum.NumberOutput; i++ {
|
||||
cOutputAttrs[i].index = C.uint32_t(i)
|
||||
|
||||
ret := C.rknn_query(r.ctx, C.RKNN_QUERY_OUTPUT_ATTR, unsafe.Pointer(&cOutputAttrs[i]), C.uint(unsafe.Sizeof(cOutputAttrs[i])))
|
||||
|
||||
if ret != C.RKNN_SUCC {
|
||||
return nil, fmt.Errorf("rknn_query RKNN_QUERY_OUTPUT_ATTR failed with code %d, error: %s",
|
||||
int(ret), ErrorCodes(ret).String())
|
||||
}
|
||||
}
|
||||
|
||||
// convert the C rknn_tensor_attr array to a Go slice of RKNNTensorAttr
|
||||
outputAttrs := make([]TensorAttr, r.ioNum.NumberOutput)
|
||||
|
||||
for i, cAttr := range cOutputAttrs {
|
||||
outputAttrs[i] = r.convertTensorAttr(&cAttr)
|
||||
}
|
||||
|
||||
return outputAttrs, nil
|
||||
|
||||
}
|
||||
|
||||
// String returns the TensorAttr's attributes formatted as a string
|
||||
func (a TensorAttr) String() string {
|
||||
|
||||
return fmt.Sprintf("index=%d, name=%s, n_dims=%d, "+
|
||||
"dims=[%d, %d, %d, %d], n_elems=%d, "+
|
||||
"size=%d, fmt=%s, type=%s, qnt_type=%s, zp=%d, scale=%f",
|
||||
a.Index, a.Name, a.NDims, a.Dims[0], a.Dims[1], a.Dims[2], a.Dims[3],
|
||||
a.NElems, a.Size, a.Fmt.String(), a.Type.String(), a.QntType.String(), a.ZP, a.Scale,
|
||||
)
|
||||
}
|
||||
|
||||
// String returns a readable description of the TensorType
|
||||
func (t TensorType) String() string {
|
||||
switch t {
|
||||
case TensorFloat32:
|
||||
return "FP32"
|
||||
case TensorFloat16:
|
||||
return "FP16"
|
||||
case TensorInt8:
|
||||
return "INT8"
|
||||
case TensorUint8:
|
||||
return "UINT8"
|
||||
case TensorInt16:
|
||||
return "INT16"
|
||||
case TensorUint16:
|
||||
return "UINT16"
|
||||
case TensorInt32:
|
||||
return "INT32"
|
||||
case TensorUint32:
|
||||
return "UINT32"
|
||||
case TensorInt64:
|
||||
return "INT64"
|
||||
case TensorBool:
|
||||
return "BOOL"
|
||||
case TensorInt4:
|
||||
return "INT4"
|
||||
default:
|
||||
return "UNKNOW"
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a readable description of the TensorQntType
|
||||
func (t TensorQntType) String() string {
|
||||
switch t {
|
||||
case TensorQntNone:
|
||||
return "NONE"
|
||||
case TensorQntDFP:
|
||||
return "DFP"
|
||||
case TensorQntAffine:
|
||||
return "AFFINE"
|
||||
default:
|
||||
return "UNKNOW"
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a readable description of the TensorFormat
|
||||
func (t TensorFormat) String() string {
|
||||
switch t {
|
||||
case TensorNCHW:
|
||||
return "NCHW"
|
||||
case TensorNHWC:
|
||||
return "NHWC"
|
||||
case TensorNC1HWC2:
|
||||
return "NC1HWC2"
|
||||
case TensorUndefined:
|
||||
return "UNDEFINED"
|
||||
default:
|
||||
return "UNKNOW"
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user