diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index e3c31e8..7333f01 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -16,7 +16,7 @@ "moby": "true" }, "ghcr.io/devcontainers/features/go:1": { - "version": "1.19" + "version": "1.25" } }, "privileged": true, diff --git a/cmd/main.go b/cmd/main.go index 5bff165..3e36cb6 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -26,7 +26,7 @@ func parseFlags() Flags { var flags Flags flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)") flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version") - flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]") + flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolofp16, yolofp32]") flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability") flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)") flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.") @@ -51,16 +51,24 @@ func main() { var model yolotriton.Model switch yolotriton.ModelType(FLAGS.ModelType) { - case yolotriton.ModelTypeYoloV8: + case yolotriton.ModelTypeYoloFP16: cfg.NumClasses = 80 cfg.NumObjects = 8400 - model = yolotriton.NewYoloV8(cfg) + model = yolotriton.NewYolo(cfg, true) + + case yolotriton.ModelTypeYoloFP32: + cfg.NumClasses = 80 + cfg.NumObjects = 8400 + model = yolotriton.NewYolo(cfg, false) + case yolotriton.ModelTypeYoloNAS: cfg.NumClasses = 80 cfg.NumObjects = 8400 model = yolotriton.NewYoloNAS(cfg) + case yolotriton.ModelTypeYoloNASInt8: model = yolotriton.NewYoloNASInt8(cfg) + default: log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType) } diff --git a/go.mod b/go.mod index 8dd7694..40f1160 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/dev6699/yolotriton -go 1.19 +go 1.25 require ( github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 @@ -12,6 +12,7 @@ require ( require ( github.com/golang/protobuf v1.5.3 // indirect + github.com/x448/float16 v0.8.4 // indirect golang.org/x/net v0.14.0 // indirect golang.org/x/sys v0.11.0 // indirect golang.org/x/text v0.12.0 // indirect diff --git a/go.sum b/go.sum index 30c72c2..2cc2f0e 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= diff --git a/model_repository/yolov8/1/.gitkeep b/model_repository/yolov12/1/.gitkeep similarity index 100% rename from model_repository/yolov8/1/.gitkeep rename to model_repository/yolov12/1/.gitkeep diff --git a/model_repository/yolov8/config.pbtxt b/model_repository/yolov12/config.pbtxt similarity index 60% rename from model_repository/yolov8/config.pbtxt rename to model_repository/yolov12/config.pbtxt index fe5f2e4..925aa5e 100644 --- a/model_repository/yolov8/config.pbtxt +++ b/model_repository/yolov12/config.pbtxt @@ -1,2 +1,2 @@ -name: "yolov8" +name: "yolov12" platform: "tensorrt_plan" \ No newline at end of file diff --git a/postprocess.go b/postprocess.go index 07d2809..b774b5a 100644 --- a/postprocess.go +++ b/postprocess.go @@ -3,10 +3,29 @@ package yolotriton import ( "bytes" "encoding/binary" + "errors" "io" "math" + + "github.com/x448/float16" ) +func bytesFP16ToFloat32Slice(data []byte) ([]float32, error) { + if len(data)%2 != 0 { + return nil, errors.New("byte slice length must be divisible by 2 for FP16") + } + + count := len(data) / 2 + floats := make([]float32, count) + + for i := 0; i < count; i++ { + bits := uint16(data[i*2]) | uint16(data[i*2+1])<<8 + floats[i] = float16.Frombits(bits).Float32() + } + + return floats, nil +} + func bytesToFloat32Slice(data []byte) ([]float32, error) { t := []float32{} diff --git a/preprocess.go b/preprocess.go index e583737..7a5d233 100644 --- a/preprocess.go +++ b/preprocess.go @@ -1,12 +1,14 @@ package yolotriton import ( + "encoding/binary" "image" "image/color" "image/draw" "math" "github.com/nfnt/resize" + "github.com/x448/float16" ) func resizeImage(img image.Image, width, heigth uint) image.Image { @@ -18,6 +20,38 @@ func pixelRGBA(c color.Color) (r, g, b, a uint32) { return r >> 8, g >> 8, b >> 8, a >> 8 } +func imageToFloat16ByteSlice(img image.Image) [][]byte { + bounds := img.Bounds() + width, height := bounds.Max.X, bounds.Max.Y + + size := width * height + + // FP16 = 2 bytes + result := make([]byte, size*3*2) + + // precompute byte offsets + rOff := 0 + gOff := size * 2 + bOff := size * 4 + + idx := 0 + for y := bounds.Min.Y; y < bounds.Max.Y; y++ { + for x := bounds.Min.X; x < bounds.Max.X; x++ { + r, g, b, _ := img.At(x, y).RGBA() + r8 := float32(uint8(r>>8)) / 255 + g8 := float32(uint8(g>>8)) / 255 + b8 := float32(uint8(b>>8)) / 255 + + binary.LittleEndian.PutUint16(result[rOff+idx*2:], float16.Fromfloat32(r8).Bits()) + binary.LittleEndian.PutUint16(result[gOff+idx*2:], float16.Fromfloat32(g8).Bits()) + binary.LittleEndian.PutUint16(result[bOff+idx*2:], float16.Fromfloat32(b8).Bits()) + idx++ + } + } + + return [][]byte{result} +} + func imageToFloat32Slice(img image.Image) []float32 { bounds := img.Bounds() width, height := bounds.Max.X, bounds.Max.Y diff --git a/yolo.go b/yolo.go index 31987c5..758946d 100644 --- a/yolo.go +++ b/yolo.go @@ -13,7 +13,8 @@ import ( type ModelType string const ( - ModelTypeYoloV8 ModelType = "yolov8" + ModelTypeYoloFP16 ModelType = "yolofp16" + ModelTypeYoloFP32 ModelType = "yolofp32" ModelTypeYoloNAS ModelType = "yolonas" ModelTypeYoloNASInt8 ModelType = "yolonasint8" ) @@ -46,7 +47,10 @@ func New(url string, model Model) (*YoloTriton, error) { return nil, err } + client := triton.NewGRPCInferenceServiceClient(conn) + return &YoloTriton{ + client: client, conn: conn, model: model, cfg: cfg, @@ -57,6 +61,7 @@ func New(url string, model Model) (*YoloTriton, error) { type YoloTriton struct { model Model cfg YoloTritonConfig + client triton.GRPCInferenceServiceClient conn *grpc.ClientConn modelMetadata *modelMetadata } @@ -74,8 +79,7 @@ func (y *YoloTriton) Infer(img image.Image) ([]Box, error) { modelInferRequest := y.modelMetadata.formInferRequest(inputs) - client := triton.NewGRPCInferenceServiceClient(y.conn) - inferResponse, err := ModelInferRequest(client, modelInferRequest) + inferResponse, err := ModelInferRequest(y.client, modelInferRequest) if err != nil { return nil, err } @@ -91,10 +95,11 @@ func (y *YoloTriton) Infer(img image.Image) ([]Box, error) { result := []Box{} for len(boxes) > 0 { - result = append(result, boxes[0]) + chosen := boxes[0] + result = append(result, chosen) tmp := []Box{} - for _, box := range boxes { - if iou(boxes[0], box) < y.cfg.MaxIOU { + for _, box := range boxes[1:] { + if iou(chosen, box) < y.cfg.MaxIOU { tmp = append(tmp, box) } } @@ -145,7 +150,7 @@ func (m *modelMetadata) formInferRequest(contents *triton.InferTensorContents) * } } - return &triton.ModelInferRequest{ + req := &triton.ModelInferRequest{ ModelName: m.modelName, ModelVersion: m.modelVersion, Inputs: []*triton.ModelInferRequest_InferInputTensor{ @@ -153,9 +158,16 @@ func (m *modelMetadata) formInferRequest(contents *triton.InferTensorContents) * Name: input.Name, Datatype: input.Datatype, Shape: input.Shape, - Contents: contents, }, }, Outputs: outputs, } + + if len(contents.BytesContents) > 0 && input.Datatype == "FP16" { + req.RawInputContents = contents.BytesContents + } else { + req.Inputs[0].Contents = contents + } + + return req } diff --git a/yolov8.go b/yolovx.go similarity index 59% rename from yolov8.go rename to yolovx.go index bfdffc5..05459ad 100644 --- a/yolov8.go +++ b/yolovx.go @@ -6,47 +6,65 @@ import ( triton "github.com/dev6699/yolotriton/grpc-client" ) -type YoloV8 struct { +type Yolo struct { YoloTritonConfig metadata struct { scaleFactorW float32 scaleFactorH float32 } + io16 bool } -func NewYoloV8(cfg YoloTritonConfig) Model { - return &YoloV8{ +func NewYolo(cfg YoloTritonConfig, io16 bool) Model { + return &Yolo{ YoloTritonConfig: cfg, + io16: io16, } } -var _ Model = &YoloV8{} +var _ Model = &Yolo{} -func (y *YoloV8) GetConfig() YoloTritonConfig { +func (y *Yolo) GetConfig() YoloTritonConfig { return y.YoloTritonConfig } -func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) { +func (y *Yolo) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) { width := img.Bounds().Dx() height := img.Bounds().Dy() preprocessedImg := resizeImage(img, targetWidth, targetHeight) - fp32Contents := imageToFloat32Slice(preprocessedImg) - y.metadata.scaleFactorW = float32(width) / float32(targetWidth) y.metadata.scaleFactorH = float32(height) / float32(targetHeight) - contents := &triton.InferTensorContents{ - Fp32Contents: fp32Contents, + if y.io16 { + bytesContents := imageToFloat16ByteSlice(preprocessedImg) + contents := &triton.InferTensorContents{ + BytesContents: bytesContents, + } + return contents, nil + } else { + fp32Contents := imageToFloat32Slice(preprocessedImg) + contents := &triton.InferTensorContents{ + Fp32Contents: fp32Contents, + } + return contents, nil } - return contents, nil } -func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) { - output, err := bytesToFloat32Slice(rawOutputContents[0]) - if err != nil { - return nil, err +func (y *Yolo) PostProcess(rawOutputContents [][]byte) ([]Box, error) { + var output []float32 + var err error + if y.io16 { + output, err = bytesFP16ToFloat32Slice(rawOutputContents[0]) + if err != nil { + return nil, err + } + } else { + output, err = bytesToFloat32Slice(rawOutputContents[0]) + if err != nil { + return nil, err + } } numObjects := y.NumObjects