Files
yolotriton/yolovx.go
2025-08-23 10:25:53 +00:00

114 lines
2.4 KiB
Go

package yolotriton
import (
"image"
triton "github.com/dev6699/yolotriton/grpc-client"
)
type Yolo struct {
YoloTritonConfig
metadata struct {
scaleFactorW float32
scaleFactorH float32
}
io16 bool
}
func NewYolo(cfg YoloTritonConfig, io16 bool) Model {
return &Yolo{
YoloTritonConfig: cfg,
io16: io16,
}
}
var _ Model = &Yolo{}
func (y *Yolo) GetConfig() YoloTritonConfig {
return y.YoloTritonConfig
}
func (y *Yolo) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
width := img.Bounds().Dx()
height := img.Bounds().Dy()
preprocessedImg := resizeImage(img, targetWidth, targetHeight)
y.metadata.scaleFactorW = float32(width) / float32(targetWidth)
y.metadata.scaleFactorH = float32(height) / float32(targetHeight)
if y.io16 {
bytesContents := imageToFloat16ByteSlice(preprocessedImg)
contents := &triton.InferTensorContents{
BytesContents: bytesContents,
}
return contents, nil
} else {
fp32Contents := imageToFloat32Slice(preprocessedImg)
contents := &triton.InferTensorContents{
Fp32Contents: fp32Contents,
}
return contents, nil
}
}
func (y *Yolo) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
var output []float32
var err error
if y.io16 {
output, err = bytesFP16ToFloat32Slice(rawOutputContents[0])
if err != nil {
return nil, err
}
} else {
output, err = bytesToFloat32Slice(rawOutputContents[0])
if err != nil {
return nil, err
}
}
numObjects := y.NumObjects
numClasses := y.NumClasses
boxes := []Box{}
for index := 0; index < numObjects; index++ {
classID := 0
prob := float32(0.0)
for col := 0; col < numClasses; col++ {
p := output[numObjects*(col+4)+index]
if p > prob {
prob = p
classID = col
}
}
if prob < y.MinProbability {
continue
}
label := y.Classes[classID]
x1raw := output[index]
y1raw := output[numObjects+index]
w := output[2*numObjects+index]
h := output[3*numObjects+index]
x1 := (x1raw - w/2) * y.metadata.scaleFactorW
y1 := (y1raw - h/2) * y.metadata.scaleFactorH
x2 := (x1raw + w/2) * y.metadata.scaleFactorW
y2 := (y1raw + h/2) * y.metadata.scaleFactorH
boxes = append(boxes, Box{
X1: float64(x1),
Y1: float64(y1),
X2: float64(x2),
Y2: float64(y2),
Probability: float64(prob),
Class: label,
})
}
return boxes, nil
}