diff --git a/examples/vision/detection/paddledetection/rknpu2/README.md b/examples/vision/detection/paddledetection/rknpu2/README.md index d75e09aa2..d5f339db5 100644 --- a/examples/vision/detection/paddledetection/rknpu2/README.md +++ b/examples/vision/detection/paddledetection/rknpu2/README.md @@ -13,7 +13,9 @@ RKNPU部署模型前需要将Paddle模型转换成RKNN模型,具体步骤如 ## 模型转换example -下面以Picodet-npu为例子,教大家如何转换PaddleDetection模型到RKNN模型。 +以下步骤均在Ubuntu电脑上完成,请参考配置文档完成转换模型环境配置。下面以Picodet-s为例子,教大家如何转换PaddleDetection模型到RKNN模型。 + +### 导出ONNX模型 ```bash # 下载Paddle静态图模型并解压 wget https://paddledet.bj.bcebos.com/deploy/Inference/picodet_s_416_coco_lcnet.tar @@ -26,12 +28,89 @@ paddle2onnx --model_dir picodet_s_416_coco_lcnet \ --save_file picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \ --enable_dev_version True +# 固定shape python -m paddle2onnx.optimize --input_model picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \ --output_model picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \ --input_shape_dict "{'image':[1,3,416,416]}" +``` + +### 编写模型导出配置文件 +以转化RK3568的RKNN模型为例子,我们需要编辑tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml,来转换ONNX模型到RKNN模型。 + +**修改normalize参数** + +如果你需要在NPU上执行normalize操作,请根据你的模型配置normalize参数,例如: +```yaml +model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx +output_folder: ./picodet_s_416_coco_lcnet +target_platform: RK3568 +normalize: + mean: [[0.485,0.456,0.406],[0,0,0]] + std: [[0.229,0.224,0.225],[0.003921,0.003921]] +outputs: ['tmp_17','p2o.Concat.9'] +``` + +**修改outputs参数** +由于Paddle2ONNX版本的不同,转换模型的输出节点名称也有所不同,请使用[Netron](https://netron.app),并找到以下蓝色方框标记的NonMaxSuppression节点,红色方框的节点名称即为目标名称。 + +例如,使用Netron可视化后,得到以下图片: +![](https://user-images.githubusercontent.com/58363586/202728663-4af0b843-d012-4aeb-8a66-626b7b87ca69.png) + +找到蓝色方框标记的NonMaxSuppression节点,可以看到红色方框标记的两个节点名称为tmp_17和p2o.Concat.9,因此需要修改outputs参数,修改后如下: +```yaml +model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx +output_folder: ./picodet_s_416_coco_lcnet +target_platform: RK3568 +normalize: None +outputs: ['tmp_17','p2o.Concat.9'] +``` + +### 转换模型 +```bash + # ONNX模型转RKNN模型 # 转换模型,模型将生成在picodet_s_320_coco_lcnet_non_postprocess目录下 -python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3588/picodet_s_416_coco_lcnet.yaml +python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml +``` + +### 修改模型运行时的配置文件 + +配置文件中,我们只需要修改**Preprocess**下的**Normalize**和**Permute**. + +**删除Permute** + +RKNPU只支持NHWC的输入格式,因此需要删除Permute操作.删除后,配置文件Precess部分后如下: +```yaml +Preprocess: +- interp: 2 + keep_ratio: false + target_size: + - 416 + - 416 + type: Resize +- is_scale: true + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + type: NormalizeImage +``` + +**根据模型转换文件决定是否删除Normalize** + +RKNPU支持使用NPU进行Normalize操作,如果你在导出模型时配置了Normalize参数,请删除**Normalize**.删除后配置文件Precess部分如下: +```yaml +Preprocess: +- interp: 2 + keep_ratio: false + target_size: + - 416 + - 416 + type: Resize ``` - [Python部署](./python) diff --git a/examples/vision/detection/paddledetection/rknpu2/cpp/CMakeLists.txt b/examples/vision/detection/paddledetection/rknpu2/cpp/CMakeLists.txt index b4eca78ec..196247773 100644 --- a/examples/vision/detection/paddledetection/rknpu2/cpp/CMakeLists.txt +++ b/examples/vision/detection/paddledetection/rknpu2/cpp/CMakeLists.txt @@ -33,5 +33,5 @@ install(DIRECTORY ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/opencv/lib DESTIN file(GLOB PADDLETOONNX_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/paddle2onnx/lib/*) install(PROGRAMS ${PADDLETOONNX_LIBS} DESTINATION lib) -file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/RK3588/lib/*) +file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/${RKNN2_TARGET_SOC}/lib/*) install(PROGRAMS ${RKNPU2_LIBS} DESTINATION lib) diff --git a/examples/vision/detection/paddledetection/rknpu2/cpp/README.md b/examples/vision/detection/paddledetection/rknpu2/cpp/README.md index d0b131971..c5ebbdeef 100644 --- a/examples/vision/detection/paddledetection/rknpu2/cpp/README.md +++ b/examples/vision/detection/paddledetection/rknpu2/cpp/README.md @@ -62,7 +62,7 @@ make install ```bash cd ./build/install -./rknpu_test +./infer_picodet model/picodet_s_416_coco_lcnet images/000000014439.jpg ``` diff --git a/examples/vision/detection/paddledetection/rknpu2/cpp/infer_picodet.cc b/examples/vision/detection/paddledetection/rknpu2/cpp/infer_picodet.cc index 862d28adc..12f405b52 100644 --- a/examples/vision/detection/paddledetection/rknpu2/cpp/infer_picodet.cc +++ b/examples/vision/detection/paddledetection/rknpu2/cpp/infer_picodet.cc @@ -14,73 +14,53 @@ #include #include #include "fastdeploy/vision.h" +#include +double __get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); } +void InferPicodet(const std::string& model_dir, const std::string& image_file); -void InferPicodet(const std::string& device = "cpu"); +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./picodet_model_dir ./test.jpeg" + << std::endl; + return -1; + } + + InferPicodet(argv[1], argv[2]); -int main() { - InferPicodet("npu"); return 0; } -fastdeploy::RuntimeOption GetOption(const std::string& device) { +void InferPicodet(const std::string& model_dir, const std::string& image_file) { + struct timeval start_time, stop_time; + auto model_file = model_dir + "/picodet_s_416_coco_lcnet_rk3568.rknn"; + auto params_file = ""; + auto config_file = model_dir + "/infer_cfg.yml"; + auto option = fastdeploy::RuntimeOption(); - if (device == "npu") { - option.UseRKNPU2(); - } else { - option.UseCpu(); - } - return option; -} + option.UseRKNPU2(); -fastdeploy::ModelFormat GetFormat(const std::string& device) { - auto format = fastdeploy::ModelFormat::ONNX; - if (device == "npu") { - format = fastdeploy::ModelFormat::RKNN; - } else { - format = fastdeploy::ModelFormat::ONNX; - } - return format; -} + auto format = fastdeploy::ModelFormat::RKNN; -std::string GetModelPath(std::string& model_path, const std::string& device) { - if (device == "npu") { - model_path += "rknn"; - } else { - model_path += "onnx"; - } - return model_path; -} - -void InferPicodet(const std::string &device) { - std::string model_file = "./model/picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet_rk3588."; - std::string params_file; - std::string config_file = "./model/picodet_s_416_coco_lcnet/infer_cfg.yml"; - - fastdeploy::RuntimeOption option = GetOption(device); - fastdeploy::ModelFormat format = GetFormat(device); - model_file = GetModelPath(model_file, device); - auto model = fastdeploy::vision::detection::RKPicoDet( + auto model = fastdeploy::vision::detection::PicoDet( model_file, params_file, config_file,option,format); - if (!model.Initialized()) { - std::cerr << "Failed to initialize." << std::endl; - return; - } - auto image_file = "./images/000000014439.jpg"; + model.GetPostprocessor().ApplyDecodeAndNMS(); + auto im = cv::imread(image_file); fastdeploy::vision::DetectionResult res; - clock_t start = clock(); + gettimeofday(&start_time, NULL); if (!model.Predict(&im, &res)) { std::cerr << "Failed to predict." << std::endl; return; } - clock_t end = clock(); - auto dur = static_cast(end - start); - printf("picodet_npu use time:%f\n", (dur / CLOCKS_PER_SEC)); + gettimeofday(&stop_time, NULL); + printf("infer use %f ms\n", (__get_us(stop_time) - __get_us(start_time)) / 1000); std::cout << res.Str() << std::endl; auto vis_im = fastdeploy::vision::VisDetection(im, res,0.5); - cv::imwrite("picodet_npu_result.jpg", vis_im); - std::cout << "Visualized result saved in ./picodet_npu_result.jpg" << std::endl; + cv::imwrite("picodet_result.jpg", vis_im); + std::cout << "Visualized result saved in ./picodet_result.jpg" << std::endl; } \ No newline at end of file diff --git a/examples/vision/detection/paddledetection/rknpu2/python/README.md b/examples/vision/detection/paddledetection/rknpu2/python/README.md index 23b13cd3b..f191063f0 100644 --- a/examples/vision/detection/paddledetection/rknpu2/python/README.md +++ b/examples/vision/detection/paddledetection/rknpu2/python/README.md @@ -15,11 +15,11 @@ cd FastDeploy/examples/vision/detection/paddledetection/rknpu2/python wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg # copy model -cp -r ./picodet_s_416_coco_npu /path/to/FastDeploy/examples/vision/detection/rknpu2detection/paddledetection/python +cp -r ./picodet_s_416_coco_lcnet /path/to/FastDeploy/examples/vision/detection/rknpu2detection/paddledetection/python # 推理 -python3 infer.py --model_file ./picodet_s_416_coco_npu/picodet_s_416_coco_npu_3588.rknn \ - --config_file ./picodet_s_416_coco_npu/infer_cfg.yml \ +python3 infer.py --model_file ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet_rk3568.rknn \ + --config_file ./picodet_s_416_coco_lcnet/infer_cfg.yml \ --image 000000014439.jpg ``` diff --git a/examples/vision/detection/paddledetection/rknpu2/python/infer.py b/examples/vision/detection/paddledetection/rknpu2/python/infer.py index ae2d8796a..62372105c 100644 --- a/examples/vision/detection/paddledetection/rknpu2/python/infer.py +++ b/examples/vision/detection/paddledetection/rknpu2/python/infer.py @@ -28,32 +28,32 @@ def parse_arguments(): return parser.parse_args() -def build_option(args): - option = fd.RuntimeOption() - option.use_rknpu2() - return option +if __name__ == "__main__": + args = parse_arguments() + model_file = args.model_file + params_file = "" + config_file = args.config_file -args = parse_arguments() + # 配置runtime,加载模型 + runtime_option = fd.RuntimeOption() + runtime_option.use_rknpu2() -# 配置runtime,加载模型 -runtime_option = build_option(args) -model_file = args.model_file -params_file = "" -config_file = args.config_file -model = fd.vision.detection.RKPicoDet( - model_file, - params_file, - config_file, - runtime_option=runtime_option, - model_format=fd.ModelFormat.RKNN) + model = fd.vision.detection.PicoDet( + model_file, + params_file, + config_file, + runtime_option=runtime_option, + model_format=fd.ModelFormat.RKNN) -# 预测图片分割结果 -im = cv2.imread(args.image) -result = model.predict(im.copy()) -print(result) + model.postprocessor.apply_decode_and_nms() -# 可视化结果 -vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) -cv2.imwrite("visualized_result.jpg", vis_im) -print("Visualized result save in ./visualized_result.jpg") + # 预测图片分割结果 + im = cv2.imread(args.image) + result = model.predict(im.copy()) + print(result) + + # 可视化结果 + vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) + cv2.imwrite("visualized_result.jpg", vis_im) + print("Visualized result save in ./visualized_result.jpg") diff --git a/fastdeploy/vision/detection/ppdet/base.cc b/fastdeploy/vision/detection/ppdet/base.cc index 1db42f158..0d4e0f290 100755 --- a/fastdeploy/vision/detection/ppdet/base.cc +++ b/fastdeploy/vision/detection/ppdet/base.cc @@ -6,10 +6,12 @@ namespace fastdeploy { namespace vision { namespace detection { -PPDetBase::PPDetBase(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option, - const ModelFormat& model_format) : preprocessor_(config_file) { +PPDetBase::PPDetBase(const std::string& model_file, + const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format) + : preprocessor_(config_file) { runtime_option = custom_option; runtime_option.model_format = model_format; runtime_option.model_file = model_file; @@ -37,7 +39,8 @@ bool PPDetBase::Predict(const cv::Mat& im, DetectionResult* result) { return true; } -bool PPDetBase::BatchPredict(const std::vector& imgs, std::vector* results) { +bool PPDetBase::BatchPredict(const std::vector& imgs, + std::vector* results) { std::vector fd_images = WrapMat(imgs); if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { FDERROR << "Failed to preprocess the input image." << std::endl; @@ -46,8 +49,13 @@ bool PPDetBase::BatchPredict(const std::vector& imgs, std::vector(reused_input_tensors_[1].Data())); + } + + // Some models don't need scale_factor and im_shape as input + while (reused_input_tensors_.size() != NumInputsOfRuntime()) { reused_input_tensors_.pop_back(); } @@ -57,12 +65,13 @@ bool PPDetBase::BatchPredict(const std::vector& imgs, std::vector* results) { +bool PaddleDetPostprocessor::ProcessMask( + const FDTensor& tensor, std::vector* results) { auto shape = tensor.Shape(); if (tensor.Dtype() != FDDataType::INT32) { - FDERROR << "The data type of out mask tensor should be INT32, but now it's " << tensor.Dtype() << std::endl; + FDERROR << "The data type of out mask tensor should be INT32, but now it's " + << tensor.Dtype() << std::endl; return false; } int64_t out_mask_h = shape[1]; @@ -46,20 +48,34 @@ bool PaddleDetPostprocessor::ProcessMask(const FDTensor& tensor, std::vector((*results)[i].masks[j].Data()); + int32_t* keep_mask_ptr = + reinterpret_cast((*results)[i].masks[j].Data()); for (int row = y1; row < y2; ++row) { size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t); const int32_t* out_row_start_ptr = current_ptr + row * out_mask_w + x1; int32_t* keep_row_start_ptr = keep_mask_ptr + (row - y1) * keep_mask_w; std::memcpy(keep_row_start_ptr, out_row_start_ptr, keep_nbytes_in_col); } - index += 1; + index += 1; } } return true; } -bool PaddleDetPostprocessor::Run(const std::vector& tensors, std::vector* results) { +bool PaddleDetPostprocessor::Run(const std::vector& tensors, + std::vector* results) { + if (DecodeAndNMSApplied()) { + FDASSERT(tensors.size() == 2, + "While postprocessing with ApplyDecodeAndNMS, " + "there should be 2 outputs for this model, but now it's %zu.", + tensors.size()); + FDASSERT(tensors[0].shape.size() == 3, + "While postprocessing with ApplyDecodeAndNMS, " + "the rank of the first outputs should be 3, but now it's %zu", + tensors[0].shape.size()); + return ProcessUnDecodeResults(tensors, results); + } + if (tensors[0].shape[0] == 0) { // No detected boxes return true; @@ -69,13 +85,13 @@ bool PaddleDetPostprocessor::Run(const std::vector& tensors, std::vect std::vector num_boxes(tensors[1].shape[0]); int total_num_boxes = 0; if (tensors[1].dtype == FDDataType::INT32) { - const int32_t* data = static_cast(tensors[1].CpuData()); + const auto* data = static_cast(tensors[1].CpuData()); for (size_t i = 0; i < tensors[1].shape[0]; ++i) { num_boxes[i] = static_cast(data[i]); total_num_boxes += num_boxes[i]; } } else if (tensors[1].dtype == FDDataType::INT64) { - const int64_t* data = static_cast(tensors[1].CpuData()); + const auto* data = static_cast(tensors[1].CpuData()); for (size_t i = 0; i < tensors[1].shape[0]; ++i) { num_boxes[i] = static_cast(data[i]); } @@ -83,33 +99,37 @@ bool PaddleDetPostprocessor::Run(const std::vector& tensors, std::vect // Special case for TensorRT, it has fixed output shape of NMS // So there's invalid boxes in its' output boxes - int num_output_boxes = tensors[0].Shape()[0]; + int num_output_boxes = static_cast(tensors[0].Shape()[0]); bool contain_invalid_boxes = false; if (total_num_boxes != num_output_boxes) { if (num_output_boxes % num_boxes.size() == 0) { contain_invalid_boxes = true; } else { - FDERROR << "Cannot handle the output data for this model, unexpected situation." << std::endl; + FDERROR << "Cannot handle the output data for this model, unexpected " + "situation." + << std::endl; return false; } } // Get boxes for each input image results->resize(num_boxes.size()); - const float* box_data = static_cast(tensors[0].CpuData()); + const auto* box_data = static_cast(tensors[0].CpuData()); int offset = 0; for (size_t i = 0; i < num_boxes.size(); ++i) { const float* ptr = box_data + offset; (*results)[i].Reserve(num_boxes[i]); for (size_t j = 0; j < num_boxes[i]; ++j) { - (*results)[i].label_ids.push_back(static_cast(round(ptr[j * 6]))); - (*results)[i].scores.push_back(ptr[j * 6 + 1]); - (*results)[i].boxes.emplace_back(std::array({ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]})); + (*results)[i].label_ids.push_back( + static_cast(round(ptr[j * 6]))); + (*results)[i].scores.push_back(ptr[j * 6 + 1]); + (*results)[i].boxes.emplace_back(std::array( + {ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]})); } if (contain_invalid_boxes) { - offset += (num_output_boxes * 6 / num_boxes.size()); + offset += static_cast(num_output_boxes * 6 / num_boxes.size()); } else { - offset += (num_boxes[i] * 6); + offset += static_cast(num_boxes[i] * 6); } } @@ -119,7 +139,10 @@ bool PaddleDetPostprocessor::Run(const std::vector& tensors, std::vect } if (tensors[2].Shape()[0] != num_output_boxes) { - FDERROR << "The first dimension of output mask tensor:" << tensors[2].Shape()[0] << " is not equal to the first dimension of output boxes tensor:" << num_output_boxes << "." << std::endl; + FDERROR << "The first dimension of output mask tensor:" + << tensors[2].Shape()[0] + << " is not equal to the first dimension of output boxes tensor:" + << num_output_boxes << "." << std::endl; return false; } @@ -127,6 +150,80 @@ bool PaddleDetPostprocessor::Run(const std::vector& tensors, std::vect return ProcessMask(tensors[2], results); } -} // namespace detection -} // namespace vision -} // namespace fastdeploy +void PaddleDetPostprocessor::ApplyDecodeAndNMS() { + apply_decode_and_nms_ = true; +} + +bool PaddleDetPostprocessor::ProcessUnDecodeResults( + const std::vector& tensors, + std::vector* results) { + if (tensors.size() != 2) { + return false; + } + + int boxes_index = 0; + int scores_index = 1; + if (tensors[0].shape[1] == tensors[1].shape[2]) { + boxes_index = 0; + scores_index = 1; + } else if (tensors[0].shape[2] == tensors[1].shape[1]) { + boxes_index = 1; + scores_index = 0; + } else { + FDERROR << "The shape of boxes and scores should be [batch, boxes_num, " + "4], [batch, classes_num, boxes_num]" + << std::endl; + return false; + } + + backend::MultiClassNMS nms; + nms.background_label = -1; + nms.keep_top_k = 100; + nms.nms_eta = 1.0; + nms.nms_threshold = 0.5; + nms.score_threshold = 0.3; + nms.nms_top_k = 1000; + nms.normalized = true; + nms.Compute(static_cast(tensors[boxes_index].Data()), + static_cast(tensors[scores_index].Data()), + tensors[boxes_index].shape, tensors[scores_index].shape); + + auto num_boxes = nms.out_num_rois_data; + auto box_data = static_cast(nms.out_box_data.data()); + // Get boxes for each input image + results->resize(num_boxes.size()); + int offset = 0; + for (size_t i = 0; i < num_boxes.size(); ++i) { + const float* ptr = box_data + offset; + (*results)[i].Reserve(num_boxes[i]); + for (size_t j = 0; j < num_boxes[i]; ++j) { + (*results)[i].label_ids.push_back( + static_cast(round(ptr[j * 6]))); + (*results)[i].scores.push_back(ptr[j * 6 + 1]); + (*results)[i].boxes.emplace_back(std::array( + {ptr[j * 6 + 2] / GetScaleFactor()[1], + ptr[j * 6 + 3] / GetScaleFactor()[0], + ptr[j * 6 + 4] / GetScaleFactor()[1], + ptr[j * 6 + 5] / GetScaleFactor()[0]})); + } + offset += (num_boxes[i] * 6); + } + return true; +} + +std::vector PaddleDetPostprocessor::GetScaleFactor(){ + return scale_factor_; +} + +void PaddleDetPostprocessor::SetScaleFactor(float* scale_factor_value){ + for (int i = 0; i < scale_factor_.size(); ++i) { + scale_factor_[i] = scale_factor_value[i]; + } +} + +bool PaddleDetPostprocessor::DecodeAndNMSApplied() { + return apply_decode_and_nms_; +} +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/postprocessor.h b/fastdeploy/vision/detection/ppdet/postprocessor.h index 54be1bfd9..9a1410dc9 100644 --- a/fastdeploy/vision/detection/ppdet/postprocessor.h +++ b/fastdeploy/vision/detection/ppdet/postprocessor.h @@ -32,11 +32,28 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor { * \return true if the postprocess successed, otherwise false */ bool Run(const std::vector& tensors, - std::vector* result); + std::vector* result); + + /// Apply box decoding and nms step for the outputs for the model.This is + /// only available for those model exported without box decoding and nms. + void ApplyDecodeAndNMS(); + + bool DecodeAndNMSApplied(); + + /// Set scale_factor_ value.This is only available for those model exported + /// without box decoding and nms. + void SetScaleFactor(float* scale_factor_value); + private: // Process mask tensor for MaskRCNN bool ProcessMask(const FDTensor& tensor, - std::vector* results); + std::vector* results); + + bool apply_decode_and_nms_ = false; + std::vector scale_factor_{1.0, 1.0}; + std::vector GetScaleFactor(); + bool ProcessUnDecodeResults(const std::vector& tensors, + std::vector* results); }; } // namespace detection diff --git a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc index 252097608..46cd5e8ea 100644 --- a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc +++ b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc @@ -43,6 +43,10 @@ void BindPPDet(pybind11::module& m) { } return results; }) + .def("apply_decode_and_nms", + [](vision::detection::PaddleDetPostprocessor& self){ + self.ApplyDecodeAndNMS(); + }) .def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector& input_array) { std::vector results; std::vector inputs; diff --git a/fastdeploy/vision/detection/ppdet/preprocessor.cc b/fastdeploy/vision/detection/ppdet/preprocessor.cc index b1179d036..bb38c67ec 100644 --- a/fastdeploy/vision/detection/ppdet/preprocessor.cc +++ b/fastdeploy/vision/detection/ppdet/preprocessor.cc @@ -22,11 +22,13 @@ namespace vision { namespace detection { PaddleDetPreprocessor::PaddleDetPreprocessor(const std::string& config_file) { - FDASSERT(BuildPreprocessPipelineFromConfig(config_file), "Failed to create PaddleDetPreprocessor."); + FDASSERT(BuildPreprocessPipelineFromConfig(config_file), + "Failed to create PaddleDetPreprocessor."); initialized_ = true; } -bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& config_file) { +bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig( + const std::string& config_file) { processors_.clear(); YAML::Node cfg; try { @@ -106,8 +108,6 @@ bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& // permute = cast + HWC2CHW processors_.push_back(std::make_shared("float")); processors_.push_back(std::make_shared()); - } else { - processors_.push_back(std::make_shared()); } // Fusion will improve performance @@ -116,13 +116,15 @@ bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& return true; } -bool PaddleDetPreprocessor::Run(std::vector* images, std::vector* outputs) { +bool PaddleDetPreprocessor::Run(std::vector* images, + std::vector* outputs) { if (!initialized_) { FDERROR << "The preprocessor is not initialized." << std::endl; return false; } if (images->size() == 0) { - FDERROR << "The size of input images should be greater than 0." << std::endl; + FDERROR << "The size of input images should be greater than 0." + << std::endl; return false; } @@ -140,8 +142,9 @@ bool PaddleDetPreprocessor::Run(std::vector* images, std::vector max_hw({-1, -1}); - float* scale_factor_ptr = reinterpret_cast((*outputs)[1].MutableData()); - float* im_shape_ptr = reinterpret_cast((*outputs)[2].MutableData()); + float* scale_factor_ptr = + reinterpret_cast((*outputs)[1].MutableData()); + float* im_shape_ptr = reinterpret_cast((*outputs)[2].MutableData()); for (size_t i = 0; i < images->size(); ++i) { int origin_w = (*images)[i].Width(); int origin_h = (*images)[i].Height(); @@ -149,7 +152,8 @@ bool PaddleDetPreprocessor::Run(std::vector* images, std::vectorName() << "." << std::endl; + FDERROR << "Failed to processs image:" << i << " in " + << processors_[i]->Name() << "." << std::endl; return false; } if (processors_[j]->Name().find("Resize") != std::string::npos) { @@ -166,15 +170,18 @@ bool PaddleDetPreprocessor::Run(std::vector* images, std::vector im_tensors(images->size()); + std::vector im_tensors(images->size()); for (size_t i = 0; i < images->size(); ++i) { if ((*images)[i].Height() < max_hw[0] || (*images)[i].Width() < max_hw[1]) { // if the size of image less than max_hw, pad to max_hw FDTensor tensor; (*images)[i].ShareWithTensor(&tensor); - function::Pad(tensor, &(im_tensors[i]), {0, 0, max_hw[0] - (*images)[i].Height(), max_hw[1] - (*images)[i].Width()}, 0); + function::Pad(tensor, &(im_tensors[i]), + {0, 0, max_hw[0] - (*images)[i].Height(), + max_hw[1] - (*images)[i].Width()}, + 0); } else { // No need pad (*images)[i].ShareWithTensor(&(im_tensors[i])); @@ -196,6 +203,6 @@ bool PaddleDetPreprocessor::Run(std::vector* images, std::vector=0.0.1 +pyyaml diff --git a/python/setup.py b/python/setup.py index a411a12be..991d87b03 100755 --- a/python/setup.py +++ b/python/setup.py @@ -61,7 +61,8 @@ setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND", "OFF") setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND", "OFF") -setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", "OFF") +setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", + "OFF") setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF") setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF") setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF") @@ -71,13 +72,15 @@ setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF") setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF") setup_configs["BUILD_ON_JETSON"] = os.getenv("BUILD_ON_JETSON", "OFF") setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED") -setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", "/usr/local/cuda") +setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", + "/usr/local/cuda") setup_configs["LIBRARY_NAME"] = PACKAGE_NAME setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main" setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "") setup_configs["ORT_DIRECTORY"] = os.getenv("ORT_DIRECTORY", "") - setup_configs["RKNN2_TARGET_SOC"] = os.getenv("RKNN2_TARGET_SOC", "") +if setup_configs["RKNN2_TARGET_SOC"] != "": + REQUIRED_PACKAGES = REQUIRED_PACKAGES.replace("opencv-python", "") if setup_configs["WITH_GPU"] == "ON" or setup_configs[ "BUILD_ON_JETSON"] == "ON": @@ -105,7 +108,8 @@ extras_require = {} # Default value is set to TRUE\1 to keep the settings same as the current ones. # However going forward the recomemded way to is to set this to False\0 -USE_MSVC_STATIC_RUNTIME = bool(os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1') +USE_MSVC_STATIC_RUNTIME = bool( + os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1') ONNX_NAMESPACE = os.getenv('ONNX_NAMESPACE', 'paddle2onnx') ################################################################################ # Version @@ -135,7 +139,8 @@ assert CMAKE, 'Could not find "cmake" executable!' @contextmanager def cd(path): if not os.path.isabs(path): - raise RuntimeError('Can only cd to absolute path, got: {}'.format(path)) + raise RuntimeError('Can only cd to absolute path, got: {}'.format( + path)) orig_path = os.getcwd() os.chdir(path) try: diff --git a/tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml b/tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml index 985489163..7bb141eca 100644 --- a/tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml +++ b/tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml @@ -1,7 +1,5 @@ model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx output_folder: ./picodet_s_416_coco_lcnet target_platform: RK3568 -normalize: - mean: [[0.485,0.456,0.406]] - std: [[0.229,0.224,0.225]] -outputs: ['tmp_16','p2o.Concat.9'] +normalize: None +outputs: ['tmp_17','p2o.Concat.9'] diff --git a/tools/rknpu2/config/RK3568/picodet_s_416_coco_npu.yaml b/tools/rknpu2/config/RK3568/picodet_s_416_coco_npu.yaml deleted file mode 100644 index 723acc8b5..000000000 --- a/tools/rknpu2/config/RK3568/picodet_s_416_coco_npu.yaml +++ /dev/null @@ -1,5 +0,0 @@ -model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx -output_folder: ./picodet_s_416_coco_npu -target_platform: RK3568 -normalize: None -outputs: ['tmp_16','p2o.Concat.17'] diff --git a/tools/rknpu2/config/RK3588/picodet_s_416_coco_lcnet.yaml b/tools/rknpu2/config/RK3588/picodet_s_416_coco_lcnet.yaml index 6110e8c0f..ba12a4be1 100644 --- a/tools/rknpu2/config/RK3588/picodet_s_416_coco_lcnet.yaml +++ b/tools/rknpu2/config/RK3588/picodet_s_416_coco_lcnet.yaml @@ -1,7 +1,5 @@ model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx output_folder: ./picodet_s_416_coco_lcnet target_platform: RK3588 -normalize: - mean: [[0.485,0.456,0.406]] - std: [[0.229,0.224,0.225]] +normalize: None outputs: ['tmp_16','p2o.Concat.9'] diff --git a/tools/rknpu2/config/RK3588/picodet_s_416_coco_npu.yaml b/tools/rknpu2/config/RK3588/picodet_s_416_coco_npu.yaml deleted file mode 100644 index 356fcfad8..000000000 --- a/tools/rknpu2/config/RK3588/picodet_s_416_coco_npu.yaml +++ /dev/null @@ -1,5 +0,0 @@ -model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx -output_folder: ./picodet_s_416_coco_npu -target_platform: RK3588 -normalize: None -outputs: ['tmp_16','p2o.Concat.17']