mirror of
https://github.com/dev6699/yolotriton.git
synced 2025-09-26 19:51:13 +08:00
feat: support yolo FP16 inputs/outputs
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
"moby": "true"
|
||||
},
|
||||
"ghcr.io/devcontainers/features/go:1": {
|
||||
"version": "1.19"
|
||||
"version": "1.25"
|
||||
}
|
||||
},
|
||||
"privileged": true,
|
||||
|
14
cmd/main.go
14
cmd/main.go
@@ -26,7 +26,7 @@ func parseFlags() Flags {
|
||||
var flags Flags
|
||||
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
|
||||
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
|
||||
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]")
|
||||
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolofp16, yolofp32]")
|
||||
flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability")
|
||||
flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)")
|
||||
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
|
||||
@@ -51,16 +51,24 @@ func main() {
|
||||
|
||||
var model yolotriton.Model
|
||||
switch yolotriton.ModelType(FLAGS.ModelType) {
|
||||
case yolotriton.ModelTypeYoloV8:
|
||||
case yolotriton.ModelTypeYoloFP16:
|
||||
cfg.NumClasses = 80
|
||||
cfg.NumObjects = 8400
|
||||
model = yolotriton.NewYoloV8(cfg)
|
||||
model = yolotriton.NewYolo(cfg, true)
|
||||
|
||||
case yolotriton.ModelTypeYoloFP32:
|
||||
cfg.NumClasses = 80
|
||||
cfg.NumObjects = 8400
|
||||
model = yolotriton.NewYolo(cfg, false)
|
||||
|
||||
case yolotriton.ModelTypeYoloNAS:
|
||||
cfg.NumClasses = 80
|
||||
cfg.NumObjects = 8400
|
||||
model = yolotriton.NewYoloNAS(cfg)
|
||||
|
||||
case yolotriton.ModelTypeYoloNASInt8:
|
||||
model = yolotriton.NewYoloNASInt8(cfg)
|
||||
|
||||
default:
|
||||
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType)
|
||||
}
|
||||
|
3
go.mod
3
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/dev6699/yolotriton
|
||||
|
||||
go 1.19
|
||||
go 1.25
|
||||
|
||||
require (
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
|
||||
@@ -12,6 +12,7 @@ require (
|
||||
|
||||
require (
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
golang.org/x/net v0.14.0 // indirect
|
||||
golang.org/x/sys v0.11.0 // indirect
|
||||
golang.org/x/text v0.12.0 // indirect
|
||||
|
2
go.sum
2
go.sum
@@ -7,6 +7,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
|
@@ -1,2 +1,2 @@
|
||||
name: "yolov8"
|
||||
name: "yolov12"
|
||||
platform: "tensorrt_plan"
|
@@ -3,10 +3,29 @@ package yolotriton
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
func bytesFP16ToFloat32Slice(data []byte) ([]float32, error) {
|
||||
if len(data)%2 != 0 {
|
||||
return nil, errors.New("byte slice length must be divisible by 2 for FP16")
|
||||
}
|
||||
|
||||
count := len(data) / 2
|
||||
floats := make([]float32, count)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
bits := uint16(data[i*2]) | uint16(data[i*2+1])<<8
|
||||
floats[i] = float16.Frombits(bits).Float32()
|
||||
}
|
||||
|
||||
return floats, nil
|
||||
}
|
||||
|
||||
func bytesToFloat32Slice(data []byte) ([]float32, error) {
|
||||
t := []float32{}
|
||||
|
||||
|
@@ -1,12 +1,14 @@
|
||||
package yolotriton
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/draw"
|
||||
"math"
|
||||
|
||||
"github.com/nfnt/resize"
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
func resizeImage(img image.Image, width, heigth uint) image.Image {
|
||||
@@ -18,6 +20,38 @@ func pixelRGBA(c color.Color) (r, g, b, a uint32) {
|
||||
return r >> 8, g >> 8, b >> 8, a >> 8
|
||||
}
|
||||
|
||||
func imageToFloat16ByteSlice(img image.Image) [][]byte {
|
||||
bounds := img.Bounds()
|
||||
width, height := bounds.Max.X, bounds.Max.Y
|
||||
|
||||
size := width * height
|
||||
|
||||
// FP16 = 2 bytes
|
||||
result := make([]byte, size*3*2)
|
||||
|
||||
// precompute byte offsets
|
||||
rOff := 0
|
||||
gOff := size * 2
|
||||
bOff := size * 4
|
||||
|
||||
idx := 0
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
r, g, b, _ := img.At(x, y).RGBA()
|
||||
r8 := float32(uint8(r>>8)) / 255
|
||||
g8 := float32(uint8(g>>8)) / 255
|
||||
b8 := float32(uint8(b>>8)) / 255
|
||||
|
||||
binary.LittleEndian.PutUint16(result[rOff+idx*2:], float16.Fromfloat32(r8).Bits())
|
||||
binary.LittleEndian.PutUint16(result[gOff+idx*2:], float16.Fromfloat32(g8).Bits())
|
||||
binary.LittleEndian.PutUint16(result[bOff+idx*2:], float16.Fromfloat32(b8).Bits())
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
return [][]byte{result}
|
||||
}
|
||||
|
||||
func imageToFloat32Slice(img image.Image) []float32 {
|
||||
bounds := img.Bounds()
|
||||
width, height := bounds.Max.X, bounds.Max.Y
|
||||
|
28
yolo.go
28
yolo.go
@@ -13,7 +13,8 @@ import (
|
||||
type ModelType string
|
||||
|
||||
const (
|
||||
ModelTypeYoloV8 ModelType = "yolov8"
|
||||
ModelTypeYoloFP16 ModelType = "yolofp16"
|
||||
ModelTypeYoloFP32 ModelType = "yolofp32"
|
||||
ModelTypeYoloNAS ModelType = "yolonas"
|
||||
ModelTypeYoloNASInt8 ModelType = "yolonasint8"
|
||||
)
|
||||
@@ -46,7 +47,10 @@ func New(url string, model Model) (*YoloTriton, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := triton.NewGRPCInferenceServiceClient(conn)
|
||||
|
||||
return &YoloTriton{
|
||||
client: client,
|
||||
conn: conn,
|
||||
model: model,
|
||||
cfg: cfg,
|
||||
@@ -57,6 +61,7 @@ func New(url string, model Model) (*YoloTriton, error) {
|
||||
type YoloTriton struct {
|
||||
model Model
|
||||
cfg YoloTritonConfig
|
||||
client triton.GRPCInferenceServiceClient
|
||||
conn *grpc.ClientConn
|
||||
modelMetadata *modelMetadata
|
||||
}
|
||||
@@ -74,8 +79,7 @@ func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
|
||||
|
||||
modelInferRequest := y.modelMetadata.formInferRequest(inputs)
|
||||
|
||||
client := triton.NewGRPCInferenceServiceClient(y.conn)
|
||||
inferResponse, err := ModelInferRequest(client, modelInferRequest)
|
||||
inferResponse, err := ModelInferRequest(y.client, modelInferRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -91,10 +95,11 @@ func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
|
||||
|
||||
result := []Box{}
|
||||
for len(boxes) > 0 {
|
||||
result = append(result, boxes[0])
|
||||
chosen := boxes[0]
|
||||
result = append(result, chosen)
|
||||
tmp := []Box{}
|
||||
for _, box := range boxes {
|
||||
if iou(boxes[0], box) < y.cfg.MaxIOU {
|
||||
for _, box := range boxes[1:] {
|
||||
if iou(chosen, box) < y.cfg.MaxIOU {
|
||||
tmp = append(tmp, box)
|
||||
}
|
||||
}
|
||||
@@ -145,7 +150,7 @@ func (m *modelMetadata) formInferRequest(contents *triton.InferTensorContents) *
|
||||
}
|
||||
}
|
||||
|
||||
return &triton.ModelInferRequest{
|
||||
req := &triton.ModelInferRequest{
|
||||
ModelName: m.modelName,
|
||||
ModelVersion: m.modelVersion,
|
||||
Inputs: []*triton.ModelInferRequest_InferInputTensor{
|
||||
@@ -153,9 +158,16 @@ func (m *modelMetadata) formInferRequest(contents *triton.InferTensorContents) *
|
||||
Name: input.Name,
|
||||
Datatype: input.Datatype,
|
||||
Shape: input.Shape,
|
||||
Contents: contents,
|
||||
},
|
||||
},
|
||||
Outputs: outputs,
|
||||
}
|
||||
|
||||
if len(contents.BytesContents) > 0 && input.Datatype == "FP16" {
|
||||
req.RawInputContents = contents.BytesContents
|
||||
} else {
|
||||
req.Inputs[0].Contents = contents
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
@@ -6,48 +6,66 @@ import (
|
||||
triton "github.com/dev6699/yolotriton/grpc-client"
|
||||
)
|
||||
|
||||
type YoloV8 struct {
|
||||
type Yolo struct {
|
||||
YoloTritonConfig
|
||||
metadata struct {
|
||||
scaleFactorW float32
|
||||
scaleFactorH float32
|
||||
}
|
||||
io16 bool
|
||||
}
|
||||
|
||||
func NewYoloV8(cfg YoloTritonConfig) Model {
|
||||
return &YoloV8{
|
||||
func NewYolo(cfg YoloTritonConfig, io16 bool) Model {
|
||||
return &Yolo{
|
||||
YoloTritonConfig: cfg,
|
||||
io16: io16,
|
||||
}
|
||||
}
|
||||
|
||||
var _ Model = &YoloV8{}
|
||||
var _ Model = &Yolo{}
|
||||
|
||||
func (y *YoloV8) GetConfig() YoloTritonConfig {
|
||||
func (y *Yolo) GetConfig() YoloTritonConfig {
|
||||
return y.YoloTritonConfig
|
||||
}
|
||||
|
||||
func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
||||
func (y *Yolo) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
||||
width := img.Bounds().Dx()
|
||||
height := img.Bounds().Dy()
|
||||
|
||||
preprocessedImg := resizeImage(img, targetWidth, targetHeight)
|
||||
|
||||
fp32Contents := imageToFloat32Slice(preprocessedImg)
|
||||
|
||||
y.metadata.scaleFactorW = float32(width) / float32(targetWidth)
|
||||
y.metadata.scaleFactorH = float32(height) / float32(targetHeight)
|
||||
|
||||
if y.io16 {
|
||||
bytesContents := imageToFloat16ByteSlice(preprocessedImg)
|
||||
contents := &triton.InferTensorContents{
|
||||
BytesContents: bytesContents,
|
||||
}
|
||||
return contents, nil
|
||||
} else {
|
||||
fp32Contents := imageToFloat32Slice(preprocessedImg)
|
||||
contents := &triton.InferTensorContents{
|
||||
Fp32Contents: fp32Contents,
|
||||
}
|
||||
return contents, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
||||
output, err := bytesToFloat32Slice(rawOutputContents[0])
|
||||
func (y *Yolo) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
||||
var output []float32
|
||||
var err error
|
||||
if y.io16 {
|
||||
output, err = bytesFP16ToFloat32Slice(rawOutputContents[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
output, err = bytesToFloat32Slice(rawOutputContents[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
numObjects := y.NumObjects
|
||||
numClasses := y.NumClasses
|
Reference in New Issue
Block a user