mirror of
https://github.com/dev6699/yolotriton.git
synced 2025-09-27 03:56:07 +08:00
83 lines
2.0 KiB
Go
83 lines
2.0 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"strings"
|
|
|
|
"github.com/dev6699/yolotriton"
|
|
)
|
|
|
|
type Flags struct {
|
|
ModelName string
|
|
ModelVersion string
|
|
ModelType string
|
|
URL string
|
|
Image string
|
|
}
|
|
|
|
func parseFlags() Flags {
|
|
var flags Flags
|
|
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
|
|
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
|
|
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolov8]")
|
|
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
|
|
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.")
|
|
flag.Parse()
|
|
return flags
|
|
}
|
|
|
|
func main() {
|
|
FLAGS := parseFlags()
|
|
fmt.Println("FLAGS:", FLAGS)
|
|
|
|
var model yolotriton.Model
|
|
switch yolotriton.ModelType(FLAGS.ModelType) {
|
|
case yolotriton.ModelTypeYoloV8:
|
|
model = yolotriton.NewYoloV8(FLAGS.ModelName, FLAGS.ModelVersion)
|
|
case yolotriton.ModelTypeYoloNAS:
|
|
model = yolotriton.NewYoloNAS(FLAGS.ModelName, FLAGS.ModelVersion)
|
|
default:
|
|
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolov8]", FLAGS.ModelType)
|
|
}
|
|
|
|
ygt, err := yolotriton.New(FLAGS.URL, model)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
img, err := yolotriton.LoadImage(FLAGS.Image)
|
|
if err != nil {
|
|
log.Fatalf("Failed to preprocess image: %v", err)
|
|
}
|
|
|
|
results, err := ygt.Infer(img)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
for i, r := range results {
|
|
fmt.Println("prediction: ", i)
|
|
fmt.Println("class: ", r.Class)
|
|
fmt.Printf("confidence: %.2f\n", r.Probability)
|
|
fmt.Println("bboxes: [", int(r.X1), int(r.Y1), int(r.X2), int(r.Y2), "]")
|
|
fmt.Println("---------------------")
|
|
}
|
|
|
|
out, err := yolotriton.DrawBoundingBoxes(
|
|
img,
|
|
results,
|
|
int(float64(img.Bounds().Dx())*0.005),
|
|
float64(img.Bounds().Dx())*0.02,
|
|
)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
err = yolotriton.SaveImage(out, fmt.Sprintf("%s_%s_out.jpg", strings.Split(FLAGS.Image, ".")[0], FLAGS.ModelName))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|