mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-22 08:09:28 +08:00
fix some usage bugs for detection models
This commit is contained in:
@@ -54,8 +54,8 @@ std::vector<int> toVec(const nvinfer1::Dims& dim) {
|
||||
|
||||
bool CheckDynamicShapeConfig(const paddle2onnx::OnnxReader& reader,
|
||||
const TrtBackendOption& option) {
|
||||
//paddle2onnx::ModelTensorInfo inputs[reader.NumInputs()];
|
||||
//std::string input_shapes[reader.NumInputs()];
|
||||
// paddle2onnx::ModelTensorInfo inputs[reader.NumInputs()];
|
||||
// std::string input_shapes[reader.NumInputs()];
|
||||
std::vector<paddle2onnx::ModelTensorInfo> inputs(reader.NumInputs());
|
||||
std::vector<std::string> input_shapes(reader.NumInputs());
|
||||
for (int i = 0; i < reader.NumInputs(); ++i) {
|
||||
@@ -374,27 +374,27 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
|
||||
1U << static_cast<uint32_t>(
|
||||
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
|
||||
|
||||
auto builder = SampleUniquePtr<nvinfer1::IBuilder>(
|
||||
builder_ = SampleUniquePtr<nvinfer1::IBuilder>(
|
||||
nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
|
||||
if (!builder) {
|
||||
if (!builder_) {
|
||||
FDERROR << "Failed to call createInferBuilder()." << std::endl;
|
||||
return false;
|
||||
}
|
||||
auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(
|
||||
builder->createNetworkV2(explicitBatch));
|
||||
if (!network) {
|
||||
network_ = SampleUniquePtr<nvinfer1::INetworkDefinition>(
|
||||
builder_->createNetworkV2(explicitBatch));
|
||||
if (!network_) {
|
||||
FDERROR << "Failed to call createNetworkV2()." << std::endl;
|
||||
return false;
|
||||
}
|
||||
auto config =
|
||||
SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
|
||||
auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(
|
||||
builder_->createBuilderConfig());
|
||||
if (!config) {
|
||||
FDERROR << "Failed to call createBuilderConfig()." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (option.enable_fp16) {
|
||||
if (!builder->platformHasFastFp16()) {
|
||||
if (!builder_->platformHasFastFp16()) {
|
||||
FDWARNING << "Detected FP16 is not supported in the current GPU, "
|
||||
"will use FP32 instead."
|
||||
<< std::endl;
|
||||
@@ -403,25 +403,25 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
|
||||
}
|
||||
}
|
||||
|
||||
auto parser = SampleUniquePtr<nvonnxparser::IParser>(
|
||||
nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
|
||||
if (!parser) {
|
||||
parser_ = SampleUniquePtr<nvonnxparser::IParser>(
|
||||
nvonnxparser::createParser(*network_, sample::gLogger.getTRTLogger()));
|
||||
if (!parser_) {
|
||||
FDERROR << "Failed to call createParser()." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (!parser->parse(onnx_model.data(), onnx_model.size())) {
|
||||
if (!parser_->parse(onnx_model.data(), onnx_model.size())) {
|
||||
FDERROR << "Failed to parse ONNX model by TensorRT." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
FDINFO << "Start to building TensorRT Engine..." << std::endl;
|
||||
bool fp16 = builder->platformHasFastFp16();
|
||||
builder->setMaxBatchSize(option.max_batch_size);
|
||||
bool fp16 = builder_->platformHasFastFp16();
|
||||
builder_->setMaxBatchSize(option.max_batch_size);
|
||||
|
||||
config->setMaxWorkspaceSize(option.max_workspace_size);
|
||||
|
||||
if (option.max_shape.size() > 0) {
|
||||
auto profile = builder->createOptimizationProfile();
|
||||
auto profile = builder_->createOptimizationProfile();
|
||||
FDASSERT(option.max_shape.size() == option.min_shape.size() &&
|
||||
option.min_shape.size() == option.opt_shape.size(),
|
||||
"[TrtBackend] Size of max_shape/opt_shape/min_shape in "
|
||||
@@ -459,7 +459,7 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
|
||||
}
|
||||
|
||||
SampleUniquePtr<IHostMemory> plan{
|
||||
builder->buildSerializedNetwork(*network, *config)};
|
||||
builder_->buildSerializedNetwork(*network_, *config)};
|
||||
if (!plan) {
|
||||
FDERROR << "Failed to call buildSerializedNetwork()." << std::endl;
|
||||
return false;
|
||||
|
@@ -85,6 +85,9 @@ class TrtBackend : public BaseBackend {
|
||||
private:
|
||||
std::shared_ptr<nvinfer1::ICudaEngine> engine_;
|
||||
std::shared_ptr<nvinfer1::IExecutionContext> context_;
|
||||
SampleUniquePtr<nvonnxparser::IParser> parser_;
|
||||
SampleUniquePtr<nvinfer1::IBuilder> builder_;
|
||||
SampleUniquePtr<nvinfer1::INetworkDefinition> network_;
|
||||
cudaStream_t stream_{};
|
||||
std::vector<void*> bindings_;
|
||||
std::vector<TrtValueInfo> inputs_desc_;
|
||||
|
@@ -53,7 +53,7 @@ bool FastDeployModel::InitRuntime() {
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
runtime_ = new Runtime();
|
||||
runtime_ = std::unique_ptr<Runtime>(new Runtime());
|
||||
if (!runtime_->Init(runtime_option)) {
|
||||
return false;
|
||||
}
|
||||
@@ -88,7 +88,7 @@ bool FastDeployModel::CreateCpuBackend() {
|
||||
continue;
|
||||
}
|
||||
runtime_option.backend = valid_cpu_backends[i];
|
||||
runtime_ = new Runtime();
|
||||
runtime_ = std::unique_ptr<Runtime>(new Runtime());
|
||||
if (!runtime_->Init(runtime_option)) {
|
||||
return false;
|
||||
}
|
||||
@@ -111,7 +111,7 @@ bool FastDeployModel::CreateGpuBackend() {
|
||||
continue;
|
||||
}
|
||||
runtime_option.backend = valid_gpu_backends[i];
|
||||
runtime_ = new Runtime();
|
||||
runtime_ = std::unique_ptr<Runtime>(new Runtime());
|
||||
if (!runtime_->Init(runtime_option)) {
|
||||
return false;
|
||||
}
|
||||
|
@@ -18,7 +18,7 @@ namespace fastdeploy {
|
||||
|
||||
class FASTDEPLOY_DECL FastDeployModel {
|
||||
public:
|
||||
virtual std::string ModelName() const { return "NameUndefined"; };
|
||||
virtual std::string ModelName() const { return "NameUndefined"; }
|
||||
|
||||
virtual bool InitRuntime();
|
||||
virtual bool CreateCpuBackend();
|
||||
@@ -47,21 +47,21 @@ class FASTDEPLOY_DECL FastDeployModel {
|
||||
virtual bool DebugEnabled();
|
||||
|
||||
private:
|
||||
Runtime* runtime_ = nullptr;
|
||||
std::unique_ptr<Runtime> runtime_;
|
||||
bool runtime_initialized_ = false;
|
||||
bool debug_ = false;
|
||||
};
|
||||
|
||||
#define TIMERECORD_START(id) \
|
||||
TimeCounter tc_##id; \
|
||||
#define TIMERECORD_START(id) \
|
||||
TimeCounter tc_##id; \
|
||||
tc_##id.Start();
|
||||
|
||||
#define TIMERECORD_END(id, prefix) \
|
||||
if (DebugEnabled()) { \
|
||||
tc_##id.End(); \
|
||||
FDLogger() << __FILE__ << "(" << __LINE__ << "):" << __FUNCTION__ << " " \
|
||||
<< prefix << " duration = " << tc_##id.Duration() << "s." \
|
||||
<< std::endl; \
|
||||
#define TIMERECORD_END(id, prefix) \
|
||||
if (DebugEnabled()) { \
|
||||
tc_##id.End(); \
|
||||
FDLogger() << __FILE__ << "(" << __LINE__ << "):" << __FUNCTION__ << " " \
|
||||
<< prefix << " duration = " << tc_##id.Duration() << "s." \
|
||||
<< std::endl; \
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
} // namespace fastdeploy
|
||||
|
@@ -12,7 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/ppdet/ppyoloe.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
bool BuildPreprocessPipelineFromConfig(
|
||||
std::vector<std::shared_ptr<Processor>>* processors,
|
||||
@@ -22,7 +27,7 @@ bool BuildPreprocessPipelineFromConfig(
|
||||
try {
|
||||
cfg = YAML::LoadFile(config_file);
|
||||
} catch (YAML::BadFile& e) {
|
||||
FDERROR << "Failed to load yaml file " << config_file_
|
||||
FDERROR << "Failed to load yaml file " << config_file
|
||||
<< ", maybe you should check this file." << std::endl;
|
||||
return false;
|
||||
}
|
||||
@@ -76,3 +81,6 @@ bool BuildPreprocessPipelineFromConfig(
|
||||
processors->push_back(std::make_shared<HWC2CHW>());
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
@@ -13,7 +13,6 @@
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/ppdet/centernet.h"
|
||||
#include "fastdeploy/vision/ppdet/picodet.h"
|
||||
#include "fastdeploy/vision/ppdet/ppyolo.h"
|
||||
#include "fastdeploy/vision/ppdet/ppyoloe.h"
|
||||
|
@@ -27,21 +27,60 @@ void BindPPDet(pybind11::module& m) {
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
pybind11::class_<vision::ppdet::PPYOLO, vision::ppdet::PPYOLOE>(ppdet_module,
|
||||
"PPYOLO")
|
||||
|
||||
pybind11::class_<vision::ppdet::PPYOLO, FastDeployModel>(ppdet_module,
|
||||
"PPYOLO")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>());
|
||||
pybind11::class_<vision::ppdet::PicoDet, vision::ppdet::PPYOLOE>(ppdet_module,
|
||||
"PicoDet")
|
||||
Frontend>())
|
||||
.def("predict", [](vision::ppdet::PPYOLO& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::ppdet::PicoDet, FastDeployModel>(ppdet_module,
|
||||
"PicoDet")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>());
|
||||
pybind11::class_<vision::ppdet::YOLOX, vision::ppdet::PPYOLOE>(ppdet_module,
|
||||
"YOLOX")
|
||||
Frontend>())
|
||||
.def("predict", [](vision::ppdet::PicoDet& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::ppdet::YOLOX, FastDeployModel>(ppdet_module, "YOLOX")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>());
|
||||
pybind11::class_<vision::ppdet::FasterRCNN, vision::ppdet::PPYOLOE>(
|
||||
ppdet_module, "FasterRCNN")
|
||||
Frontend>())
|
||||
.def("predict", [](vision::ppdet::YOLOX& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::ppdet::FasterRCNN, FastDeployModel>(ppdet_module,
|
||||
"FasterRCNN")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>());
|
||||
Frontend>())
|
||||
.def("predict",
|
||||
[](vision::ppdet::FasterRCNN& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::ppdet::YOLOv3, FastDeployModel>(ppdet_module,
|
||||
"YOLOv3")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>())
|
||||
.def("predict", [](vision::ppdet::YOLOv3& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
}
|
||||
} // namespace fastdeploy
|
||||
|
@@ -49,14 +49,12 @@ bool PPYOLO::Initialize() {
|
||||
bool PPYOLO::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
int origin_w = mat->Width();
|
||||
int origin_h = mat->Height();
|
||||
mat->PrintInfo("Origin");
|
||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||
if (!(*(processors_[i].get()))(mat)) {
|
||||
FDERROR << "Failed to process image data in " << processors_[i]->Name()
|
||||
<< "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
mat->PrintInfo(processors_[i]->Name());
|
||||
}
|
||||
|
||||
outputs->resize(3);
|
||||
|
@@ -139,14 +139,12 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() {
|
||||
bool PPYOLOE::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
int origin_w = mat->Width();
|
||||
int origin_h = mat->Height();
|
||||
mat->PrintInfo("Origin");
|
||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||
if (!(*(processors_[i].get()))(mat)) {
|
||||
FDERROR << "Failed to process image data in " << processors_[i]->Name()
|
||||
<< "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
mat->PrintInfo(processors_[i]->Name());
|
||||
}
|
||||
|
||||
outputs->resize(2);
|
||||
@@ -239,7 +237,6 @@ bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) {
|
||||
return false;
|
||||
}
|
||||
|
||||
processed_data[0].PrintInfo("Before infer");
|
||||
float* tmp = static_cast<float*>(processed_data[1].Data());
|
||||
std::vector<FDTensor> infer_result;
|
||||
if (!Infer(processed_data, &infer_result)) {
|
||||
@@ -248,8 +245,6 @@ bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) {
|
||||
return false;
|
||||
}
|
||||
|
||||
infer_result[0].PrintInfo("Boxes");
|
||||
infer_result[1].PrintInfo("Num");
|
||||
if (!Postprocess(infer_result, result)) {
|
||||
FDERROR << "Failed to postprocess while using model:" << ModelName() << "."
|
||||
<< std::endl;
|
||||
|
@@ -45,9 +45,6 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel {
|
||||
|
||||
protected:
|
||||
PPYOLOE() {}
|
||||
// This function will used to check if this model contains multiclass_nms
|
||||
// and get parameters from the operator
|
||||
void GetNmsInfo();
|
||||
|
||||
std::vector<std::shared_ptr<Processor>> processors_;
|
||||
std::string config_file_;
|
||||
|
@@ -50,7 +50,6 @@ bool FasterRCNN::Initialize() {
|
||||
bool FasterRCNN::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
int origin_w = mat->Width();
|
||||
int origin_h = mat->Height();
|
||||
mat->PrintInfo("Origin");
|
||||
float scale[2] = {1.0, 1.0};
|
||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||
if (!(*(processors_[i].get()))(mat)) {
|
||||
@@ -62,7 +61,6 @@ bool FasterRCNN::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
scale[0] = mat->Height() * 1.0 / origin_h;
|
||||
scale[1] = mat->Width() * 1.0 / origin_w;
|
||||
}
|
||||
mat->PrintInfo(processors_[i]->Name());
|
||||
}
|
||||
|
||||
outputs->resize(3);
|
||||
@@ -78,9 +76,6 @@ bool FasterRCNN::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
mat->ShareWithTensor(&((*outputs)[1]));
|
||||
// reshape to [1, c, h, w]
|
||||
(*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1);
|
||||
(*outputs)[0].PrintInfo("im_shape");
|
||||
(*outputs)[1].PrintInfo("image");
|
||||
(*outputs)[2].PrintInfo("scale_factor");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@@ -35,14 +35,12 @@ YOLOv3::YOLOv3(const std::string& model_file, const std::string& params_file,
|
||||
bool YOLOv3::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
int origin_w = mat->Width();
|
||||
int origin_h = mat->Height();
|
||||
mat->PrintInfo("Origin");
|
||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||
if (!(*(processors_[i].get()))(mat)) {
|
||||
FDERROR << "Failed to process image data in " << processors_[i]->Name()
|
||||
<< "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
mat->PrintInfo(processors_[i]->Name());
|
||||
}
|
||||
|
||||
outputs->resize(3);
|
||||
|
@@ -42,14 +42,12 @@ bool YOLOX::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
int origin_w = mat->Width();
|
||||
int origin_h = mat->Height();
|
||||
float scale[2] = {1.0, 1.0};
|
||||
mat->PrintInfo("Origin");
|
||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||
if (!(*(processors_[i].get()))(mat)) {
|
||||
FDERROR << "Failed to process image data in " << processors_[i]->Name()
|
||||
<< "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
mat->PrintInfo(processors_[i]->Name());
|
||||
if (processors_[i]->Name().find("Resize") != std::string::npos) {
|
||||
scale[0] = mat->Height() * 1.0 / origin_h;
|
||||
scale[1] = mat->Width() * 1.0 / origin_w;
|
||||
|
@@ -16,6 +16,11 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
try:
|
||||
import paddle
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def add_dll_search_dir(dir_path):
|
||||
os.environ["path"] = dir_path + ";" + os.environ["path"]
|
||||
|
@@ -45,7 +45,7 @@ class PPYOLO(PPYOLOE):
|
||||
config_file,
|
||||
runtime_option=None,
|
||||
model_format=Frontend.PADDLE):
|
||||
super(PPYOLO, self).__init__(runtime_option)
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "PPYOLO model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.ppdet.PPYOLO(model_file, params_file,
|
||||
@@ -61,7 +61,7 @@ class YOLOX(PPYOLOE):
|
||||
config_file,
|
||||
runtime_option=None,
|
||||
model_format=Frontend.PADDLE):
|
||||
super(YOLOX, self).__init__(runtime_option)
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "YOLOX model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.ppdet.YOLOX(model_file, params_file,
|
||||
@@ -77,7 +77,7 @@ class PicoDet(PPYOLOE):
|
||||
config_file,
|
||||
runtime_option=None,
|
||||
model_format=Frontend.PADDLE):
|
||||
super(PicoDet, self).__init__(runtime_option)
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "PicoDet model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.ppdet.PicoDet(model_file, params_file,
|
||||
@@ -93,10 +93,26 @@ class FasterRCNN(PPYOLOE):
|
||||
config_file,
|
||||
runtime_option=None,
|
||||
model_format=Frontend.PADDLE):
|
||||
super(FasterRCNN, self).__init__(runtime_option)
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "FasterRCNN model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.ppdet.FasterRCNN(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "FasterRCNN model initialize failed."
|
||||
|
||||
|
||||
class YOLOv3(PPYOLOE):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file,
|
||||
config_file,
|
||||
runtime_option=None,
|
||||
model_format=Frontend.PADDLE):
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == Frontend.PADDLE, "YOLOv3 model only support model format of Frontend.Paddle now."
|
||||
self._model = C.vision.ppdet.YOLOv3(model_file, params_file,
|
||||
config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "YOLOv3 model initialize failed."
|
||||
|
@@ -1,52 +0,0 @@
|
||||
# PaddleDetection/PPYOLOE部署示例
|
||||
|
||||
- 当前支持PaddleDetection版本为[release/2.4](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4)
|
||||
|
||||
本文档说明如何进行[PPYOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)的快速部署推理。本目录结构如下
|
||||
```
|
||||
.
|
||||
├── cpp # C++ 代码目录
|
||||
│ ├── CMakeLists.txt # C++ 代码编译CMakeLists文件
|
||||
│ ├── README.md # C++ 代码编译部署文档
|
||||
│ └── ppyoloe.cc # C++ 示例代码
|
||||
├── README.md # PPYOLOE 部署文档
|
||||
└── ppyoloe.py # Python示例代码
|
||||
```
|
||||
|
||||
## 安装FastDeploy
|
||||
|
||||
使用如下命令安装FastDeploy,注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu`
|
||||
```
|
||||
# 安装fastdeploy-python工具
|
||||
pip install fastdeploy-python
|
||||
```
|
||||
|
||||
## Python部署
|
||||
|
||||
执行如下代码即会自动下载PPYOLOE模型和测试图片
|
||||
```
|
||||
python ppyoloe.py
|
||||
```
|
||||
|
||||
执行完成后会将可视化结果保存在本地`vis_result.jpg`,同时输出检测结果如下
|
||||
```
|
||||
DetectionResult: [xmin, ymin, xmax, ymax, score, label_id]
|
||||
162.380249,132.057449, 463.178345, 413.167114, 0.962918, 33
|
||||
414.914642,141.148666, 91.275269, 308.688293, 0.951003, 0
|
||||
163.449234,129.669067, 35.253891, 135.111786, 0.900734, 0
|
||||
267.232239,142.290436, 31.578918, 126.329773, 0.848709, 0
|
||||
581.790833,179.027115, 30.893127, 135.484940, 0.837986, 0
|
||||
104.407021,72.602615, 22.900627, 75.469055, 0.796468, 0
|
||||
348.795380,70.122147, 18.806061, 85.829330, 0.785557, 0
|
||||
364.118683,92.457428, 17.437622, 89.212891, 0.774282, 0
|
||||
75.180283,192.470490, 41.898407, 55.552414, 0.712569, 56
|
||||
328.133759,61.894299, 19.100616, 65.633575, 0.710519, 0
|
||||
504.797760,181.732574, 107.740814, 248.115082, 0.708902, 0
|
||||
379.063080,64.762360, 15.956146, 68.312546, 0.680725, 0
|
||||
25.858747,186.564178, 34.958130, 56.007080, 0.580415, 0
|
||||
```
|
||||
|
||||
## 其它文档
|
||||
|
||||
- [C++部署](./cpp/README.md)
|
||||
- [PPYOLOE API文档](./api.md)
|
@@ -1,68 +0,0 @@
|
||||
# PPYOLOE API说明
|
||||
|
||||
## Python API
|
||||
|
||||
### PPYOLOE类
|
||||
```
|
||||
fastdeploy.vision.ppdet.PPYOLOE(model_file, params_file, config_file, runtime_option=None, model_format=fd.Frontend.PADDLE)
|
||||
```
|
||||
PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当前仅支持model_format为Paddle格式
|
||||
|
||||
**参数**
|
||||
|
||||
> * **model_file**(str): 模型文件路径
|
||||
> * **params_file**(str): 参数文件路径
|
||||
> * **config_file**(str): 模型推理配置文件
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(Frontend): 模型格式
|
||||
|
||||
#### predict函数
|
||||
> ```
|
||||
> PPYOLOE.predict(image_data)
|
||||
> ```
|
||||
> 模型预测结口,输入图像直接输出检测结果。
|
||||
>
|
||||
> **参数**
|
||||
>
|
||||
> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式
|
||||
|
||||
示例代码参考[ppyoloe.py](./ppyoloe.py)
|
||||
|
||||
|
||||
## C++ API
|
||||
|
||||
### PPYOLOE类
|
||||
```
|
||||
fastdeploy::vision::ppdet::PPYOLOE(
|
||||
const string& model_file,
|
||||
const string& params_file,
|
||||
const string& config_file,
|
||||
const RuntimeOption& runtime_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::PADDLE)
|
||||
```
|
||||
PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当前仅支持model_format为Paddle格式
|
||||
|
||||
**参数**
|
||||
|
||||
> * **model_file**(str): 模型文件路径
|
||||
> * **params_file**(str): 参数文件路径
|
||||
> * **config_file**(str): 模型推理配置文件
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(Frontend): 模型格式
|
||||
|
||||
#### Predict函数
|
||||
> ```
|
||||
> PPYOLOE::Predict(cv::Mat* im, DetectionResult* result)
|
||||
> ```
|
||||
> 模型预测接口,输入图像直接输出检测结果。
|
||||
>
|
||||
> **参数**
|
||||
>
|
||||
> > * **im**: 输入图像,注意需为HWC,BGR格式
|
||||
> > * **result**: 检测结果,包括检测框,各个框的置信度
|
||||
|
||||
示例代码参考[cpp/ppyoloe.cc](cpp/ppyoloe.cc)
|
||||
|
||||
## 其它API使用
|
||||
|
||||
- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md)
|
@@ -1,17 +0,0 @@
|
||||
PROJECT(ppyoloe_demo C CXX)
|
||||
CMAKE_MINIMUM_REQUIRED (VERSION 3.16)
|
||||
|
||||
# 在低版本ABI环境中,通过如下代码进行兼容性编译
|
||||
# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||
|
||||
# 指定下载解压后的fastdeploy库路径
|
||||
set(FASTDEPLOY_INSTALL_DIR /fastdeploy/CustomOp/FastDeploy/build1/fastdeploy-linux-x64-gpu-0.3.0)
|
||||
|
||||
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
|
||||
|
||||
# 添加FastDeploy依赖头文件
|
||||
include_directories(${FASTDEPLOY_INCS})
|
||||
|
||||
add_executable(ppyoloe_demo ${PROJECT_SOURCE_DIR}/ppyoloe.cc)
|
||||
# 添加FastDeploy库依赖
|
||||
target_link_libraries(ppyoloe_demo ${FASTDEPLOY_LIBS})
|
@@ -1,39 +0,0 @@
|
||||
# 编译PPYOLOE示例
|
||||
|
||||
|
||||
```
|
||||
# 下载和解压预测库
|
||||
wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.0.3.tgz
|
||||
tar xvf fastdeploy-linux-x64-0.0.3.tgz
|
||||
|
||||
# 编译示例代码
|
||||
mkdir build & cd build
|
||||
cmake ..
|
||||
make -j
|
||||
|
||||
# 下载模型和图片
|
||||
wget https://bj.bcebos.com/paddle2onnx/fastdeploy/models/ppdet/ppyoloe_crn_l_300e_coco.tgz
|
||||
tar xvf ppyoloe_crn_l_300e_coco.tgz
|
||||
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleDetection/release/2.4/demo/000000014439_640x640.jpg
|
||||
|
||||
# 执行
|
||||
./ppyoloe_demo
|
||||
```
|
||||
|
||||
执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示
|
||||
```
|
||||
DetectionResult: [xmin, ymin, xmax, ymax, score, label_id]
|
||||
162.380249,132.057449, 463.178345, 413.167114, 0.962918, 33
|
||||
414.914642,141.148666, 91.275269, 308.688293, 0.951003, 0
|
||||
163.449234,129.669067, 35.253891, 135.111786, 0.900734, 0
|
||||
267.232239,142.290436, 31.578918, 126.329773, 0.848709, 0
|
||||
581.790833,179.027115, 30.893127, 135.484940, 0.837986, 0
|
||||
104.407021,72.602615, 22.900627, 75.469055, 0.796468, 0
|
||||
348.795380,70.122147, 18.806061, 85.829330, 0.785557, 0
|
||||
364.118683,92.457428, 17.437622, 89.212891, 0.774282, 0
|
||||
75.180283,192.470490, 41.898407, 55.552414, 0.712569, 56
|
||||
328.133759,61.894299, 19.100616, 65.633575, 0.710519, 0
|
||||
504.797760,181.732574, 107.740814, 248.115082, 0.708902, 0
|
||||
379.063080,64.762360, 15.956146, 68.312546, 0.680725, 0
|
||||
25.858747,186.564178, 34.958130, 56.007080, 0.580415, 0
|
||||
```
|
@@ -1,51 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision.h"
|
||||
|
||||
int main() {
|
||||
namespace vis = fastdeploy::vision;
|
||||
|
||||
std::string model_file = "ppyoloe_crn_l_300e_coco/model.pdmodel";
|
||||
std::string params_file = "ppyoloe_crn_l_300e_coco/model.pdiparams";
|
||||
std::string config_file = "ppyoloe_crn_l_300e_coco/infer_cfg.yml";
|
||||
std::string img_path = "000000014439_640x640.jpg";
|
||||
std::string vis_path = "vis.jpeg";
|
||||
|
||||
auto model = vis::ppdet::PPYOLOE(model_file, params_file, config_file);
|
||||
if (!model.Initialized()) {
|
||||
std::cerr << "Init Failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cv::Mat im = cv::imread(img_path);
|
||||
cv::Mat vis_im = im.clone();
|
||||
|
||||
vis::DetectionResult res;
|
||||
if (!model.Predict(&im, &res)) {
|
||||
std::cerr << "Prediction Failed." << std::endl;
|
||||
return -1;
|
||||
} else {
|
||||
std::cout << "Prediction Done!" << std::endl;
|
||||
}
|
||||
|
||||
// 输出预测框结果
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
// 可视化预测结果
|
||||
vis::Visualize::VisDetection(&vis_im, res);
|
||||
cv::imwrite(vis_path, vis_im);
|
||||
std::cout << "Detect Done! Saved: " << vis_path << std::endl;
|
||||
return 0;
|
||||
}
|
@@ -1,24 +0,0 @@
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
|
||||
# 下载模型和测试图片
|
||||
model_url = "https://bj.bcebos.com/paddle2onnx/fastdeploy/models/ppdet/ppyoloe_crn_l_300e_coco.tgz"
|
||||
test_jpg_url = "https://raw.githubusercontent.com/PaddlePaddle/PaddleDetection/release/2.4/demo/000000014439_640x640.jpg"
|
||||
fd.download_and_decompress(model_url, ".")
|
||||
fd.download(test_jpg_url, ".", show_progress=True)
|
||||
|
||||
# 加载模型
|
||||
model = fd.vision.ppdet.PPYOLOE("ppyoloe_crn_l_300e_coco/model.pdmodel",
|
||||
"ppyoloe_crn_l_300e_coco/model.pdiparams",
|
||||
"ppyoloe_crn_l_300e_coco/infer_cfg.yml")
|
||||
|
||||
# 预测图片
|
||||
im = cv2.imread("000000014439_640x640.jpg")
|
||||
result = model.predict(im)
|
||||
|
||||
# 可视化结果
|
||||
fd.vision.visualize.vis_detection(im, result)
|
||||
cv2.imwrite("vis_result.jpg", im)
|
||||
|
||||
# 输出预测结果
|
||||
print(result)
|
13
setup.py
13
setup.py
@@ -371,9 +371,13 @@ if sys.argv[1] == "install" or sys.argv[1] == "bdist_wheel":
|
||||
for f1 in os.listdir(lib_dir_name):
|
||||
release_dir = os.path.join(lib_dir_name, f1)
|
||||
if f1 == "Release" and not os.path.isfile(release_dir):
|
||||
if os.path.exists(os.path.join("fastdeploy/libs/third_libs", f)):
|
||||
shutil.rmtree(os.path.join("fastdeploy/libs/third_libs", f))
|
||||
shutil.copytree(release_dir, os.path.join("fastdeploy/libs/third_libs", f, "lib"))
|
||||
if os.path.exists(
|
||||
os.path.join("fastdeploy/libs/third_libs", f)):
|
||||
shutil.rmtree(
|
||||
os.path.join("fastdeploy/libs/third_libs", f))
|
||||
shutil.copytree(release_dir,
|
||||
os.path.join("fastdeploy/libs/third_libs",
|
||||
f, "lib"))
|
||||
|
||||
if platform.system().lower() == "windows":
|
||||
release_dir = os.path.join(".setuptools-cmake-build", "Release")
|
||||
@@ -398,6 +402,9 @@ if sys.argv[1] == "install" or sys.argv[1] == "bdist_wheel":
|
||||
path))
|
||||
rpaths = ":".join(rpaths)
|
||||
command = "patchelf --set-rpath '{}' ".format(rpaths) + pybind_so_file
|
||||
print(
|
||||
"=========================Set rpath for library===================")
|
||||
print(command)
|
||||
# The sw_64 not suppot patchelf, so we just disable that.
|
||||
if platform.machine() != 'sw_64' and platform.machine() != 'mips64':
|
||||
assert os.system(
|
||||
|
Reference in New Issue
Block a user