mirror of
https://github.com/dev6699/yolotriton.git
synced 2025-09-26 19:51:13 +08:00
feat: extract config, added MinProbability, MaxIOU flag
This commit is contained in:
@@ -77,6 +77,10 @@ go run cmd/main.go --help
|
||||
Name of model being served (Required) (default "yolonas")
|
||||
-n int
|
||||
Number of benchmark run. (default 1)
|
||||
-o float
|
||||
Intersection over Union (IoU) (default 0.7)
|
||||
-p float
|
||||
Minimum probability (default 0.5)
|
||||
-t string
|
||||
Type of model. Available options: [yolonas, yolonasint8, yolov8] (default "yolonas")
|
||||
-u string
|
||||
|
2
class.go
2
class.go
@@ -1,6 +1,6 @@
|
||||
package yolotriton
|
||||
|
||||
var yoloClasses = []string{
|
||||
var YoloClasses = []string{
|
||||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
|
||||
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
|
||||
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
|
||||
|
22
cmd/main.go
22
cmd/main.go
@@ -14,6 +14,8 @@ type Flags struct {
|
||||
ModelName string
|
||||
ModelVersion string
|
||||
ModelType string
|
||||
MinProbability float64
|
||||
MaxIOU float64
|
||||
URL string
|
||||
Image string
|
||||
Benchmark bool
|
||||
@@ -25,6 +27,8 @@ func parseFlags() Flags {
|
||||
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
|
||||
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
|
||||
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]")
|
||||
flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability")
|
||||
flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)")
|
||||
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
|
||||
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.")
|
||||
flag.BoolVar(&flags.Benchmark, "b", false, "Run benchmark.")
|
||||
@@ -37,14 +41,26 @@ func main() {
|
||||
FLAGS := parseFlags()
|
||||
fmt.Println("FLAGS:", FLAGS)
|
||||
|
||||
cfg := yolotriton.YoloTritonConfig{
|
||||
ModelName: FLAGS.ModelName,
|
||||
ModelVersion: FLAGS.ModelVersion,
|
||||
MinProbability: float32(FLAGS.MinProbability),
|
||||
MaxIOU: FLAGS.MaxIOU,
|
||||
Classes: yolotriton.YoloClasses,
|
||||
}
|
||||
|
||||
var model yolotriton.Model
|
||||
switch yolotriton.ModelType(FLAGS.ModelType) {
|
||||
case yolotriton.ModelTypeYoloV8:
|
||||
model = yolotriton.NewYoloV8(FLAGS.ModelName, FLAGS.ModelVersion)
|
||||
cfg.NumClasses = 80
|
||||
cfg.NumObjects = 8400
|
||||
model = yolotriton.NewYoloV8(cfg)
|
||||
case yolotriton.ModelTypeYoloNAS:
|
||||
model = yolotriton.NewYoloNAS(FLAGS.ModelName, FLAGS.ModelVersion)
|
||||
cfg.NumClasses = 80
|
||||
cfg.NumObjects = 8400
|
||||
model = yolotriton.NewYoloNAS(cfg)
|
||||
case yolotriton.ModelTypeYoloNASInt8:
|
||||
model = yolotriton.NewYoloNASInt8(FLAGS.ModelName, FLAGS.ModelVersion)
|
||||
model = yolotriton.NewYoloNASInt8(cfg)
|
||||
default:
|
||||
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType)
|
||||
}
|
||||
|
2
yolo.go
2
yolo.go
@@ -22,7 +22,6 @@ type Model interface {
|
||||
GetConfig() YoloTritonConfig
|
||||
PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error)
|
||||
PostProcess(rawOutputContents [][]byte) ([]Box, error)
|
||||
GetClass(index int) string
|
||||
}
|
||||
|
||||
type YoloTritonConfig struct {
|
||||
@@ -32,6 +31,7 @@ type YoloTritonConfig struct {
|
||||
ModelVersion string
|
||||
MinProbability float32
|
||||
MaxIOU float64
|
||||
Classes []string
|
||||
}
|
||||
|
||||
func New(url string, model Model) (*YoloTriton, error) {
|
||||
|
17
yolonas.go
17
yolonas.go
@@ -16,16 +16,9 @@ type YoloNAS struct {
|
||||
}
|
||||
}
|
||||
|
||||
func NewYoloNAS(modelName string, modelVersion string) Model {
|
||||
func NewYoloNAS(cfg YoloTritonConfig) Model {
|
||||
return &YoloNAS{
|
||||
YoloTritonConfig: YoloTritonConfig{
|
||||
NumClasses: 80,
|
||||
NumObjects: 8400,
|
||||
MinProbability: 0.5,
|
||||
MaxIOU: 0.7,
|
||||
ModelName: modelName,
|
||||
ModelVersion: modelVersion,
|
||||
},
|
||||
YoloTritonConfig: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,10 +28,6 @@ func (y *YoloNAS) GetConfig() YoloTritonConfig {
|
||||
return y.YoloTritonConfig
|
||||
}
|
||||
|
||||
func (y *YoloNAS) GetClass(index int) string {
|
||||
return yoloClasses[index]
|
||||
}
|
||||
|
||||
func (y *YoloNAS) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
||||
height := img.Bounds().Dy()
|
||||
width := img.Bounds().Dx()
|
||||
@@ -94,7 +83,7 @@ func (y *YoloNAS) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
label := y.GetClass(classID)
|
||||
label := y.Classes[classID]
|
||||
idx := (index * 4)
|
||||
x1raw := predBoxes[idx]
|
||||
y1raw := predBoxes[idx+1]
|
||||
|
@@ -16,14 +16,9 @@ type YoloNASInt8 struct {
|
||||
}
|
||||
}
|
||||
|
||||
func NewYoloNASInt8(modelName string, modelVersion string) Model {
|
||||
func NewYoloNASInt8(cfg YoloTritonConfig) Model {
|
||||
return &YoloNASInt8{
|
||||
YoloTritonConfig: YoloTritonConfig{
|
||||
MinProbability: 0.5,
|
||||
MaxIOU: 0.7,
|
||||
ModelName: modelName,
|
||||
ModelVersion: modelVersion,
|
||||
},
|
||||
YoloTritonConfig: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,10 +28,6 @@ func (y *YoloNASInt8) GetConfig() YoloTritonConfig {
|
||||
return y.YoloTritonConfig
|
||||
}
|
||||
|
||||
func (y *YoloNASInt8) GetClass(index int) string {
|
||||
return yoloClasses[index]
|
||||
}
|
||||
|
||||
func (y *YoloNASInt8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
||||
height := img.Bounds().Dy()
|
||||
width := img.Bounds().Dx()
|
||||
@@ -89,7 +80,7 @@ func (y *YoloNASInt8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
||||
}
|
||||
|
||||
classID := predClasses[index]
|
||||
label := y.GetClass(int(classID))
|
||||
label := y.Classes[classID]
|
||||
idx := (index * 4)
|
||||
x1raw := predBoxes[idx]
|
||||
y1raw := predBoxes[idx+1]
|
||||
|
17
yolov8.go
17
yolov8.go
@@ -14,16 +14,9 @@ type YoloV8 struct {
|
||||
}
|
||||
}
|
||||
|
||||
func NewYoloV8(modelName string, modelVersion string) Model {
|
||||
func NewYoloV8(cfg YoloTritonConfig) Model {
|
||||
return &YoloV8{
|
||||
YoloTritonConfig: YoloTritonConfig{
|
||||
NumClasses: 80,
|
||||
NumObjects: 8400,
|
||||
MinProbability: 0.5,
|
||||
MaxIOU: 0.7,
|
||||
ModelName: modelName,
|
||||
ModelVersion: modelVersion,
|
||||
},
|
||||
YoloTritonConfig: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,10 +26,6 @@ func (y *YoloV8) GetConfig() YoloTritonConfig {
|
||||
return y.YoloTritonConfig
|
||||
}
|
||||
|
||||
func (y *YoloV8) GetClass(index int) string {
|
||||
return yoloClasses[index]
|
||||
}
|
||||
|
||||
func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
||||
width := img.Bounds().Dx()
|
||||
height := img.Bounds().Dy()
|
||||
@@ -81,7 +70,7 @@ func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
label := y.GetClass(classID)
|
||||
label := y.Classes[classID]
|
||||
x1raw := output[index]
|
||||
y1raw := output[numObjects+index]
|
||||
w := output[2*numObjects+index]
|
||||
|
Reference in New Issue
Block a user