This commit is contained in:
kweijack
2023-09-03 13:15:02 +00:00
commit cbcfdb9189
30 changed files with 18735 additions and 0 deletions

5
.devcontainer/Dockerfile Normal file
View File

@@ -0,0 +1,5 @@
# Note: You can use any Debian/Ubuntu based image you want.
FROM nvcr.io/nvidia/tensorrt:23.07-py3
# [Optional] Uncomment this section to install additional OS packages.
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends ffmpeg libsm6 libxext6

View File

@@ -0,0 +1,32 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-outside-of-docker-compose
{
"name": "Docker from Docker Compose",
"dockerComposeFile": "docker-compose.yml",
"service": "app",
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
// Use this environment variable if you need to bind mount your local source code into a new container.
"remoteEnv": {
"LOCAL_WORKSPACE_FOLDER": "${localWorkspaceFolder}"
},
"features": {
"ghcr.io/devcontainers/features/docker-outside-of-docker:1": {
"version": "latest",
"enableNonRootDocker": "true",
"moby": "true"
},
"ghcr.io/devcontainers/features/go:1": {
"version": "1.19"
}
},
"privileged": true,
"hostRequirements": {
"gpu": true
}
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "docker --version",
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}

View File

@@ -0,0 +1,26 @@
version: '3'
services:
app:
build:
context: .
dockerfile: Dockerfile
volumes:
# Forwards the local Docker socket to the container.
- /var/run/docker.sock:/var/run/docker-host.sock
# Update this to wherever you want VS Code to mount the folder of your project
- ../..:/workspaces:cached
# Overrides default command so things don't shut down after the process ends.
entrypoint: /usr/local/share/docker-init.sh
command: sleep infinity
# Uncomment the next four lines if you will use a ptrace-based debuggers like C++, Go, and Rust.
# cap_add:
# - SYS_PTRACE
# security_opt:
# - seccomp:unconfined
# Use "forwardPorts" in **devcontainer.json** to forward an app port locally.
# (Adding the "ports" property to this file will not forward from a Codespace.)

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 dev6699
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

58
README.md Normal file
View File

@@ -0,0 +1,58 @@
# yolotriton
[![GoDoc](https://pkg.go.dev/badge/github.com/dev6699/yolotriton)](https://pkg.go.dev/github.com/dev6699/yolotriton)
[![Go Report Card](https://goreportcard.com/badge/github.com/dev6699/yolotriton)](https://goreportcard.com/report/github.com/dev6699/yolotriton)
[![License](https://img.shields.io/github/license/dev6699/yolotriton)](LICENSE)
Go (Golang) gRPC client for YOLOv8 inference using the Triton Inference Server.
## Installation
Use `go get` to install this package:
```bash
go get github.com/dev6699/yolotriton
```
### Get YOLOv8 TensorRT model
```bash
pip install ultralytics
yolo export model=yolov8m.pt format=onnx
trtexec --onnx=yolov8m.onnx --saveEngine=model_repository/yolov8_tensorrt/1/model.plan
```
References:
1. https://docs.nvidia.com/deeplearning/tensorrt/quick-start-guide/index.html
2. https://docs.ultralytics.com/modes/export/
3. https://github.com/NVIDIA/TensorRT/tree/master/samples/trtexec
### Start trinton server
```bash
docker compose up tritonserver
```
References:
1. https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_repository.html
### Sample usage
Check [cmd/main.go](cmd/main.go) for more details.
Available args:
```bash
-i string
Inference Image. Default: images/1.jpg (default "images/1.jpg")
-m string
Name of model being served. (Required) (default "yolov8_tensorrt")
-u string
Inference Server URL. Default: tritonserver:8001 (default "tritonserver:8001")
-x string
Version of model. Default: Latest Version.
```
```bash
go run cmd/main.go
```
### Results
| Input | Ouput |
| --------------------------- | ------------------------------- |
| <img src="images/1.jpg" /> | <img src="images/1_out.jpg" /> |
| <img src="images/2.jpg" /> | <img src="images/2_out.jpg" /> |

13
class.go Normal file
View File

@@ -0,0 +1,13 @@
package yolotriton
var yoloClasses = []string{
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon",
"bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
"clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
}

76
cmd/main.go Normal file
View File

@@ -0,0 +1,76 @@
package main
import (
"flag"
"fmt"
"log"
"strings"
"github.com/dev6699/yolotriton"
)
type Flags struct {
ModelName string
ModelVersion string
URL string
Image string
}
func parseFlags() Flags {
var flags Flags
flag.StringVar(&flags.ModelName, "m", "yolov8_tensorrt", "Name of model being served. (Required)")
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version.")
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL. Default: tritonserver:8001")
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image. Default: images/1.jpg")
flag.Parse()
return flags
}
func main() {
FLAGS := parseFlags()
fmt.Println("FLAGS:", FLAGS)
ygt, err := yolotriton.New(
FLAGS.URL,
yolotriton.YoloTritonConfig{
BatchSize: 1,
NumChannels: 84,
NumObjects: 8400,
Width: 640,
Height: 640,
ModelName: FLAGS.ModelName,
ModelVersion: FLAGS.ModelVersion,
MinProbability: 0.5,
MaxIOU: 0.7,
})
if err != nil {
log.Fatal(err)
}
img, err := yolotriton.LoadImage(FLAGS.Image)
if err != nil {
log.Fatalf("Failed to preprocess image: %v", err)
}
results, err := ygt.Infer(img)
if err != nil {
log.Fatal(err)
}
for i, r := range results {
fmt.Printf("---%d---", i)
fmt.Println(r.Class, r.Probability)
fmt.Println("[x1,x2,y1,y2]", int(r.X1), int(r.X2), int(r.Y1), int(r.Y2))
}
out, err := yolotriton.DrawBoundingBoxes(img, results, 5)
if err != nil {
log.Fatal(err)
}
err = yolotriton.SaveImage(out, fmt.Sprintf("%s_out.jpg", strings.Split(FLAGS.Image, ".")[0]))
if err != nil {
log.Fatal(err)
}
}

59
conn.go Normal file
View File

@@ -0,0 +1,59 @@
package yolotriton
import (
"context"
"time"
triton "github.com/dev6699/yolotriton/grpc-client"
)
func ServerLiveRequest(client triton.GRPCInferenceServiceClient) (*triton.ServerLiveResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
serverLiveRequest := triton.ServerLiveRequest{}
serverLiveResponse, err := client.ServerLive(ctx, &serverLiveRequest)
if err != nil {
return nil, err
}
return serverLiveResponse, nil
}
func ServerReadyRequest(client triton.GRPCInferenceServiceClient) (*triton.ServerReadyResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
serverReadyRequest := triton.ServerReadyRequest{}
serverReadyResponse, err := client.ServerReady(ctx, &serverReadyRequest)
if err != nil {
return nil, err
}
return serverReadyResponse, nil
}
func ModelMetadataRequest(client triton.GRPCInferenceServiceClient, modelName string, modelVersion string) (*triton.ModelMetadataResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
modelMetadataRequest := triton.ModelMetadataRequest{
Name: modelName,
Version: modelVersion,
}
modelMetadataResponse, err := client.ModelMetadata(ctx, &modelMetadataRequest)
if err != nil {
return nil, err
}
return modelMetadataResponse, nil
}
func ModelInferRequest(client triton.GRPCInferenceServiceClient, modelInferRequest *triton.ModelInferRequest) (*triton.ModelInferResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
modelInferResponse, err := client.ModelInfer(ctx, modelInferRequest)
if err != nil {
return nil, err
}
return modelInferResponse, nil
}

1758
core/grpc_service.proto Normal file

File diff suppressed because it is too large Load Diff

78
core/health.proto Normal file
View File

@@ -0,0 +1,78 @@
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto3";
package grpc.health.v1;
//@@.. cpp:namespace:: grpc.health.v1
//@@
//@@.. cpp:var:: message HealthCheckRequest
//@@
//@@ Request message for HealthCheck
//@@
message HealthCheckRequest
{
string service = 1;
}
//@@
//@@.. cpp:var:: message HealthCheckResponse
//@@
//@@ Response message for HealthCheck
//@@
message HealthCheckResponse
{
//@@
//@@.. cpp:enum:: ServingStatus
//@@
//@@ Statuses supported by GRPC's health check.
//@@
enum ServingStatus {
UNKNOWN = 0;
SERVING = 1;
NOT_SERVING = 2;
SERVICE_UNKNOWN = 3;
}
ServingStatus status = 1;
}
//@@
//@@.. cpp:var:: service Health
//@@
//@@ Health service for GRPC endpoints.
//@@
service Health
{
//@@ .. cpp:var:: rpc Check(HealthCheckRequest) returns
//@@ (HealthCheckResponse)
//@@
//@@ Get serving status of the inference server.
//@@
rpc Check(HealthCheckRequest) returns (HealthCheckResponse);
}
option go_package = "./grpc-client";

2011
core/model_config.proto Normal file

File diff suppressed because it is too large Load Diff

22
docker-compose.yml Normal file
View File

@@ -0,0 +1,22 @@
version: "3.9"
services:
tritonserver:
container_name: tritonserver
image: nvcr.io/nvidia/tritonserver:23.07-py3
command: tritonserver --model-repository=/models
ports:
- 8000:8000
- 8001:8001
- 8002:8002
volumes:
- ${LOCAL_WORKSPACE_FOLDER}/model_repository:/models
privileged: true
deploy:
resources:
reservations:
devices:
- capabilities: [gpu]
networks:
default:
name: yolotriton_devcontainer_default

19
go.mod Normal file
View File

@@ -0,0 +1,19 @@
module github.com/dev6699/yolotriton
go 1.19
require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
golang.org/x/image v0.11.0
google.golang.org/grpc v1.57.0
google.golang.org/protobuf v1.31.0
)
require (
github.com/golang/protobuf v1.5.3 // indirect
golang.org/x/net v0.14.0 // indirect
golang.org/x/sys v0.11.0 // indirect
golang.org/x/text v0.12.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect
)

56
go.sum Normal file
View File

@@ -0,0 +1,56 @@
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/image v0.11.0 h1:ds2RoQvBvYTiJkwpSFDwCcDFNX7DqjL2WsUgTNk0Ooo=
golang.org/x/image v0.11.0/go.mod h1:bglhjqbqVuEb9e9+eNR45Jfu7D+T4Qan+NhQk8Ck2P8=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA=
google.golang.org/grpc v1.57.0 h1:kfzNeI/klCGD2YPMUlaGNT3pxvYfga7smW3Vth8Zsiw=
google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

319
grpc-client/health.pb.go Normal file
View File

@@ -0,0 +1,319 @@
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.31.0
// protoc v3.12.4
// source: health.proto
package grpc_client
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// @@
// @@.. cpp:enum:: ServingStatus
// @@
// @@ Statuses supported by GRPC's health check.
// @@
type HealthCheckResponse_ServingStatus int32
const (
HealthCheckResponse_UNKNOWN HealthCheckResponse_ServingStatus = 0
HealthCheckResponse_SERVING HealthCheckResponse_ServingStatus = 1
HealthCheckResponse_NOT_SERVING HealthCheckResponse_ServingStatus = 2
HealthCheckResponse_SERVICE_UNKNOWN HealthCheckResponse_ServingStatus = 3
)
// Enum value maps for HealthCheckResponse_ServingStatus.
var (
HealthCheckResponse_ServingStatus_name = map[int32]string{
0: "UNKNOWN",
1: "SERVING",
2: "NOT_SERVING",
3: "SERVICE_UNKNOWN",
}
HealthCheckResponse_ServingStatus_value = map[string]int32{
"UNKNOWN": 0,
"SERVING": 1,
"NOT_SERVING": 2,
"SERVICE_UNKNOWN": 3,
}
)
func (x HealthCheckResponse_ServingStatus) Enum() *HealthCheckResponse_ServingStatus {
p := new(HealthCheckResponse_ServingStatus)
*p = x
return p
}
func (x HealthCheckResponse_ServingStatus) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (HealthCheckResponse_ServingStatus) Descriptor() protoreflect.EnumDescriptor {
return file_health_proto_enumTypes[0].Descriptor()
}
func (HealthCheckResponse_ServingStatus) Type() protoreflect.EnumType {
return &file_health_proto_enumTypes[0]
}
func (x HealthCheckResponse_ServingStatus) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use HealthCheckResponse_ServingStatus.Descriptor instead.
func (HealthCheckResponse_ServingStatus) EnumDescriptor() ([]byte, []int) {
return file_health_proto_rawDescGZIP(), []int{1, 0}
}
// @@
// @@.. cpp:var:: message HealthCheckRequest
// @@
// @@ Request message for HealthCheck
// @@
type HealthCheckRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Service string `protobuf:"bytes,1,opt,name=service,proto3" json:"service,omitempty"`
}
func (x *HealthCheckRequest) Reset() {
*x = HealthCheckRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_health_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *HealthCheckRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HealthCheckRequest) ProtoMessage() {}
func (x *HealthCheckRequest) ProtoReflect() protoreflect.Message {
mi := &file_health_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HealthCheckRequest.ProtoReflect.Descriptor instead.
func (*HealthCheckRequest) Descriptor() ([]byte, []int) {
return file_health_proto_rawDescGZIP(), []int{0}
}
func (x *HealthCheckRequest) GetService() string {
if x != nil {
return x.Service
}
return ""
}
// @@
// @@.. cpp:var:: message HealthCheckResponse
// @@
// @@ Response message for HealthCheck
// @@
type HealthCheckResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Status HealthCheckResponse_ServingStatus `protobuf:"varint,1,opt,name=status,proto3,enum=grpc.health.v1.HealthCheckResponse_ServingStatus" json:"status,omitempty"`
}
func (x *HealthCheckResponse) Reset() {
*x = HealthCheckResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_health_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *HealthCheckResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HealthCheckResponse) ProtoMessage() {}
func (x *HealthCheckResponse) ProtoReflect() protoreflect.Message {
mi := &file_health_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HealthCheckResponse.ProtoReflect.Descriptor instead.
func (*HealthCheckResponse) Descriptor() ([]byte, []int) {
return file_health_proto_rawDescGZIP(), []int{1}
}
func (x *HealthCheckResponse) GetStatus() HealthCheckResponse_ServingStatus {
if x != nil {
return x.Status
}
return HealthCheckResponse_UNKNOWN
}
var File_health_proto protoreflect.FileDescriptor
var file_health_proto_rawDesc = []byte{
0x0a, 0x0c, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0e,
0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x22, 0x2e,
0x0a, 0x12, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x18,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0xb1,
0x01, 0x0a, 0x13, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x49, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x31, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65,
0x61, 0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68,
0x65, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x53, 0x65, 0x72, 0x76,
0x69, 0x6e, 0x67, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
0x73, 0x22, 0x4f, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x6e, 0x67, 0x53, 0x74, 0x61, 0x74,
0x75, 0x73, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12,
0x0b, 0x0a, 0x07, 0x53, 0x45, 0x52, 0x56, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x0f, 0x0a, 0x0b,
0x4e, 0x4f, 0x54, 0x5f, 0x53, 0x45, 0x52, 0x56, 0x49, 0x4e, 0x47, 0x10, 0x02, 0x12, 0x13, 0x0a,
0x0f, 0x53, 0x45, 0x52, 0x56, 0x49, 0x43, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e,
0x10, 0x03, 0x32, 0x5a, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x50, 0x0a, 0x05,
0x43, 0x68, 0x65, 0x63, 0x6b, 0x12, 0x22, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x68, 0x65, 0x61,
0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65,
0x63, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x67, 0x72, 0x70, 0x63,
0x2e, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74,
0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x0f,
0x5a, 0x0d, 0x2e, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_health_proto_rawDescOnce sync.Once
file_health_proto_rawDescData = file_health_proto_rawDesc
)
func file_health_proto_rawDescGZIP() []byte {
file_health_proto_rawDescOnce.Do(func() {
file_health_proto_rawDescData = protoimpl.X.CompressGZIP(file_health_proto_rawDescData)
})
return file_health_proto_rawDescData
}
var file_health_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_health_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_health_proto_goTypes = []interface{}{
(HealthCheckResponse_ServingStatus)(0), // 0: grpc.health.v1.HealthCheckResponse.ServingStatus
(*HealthCheckRequest)(nil), // 1: grpc.health.v1.HealthCheckRequest
(*HealthCheckResponse)(nil), // 2: grpc.health.v1.HealthCheckResponse
}
var file_health_proto_depIdxs = []int32{
0, // 0: grpc.health.v1.HealthCheckResponse.status:type_name -> grpc.health.v1.HealthCheckResponse.ServingStatus
1, // 1: grpc.health.v1.Health.Check:input_type -> grpc.health.v1.HealthCheckRequest
2, // 2: grpc.health.v1.Health.Check:output_type -> grpc.health.v1.HealthCheckResponse
2, // [2:3] is the sub-list for method output_type
1, // [1:2] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_health_proto_init() }
func file_health_proto_init() {
if File_health_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_health_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*HealthCheckRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_health_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*HealthCheckResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_health_proto_rawDesc,
NumEnums: 1,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_health_proto_goTypes,
DependencyIndexes: file_health_proto_depIdxs,
EnumInfos: file_health_proto_enumTypes,
MessageInfos: file_health_proto_msgTypes,
}.Build()
File_health_proto = out.File
file_health_proto_rawDesc = nil
file_health_proto_goTypes = nil
file_health_proto_depIdxs = nil
}

View File

@@ -0,0 +1,145 @@
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.3.0
// - protoc v3.12.4
// source: health.proto
package grpc_client
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
const (
Health_Check_FullMethodName = "/grpc.health.v1.Health/Check"
)
// HealthClient is the client API for Health service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type HealthClient interface {
// @@ .. cpp:var:: rpc Check(HealthCheckRequest) returns
// @@ (HealthCheckResponse)
// @@
// @@ Get serving status of the inference server.
// @@
Check(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error)
}
type healthClient struct {
cc grpc.ClientConnInterface
}
func NewHealthClient(cc grpc.ClientConnInterface) HealthClient {
return &healthClient{cc}
}
func (c *healthClient) Check(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) {
out := new(HealthCheckResponse)
err := c.cc.Invoke(ctx, Health_Check_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// HealthServer is the server API for Health service.
// All implementations must embed UnimplementedHealthServer
// for forward compatibility
type HealthServer interface {
// @@ .. cpp:var:: rpc Check(HealthCheckRequest) returns
// @@ (HealthCheckResponse)
// @@
// @@ Get serving status of the inference server.
// @@
Check(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error)
mustEmbedUnimplementedHealthServer()
}
// UnimplementedHealthServer must be embedded to have forward compatible implementations.
type UnimplementedHealthServer struct {
}
func (UnimplementedHealthServer) Check(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Check not implemented")
}
func (UnimplementedHealthServer) mustEmbedUnimplementedHealthServer() {}
// UnsafeHealthServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to HealthServer will
// result in compilation errors.
type UnsafeHealthServer interface {
mustEmbedUnimplementedHealthServer()
}
func RegisterHealthServer(s grpc.ServiceRegistrar, srv HealthServer) {
s.RegisterService(&Health_ServiceDesc, srv)
}
func _Health_Check_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HealthCheckRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(HealthServer).Check(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Health_Check_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(HealthServer).Check(ctx, req.(*HealthCheckRequest))
}
return interceptor(ctx, in, info, handler)
}
// Health_ServiceDesc is the grpc.ServiceDesc for Health service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var Health_ServiceDesc = grpc.ServiceDesc{
ServiceName: "grpc.health.v1.Health",
HandlerType: (*HealthServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Check",
Handler: _Health_Check_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "health.proto",
}

File diff suppressed because it is too large Load Diff

BIN
images/1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

BIN
images/1_out.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

BIN
images/2.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 MiB

BIN
images/2_out.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

1
model_repository/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
yolov8_tensorrt/1/model.plan

View File

@@ -0,0 +1,2 @@
name: "yolov8_tensorrt"
platform: "tensorrt_plan"

116
postprocess.go Normal file
View File

@@ -0,0 +1,116 @@
package yolotriton
import (
"bytes"
"encoding/binary"
"math"
"sort"
)
func (y *YoloTriton) bytesToFloat32Slice(data []byte) ([]float32, error) {
t := []float32{}
// Create a buffer from the input data
buffer := bytes.NewReader(data)
for i := 0; i < y.cfg.BatchSize; i++ {
for j := 0; j < y.cfg.NumChannels; j++ {
for k := 0; k < y.cfg.NumObjects; k++ {
// Read the binary data from the buffer
var binaryValue uint32
err := binary.Read(buffer, binary.LittleEndian, &binaryValue)
if err != nil {
return nil, err
}
t = append(t, math.Float32frombits(binaryValue))
}
}
}
return t, nil
}
type Box struct {
X1 float64
Y1 float64
X2 float64
Y2 float64
Probability float64
Class string
}
func (y *YoloTriton) parseOutput(output []float32, origImgWidth, origImgHeight int) []Box {
boxes := []Box{}
for index := 0; index < y.cfg.NumObjects; index++ {
classID := 0
prob := float32(0.0)
for col := 0; col < y.cfg.NumChannels-4; col++ {
if output[y.cfg.NumObjects*(col+4)+index] > prob {
prob = output[y.cfg.NumObjects*(col+4)+index]
classID = col
}
}
if prob < float32(y.cfg.MinProbability) {
continue
}
label := yoloClasses[classID]
xc := output[index]
yc := output[y.cfg.NumObjects+index]
w := output[2*y.cfg.NumObjects+index]
h := output[3*y.cfg.NumObjects+index]
x1 := (xc - w/2) / float32(y.cfg.Width) * float32(origImgWidth)
y1 := (yc - h/2) / float32(y.cfg.Height) * float32(origImgHeight)
x2 := (xc + w/2) / float32(y.cfg.Width) * float32(origImgWidth)
y2 := (yc + h/2) / float32(y.cfg.Height) * float32(origImgHeight)
boxes = append(boxes, Box{
X1: float64(x1),
Y1: float64(y1),
X2: float64(x2),
Y2: float64(y2),
Probability: float64(prob),
Class: label,
})
}
sort.Slice(boxes, func(i, j int) bool {
return boxes[i].Probability < boxes[j].Probability
})
result := []Box{}
for len(boxes) > 0 {
result = append(result, boxes[0])
tmp := []Box{}
for _, box := range boxes {
if iou(boxes[0], box) < y.cfg.MaxIOU {
tmp = append(tmp, box)
}
}
boxes = tmp
}
return result
}
func iou(box1, box2 Box) float64 {
// Calculate the coordinates of the intersection rectangle
intersectionX1 := math.Max(box1.X1, box2.X1)
intersectionY1 := math.Max(box1.Y1, box2.Y1)
intersectionX2 := math.Min(box1.X2, box2.X2)
intersectionY2 := math.Min(box1.Y2, box2.Y2)
// Calculate the area of the intersection rectangle
intersectionArea := math.Max(0, intersectionX2-intersectionX1+1) * math.Max(0, intersectionY2-intersectionY1+1)
// Calculate the area of each bounding box
box1Area := (box1.X2 - box1.X1 + 1) * (box1.Y2 - box1.Y1 + 1)
box2Area := (box2.X2 - box2.X1 + 1) * (box2.Y2 - box2.Y1 + 1)
// Calculate the IoU
iou := intersectionArea / (box1Area + box2Area - intersectionArea)
return iou
}

45
preprocess.go Normal file
View File

@@ -0,0 +1,45 @@
package yolotriton
import (
"image"
"image/color"
"github.com/nfnt/resize"
)
func resizeImage(img image.Image, width, heigth uint) image.Image {
return resize.Resize(width, heigth, img, resize.Lanczos3)
}
func pixelRGBA(c color.Color) (r, g, b, a uint32) {
r, g, b, a = c.RGBA()
return r >> 8, g >> 8, b >> 8, a >> 8
}
func imageToFloat32Slice(img image.Image) []float32 {
bounds := img.Bounds()
width, height := bounds.Max.X, bounds.Max.Y
inputContents := make([]float32, width*height*3)
idx := 0
offset := (height * width)
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
pixel := img.At(x, y)
r, g, b, _ := pixelRGBA(pixel)
// Normalize the color values to the range [0, 1]
floatR := float32(r) / 255
floatG := float32(g) / 255
floatB := float32(b) / 255
inputContents[idx] = floatR
inputContents[offset+idx] = floatG
inputContents[2*offset+idx] = floatB
idx++
}
}
return inputContents
}

100
util.go Normal file
View File

@@ -0,0 +1,100 @@
package yolotriton
import (
"fmt"
"image"
"image/color"
"image/draw"
"image/jpeg"
_ "image/png"
"os"
"github.com/golang/freetype/truetype"
"golang.org/x/image/font"
"golang.org/x/image/font/gofont/goregular"
"golang.org/x/image/math/fixed"
)
func LoadImage(imagePath string) (image.Image, error) {
file, err := os.Open(imagePath)
if err != nil {
return nil, err
}
defer file.Close()
img, _, err := image.Decode(file)
if err != nil {
return nil, err
}
return img, nil
}
func SaveImage(img image.Image, filename string) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
err = jpeg.Encode(file, img, nil)
if err != nil {
return err
}
return nil
}
func DrawBoundingBoxes(img image.Image, boxes []Box, lineWidth int) (image.Image, error) {
// Create a new RGBA image to draw the bounding boxes and text labels on
bounds := img.Bounds()
dst := image.NewRGBA(bounds)
// Copy the original image to the destination image
draw.Draw(dst, bounds, img, bounds.Min, draw.Over)
// Create a color for the bounding boxes (red in this example)
red := color.RGBA{255, 0, 0, 255}
// Create a font from a TrueType font file with the specified font size
ttfFont, err := truetype.Parse(goregular.TTF)
if err != nil {
return nil, err
}
face := truetype.NewFace(ttfFont, &truetype.Options{
Size: 36.0,
})
// Draw the bounding boxes and text labels on the destination image
for _, box := range boxes {
x1, y1, x2, y2 := box.X1, box.Y1, box.X2, box.Y2
// Draw the bounding box
for x := x1; x <= x2; x++ {
for w := 0; w < lineWidth; w++ {
dst.Set(int(x), int(y1)+w, red)
dst.Set(int(x), int(y2)+w, red)
}
}
for y := y1; y <= y2; y++ {
for w := 0; w < lineWidth; w++ {
dst.Set(int(x1)+w, int(y), red)
dst.Set(int(x2)+w, int(y), red)
}
}
// Draw the text label above the box
label := fmt.Sprintf("%s %f", box.Class, box.Probability)
textX := int(x1)
textY := int(y1) - 5
d := &font.Drawer{
Dst: dst,
Src: image.NewUniform(red),
Face: face,
Dot: fixed.P(textX, textY),
}
d.DrawString(label)
}
return dst, nil
}

87
yolo.go Normal file
View File

@@ -0,0 +1,87 @@
package yolotriton
import (
"image"
_ "image/png"
triton "github.com/dev6699/yolotriton/grpc-client"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
type YoloTritonConfig struct {
BatchSize int
NumChannels int
NumObjects int
Width int
Height int
ModelName string
ModelVersion string
MinProbability float64
MaxIOU float64
}
func New(url string, cfg YoloTritonConfig) (*YoloTriton, error) {
conn, err := grpc.Dial(url, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
return &YoloTriton{
cfg: cfg,
conn: conn,
}, nil
}
type YoloTriton struct {
cfg YoloTritonConfig
conn *grpc.ClientConn
}
func (y *YoloTriton) Close() error {
return y.conn.Close()
}
func (y *YoloTriton) Infer(img image.Image) ([]Box, error) {
preprocessedImg := resizeImage(img, uint(y.cfg.Width), uint(y.cfg.Height))
fp32Contents := imageToFloat32Slice(preprocessedImg)
client := triton.NewGRPCInferenceServiceClient(y.conn)
inferInputs := []*triton.ModelInferRequest_InferInputTensor{
{
Name: "images",
Datatype: "FP32",
Shape: []int64{int64(y.cfg.BatchSize), 3, int64(y.cfg.Width), int64(y.cfg.Height)},
Contents: &triton.InferTensorContents{
Fp32Contents: fp32Contents,
},
},
}
inferOutputs := []*triton.ModelInferRequest_InferRequestedOutputTensor{
{
Name: "output0",
},
}
modelInferRequest := &triton.ModelInferRequest{
ModelName: y.cfg.ModelName,
ModelVersion: y.cfg.ModelVersion,
Inputs: inferInputs,
Outputs: inferOutputs,
}
inferResponse, err := ModelInferRequest(client, modelInferRequest)
if err != nil {
return nil, err
}
t, err := y.bytesToFloat32Slice(inferResponse.RawOutputContents[0])
if err != nil {
return nil, err
}
boxes := y.parseOutput(t, img.Bounds().Dx(), img.Bounds().Dy())
return boxes, nil
}