feat: support yolo FP16 inputs/outputs

This commit is contained in:
dev6699
2025-08-23 10:25:53 +00:00
parent 74c1760e58
commit 1c4d27d0b9
10 changed files with 123 additions and 29 deletions

View File

@@ -16,7 +16,7 @@
"moby": "true" "moby": "true"
}, },
"ghcr.io/devcontainers/features/go:1": { "ghcr.io/devcontainers/features/go:1": {
"version": "1.19" "version": "1.25"
} }
}, },
"privileged": true, "privileged": true,

View File

@@ -26,7 +26,7 @@ func parseFlags() Flags {
var flags Flags var flags Flags
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)") flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version") flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]") flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolofp16, yolofp32]")
flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability") flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability")
flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)") flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)")
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.") flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
@@ -51,16 +51,24 @@ func main() {
var model yolotriton.Model var model yolotriton.Model
switch yolotriton.ModelType(FLAGS.ModelType) { switch yolotriton.ModelType(FLAGS.ModelType) {
case yolotriton.ModelTypeYoloV8: case yolotriton.ModelTypeYoloFP16:
cfg.NumClasses = 80 cfg.NumClasses = 80
cfg.NumObjects = 8400 cfg.NumObjects = 8400
model = yolotriton.NewYoloV8(cfg) model = yolotriton.NewYolo(cfg, true)
case yolotriton.ModelTypeYoloFP32:
cfg.NumClasses = 80
cfg.NumObjects = 8400
model = yolotriton.NewYolo(cfg, false)
case yolotriton.ModelTypeYoloNAS: case yolotriton.ModelTypeYoloNAS:
cfg.NumClasses = 80 cfg.NumClasses = 80
cfg.NumObjects = 8400 cfg.NumObjects = 8400
model = yolotriton.NewYoloNAS(cfg) model = yolotriton.NewYoloNAS(cfg)
case yolotriton.ModelTypeYoloNASInt8: case yolotriton.ModelTypeYoloNASInt8:
model = yolotriton.NewYoloNASInt8(cfg) model = yolotriton.NewYoloNASInt8(cfg)
default: default:
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType) log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType)
} }

3
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/dev6699/yolotriton module github.com/dev6699/yolotriton
go 1.19 go 1.25
require ( require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
@@ -12,6 +12,7 @@ require (
require ( require (
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/net v0.14.0 // indirect golang.org/x/net v0.14.0 // indirect
golang.org/x/sys v0.11.0 // indirect golang.org/x/sys v0.11.0 // indirect
golang.org/x/text v0.12.0 // indirect golang.org/x/text v0.12.0 // indirect

2
go.sum
View File

@@ -7,6 +7,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=

View File

@@ -1,2 +1,2 @@
name: "yolov8" name: "yolov12"
platform: "tensorrt_plan" platform: "tensorrt_plan"

View File

@@ -3,10 +3,29 @@ package yolotriton
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"io" "io"
"math" "math"
"github.com/x448/float16"
) )
func bytesFP16ToFloat32Slice(data []byte) ([]float32, error) {
if len(data)%2 != 0 {
return nil, errors.New("byte slice length must be divisible by 2 for FP16")
}
count := len(data) / 2
floats := make([]float32, count)
for i := 0; i < count; i++ {
bits := uint16(data[i*2]) | uint16(data[i*2+1])<<8
floats[i] = float16.Frombits(bits).Float32()
}
return floats, nil
}
func bytesToFloat32Slice(data []byte) ([]float32, error) { func bytesToFloat32Slice(data []byte) ([]float32, error) {
t := []float32{} t := []float32{}

View File

@@ -1,12 +1,14 @@
package yolotriton package yolotriton
import ( import (
"encoding/binary"
"image" "image"
"image/color" "image/color"
"image/draw" "image/draw"
"math" "math"
"github.com/nfnt/resize" "github.com/nfnt/resize"
"github.com/x448/float16"
) )
func resizeImage(img image.Image, width, heigth uint) image.Image { func resizeImage(img image.Image, width, heigth uint) image.Image {
@@ -18,6 +20,38 @@ func pixelRGBA(c color.Color) (r, g, b, a uint32) {
return r >> 8, g >> 8, b >> 8, a >> 8 return r >> 8, g >> 8, b >> 8, a >> 8
} }
func imageToFloat16ByteSlice(img image.Image) [][]byte {
bounds := img.Bounds()
width, height := bounds.Max.X, bounds.Max.Y
size := width * height
// FP16 = 2 bytes
result := make([]byte, size*3*2)
// precompute byte offsets
rOff := 0
gOff := size * 2
bOff := size * 4
idx := 0
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
r, g, b, _ := img.At(x, y).RGBA()
r8 := float32(uint8(r>>8)) / 255
g8 := float32(uint8(g>>8)) / 255
b8 := float32(uint8(b>>8)) / 255
binary.LittleEndian.PutUint16(result[rOff+idx*2:], float16.Fromfloat32(r8).Bits())
binary.LittleEndian.PutUint16(result[gOff+idx*2:], float16.Fromfloat32(g8).Bits())
binary.LittleEndian.PutUint16(result[bOff+idx*2:], float16.Fromfloat32(b8).Bits())
idx++
}
}
return [][]byte{result}
}
func imageToFloat32Slice(img image.Image) []float32 { func imageToFloat32Slice(img image.Image) []float32 {
bounds := img.Bounds() bounds := img.Bounds()
width, height := bounds.Max.X, bounds.Max.Y width, height := bounds.Max.X, bounds.Max.Y

28
yolo.go
View File

@@ -13,7 +13,8 @@ import (
type ModelType string type ModelType string
const ( const (
ModelTypeYoloV8 ModelType = "yolov8" ModelTypeYoloFP16 ModelType = "yolofp16"
ModelTypeYoloFP32 ModelType = "yolofp32"
ModelTypeYoloNAS ModelType = "yolonas" ModelTypeYoloNAS ModelType = "yolonas"
ModelTypeYoloNASInt8 ModelType = "yolonasint8" ModelTypeYoloNASInt8 ModelType = "yolonasint8"
) )
@@ -46,7 +47,10 @@ func New(url string, model Model) (*YoloTriton, error) {
return nil, err return nil, err
} }
client := triton.NewGRPCInferenceServiceClient(conn)
return &YoloTriton{ return &YoloTriton{
client: client,
conn: conn, conn: conn,
model: model, model: model,
cfg: cfg, cfg: cfg,
@@ -57,6 +61,7 @@ func New(url string, model Model) (*YoloTriton, error) {
type YoloTriton struct { type YoloTriton struct {
model Model model Model
cfg YoloTritonConfig cfg YoloTritonConfig
client triton.GRPCInferenceServiceClient
conn *grpc.ClientConn conn *grpc.ClientConn
modelMetadata *modelMetadata modelMetadata *modelMetadata
} }
@@ -74,8 +79,7 @@ func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
modelInferRequest := y.modelMetadata.formInferRequest(inputs) modelInferRequest := y.modelMetadata.formInferRequest(inputs)
client := triton.NewGRPCInferenceServiceClient(y.conn) inferResponse, err := ModelInferRequest(y.client, modelInferRequest)
inferResponse, err := ModelInferRequest(client, modelInferRequest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -91,10 +95,11 @@ func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
result := []Box{} result := []Box{}
for len(boxes) > 0 { for len(boxes) > 0 {
result = append(result, boxes[0]) chosen := boxes[0]
result = append(result, chosen)
tmp := []Box{} tmp := []Box{}
for _, box := range boxes { for _, box := range boxes[1:] {
if iou(boxes[0], box) < y.cfg.MaxIOU { if iou(chosen, box) < y.cfg.MaxIOU {
tmp = append(tmp, box) tmp = append(tmp, box)
} }
} }
@@ -145,7 +150,7 @@ func (m *modelMetadata) formInferRequest(contents *triton.InferTensorContents) *
} }
} }
return &triton.ModelInferRequest{ req := &triton.ModelInferRequest{
ModelName: m.modelName, ModelName: m.modelName,
ModelVersion: m.modelVersion, ModelVersion: m.modelVersion,
Inputs: []*triton.ModelInferRequest_InferInputTensor{ Inputs: []*triton.ModelInferRequest_InferInputTensor{
@@ -153,9 +158,16 @@ func (m *modelMetadata) formInferRequest(contents *triton.InferTensorContents) *
Name: input.Name, Name: input.Name,
Datatype: input.Datatype, Datatype: input.Datatype,
Shape: input.Shape, Shape: input.Shape,
Contents: contents,
}, },
}, },
Outputs: outputs, Outputs: outputs,
} }
if len(contents.BytesContents) > 0 && input.Datatype == "FP16" {
req.RawInputContents = contents.BytesContents
} else {
req.Inputs[0].Contents = contents
}
return req
} }

View File

@@ -6,47 +6,65 @@ import (
triton "github.com/dev6699/yolotriton/grpc-client" triton "github.com/dev6699/yolotriton/grpc-client"
) )
type YoloV8 struct { type Yolo struct {
YoloTritonConfig YoloTritonConfig
metadata struct { metadata struct {
scaleFactorW float32 scaleFactorW float32
scaleFactorH float32 scaleFactorH float32
} }
io16 bool
} }
func NewYoloV8(cfg YoloTritonConfig) Model { func NewYolo(cfg YoloTritonConfig, io16 bool) Model {
return &YoloV8{ return &Yolo{
YoloTritonConfig: cfg, YoloTritonConfig: cfg,
io16: io16,
} }
} }
var _ Model = &YoloV8{} var _ Model = &Yolo{}
func (y *YoloV8) GetConfig() YoloTritonConfig { func (y *Yolo) GetConfig() YoloTritonConfig {
return y.YoloTritonConfig return y.YoloTritonConfig
} }
func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) { func (y *Yolo) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
width := img.Bounds().Dx() width := img.Bounds().Dx()
height := img.Bounds().Dy() height := img.Bounds().Dy()
preprocessedImg := resizeImage(img, targetWidth, targetHeight) preprocessedImg := resizeImage(img, targetWidth, targetHeight)
fp32Contents := imageToFloat32Slice(preprocessedImg)
y.metadata.scaleFactorW = float32(width) / float32(targetWidth) y.metadata.scaleFactorW = float32(width) / float32(targetWidth)
y.metadata.scaleFactorH = float32(height) / float32(targetHeight) y.metadata.scaleFactorH = float32(height) / float32(targetHeight)
contents := &triton.InferTensorContents{ if y.io16 {
Fp32Contents: fp32Contents, bytesContents := imageToFloat16ByteSlice(preprocessedImg)
contents := &triton.InferTensorContents{
BytesContents: bytesContents,
}
return contents, nil
} else {
fp32Contents := imageToFloat32Slice(preprocessedImg)
contents := &triton.InferTensorContents{
Fp32Contents: fp32Contents,
}
return contents, nil
} }
return contents, nil
} }
func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) { func (y *Yolo) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
output, err := bytesToFloat32Slice(rawOutputContents[0]) var output []float32
if err != nil { var err error
return nil, err if y.io16 {
output, err = bytesFP16ToFloat32Slice(rawOutputContents[0])
if err != nil {
return nil, err
}
} else {
output, err = bytesToFloat32Slice(rawOutputContents[0])
if err != nil {
return nil, err
}
} }
numObjects := y.NumObjects numObjects := y.NumObjects