Update main.cpp

This commit is contained in:
hpc203
2023-07-30 16:37:12 +08:00
committed by GitHub
parent 5767b8bb19
commit f7c7606aee

View File

@@ -4,7 +4,7 @@
#include <numeric> #include <numeric>
#include <opencv2/imgproc.hpp> #include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp> #include <opencv2/highgui.hpp>
//#include <cuda_provider_factory.h> ///nvidia-cuda<EFBFBD><EFBFBD><EFBFBD><EFBFBD> //#include <cuda_provider_factory.h> ///nvidia-cuda加速
#include <onnxruntime_cxx_api.h> #include <onnxruntime_cxx_api.h>
using namespace cv; using namespace cv;
@@ -34,8 +34,8 @@ private:
vector<string> class_names; vector<string> class_names;
const int max_size = 800; const int max_size = 800;
//<EFBFBD><EFBFBD><EFBFBD>ʼ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>õĿ<EFBFBD>ִ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> //存储初始化获得的可执行网络
Env env = Env(ORT_LOGGING_LEVEL_ERROR, "Head Pose Estimation"); Env env = Env(ORT_LOGGING_LEVEL_ERROR, "Detic");
Ort::Session *ort_session = nullptr; Ort::Session *ort_session = nullptr;
SessionOptions sessionOptions = SessionOptions(); SessionOptions sessionOptions = SessionOptions();
vector<char*> input_names; vector<char*> input_names;
@@ -46,11 +46,11 @@ private:
Detic::Detic(string model_path) Detic::Detic(string model_path)
{ {
//OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0); ///nvidia-cuda<EFBFBD><EFBFBD><EFBFBD><EFBFBD> //OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0); ///nvidia-cuda加速
sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
std::wstring widestr = std::wstring(model_path.begin(), model_path.end()); ///<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>windowsϵͳ<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ôд std::wstring widestr = std::wstring(model_path.begin(), model_path.end()); ///如果在windows系统就这么写
ort_session = new Session(env, widestr.c_str(), sessionOptions); ///<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>windowsϵͳ<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ôд ort_session = new Session(env, widestr.c_str(), sessionOptions); ///如果在windows系统就这么写
///ort_session = new Session(env, model_path.c_str(), sessionOptions); ///<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>linuxϵͳ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ôд ///ort_session = new Session(env, model_path.c_str(), sessionOptions); ///如果在linux系统就这么写
size_t numInputNodes = ort_session->GetInputCount(); size_t numInputNodes = ort_session->GetInputCount();
size_t numOutputNodes = ort_session->GetOutputCount(); size_t numOutputNodes = ort_session->GetOutputCount();
@@ -76,7 +76,7 @@ Detic::Detic(string model_path)
string line; string line;
while (getline(ifs, line)) while (getline(ifs, line))
{ {
this->class_names.push_back(line); ///<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ÿ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>RGBֵ this->class_names.push_back(line); ///你可以用随机数给每个类别分配RGB
} }
} }
@@ -136,7 +136,7 @@ vector<BoxInfo> Detic::detect(Mat srcimg)
auto allocator_info = MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); auto allocator_info = MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Value input_tensor_ = Value::CreateTensor<float>(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()); Value input_tensor_ = Value::CreateTensor<float>(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());
// <EFBFBD><EFBFBD>ʼ<EFBFBD><EFBFBD><EFBFBD><EFBFBD> // 开始推理
vector<Value> ort_outputs = ort_session->Run(RunOptions{ nullptr }, &input_names[0], &input_tensor_, 1, output_names.data(), output_names.size()); vector<Value> ort_outputs = ort_session->Run(RunOptions{ nullptr }, &input_names[0], &input_tensor_, 1, output_names.data(), output_names.size());
const float *pred_boxes = ort_outputs[0].GetTensorMutableData<float>(); const float *pred_boxes = ort_outputs[0].GetTensorMutableData<float>();