feat: added support for YOLO-NAS

This commit is contained in:
kweijack
2023-09-06 03:46:09 +00:00
parent cbcfdb9189
commit fea4ba6dc0
18 changed files with 394 additions and 138 deletions

View File

@@ -4,7 +4,7 @@
[![Go Report Card](https://goreportcard.com/badge/github.com/dev6699/yolotriton)](https://goreportcard.com/report/github.com/dev6699/yolotriton) [![Go Report Card](https://goreportcard.com/badge/github.com/dev6699/yolotriton)](https://goreportcard.com/report/github.com/dev6699/yolotriton)
[![License](https://img.shields.io/github/license/dev6699/yolotriton)](LICENSE) [![License](https://img.shields.io/github/license/dev6699/yolotriton)](LICENSE)
Go (Golang) gRPC client for YOLOv8 inference using the Triton Inference Server. Go (Golang) gRPC client for YOLO-NAS, YOLOv8 inference using the Triton Inference Server.
## Installation ## Installation
@@ -14,11 +14,12 @@ Use `go get` to install this package:
go get github.com/dev6699/yolotriton go get github.com/dev6699/yolotriton
``` ```
### Get YOLOv8 TensorRT model ### Get YOLO-NAS, YOLOv8 TensorRT model
Replace `yolov8m.pt` with your desired model
```bash ```bash
pip install ultralytics pip install ultralytics
yolo export model=yolov8m.pt format=onnx yolo export model=yolov8m.pt format=onnx
trtexec --onnx=yolov8m.onnx --saveEngine=model_repository/yolov8_tensorrt/1/model.plan trtexec --onnx=yolov8m.onnx --saveEngine=model_repository/yolov8/1/model.plan
``` ```
References: References:
@@ -39,20 +40,39 @@ Check [cmd/main.go](cmd/main.go) for more details.
Available args: Available args:
```bash ```bash
-i string -i string
Inference Image. Default: images/1.jpg (default "images/1.jpg") Inference Image. (default "images/1.jpg")
-m string -m string
Name of model being served. (Required) (default "yolov8_tensorrt") Name of model being served (Required) (default "yolonas")
-t string
Type of model. Available options: [yolonas, yolov8] (default "yolonas")
-u string -u string
Inference Server URL. Default: tritonserver:8001 (default "tritonserver:8001") Inference Server URL. (default "tritonserver:8001")
-x string -x string
Version of model. Default: Latest Version. Version of model. Default: Latest Version
``` ```
```bash ```bash
go run cmd/main.go go run cmd/main.go
``` ```
### Results ### Results
| Input | Ouput | ```
| --------------------------- | ------------------------------- | prediction: 0
| <img src="images/1.jpg" /> | <img src="images/1_out.jpg" /> | class: dog
| <img src="images/2.jpg" /> | <img src="images/2_out.jpg" /> | confidence: 0.96
bboxes: [ 669 130 1061 563 ]
---------------------
prediction: 1
class: person
confidence: 0.96
bboxes: [ 440 30 760 541 ]
---------------------
prediction: 2
class: dog
confidence: 0.93
bboxes: [ 168 83 495 592 ]
---------------------
```
| Input | YOLO-NAS Ouput | YOLOv8 Output |
| --------------------------- | --------------------------------------- | -------------------------------------- |
| <img src="images/1.jpg" /> | <img src="images/1_yolonas_out.jpg" /> | <img src="images/1_yolonas_out.jpg" /> |
| <img src="images/2.jpg" /> | <img src="images/2_yolonas_out.jpg" /> | <img src="images/2_yolonas_out.jpg" /> |

View File

@@ -12,16 +12,18 @@ import (
type Flags struct { type Flags struct {
ModelName string ModelName string
ModelVersion string ModelVersion string
ModelType string
URL string URL string
Image string Image string
} }
func parseFlags() Flags { func parseFlags() Flags {
var flags Flags var flags Flags
flag.StringVar(&flags.ModelName, "m", "yolov8_tensorrt", "Name of model being served. (Required)") flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version.") flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL. Default: tritonserver:8001") flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolov8]")
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image. Default: images/1.jpg") flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.")
flag.Parse() flag.Parse()
return flags return flags
} }
@@ -30,20 +32,17 @@ func main() {
FLAGS := parseFlags() FLAGS := parseFlags()
fmt.Println("FLAGS:", FLAGS) fmt.Println("FLAGS:", FLAGS)
ygt, err := yolotriton.New( var model yolotriton.Model
FLAGS.URL, switch yolotriton.ModelType(FLAGS.ModelType) {
yolotriton.YoloTritonConfig{ case yolotriton.ModelTypeYoloV8:
BatchSize: 1, model = yolotriton.NewYoloV8(FLAGS.ModelName, FLAGS.ModelVersion)
NumChannels: 84, case yolotriton.ModelTypeYoloNAS:
NumObjects: 8400, model = yolotriton.NewYoloNAS(FLAGS.ModelName, FLAGS.ModelVersion)
Width: 640, default:
Height: 640, log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolov8]", FLAGS.ModelType)
ModelName: FLAGS.ModelName, }
ModelVersion: FLAGS.ModelVersion,
MinProbability: 0.5,
MaxIOU: 0.7,
})
ygt, err := yolotriton.New(FLAGS.URL, model)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@@ -59,17 +58,24 @@ func main() {
} }
for i, r := range results { for i, r := range results {
fmt.Printf("---%d---", i) fmt.Println("prediction: ", i)
fmt.Println(r.Class, r.Probability) fmt.Println("class: ", r.Class)
fmt.Println("[x1,x2,y1,y2]", int(r.X1), int(r.X2), int(r.Y1), int(r.Y2)) fmt.Printf("confidence: %.2f\n", r.Probability)
fmt.Println("bboxes: [", int(r.X1), int(r.Y1), int(r.X2), int(r.Y2), "]")
fmt.Println("---------------------")
} }
out, err := yolotriton.DrawBoundingBoxes(img, results, 5) out, err := yolotriton.DrawBoundingBoxes(
img,
results,
int(float64(img.Bounds().Dx())*0.005),
float64(img.Bounds().Dx())*0.02,
)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
err = yolotriton.SaveImage(out, fmt.Sprintf("%s_out.jpg", strings.Split(FLAGS.Image, ".")[0])) err = yolotriton.SaveImage(out, fmt.Sprintf("%s_%s_out.jpg", strings.Split(FLAGS.Image, ".")[0], FLAGS.ModelName))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

Binary file not shown.

Before

Width:  |  Height:  |  Size: 96 KiB

After

Width:  |  Height:  |  Size: 88 KiB

BIN
images/1_yolonas_out.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

BIN
images/1_yolov8_out.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.3 MiB

BIN
images/2_yolonas_out.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

BIN
images/2_yolov8_out.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

View File

@@ -1 +1 @@
yolov8_tensorrt/1/model.plan model.plan

View File

@@ -1,2 +1,2 @@
name: "yolov8_tensorrt" name: "yolonas"
platform: "tensorrt_plan" platform: "tensorrt_plan"

View File

View File

@@ -0,0 +1,2 @@
name: "yolov8"
platform: "tensorrt_plan"

View File

@@ -3,30 +3,30 @@ package yolotriton
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"io"
"math" "math"
"sort"
) )
func (y *YoloTriton) bytesToFloat32Slice(data []byte) ([]float32, error) { func bytesToFloat32Slice(data []byte) ([]float32, error) {
t := []float32{} t := []float32{}
// Create a buffer from the input data // Create a buffer from the input data
buffer := bytes.NewReader(data) buffer := bytes.NewReader(data)
for i := 0; i < y.cfg.BatchSize; i++ { for {
for j := 0; j < y.cfg.NumChannels; j++ { // Read the binary data from the buffer
for k := 0; k < y.cfg.NumObjects; k++ { var binaryValue uint32
// Read the binary data from the buffer err := binary.Read(buffer, binary.LittleEndian, &binaryValue)
var binaryValue uint32 if err != nil {
err := binary.Read(buffer, binary.LittleEndian, &binaryValue) if err == io.EOF {
if err != nil { break
return nil, err
}
t = append(t, math.Float32frombits(binaryValue))
} }
return nil, err
} }
t = append(t, math.Float32frombits(binaryValue))
} }
return t, nil return t, nil
} }
@@ -39,62 +39,6 @@ type Box struct {
Class string Class string
} }
func (y *YoloTriton) parseOutput(output []float32, origImgWidth, origImgHeight int) []Box {
boxes := []Box{}
for index := 0; index < y.cfg.NumObjects; index++ {
classID := 0
prob := float32(0.0)
for col := 0; col < y.cfg.NumChannels-4; col++ {
if output[y.cfg.NumObjects*(col+4)+index] > prob {
prob = output[y.cfg.NumObjects*(col+4)+index]
classID = col
}
}
if prob < float32(y.cfg.MinProbability) {
continue
}
label := yoloClasses[classID]
xc := output[index]
yc := output[y.cfg.NumObjects+index]
w := output[2*y.cfg.NumObjects+index]
h := output[3*y.cfg.NumObjects+index]
x1 := (xc - w/2) / float32(y.cfg.Width) * float32(origImgWidth)
y1 := (yc - h/2) / float32(y.cfg.Height) * float32(origImgHeight)
x2 := (xc + w/2) / float32(y.cfg.Width) * float32(origImgWidth)
y2 := (yc + h/2) / float32(y.cfg.Height) * float32(origImgHeight)
boxes = append(boxes, Box{
X1: float64(x1),
Y1: float64(y1),
X2: float64(x2),
Y2: float64(y2),
Probability: float64(prob),
Class: label,
})
}
sort.Slice(boxes, func(i, j int) bool {
return boxes[i].Probability < boxes[j].Probability
})
result := []Box{}
for len(boxes) > 0 {
result = append(result, boxes[0])
tmp := []Box{}
for _, box := range boxes {
if iou(boxes[0], box) < y.cfg.MaxIOU {
tmp = append(tmp, box)
}
}
boxes = tmp
}
return result
}
func iou(box1, box2 Box) float64 { func iou(box1, box2 Box) float64 {
// Calculate the coordinates of the intersection rectangle // Calculate the coordinates of the intersection rectangle
intersectionX1 := math.Max(box1.X1, box2.X1) intersectionX1 := math.Max(box1.X1, box2.X1)

View File

@@ -44,7 +44,7 @@ func SaveImage(img image.Image, filename string) error {
return nil return nil
} }
func DrawBoundingBoxes(img image.Image, boxes []Box, lineWidth int) (image.Image, error) { func DrawBoundingBoxes(img image.Image, boxes []Box, lineWidth int, fontSize float64) (image.Image, error) {
// Create a new RGBA image to draw the bounding boxes and text labels on // Create a new RGBA image to draw the bounding boxes and text labels on
bounds := img.Bounds() bounds := img.Bounds()
@@ -62,7 +62,7 @@ func DrawBoundingBoxes(img image.Image, boxes []Box, lineWidth int) (image.Image
return nil, err return nil, err
} }
face := truetype.NewFace(ttfFont, &truetype.Options{ face := truetype.NewFace(ttfFont, &truetype.Options{
Size: 36.0, Size: fontSize,
}) })
// Draw the bounding boxes and text labels on the destination image // Draw the bounding boxes and text labels on the destination image

108
yolo.go
View File

@@ -3,39 +3,53 @@ package yolotriton
import ( import (
"image" "image"
_ "image/png" _ "image/png"
"sort"
triton "github.com/dev6699/yolotriton/grpc-client" triton "github.com/dev6699/yolotriton/grpc-client"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
) )
type ModelType string
const (
ModelTypeYoloV8 ModelType = "yolov8"
ModelTypeYoloNAS ModelType = "yolonas"
)
type Model interface {
GetConfig() YoloTritonConfig
PreProcess(img image.Image, targetWidth uint, targetHeight uint) ([]float32, error)
PostProcess(rawOutputContents [][]byte) ([]Box, error)
}
type YoloTritonConfig struct { type YoloTritonConfig struct {
BatchSize int BatchSize int
NumChannels int NumChannels int
NumObjects int NumObjects int
Width int
Height int
ModelName string ModelName string
ModelVersion string ModelVersion string
MinProbability float64 MinProbability float32
MaxIOU float64 MaxIOU float64
} }
func New(url string, cfg YoloTritonConfig) (*YoloTriton, error) { func New(url string, model Model) (*YoloTriton, error) {
conn, err := grpc.Dial(url, grpc.WithTransportCredentials(insecure.NewCredentials())) conn, err := grpc.Dial(url, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &YoloTriton{ return &YoloTriton{
cfg: cfg, conn: conn,
conn: conn, model: model,
cfg: model.GetConfig(),
}, nil }, nil
} }
type YoloTriton struct { type YoloTriton struct {
cfg YoloTritonConfig cfg YoloTritonConfig
conn *grpc.ClientConn conn *grpc.ClientConn
model Model
} }
func (y *YoloTriton) Close() error { func (y *YoloTriton) Close() error {
@@ -44,32 +58,49 @@ func (y *YoloTriton) Close() error {
func (y *YoloTriton) Infer(img image.Image) ([]Box, error) { func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
preprocessedImg := resizeImage(img, uint(y.cfg.Width), uint(y.cfg.Height))
fp32Contents := imageToFloat32Slice(preprocessedImg)
client := triton.NewGRPCInferenceServiceClient(y.conn) client := triton.NewGRPCInferenceServiceClient(y.conn)
inferInputs := []*triton.ModelInferRequest_InferInputTensor{ metaResponse, err := ModelMetadataRequest(client, y.cfg.ModelName, y.cfg.ModelVersion)
{ if err != nil {
Name: "images", return nil, err
Datatype: "FP32",
Shape: []int64{int64(y.cfg.BatchSize), 3, int64(y.cfg.Width), int64(y.cfg.Height)},
Contents: &triton.InferTensorContents{
Fp32Contents: fp32Contents,
},
},
}
inferOutputs := []*triton.ModelInferRequest_InferRequestedOutputTensor{
{
Name: "output0",
},
} }
modelInferRequest := &triton.ModelInferRequest{ modelInferRequest := &triton.ModelInferRequest{
ModelName: y.cfg.ModelName, ModelName: y.cfg.ModelName,
ModelVersion: y.cfg.ModelVersion, ModelVersion: y.cfg.ModelVersion,
Inputs: inferInputs, }
Outputs: inferOutputs,
input := metaResponse.Inputs[0]
if input.Shape[0] == -1 {
input.Shape[0] = 1
}
inputWidth := input.Shape[2]
inputHeight := input.Shape[3]
fp32Contents, err := y.model.PreProcess(img, uint(inputWidth), uint(inputHeight))
if err != nil {
return nil, err
}
modelInferRequest.Inputs = append(modelInferRequest.Inputs,
&triton.ModelInferRequest_InferInputTensor{
Name: input.Name,
Datatype: input.Datatype,
Shape: input.Shape,
Contents: &triton.InferTensorContents{
// Simply assume all are fp32
Fp32Contents: fp32Contents,
},
},
)
for _, o := range metaResponse.Outputs {
modelInferRequest.Outputs = append(modelInferRequest.Outputs,
&triton.ModelInferRequest_InferRequestedOutputTensor{
Name: o.Name,
},
)
} }
inferResponse, err := ModelInferRequest(client, modelInferRequest) inferResponse, err := ModelInferRequest(client, modelInferRequest)
@@ -77,11 +108,26 @@ func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
return nil, err return nil, err
} }
t, err := y.bytesToFloat32Slice(inferResponse.RawOutputContents[0]) boxes, err := y.model.PostProcess(inferResponse.RawOutputContents)
if err != nil { if err != nil {
return nil, err return nil, err
} }
boxes := y.parseOutput(t, img.Bounds().Dx(), img.Bounds().Dy()) sort.Slice(boxes, func(i, j int) bool {
return boxes, nil return boxes[i].Probability > boxes[j].Probability
})
result := []Box{}
for len(boxes) > 0 {
result = append(result, boxes[0])
tmp := []Box{}
for _, box := range boxes {
if iou(boxes[0], box) < y.cfg.MaxIOU {
tmp = append(tmp, box)
}
}
boxes = tmp
}
return result, nil
} }

140
yolonas.go Normal file
View File

@@ -0,0 +1,140 @@
package yolotriton
import (
"image"
"image/color"
"image/draw"
"math"
)
type YoloNAS struct {
YoloTritonConfig
metadata struct {
xOffset float32
yOffset float32
scaleFactor float32
}
}
func NewYoloNAS(modelName string, modelVersion string) Model {
return &YoloNAS{
YoloTritonConfig: YoloTritonConfig{
BatchSize: 1,
NumChannels: 80,
NumObjects: 8400,
MinProbability: 0.5,
MaxIOU: 0.7,
ModelName: modelName,
ModelVersion: modelVersion,
},
}
}
var _ Model = &YoloNAS{}
func (y *YoloNAS) GetConfig() YoloTritonConfig {
return y.YoloTritonConfig
}
func (y *YoloNAS) PreProcess(img image.Image, targetWidth uint, targetHeight uint) ([]float32, error) {
height := img.Bounds().Dy()
width := img.Bounds().Dx()
// https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/processing/processing.py#L547
scaleFactor := math.Min(float64(636)/float64(height), float64(636)/float64(width))
if scaleFactor != 1.0 {
newHeight := uint(math.Round(float64(height) * scaleFactor))
newWidth := uint(math.Round(float64(width) * scaleFactor))
img = resizeImage(img, newWidth, newHeight)
}
paddedImage, xOffset, yOffset := padImageToCenterWithGray(img, int(targetWidth), int(targetHeight), 114)
fp32Contents := imageToFloat32Slice(paddedImage)
y.metadata.xOffset = float32(xOffset)
y.metadata.yOffset = float32(yOffset)
y.metadata.scaleFactor = float32(scaleFactor)
return fp32Contents, nil
}
func (y *YoloNAS) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
predScores, err := bytesToFloat32Slice(rawOutputContents[0])
if err != nil {
return nil, err
}
predBoxes, err := bytesToFloat32Slice(rawOutputContents[1])
if err != nil {
return nil, err
}
boxes := []Box{}
for index := 0; index < y.NumObjects; index++ {
classID := 0
prob := float32(0.0)
for col := 0; col < y.NumChannels; col++ {
p := predScores[index*y.NumChannels+(col)]
if p > prob {
prob = p
classID = col
}
}
if prob < y.MinProbability {
continue
}
label := yoloClasses[classID]
i := (index * 4)
xc := predBoxes[i]
yc := predBoxes[i+1]
w := predBoxes[i+2]
h := predBoxes[i+3]
scale := y.metadata.scaleFactor
x1 := (xc - y.metadata.xOffset) / scale
y1 := (yc - y.metadata.yOffset) / scale
x2 := (w - y.metadata.xOffset) / scale
y2 := (h - y.metadata.yOffset) / scale
boxes = append(boxes, Box{
X1: float64(x1),
Y1: float64(y1),
X2: float64(x2),
Y2: float64(y2),
Probability: float64(prob),
Class: label,
})
}
return boxes, nil
}
func padImageToCenterWithGray(originalImage image.Image, targetWidth, targetHeight int, grayValue uint8) (image.Image, int, int) {
// Calculate the dimensions of the original image
originalWidth := originalImage.Bounds().Dx()
originalHeight := originalImage.Bounds().Dy()
// Calculate the padding dimensions
padWidth := targetWidth - originalWidth
padHeight := targetHeight - originalHeight
// Create a new RGBA image with the desired dimensions and fill it with gray
paddedImage := image.NewRGBA(image.Rect(0, 0, targetWidth, targetHeight))
grayColor := color.RGBA{grayValue, grayValue, grayValue, 255}
draw.Draw(paddedImage, paddedImage.Bounds(), &image.Uniform{grayColor}, image.Point{}, draw.Src)
// Calculate the position to paste the original image in the center
xOffset := int(math.Floor(float64(padWidth) / 2))
yOffset := int(math.Floor(float64(padHeight) / 2))
// Paste the original image onto the padded image
pasteRect := image.Rect(xOffset, yOffset, xOffset+originalWidth, yOffset+originalHeight)
draw.Draw(paddedImage, pasteRect, originalImage, image.Point{}, draw.Over)
return paddedImage, xOffset, yOffset
}

98
yolov8.go Normal file
View File

@@ -0,0 +1,98 @@
package yolotriton
import (
"image"
)
type YoloV8 struct {
YoloTritonConfig
metadata struct {
scaleFactorW float32
scaleFactorH float32
}
}
func NewYoloV8(modelName string, modelVersion string) Model {
return &YoloV8{
YoloTritonConfig: YoloTritonConfig{
BatchSize: 1,
NumChannels: 84,
NumObjects: 8400,
MinProbability: 0.5,
MaxIOU: 0.7,
ModelName: modelName,
ModelVersion: modelVersion,
},
}
}
var _ Model = &YoloV8{}
func (y *YoloV8) GetConfig() YoloTritonConfig {
return y.YoloTritonConfig
}
func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) ([]float32, error) {
width := img.Bounds().Dx()
height := img.Bounds().Dy()
preprocessedImg := resizeImage(img, targetWidth, targetHeight)
fp32Contents := imageToFloat32Slice(preprocessedImg)
y.metadata.scaleFactorW = float32(width) / float32(targetWidth)
y.metadata.scaleFactorH = float32(height) / float32(targetHeight)
return fp32Contents, nil
}
func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
output, err := bytesToFloat32Slice(rawOutputContents[0])
if err != nil {
return nil, err
}
numObjects := y.NumObjects
numChannels := y.NumChannels
boxes := []Box{}
for index := 0; index < numObjects; index++ {
classID := 0
prob := float32(0.0)
for col := 0; col < numChannels-4; col++ {
p := output[numObjects*(col+4)+index]
if p > prob {
prob = p
classID = col
}
}
if prob < y.MinProbability {
continue
}
label := yoloClasses[classID]
xc := output[index]
yc := output[numObjects+index]
w := output[2*numObjects+index]
h := output[3*numObjects+index]
x1 := (xc - w/2) * y.metadata.scaleFactorW
y1 := (yc - h/2) * y.metadata.scaleFactorH
x2 := (xc + w/2) * y.metadata.scaleFactorW
y2 := (yc + h/2) * y.metadata.scaleFactorH
boxes = append(boxes, Box{
X1: float64(x1),
Y1: float64(y1),
X2: float64(x2),
Y2: float64(y2),
Probability: float64(prob),
Class: label,
})
}
return boxes, nil
}