mirror of
https://github.com/swdee/go-rknnlite.git
synced 2025-11-03 01:33:51 +08:00
530 lines
14 KiB
Go
530 lines
14 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"github.com/swdee/go-rknnlite"
|
|
"github.com/swdee/go-rknnlite/postprocess"
|
|
"github.com/swdee/go-rknnlite/tracker"
|
|
"gocv.io/x/gocv"
|
|
"image"
|
|
"image/color"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
// FPS is the number of FPS to simulate
|
|
FPS = 30
|
|
FPSinterval = time.Duration(float64(time.Second) / float64(FPS))
|
|
|
|
clrPink = color.RGBA{R: 255, G: 0, B: 255, A: 255}
|
|
clrRed = color.RGBA{R: 255, G: 0, B: 0, A: 255}
|
|
clrBlack = color.RGBA{R: 0, G: 0, B: 0, A: 255}
|
|
clrWhite = color.RGBA{R: 255, G: 255, B: 255, A: 255}
|
|
clrYellow = color.RGBA{R: 255, G: 255, B: 50, A: 255}
|
|
)
|
|
|
|
// ImgScale holds the scale factor for images between their video source size
|
|
// and input tensor size
|
|
type ImgScale struct {
|
|
Width float32
|
|
Height float32
|
|
}
|
|
|
|
// Timing is a struct to hold timers used for finding execution time
|
|
// for various parts of the process
|
|
type Timing struct {
|
|
ProcessStart time.Time
|
|
DetObjStart time.Time
|
|
DetObjInferenceEnd time.Time
|
|
DetObjEnd time.Time
|
|
TrackerStart time.Time
|
|
TrackerEnd time.Time
|
|
RenderingStart time.Time
|
|
ProcessEnd time.Time
|
|
}
|
|
|
|
// ResultFrame is a struct to wrap the gocv byte buffer and error result
|
|
type ResultFrame struct {
|
|
Buf *gocv.NativeByteBuffer
|
|
Err error
|
|
}
|
|
|
|
type YOLOProcessor interface {
|
|
DetectObjects(outputs *rknnlite.Outputs) []postprocess.DetectResult
|
|
}
|
|
|
|
type Processor struct {
|
|
process YOLOProcessor
|
|
}
|
|
|
|
func NewProcessor(process YOLOProcessor) *Processor {
|
|
return &Processor{process: process}
|
|
}
|
|
|
|
func (p *Processor) DetectObjects(outputs *rknnlite.Outputs) []postprocess.DetectResult {
|
|
return p.process.DetectObjects(outputs)
|
|
}
|
|
|
|
// Demo defines the struct for running the object tracking demo
|
|
type Demo struct {
|
|
// vidBuffer buffers the video frames into memory
|
|
vidBuffer []gocv.Mat
|
|
// pool of rknnlite runtimes to perform inference in parallel
|
|
pool *rknnlite.Pool
|
|
// process is a YOLO object detection processor
|
|
process *Processor
|
|
// labels are the COCO labels the YOLO model was trained on
|
|
labels []string
|
|
// inputAttrs are the model tensor input attributes
|
|
inputAttrs []rknnlite.TensorAttr
|
|
// scale holds the scale factor between video frames resolution and
|
|
// model tensor input size
|
|
scale ImgScale
|
|
// limitObjs restricts object detection results to be only those provided
|
|
limitObjs []string
|
|
}
|
|
|
|
// NewDemo returns and instance of Demo, a streaming HTTP server showing
|
|
// video with object detection
|
|
func NewDemo(vidFile, modelFile, labelFile string, poolSize int,
|
|
modelType string) (*Demo, error) {
|
|
|
|
d := &Demo{
|
|
limitObjs: make([]string, 0),
|
|
}
|
|
|
|
err := d.bufferVideo(vidFile)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Error buffering video: %w", err)
|
|
}
|
|
|
|
// create new pool
|
|
d.pool, err = rknnlite.NewPool(poolSize, modelFile)
|
|
|
|
if err != nil {
|
|
log.Fatalf("Error creating RKNN pool: %v\n", err)
|
|
}
|
|
|
|
// set runtime to leave output tensors as int8
|
|
d.pool.SetWantFloat(false)
|
|
|
|
// create YOLOv5 post processor
|
|
switch modelType {
|
|
case "v8":
|
|
d.process = NewProcessor(postprocess.NewYOLOv8(postprocess.YOLOv8COCOParams()))
|
|
case "v5":
|
|
d.process = NewProcessor(postprocess.NewYOLOv5(postprocess.YOLOv5COCOParams()))
|
|
default:
|
|
log.Fatal("Unknown model type, use 'v5' or 'v8'")
|
|
}
|
|
|
|
// load in Model class names
|
|
d.labels, err = rknnlite.LoadLabels(labelFile)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Error loading model labels: %w", err)
|
|
}
|
|
|
|
// query Input tensors to get model size
|
|
rt := d.pool.Get()
|
|
d.inputAttrs, err = rt.QueryInputTensors()
|
|
d.pool.Return(rt)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Error querying Input Tensors: %w", err)
|
|
}
|
|
|
|
log.Printf("Model Input Tensor Dimensions %dx%d", int(d.inputAttrs[0].Dims[1]), int(d.inputAttrs[0].Dims[2]))
|
|
|
|
// get image frame size and calculate scale factor
|
|
d.calcScaleFactor()
|
|
|
|
return d, nil
|
|
}
|
|
|
|
// LimitObjects limits the object detection kind to the labels provided, eg:
|
|
// limit to just "person". Provide a comma delimited list of labels to
|
|
// restrict to.
|
|
func (d *Demo) LimitObjects(lim string) {
|
|
|
|
words := strings.Split(lim, ",")
|
|
|
|
for _, word := range words {
|
|
trimmed := strings.TrimSpace(word)
|
|
|
|
// check if word is an actual label in our labels file
|
|
if containsStr(d.labels, trimmed) {
|
|
d.limitObjs = append(d.limitObjs, trimmed)
|
|
}
|
|
}
|
|
|
|
log.Printf("Limiting object detection class to: %s\n", strings.Join(d.limitObjs, ", "))
|
|
}
|
|
|
|
// containsStr is a function that takes a string slice and checks if a given
|
|
// string exists in the slice
|
|
func containsStr(slice []string, item string) bool {
|
|
for _, s := range slice {
|
|
if s == item {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// calcScaleFactor calculates the scale factor between the video image frame
|
|
// size and that of the Model input tensor size
|
|
func (d *Demo) calcScaleFactor() {
|
|
|
|
// get size of video image frame
|
|
width := d.vidBuffer[0].Cols()
|
|
height := d.vidBuffer[0].Rows()
|
|
|
|
d.scale = ImgScale{
|
|
Width: float32(width) / float32(d.inputAttrs[0].Dims[1]),
|
|
Height: float32(height) / float32(d.inputAttrs[0].Dims[2]),
|
|
}
|
|
|
|
log.Printf("Scale factor: Width=%.3f, Height=%.3f\n", d.scale.Width, d.scale.Height)
|
|
}
|
|
|
|
// bufferVideo reads in the video frames and saves them to a buffer
|
|
func (d *Demo) bufferVideo(vidFile string) error {
|
|
|
|
// open handle to read frames of video file
|
|
video, err := gocv.VideoCaptureFile(vidFile)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
defer video.Close()
|
|
|
|
d.vidBuffer = make([]gocv.Mat, 0)
|
|
|
|
for {
|
|
img := gocv.NewMat()
|
|
|
|
// read the next frame from the video
|
|
if ok := video.Read(&img); !ok {
|
|
// reached last video frame
|
|
break
|
|
}
|
|
|
|
// Check if the frame is empty
|
|
if img.Empty() {
|
|
continue
|
|
}
|
|
|
|
// push frame onto buffer
|
|
d.vidBuffer = append(d.vidBuffer, img)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stream is the HTTP handler function used to stream video frames to browser
|
|
func (d *Demo) Stream(w http.ResponseWriter, r *http.Request) {
|
|
|
|
log.Printf("New client connection established\n")
|
|
|
|
w.Header().Set("Content-Type", "multipart/x-mixed-replace; boundary=frame")
|
|
|
|
// pointer to position in video buffer
|
|
frameNum := -1
|
|
|
|
// create a bytetracker for tracking detected objects
|
|
// you must create a new instance of byteTrack per stream as it keeps a
|
|
// record of past object detections for tracking
|
|
byteTrack := tracker.NewBYTETracker(FPS, FPS*10, 0.5, 0.6, 0.8)
|
|
|
|
// create a trails history
|
|
trail := tracker.NewTrail(90, d.scale.Width, d.scale.Height)
|
|
|
|
// create Mat for annotated image
|
|
resImg := gocv.NewMat()
|
|
defer resImg.Close()
|
|
|
|
// used for calculating FPS
|
|
frameCount := 0
|
|
startTime := time.Now()
|
|
fps := float64(0)
|
|
|
|
ticker := time.NewTicker(FPSinterval)
|
|
defer ticker.Stop()
|
|
|
|
// chan to receive processed frames
|
|
recvFrame := make(chan ResultFrame, 30)
|
|
|
|
loop:
|
|
for {
|
|
select {
|
|
case <-r.Context().Done():
|
|
log.Printf("Client disconnected\n")
|
|
break loop
|
|
|
|
// simulate reading 30FPS web camera
|
|
case <-ticker.C:
|
|
|
|
// increment pointer to next image in the video buffer
|
|
frameNum++
|
|
if frameNum > len(d.vidBuffer)-1 {
|
|
// last frame reached so loop back to start of video
|
|
frameNum = 0
|
|
// clear tracker data
|
|
byteTrack.Reset()
|
|
// clear trail data
|
|
trail.Reset()
|
|
}
|
|
|
|
go d.ProcessFrame(d.vidBuffer[frameNum], recvFrame, fps, frameNum,
|
|
byteTrack, trail)
|
|
|
|
case buf := <-recvFrame:
|
|
|
|
if buf.Err != nil {
|
|
log.Printf("Error occured during ProcessFrame: %v", buf.Err)
|
|
|
|
} else {
|
|
// Write the image to the response writer
|
|
w.Write([]byte("--frame\r\n"))
|
|
w.Write([]byte("Content-Type: image/jpeg\r\n\r\n"))
|
|
w.Write(buf.Buf.GetBytes())
|
|
w.Write([]byte("\r\n"))
|
|
|
|
// Flush the buffer
|
|
flusher, ok := w.(http.Flusher)
|
|
if ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
buf.Buf.Close()
|
|
|
|
// calculate FPS
|
|
frameCount++
|
|
elapsed := time.Since(startTime).Seconds()
|
|
|
|
if elapsed >= 1.0 {
|
|
fps = float64(frameCount) / elapsed
|
|
frameCount = 0
|
|
startTime = time.Now()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ProcessFrame takes an image from the video and runs inference/object
|
|
// detection on it, annotates the image and returns the result encoded
|
|
// as a JPG file
|
|
func (d *Demo) ProcessFrame(img gocv.Mat, retChan chan<- ResultFrame,
|
|
fps float64, frameNum int, byteTrack *tracker.BYTETracker, trail *tracker.Trail) {
|
|
|
|
timing := &Timing{
|
|
ProcessStart: time.Now(),
|
|
}
|
|
|
|
resImg := gocv.NewMat()
|
|
defer resImg.Close()
|
|
|
|
// run object detection on frame
|
|
detObjs, err := d.DetectObjects(img, frameNum, timing)
|
|
|
|
if err != nil {
|
|
log.Printf("Error detecting objects: %v", err)
|
|
}
|
|
|
|
// track detected objects
|
|
timing.TrackerStart = time.Now()
|
|
trackObjs, err := byteTrack.Update(tracker.DetectionsToObjects(detObjs))
|
|
timing.TrackerEnd = time.Now()
|
|
|
|
// add tracked objects to history trail
|
|
for _, trackObj := range trackObjs {
|
|
trail.Add(trackObj)
|
|
}
|
|
|
|
// copy the source image and annotate the copy
|
|
img.CopyTo(&resImg)
|
|
d.AnnotateImg(resImg, trackObjs, trail, fps, frameNum, timing)
|
|
|
|
// Encode the image to JPEG format
|
|
buf, err := gocv.IMEncode(".jpg", resImg)
|
|
|
|
res := ResultFrame{
|
|
Buf: buf,
|
|
Err: err,
|
|
}
|
|
|
|
retChan <- res
|
|
}
|
|
|
|
func (d *Demo) AnnotateImg(img gocv.Mat, trackResults []*tracker.STrack,
|
|
trail *tracker.Trail, fps float64, frameNum int, timing *Timing) {
|
|
|
|
objCnt := 0
|
|
timing.RenderingStart = time.Now()
|
|
|
|
for _, tResult := range trackResults {
|
|
|
|
// exclude objects detected that are not a given class/label
|
|
if len(d.limitObjs) > 0 {
|
|
if !containsStr(d.limitObjs, d.labels[tResult.GetLabel()]) {
|
|
continue
|
|
}
|
|
}
|
|
|
|
objCnt++
|
|
|
|
text := fmt.Sprintf("%s %d",
|
|
d.labels[tResult.GetLabel()], tResult.GetTrackID())
|
|
|
|
// calculate the coordinates in the original image
|
|
originalLeft := int(tResult.GetRect().TLX() * d.scale.Width)
|
|
originalTop := int(tResult.GetRect().TLY() * d.scale.Height)
|
|
originalRight := int(tResult.GetRect().BRX() * d.scale.Width)
|
|
originalBottom := int(tResult.GetRect().BRY() * d.scale.Height)
|
|
|
|
// Draw rectangle around detected object
|
|
rect := image.Rect(originalLeft, originalTop, originalRight, originalBottom)
|
|
gocv.Rectangle(&img, rect, clrPink, 1)
|
|
|
|
// draw trail line showing tracking history
|
|
points := trail.GetPoints(tResult.GetTrackID())
|
|
|
|
if len(points) > 2 {
|
|
// draw trail
|
|
for i := 1; i < len(points); i++ {
|
|
// draw line segment of trail
|
|
gocv.Line(&img,
|
|
image.Pt(points[i-1].X, points[i-1].Y),
|
|
image.Pt(points[i].X, points[i].Y),
|
|
clrYellow, 1,
|
|
)
|
|
|
|
if i == len(points)-1 {
|
|
// draw center point circle on current rect/box
|
|
gocv.Circle(&img, image.Pt(points[i].X, points[i].Y), 3, clrPink, -1)
|
|
}
|
|
}
|
|
}
|
|
|
|
// create box for placing text on
|
|
textSize := gocv.GetTextSize(text, gocv.FontHersheySimplex, 0.4, 1)
|
|
bRect := image.Rect(originalLeft, originalTop-textSize.Y-8, originalLeft+textSize.X+4, originalTop)
|
|
gocv.Rectangle(&img, bRect, clrPink, -1)
|
|
|
|
// put text with detection class/label and tracking id
|
|
gocv.PutTextWithParams(&img, text, image.Pt(originalLeft+4, originalTop-5),
|
|
gocv.FontHersheySimplex, 0.4, clrBlack,
|
|
1, gocv.LineAA, false)
|
|
}
|
|
|
|
timing.ProcessEnd = time.Now()
|
|
|
|
// calculate processing lag
|
|
lag := time.Since(timing.ProcessStart).Milliseconds() - int64(FPS)
|
|
|
|
// blank out background video
|
|
rect := image.Rect(0, 0, img.Cols(), 36)
|
|
gocv.Rectangle(&img, rect, clrBlack, -1) // -1 fills the rectangle
|
|
|
|
// add FPS, object count, and frame number to top of image
|
|
gocv.PutTextWithParams(&img, fmt.Sprintf("Frame: %d, FPS: %.2f, Lag: %dms, Objects: %d", frameNum, fps, lag, objCnt),
|
|
image.Pt(4, 14), gocv.FontHersheySimplex, 0.5, clrPink, 1,
|
|
gocv.LineAA, false)
|
|
|
|
// add inference stats to top of image
|
|
gocv.PutTextWithParams(&img, fmt.Sprintf("Inference: %.2fms, Post Processing: %.2fms, Tracking: %.2fms, Rendering: %.2fms, Total Time: %.2fms",
|
|
float32(timing.DetObjInferenceEnd.Sub(timing.DetObjStart))/float32(time.Millisecond),
|
|
float32(timing.DetObjEnd.Sub(timing.DetObjInferenceEnd))/float32(time.Millisecond),
|
|
float32(timing.TrackerEnd.Sub(timing.TrackerStart))/float32(time.Millisecond),
|
|
float32(timing.ProcessEnd.Sub(timing.RenderingStart))/float32(time.Millisecond),
|
|
float32(timing.ProcessEnd.Sub(timing.ProcessStart))/float32(time.Millisecond),
|
|
),
|
|
image.Pt(4, 30), gocv.FontHersheySimplex, 0.5, clrPink, 1,
|
|
gocv.LineAA, false)
|
|
}
|
|
|
|
// DetectObjects takes a raw video frame and runs YOLO inference on it to detect
|
|
// objects
|
|
func (d *Demo) DetectObjects(img gocv.Mat, frameNum int, timing *Timing) ([]postprocess.DetectResult, error) {
|
|
|
|
timing.DetObjStart = time.Now()
|
|
|
|
// convert colorspace and resize image
|
|
rgbImg := gocv.NewMat()
|
|
defer rgbImg.Close()
|
|
gocv.CvtColor(img, &rgbImg, gocv.ColorBGRToRGB)
|
|
|
|
cropImg := rgbImg.Clone()
|
|
defer cropImg.Close()
|
|
scaleSize := image.Pt(int(d.inputAttrs[0].Dims[1]), int(d.inputAttrs[0].Dims[2]))
|
|
gocv.Resize(rgbImg, &cropImg, scaleSize, 0, 0, gocv.InterpolationArea)
|
|
|
|
// perform inference on image file
|
|
rt := d.pool.Get()
|
|
outputs, err := rt.Inference([]gocv.Mat{cropImg})
|
|
d.pool.Return(rt)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Runtime inferencing failed with error: %w", err)
|
|
}
|
|
|
|
timing.DetObjInferenceEnd = time.Now()
|
|
|
|
detectResults := d.process.DetectObjects(outputs)
|
|
|
|
timing.DetObjEnd = time.Now()
|
|
|
|
// free outputs allocated in C memory after you have finished post processing
|
|
err = outputs.Free()
|
|
|
|
return detectResults, nil
|
|
}
|
|
|
|
func main() {
|
|
// disable logging timestamps
|
|
log.SetFlags(0)
|
|
|
|
// read in cli flags
|
|
modelFile := flag.String("m", "../data/yolov5s-640-640-rk3588.rknn", "RKNN compiled YOLO model file")
|
|
modelType := flag.String("t", "v5", "Version of YOLO model [v5|v8]")
|
|
vidFile := flag.String("v", "../data/palace.mp4", "Video file to run object detection and tracking on")
|
|
labelFile := flag.String("l", "../data/coco_80_labels_list.txt", "Text file containing model labels")
|
|
httpAddr := flag.String("a", "localhost:8080", "HTTP Address to run server on, format address:port")
|
|
poolSize := flag.Int("s", 3, "Size of RKNN runtime pool, choose 1, 2, 3, or multiples of 3")
|
|
limitLabels := flag.String("x", "", "Comma delimited list of labels (COCO) to restrict object tracking to")
|
|
|
|
flag.Parse()
|
|
|
|
err := rknnlite.SetCPUAffinity(rknnlite.RK3588FastCores)
|
|
|
|
if err != nil {
|
|
log.Printf("Failed to set CPU Affinity: %w", err)
|
|
}
|
|
|
|
demo, err := NewDemo(*vidFile, *modelFile, *labelFile, *poolSize, *modelType)
|
|
|
|
if err != nil {
|
|
log.Fatalf("Error creating demo: %v", err)
|
|
}
|
|
|
|
if *limitLabels != "" {
|
|
demo.LimitObjects(*limitLabels)
|
|
}
|
|
|
|
http.HandleFunc("/stream", demo.Stream)
|
|
|
|
// start http server
|
|
log.Println(fmt.Sprintf("Open browser and view video at http://%s/stream",
|
|
*httpAddr))
|
|
log.Fatal(http.ListenAndServe(*httpAddr, nil))
|
|
}
|