added LPRNet example for license plate detection

This commit is contained in:
swdee
2024-05-11 14:15:19 +12:00
parent bece340dfe
commit bb5f0035ba
4 changed files with 361 additions and 0 deletions

193
example/lprnet/lprnet.go Normal file
View File

@@ -0,0 +1,193 @@
/*
Example code showing how to perform inferencing using a LPRnet model
*/
package main
import (
"flag"
"fmt"
"github.com/swdee/go-rknnlite"
"github.com/swdee/go-rknnlite/postprocess"
"gocv.io/x/gocv"
"image"
"log"
"time"
)
func main() {
// disable logging timestamps
log.SetFlags(0)
// read in cli flags
modelFile := flag.String("m", "../data/lprnet-rk3588.rknn", "RKNN compiled model file")
imgFile := flag.String("i", "../data/lplate.jpg", "Image file to run inference on")
flag.Parse()
// create rknn runtime instance
rt, err := rknnlite.NewRuntime(*modelFile, rknnlite.NPUCoreAuto)
if err != nil {
log.Fatal("Error initializing RKNN runtime: ", err)
}
// optional querying of model file tensors and SDK version. not necessary
// for production inference code
inputAttrs := optionalQueries(rt)
// create LPRNet post processor using parameters used during model training
lprnetProcesser := postprocess.NewLPRNet(postprocess.LPRNetParams{
PlatePositions: 18,
PlateChars: []string{
"京", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑",
"苏", "浙", "皖", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤",
"桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁",
"新",
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
"A", "B", "C", "D", "E", "F", "G", "H", "J", "K",
"L", "M", "N", "P", "Q", "R", "S", "T", "U", "V",
"W", "X", "Y", "Z", "I", "O", "-",
},
})
// load image
img := gocv.IMRead(*imgFile, gocv.IMReadColor)
if img.Empty() {
log.Fatal("Error reading image from: ", *imgFile)
}
// resize image to 94x24
cropImg := gocv.NewMat()
scaleSize := image.Pt(int(inputAttrs[0].Dims[2]), int(inputAttrs[0].Dims[1]))
gocv.Resize(img, &cropImg, scaleSize, 0, 0, gocv.InterpolationArea)
defer img.Close()
defer cropImg.Close()
start := time.Now()
// perform inference on image file
outputs, err := rt.Inference([]gocv.Mat{cropImg})
if err != nil {
log.Fatal("Runtime inferencing failed with error: ", err)
}
endInference := time.Now()
// read number plates from outputs
plates := lprnetProcesser.ReadPlates(outputs)
endDetect := time.Now()
log.Printf("Model first run speed: inference=%s, post processing=%s, total time=%s\n",
endInference.Sub(start).String(),
endDetect.Sub(endInference).String(),
endDetect.Sub(start).String(),
)
for _, plate := range plates {
log.Printf("License plate recognition result: %s\n", plate)
}
// free outputs allocated in C memory after you have finished post processing
err = outputs.Free()
if err != nil {
log.Fatal("Error freeing Outputs: ", err)
}
// optional code. run benchmark to get average time of 10 runs
runBenchmark(rt, lprnetProcesser, []gocv.Mat{cropImg})
// close runtime and release resources
err = rt.Close()
if err != nil {
log.Fatal("Error closing RKNN runtime: ", err)
}
log.Println("done")
}
func runBenchmark(rt *rknnlite.Runtime, lprnetProcesser *postprocess.LPRNet,
mats []gocv.Mat) {
count := 10
start := time.Now()
for i := 0; i < count; i++ {
// perform inference on image file
outputs, err := rt.Inference(mats)
if err != nil {
log.Fatal("Runtime inferencing failed with error: ", err)
}
// post process
_ = lprnetProcesser.ReadPlates(outputs)
err = outputs.Free()
if err != nil {
log.Fatal("Error freeing Outputs: ", err)
}
}
end := time.Now()
total := end.Sub(start)
avg := total / time.Duration(count)
log.Printf("Benchmark time=%s, count=%d, average total time=%s\n",
total.String(), count, avg.String(),
)
}
func optionalQueries(rt *rknnlite.Runtime) []rknnlite.TensorAttr {
// get SDK version
ver, err := rt.SDKVersion()
if err != nil {
log.Fatal("Error initializing RKNN runtime: ", err)
}
fmt.Printf("Driver Version: %s, API Version: %s\n", ver.DriverVersion, ver.APIVersion)
// get model input and output numbers
num, err := rt.QueryModelIONumber()
if err != nil {
log.Fatal("Error querying IO Numbers: ", err)
}
log.Printf("Model Input Number: %d, Ouput Number: %d\n", num.NumberInput, num.NumberOutput)
// query Input tensors
inputAttrs, err := rt.QueryInputTensors()
if err != nil {
log.Fatal("Error querying Input Tensors: ", err)
}
log.Println("Input tensors:")
for _, attr := range inputAttrs {
log.Printf(" %s\n", attr.String())
}
// query Output tensors
outputAttrs, err := rt.QueryOutputTensors()
if err != nil {
log.Fatal("Error querying Output Tensors: ", err)
}
log.Println("Output tensors:")
for _, attr := range outputAttrs {
log.Printf(" %s\n", attr.String())
}
return inputAttrs
}