mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Model] Add Picodet RKNPU2 (#635)
* * 更新picodet cpp代码 * * 更新文档 * 更新picodet cpp example * * 删除无用的debug代码 * 新增python example * * 修改c++代码 * * 修改python代码 * * 修改postprocess代码 * 修复没有scale_factor导致的bug * 修复错误 * 更正代码格式 * 更正代码格式
This commit is contained in:
@@ -13,7 +13,9 @@ RKNPU部署模型前需要将Paddle模型转换成RKNN模型,具体步骤如
|
||||
|
||||
|
||||
## 模型转换example
|
||||
下面以Picodet-npu为例子,教大家如何转换PaddleDetection模型到RKNN模型。
|
||||
以下步骤均在Ubuntu电脑上完成,请参考配置文档完成转换模型环境配置。下面以Picodet-s为例子,教大家如何转换PaddleDetection模型到RKNN模型。
|
||||
|
||||
### 导出ONNX模型
|
||||
```bash
|
||||
# 下载Paddle静态图模型并解压
|
||||
wget https://paddledet.bj.bcebos.com/deploy/Inference/picodet_s_416_coco_lcnet.tar
|
||||
@@ -26,12 +28,89 @@ paddle2onnx --model_dir picodet_s_416_coco_lcnet \
|
||||
--save_file picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \
|
||||
--enable_dev_version True
|
||||
|
||||
# 固定shape
|
||||
python -m paddle2onnx.optimize --input_model picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \
|
||||
--output_model picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \
|
||||
--input_shape_dict "{'image':[1,3,416,416]}"
|
||||
```
|
||||
|
||||
### 编写模型导出配置文件
|
||||
以转化RK3568的RKNN模型为例子,我们需要编辑tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml,来转换ONNX模型到RKNN模型。
|
||||
|
||||
**修改normalize参数**
|
||||
|
||||
如果你需要在NPU上执行normalize操作,请根据你的模型配置normalize参数,例如:
|
||||
```yaml
|
||||
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
|
||||
output_folder: ./picodet_s_416_coco_lcnet
|
||||
target_platform: RK3568
|
||||
normalize:
|
||||
mean: [[0.485,0.456,0.406],[0,0,0]]
|
||||
std: [[0.229,0.224,0.225],[0.003921,0.003921]]
|
||||
outputs: ['tmp_17','p2o.Concat.9']
|
||||
```
|
||||
|
||||
**修改outputs参数**
|
||||
由于Paddle2ONNX版本的不同,转换模型的输出节点名称也有所不同,请使用[Netron](https://netron.app),并找到以下蓝色方框标记的NonMaxSuppression节点,红色方框的节点名称即为目标名称。
|
||||
|
||||
例如,使用Netron可视化后,得到以下图片:
|
||||

|
||||
|
||||
找到蓝色方框标记的NonMaxSuppression节点,可以看到红色方框标记的两个节点名称为tmp_17和p2o.Concat.9,因此需要修改outputs参数,修改后如下:
|
||||
```yaml
|
||||
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
|
||||
output_folder: ./picodet_s_416_coco_lcnet
|
||||
target_platform: RK3568
|
||||
normalize: None
|
||||
outputs: ['tmp_17','p2o.Concat.9']
|
||||
```
|
||||
|
||||
### 转换模型
|
||||
```bash
|
||||
|
||||
# ONNX模型转RKNN模型
|
||||
# 转换模型,模型将生成在picodet_s_320_coco_lcnet_non_postprocess目录下
|
||||
python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3588/picodet_s_416_coco_lcnet.yaml
|
||||
python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml
|
||||
```
|
||||
|
||||
### 修改模型运行时的配置文件
|
||||
|
||||
配置文件中,我们只需要修改**Preprocess**下的**Normalize**和**Permute**.
|
||||
|
||||
**删除Permute**
|
||||
|
||||
RKNPU只支持NHWC的输入格式,因此需要删除Permute操作.删除后,配置文件Precess部分后如下:
|
||||
```yaml
|
||||
Preprocess:
|
||||
- interp: 2
|
||||
keep_ratio: false
|
||||
target_size:
|
||||
- 416
|
||||
- 416
|
||||
type: Resize
|
||||
- is_scale: true
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
type: NormalizeImage
|
||||
```
|
||||
|
||||
**根据模型转换文件决定是否删除Normalize**
|
||||
|
||||
RKNPU支持使用NPU进行Normalize操作,如果你在导出模型时配置了Normalize参数,请删除**Normalize**.删除后配置文件Precess部分如下:
|
||||
```yaml
|
||||
Preprocess:
|
||||
- interp: 2
|
||||
keep_ratio: false
|
||||
target_size:
|
||||
- 416
|
||||
- 416
|
||||
type: Resize
|
||||
```
|
||||
|
||||
- [Python部署](./python)
|
||||
|
@@ -33,5 +33,5 @@ install(DIRECTORY ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/opencv/lib DESTIN
|
||||
file(GLOB PADDLETOONNX_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/paddle2onnx/lib/*)
|
||||
install(PROGRAMS ${PADDLETOONNX_LIBS} DESTINATION lib)
|
||||
|
||||
file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/RK3588/lib/*)
|
||||
file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/${RKNN2_TARGET_SOC}/lib/*)
|
||||
install(PROGRAMS ${RKNPU2_LIBS} DESTINATION lib)
|
||||
|
@@ -62,7 +62,7 @@ make install
|
||||
|
||||
```bash
|
||||
cd ./build/install
|
||||
./rknpu_test
|
||||
./infer_picodet model/picodet_s_416_coco_lcnet images/000000014439.jpg
|
||||
```
|
||||
|
||||
|
||||
|
@@ -14,73 +14,53 @@
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include "fastdeploy/vision.h"
|
||||
#include <sys/time.h>
|
||||
double __get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
|
||||
void InferPicodet(const std::string& model_dir, const std::string& image_file);
|
||||
|
||||
void InferPicodet(const std::string& device = "cpu");
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc < 3) {
|
||||
std::cout
|
||||
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
|
||||
"e.g ./infer_model ./picodet_model_dir ./test.jpeg"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
InferPicodet(argv[1], argv[2]);
|
||||
|
||||
int main() {
|
||||
InferPicodet("npu");
|
||||
return 0;
|
||||
}
|
||||
|
||||
fastdeploy::RuntimeOption GetOption(const std::string& device) {
|
||||
void InferPicodet(const std::string& model_dir, const std::string& image_file) {
|
||||
struct timeval start_time, stop_time;
|
||||
auto model_file = model_dir + "/picodet_s_416_coco_lcnet_rk3568.rknn";
|
||||
auto params_file = "";
|
||||
auto config_file = model_dir + "/infer_cfg.yml";
|
||||
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
if (device == "npu") {
|
||||
option.UseRKNPU2();
|
||||
} else {
|
||||
option.UseCpu();
|
||||
}
|
||||
return option;
|
||||
}
|
||||
|
||||
fastdeploy::ModelFormat GetFormat(const std::string& device) {
|
||||
auto format = fastdeploy::ModelFormat::ONNX;
|
||||
if (device == "npu") {
|
||||
format = fastdeploy::ModelFormat::RKNN;
|
||||
} else {
|
||||
format = fastdeploy::ModelFormat::ONNX;
|
||||
}
|
||||
return format;
|
||||
}
|
||||
auto format = fastdeploy::ModelFormat::RKNN;
|
||||
|
||||
std::string GetModelPath(std::string& model_path, const std::string& device) {
|
||||
if (device == "npu") {
|
||||
model_path += "rknn";
|
||||
} else {
|
||||
model_path += "onnx";
|
||||
}
|
||||
return model_path;
|
||||
}
|
||||
|
||||
void InferPicodet(const std::string &device) {
|
||||
std::string model_file = "./model/picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet_rk3588.";
|
||||
std::string params_file;
|
||||
std::string config_file = "./model/picodet_s_416_coco_lcnet/infer_cfg.yml";
|
||||
|
||||
fastdeploy::RuntimeOption option = GetOption(device);
|
||||
fastdeploy::ModelFormat format = GetFormat(device);
|
||||
model_file = GetModelPath(model_file, device);
|
||||
auto model = fastdeploy::vision::detection::RKPicoDet(
|
||||
auto model = fastdeploy::vision::detection::PicoDet(
|
||||
model_file, params_file, config_file,option,format);
|
||||
|
||||
if (!model.Initialized()) {
|
||||
std::cerr << "Failed to initialize." << std::endl;
|
||||
return;
|
||||
}
|
||||
auto image_file = "./images/000000014439.jpg";
|
||||
model.GetPostprocessor().ApplyDecodeAndNMS();
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
|
||||
fastdeploy::vision::DetectionResult res;
|
||||
clock_t start = clock();
|
||||
gettimeofday(&start_time, NULL);
|
||||
if (!model.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
clock_t end = clock();
|
||||
auto dur = static_cast<double>(end - start);
|
||||
printf("picodet_npu use time:%f\n", (dur / CLOCKS_PER_SEC));
|
||||
gettimeofday(&stop_time, NULL);
|
||||
printf("infer use %f ms\n", (__get_us(stop_time) - __get_us(start_time)) / 1000);
|
||||
|
||||
std::cout << res.Str() << std::endl;
|
||||
auto vis_im = fastdeploy::vision::VisDetection(im, res,0.5);
|
||||
cv::imwrite("picodet_npu_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./picodet_npu_result.jpg" << std::endl;
|
||||
cv::imwrite("picodet_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./picodet_result.jpg" << std::endl;
|
||||
}
|
@@ -15,11 +15,11 @@ cd FastDeploy/examples/vision/detection/paddledetection/rknpu2/python
|
||||
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
|
||||
|
||||
# copy model
|
||||
cp -r ./picodet_s_416_coco_npu /path/to/FastDeploy/examples/vision/detection/rknpu2detection/paddledetection/python
|
||||
cp -r ./picodet_s_416_coco_lcnet /path/to/FastDeploy/examples/vision/detection/rknpu2detection/paddledetection/python
|
||||
|
||||
# 推理
|
||||
python3 infer.py --model_file ./picodet_s_416_coco_npu/picodet_s_416_coco_npu_3588.rknn \
|
||||
--config_file ./picodet_s_416_coco_npu/infer_cfg.yml \
|
||||
python3 infer.py --model_file ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet_rk3568.rknn \
|
||||
--config_file ./picodet_s_416_coco_lcnet/infer_cfg.yml \
|
||||
--image 000000014439.jpg
|
||||
```
|
||||
|
||||
|
@@ -28,32 +28,32 @@ def parse_arguments():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
option.use_rknpu2()
|
||||
return option
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
|
||||
model_file = args.model_file
|
||||
params_file = ""
|
||||
config_file = args.config_file
|
||||
|
||||
args = parse_arguments()
|
||||
# 配置runtime,加载模型
|
||||
runtime_option = fd.RuntimeOption()
|
||||
runtime_option.use_rknpu2()
|
||||
|
||||
# 配置runtime,加载模型
|
||||
runtime_option = build_option(args)
|
||||
model_file = args.model_file
|
||||
params_file = ""
|
||||
config_file = args.config_file
|
||||
model = fd.vision.detection.RKPicoDet(
|
||||
model = fd.vision.detection.PicoDet(
|
||||
model_file,
|
||||
params_file,
|
||||
config_file,
|
||||
runtime_option=runtime_option,
|
||||
model_format=fd.ModelFormat.RKNN)
|
||||
|
||||
# 预测图片分割结果
|
||||
im = cv2.imread(args.image)
|
||||
result = model.predict(im.copy())
|
||||
print(result)
|
||||
model.postprocessor.apply_decode_and_nms()
|
||||
|
||||
# 可视化结果
|
||||
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
|
||||
cv2.imwrite("visualized_result.jpg", vis_im)
|
||||
print("Visualized result save in ./visualized_result.jpg")
|
||||
# 预测图片分割结果
|
||||
im = cv2.imread(args.image)
|
||||
result = model.predict(im.copy())
|
||||
print(result)
|
||||
|
||||
# 可视化结果
|
||||
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
|
||||
cv2.imwrite("visualized_result.jpg", vis_im)
|
||||
print("Visualized result save in ./visualized_result.jpg")
|
||||
|
@@ -6,10 +6,12 @@ namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
PPDetBase::PPDetBase(const std::string& model_file, const std::string& params_file,
|
||||
PPDetBase::PPDetBase(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) : preprocessor_(config_file) {
|
||||
const ModelFormat& model_format)
|
||||
: preprocessor_(config_file) {
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
@@ -37,7 +39,8 @@ bool PPDetBase::Predict(const cv::Mat& im, DetectionResult* result) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PPDetBase::BatchPredict(const std::vector<cv::Mat>& imgs, std::vector<DetectionResult>* results) {
|
||||
bool PPDetBase::BatchPredict(const std::vector<cv::Mat>& imgs,
|
||||
std::vector<DetectionResult>* results) {
|
||||
std::vector<FDMat> fd_images = WrapMat(imgs);
|
||||
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
|
||||
FDERROR << "Failed to preprocess the input image." << std::endl;
|
||||
@@ -46,8 +49,13 @@ bool PPDetBase::BatchPredict(const std::vector<cv::Mat>& imgs, std::vector<Detec
|
||||
reused_input_tensors_[0].name = "image";
|
||||
reused_input_tensors_[1].name = "scale_factor";
|
||||
reused_input_tensors_[2].name = "im_shape";
|
||||
// Some models don't need im_shape as input
|
||||
if (NumInputsOfRuntime() == 2) {
|
||||
|
||||
if(postprocessor_.DecodeAndNMSApplied()){
|
||||
postprocessor_.SetScaleFactor(static_cast<float*>(reused_input_tensors_[1].Data()));
|
||||
}
|
||||
|
||||
// Some models don't need scale_factor and im_shape as input
|
||||
while (reused_input_tensors_.size() != NumInputsOfRuntime()) {
|
||||
reused_input_tensors_.pop_back();
|
||||
}
|
||||
|
||||
@@ -57,7 +65,8 @@ bool PPDetBase::BatchPredict(const std::vector<cv::Mat>& imgs, std::vector<Detec
|
||||
}
|
||||
|
||||
if (!postprocessor_.Run(reused_output_tensors_, results)) {
|
||||
FDERROR << "Failed to postprocess the inference results by runtime." << std::endl;
|
||||
FDERROR << "Failed to postprocess the inference results by runtime."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@@ -38,6 +38,7 @@ class FASTDEPLOY_DECL PicoDet : public PPDetBase {
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT,
|
||||
Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
valid_rknpu_backends = {Backend::RKNPU2};
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
|
@@ -19,10 +19,12 @@ namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
bool PaddleDetPostprocessor::ProcessMask(const FDTensor& tensor, std::vector<DetectionResult>* results) {
|
||||
bool PaddleDetPostprocessor::ProcessMask(
|
||||
const FDTensor& tensor, std::vector<DetectionResult>* results) {
|
||||
auto shape = tensor.Shape();
|
||||
if (tensor.Dtype() != FDDataType::INT32) {
|
||||
FDERROR << "The data type of out mask tensor should be INT32, but now it's " << tensor.Dtype() << std::endl;
|
||||
FDERROR << "The data type of out mask tensor should be INT32, but now it's "
|
||||
<< tensor.Dtype() << std::endl;
|
||||
return false;
|
||||
}
|
||||
int64_t out_mask_h = shape[1];
|
||||
@@ -46,7 +48,8 @@ bool PaddleDetPostprocessor::ProcessMask(const FDTensor& tensor, std::vector<Det
|
||||
(*results)[i].masks[j].shape = {keep_mask_h, keep_mask_w};
|
||||
const int32_t* current_ptr = data + index * out_mask_numel;
|
||||
|
||||
int32_t* keep_mask_ptr = reinterpret_cast<int32_t*>((*results)[i].masks[j].Data());
|
||||
int32_t* keep_mask_ptr =
|
||||
reinterpret_cast<int32_t*>((*results)[i].masks[j].Data());
|
||||
for (int row = y1; row < y2; ++row) {
|
||||
size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t);
|
||||
const int32_t* out_row_start_ptr = current_ptr + row * out_mask_w + x1;
|
||||
@@ -59,7 +62,20 @@ bool PaddleDetPostprocessor::ProcessMask(const FDTensor& tensor, std::vector<Det
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vector<DetectionResult>* results) {
|
||||
bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results) {
|
||||
if (DecodeAndNMSApplied()) {
|
||||
FDASSERT(tensors.size() == 2,
|
||||
"While postprocessing with ApplyDecodeAndNMS, "
|
||||
"there should be 2 outputs for this model, but now it's %zu.",
|
||||
tensors.size());
|
||||
FDASSERT(tensors[0].shape.size() == 3,
|
||||
"While postprocessing with ApplyDecodeAndNMS, "
|
||||
"the rank of the first outputs should be 3, but now it's %zu",
|
||||
tensors[0].shape.size());
|
||||
return ProcessUnDecodeResults(tensors, results);
|
||||
}
|
||||
|
||||
if (tensors[0].shape[0] == 0) {
|
||||
// No detected boxes
|
||||
return true;
|
||||
@@ -69,13 +85,13 @@ bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vect
|
||||
std::vector<int> num_boxes(tensors[1].shape[0]);
|
||||
int total_num_boxes = 0;
|
||||
if (tensors[1].dtype == FDDataType::INT32) {
|
||||
const int32_t* data = static_cast<const int32_t*>(tensors[1].CpuData());
|
||||
const auto* data = static_cast<const int32_t*>(tensors[1].CpuData());
|
||||
for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
|
||||
num_boxes[i] = static_cast<int>(data[i]);
|
||||
total_num_boxes += num_boxes[i];
|
||||
}
|
||||
} else if (tensors[1].dtype == FDDataType::INT64) {
|
||||
const int64_t* data = static_cast<const int64_t*>(tensors[1].CpuData());
|
||||
const auto* data = static_cast<const int64_t*>(tensors[1].CpuData());
|
||||
for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
|
||||
num_boxes[i] = static_cast<int>(data[i]);
|
||||
}
|
||||
@@ -83,33 +99,37 @@ bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vect
|
||||
|
||||
// Special case for TensorRT, it has fixed output shape of NMS
|
||||
// So there's invalid boxes in its' output boxes
|
||||
int num_output_boxes = tensors[0].Shape()[0];
|
||||
int num_output_boxes = static_cast<int>(tensors[0].Shape()[0]);
|
||||
bool contain_invalid_boxes = false;
|
||||
if (total_num_boxes != num_output_boxes) {
|
||||
if (num_output_boxes % num_boxes.size() == 0) {
|
||||
contain_invalid_boxes = true;
|
||||
} else {
|
||||
FDERROR << "Cannot handle the output data for this model, unexpected situation." << std::endl;
|
||||
FDERROR << "Cannot handle the output data for this model, unexpected "
|
||||
"situation."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Get boxes for each input image
|
||||
results->resize(num_boxes.size());
|
||||
const float* box_data = static_cast<const float*>(tensors[0].CpuData());
|
||||
const auto* box_data = static_cast<const float*>(tensors[0].CpuData());
|
||||
int offset = 0;
|
||||
for (size_t i = 0; i < num_boxes.size(); ++i) {
|
||||
const float* ptr = box_data + offset;
|
||||
(*results)[i].Reserve(num_boxes[i]);
|
||||
for (size_t j = 0; j < num_boxes[i]; ++j) {
|
||||
(*results)[i].label_ids.push_back(static_cast<int32_t>(round(ptr[j * 6])));
|
||||
(*results)[i].label_ids.push_back(
|
||||
static_cast<int32_t>(round(ptr[j * 6])));
|
||||
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
|
||||
(*results)[i].boxes.emplace_back(std::array<float, 4>({ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
|
||||
(*results)[i].boxes.emplace_back(std::array<float, 4>(
|
||||
{ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
|
||||
}
|
||||
if (contain_invalid_boxes) {
|
||||
offset += (num_output_boxes * 6 / num_boxes.size());
|
||||
offset += static_cast<int>(num_output_boxes * 6 / num_boxes.size());
|
||||
} else {
|
||||
offset += (num_boxes[i] * 6);
|
||||
offset += static_cast<int>(num_boxes[i] * 6);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +139,10 @@ bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vect
|
||||
}
|
||||
|
||||
if (tensors[2].Shape()[0] != num_output_boxes) {
|
||||
FDERROR << "The first dimension of output mask tensor:" << tensors[2].Shape()[0] << " is not equal to the first dimension of output boxes tensor:" << num_output_boxes << "." << std::endl;
|
||||
FDERROR << "The first dimension of output mask tensor:"
|
||||
<< tensors[2].Shape()[0]
|
||||
<< " is not equal to the first dimension of output boxes tensor:"
|
||||
<< num_output_boxes << "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -127,6 +150,80 @@ bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vect
|
||||
return ProcessMask(tensors[2], results);
|
||||
}
|
||||
|
||||
void PaddleDetPostprocessor::ApplyDecodeAndNMS() {
|
||||
apply_decode_and_nms_ = true;
|
||||
}
|
||||
|
||||
bool PaddleDetPostprocessor::ProcessUnDecodeResults(
|
||||
const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results) {
|
||||
if (tensors.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int boxes_index = 0;
|
||||
int scores_index = 1;
|
||||
if (tensors[0].shape[1] == tensors[1].shape[2]) {
|
||||
boxes_index = 0;
|
||||
scores_index = 1;
|
||||
} else if (tensors[0].shape[2] == tensors[1].shape[1]) {
|
||||
boxes_index = 1;
|
||||
scores_index = 0;
|
||||
} else {
|
||||
FDERROR << "The shape of boxes and scores should be [batch, boxes_num, "
|
||||
"4], [batch, classes_num, boxes_num]"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
backend::MultiClassNMS nms;
|
||||
nms.background_label = -1;
|
||||
nms.keep_top_k = 100;
|
||||
nms.nms_eta = 1.0;
|
||||
nms.nms_threshold = 0.5;
|
||||
nms.score_threshold = 0.3;
|
||||
nms.nms_top_k = 1000;
|
||||
nms.normalized = true;
|
||||
nms.Compute(static_cast<const float*>(tensors[boxes_index].Data()),
|
||||
static_cast<const float*>(tensors[scores_index].Data()),
|
||||
tensors[boxes_index].shape, tensors[scores_index].shape);
|
||||
|
||||
auto num_boxes = nms.out_num_rois_data;
|
||||
auto box_data = static_cast<const float*>(nms.out_box_data.data());
|
||||
// Get boxes for each input image
|
||||
results->resize(num_boxes.size());
|
||||
int offset = 0;
|
||||
for (size_t i = 0; i < num_boxes.size(); ++i) {
|
||||
const float* ptr = box_data + offset;
|
||||
(*results)[i].Reserve(num_boxes[i]);
|
||||
for (size_t j = 0; j < num_boxes[i]; ++j) {
|
||||
(*results)[i].label_ids.push_back(
|
||||
static_cast<int32_t>(round(ptr[j * 6])));
|
||||
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
|
||||
(*results)[i].boxes.emplace_back(std::array<float, 4>(
|
||||
{ptr[j * 6 + 2] / GetScaleFactor()[1],
|
||||
ptr[j * 6 + 3] / GetScaleFactor()[0],
|
||||
ptr[j * 6 + 4] / GetScaleFactor()[1],
|
||||
ptr[j * 6 + 5] / GetScaleFactor()[0]}));
|
||||
}
|
||||
offset += (num_boxes[i] * 6);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<float> PaddleDetPostprocessor::GetScaleFactor(){
|
||||
return scale_factor_;
|
||||
}
|
||||
|
||||
void PaddleDetPostprocessor::SetScaleFactor(float* scale_factor_value){
|
||||
for (int i = 0; i < scale_factor_.size(); ++i) {
|
||||
scale_factor_[i] = scale_factor_value[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool PaddleDetPostprocessor::DecodeAndNMSApplied() {
|
||||
return apply_decode_and_nms_;
|
||||
}
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
@@ -33,10 +33,27 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
|
||||
*/
|
||||
bool Run(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* result);
|
||||
|
||||
/// Apply box decoding and nms step for the outputs for the model.This is
|
||||
/// only available for those model exported without box decoding and nms.
|
||||
void ApplyDecodeAndNMS();
|
||||
|
||||
bool DecodeAndNMSApplied();
|
||||
|
||||
/// Set scale_factor_ value.This is only available for those model exported
|
||||
/// without box decoding and nms.
|
||||
void SetScaleFactor(float* scale_factor_value);
|
||||
|
||||
private:
|
||||
// Process mask tensor for MaskRCNN
|
||||
bool ProcessMask(const FDTensor& tensor,
|
||||
std::vector<DetectionResult>* results);
|
||||
|
||||
bool apply_decode_and_nms_ = false;
|
||||
std::vector<float> scale_factor_{1.0, 1.0};
|
||||
std::vector<float> GetScaleFactor();
|
||||
bool ProcessUnDecodeResults(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results);
|
||||
};
|
||||
|
||||
} // namespace detection
|
||||
|
@@ -43,6 +43,10 @@ void BindPPDet(pybind11::module& m) {
|
||||
}
|
||||
return results;
|
||||
})
|
||||
.def("apply_decode_and_nms",
|
||||
[](vision::detection::PaddleDetPostprocessor& self){
|
||||
self.ApplyDecodeAndNMS();
|
||||
})
|
||||
.def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector<pybind11::array>& input_array) {
|
||||
std::vector<vision::DetectionResult> results;
|
||||
std::vector<FDTensor> inputs;
|
||||
|
@@ -22,11 +22,13 @@ namespace vision {
|
||||
namespace detection {
|
||||
|
||||
PaddleDetPreprocessor::PaddleDetPreprocessor(const std::string& config_file) {
|
||||
FDASSERT(BuildPreprocessPipelineFromConfig(config_file), "Failed to create PaddleDetPreprocessor.");
|
||||
FDASSERT(BuildPreprocessPipelineFromConfig(config_file),
|
||||
"Failed to create PaddleDetPreprocessor.");
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& config_file) {
|
||||
bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(
|
||||
const std::string& config_file) {
|
||||
processors_.clear();
|
||||
YAML::Node cfg;
|
||||
try {
|
||||
@@ -106,8 +108,6 @@ bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string&
|
||||
// permute = cast<float> + HWC2CHW
|
||||
processors_.push_back(std::make_shared<Cast>("float"));
|
||||
processors_.push_back(std::make_shared<HWC2CHW>());
|
||||
} else {
|
||||
processors_.push_back(std::make_shared<HWC2CHW>());
|
||||
}
|
||||
|
||||
// Fusion will improve performance
|
||||
@@ -116,13 +116,15 @@ bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string&
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
|
||||
bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
if (!initialized_) {
|
||||
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (images->size() == 0) {
|
||||
FDERROR << "The size of input images should be greater than 0." << std::endl;
|
||||
FDERROR << "The size of input images should be greater than 0."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -140,7 +142,8 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor
|
||||
// All the tensor will pad to the max size to compose a batched tensor
|
||||
std::vector<int> max_hw({-1, -1});
|
||||
|
||||
float* scale_factor_ptr = reinterpret_cast<float*>((*outputs)[1].MutableData());
|
||||
float* scale_factor_ptr =
|
||||
reinterpret_cast<float*>((*outputs)[1].MutableData());
|
||||
float* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
int origin_w = (*images)[i].Width();
|
||||
@@ -149,7 +152,8 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor
|
||||
scale_factor_ptr[2 * i + 1] = 1.0;
|
||||
for (size_t j = 0; j < processors_.size(); ++j) {
|
||||
if (!(*(processors_[j].get()))(&((*images)[i]))) {
|
||||
FDERROR << "Failed to processs image:" << i << " in " << processors_[i]->Name() << "." << std::endl;
|
||||
FDERROR << "Failed to processs image:" << i << " in "
|
||||
<< processors_[i]->Name() << "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (processors_[j]->Name().find("Resize") != std::string::npos) {
|
||||
@@ -174,7 +178,10 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor
|
||||
// if the size of image less than max_hw, pad to max_hw
|
||||
FDTensor tensor;
|
||||
(*images)[i].ShareWithTensor(&tensor);
|
||||
function::Pad(tensor, &(im_tensors[i]), {0, 0, max_hw[0] - (*images)[i].Height(), max_hw[1] - (*images)[i].Width()}, 0);
|
||||
function::Pad(tensor, &(im_tensors[i]),
|
||||
{0, 0, max_hw[0] - (*images)[i].Height(),
|
||||
max_hw[1] - (*images)[i].Width()},
|
||||
0);
|
||||
} else {
|
||||
// No need pad
|
||||
(*images)[i].ShareWithTensor(&(im_tensors[i]));
|
||||
|
@@ -52,6 +52,11 @@ class PaddleDetPostprocessor:
|
||||
"""
|
||||
return self._postprocessor.run(runtime_results)
|
||||
|
||||
def apply_decode_and_nms(self):
|
||||
"""This function will enable decode and nms in postprocess step.
|
||||
"""
|
||||
return self._postprocessor.apply_decode_and_nms()
|
||||
|
||||
|
||||
class PPYOLOE(FastDeployModel):
|
||||
def __init__(self,
|
||||
@@ -70,7 +75,6 @@ class PPYOLOE(FastDeployModel):
|
||||
"""
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == ModelFormat.PADDLE, "PPYOLOE model only support model format of ModelFormat.Paddle now."
|
||||
self._model = C.vision.detection.PPYOLOE(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
@@ -179,7 +183,6 @@ class PicoDet(PPYOLOE):
|
||||
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == ModelFormat.PADDLE, "PicoDet model only support model format of ModelFormat.Paddle now."
|
||||
self._model = C.vision.detection.PicoDet(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
|
@@ -1,44 +0,0 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from typing import Union, List
|
||||
import logging
|
||||
from .... import FastDeployModel, ModelFormat
|
||||
from .... import c_lib_wrap as C
|
||||
from .. import PPYOLOE
|
||||
|
||||
|
||||
class RKPicoDet(PPYOLOE):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file,
|
||||
config_file,
|
||||
runtime_option=None,
|
||||
model_format=ModelFormat.RKNN):
|
||||
"""Load a PicoDet model exported by PaddleDetection.
|
||||
|
||||
:param model_file: (str)Path of model file, e.g picodet/model.pdmodel
|
||||
:param params_file: (str)Path of parameters file, e.g picodet/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
|
||||
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
|
||||
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
|
||||
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
|
||||
"""
|
||||
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == ModelFormat.RKNN, "RKPicoDet model only support model format of ModelFormat.RKNN now."
|
||||
self._model = C.vision.detection.RKPicoDet(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "RKPicoDet model initialize failed."
|
@@ -4,3 +4,4 @@ tqdm
|
||||
numpy
|
||||
opencv-python
|
||||
fd-auto-compress>=0.0.1
|
||||
pyyaml
|
||||
|
@@ -61,7 +61,8 @@ setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND",
|
||||
"OFF")
|
||||
setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND",
|
||||
"OFF")
|
||||
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", "OFF")
|
||||
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND",
|
||||
"OFF")
|
||||
setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF")
|
||||
setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF")
|
||||
setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF")
|
||||
@@ -71,13 +72,15 @@ setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF")
|
||||
setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF")
|
||||
setup_configs["BUILD_ON_JETSON"] = os.getenv("BUILD_ON_JETSON", "OFF")
|
||||
setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED")
|
||||
setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", "/usr/local/cuda")
|
||||
setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY",
|
||||
"/usr/local/cuda")
|
||||
setup_configs["LIBRARY_NAME"] = PACKAGE_NAME
|
||||
setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main"
|
||||
setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "")
|
||||
setup_configs["ORT_DIRECTORY"] = os.getenv("ORT_DIRECTORY", "")
|
||||
|
||||
setup_configs["RKNN2_TARGET_SOC"] = os.getenv("RKNN2_TARGET_SOC", "")
|
||||
if setup_configs["RKNN2_TARGET_SOC"] != "":
|
||||
REQUIRED_PACKAGES = REQUIRED_PACKAGES.replace("opencv-python", "")
|
||||
|
||||
if setup_configs["WITH_GPU"] == "ON" or setup_configs[
|
||||
"BUILD_ON_JETSON"] == "ON":
|
||||
@@ -105,7 +108,8 @@ extras_require = {}
|
||||
|
||||
# Default value is set to TRUE\1 to keep the settings same as the current ones.
|
||||
# However going forward the recomemded way to is to set this to False\0
|
||||
USE_MSVC_STATIC_RUNTIME = bool(os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1')
|
||||
USE_MSVC_STATIC_RUNTIME = bool(
|
||||
os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1')
|
||||
ONNX_NAMESPACE = os.getenv('ONNX_NAMESPACE', 'paddle2onnx')
|
||||
################################################################################
|
||||
# Version
|
||||
@@ -135,7 +139,8 @@ assert CMAKE, 'Could not find "cmake" executable!'
|
||||
@contextmanager
|
||||
def cd(path):
|
||||
if not os.path.isabs(path):
|
||||
raise RuntimeError('Can only cd to absolute path, got: {}'.format(path))
|
||||
raise RuntimeError('Can only cd to absolute path, got: {}'.format(
|
||||
path))
|
||||
orig_path = os.getcwd()
|
||||
os.chdir(path)
|
||||
try:
|
||||
|
@@ -1,7 +1,5 @@
|
||||
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
|
||||
output_folder: ./picodet_s_416_coco_lcnet
|
||||
target_platform: RK3568
|
||||
normalize:
|
||||
mean: [[0.485,0.456,0.406]]
|
||||
std: [[0.229,0.224,0.225]]
|
||||
outputs: ['tmp_16','p2o.Concat.9']
|
||||
normalize: None
|
||||
outputs: ['tmp_17','p2o.Concat.9']
|
||||
|
@@ -1,5 +0,0 @@
|
||||
model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx
|
||||
output_folder: ./picodet_s_416_coco_npu
|
||||
target_platform: RK3568
|
||||
normalize: None
|
||||
outputs: ['tmp_16','p2o.Concat.17']
|
@@ -1,7 +1,5 @@
|
||||
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
|
||||
output_folder: ./picodet_s_416_coco_lcnet
|
||||
target_platform: RK3588
|
||||
normalize:
|
||||
mean: [[0.485,0.456,0.406]]
|
||||
std: [[0.229,0.224,0.225]]
|
||||
normalize: None
|
||||
outputs: ['tmp_16','p2o.Concat.9']
|
||||
|
@@ -1,5 +0,0 @@
|
||||
model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx
|
||||
output_folder: ./picodet_s_416_coco_npu
|
||||
target_platform: RK3588
|
||||
normalize: None
|
||||
outputs: ['tmp_16','p2o.Concat.17']
|
Reference in New Issue
Block a user