Files
yolotriton/yolo.go
2025-08-23 10:25:53 +00:00

174 lines
3.9 KiB
Go

package yolotriton
import (
"image"
_ "image/png"
"sort"
triton "github.com/dev6699/yolotriton/grpc-client"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
type ModelType string
const (
ModelTypeYoloFP16 ModelType = "yolofp16"
ModelTypeYoloFP32 ModelType = "yolofp32"
ModelTypeYoloNAS ModelType = "yolonas"
ModelTypeYoloNASInt8 ModelType = "yolonasint8"
)
type Model interface {
GetConfig() YoloTritonConfig
PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error)
PostProcess(rawOutputContents [][]byte) ([]Box, error)
}
type YoloTritonConfig struct {
NumClasses int
NumObjects int
ModelName string
ModelVersion string
MinProbability float32
MaxIOU float64
Classes []string
}
func New(url string, model Model) (*YoloTriton, error) {
conn, err := grpc.Dial(url, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
cfg := model.GetConfig()
modelMetadata, err := newModelMetadata(conn, cfg.ModelName, cfg.ModelVersion)
if err != nil {
return nil, err
}
client := triton.NewGRPCInferenceServiceClient(conn)
return &YoloTriton{
client: client,
conn: conn,
model: model,
cfg: cfg,
modelMetadata: modelMetadata,
}, nil
}
type YoloTriton struct {
model Model
cfg YoloTritonConfig
client triton.GRPCInferenceServiceClient
conn *grpc.ClientConn
modelMetadata *modelMetadata
}
func (y *YoloTriton) Close() error {
return y.conn.Close()
}
func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
inputs, err := y.model.PreProcess(img, y.modelMetadata.inputWidth(), y.modelMetadata.inputHeight())
if err != nil {
return nil, err
}
modelInferRequest := y.modelMetadata.formInferRequest(inputs)
inferResponse, err := ModelInferRequest(y.client, modelInferRequest)
if err != nil {
return nil, err
}
boxes, err := y.model.PostProcess(inferResponse.RawOutputContents)
if err != nil {
return nil, err
}
sort.Slice(boxes, func(i, j int) bool {
return boxes[i].Probability > boxes[j].Probability
})
result := []Box{}
for len(boxes) > 0 {
chosen := boxes[0]
result = append(result, chosen)
tmp := []Box{}
for _, box := range boxes[1:] {
if iou(chosen, box) < y.cfg.MaxIOU {
tmp = append(tmp, box)
}
}
boxes = tmp
}
return result, nil
}
type modelMetadata struct {
modelName string
modelVersion string
*triton.ModelMetadataResponse
}
func newModelMetadata(conn *grpc.ClientConn, modelName string, modelVersion string) (*modelMetadata, error) {
client := triton.NewGRPCInferenceServiceClient(conn)
metaResponse, err := ModelMetadataRequest(client, modelName, modelVersion)
if err != nil {
return nil, err
}
return &modelMetadata{
modelName: modelName,
modelVersion: modelVersion,
ModelMetadataResponse: metaResponse,
}, nil
}
func (m *modelMetadata) inputWidth() uint {
return uint(m.Inputs[0].Shape[2])
}
func (m *modelMetadata) inputHeight() uint {
return uint(m.Inputs[0].Shape[3])
}
func (m *modelMetadata) formInferRequest(contents *triton.InferTensorContents) *triton.ModelInferRequest {
input := m.Inputs[0]
if input.Shape[0] == -1 {
input.Shape[0] = 1
}
outputs := make([]*triton.ModelInferRequest_InferRequestedOutputTensor, len(m.Outputs))
for i, o := range m.Outputs {
outputs[i] = &triton.ModelInferRequest_InferRequestedOutputTensor{
Name: o.Name,
}
}
req := &triton.ModelInferRequest{
ModelName: m.modelName,
ModelVersion: m.modelVersion,
Inputs: []*triton.ModelInferRequest_InferInputTensor{
{
Name: input.Name,
Datatype: input.Datatype,
Shape: input.Shape,
},
},
Outputs: outputs,
}
if len(contents.BytesContents) > 0 && input.Datatype == "FP16" {
req.RawInputContents = contents.BytesContents
} else {
req.Inputs[0].Contents = contents
}
return req
}