Files
go-rknnlite/preprocess/sahi.go
2025-06-21 16:30:37 +12:00

363 lines
9.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package preprocess
import (
"errors"
"github.com/swdee/go-rknnlite/postprocess/result"
"gocv.io/x/gocv"
"image"
"image/color"
"math"
"sort"
)
// SAHI defines the struct used for Slicing Aided Hyper Inference
type SAHI struct {
// sliceWidth is the width of each image slice
sliceWidth int
// sliceHeight is the height of each image slice
sliceHeight int
// overlapWidth is a ratio from 0.0 to 1.0 to represent the number of pixels
// to overlap each slice. A value of 0.2 represents 20% of sliceWidth's pixels
overlapWidth float32
// overlapHeight is a ratio from 0.0 to 1.0 to represent the number of pixels
// to overlap each slice. A value of 0.2 represents 20% of sliceHeight's pixels
overlapHeight float32
// results stores a slices object detection results
results []sahiResult
// nextID is a counter that increments and provides the next number
// for each detection result ID
idGen *result.IDGenerator
}
// sahiResult defines a struct to store a slice and its detection results
type sahiResult struct {
slice Slice
det []result.DetectResult
}
// Slice defines the struct used to store the coordinates for a slice of the
// source image
type Slice struct {
// X is the coordinate of the slices left corner
X int
// Y is the coordinate of the slices top corner
Y int
// X2 is the coordinate of the slices right corner
X2 int
// Y2 is the coordinate of the slices bottom corner
Y2 int
// slice is the sliced image Mat
slice gocv.Mat
// resizer is an instance of the image resizer
resizer *Resizer
// destMat is the destination Mat after crop and resize of the slice
destMat gocv.Mat
}
// NewSAHI returns a SAHI (Slicing Aided Hyper Inference) instance for slicing
// a source image into a series of tiles for inference. The sliceWidth and
// sliceHeight should be the same size as the Models input tensor dimensions.
func NewSAHI(sliceWidth, sliceHeight int, overlapWidth, overlapHeight float32) *SAHI {
s := &SAHI{
sliceWidth: sliceWidth,
sliceHeight: sliceHeight,
overlapWidth: overlapWidth,
overlapHeight: overlapHeight,
results: make([]sahiResult, 0),
idGen: result.NewIDGenerator(),
}
return s
}
// computePositions returns the startcoordinates (0based) of each tile
// along one axis, and the computed tile length. It guarantees:
//
// - you get the smallest n tiles so that n*tileLen (n1)*step >= srcLen
// - step = (srcLentileLen)/(n1) is =< sliceLen
// - thus overlap = tileLen step >= sliceLen*overlapRatio
//
// any leftover pixels to cover the image get spread evenly via rounding.
// function returns a slice of positions and tileLen
func (s *SAHI) computePositions(srcLen, sliceLen int, overlapRatio float32) ([]int, int) {
// minimum pixeloverlap
minOv := int(math.Ceil(float64(sliceLen) * float64(overlapRatio)))
// tile length = sliceLen + that minimum overlap
tileLen := sliceLen + minOv
// how many tiles you'd need if you stepped by sliceLen each time?
// this ensures step =< sliceLen and so overlap >= minOv
n := int(math.Ceil(float64(srcLen-tileLen)/float64(sliceLen))) + 1
if n < 1 {
n = 1
}
// actual step (evenly spread)
denom := n - 1
var step float64
if denom > 0 {
step = float64(srcLen-tileLen) / float64(denom)
} else {
step = 0
}
// build positions, rounding to distribute leftovers
positions := make([]int, n)
for i := 0; i < n; i++ {
p := int(math.Round(step * float64(i)))
// clamp to [0, srcLen-tileLen]
if p < 0 {
p = 0
} else if p > srcLen-tileLen {
p = srcLen - tileLen
}
positions[i] = p
}
return positions, tileLen
}
// Slice slices the given input image into a series of tiles
func (s *SAHI) Slice(src gocv.Mat) []Slice {
// get dimensions of source image
srcH, srcW := src.Rows(), src.Cols()
// compute X starts and tileW
xs, tileW := s.computePositions(srcW, s.sliceWidth, s.overlapWidth)
// compute Y starts and tileH
ys, tileH := s.computePositions(srcH, s.sliceHeight, s.overlapHeight)
slices := make([]Slice, 0, len(xs)*len(ys))
for _, y := range ys {
for _, x := range xs {
rect := image.Rect(x, y, x+tileW, y+tileH)
// letterbox this tile down to sliceWidth x sliceHeight
resizer := NewResizer(tileW, tileH, s.sliceWidth, s.sliceHeight)
slices = append(slices, Slice{
X: x,
Y: y,
X2: x + tileW,
Y2: y + tileH,
slice: src.Region(rect),
resizer: resizer,
destMat: gocv.NewMat(),
})
}
}
return slices
}
// AddResult adds the slice and its detection result
func (s *SAHI) AddResult(slice Slice, res []result.DetectResult) {
s.results = append(s.results, sahiResult{
slice: slice,
det: res,
})
}
// GetDetectResults returns the global detection results for the source image
// made up from the combination of all slices detect results.
// - iouThreshold is the intersection-over-union value (NMSThreshold) used in the YOLO
// processor parameters.
// - smallBoxOverlapThreshold is the percentage represented from a value of
// 0 to 1 that small boxes need to overlap by to be discarded which occurs when
// and object sits on the slice/tile overlap boundary
func (s *SAHI) GetDetectResults(iouThreshold, smallBoxOverlapThresh float32) []result.DetectResult {
// collate objects into a result for returning
group := make([]result.DetectResult, 0)
// for each slice
for _, sr := range s.results {
// for each object/detection in the slice
for _, dr := range sr.det {
// remap the coordinates of the detection result to the global coordinates
// of the source image
gresult := result.DetectResult{
Box: result.BoxRect{
Left: sr.slice.X + dr.Box.Left,
Top: sr.slice.Y + dr.Box.Top,
Right: sr.slice.X + dr.Box.Right,
Bottom: sr.slice.Y + dr.Box.Bottom,
},
Probability: dr.Probability,
Class: dr.Class,
ID: s.idGen.GetNext(),
}
group = append(group, gresult)
}
}
// Sort by descending probability
sort.Slice(group, func(i, j int) bool {
return group[i].Probability > group[j].Probability
})
// run class-aware NMS
return s.nmsCluster(group, iouThreshold, smallBoxOverlapThresh)
}
// FreeResults releases stored detection results. This should be called after
// processing of each image to clear out results added by AddResult().
func (s *SAHI) FreeResults() {
s.results = s.results[:0]
}
// iou computes the Intersection-over-Union of two boxes.
func (s *SAHI) iou(a, b result.BoxRect) float32 {
x1 := max(a.Left, b.Left)
y1 := max(a.Top, b.Top)
x2 := min(a.Right, b.Right)
y2 := min(a.Bottom, b.Bottom)
w := float32(max(0, x2-x1))
h := float32(max(0, y2-y1))
inter := w * h
areaA := float32((a.Right - a.Left) * (a.Bottom - a.Top))
areaB := float32((b.Right - b.Left) * (b.Bottom - b.Top))
return inter / (areaA + areaB - inter)
}
// max returns max between two numbers
func max(a, b int) int {
if a > b {
return a
}
return b
}
// min returns minimum between two numbers
func min(a, b int) int {
if a < b {
return a
}
return b
}
// nmsCluster picks one box per overlapping cluster (classagnostic),
// choosing the larghest area (tiebreak on confidence), and uses both IoU
// and smallbox overlap tests to form clusters.
// - detections must be sorted descending by Probability before calling.
// - iouThreshold is your overlap cutoff (e.g. 0.45)
// - smallBoxOverlapThresh is the fraction of the small boxs area that must
// be overlapped to consider it a duplicate (e.g. 0.7)
func (s *SAHI) nmsCluster(dets []result.DetectResult,
iouThreshold, smallBoxOverlapThresh float32) []result.DetectResult {
n := len(dets)
suppressed := make([]bool, n)
keep := make([]result.DetectResult, 0, n)
for i, base := range dets {
if suppressed[i] {
continue
}
// start a new cluster with "base"
cluster := []result.DetectResult{base}
suppressed[i] = true
for j := i + 1; j < n; j++ {
if suppressed[j] {
continue
}
other := dets[j]
// decide if "other" belongs in this cluster
inCluster := false
// IoU test
if s.iou(base.Box, other.Box) > iouThreshold {
inCluster = true
} else {
// smallbox selfoverlap test
inter := s.intersectionArea(base.Box, other.Box)
areaOther := s.boxArea(other.Box)
if areaOther > 0 && float32(inter)/float32(areaOther) > smallBoxOverlapThresh {
inCluster = true
}
}
if !inCluster {
continue
}
// assign to this cluster
suppressed[j] = true
cluster = append(cluster, other)
}
// pick the single largestarea box (tiebreak on probability)
best := cluster[0]
bestArea := s.boxArea(best.Box)
for _, c := range cluster[1:] {
a := s.boxArea(c.Box)
if a > bestArea || (a == bestArea && c.Probability > best.Probability) {
best = c
bestArea = a
}
}
keep = append(keep, best)
}
return keep
}
// intersectionArea returns the raw pixel-area of overlap between two boxes.
func (s *SAHI) intersectionArea(a, b result.BoxRect) int {
x1 := max(a.Left, b.Left)
y1 := max(a.Top, b.Top)
x2 := min(a.Right, b.Right)
y2 := min(a.Bottom, b.Bottom)
if x2 <= x1 || y2 <= y1 {
return 0
}
return (x2 - x1) * (y2 - y1)
}
// boxArea returns the pixelarea of a single box.
func (s *SAHI) boxArea(a result.BoxRect) int {
return max(0, a.Right-a.Left) * max(0, a.Bottom-a.Top)
}
// Mat returns the slices Mat after cropping and resizing to be used
// for running inference on
func (s *Slice) Mat() *gocv.Mat {
s.resizer.LetterBoxResize(s.slice, &s.destMat, color.RGBA{R: 0, G: 0, B: 0, A: 255})
return &s.destMat
}
// Resizer returns the slices letter box resize
func (s *Slice) Resizer() *Resizer {
return s.resizer
}
// Free releases the slice from memory
func (s *Slice) Free() error {
err := s.resizer.Close()
err2 := s.slice.Close()
err3 := s.destMat.Close()
return errors.Join(err, err2, err3)
}