Files
go-rknnlite/example/sahi/sahi.go
2025-06-21 16:29:54 +12:00

180 lines
5.0 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"flag"
"fmt"
"github.com/swdee/go-rknnlite"
"github.com/swdee/go-rknnlite/postprocess"
"github.com/swdee/go-rknnlite/preprocess"
"github.com/swdee/go-rknnlite/render"
"gocv.io/x/gocv"
"log"
"strings"
"sync"
"time"
)
func main() {
// disable logging timestamps
log.SetFlags(0)
// read in cli flags
modelFile := flag.String("m", "../data/models/rk3588/yolov5s-rk3588.rknn", "RKNN compiled YOLO model file")
imgFile := flag.String("i", "../data/protest.jpg", "Image file to run object detection on")
poolSize := flag.Int("s", 1, "Size of RKNN runtime pool, choose 1, 2, 3, or multiples of 3")
labelFile := flag.String("l", "../data/coco_80_labels_list.txt", "Text file containing model labels")
saveFile := flag.String("o", "../data/protest-sahi-out.jpg", "The output JPG file with object detection markers")
rkPlatform := flag.String("p", "rk3588", "Rockchip CPU Model number [rk3562|rk3566|rk3568|rk3576|rk3582|rk3582|rk3588]")
flag.Parse()
err := rknnlite.SetCPUAffinityByPlatform(*rkPlatform, rknnlite.FastCores)
if err != nil {
log.Printf("Failed to set CPU Affinity: %v\n", err)
}
// check if user specified model file or if default is being used. if default
// then pick the default platform model to use.
if f := flag.Lookup("m"); f != nil && f.Value.String() == f.DefValue && *rkPlatform != "rk3588" {
*modelFile = strings.ReplaceAll(*modelFile, "rk3588", *rkPlatform)
}
// create new pool
pool, err := rknnlite.NewPoolByPlatform(*rkPlatform, *poolSize, *modelFile)
if err != nil {
log.Fatalf("Error creating RKNN pool: %v\n", err)
}
// set runtime to leave output tensors as int8
pool.SetWantFloat(false)
// create YOLOv5 post processor
yoloProcesser := postprocess.NewYOLOv5(postprocess.YOLOv5COCOParams())
// load in Model class names
classNames, err := rknnlite.LoadLabels(*labelFile)
if err != nil {
log.Fatal("Error loading model labels: ", err)
}
// load image
img := gocv.IMRead(*imgFile, gocv.IMReadColor)
if img.Empty() {
log.Fatal("Error reading image from: ", *imgFile)
}
// output dimensions of source image
log.Printf("Source image dimensions %dx%d\n", img.Cols(), img.Rows())
// get the tensor input dimensions
rt := pool.Get()
tensorWidth := int(rt.InputAttrs()[0].Dims[1])
tensorHeight := int(rt.InputAttrs()[0].Dims[2])
pool.Return(rt)
start := time.Now()
sahi := preprocess.NewSAHI(tensorWidth, tensorHeight, 0.2, 0.2)
slices := sahi.Slice(img)
// waitgroup used to wait for all go-routines to complete before closing
// the pool
var wg sync.WaitGroup
// create mutex to ensure stdout results are in order
var printMu sync.Mutex
// run inference on all the slices
for _, slice := range slices {
// Pin the current slice in a new variable
sl := slice
wg.Add(1)
// build one big string for this slices output
var sb strings.Builder
sb.WriteString(fmt.Sprintf("\nProcessing Slice (%d %d %d %d) with box size (%d %d)\n",
sl.X, sl.Y, sl.X2, sl.Y2, sl.X2-sl.X, sl.Y2-sl.Y),
)
// pool.Get() blocks if no runtimes are available in the pool
rt := pool.Get()
go func(sl preprocess.Slice, rt *rknnlite.Runtime) {
// perform inference on image file
outputs, err := rt.Inference([]gocv.Mat{*sl.Mat()})
if err != nil {
log.Fatal("Runtime inferencing failed with error: ", err)
}
detectObjs := yoloProcesser.DetectObjects(outputs, sl.Resizer())
detectResults := detectObjs.GetDetectResults()
// output detection boxes to stdout
for _, detResult := range detectResults {
sb.WriteString(fmt.Sprintf(
"%s @ (%d %d %d %d) %f\n",
classNames[detResult.Class], detResult.Box.Left,
detResult.Box.Top, detResult.Box.Right,
detResult.Box.Bottom, detResult.Probability),
)
}
sahi.AddResult(sl, detectResults)
// free outputs allocated in C memory after you have finished post processing
err = outputs.Free()
sl.Free()
// print slice object detection results
printMu.Lock()
fmt.Print(sb.String())
printMu.Unlock()
pool.Return(rt)
wg.Done()
}(sl, rt)
}
wg.Wait()
// get the detection results from all slices combined into those which map
// back onto the source image dimensions
detectResults := sahi.GetDetectResults(postprocess.YOLOv5COCOParams().NMSThreshold, 0.7)
fmt.Printf("\nCombined object detection results\n")
for _, detResult := range detectResults {
fmt.Printf("%s @ (%d %d %d %d) %f\n", classNames[detResult.Class], detResult.Box.Left, detResult.Box.Top, detResult.Box.Right, detResult.Box.Bottom, detResult.Probability)
}
render.DetectionBoxes(&img, detectResults, classNames,
render.DefaultFont(), 2)
log.Printf("SAHI Execution speed=%s, slices=%d, objects=%d\n",
time.Now().Sub(start).String(),
len(slices),
len(detectResults),
)
// Save the result
if ok := gocv.IMWrite(*saveFile, img); !ok {
log.Fatal("Failed to save the image")
}
log.Printf("Saved object detection result to %s\n", *saveFile)
// free results
sahi.FreeResults()
// close runtime and release resources
pool.Close()
log.Println("done")
}