From f338c40b9373cf28b2bf70b1dde1c1db180b0607 Mon Sep 17 00:00:00 2001 From: kweijack Date: Tue, 19 Sep 2023 12:43:10 +0000 Subject: [PATCH] feat: extract config, added MinProbability, MaxIOU flag --- README.md | 4 ++++ class.go | 2 +- cmd/main.go | 22 +++++++++++++++++++--- yolo.go | 2 +- yolonas.go | 17 +++-------------- yolonasint8.go | 15 +++------------ yolov8.go | 17 +++-------------- 7 files changed, 34 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 4a77f52..7166693 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,10 @@ go run cmd/main.go --help Name of model being served (Required) (default "yolonas") -n int Number of benchmark run. (default 1) + -o float + Intersection over Union (IoU) (default 0.7) + -p float + Minimum probability (default 0.5) -t string Type of model. Available options: [yolonas, yolonasint8, yolov8] (default "yolonas") -u string diff --git a/class.go b/class.go index dd5712a..f80c8de 100644 --- a/class.go +++ b/class.go @@ -1,6 +1,6 @@ package yolotriton -var yoloClasses = []string{ +var YoloClasses = []string{ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", diff --git a/cmd/main.go b/cmd/main.go index cecd15d..5bff165 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -14,6 +14,8 @@ type Flags struct { ModelName string ModelVersion string ModelType string + MinProbability float64 + MaxIOU float64 URL string Image string Benchmark bool @@ -25,6 +27,8 @@ func parseFlags() Flags { flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)") flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version") flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]") + flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability") + flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)") flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.") flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.") flag.BoolVar(&flags.Benchmark, "b", false, "Run benchmark.") @@ -37,14 +41,26 @@ func main() { FLAGS := parseFlags() fmt.Println("FLAGS:", FLAGS) + cfg := yolotriton.YoloTritonConfig{ + ModelName: FLAGS.ModelName, + ModelVersion: FLAGS.ModelVersion, + MinProbability: float32(FLAGS.MinProbability), + MaxIOU: FLAGS.MaxIOU, + Classes: yolotriton.YoloClasses, + } + var model yolotriton.Model switch yolotriton.ModelType(FLAGS.ModelType) { case yolotriton.ModelTypeYoloV8: - model = yolotriton.NewYoloV8(FLAGS.ModelName, FLAGS.ModelVersion) + cfg.NumClasses = 80 + cfg.NumObjects = 8400 + model = yolotriton.NewYoloV8(cfg) case yolotriton.ModelTypeYoloNAS: - model = yolotriton.NewYoloNAS(FLAGS.ModelName, FLAGS.ModelVersion) + cfg.NumClasses = 80 + cfg.NumObjects = 8400 + model = yolotriton.NewYoloNAS(cfg) case yolotriton.ModelTypeYoloNASInt8: - model = yolotriton.NewYoloNASInt8(FLAGS.ModelName, FLAGS.ModelVersion) + model = yolotriton.NewYoloNASInt8(cfg) default: log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType) } diff --git a/yolo.go b/yolo.go index 4011c7d..31987c5 100644 --- a/yolo.go +++ b/yolo.go @@ -22,7 +22,6 @@ type Model interface { GetConfig() YoloTritonConfig PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) PostProcess(rawOutputContents [][]byte) ([]Box, error) - GetClass(index int) string } type YoloTritonConfig struct { @@ -32,6 +31,7 @@ type YoloTritonConfig struct { ModelVersion string MinProbability float32 MaxIOU float64 + Classes []string } func New(url string, model Model) (*YoloTriton, error) { diff --git a/yolonas.go b/yolonas.go index 2ae16ce..711be5e 100644 --- a/yolonas.go +++ b/yolonas.go @@ -16,16 +16,9 @@ type YoloNAS struct { } } -func NewYoloNAS(modelName string, modelVersion string) Model { +func NewYoloNAS(cfg YoloTritonConfig) Model { return &YoloNAS{ - YoloTritonConfig: YoloTritonConfig{ - NumClasses: 80, - NumObjects: 8400, - MinProbability: 0.5, - MaxIOU: 0.7, - ModelName: modelName, - ModelVersion: modelVersion, - }, + YoloTritonConfig: cfg, } } @@ -35,10 +28,6 @@ func (y *YoloNAS) GetConfig() YoloTritonConfig { return y.YoloTritonConfig } -func (y *YoloNAS) GetClass(index int) string { - return yoloClasses[index] -} - func (y *YoloNAS) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) { height := img.Bounds().Dy() width := img.Bounds().Dx() @@ -94,7 +83,7 @@ func (y *YoloNAS) PostProcess(rawOutputContents [][]byte) ([]Box, error) { continue } - label := y.GetClass(classID) + label := y.Classes[classID] idx := (index * 4) x1raw := predBoxes[idx] y1raw := predBoxes[idx+1] diff --git a/yolonasint8.go b/yolonasint8.go index 13be622..7029173 100644 --- a/yolonasint8.go +++ b/yolonasint8.go @@ -16,14 +16,9 @@ type YoloNASInt8 struct { } } -func NewYoloNASInt8(modelName string, modelVersion string) Model { +func NewYoloNASInt8(cfg YoloTritonConfig) Model { return &YoloNASInt8{ - YoloTritonConfig: YoloTritonConfig{ - MinProbability: 0.5, - MaxIOU: 0.7, - ModelName: modelName, - ModelVersion: modelVersion, - }, + YoloTritonConfig: cfg, } } @@ -33,10 +28,6 @@ func (y *YoloNASInt8) GetConfig() YoloTritonConfig { return y.YoloTritonConfig } -func (y *YoloNASInt8) GetClass(index int) string { - return yoloClasses[index] -} - func (y *YoloNASInt8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) { height := img.Bounds().Dy() width := img.Bounds().Dx() @@ -89,7 +80,7 @@ func (y *YoloNASInt8) PostProcess(rawOutputContents [][]byte) ([]Box, error) { } classID := predClasses[index] - label := y.GetClass(int(classID)) + label := y.Classes[classID] idx := (index * 4) x1raw := predBoxes[idx] y1raw := predBoxes[idx+1] diff --git a/yolov8.go b/yolov8.go index 3cf2116..bfdffc5 100644 --- a/yolov8.go +++ b/yolov8.go @@ -14,16 +14,9 @@ type YoloV8 struct { } } -func NewYoloV8(modelName string, modelVersion string) Model { +func NewYoloV8(cfg YoloTritonConfig) Model { return &YoloV8{ - YoloTritonConfig: YoloTritonConfig{ - NumClasses: 80, - NumObjects: 8400, - MinProbability: 0.5, - MaxIOU: 0.7, - ModelName: modelName, - ModelVersion: modelVersion, - }, + YoloTritonConfig: cfg, } } @@ -33,10 +26,6 @@ func (y *YoloV8) GetConfig() YoloTritonConfig { return y.YoloTritonConfig } -func (y *YoloV8) GetClass(index int) string { - return yoloClasses[index] -} - func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) { width := img.Bounds().Dx() height := img.Bounds().Dy() @@ -81,7 +70,7 @@ func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) { continue } - label := y.GetClass(classID) + label := y.Classes[classID] x1raw := output[index] y1raw := output[numObjects+index] w := output[2*numObjects+index]