mirror of
https://github.com/dev6699/yolotriton.git
synced 2025-09-26 19:51:13 +08:00
feat: extract config, added MinProbability, MaxIOU flag
This commit is contained in:
@@ -77,6 +77,10 @@ go run cmd/main.go --help
|
|||||||
Name of model being served (Required) (default "yolonas")
|
Name of model being served (Required) (default "yolonas")
|
||||||
-n int
|
-n int
|
||||||
Number of benchmark run. (default 1)
|
Number of benchmark run. (default 1)
|
||||||
|
-o float
|
||||||
|
Intersection over Union (IoU) (default 0.7)
|
||||||
|
-p float
|
||||||
|
Minimum probability (default 0.5)
|
||||||
-t string
|
-t string
|
||||||
Type of model. Available options: [yolonas, yolonasint8, yolov8] (default "yolonas")
|
Type of model. Available options: [yolonas, yolonasint8, yolov8] (default "yolonas")
|
||||||
-u string
|
-u string
|
||||||
|
2
class.go
2
class.go
@@ -1,6 +1,6 @@
|
|||||||
package yolotriton
|
package yolotriton
|
||||||
|
|
||||||
var yoloClasses = []string{
|
var YoloClasses = []string{
|
||||||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
|
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
|
||||||
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
|
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
|
||||||
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
|
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
|
||||||
|
22
cmd/main.go
22
cmd/main.go
@@ -14,6 +14,8 @@ type Flags struct {
|
|||||||
ModelName string
|
ModelName string
|
||||||
ModelVersion string
|
ModelVersion string
|
||||||
ModelType string
|
ModelType string
|
||||||
|
MinProbability float64
|
||||||
|
MaxIOU float64
|
||||||
URL string
|
URL string
|
||||||
Image string
|
Image string
|
||||||
Benchmark bool
|
Benchmark bool
|
||||||
@@ -25,6 +27,8 @@ func parseFlags() Flags {
|
|||||||
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
|
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
|
||||||
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
|
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
|
||||||
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]")
|
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]")
|
||||||
|
flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability")
|
||||||
|
flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)")
|
||||||
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
|
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
|
||||||
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.")
|
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.")
|
||||||
flag.BoolVar(&flags.Benchmark, "b", false, "Run benchmark.")
|
flag.BoolVar(&flags.Benchmark, "b", false, "Run benchmark.")
|
||||||
@@ -37,14 +41,26 @@ func main() {
|
|||||||
FLAGS := parseFlags()
|
FLAGS := parseFlags()
|
||||||
fmt.Println("FLAGS:", FLAGS)
|
fmt.Println("FLAGS:", FLAGS)
|
||||||
|
|
||||||
|
cfg := yolotriton.YoloTritonConfig{
|
||||||
|
ModelName: FLAGS.ModelName,
|
||||||
|
ModelVersion: FLAGS.ModelVersion,
|
||||||
|
MinProbability: float32(FLAGS.MinProbability),
|
||||||
|
MaxIOU: FLAGS.MaxIOU,
|
||||||
|
Classes: yolotriton.YoloClasses,
|
||||||
|
}
|
||||||
|
|
||||||
var model yolotriton.Model
|
var model yolotriton.Model
|
||||||
switch yolotriton.ModelType(FLAGS.ModelType) {
|
switch yolotriton.ModelType(FLAGS.ModelType) {
|
||||||
case yolotriton.ModelTypeYoloV8:
|
case yolotriton.ModelTypeYoloV8:
|
||||||
model = yolotriton.NewYoloV8(FLAGS.ModelName, FLAGS.ModelVersion)
|
cfg.NumClasses = 80
|
||||||
|
cfg.NumObjects = 8400
|
||||||
|
model = yolotriton.NewYoloV8(cfg)
|
||||||
case yolotriton.ModelTypeYoloNAS:
|
case yolotriton.ModelTypeYoloNAS:
|
||||||
model = yolotriton.NewYoloNAS(FLAGS.ModelName, FLAGS.ModelVersion)
|
cfg.NumClasses = 80
|
||||||
|
cfg.NumObjects = 8400
|
||||||
|
model = yolotriton.NewYoloNAS(cfg)
|
||||||
case yolotriton.ModelTypeYoloNASInt8:
|
case yolotriton.ModelTypeYoloNASInt8:
|
||||||
model = yolotriton.NewYoloNASInt8(FLAGS.ModelName, FLAGS.ModelVersion)
|
model = yolotriton.NewYoloNASInt8(cfg)
|
||||||
default:
|
default:
|
||||||
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType)
|
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType)
|
||||||
}
|
}
|
||||||
|
2
yolo.go
2
yolo.go
@@ -22,7 +22,6 @@ type Model interface {
|
|||||||
GetConfig() YoloTritonConfig
|
GetConfig() YoloTritonConfig
|
||||||
PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error)
|
PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error)
|
||||||
PostProcess(rawOutputContents [][]byte) ([]Box, error)
|
PostProcess(rawOutputContents [][]byte) ([]Box, error)
|
||||||
GetClass(index int) string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type YoloTritonConfig struct {
|
type YoloTritonConfig struct {
|
||||||
@@ -32,6 +31,7 @@ type YoloTritonConfig struct {
|
|||||||
ModelVersion string
|
ModelVersion string
|
||||||
MinProbability float32
|
MinProbability float32
|
||||||
MaxIOU float64
|
MaxIOU float64
|
||||||
|
Classes []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(url string, model Model) (*YoloTriton, error) {
|
func New(url string, model Model) (*YoloTriton, error) {
|
||||||
|
17
yolonas.go
17
yolonas.go
@@ -16,16 +16,9 @@ type YoloNAS struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewYoloNAS(modelName string, modelVersion string) Model {
|
func NewYoloNAS(cfg YoloTritonConfig) Model {
|
||||||
return &YoloNAS{
|
return &YoloNAS{
|
||||||
YoloTritonConfig: YoloTritonConfig{
|
YoloTritonConfig: cfg,
|
||||||
NumClasses: 80,
|
|
||||||
NumObjects: 8400,
|
|
||||||
MinProbability: 0.5,
|
|
||||||
MaxIOU: 0.7,
|
|
||||||
ModelName: modelName,
|
|
||||||
ModelVersion: modelVersion,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,10 +28,6 @@ func (y *YoloNAS) GetConfig() YoloTritonConfig {
|
|||||||
return y.YoloTritonConfig
|
return y.YoloTritonConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (y *YoloNAS) GetClass(index int) string {
|
|
||||||
return yoloClasses[index]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (y *YoloNAS) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
func (y *YoloNAS) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
||||||
height := img.Bounds().Dy()
|
height := img.Bounds().Dy()
|
||||||
width := img.Bounds().Dx()
|
width := img.Bounds().Dx()
|
||||||
@@ -94,7 +83,7 @@ func (y *YoloNAS) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
label := y.GetClass(classID)
|
label := y.Classes[classID]
|
||||||
idx := (index * 4)
|
idx := (index * 4)
|
||||||
x1raw := predBoxes[idx]
|
x1raw := predBoxes[idx]
|
||||||
y1raw := predBoxes[idx+1]
|
y1raw := predBoxes[idx+1]
|
||||||
|
@@ -16,14 +16,9 @@ type YoloNASInt8 struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewYoloNASInt8(modelName string, modelVersion string) Model {
|
func NewYoloNASInt8(cfg YoloTritonConfig) Model {
|
||||||
return &YoloNASInt8{
|
return &YoloNASInt8{
|
||||||
YoloTritonConfig: YoloTritonConfig{
|
YoloTritonConfig: cfg,
|
||||||
MinProbability: 0.5,
|
|
||||||
MaxIOU: 0.7,
|
|
||||||
ModelName: modelName,
|
|
||||||
ModelVersion: modelVersion,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,10 +28,6 @@ func (y *YoloNASInt8) GetConfig() YoloTritonConfig {
|
|||||||
return y.YoloTritonConfig
|
return y.YoloTritonConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (y *YoloNASInt8) GetClass(index int) string {
|
|
||||||
return yoloClasses[index]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (y *YoloNASInt8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
func (y *YoloNASInt8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
||||||
height := img.Bounds().Dy()
|
height := img.Bounds().Dy()
|
||||||
width := img.Bounds().Dx()
|
width := img.Bounds().Dx()
|
||||||
@@ -89,7 +80,7 @@ func (y *YoloNASInt8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
classID := predClasses[index]
|
classID := predClasses[index]
|
||||||
label := y.GetClass(int(classID))
|
label := y.Classes[classID]
|
||||||
idx := (index * 4)
|
idx := (index * 4)
|
||||||
x1raw := predBoxes[idx]
|
x1raw := predBoxes[idx]
|
||||||
y1raw := predBoxes[idx+1]
|
y1raw := predBoxes[idx+1]
|
||||||
|
17
yolov8.go
17
yolov8.go
@@ -14,16 +14,9 @@ type YoloV8 struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewYoloV8(modelName string, modelVersion string) Model {
|
func NewYoloV8(cfg YoloTritonConfig) Model {
|
||||||
return &YoloV8{
|
return &YoloV8{
|
||||||
YoloTritonConfig: YoloTritonConfig{
|
YoloTritonConfig: cfg,
|
||||||
NumClasses: 80,
|
|
||||||
NumObjects: 8400,
|
|
||||||
MinProbability: 0.5,
|
|
||||||
MaxIOU: 0.7,
|
|
||||||
ModelName: modelName,
|
|
||||||
ModelVersion: modelVersion,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,10 +26,6 @@ func (y *YoloV8) GetConfig() YoloTritonConfig {
|
|||||||
return y.YoloTritonConfig
|
return y.YoloTritonConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (y *YoloV8) GetClass(index int) string {
|
|
||||||
return yoloClasses[index]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
||||||
width := img.Bounds().Dx()
|
width := img.Bounds().Dx()
|
||||||
height := img.Bounds().Dy()
|
height := img.Bounds().Dy()
|
||||||
@@ -81,7 +70,7 @@ func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
label := y.GetClass(classID)
|
label := y.Classes[classID]
|
||||||
x1raw := output[index]
|
x1raw := output[index]
|
||||||
y1raw := output[numObjects+index]
|
y1raw := output[numObjects+index]
|
||||||
w := output[2*numObjects+index]
|
w := output[2*numObjects+index]
|
||||||
|
Reference in New Issue
Block a user