feat: extract config, added MinProbability, MaxIOU flag

This commit is contained in:
kweijack
2023-09-19 12:43:10 +00:00
parent d56aaec916
commit f338c40b93
7 changed files with 34 additions and 45 deletions

View File

@@ -77,6 +77,10 @@ go run cmd/main.go --help
Name of model being served (Required) (default "yolonas") Name of model being served (Required) (default "yolonas")
-n int -n int
Number of benchmark run. (default 1) Number of benchmark run. (default 1)
-o float
Intersection over Union (IoU) (default 0.7)
-p float
Minimum probability (default 0.5)
-t string -t string
Type of model. Available options: [yolonas, yolonasint8, yolov8] (default "yolonas") Type of model. Available options: [yolonas, yolonasint8, yolov8] (default "yolonas")
-u string -u string

View File

@@ -1,6 +1,6 @@
package yolotriton package yolotriton
var yoloClasses = []string{ var YoloClasses = []string{
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",

View File

@@ -14,6 +14,8 @@ type Flags struct {
ModelName string ModelName string
ModelVersion string ModelVersion string
ModelType string ModelType string
MinProbability float64
MaxIOU float64
URL string URL string
Image string Image string
Benchmark bool Benchmark bool
@@ -25,6 +27,8 @@ func parseFlags() Flags {
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)") flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version") flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]") flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]")
flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability")
flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)")
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.") flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.") flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.")
flag.BoolVar(&flags.Benchmark, "b", false, "Run benchmark.") flag.BoolVar(&flags.Benchmark, "b", false, "Run benchmark.")
@@ -37,14 +41,26 @@ func main() {
FLAGS := parseFlags() FLAGS := parseFlags()
fmt.Println("FLAGS:", FLAGS) fmt.Println("FLAGS:", FLAGS)
cfg := yolotriton.YoloTritonConfig{
ModelName: FLAGS.ModelName,
ModelVersion: FLAGS.ModelVersion,
MinProbability: float32(FLAGS.MinProbability),
MaxIOU: FLAGS.MaxIOU,
Classes: yolotriton.YoloClasses,
}
var model yolotriton.Model var model yolotriton.Model
switch yolotriton.ModelType(FLAGS.ModelType) { switch yolotriton.ModelType(FLAGS.ModelType) {
case yolotriton.ModelTypeYoloV8: case yolotriton.ModelTypeYoloV8:
model = yolotriton.NewYoloV8(FLAGS.ModelName, FLAGS.ModelVersion) cfg.NumClasses = 80
cfg.NumObjects = 8400
model = yolotriton.NewYoloV8(cfg)
case yolotriton.ModelTypeYoloNAS: case yolotriton.ModelTypeYoloNAS:
model = yolotriton.NewYoloNAS(FLAGS.ModelName, FLAGS.ModelVersion) cfg.NumClasses = 80
cfg.NumObjects = 8400
model = yolotriton.NewYoloNAS(cfg)
case yolotriton.ModelTypeYoloNASInt8: case yolotriton.ModelTypeYoloNASInt8:
model = yolotriton.NewYoloNASInt8(FLAGS.ModelName, FLAGS.ModelVersion) model = yolotriton.NewYoloNASInt8(cfg)
default: default:
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType) log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType)
} }

View File

@@ -22,7 +22,6 @@ type Model interface {
GetConfig() YoloTritonConfig GetConfig() YoloTritonConfig
PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error)
PostProcess(rawOutputContents [][]byte) ([]Box, error) PostProcess(rawOutputContents [][]byte) ([]Box, error)
GetClass(index int) string
} }
type YoloTritonConfig struct { type YoloTritonConfig struct {
@@ -32,6 +31,7 @@ type YoloTritonConfig struct {
ModelVersion string ModelVersion string
MinProbability float32 MinProbability float32
MaxIOU float64 MaxIOU float64
Classes []string
} }
func New(url string, model Model) (*YoloTriton, error) { func New(url string, model Model) (*YoloTriton, error) {

View File

@@ -16,16 +16,9 @@ type YoloNAS struct {
} }
} }
func NewYoloNAS(modelName string, modelVersion string) Model { func NewYoloNAS(cfg YoloTritonConfig) Model {
return &YoloNAS{ return &YoloNAS{
YoloTritonConfig: YoloTritonConfig{ YoloTritonConfig: cfg,
NumClasses: 80,
NumObjects: 8400,
MinProbability: 0.5,
MaxIOU: 0.7,
ModelName: modelName,
ModelVersion: modelVersion,
},
} }
} }
@@ -35,10 +28,6 @@ func (y *YoloNAS) GetConfig() YoloTritonConfig {
return y.YoloTritonConfig return y.YoloTritonConfig
} }
func (y *YoloNAS) GetClass(index int) string {
return yoloClasses[index]
}
func (y *YoloNAS) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) { func (y *YoloNAS) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
height := img.Bounds().Dy() height := img.Bounds().Dy()
width := img.Bounds().Dx() width := img.Bounds().Dx()
@@ -94,7 +83,7 @@ func (y *YoloNAS) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
continue continue
} }
label := y.GetClass(classID) label := y.Classes[classID]
idx := (index * 4) idx := (index * 4)
x1raw := predBoxes[idx] x1raw := predBoxes[idx]
y1raw := predBoxes[idx+1] y1raw := predBoxes[idx+1]

View File

@@ -16,14 +16,9 @@ type YoloNASInt8 struct {
} }
} }
func NewYoloNASInt8(modelName string, modelVersion string) Model { func NewYoloNASInt8(cfg YoloTritonConfig) Model {
return &YoloNASInt8{ return &YoloNASInt8{
YoloTritonConfig: YoloTritonConfig{ YoloTritonConfig: cfg,
MinProbability: 0.5,
MaxIOU: 0.7,
ModelName: modelName,
ModelVersion: modelVersion,
},
} }
} }
@@ -33,10 +28,6 @@ func (y *YoloNASInt8) GetConfig() YoloTritonConfig {
return y.YoloTritonConfig return y.YoloTritonConfig
} }
func (y *YoloNASInt8) GetClass(index int) string {
return yoloClasses[index]
}
func (y *YoloNASInt8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) { func (y *YoloNASInt8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
height := img.Bounds().Dy() height := img.Bounds().Dy()
width := img.Bounds().Dx() width := img.Bounds().Dx()
@@ -89,7 +80,7 @@ func (y *YoloNASInt8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
} }
classID := predClasses[index] classID := predClasses[index]
label := y.GetClass(int(classID)) label := y.Classes[classID]
idx := (index * 4) idx := (index * 4)
x1raw := predBoxes[idx] x1raw := predBoxes[idx]
y1raw := predBoxes[idx+1] y1raw := predBoxes[idx+1]

View File

@@ -14,16 +14,9 @@ type YoloV8 struct {
} }
} }
func NewYoloV8(modelName string, modelVersion string) Model { func NewYoloV8(cfg YoloTritonConfig) Model {
return &YoloV8{ return &YoloV8{
YoloTritonConfig: YoloTritonConfig{ YoloTritonConfig: cfg,
NumClasses: 80,
NumObjects: 8400,
MinProbability: 0.5,
MaxIOU: 0.7,
ModelName: modelName,
ModelVersion: modelVersion,
},
} }
} }
@@ -33,10 +26,6 @@ func (y *YoloV8) GetConfig() YoloTritonConfig {
return y.YoloTritonConfig return y.YoloTritonConfig
} }
func (y *YoloV8) GetClass(index int) string {
return yoloClasses[index]
}
func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) { func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
width := img.Bounds().Dx() width := img.Bounds().Dx()
height := img.Bounds().Dy() height := img.Bounds().Dy()
@@ -81,7 +70,7 @@ func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
continue continue
} }
label := y.GetClass(classID) label := y.Classes[classID]
x1raw := output[index] x1raw := output[index]
y1raw := output[numObjects+index] y1raw := output[numObjects+index]
w := output[2*numObjects+index] w := output[2*numObjects+index]