mirror of
https://github.com/AndreyGermanov/yolov8_onnx_go.git
synced 2025-09-27 04:26:22 +08:00
100 lines
2.9 KiB
Go
100 lines
2.9 KiB
Go
package main
|
|
|
|
import (
|
|
ort "github.com/yalue/onnxruntime_go"
|
|
"runtime"
|
|
)
|
|
|
|
func InitYolo8Session(input []float32) (ModelSession, error) {
|
|
ort.SetSharedLibraryPath(getSharedLibPath())
|
|
err := ort.InitializeEnvironment()
|
|
if err != nil {
|
|
return ModelSession{}, err
|
|
}
|
|
|
|
inputShape := ort.NewShape(1, 3, 640, 640)
|
|
inputTensor, err := ort.NewTensor(inputShape, input)
|
|
if err != nil {
|
|
return ModelSession{}, err
|
|
}
|
|
|
|
outputShape := ort.NewShape(1, 84, 8400)
|
|
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
|
|
if err != nil {
|
|
return ModelSession{}, err
|
|
}
|
|
|
|
options, e := ort.NewSessionOptions()
|
|
if e != nil {
|
|
return ModelSession{}, err
|
|
}
|
|
|
|
if UseCoreML { // If CoreML is enabled, append the CoreML execution provider
|
|
e = options.AppendExecutionProviderCoreML(0)
|
|
if e != nil {
|
|
options.Destroy()
|
|
return ModelSession{}, err
|
|
}
|
|
defer options.Destroy()
|
|
}
|
|
|
|
session, err := ort.NewAdvancedSession(ModelPath,
|
|
[]string{"images"}, []string{"output0"},
|
|
[]ort.ArbitraryTensor{inputTensor}, []ort.ArbitraryTensor{outputTensor}, options)
|
|
|
|
if err != nil {
|
|
return ModelSession{}, err
|
|
}
|
|
|
|
modelSes := ModelSession{
|
|
Session: session,
|
|
Input: inputTensor,
|
|
Output: outputTensor,
|
|
}
|
|
|
|
return modelSes, err
|
|
}
|
|
|
|
func getSharedLibPath() string {
|
|
if runtime.GOOS == "windows" {
|
|
if runtime.GOARCH == "amd64" {
|
|
return "./third_party/onnxruntime.dll"
|
|
}
|
|
}
|
|
if runtime.GOOS == "darwin" {
|
|
if runtime.GOARCH == "arm64" {
|
|
return "./third_party/onnxruntime_arm64.dylib"
|
|
}
|
|
}
|
|
if runtime.GOOS == "linux" {
|
|
if runtime.GOARCH == "arm64" {
|
|
return "../third_party/onnxruntime_arm64.so"
|
|
}
|
|
return "./third_party/onnxruntime.so"
|
|
}
|
|
panic("Unable to find a version of the onnxruntime library supporting this system.")
|
|
}
|
|
|
|
// Array of YOLOv8 class labels
|
|
var yolo_classes = []string{
|
|
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
|
|
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
|
|
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
|
|
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
|
|
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon",
|
|
"bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
|
|
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
|
|
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
|
|
"clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
|
|
}
|
|
|
|
func runInference(modelSes ModelSession, input []float32) ([]float32, error) {
|
|
inTensor := modelSes.Input.GetData()
|
|
copy(inTensor, input)
|
|
err := modelSes.Session.Run()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return modelSes.Output.GetData(), nil
|
|
}
|