diff --git a/README.md b/README.md
index d095518..cf11cdd 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
[](https://goreportcard.com/report/github.com/dev6699/yolotriton)
[](LICENSE)
-Go (Golang) gRPC client for YOLOv8 inference using the Triton Inference Server.
+Go (Golang) gRPC client for YOLO-NAS, YOLOv8 inference using the Triton Inference Server.
## Installation
@@ -14,11 +14,12 @@ Use `go get` to install this package:
go get github.com/dev6699/yolotriton
```
-### Get YOLOv8 TensorRT model
+### Get YOLO-NAS, YOLOv8 TensorRT model
+Replace `yolov8m.pt` with your desired model
```bash
pip install ultralytics
yolo export model=yolov8m.pt format=onnx
-trtexec --onnx=yolov8m.onnx --saveEngine=model_repository/yolov8_tensorrt/1/model.plan
+trtexec --onnx=yolov8m.onnx --saveEngine=model_repository/yolov8/1/model.plan
```
References:
@@ -39,20 +40,39 @@ Check [cmd/main.go](cmd/main.go) for more details.
Available args:
```bash
-i string
- Inference Image. Default: images/1.jpg (default "images/1.jpg")
+ Inference Image. (default "images/1.jpg")
-m string
- Name of model being served. (Required) (default "yolov8_tensorrt")
+ Name of model being served (Required) (default "yolonas")
+ -t string
+ Type of model. Available options: [yolonas, yolov8] (default "yolonas")
-u string
- Inference Server URL. Default: tritonserver:8001 (default "tritonserver:8001")
+ Inference Server URL. (default "tritonserver:8001")
-x string
- Version of model. Default: Latest Version.
+ Version of model. Default: Latest Version
```
```bash
go run cmd/main.go
```
### Results
-| Input | Ouput |
-| --------------------------- | ------------------------------- |
-|
|
|
-|
|
|
\ No newline at end of file
+```
+prediction: 0
+class: dog
+confidence: 0.96
+bboxes: [ 669 130 1061 563 ]
+---------------------
+prediction: 1
+class: person
+confidence: 0.96
+bboxes: [ 440 30 760 541 ]
+---------------------
+prediction: 2
+class: dog
+confidence: 0.93
+bboxes: [ 168 83 495 592 ]
+---------------------
+```
+| Input | YOLO-NAS Ouput | YOLOv8 Output |
+| --------------------------- | --------------------------------------- | -------------------------------------- |
+|
|
|
|
+|
|
|
|
\ No newline at end of file
diff --git a/cmd/main.go b/cmd/main.go
index 9728001..6f59b80 100644
--- a/cmd/main.go
+++ b/cmd/main.go
@@ -12,16 +12,18 @@ import (
type Flags struct {
ModelName string
ModelVersion string
+ ModelType string
URL string
Image string
}
func parseFlags() Flags {
var flags Flags
- flag.StringVar(&flags.ModelName, "m", "yolov8_tensorrt", "Name of model being served. (Required)")
- flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version.")
- flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL. Default: tritonserver:8001")
- flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image. Default: images/1.jpg")
+ flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
+ flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
+ flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolov8]")
+ flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
+ flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.")
flag.Parse()
return flags
}
@@ -30,20 +32,17 @@ func main() {
FLAGS := parseFlags()
fmt.Println("FLAGS:", FLAGS)
- ygt, err := yolotriton.New(
- FLAGS.URL,
- yolotriton.YoloTritonConfig{
- BatchSize: 1,
- NumChannels: 84,
- NumObjects: 8400,
- Width: 640,
- Height: 640,
- ModelName: FLAGS.ModelName,
- ModelVersion: FLAGS.ModelVersion,
- MinProbability: 0.5,
- MaxIOU: 0.7,
- })
+ var model yolotriton.Model
+ switch yolotriton.ModelType(FLAGS.ModelType) {
+ case yolotriton.ModelTypeYoloV8:
+ model = yolotriton.NewYoloV8(FLAGS.ModelName, FLAGS.ModelVersion)
+ case yolotriton.ModelTypeYoloNAS:
+ model = yolotriton.NewYoloNAS(FLAGS.ModelName, FLAGS.ModelVersion)
+ default:
+ log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolov8]", FLAGS.ModelType)
+ }
+ ygt, err := yolotriton.New(FLAGS.URL, model)
if err != nil {
log.Fatal(err)
}
@@ -59,17 +58,24 @@ func main() {
}
for i, r := range results {
- fmt.Printf("---%d---", i)
- fmt.Println(r.Class, r.Probability)
- fmt.Println("[x1,x2,y1,y2]", int(r.X1), int(r.X2), int(r.Y1), int(r.Y2))
+ fmt.Println("prediction: ", i)
+ fmt.Println("class: ", r.Class)
+ fmt.Printf("confidence: %.2f\n", r.Probability)
+ fmt.Println("bboxes: [", int(r.X1), int(r.Y1), int(r.X2), int(r.Y2), "]")
+ fmt.Println("---------------------")
}
- out, err := yolotriton.DrawBoundingBoxes(img, results, 5)
+ out, err := yolotriton.DrawBoundingBoxes(
+ img,
+ results,
+ int(float64(img.Bounds().Dx())*0.005),
+ float64(img.Bounds().Dx())*0.02,
+ )
if err != nil {
log.Fatal(err)
}
- err = yolotriton.SaveImage(out, fmt.Sprintf("%s_out.jpg", strings.Split(FLAGS.Image, ".")[0]))
+ err = yolotriton.SaveImage(out, fmt.Sprintf("%s_%s_out.jpg", strings.Split(FLAGS.Image, ".")[0], FLAGS.ModelName))
if err != nil {
log.Fatal(err)
}
diff --git a/images/1_out.jpg b/images/1_out.jpg
index 30ace3d..5ca8f10 100644
Binary files a/images/1_out.jpg and b/images/1_out.jpg differ
diff --git a/images/1_yolonas_out.jpg b/images/1_yolonas_out.jpg
new file mode 100644
index 0000000..efcf7c7
Binary files /dev/null and b/images/1_yolonas_out.jpg differ
diff --git a/images/1_yolov8_out.jpg b/images/1_yolov8_out.jpg
new file mode 100644
index 0000000..3561526
Binary files /dev/null and b/images/1_yolov8_out.jpg differ
diff --git a/images/2_out.jpg b/images/2_out.jpg
deleted file mode 100644
index 52d9601..0000000
Binary files a/images/2_out.jpg and /dev/null differ
diff --git a/images/2_yolonas_out.jpg b/images/2_yolonas_out.jpg
new file mode 100644
index 0000000..1497b4c
Binary files /dev/null and b/images/2_yolonas_out.jpg differ
diff --git a/images/2_yolov8_out.jpg b/images/2_yolov8_out.jpg
new file mode 100644
index 0000000..1210287
Binary files /dev/null and b/images/2_yolov8_out.jpg differ
diff --git a/model_repository/.gitignore b/model_repository/.gitignore
index f00c565..66a36b7 100644
--- a/model_repository/.gitignore
+++ b/model_repository/.gitignore
@@ -1 +1 @@
-yolov8_tensorrt/1/model.plan
\ No newline at end of file
+model.plan
\ No newline at end of file
diff --git a/model_repository/yolov8_tensorrt/1/.gitkeep b/model_repository/yolonas/1/.gitkeep
similarity index 100%
rename from model_repository/yolov8_tensorrt/1/.gitkeep
rename to model_repository/yolonas/1/.gitkeep
diff --git a/model_repository/yolov8_tensorrt/config.pbtxt b/model_repository/yolonas/config.pbtxt
similarity index 51%
rename from model_repository/yolov8_tensorrt/config.pbtxt
rename to model_repository/yolonas/config.pbtxt
index 3602cd5..5024667 100644
--- a/model_repository/yolov8_tensorrt/config.pbtxt
+++ b/model_repository/yolonas/config.pbtxt
@@ -1,2 +1,2 @@
-name: "yolov8_tensorrt"
+name: "yolonas"
platform: "tensorrt_plan"
\ No newline at end of file
diff --git a/model_repository/yolov8/1/.gitkeep b/model_repository/yolov8/1/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/model_repository/yolov8/config.pbtxt b/model_repository/yolov8/config.pbtxt
new file mode 100644
index 0000000..fe5f2e4
--- /dev/null
+++ b/model_repository/yolov8/config.pbtxt
@@ -0,0 +1,2 @@
+name: "yolov8"
+platform: "tensorrt_plan"
\ No newline at end of file
diff --git a/postprocess.go b/postprocess.go
index 77c4f66..67fdda2 100644
--- a/postprocess.go
+++ b/postprocess.go
@@ -3,30 +3,30 @@ package yolotriton
import (
"bytes"
"encoding/binary"
+ "io"
"math"
- "sort"
)
-func (y *YoloTriton) bytesToFloat32Slice(data []byte) ([]float32, error) {
+func bytesToFloat32Slice(data []byte) ([]float32, error) {
t := []float32{}
// Create a buffer from the input data
buffer := bytes.NewReader(data)
- for i := 0; i < y.cfg.BatchSize; i++ {
- for j := 0; j < y.cfg.NumChannels; j++ {
- for k := 0; k < y.cfg.NumObjects; k++ {
- // Read the binary data from the buffer
- var binaryValue uint32
- err := binary.Read(buffer, binary.LittleEndian, &binaryValue)
- if err != nil {
- return nil, err
- }
-
- t = append(t, math.Float32frombits(binaryValue))
-
+ for {
+ // Read the binary data from the buffer
+ var binaryValue uint32
+ err := binary.Read(buffer, binary.LittleEndian, &binaryValue)
+ if err != nil {
+ if err == io.EOF {
+ break
}
+ return nil, err
}
+
+ t = append(t, math.Float32frombits(binaryValue))
+
}
+
return t, nil
}
@@ -39,62 +39,6 @@ type Box struct {
Class string
}
-func (y *YoloTriton) parseOutput(output []float32, origImgWidth, origImgHeight int) []Box {
- boxes := []Box{}
-
- for index := 0; index < y.cfg.NumObjects; index++ {
- classID := 0
- prob := float32(0.0)
-
- for col := 0; col < y.cfg.NumChannels-4; col++ {
- if output[y.cfg.NumObjects*(col+4)+index] > prob {
- prob = output[y.cfg.NumObjects*(col+4)+index]
- classID = col
- }
- }
-
- if prob < float32(y.cfg.MinProbability) {
- continue
- }
-
- label := yoloClasses[classID]
- xc := output[index]
- yc := output[y.cfg.NumObjects+index]
- w := output[2*y.cfg.NumObjects+index]
- h := output[3*y.cfg.NumObjects+index]
- x1 := (xc - w/2) / float32(y.cfg.Width) * float32(origImgWidth)
- y1 := (yc - h/2) / float32(y.cfg.Height) * float32(origImgHeight)
- x2 := (xc + w/2) / float32(y.cfg.Width) * float32(origImgWidth)
- y2 := (yc + h/2) / float32(y.cfg.Height) * float32(origImgHeight)
- boxes = append(boxes, Box{
- X1: float64(x1),
- Y1: float64(y1),
- X2: float64(x2),
- Y2: float64(y2),
- Probability: float64(prob),
- Class: label,
- })
- }
-
- sort.Slice(boxes, func(i, j int) bool {
- return boxes[i].Probability < boxes[j].Probability
- })
-
- result := []Box{}
- for len(boxes) > 0 {
- result = append(result, boxes[0])
- tmp := []Box{}
- for _, box := range boxes {
- if iou(boxes[0], box) < y.cfg.MaxIOU {
- tmp = append(tmp, box)
- }
- }
- boxes = tmp
- }
-
- return result
-}
-
func iou(box1, box2 Box) float64 {
// Calculate the coordinates of the intersection rectangle
intersectionX1 := math.Max(box1.X1, box2.X1)
diff --git a/util.go b/util.go
index 39be7f3..754aeb1 100644
--- a/util.go
+++ b/util.go
@@ -44,7 +44,7 @@ func SaveImage(img image.Image, filename string) error {
return nil
}
-func DrawBoundingBoxes(img image.Image, boxes []Box, lineWidth int) (image.Image, error) {
+func DrawBoundingBoxes(img image.Image, boxes []Box, lineWidth int, fontSize float64) (image.Image, error) {
// Create a new RGBA image to draw the bounding boxes and text labels on
bounds := img.Bounds()
@@ -62,7 +62,7 @@ func DrawBoundingBoxes(img image.Image, boxes []Box, lineWidth int) (image.Image
return nil, err
}
face := truetype.NewFace(ttfFont, &truetype.Options{
- Size: 36.0,
+ Size: fontSize,
})
// Draw the bounding boxes and text labels on the destination image
diff --git a/yolo.go b/yolo.go
index 4ee1bf6..1cfd64d 100644
--- a/yolo.go
+++ b/yolo.go
@@ -3,39 +3,53 @@ package yolotriton
import (
"image"
_ "image/png"
+ "sort"
triton "github.com/dev6699/yolotriton/grpc-client"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
+type ModelType string
+
+const (
+ ModelTypeYoloV8 ModelType = "yolov8"
+ ModelTypeYoloNAS ModelType = "yolonas"
+)
+
+type Model interface {
+ GetConfig() YoloTritonConfig
+ PreProcess(img image.Image, targetWidth uint, targetHeight uint) ([]float32, error)
+ PostProcess(rawOutputContents [][]byte) ([]Box, error)
+}
+
type YoloTritonConfig struct {
BatchSize int
NumChannels int
NumObjects int
- Width int
- Height int
ModelName string
ModelVersion string
- MinProbability float64
+ MinProbability float32
MaxIOU float64
}
-func New(url string, cfg YoloTritonConfig) (*YoloTriton, error) {
+func New(url string, model Model) (*YoloTriton, error) {
conn, err := grpc.Dial(url, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
return &YoloTriton{
- cfg: cfg,
- conn: conn,
+ conn: conn,
+ model: model,
+ cfg: model.GetConfig(),
}, nil
}
type YoloTriton struct {
- cfg YoloTritonConfig
- conn *grpc.ClientConn
+ cfg YoloTritonConfig
+ conn *grpc.ClientConn
+ model Model
}
func (y *YoloTriton) Close() error {
@@ -44,32 +58,49 @@ func (y *YoloTriton) Close() error {
func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
- preprocessedImg := resizeImage(img, uint(y.cfg.Width), uint(y.cfg.Height))
-
- fp32Contents := imageToFloat32Slice(preprocessedImg)
-
client := triton.NewGRPCInferenceServiceClient(y.conn)
- inferInputs := []*triton.ModelInferRequest_InferInputTensor{
- {
- Name: "images",
- Datatype: "FP32",
- Shape: []int64{int64(y.cfg.BatchSize), 3, int64(y.cfg.Width), int64(y.cfg.Height)},
- Contents: &triton.InferTensorContents{
- Fp32Contents: fp32Contents,
- },
- },
- }
- inferOutputs := []*triton.ModelInferRequest_InferRequestedOutputTensor{
- {
- Name: "output0",
- },
+ metaResponse, err := ModelMetadataRequest(client, y.cfg.ModelName, y.cfg.ModelVersion)
+ if err != nil {
+ return nil, err
}
+
modelInferRequest := &triton.ModelInferRequest{
ModelName: y.cfg.ModelName,
ModelVersion: y.cfg.ModelVersion,
- Inputs: inferInputs,
- Outputs: inferOutputs,
+ }
+
+ input := metaResponse.Inputs[0]
+ if input.Shape[0] == -1 {
+ input.Shape[0] = 1
+ }
+
+ inputWidth := input.Shape[2]
+ inputHeight := input.Shape[3]
+
+ fp32Contents, err := y.model.PreProcess(img, uint(inputWidth), uint(inputHeight))
+ if err != nil {
+ return nil, err
+ }
+
+ modelInferRequest.Inputs = append(modelInferRequest.Inputs,
+ &triton.ModelInferRequest_InferInputTensor{
+ Name: input.Name,
+ Datatype: input.Datatype,
+ Shape: input.Shape,
+ Contents: &triton.InferTensorContents{
+ // Simply assume all are fp32
+ Fp32Contents: fp32Contents,
+ },
+ },
+ )
+
+ for _, o := range metaResponse.Outputs {
+ modelInferRequest.Outputs = append(modelInferRequest.Outputs,
+ &triton.ModelInferRequest_InferRequestedOutputTensor{
+ Name: o.Name,
+ },
+ )
}
inferResponse, err := ModelInferRequest(client, modelInferRequest)
@@ -77,11 +108,26 @@ func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
return nil, err
}
- t, err := y.bytesToFloat32Slice(inferResponse.RawOutputContents[0])
+ boxes, err := y.model.PostProcess(inferResponse.RawOutputContents)
if err != nil {
return nil, err
}
- boxes := y.parseOutput(t, img.Bounds().Dx(), img.Bounds().Dy())
- return boxes, nil
+ sort.Slice(boxes, func(i, j int) bool {
+ return boxes[i].Probability > boxes[j].Probability
+ })
+
+ result := []Box{}
+ for len(boxes) > 0 {
+ result = append(result, boxes[0])
+ tmp := []Box{}
+ for _, box := range boxes {
+ if iou(boxes[0], box) < y.cfg.MaxIOU {
+ tmp = append(tmp, box)
+ }
+ }
+ boxes = tmp
+ }
+
+ return result, nil
}
diff --git a/yolonas.go b/yolonas.go
new file mode 100644
index 0000000..2940a84
--- /dev/null
+++ b/yolonas.go
@@ -0,0 +1,140 @@
+package yolotriton
+
+import (
+ "image"
+ "image/color"
+ "image/draw"
+ "math"
+)
+
+type YoloNAS struct {
+ YoloTritonConfig
+ metadata struct {
+ xOffset float32
+ yOffset float32
+ scaleFactor float32
+ }
+}
+
+func NewYoloNAS(modelName string, modelVersion string) Model {
+ return &YoloNAS{
+ YoloTritonConfig: YoloTritonConfig{
+ BatchSize: 1,
+ NumChannels: 80,
+ NumObjects: 8400,
+ MinProbability: 0.5,
+ MaxIOU: 0.7,
+ ModelName: modelName,
+ ModelVersion: modelVersion,
+ },
+ }
+}
+
+var _ Model = &YoloNAS{}
+
+func (y *YoloNAS) GetConfig() YoloTritonConfig {
+ return y.YoloTritonConfig
+}
+
+func (y *YoloNAS) PreProcess(img image.Image, targetWidth uint, targetHeight uint) ([]float32, error) {
+ height := img.Bounds().Dy()
+ width := img.Bounds().Dx()
+
+ // https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/processing/processing.py#L547
+ scaleFactor := math.Min(float64(636)/float64(height), float64(636)/float64(width))
+ if scaleFactor != 1.0 {
+ newHeight := uint(math.Round(float64(height) * scaleFactor))
+ newWidth := uint(math.Round(float64(width) * scaleFactor))
+ img = resizeImage(img, newWidth, newHeight)
+ }
+
+ paddedImage, xOffset, yOffset := padImageToCenterWithGray(img, int(targetWidth), int(targetHeight), 114)
+
+ fp32Contents := imageToFloat32Slice(paddedImage)
+
+ y.metadata.xOffset = float32(xOffset)
+ y.metadata.yOffset = float32(yOffset)
+ y.metadata.scaleFactor = float32(scaleFactor)
+
+ return fp32Contents, nil
+}
+
+func (y *YoloNAS) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
+ predScores, err := bytesToFloat32Slice(rawOutputContents[0])
+ if err != nil {
+ return nil, err
+ }
+ predBoxes, err := bytesToFloat32Slice(rawOutputContents[1])
+ if err != nil {
+ return nil, err
+ }
+
+ boxes := []Box{}
+
+ for index := 0; index < y.NumObjects; index++ {
+
+ classID := 0
+ prob := float32(0.0)
+
+ for col := 0; col < y.NumChannels; col++ {
+ p := predScores[index*y.NumChannels+(col)]
+ if p > prob {
+ prob = p
+ classID = col
+ }
+ }
+
+ if prob < y.MinProbability {
+ continue
+ }
+
+ label := yoloClasses[classID]
+ i := (index * 4)
+ xc := predBoxes[i]
+ yc := predBoxes[i+1]
+ w := predBoxes[i+2]
+ h := predBoxes[i+3]
+
+ scale := y.metadata.scaleFactor
+ x1 := (xc - y.metadata.xOffset) / scale
+ y1 := (yc - y.metadata.yOffset) / scale
+ x2 := (w - y.metadata.xOffset) / scale
+ y2 := (h - y.metadata.yOffset) / scale
+
+ boxes = append(boxes, Box{
+ X1: float64(x1),
+ Y1: float64(y1),
+ X2: float64(x2),
+ Y2: float64(y2),
+ Probability: float64(prob),
+ Class: label,
+ })
+ }
+
+ return boxes, nil
+}
+
+func padImageToCenterWithGray(originalImage image.Image, targetWidth, targetHeight int, grayValue uint8) (image.Image, int, int) {
+ // Calculate the dimensions of the original image
+ originalWidth := originalImage.Bounds().Dx()
+ originalHeight := originalImage.Bounds().Dy()
+
+ // Calculate the padding dimensions
+ padWidth := targetWidth - originalWidth
+ padHeight := targetHeight - originalHeight
+
+ // Create a new RGBA image with the desired dimensions and fill it with gray
+ paddedImage := image.NewRGBA(image.Rect(0, 0, targetWidth, targetHeight))
+ grayColor := color.RGBA{grayValue, grayValue, grayValue, 255}
+ draw.Draw(paddedImage, paddedImage.Bounds(), &image.Uniform{grayColor}, image.Point{}, draw.Src)
+
+ // Calculate the position to paste the original image in the center
+ xOffset := int(math.Floor(float64(padWidth) / 2))
+ yOffset := int(math.Floor(float64(padHeight) / 2))
+
+ // Paste the original image onto the padded image
+ pasteRect := image.Rect(xOffset, yOffset, xOffset+originalWidth, yOffset+originalHeight)
+ draw.Draw(paddedImage, pasteRect, originalImage, image.Point{}, draw.Over)
+
+ return paddedImage, xOffset, yOffset
+}
diff --git a/yolov8.go b/yolov8.go
new file mode 100644
index 0000000..8fe97e4
--- /dev/null
+++ b/yolov8.go
@@ -0,0 +1,98 @@
+package yolotriton
+
+import (
+ "image"
+)
+
+type YoloV8 struct {
+ YoloTritonConfig
+ metadata struct {
+ scaleFactorW float32
+ scaleFactorH float32
+ }
+}
+
+func NewYoloV8(modelName string, modelVersion string) Model {
+ return &YoloV8{
+ YoloTritonConfig: YoloTritonConfig{
+ BatchSize: 1,
+ NumChannels: 84,
+ NumObjects: 8400,
+ MinProbability: 0.5,
+ MaxIOU: 0.7,
+ ModelName: modelName,
+ ModelVersion: modelVersion,
+ },
+ }
+}
+
+var _ Model = &YoloV8{}
+
+func (y *YoloV8) GetConfig() YoloTritonConfig {
+ return y.YoloTritonConfig
+}
+
+func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) ([]float32, error) {
+ width := img.Bounds().Dx()
+ height := img.Bounds().Dy()
+
+ preprocessedImg := resizeImage(img, targetWidth, targetHeight)
+
+ fp32Contents := imageToFloat32Slice(preprocessedImg)
+
+ y.metadata.scaleFactorW = float32(width) / float32(targetWidth)
+ y.metadata.scaleFactorH = float32(height) / float32(targetHeight)
+
+ return fp32Contents, nil
+}
+
+func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
+ output, err := bytesToFloat32Slice(rawOutputContents[0])
+ if err != nil {
+ return nil, err
+ }
+
+ numObjects := y.NumObjects
+ numChannels := y.NumChannels
+
+ boxes := []Box{}
+
+ for index := 0; index < numObjects; index++ {
+ classID := 0
+ prob := float32(0.0)
+
+ for col := 0; col < numChannels-4; col++ {
+ p := output[numObjects*(col+4)+index]
+ if p > prob {
+ prob = p
+ classID = col
+ }
+ }
+
+ if prob < y.MinProbability {
+ continue
+ }
+
+ label := yoloClasses[classID]
+ xc := output[index]
+ yc := output[numObjects+index]
+ w := output[2*numObjects+index]
+ h := output[3*numObjects+index]
+
+ x1 := (xc - w/2) * y.metadata.scaleFactorW
+ y1 := (yc - h/2) * y.metadata.scaleFactorH
+ x2 := (xc + w/2) * y.metadata.scaleFactorW
+ y2 := (yc + h/2) * y.metadata.scaleFactorH
+
+ boxes = append(boxes, Box{
+ X1: float64(x1),
+ Y1: float64(y1),
+ X2: float64(x2),
+ Y2: float64(y2),
+ Probability: float64(prob),
+ Class: label,
+ })
+ }
+
+ return boxes, nil
+}