mirror of
https://github.com/dev6699/yolotriton.git
synced 2025-09-26 19:51:13 +08:00
feat: support yolo FP16 inputs/outputs
This commit is contained in:
@@ -16,7 +16,7 @@
|
|||||||
"moby": "true"
|
"moby": "true"
|
||||||
},
|
},
|
||||||
"ghcr.io/devcontainers/features/go:1": {
|
"ghcr.io/devcontainers/features/go:1": {
|
||||||
"version": "1.19"
|
"version": "1.25"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"privileged": true,
|
"privileged": true,
|
||||||
|
14
cmd/main.go
14
cmd/main.go
@@ -26,7 +26,7 @@ func parseFlags() Flags {
|
|||||||
var flags Flags
|
var flags Flags
|
||||||
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
|
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
|
||||||
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
|
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
|
||||||
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]")
|
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolofp16, yolofp32]")
|
||||||
flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability")
|
flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability")
|
||||||
flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)")
|
flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)")
|
||||||
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
|
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
|
||||||
@@ -51,16 +51,24 @@ func main() {
|
|||||||
|
|
||||||
var model yolotriton.Model
|
var model yolotriton.Model
|
||||||
switch yolotriton.ModelType(FLAGS.ModelType) {
|
switch yolotriton.ModelType(FLAGS.ModelType) {
|
||||||
case yolotriton.ModelTypeYoloV8:
|
case yolotriton.ModelTypeYoloFP16:
|
||||||
cfg.NumClasses = 80
|
cfg.NumClasses = 80
|
||||||
cfg.NumObjects = 8400
|
cfg.NumObjects = 8400
|
||||||
model = yolotriton.NewYoloV8(cfg)
|
model = yolotriton.NewYolo(cfg, true)
|
||||||
|
|
||||||
|
case yolotriton.ModelTypeYoloFP32:
|
||||||
|
cfg.NumClasses = 80
|
||||||
|
cfg.NumObjects = 8400
|
||||||
|
model = yolotriton.NewYolo(cfg, false)
|
||||||
|
|
||||||
case yolotriton.ModelTypeYoloNAS:
|
case yolotriton.ModelTypeYoloNAS:
|
||||||
cfg.NumClasses = 80
|
cfg.NumClasses = 80
|
||||||
cfg.NumObjects = 8400
|
cfg.NumObjects = 8400
|
||||||
model = yolotriton.NewYoloNAS(cfg)
|
model = yolotriton.NewYoloNAS(cfg)
|
||||||
|
|
||||||
case yolotriton.ModelTypeYoloNASInt8:
|
case yolotriton.ModelTypeYoloNASInt8:
|
||||||
model = yolotriton.NewYoloNASInt8(cfg)
|
model = yolotriton.NewYoloNASInt8(cfg)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType)
|
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType)
|
||||||
}
|
}
|
||||||
|
3
go.mod
3
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module github.com/dev6699/yolotriton
|
module github.com/dev6699/yolotriton
|
||||||
|
|
||||||
go 1.19
|
go 1.25
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
|
||||||
@@ -12,6 +12,7 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/golang/protobuf v1.5.3 // indirect
|
github.com/golang/protobuf v1.5.3 // indirect
|
||||||
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
golang.org/x/net v0.14.0 // indirect
|
golang.org/x/net v0.14.0 // indirect
|
||||||
golang.org/x/sys v0.11.0 // indirect
|
golang.org/x/sys v0.11.0 // indirect
|
||||||
golang.org/x/text v0.12.0 // indirect
|
golang.org/x/text v0.12.0 // indirect
|
||||||
|
2
go.sum
2
go.sum
@@ -7,6 +7,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||||
|
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||||
|
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
|
@@ -1,2 +1,2 @@
|
|||||||
name: "yolov8"
|
name: "yolov12"
|
||||||
platform: "tensorrt_plan"
|
platform: "tensorrt_plan"
|
@@ -3,10 +3,29 @@ package yolotriton
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
|
"github.com/x448/float16"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func bytesFP16ToFloat32Slice(data []byte) ([]float32, error) {
|
||||||
|
if len(data)%2 != 0 {
|
||||||
|
return nil, errors.New("byte slice length must be divisible by 2 for FP16")
|
||||||
|
}
|
||||||
|
|
||||||
|
count := len(data) / 2
|
||||||
|
floats := make([]float32, count)
|
||||||
|
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
bits := uint16(data[i*2]) | uint16(data[i*2+1])<<8
|
||||||
|
floats[i] = float16.Frombits(bits).Float32()
|
||||||
|
}
|
||||||
|
|
||||||
|
return floats, nil
|
||||||
|
}
|
||||||
|
|
||||||
func bytesToFloat32Slice(data []byte) ([]float32, error) {
|
func bytesToFloat32Slice(data []byte) ([]float32, error) {
|
||||||
t := []float32{}
|
t := []float32{}
|
||||||
|
|
||||||
|
@@ -1,12 +1,14 @@
|
|||||||
package yolotriton
|
package yolotriton
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"image"
|
"image"
|
||||||
"image/color"
|
"image/color"
|
||||||
"image/draw"
|
"image/draw"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/nfnt/resize"
|
"github.com/nfnt/resize"
|
||||||
|
"github.com/x448/float16"
|
||||||
)
|
)
|
||||||
|
|
||||||
func resizeImage(img image.Image, width, heigth uint) image.Image {
|
func resizeImage(img image.Image, width, heigth uint) image.Image {
|
||||||
@@ -18,6 +20,38 @@ func pixelRGBA(c color.Color) (r, g, b, a uint32) {
|
|||||||
return r >> 8, g >> 8, b >> 8, a >> 8
|
return r >> 8, g >> 8, b >> 8, a >> 8
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func imageToFloat16ByteSlice(img image.Image) [][]byte {
|
||||||
|
bounds := img.Bounds()
|
||||||
|
width, height := bounds.Max.X, bounds.Max.Y
|
||||||
|
|
||||||
|
size := width * height
|
||||||
|
|
||||||
|
// FP16 = 2 bytes
|
||||||
|
result := make([]byte, size*3*2)
|
||||||
|
|
||||||
|
// precompute byte offsets
|
||||||
|
rOff := 0
|
||||||
|
gOff := size * 2
|
||||||
|
bOff := size * 4
|
||||||
|
|
||||||
|
idx := 0
|
||||||
|
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||||
|
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||||
|
r, g, b, _ := img.At(x, y).RGBA()
|
||||||
|
r8 := float32(uint8(r>>8)) / 255
|
||||||
|
g8 := float32(uint8(g>>8)) / 255
|
||||||
|
b8 := float32(uint8(b>>8)) / 255
|
||||||
|
|
||||||
|
binary.LittleEndian.PutUint16(result[rOff+idx*2:], float16.Fromfloat32(r8).Bits())
|
||||||
|
binary.LittleEndian.PutUint16(result[gOff+idx*2:], float16.Fromfloat32(g8).Bits())
|
||||||
|
binary.LittleEndian.PutUint16(result[bOff+idx*2:], float16.Fromfloat32(b8).Bits())
|
||||||
|
idx++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return [][]byte{result}
|
||||||
|
}
|
||||||
|
|
||||||
func imageToFloat32Slice(img image.Image) []float32 {
|
func imageToFloat32Slice(img image.Image) []float32 {
|
||||||
bounds := img.Bounds()
|
bounds := img.Bounds()
|
||||||
width, height := bounds.Max.X, bounds.Max.Y
|
width, height := bounds.Max.X, bounds.Max.Y
|
||||||
|
28
yolo.go
28
yolo.go
@@ -13,7 +13,8 @@ import (
|
|||||||
type ModelType string
|
type ModelType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ModelTypeYoloV8 ModelType = "yolov8"
|
ModelTypeYoloFP16 ModelType = "yolofp16"
|
||||||
|
ModelTypeYoloFP32 ModelType = "yolofp32"
|
||||||
ModelTypeYoloNAS ModelType = "yolonas"
|
ModelTypeYoloNAS ModelType = "yolonas"
|
||||||
ModelTypeYoloNASInt8 ModelType = "yolonasint8"
|
ModelTypeYoloNASInt8 ModelType = "yolonasint8"
|
||||||
)
|
)
|
||||||
@@ -46,7 +47,10 @@ func New(url string, model Model) (*YoloTriton, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client := triton.NewGRPCInferenceServiceClient(conn)
|
||||||
|
|
||||||
return &YoloTriton{
|
return &YoloTriton{
|
||||||
|
client: client,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
model: model,
|
model: model,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
@@ -57,6 +61,7 @@ func New(url string, model Model) (*YoloTriton, error) {
|
|||||||
type YoloTriton struct {
|
type YoloTriton struct {
|
||||||
model Model
|
model Model
|
||||||
cfg YoloTritonConfig
|
cfg YoloTritonConfig
|
||||||
|
client triton.GRPCInferenceServiceClient
|
||||||
conn *grpc.ClientConn
|
conn *grpc.ClientConn
|
||||||
modelMetadata *modelMetadata
|
modelMetadata *modelMetadata
|
||||||
}
|
}
|
||||||
@@ -74,8 +79,7 @@ func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
|
|||||||
|
|
||||||
modelInferRequest := y.modelMetadata.formInferRequest(inputs)
|
modelInferRequest := y.modelMetadata.formInferRequest(inputs)
|
||||||
|
|
||||||
client := triton.NewGRPCInferenceServiceClient(y.conn)
|
inferResponse, err := ModelInferRequest(y.client, modelInferRequest)
|
||||||
inferResponse, err := ModelInferRequest(client, modelInferRequest)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -91,10 +95,11 @@ func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
|
|||||||
|
|
||||||
result := []Box{}
|
result := []Box{}
|
||||||
for len(boxes) > 0 {
|
for len(boxes) > 0 {
|
||||||
result = append(result, boxes[0])
|
chosen := boxes[0]
|
||||||
|
result = append(result, chosen)
|
||||||
tmp := []Box{}
|
tmp := []Box{}
|
||||||
for _, box := range boxes {
|
for _, box := range boxes[1:] {
|
||||||
if iou(boxes[0], box) < y.cfg.MaxIOU {
|
if iou(chosen, box) < y.cfg.MaxIOU {
|
||||||
tmp = append(tmp, box)
|
tmp = append(tmp, box)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -145,7 +150,7 @@ func (m *modelMetadata) formInferRequest(contents *triton.InferTensorContents) *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &triton.ModelInferRequest{
|
req := &triton.ModelInferRequest{
|
||||||
ModelName: m.modelName,
|
ModelName: m.modelName,
|
||||||
ModelVersion: m.modelVersion,
|
ModelVersion: m.modelVersion,
|
||||||
Inputs: []*triton.ModelInferRequest_InferInputTensor{
|
Inputs: []*triton.ModelInferRequest_InferInputTensor{
|
||||||
@@ -153,9 +158,16 @@ func (m *modelMetadata) formInferRequest(contents *triton.InferTensorContents) *
|
|||||||
Name: input.Name,
|
Name: input.Name,
|
||||||
Datatype: input.Datatype,
|
Datatype: input.Datatype,
|
||||||
Shape: input.Shape,
|
Shape: input.Shape,
|
||||||
Contents: contents,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Outputs: outputs,
|
Outputs: outputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(contents.BytesContents) > 0 && input.Datatype == "FP16" {
|
||||||
|
req.RawInputContents = contents.BytesContents
|
||||||
|
} else {
|
||||||
|
req.Inputs[0].Contents = contents
|
||||||
|
}
|
||||||
|
|
||||||
|
return req
|
||||||
}
|
}
|
||||||
|
@@ -6,47 +6,65 @@ import (
|
|||||||
triton "github.com/dev6699/yolotriton/grpc-client"
|
triton "github.com/dev6699/yolotriton/grpc-client"
|
||||||
)
|
)
|
||||||
|
|
||||||
type YoloV8 struct {
|
type Yolo struct {
|
||||||
YoloTritonConfig
|
YoloTritonConfig
|
||||||
metadata struct {
|
metadata struct {
|
||||||
scaleFactorW float32
|
scaleFactorW float32
|
||||||
scaleFactorH float32
|
scaleFactorH float32
|
||||||
}
|
}
|
||||||
|
io16 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewYoloV8(cfg YoloTritonConfig) Model {
|
func NewYolo(cfg YoloTritonConfig, io16 bool) Model {
|
||||||
return &YoloV8{
|
return &Yolo{
|
||||||
YoloTritonConfig: cfg,
|
YoloTritonConfig: cfg,
|
||||||
|
io16: io16,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Model = &YoloV8{}
|
var _ Model = &Yolo{}
|
||||||
|
|
||||||
func (y *YoloV8) GetConfig() YoloTritonConfig {
|
func (y *Yolo) GetConfig() YoloTritonConfig {
|
||||||
return y.YoloTritonConfig
|
return y.YoloTritonConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
func (y *Yolo) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
|
||||||
width := img.Bounds().Dx()
|
width := img.Bounds().Dx()
|
||||||
height := img.Bounds().Dy()
|
height := img.Bounds().Dy()
|
||||||
|
|
||||||
preprocessedImg := resizeImage(img, targetWidth, targetHeight)
|
preprocessedImg := resizeImage(img, targetWidth, targetHeight)
|
||||||
|
|
||||||
fp32Contents := imageToFloat32Slice(preprocessedImg)
|
|
||||||
|
|
||||||
y.metadata.scaleFactorW = float32(width) / float32(targetWidth)
|
y.metadata.scaleFactorW = float32(width) / float32(targetWidth)
|
||||||
y.metadata.scaleFactorH = float32(height) / float32(targetHeight)
|
y.metadata.scaleFactorH = float32(height) / float32(targetHeight)
|
||||||
|
|
||||||
contents := &triton.InferTensorContents{
|
if y.io16 {
|
||||||
Fp32Contents: fp32Contents,
|
bytesContents := imageToFloat16ByteSlice(preprocessedImg)
|
||||||
|
contents := &triton.InferTensorContents{
|
||||||
|
BytesContents: bytesContents,
|
||||||
|
}
|
||||||
|
return contents, nil
|
||||||
|
} else {
|
||||||
|
fp32Contents := imageToFloat32Slice(preprocessedImg)
|
||||||
|
contents := &triton.InferTensorContents{
|
||||||
|
Fp32Contents: fp32Contents,
|
||||||
|
}
|
||||||
|
return contents, nil
|
||||||
}
|
}
|
||||||
return contents, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
func (y *Yolo) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
|
||||||
output, err := bytesToFloat32Slice(rawOutputContents[0])
|
var output []float32
|
||||||
if err != nil {
|
var err error
|
||||||
return nil, err
|
if y.io16 {
|
||||||
|
output, err = bytesFP16ToFloat32Slice(rawOutputContents[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
output, err = bytesToFloat32Slice(rawOutputContents[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
numObjects := y.NumObjects
|
numObjects := y.NumObjects
|
Reference in New Issue
Block a user