feat: added support for YOLO-NAS
42
README.md
@@ -4,7 +4,7 @@
|
|||||||
[](https://goreportcard.com/report/github.com/dev6699/yolotriton)
|
[](https://goreportcard.com/report/github.com/dev6699/yolotriton)
|
||||||
[](LICENSE)
|
[](LICENSE)
|
||||||
|
|
||||||
Go (Golang) gRPC client for YOLOv8 inference using the Triton Inference Server.
|
Go (Golang) gRPC client for YOLO-NAS, YOLOv8 inference using the Triton Inference Server.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@@ -14,11 +14,12 @@ Use `go get` to install this package:
|
|||||||
go get github.com/dev6699/yolotriton
|
go get github.com/dev6699/yolotriton
|
||||||
```
|
```
|
||||||
|
|
||||||
### Get YOLOv8 TensorRT model
|
### Get YOLO-NAS, YOLOv8 TensorRT model
|
||||||
|
Replace `yolov8m.pt` with your desired model
|
||||||
```bash
|
```bash
|
||||||
pip install ultralytics
|
pip install ultralytics
|
||||||
yolo export model=yolov8m.pt format=onnx
|
yolo export model=yolov8m.pt format=onnx
|
||||||
trtexec --onnx=yolov8m.onnx --saveEngine=model_repository/yolov8_tensorrt/1/model.plan
|
trtexec --onnx=yolov8m.onnx --saveEngine=model_repository/yolov8/1/model.plan
|
||||||
```
|
```
|
||||||
|
|
||||||
References:
|
References:
|
||||||
@@ -39,20 +40,39 @@ Check [cmd/main.go](cmd/main.go) for more details.
|
|||||||
Available args:
|
Available args:
|
||||||
```bash
|
```bash
|
||||||
-i string
|
-i string
|
||||||
Inference Image. Default: images/1.jpg (default "images/1.jpg")
|
Inference Image. (default "images/1.jpg")
|
||||||
-m string
|
-m string
|
||||||
Name of model being served. (Required) (default "yolov8_tensorrt")
|
Name of model being served (Required) (default "yolonas")
|
||||||
|
-t string
|
||||||
|
Type of model. Available options: [yolonas, yolov8] (default "yolonas")
|
||||||
-u string
|
-u string
|
||||||
Inference Server URL. Default: tritonserver:8001 (default "tritonserver:8001")
|
Inference Server URL. (default "tritonserver:8001")
|
||||||
-x string
|
-x string
|
||||||
Version of model. Default: Latest Version.
|
Version of model. Default: Latest Version
|
||||||
```
|
```
|
||||||
```bash
|
```bash
|
||||||
go run cmd/main.go
|
go run cmd/main.go
|
||||||
```
|
```
|
||||||
|
|
||||||
### Results
|
### Results
|
||||||
| Input | Ouput |
|
```
|
||||||
| --------------------------- | ------------------------------- |
|
prediction: 0
|
||||||
| <img src="images/1.jpg" /> | <img src="images/1_out.jpg" /> |
|
class: dog
|
||||||
| <img src="images/2.jpg" /> | <img src="images/2_out.jpg" /> |
|
confidence: 0.96
|
||||||
|
bboxes: [ 669 130 1061 563 ]
|
||||||
|
---------------------
|
||||||
|
prediction: 1
|
||||||
|
class: person
|
||||||
|
confidence: 0.96
|
||||||
|
bboxes: [ 440 30 760 541 ]
|
||||||
|
---------------------
|
||||||
|
prediction: 2
|
||||||
|
class: dog
|
||||||
|
confidence: 0.93
|
||||||
|
bboxes: [ 168 83 495 592 ]
|
||||||
|
---------------------
|
||||||
|
```
|
||||||
|
| Input | YOLO-NAS Ouput | YOLOv8 Output |
|
||||||
|
| --------------------------- | --------------------------------------- | -------------------------------------- |
|
||||||
|
| <img src="images/1.jpg" /> | <img src="images/1_yolonas_out.jpg" /> | <img src="images/1_yolonas_out.jpg" /> |
|
||||||
|
| <img src="images/2.jpg" /> | <img src="images/2_yolonas_out.jpg" /> | <img src="images/2_yolonas_out.jpg" /> |
|
50
cmd/main.go
@@ -12,16 +12,18 @@ import (
|
|||||||
type Flags struct {
|
type Flags struct {
|
||||||
ModelName string
|
ModelName string
|
||||||
ModelVersion string
|
ModelVersion string
|
||||||
|
ModelType string
|
||||||
URL string
|
URL string
|
||||||
Image string
|
Image string
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFlags() Flags {
|
func parseFlags() Flags {
|
||||||
var flags Flags
|
var flags Flags
|
||||||
flag.StringVar(&flags.ModelName, "m", "yolov8_tensorrt", "Name of model being served. (Required)")
|
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
|
||||||
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version.")
|
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
|
||||||
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL. Default: tritonserver:8001")
|
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolov8]")
|
||||||
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image. Default: images/1.jpg")
|
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
|
||||||
|
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
return flags
|
return flags
|
||||||
}
|
}
|
||||||
@@ -30,20 +32,17 @@ func main() {
|
|||||||
FLAGS := parseFlags()
|
FLAGS := parseFlags()
|
||||||
fmt.Println("FLAGS:", FLAGS)
|
fmt.Println("FLAGS:", FLAGS)
|
||||||
|
|
||||||
ygt, err := yolotriton.New(
|
var model yolotriton.Model
|
||||||
FLAGS.URL,
|
switch yolotriton.ModelType(FLAGS.ModelType) {
|
||||||
yolotriton.YoloTritonConfig{
|
case yolotriton.ModelTypeYoloV8:
|
||||||
BatchSize: 1,
|
model = yolotriton.NewYoloV8(FLAGS.ModelName, FLAGS.ModelVersion)
|
||||||
NumChannels: 84,
|
case yolotriton.ModelTypeYoloNAS:
|
||||||
NumObjects: 8400,
|
model = yolotriton.NewYoloNAS(FLAGS.ModelName, FLAGS.ModelVersion)
|
||||||
Width: 640,
|
default:
|
||||||
Height: 640,
|
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolov8]", FLAGS.ModelType)
|
||||||
ModelName: FLAGS.ModelName,
|
}
|
||||||
ModelVersion: FLAGS.ModelVersion,
|
|
||||||
MinProbability: 0.5,
|
|
||||||
MaxIOU: 0.7,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
ygt, err := yolotriton.New(FLAGS.URL, model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -59,17 +58,24 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, r := range results {
|
for i, r := range results {
|
||||||
fmt.Printf("---%d---", i)
|
fmt.Println("prediction: ", i)
|
||||||
fmt.Println(r.Class, r.Probability)
|
fmt.Println("class: ", r.Class)
|
||||||
fmt.Println("[x1,x2,y1,y2]", int(r.X1), int(r.X2), int(r.Y1), int(r.Y2))
|
fmt.Printf("confidence: %.2f\n", r.Probability)
|
||||||
|
fmt.Println("bboxes: [", int(r.X1), int(r.Y1), int(r.X2), int(r.Y2), "]")
|
||||||
|
fmt.Println("---------------------")
|
||||||
}
|
}
|
||||||
|
|
||||||
out, err := yolotriton.DrawBoundingBoxes(img, results, 5)
|
out, err := yolotriton.DrawBoundingBoxes(
|
||||||
|
img,
|
||||||
|
results,
|
||||||
|
int(float64(img.Bounds().Dx())*0.005),
|
||||||
|
float64(img.Bounds().Dx())*0.02,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = yolotriton.SaveImage(out, fmt.Sprintf("%s_out.jpg", strings.Split(FLAGS.Image, ".")[0]))
|
err = yolotriton.SaveImage(out, fmt.Sprintf("%s_%s_out.jpg", strings.Split(FLAGS.Image, ".")[0], FLAGS.ModelName))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
BIN
images/1_out.jpg
Before Width: | Height: | Size: 96 KiB After Width: | Height: | Size: 88 KiB |
BIN
images/1_yolonas_out.jpg
Normal file
After Width: | Height: | Size: 92 KiB |
BIN
images/1_yolov8_out.jpg
Normal file
After Width: | Height: | Size: 90 KiB |
BIN
images/2_out.jpg
Before Width: | Height: | Size: 2.3 MiB |
BIN
images/2_yolonas_out.jpg
Normal file
After Width: | Height: | Size: 2.3 MiB |
BIN
images/2_yolov8_out.jpg
Normal file
After Width: | Height: | Size: 2.3 MiB |
2
model_repository/.gitignore
vendored
@@ -1 +1 @@
|
|||||||
yolov8_tensorrt/1/model.plan
|
model.plan
|
@@ -1,2 +1,2 @@
|
|||||||
name: "yolov8_tensorrt"
|
name: "yolonas"
|
||||||
platform: "tensorrt_plan"
|
platform: "tensorrt_plan"
|
0
model_repository/yolov8/1/.gitkeep
Normal file
2
model_repository/yolov8/config.pbtxt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
name: "yolov8"
|
||||||
|
platform: "tensorrt_plan"
|
@@ -3,30 +3,30 @@ package yolotriton
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"sort"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (y *YoloTriton) bytesToFloat32Slice(data []byte) ([]float32, error) {
|
func bytesToFloat32Slice(data []byte) ([]float32, error) {
|
||||||
t := []float32{}
|
t := []float32{}
|
||||||
|
|
||||||
// Create a buffer from the input data
|
// Create a buffer from the input data
|
||||||
buffer := bytes.NewReader(data)
|
buffer := bytes.NewReader(data)
|
||||||
for i := 0; i < y.cfg.BatchSize; i++ {
|
for {
|
||||||
for j := 0; j < y.cfg.NumChannels; j++ {
|
// Read the binary data from the buffer
|
||||||
for k := 0; k < y.cfg.NumObjects; k++ {
|
var binaryValue uint32
|
||||||
// Read the binary data from the buffer
|
err := binary.Read(buffer, binary.LittleEndian, &binaryValue)
|
||||||
var binaryValue uint32
|
if err != nil {
|
||||||
err := binary.Read(buffer, binary.LittleEndian, &binaryValue)
|
if err == io.EOF {
|
||||||
if err != nil {
|
break
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t = append(t, math.Float32frombits(binaryValue))
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t = append(t, math.Float32frombits(binaryValue))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,62 +39,6 @@ type Box struct {
|
|||||||
Class string
|
Class string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (y *YoloTriton) parseOutput(output []float32, origImgWidth, origImgHeight int) []Box {
|
|
||||||
boxes := []Box{}
|
|
||||||
|
|
||||||
for index := 0; index < y.cfg.NumObjects; index++ {
|
|
||||||
classID := 0
|
|
||||||
prob := float32(0.0)
|
|
||||||
|
|
||||||
for col := 0; col < y.cfg.NumChannels-4; col++ {
|
|
||||||
if output[y.cfg.NumObjects*(col+4)+index] > prob {
|
|
||||||
prob = output[y.cfg.NumObjects*(col+4)+index]
|
|
||||||
classID = col
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if prob < float32(y.cfg.MinProbability) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
label := yoloClasses[classID]
|
|
||||||
xc := output[index]
|
|
||||||
yc := output[y.cfg.NumObjects+index]
|
|
||||||
w := output[2*y.cfg.NumObjects+index]
|
|
||||||
h := output[3*y.cfg.NumObjects+index]
|
|
||||||
x1 := (xc - w/2) / float32(y.cfg.Width) * float32(origImgWidth)
|
|
||||||
y1 := (yc - h/2) / float32(y.cfg.Height) * float32(origImgHeight)
|
|
||||||
x2 := (xc + w/2) / float32(y.cfg.Width) * float32(origImgWidth)
|
|
||||||
y2 := (yc + h/2) / float32(y.cfg.Height) * float32(origImgHeight)
|
|
||||||
boxes = append(boxes, Box{
|
|
||||||
X1: float64(x1),
|
|
||||||
Y1: float64(y1),
|
|
||||||
X2: float64(x2),
|
|
||||||
Y2: float64(y2),
|
|
||||||
Probability: float64(prob),
|
|
||||||
Class: label,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(boxes, func(i, j int) bool {
|
|
||||||
return boxes[i].Probability < boxes[j].Probability
|
|
||||||
})
|
|
||||||
|
|
||||||
result := []Box{}
|
|
||||||
for len(boxes) > 0 {
|
|
||||||
result = append(result, boxes[0])
|
|
||||||
tmp := []Box{}
|
|
||||||
for _, box := range boxes {
|
|
||||||
if iou(boxes[0], box) < y.cfg.MaxIOU {
|
|
||||||
tmp = append(tmp, box)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
boxes = tmp
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func iou(box1, box2 Box) float64 {
|
func iou(box1, box2 Box) float64 {
|
||||||
// Calculate the coordinates of the intersection rectangle
|
// Calculate the coordinates of the intersection rectangle
|
||||||
intersectionX1 := math.Max(box1.X1, box2.X1)
|
intersectionX1 := math.Max(box1.X1, box2.X1)
|
||||||
|
4
util.go
@@ -44,7 +44,7 @@ func SaveImage(img image.Image, filename string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DrawBoundingBoxes(img image.Image, boxes []Box, lineWidth int) (image.Image, error) {
|
func DrawBoundingBoxes(img image.Image, boxes []Box, lineWidth int, fontSize float64) (image.Image, error) {
|
||||||
|
|
||||||
// Create a new RGBA image to draw the bounding boxes and text labels on
|
// Create a new RGBA image to draw the bounding boxes and text labels on
|
||||||
bounds := img.Bounds()
|
bounds := img.Bounds()
|
||||||
@@ -62,7 +62,7 @@ func DrawBoundingBoxes(img image.Image, boxes []Box, lineWidth int) (image.Image
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
face := truetype.NewFace(ttfFont, &truetype.Options{
|
face := truetype.NewFace(ttfFont, &truetype.Options{
|
||||||
Size: 36.0,
|
Size: fontSize,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Draw the bounding boxes and text labels on the destination image
|
// Draw the bounding boxes and text labels on the destination image
|
||||||
|
108
yolo.go
@@ -3,39 +3,53 @@ package yolotriton
|
|||||||
import (
|
import (
|
||||||
"image"
|
"image"
|
||||||
_ "image/png"
|
_ "image/png"
|
||||||
|
"sort"
|
||||||
|
|
||||||
triton "github.com/dev6699/yolotriton/grpc-client"
|
triton "github.com/dev6699/yolotriton/grpc-client"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ModelType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelTypeYoloV8 ModelType = "yolov8"
|
||||||
|
ModelTypeYoloNAS ModelType = "yolonas"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model interface {
|
||||||
|
GetConfig() YoloTritonConfig
|
||||||
|
PreProcess(img image.Image, targetWidth uint, targetHeight uint) ([]float32, error)
|
||||||
|
PostProcess(rawOutputContents [][]byte) ([]Box, error)
|
||||||
|
}
|
||||||
|
|
||||||
type YoloTritonConfig struct {
|
type YoloTritonConfig struct {
|
||||||
BatchSize int
|
BatchSize int
|
||||||
NumChannels int
|
NumChannels int
|
||||||
NumObjects int
|
NumObjects int
|
||||||
Width int
|
|
||||||
Height int
|
|
||||||
ModelName string
|
ModelName string
|
||||||
ModelVersion string
|
ModelVersion string
|
||||||
MinProbability float64
|
MinProbability float32
|
||||||
MaxIOU float64
|
MaxIOU float64
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(url string, cfg YoloTritonConfig) (*YoloTriton, error) {
|
func New(url string, model Model) (*YoloTriton, error) {
|
||||||
conn, err := grpc.Dial(url, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.Dial(url, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &YoloTriton{
|
return &YoloTriton{
|
||||||
cfg: cfg,
|
conn: conn,
|
||||||
conn: conn,
|
model: model,
|
||||||
|
cfg: model.GetConfig(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type YoloTriton struct {
|
type YoloTriton struct {
|
||||||
cfg YoloTritonConfig
|
cfg YoloTritonConfig
|
||||||
conn *grpc.ClientConn
|
conn *grpc.ClientConn
|
||||||
|
model Model
|
||||||
}
|
}
|
||||||
|
|
||||||
func (y *YoloTriton) Close() error {
|
func (y *YoloTriton) Close() error {
|
||||||
@@ -44,32 +58,49 @@ func (y *YoloTriton) Close() error {
|
|||||||
|
|
||||||
func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
|
func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
|
||||||
|
|
||||||
preprocessedImg := resizeImage(img, uint(y.cfg.Width), uint(y.cfg.Height))
|
|
||||||
|
|
||||||
fp32Contents := imageToFloat32Slice(preprocessedImg)
|
|
||||||
|
|
||||||
client := triton.NewGRPCInferenceServiceClient(y.conn)
|
client := triton.NewGRPCInferenceServiceClient(y.conn)
|
||||||
|
|
||||||
inferInputs := []*triton.ModelInferRequest_InferInputTensor{
|
metaResponse, err := ModelMetadataRequest(client, y.cfg.ModelName, y.cfg.ModelVersion)
|
||||||
{
|
if err != nil {
|
||||||
Name: "images",
|
return nil, err
|
||||||
Datatype: "FP32",
|
|
||||||
Shape: []int64{int64(y.cfg.BatchSize), 3, int64(y.cfg.Width), int64(y.cfg.Height)},
|
|
||||||
Contents: &triton.InferTensorContents{
|
|
||||||
Fp32Contents: fp32Contents,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
inferOutputs := []*triton.ModelInferRequest_InferRequestedOutputTensor{
|
|
||||||
{
|
|
||||||
Name: "output0",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
modelInferRequest := &triton.ModelInferRequest{
|
modelInferRequest := &triton.ModelInferRequest{
|
||||||
ModelName: y.cfg.ModelName,
|
ModelName: y.cfg.ModelName,
|
||||||
ModelVersion: y.cfg.ModelVersion,
|
ModelVersion: y.cfg.ModelVersion,
|
||||||
Inputs: inferInputs,
|
}
|
||||||
Outputs: inferOutputs,
|
|
||||||
|
input := metaResponse.Inputs[0]
|
||||||
|
if input.Shape[0] == -1 {
|
||||||
|
input.Shape[0] = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
inputWidth := input.Shape[2]
|
||||||
|
inputHeight := input.Shape[3]
|
||||||
|
|
||||||
|
fp32Contents, err := y.model.PreProcess(img, uint(inputWidth), uint(inputHeight))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelInferRequest.Inputs = append(modelInferRequest.Inputs,
|
||||||
|
&triton.ModelInferRequest_InferInputTensor{
|
||||||
|
Name: input.Name,
|
||||||
|
Datatype: input.Datatype,
|
||||||
|
Shape: input.Shape,
|
||||||
|
Contents: &triton.InferTensorContents{
|
||||||
|
// Simply assume all are fp32
|
||||||
|
Fp32Contents: fp32Contents,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, o := range metaResponse.Outputs {
|
||||||
|
modelInferRequest.Outputs = append(modelInferRequest.Outputs,
|
||||||
|
&triton.ModelInferRequest_InferRequestedOutputTensor{
|
||||||
|
Name: o.Name,
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
inferResponse, err := ModelInferRequest(client, modelInferRequest)
|
inferResponse, err := ModelInferRequest(client, modelInferRequest)
|
||||||
@@ -77,11 +108,26 @@ func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := y.bytesToFloat32Slice(inferResponse.RawOutputContents[0])
|
boxes, err := y.model.PostProcess(inferResponse.RawOutputContents)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
boxes := y.parseOutput(t, img.Bounds().Dx(), img.Bounds().Dy())
|
sort.Slice(boxes, func(i, j int) bool {
|
||||||
return boxes, nil
|
return boxes[i].Probability > boxes[j].Probability
|
||||||
|
})
|
||||||
|
|
||||||
|
result := []Box{}
|
||||||
|
for len(boxes) > 0 {
|
||||||
|
result = append(result, boxes[0])
|
||||||
|
tmp := []Box{}
|
||||||
|
for _, box := range boxes {
|
||||||
|
if iou(boxes[0], box) < y.cfg.MaxIOU {
|
||||||
|
tmp = append(tmp, box)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
boxes = tmp
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
140
yolonas.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package yolotriton
|
||||||
|
|
||||||
|
import (
|
||||||
|
"image"
|
||||||
|
"image/color"
|
||||||
|
"image/draw"
|
||||||
|
"math"
|
||||||
|
)
|
||||||
|
|
||||||
|
type YoloNAS struct {
|
||||||
|
YoloTritonConfig
|
||||||
|
metadata struct {
|
||||||
|
xOffset float32
|
||||||
|
yOffset float32
|
||||||
|
scaleFactor float32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewYoloNAS(modelName string, modelVersion string) Model {
|
||||||
|
return &YoloNAS{
|
||||||
|
YoloTritonConfig: YoloTritonConfig{
|
||||||
|
BatchSize: 1,
|
||||||
|
NumChannels: 80,
|
||||||
|
NumObjects: 8400,
|
||||||
|
MinProbability: 0.5,
|
||||||
|
MaxIOU: 0.7,
|
||||||
|
ModelName: modelName,
|
||||||
|
ModelVersion: modelVersion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Model = &YoloNAS{}
|
||||||
|
|
||||||
|
func (y *YoloNAS) GetConfig() YoloTritonConfig {
|
||||||
|
return y.YoloTritonConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (y *YoloNAS) PreProcess(img image.Image, targetWidth uint, targetHeight uint) ([]float32, error) {
|
||||||
|
height := img.Bounds().Dy()
|
||||||
|
width := img.Bounds().Dx()
|
||||||
|
|
||||||
|
// https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/processing/processing.py#L547
|
||||||
|
scaleFactor := math.Min(float64(636)/float64(height), float64(636)/float64(width))
|
||||||
|
if scaleFactor != 1.0 {
|
||||||
|
newHeight := uint(math.Round(float64(height) * scaleFactor))
|
||||||
|
newWidth := uint(math.Round(float64(width) * scaleFactor))
|
||||||
|
img = resizeImage(img, newWidth, newHeight)
|
||||||
|
}
|
||||||
|
|
||||||
|
paddedImage, xOffset, yOffset := padImageToCenterWithGray(img, int(targetWidth), int(targetHeight), 114)
|
||||||
|
|
||||||
|
fp32Contents := imageToFloat32Slice(paddedImage)
|
||||||
|
|
||||||
|
y.metadata.xOffset = float32(xOffset)
|
||||||
|
y.metadata.yOffset = float32(yOffset)
|
||||||
|
y.metadata.scaleFactor = float32(scaleFactor)
|
||||||
|
|
||||||
|
return fp32Contents, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (y *YoloNAS) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
||||||
|
predScores, err := bytesToFloat32Slice(rawOutputContents[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
predBoxes, err := bytesToFloat32Slice(rawOutputContents[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
boxes := []Box{}
|
||||||
|
|
||||||
|
for index := 0; index < y.NumObjects; index++ {
|
||||||
|
|
||||||
|
classID := 0
|
||||||
|
prob := float32(0.0)
|
||||||
|
|
||||||
|
for col := 0; col < y.NumChannels; col++ {
|
||||||
|
p := predScores[index*y.NumChannels+(col)]
|
||||||
|
if p > prob {
|
||||||
|
prob = p
|
||||||
|
classID = col
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prob < y.MinProbability {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
label := yoloClasses[classID]
|
||||||
|
i := (index * 4)
|
||||||
|
xc := predBoxes[i]
|
||||||
|
yc := predBoxes[i+1]
|
||||||
|
w := predBoxes[i+2]
|
||||||
|
h := predBoxes[i+3]
|
||||||
|
|
||||||
|
scale := y.metadata.scaleFactor
|
||||||
|
x1 := (xc - y.metadata.xOffset) / scale
|
||||||
|
y1 := (yc - y.metadata.yOffset) / scale
|
||||||
|
x2 := (w - y.metadata.xOffset) / scale
|
||||||
|
y2 := (h - y.metadata.yOffset) / scale
|
||||||
|
|
||||||
|
boxes = append(boxes, Box{
|
||||||
|
X1: float64(x1),
|
||||||
|
Y1: float64(y1),
|
||||||
|
X2: float64(x2),
|
||||||
|
Y2: float64(y2),
|
||||||
|
Probability: float64(prob),
|
||||||
|
Class: label,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return boxes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func padImageToCenterWithGray(originalImage image.Image, targetWidth, targetHeight int, grayValue uint8) (image.Image, int, int) {
|
||||||
|
// Calculate the dimensions of the original image
|
||||||
|
originalWidth := originalImage.Bounds().Dx()
|
||||||
|
originalHeight := originalImage.Bounds().Dy()
|
||||||
|
|
||||||
|
// Calculate the padding dimensions
|
||||||
|
padWidth := targetWidth - originalWidth
|
||||||
|
padHeight := targetHeight - originalHeight
|
||||||
|
|
||||||
|
// Create a new RGBA image with the desired dimensions and fill it with gray
|
||||||
|
paddedImage := image.NewRGBA(image.Rect(0, 0, targetWidth, targetHeight))
|
||||||
|
grayColor := color.RGBA{grayValue, grayValue, grayValue, 255}
|
||||||
|
draw.Draw(paddedImage, paddedImage.Bounds(), &image.Uniform{grayColor}, image.Point{}, draw.Src)
|
||||||
|
|
||||||
|
// Calculate the position to paste the original image in the center
|
||||||
|
xOffset := int(math.Floor(float64(padWidth) / 2))
|
||||||
|
yOffset := int(math.Floor(float64(padHeight) / 2))
|
||||||
|
|
||||||
|
// Paste the original image onto the padded image
|
||||||
|
pasteRect := image.Rect(xOffset, yOffset, xOffset+originalWidth, yOffset+originalHeight)
|
||||||
|
draw.Draw(paddedImage, pasteRect, originalImage, image.Point{}, draw.Over)
|
||||||
|
|
||||||
|
return paddedImage, xOffset, yOffset
|
||||||
|
}
|
98
yolov8.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package yolotriton
|
||||||
|
|
||||||
|
import (
|
||||||
|
"image"
|
||||||
|
)
|
||||||
|
|
||||||
|
type YoloV8 struct {
|
||||||
|
YoloTritonConfig
|
||||||
|
metadata struct {
|
||||||
|
scaleFactorW float32
|
||||||
|
scaleFactorH float32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewYoloV8(modelName string, modelVersion string) Model {
|
||||||
|
return &YoloV8{
|
||||||
|
YoloTritonConfig: YoloTritonConfig{
|
||||||
|
BatchSize: 1,
|
||||||
|
NumChannels: 84,
|
||||||
|
NumObjects: 8400,
|
||||||
|
MinProbability: 0.5,
|
||||||
|
MaxIOU: 0.7,
|
||||||
|
ModelName: modelName,
|
||||||
|
ModelVersion: modelVersion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Model = &YoloV8{}
|
||||||
|
|
||||||
|
func (y *YoloV8) GetConfig() YoloTritonConfig {
|
||||||
|
return y.YoloTritonConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) ([]float32, error) {
|
||||||
|
width := img.Bounds().Dx()
|
||||||
|
height := img.Bounds().Dy()
|
||||||
|
|
||||||
|
preprocessedImg := resizeImage(img, targetWidth, targetHeight)
|
||||||
|
|
||||||
|
fp32Contents := imageToFloat32Slice(preprocessedImg)
|
||||||
|
|
||||||
|
y.metadata.scaleFactorW = float32(width) / float32(targetWidth)
|
||||||
|
y.metadata.scaleFactorH = float32(height) / float32(targetHeight)
|
||||||
|
|
||||||
|
return fp32Contents, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
||||||
|
output, err := bytesToFloat32Slice(rawOutputContents[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
numObjects := y.NumObjects
|
||||||
|
numChannels := y.NumChannels
|
||||||
|
|
||||||
|
boxes := []Box{}
|
||||||
|
|
||||||
|
for index := 0; index < numObjects; index++ {
|
||||||
|
classID := 0
|
||||||
|
prob := float32(0.0)
|
||||||
|
|
||||||
|
for col := 0; col < numChannels-4; col++ {
|
||||||
|
p := output[numObjects*(col+4)+index]
|
||||||
|
if p > prob {
|
||||||
|
prob = p
|
||||||
|
classID = col
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prob < y.MinProbability {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
label := yoloClasses[classID]
|
||||||
|
xc := output[index]
|
||||||
|
yc := output[numObjects+index]
|
||||||
|
w := output[2*numObjects+index]
|
||||||
|
h := output[3*numObjects+index]
|
||||||
|
|
||||||
|
x1 := (xc - w/2) * y.metadata.scaleFactorW
|
||||||
|
y1 := (yc - h/2) * y.metadata.scaleFactorH
|
||||||
|
x2 := (xc + w/2) * y.metadata.scaleFactorW
|
||||||
|
y2 := (yc + h/2) * y.metadata.scaleFactorH
|
||||||
|
|
||||||
|
boxes = append(boxes, Box{
|
||||||
|
X1: float64(x1),
|
||||||
|
Y1: float64(y1),
|
||||||
|
X2: float64(x2),
|
||||||
|
Y2: float64(y2),
|
||||||
|
Probability: float64(prob),
|
||||||
|
Class: label,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return boxes, nil
|
||||||
|
}
|