mirror of
https://github.com/swdee/go-rknnlite.git
synced 2025-09-27 03:35:56 +08:00
363 lines
9.7 KiB
Go
363 lines
9.7 KiB
Go
package preprocess
|
||
|
||
import (
|
||
"errors"
|
||
"github.com/swdee/go-rknnlite/postprocess/result"
|
||
"gocv.io/x/gocv"
|
||
"image"
|
||
"image/color"
|
||
"math"
|
||
"sort"
|
||
)
|
||
|
||
// SAHI defines the struct used for Slicing Aided Hyper Inference
|
||
type SAHI struct {
|
||
// sliceWidth is the width of each image slice
|
||
sliceWidth int
|
||
// sliceHeight is the height of each image slice
|
||
sliceHeight int
|
||
// overlapWidth is a ratio from 0.0 to 1.0 to represent the number of pixels
|
||
// to overlap each slice. A value of 0.2 represents 20% of sliceWidth's pixels
|
||
overlapWidth float32
|
||
// overlapHeight is a ratio from 0.0 to 1.0 to represent the number of pixels
|
||
// to overlap each slice. A value of 0.2 represents 20% of sliceHeight's pixels
|
||
overlapHeight float32
|
||
// results stores a slices object detection results
|
||
results []sahiResult
|
||
// nextID is a counter that increments and provides the next number
|
||
// for each detection result ID
|
||
idGen *result.IDGenerator
|
||
}
|
||
|
||
// sahiResult defines a struct to store a slice and its detection results
|
||
type sahiResult struct {
|
||
slice Slice
|
||
det []result.DetectResult
|
||
}
|
||
|
||
// Slice defines the struct used to store the coordinates for a slice of the
|
||
// source image
|
||
type Slice struct {
|
||
// X is the coordinate of the slices left corner
|
||
X int
|
||
// Y is the coordinate of the slices top corner
|
||
Y int
|
||
// X2 is the coordinate of the slices right corner
|
||
X2 int
|
||
// Y2 is the coordinate of the slices bottom corner
|
||
Y2 int
|
||
// slice is the sliced image Mat
|
||
slice gocv.Mat
|
||
// resizer is an instance of the image resizer
|
||
resizer *Resizer
|
||
// destMat is the destination Mat after crop and resize of the slice
|
||
destMat gocv.Mat
|
||
}
|
||
|
||
// NewSAHI returns a SAHI (Slicing Aided Hyper Inference) instance for slicing
|
||
// a source image into a series of tiles for inference. The sliceWidth and
|
||
// sliceHeight should be the same size as the Models input tensor dimensions.
|
||
func NewSAHI(sliceWidth, sliceHeight int, overlapWidth, overlapHeight float32) *SAHI {
|
||
s := &SAHI{
|
||
sliceWidth: sliceWidth,
|
||
sliceHeight: sliceHeight,
|
||
overlapWidth: overlapWidth,
|
||
overlapHeight: overlapHeight,
|
||
results: make([]sahiResult, 0),
|
||
idGen: result.NewIDGenerator(),
|
||
}
|
||
|
||
return s
|
||
}
|
||
|
||
// computePositions returns the start‐coordinates (0‐based) of each tile
|
||
// along one axis, and the computed tile length. It guarantees:
|
||
//
|
||
// - you get the smallest n tiles so that n*tileLen – (n−1)*step >= srcLen
|
||
// - step = (srcLen−tileLen)/(n−1) is =< sliceLen
|
||
// - thus overlap = tileLen − step >= sliceLen*overlapRatio
|
||
//
|
||
// any leftover pixels to cover the image get spread evenly via rounding.
|
||
// function returns a slice of positions and tileLen
|
||
func (s *SAHI) computePositions(srcLen, sliceLen int, overlapRatio float32) ([]int, int) {
|
||
|
||
// minimum pixel‐overlap
|
||
minOv := int(math.Ceil(float64(sliceLen) * float64(overlapRatio)))
|
||
|
||
// tile length = sliceLen + that minimum overlap
|
||
tileLen := sliceLen + minOv
|
||
|
||
// how many tiles you'd need if you stepped by sliceLen each time?
|
||
// this ensures step =< sliceLen and so overlap >= minOv
|
||
n := int(math.Ceil(float64(srcLen-tileLen)/float64(sliceLen))) + 1
|
||
if n < 1 {
|
||
n = 1
|
||
}
|
||
|
||
// actual step (evenly spread)
|
||
denom := n - 1
|
||
var step float64
|
||
|
||
if denom > 0 {
|
||
step = float64(srcLen-tileLen) / float64(denom)
|
||
} else {
|
||
step = 0
|
||
}
|
||
|
||
// build positions, rounding to distribute leftovers
|
||
positions := make([]int, n)
|
||
|
||
for i := 0; i < n; i++ {
|
||
p := int(math.Round(step * float64(i)))
|
||
|
||
// clamp to [0, srcLen-tileLen]
|
||
if p < 0 {
|
||
p = 0
|
||
} else if p > srcLen-tileLen {
|
||
p = srcLen - tileLen
|
||
}
|
||
|
||
positions[i] = p
|
||
}
|
||
|
||
return positions, tileLen
|
||
}
|
||
|
||
// Slice slices the given input image into a series of tiles
|
||
func (s *SAHI) Slice(src gocv.Mat) []Slice {
|
||
// get dimensions of source image
|
||
srcH, srcW := src.Rows(), src.Cols()
|
||
|
||
// compute X starts and tileW
|
||
xs, tileW := s.computePositions(srcW, s.sliceWidth, s.overlapWidth)
|
||
// compute Y starts and tileH
|
||
ys, tileH := s.computePositions(srcH, s.sliceHeight, s.overlapHeight)
|
||
|
||
slices := make([]Slice, 0, len(xs)*len(ys))
|
||
|
||
for _, y := range ys {
|
||
for _, x := range xs {
|
||
rect := image.Rect(x, y, x+tileW, y+tileH)
|
||
|
||
// letter‐box this tile down to sliceWidth x sliceHeight
|
||
resizer := NewResizer(tileW, tileH, s.sliceWidth, s.sliceHeight)
|
||
|
||
slices = append(slices, Slice{
|
||
X: x,
|
||
Y: y,
|
||
X2: x + tileW,
|
||
Y2: y + tileH,
|
||
slice: src.Region(rect),
|
||
resizer: resizer,
|
||
destMat: gocv.NewMat(),
|
||
})
|
||
}
|
||
}
|
||
|
||
return slices
|
||
}
|
||
|
||
// AddResult adds the slice and its detection result
|
||
func (s *SAHI) AddResult(slice Slice, res []result.DetectResult) {
|
||
s.results = append(s.results, sahiResult{
|
||
slice: slice,
|
||
det: res,
|
||
})
|
||
}
|
||
|
||
// GetDetectResults returns the global detection results for the source image
|
||
// made up from the combination of all slices detect results.
|
||
// - iouThreshold is the intersection-over-union value (NMSThreshold) used in the YOLO
|
||
// processor parameters.
|
||
// - smallBoxOverlapThreshold is the percentage represented from a value of
|
||
// 0 to 1 that small boxes need to overlap by to be discarded which occurs when
|
||
// and object sits on the slice/tile overlap boundary
|
||
func (s *SAHI) GetDetectResults(iouThreshold, smallBoxOverlapThresh float32) []result.DetectResult {
|
||
|
||
// collate objects into a result for returning
|
||
group := make([]result.DetectResult, 0)
|
||
|
||
// for each slice
|
||
for _, sr := range s.results {
|
||
|
||
// for each object/detection in the slice
|
||
for _, dr := range sr.det {
|
||
// remap the coordinates of the detection result to the global coordinates
|
||
// of the source image
|
||
gresult := result.DetectResult{
|
||
Box: result.BoxRect{
|
||
Left: sr.slice.X + dr.Box.Left,
|
||
Top: sr.slice.Y + dr.Box.Top,
|
||
Right: sr.slice.X + dr.Box.Right,
|
||
Bottom: sr.slice.Y + dr.Box.Bottom,
|
||
},
|
||
Probability: dr.Probability,
|
||
Class: dr.Class,
|
||
ID: s.idGen.GetNext(),
|
||
}
|
||
|
||
group = append(group, gresult)
|
||
}
|
||
}
|
||
|
||
// Sort by descending probability
|
||
sort.Slice(group, func(i, j int) bool {
|
||
return group[i].Probability > group[j].Probability
|
||
})
|
||
|
||
// run class-aware NMS
|
||
return s.nmsCluster(group, iouThreshold, smallBoxOverlapThresh)
|
||
}
|
||
|
||
// FreeResults releases stored detection results. This should be called after
|
||
// processing of each image to clear out results added by AddResult().
|
||
func (s *SAHI) FreeResults() {
|
||
s.results = s.results[:0]
|
||
}
|
||
|
||
// iou computes the Intersection-over-Union of two boxes.
|
||
func (s *SAHI) iou(a, b result.BoxRect) float32 {
|
||
x1 := max(a.Left, b.Left)
|
||
y1 := max(a.Top, b.Top)
|
||
x2 := min(a.Right, b.Right)
|
||
y2 := min(a.Bottom, b.Bottom)
|
||
|
||
w := float32(max(0, x2-x1))
|
||
h := float32(max(0, y2-y1))
|
||
inter := w * h
|
||
|
||
areaA := float32((a.Right - a.Left) * (a.Bottom - a.Top))
|
||
areaB := float32((b.Right - b.Left) * (b.Bottom - b.Top))
|
||
|
||
return inter / (areaA + areaB - inter)
|
||
}
|
||
|
||
// max returns max between two numbers
|
||
func max(a, b int) int {
|
||
if a > b {
|
||
return a
|
||
}
|
||
return b
|
||
}
|
||
|
||
// min returns minimum between two numbers
|
||
func min(a, b int) int {
|
||
if a < b {
|
||
return a
|
||
}
|
||
return b
|
||
}
|
||
|
||
// nmsCluster picks one box per overlapping cluster (class‐agnostic),
|
||
// choosing the larghest area (tie‐break on confidence), and uses both IoU
|
||
// and small‐box overlap tests to form clusters.
|
||
// - detections must be sorted descending by Probability before calling.
|
||
// - iouThreshold is your overlap cutoff (e.g. 0.45)
|
||
// - smallBoxOverlapThresh is the fraction of the small box’s area that must
|
||
// be overlapped to consider it a duplicate (e.g. 0.7)
|
||
func (s *SAHI) nmsCluster(dets []result.DetectResult,
|
||
iouThreshold, smallBoxOverlapThresh float32) []result.DetectResult {
|
||
|
||
n := len(dets)
|
||
suppressed := make([]bool, n)
|
||
keep := make([]result.DetectResult, 0, n)
|
||
|
||
for i, base := range dets {
|
||
if suppressed[i] {
|
||
continue
|
||
}
|
||
|
||
// start a new cluster with "base"
|
||
cluster := []result.DetectResult{base}
|
||
suppressed[i] = true
|
||
|
||
for j := i + 1; j < n; j++ {
|
||
if suppressed[j] {
|
||
continue
|
||
}
|
||
|
||
other := dets[j]
|
||
|
||
// decide if "other" belongs in this cluster
|
||
inCluster := false
|
||
|
||
// IoU test
|
||
if s.iou(base.Box, other.Box) > iouThreshold {
|
||
inCluster = true
|
||
} else {
|
||
// small‐box self‐overlap test
|
||
inter := s.intersectionArea(base.Box, other.Box)
|
||
areaOther := s.boxArea(other.Box)
|
||
if areaOther > 0 && float32(inter)/float32(areaOther) > smallBoxOverlapThresh {
|
||
inCluster = true
|
||
}
|
||
}
|
||
|
||
if !inCluster {
|
||
continue
|
||
}
|
||
|
||
// assign to this cluster
|
||
suppressed[j] = true
|
||
cluster = append(cluster, other)
|
||
}
|
||
|
||
// pick the single largest‐area box (tie‐break on probability)
|
||
best := cluster[0]
|
||
bestArea := s.boxArea(best.Box)
|
||
|
||
for _, c := range cluster[1:] {
|
||
a := s.boxArea(c.Box)
|
||
|
||
if a > bestArea || (a == bestArea && c.Probability > best.Probability) {
|
||
best = c
|
||
bestArea = a
|
||
}
|
||
}
|
||
|
||
keep = append(keep, best)
|
||
}
|
||
|
||
return keep
|
||
}
|
||
|
||
// intersectionArea returns the raw pixel-area of overlap between two boxes.
|
||
func (s *SAHI) intersectionArea(a, b result.BoxRect) int {
|
||
x1 := max(a.Left, b.Left)
|
||
y1 := max(a.Top, b.Top)
|
||
x2 := min(a.Right, b.Right)
|
||
y2 := min(a.Bottom, b.Bottom)
|
||
|
||
if x2 <= x1 || y2 <= y1 {
|
||
return 0
|
||
}
|
||
|
||
return (x2 - x1) * (y2 - y1)
|
||
}
|
||
|
||
// boxArea returns the pixel‐area of a single box.
|
||
func (s *SAHI) boxArea(a result.BoxRect) int {
|
||
return max(0, a.Right-a.Left) * max(0, a.Bottom-a.Top)
|
||
}
|
||
|
||
// Mat returns the slices Mat after cropping and resizing to be used
|
||
// for running inference on
|
||
func (s *Slice) Mat() *gocv.Mat {
|
||
s.resizer.LetterBoxResize(s.slice, &s.destMat, color.RGBA{R: 0, G: 0, B: 0, A: 255})
|
||
return &s.destMat
|
||
}
|
||
|
||
// Resizer returns the slices letter box resize
|
||
func (s *Slice) Resizer() *Resizer {
|
||
return s.resizer
|
||
}
|
||
|
||
// Free releases the slice from memory
|
||
func (s *Slice) Free() error {
|
||
err := s.resizer.Close()
|
||
err2 := s.slice.Close()
|
||
err3 := s.destMat.Close()
|
||
|
||
return errors.Join(err, err2, err3)
|
||
}
|