diff --git a/docs/cn/faq/rknpu2/rknpu2.md b/docs/cn/faq/rknpu2/rknpu2.md index c488f311c..0e1f920d6 100644 --- a/docs/cn/faq/rknpu2/rknpu2.md +++ b/docs/cn/faq/rknpu2/rknpu2.md @@ -14,19 +14,18 @@ ONNX模型不能直接调用RK芯片中的NPU进行运算,需要把ONNX模型 | 任务场景 | 模型 | 模型版本(表示已经测试的版本) | ARM CPU/RKNN速度(ms) | |------------------|-------------------|-------------------------------|--------------------| -| Detection | Picodet | Picodet-s | 599/136 | +| Detection | Picodet | Picodet-s | 162/112 | | Segmentation | Unet | Unet-cityscapes | -/- | -| Segmentation | PP-LiteSeg | PP_LiteSeg_T_STDC1_cityscapes | 6634/5598 | -| Segmentation | PP-HumanSegV2Lite | portrait | 456/266 | -| Segmentation | PP-HumanSegV2Lite | human | 496/256 | -| Face Detection | SCRFD | SCRFD-2.5G-kps-640 | 963/142 | +| Segmentation | PP-LiteSeg | PP_LiteSeg_T_STDC1_cityscapes | -/- | +| Segmentation | PP-HumanSegV2Lite | portrait | 53/50 | +| Segmentation | PP-HumanSegV2Lite | human | 53/50 | +| Face Detection | SCRFD | SCRFD-2.5G-kps-640 | 112/108 | ## TODO 以下为TODO计划,表示还正在准备支持,但是还存在问题或还可以改进的模型。 | 任务场景 | 模型 | 模型版本(表示已经测试的版本) | ARM CPU/RKNN速度(ms) | |------------------|---------|---------------------|--------------------| -| Detection | Picodet | Picodet-s(int8) | -/- | | Detection | PPYOLOE | PPYOLOE(int8) | -/- | | Detection | YOLOv5 | YOLOv5-s_v6.2(int8) | -/- | | Face Recognition | ArcFace | ArcFace_r18 | 600/3 | diff --git a/examples/vision/detection/paddledetection/rknpu2/README.md b/examples/vision/detection/paddledetection/rknpu2/README.md index 98f1ada10..d242cf339 100644 --- a/examples/vision/detection/paddledetection/rknpu2/README.md +++ b/examples/vision/detection/paddledetection/rknpu2/README.md @@ -45,8 +45,8 @@ model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx output_folder: ./picodet_s_416_coco_lcnet target_platform: RK3568 normalize: - mean: [[0.485,0.456,0.406],[0,0,0]] - std: [[0.229,0.224,0.225],[0.003921,0.003921]] + mean: [[0.485,0.456,0.406]] + std: [[0.229,0.224,0.225]] outputs: ['tmp_17','p2o.Concat.9'] ``` diff --git a/examples/vision/detection/paddledetection/rknpu2/cpp/infer_picodet.cc b/examples/vision/detection/paddledetection/rknpu2/cpp/infer_picodet.cc index 12f405b52..8535aa338 100644 --- a/examples/vision/detection/paddledetection/rknpu2/cpp/infer_picodet.cc +++ b/examples/vision/detection/paddledetection/rknpu2/cpp/infer_picodet.cc @@ -15,26 +15,39 @@ #include #include "fastdeploy/vision.h" #include -double __get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); } -void InferPicodet(const std::string& model_dir, const std::string& image_file); -int main(int argc, char* argv[]) { - if (argc < 3) { - std::cout - << "Usage: infer_demo path/to/model_dir path/to/image run_option, " - "e.g ./infer_model ./picodet_model_dir ./test.jpeg" - << std::endl; - return -1; +void ONNXInfer(const std::string& model_dir, const std::string& image_file) { + std::string model_file = model_dir + "/picodet_s_416_coco_lcnet.onnx"; + std::string params_file; + std::string config_file = model_dir + "/deploy.yaml"; + auto option = fastdeploy::RuntimeOption(); + option.UseCpu(); + auto format = fastdeploy::ModelFormat::ONNX; + + auto model = fastdeploy::vision::detection::PicoDet( + model_file, params_file, config_file,option,format); + model.GetPostprocessor().ApplyDecodeAndNMS(); + + fastdeploy::TimeCounter tc; + tc.Start(); + auto im = cv::imread(image_file); + fastdeploy::vision::DetectionResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; } + auto vis_im = fastdeploy::vision::VisDetection(im, res,0.5); + tc.End(); + tc.PrintInfo("PPDet in ONNX"); - InferPicodet(argv[1], argv[2]); - - return 0; + cv::imwrite("infer_onnx.jpg", vis_im); + std::cout + << "Visualized result saved in ./infer_onnx.jpg" + << std::endl; } -void InferPicodet(const std::string& model_dir, const std::string& image_file) { - struct timeval start_time, stop_time; - auto model_file = model_dir + "/picodet_s_416_coco_lcnet_rk3568.rknn"; +void RKNPU2Infer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + "/picodet_s_416_coco_lcnet_rk3588.rknn"; auto params_file = ""; auto config_file = model_dir + "/infer_cfg.yml"; @@ -51,16 +64,31 @@ void InferPicodet(const std::string& model_dir, const std::string& image_file) { auto im = cv::imread(image_file); fastdeploy::vision::DetectionResult res; - gettimeofday(&start_time, NULL); + fastdeploy::TimeCounter tc; + tc.Start(); if (!model.Predict(&im, &res)) { std::cerr << "Failed to predict." << std::endl; return; } - gettimeofday(&stop_time, NULL); - printf("infer use %f ms\n", (__get_us(stop_time) - __get_us(start_time)) / 1000); + tc.End(); + tc.PrintInfo("PPDet in RKNPU2"); std::cout << res.Str() << std::endl; auto vis_im = fastdeploy::vision::VisDetection(im, res,0.5); - cv::imwrite("picodet_result.jpg", vis_im); - std::cout << "Visualized result saved in ./picodet_result.jpg" << std::endl; -} \ No newline at end of file + cv::imwrite("infer_rknpu2.jpg", vis_im); + std::cout << "Visualized result saved in ./infer_rknpu2.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./picodet_model_dir ./test.jpeg" + << std::endl; + return -1; + } + RKNPU2Infer(argv[1], argv[2]); +//ONNXInfer(argv[1], argv[2]); + return 0; +} + diff --git a/examples/vision/facedet/scrfd/rknpu2/cpp/infer.cc b/examples/vision/facedet/scrfd/rknpu2/cpp/infer.cc index a01f1b184..4ac3496f5 100644 --- a/examples/vision/facedet/scrfd/rknpu2/cpp/infer.cc +++ b/examples/vision/facedet/scrfd/rknpu2/cpp/infer.cc @@ -2,50 +2,13 @@ #include #include "fastdeploy/vision.h" -void InferScrfd(const std::string& device = "cpu"); - -int main() { - InferScrfd("npu"); - return 0; -} - -fastdeploy::RuntimeOption GetOption(const std::string& device) { - auto option = fastdeploy::RuntimeOption(); - if (device == "npu") { - option.UseRKNPU2(); - } else { - option.UseCpu(); - } - return option; -} - -fastdeploy::ModelFormat GetFormat(const std::string& device) { - auto format = fastdeploy::ModelFormat::ONNX; - if (device == "npu") { - format = fastdeploy::ModelFormat::RKNN; - } else { - format = fastdeploy::ModelFormat::ONNX; - } - return format; -} - -std::string GetModelPath(std::string& model_path, const std::string& device) { - if (device == "npu") { - model_path += "rknn"; - } else { - model_path += "onnx"; - } - return model_path; -} - -void InferScrfd(const std::string& device) { - std::string model_file = - "./model/scrfd_500m_bnkps_shape640x640_rk3588."; +void ONNXInfer(const std::string& model_dir, const std::string& image_file) { + std::string model_file = model_dir + "/scrfd_500m_bnkps_shape640x640.onnx"; std::string params_file; + auto option = fastdeploy::RuntimeOption(); + option.UseCpu(); + auto format = fastdeploy::ModelFormat::ONNX; - fastdeploy::RuntimeOption option = GetOption(device); - fastdeploy::ModelFormat format = GetFormat(device); - model_file = GetModelPath(model_file, device); auto model = fastdeploy::vision::facedet::SCRFD( model_file, params_file, option, format); @@ -53,27 +16,68 @@ void InferScrfd(const std::string& device) { std::cerr << "Failed to initialize." << std::endl; return; } - auto image_file = - "./images/test_lite_face_detector_3.jpg"; + + fastdeploy::TimeCounter tc; + tc.Start(); auto im = cv::imread(image_file); - - if (device == "npu") { - model.DisableNormalizeAndPermute(); - } - fastdeploy::vision::FaceDetectionResult res; - clock_t start = clock(); if (!model.Predict(&im, &res)) { std::cerr << "Failed to predict." << std::endl; return; } - clock_t end = clock(); - auto dur = static_cast(end - start); - printf("InferScrfd use time:%f\n", - (dur / CLOCKS_PER_SEC)); - - std::cout << res.Str() << std::endl; auto vis_im = fastdeploy::vision::Visualize::VisFaceDetection(im, res); - cv::imwrite("vis_result.jpg", vis_im); - std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; + tc.End(); + tc.PrintInfo("SCRFD in ONNX"); + + cv::imwrite("infer_onnx.jpg", vis_im); + std::cout + << "Visualized result saved in ./infer_onnx.jpg" + << std::endl; +} + +void RKNPU2Infer(const std::string& model_dir, const std::string& image_file) { + std::string model_file = model_dir + "/scrfd_500m_bnkps_shape640x640_rk3588.rknn"; + std::string params_file; + auto option = fastdeploy::RuntimeOption(); + option.UseRKNPU2(); + auto format = fastdeploy::ModelFormat::RKNN; + + auto model = fastdeploy::vision::facedet::SCRFD(model_file, params_file, option, format); + + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + model.DisableNormalizeAndPermute(); + + fastdeploy::TimeCounter tc; + tc.Start(); + auto im = cv::imread(image_file); + fastdeploy::vision::FaceDetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + auto vis_im = fastdeploy::vision::Visualize::VisFaceDetection(im, res); + tc.End(); + tc.PrintInfo("SCRFD in RKNN"); + + cv::imwrite("infer_rknn.jpg", vis_im); + std::cout + << "Visualized result saved in ./infer_rknn.jpg" + << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./picodet_model_dir ./test.jpeg" + << std::endl; + return -1; + } + + RKNPU2Infer(argv[1], argv[2]); + ONNXInfer(argv[1], argv[2]); + return 0; } \ No newline at end of file diff --git a/examples/vision/segmentation/paddleseg/rknpu2/cpp/infer.cc b/examples/vision/segmentation/paddleseg/rknpu2/cpp/infer.cc index bfc108d05..4e02ae32e 100644 --- a/examples/vision/segmentation/paddleseg/rknpu2/cpp/infer.cc +++ b/examples/vision/segmentation/paddleseg/rknpu2/cpp/infer.cc @@ -15,83 +15,84 @@ #include #include "fastdeploy/vision.h" -void InferHumanPPHumansegv2Lite(const std::string& device = "cpu"); - -int main() { - InferHumanPPHumansegv2Lite("npu"); - return 0; -} - -fastdeploy::RuntimeOption GetOption(const std::string& device) { - auto option = fastdeploy::RuntimeOption(); - if (device == "npu") { - option.UseRKNPU2(); - } else { - option.UseCpu(); - } - return option; -} - -fastdeploy::ModelFormat GetFormat(const std::string& device) { - auto format = fastdeploy::ModelFormat::ONNX; - if (device == "npu") { - format = fastdeploy::ModelFormat::RKNN; - } else { - format = fastdeploy::ModelFormat::ONNX; - } - return format; -} - -std::string GetModelPath(std::string& model_path, const std::string& device) { - if (device == "npu") { - model_path += "rknn"; - } else { - model_path += "onnx"; - } - return model_path; -} - -void InferHumanPPHumansegv2Lite(const std::string& device) { - std::string model_file = - "./model/Portrait_PP_HumanSegV2_Lite_256x144_infer/" - "Portrait_PP_HumanSegV2_Lite_256x144_infer_rk3588."; +void ONNXInfer(const std::string& model_dir, const std::string& image_file) { + std::string model_file = model_dir + "/Portrait_PP_HumanSegV2_Lite_256x144_infer.onnx"; std::string params_file; - std::string config_file = - "./model/Portrait_PP_HumanSegV2_Lite_256x144_infer/deploy.yaml"; + std::string config_file = model_dir + "/deploy.yaml"; + auto option = fastdeploy::RuntimeOption(); + option.UseCpu(); + auto format = fastdeploy::ModelFormat::ONNX; - fastdeploy::RuntimeOption option = GetOption(device); - fastdeploy::ModelFormat format = GetFormat(device); - model_file = GetModelPath(model_file, device); auto model = fastdeploy::vision::segmentation::PaddleSegModel( model_file, params_file, config_file, option, format); - if (!model.Initialized()) { std::cerr << "Failed to initialize." << std::endl; return; } - auto image_file = - "./images/portrait_heng.jpg"; + + fastdeploy::TimeCounter tc; + tc.Start(); auto im = cv::imread(image_file); - - if (device == "npu") { - model.GetPreprocessor().DisableNormalizeAndPermute(); - } - fastdeploy::vision::SegmentationResult res; - clock_t start = clock(); if (!model.Predict(im, &res)) { std::cerr << "Failed to predict." << std::endl; return; } - clock_t end = clock(); - auto dur = (double)(end - start); - printf("infer_human_pp_humansegv2_lite_npu use time:%f\n", - (dur / CLOCKS_PER_SEC)); - - std::cout << res.Str() << std::endl; auto vis_im = fastdeploy::vision::VisSegmentation(im, res); - cv::imwrite("human_pp_humansegv2_lite_npu_result.jpg", vis_im); + tc.End(); + tc.PrintInfo("PPSeg in ONNX"); + + cv::imwrite("infer_onnx.jpg", vis_im); std::cout - << "Visualized result saved in ./human_pp_humansegv2_lite_npu_result.jpg" + << "Visualized result saved in ./infer_onnx.jpg" << std::endl; -} \ No newline at end of file +} + +void RKNPU2Infer(const std::string& model_dir, const std::string& image_file) { + std::string model_file = model_dir + "/Portrait_PP_HumanSegV2_Lite_256x144_infer_rk3588.rknn"; + std::string params_file; + std::string config_file = model_dir + "/deploy.yaml"; + auto option = fastdeploy::RuntimeOption(); + option.UseRKNPU2(); + auto format = fastdeploy::ModelFormat::RKNN; + + auto model = fastdeploy::vision::segmentation::PaddleSegModel( + model_file, params_file, config_file, option, format); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + model.GetPreprocessor().DisableNormalizeAndPermute(); + + fastdeploy::TimeCounter tc; + tc.Start(); + auto im = cv::imread(image_file); + fastdeploy::vision::SegmentationResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + auto vis_im = fastdeploy::vision::VisSegmentation(im, res); + tc.End(); + tc.PrintInfo("PPSeg in RKNPU2"); + + cv::imwrite("infer_rknn.jpg", vis_im); + std::cout + << "Visualized result saved in ./infer_rknn.jpg" + << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./picodet_model_dir ./test.jpeg" + << std::endl; + return -1; + } + + RKNPU2Infer(argv[1], argv[2]); + ONNXInfer(argv[1], argv[2]); + return 0; +} + diff --git a/examples/vision/segmentation/paddleseg/rknpu2/python/infer.py b/examples/vision/segmentation/paddleseg/rknpu2/python/infer.py index d7239eb42..4168d591d 100644 --- a/examples/vision/segmentation/paddleseg/rknpu2/python/infer.py +++ b/examples/vision/segmentation/paddleseg/rknpu2/python/infer.py @@ -49,7 +49,7 @@ model = fd.vision.segmentation.PaddleSegModel( runtime_option=runtime_option, model_format=fd.ModelFormat.RKNN) -model.disable_normalize_and_permute() +model.preprocessor.disable_normalize_and_permute() # 预测图片分割结果 im = cv2.imread(args.image) diff --git a/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.cc b/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.cc index b577c2791..8046fd87a 100644 --- a/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.cc +++ b/fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.cc @@ -345,6 +345,9 @@ bool RKNPU2Backend::Infer(std::vector& inputs, FDERROR << "rknn_create_mem output_mems_ error." << std::endl; return false; } + if(output_attrs_[i].type == RKNN_TENSOR_FLOAT16){ + output_attrs_[i].type = RKNN_TENSOR_FLOAT32; + } // default output type is depend on model, this requires float32 to compute top5 ret = rknn_set_io_mem(ctx, output_mems_[i], &output_attrs_[i]); // set output memory and attribute diff --git a/fastdeploy/vision/detection/contrib/rknpu2/model.h b/fastdeploy/vision/detection/contrib/rknpu2/model.h index 9a0fd423d..8fe036cb5 100644 --- a/fastdeploy/vision/detection/contrib/rknpu2/model.h +++ b/fastdeploy/vision/detection/contrib/rknpu2/model.h @@ -35,7 +35,10 @@ class FASTDEPLOY_DECL RKYOLOV5 : public RKYOLO { valid_cpu_backends = {}; valid_gpu_backends = {}; valid_rknpu_backends = {Backend::RKNPU2}; - GetPostprocessor().SetModelType(ModelType::RKYOLOV5); + std::vector anchors = {10, 13, 16, 30, 33, 23, 30, 61, 62, + 45, 59, 119, 116, 90, 156, 198, 373, 326}; + int anchor_per_branch_ = 3; + GetPostprocessor().SetAnchor(anchors, anchor_per_branch_); } virtual std::string ModelName() const { return "RKYOLOV5"; } @@ -58,7 +61,10 @@ class FASTDEPLOY_DECL RKYOLOV7 : public RKYOLO { valid_cpu_backends = {}; valid_gpu_backends = {}; valid_rknpu_backends = {Backend::RKNPU2}; - GetPostprocessor().SetModelType(ModelType::RKYOLOV7); + std::vector anchors = {12, 16, 19, 36, 40, 28, 36, 75, 76, + 55, 72, 146, 142, 110, 192, 243, 459, 401}; + int anchor_per_branch_ = 3; + GetPostprocessor().SetAnchor(anchors, anchor_per_branch_); } virtual std::string ModelName() const { return "RKYOLOV7"; } @@ -81,7 +87,10 @@ class FASTDEPLOY_DECL RKYOLOX : public RKYOLO { valid_cpu_backends = {}; valid_gpu_backends = {}; valid_rknpu_backends = {Backend::RKNPU2}; - GetPostprocessor().SetModelType(ModelType::RKYOLOX); + std::vector anchors = {10, 13, 16, 30, 33, 23, 30, 61, 62, + 45, 59, 119, 116, 90, 156, 198, 373, 326}; + int anchor_per_branch_ = 1; + GetPostprocessor().SetAnchor(anchors, anchor_per_branch_); } virtual std::string ModelName() const { return "RKYOLOV7"; } diff --git a/fastdeploy/vision/detection/contrib/rknpu2/postprocessor.cc b/fastdeploy/vision/detection/contrib/rknpu2/postprocessor.cc index bb46eff5c..bf8be2727 100755 --- a/fastdeploy/vision/detection/contrib/rknpu2/postprocessor.cc +++ b/fastdeploy/vision/detection/contrib/rknpu2/postprocessor.cc @@ -21,32 +21,8 @@ namespace detection { RKYOLOPostprocessor::RKYOLOPostprocessor() {} -void RKYOLOPostprocessor::SetModelType(ModelType model_type) { - model_type_ = model_type; - if (model_type == RKYOLOV5) { - anchors_ = {10, 13, 16, 30, 33, 23, 30, 61, 62, - 45, 59, 119, 116, 90, 156, 198, 373, 326}; - anchor_per_branch_ = 3; - } else if (model_type == RKYOLOX) { - anchors_ = {10, 13, 16, 30, 33, 23, 30, 61, 62, - 45, 59, 119, 116, 90, 156, 198, 373, 326}; - anchor_per_branch_ = 1; - } else if (model_type == RKYOLOV7) { - anchors_ = {12, 16, 19, 36, 40, 28, 36, 75, 76, - 55, 72, 146, 142, 110, 192, 243, 459, 401}; - anchor_per_branch_ = 3; - } else { - return; - } -} - bool RKYOLOPostprocessor::Run(const std::vector& tensors, std::vector* results) { - if (model_type_ == ModelType::UNKNOWN) { - FDERROR << "RKYOLO Only Support YOLOV5,YOLOV7,YOLOX" << std::endl; - return false; - } - results->resize(tensors[0].shape[0]); for (int num = 0; num < tensors[0].shape[0]; ++num) { int validCount = 0; @@ -62,13 +38,15 @@ bool RKYOLOPostprocessor::Run(const std::vector& tensors, int grid_h = height_ / stride; int grid_w = width_ / stride; int* anchor = &(anchors_.data()[i * 2 * anchor_per_branch_]); - if (tensors[i].dtype == FDDataType::INT8 || tensors[i].dtype == FDDataType::UINT8) { + if (tensors[i].dtype == FDDataType::INT8 || + tensors[i].dtype == FDDataType::UINT8) { auto quantization_info = tensors[i].GetQuantizationInfo(); - validCount = validCount + - ProcessInt8((int8_t*)tensors[i].Data() + skip_address, - anchor, grid_h, grid_w, stride, filterBoxes, - boxesScore, classId, conf_threshold_, - quantization_info.first, quantization_info.second[0]); + validCount = + validCount + ProcessInt8((int8_t*)tensors[i].Data() + skip_address, + anchor, grid_h, grid_w, stride, + filterBoxes, boxesScore, classId, + conf_threshold_, quantization_info.first, + quantization_info.second[0]); } else { FDERROR << "RKYOLO Only Support INT8 Model" << std::endl; } @@ -87,10 +65,13 @@ bool RKYOLOPostprocessor::Run(const std::vector& tensors, QuickSortIndiceInverse(boxesScore, 0, validCount - 1, indexArray); - if (model_type_ == RKYOLOV5 || model_type_ == RKYOLOV7) { + if (anchor_per_branch_ == 3) { NMS(validCount, filterBoxes, classId, indexArray, nms_threshold_, false); - } else if (model_type_ == RKYOLOX) { + } else if (anchor_per_branch_ == 1) { NMS(validCount, filterBoxes, classId, indexArray, nms_threshold_, true); + }else{ + FDERROR << "anchor_per_branch_ only support 3 or 1." << std::endl; + return false; } int last_count = 0; @@ -110,19 +91,18 @@ bool RKYOLOPostprocessor::Run(const std::vector& tensors, float y2 = y1 + filterBoxes[n * 4 + 3]; int id = classId[n]; (*results)[num].boxes.emplace_back(std::array{ - (float)((clamp(x1, 0, width_) - pad_hw_values_[num][1] / 2) / + (float)((Clamp(x1, 0, width_) - pad_hw_values_[num][1] / 2) / scale_[num]), - (float)((clamp(y1, 0, height_) - pad_hw_values_[num][0] / 2) / + (float)((Clamp(y1, 0, height_) - pad_hw_values_[num][0] / 2) / scale_[num]), - (float)((clamp(x2, 0, width_) - pad_hw_values_[num][1] / 2) / + (float)((Clamp(x2, 0, width_) - pad_hw_values_[num][1] / 2) / scale_[num]), - (float)((clamp(y2, 0, height_) - pad_hw_values_[num][0] / 2) / + (float)((Clamp(y2, 0, height_) - pad_hw_values_[num][0] / 2) / scale_[0])}); (*results)[num].label_ids.push_back(id); (*results)[num].scores.push_back(boxesScore[i]); last_count++; } - std::cout << "last_count" << last_count << std::endl; } return true; } @@ -159,7 +139,7 @@ int RKYOLOPostprocessor::ProcessInt8(int8_t* input, int* anchor, int grid_h, float box_conf_f32 = DeqntAffineToF32(box_confidence, zp, scale); float class_prob_f32 = DeqntAffineToF32(maxClassProbs, zp, scale); float limit_score = 0; - if (model_type_ == RKYOLOX) { + if (anchor_per_branch_ == 1) { limit_score = box_conf_f32 * class_prob_f32; } else { limit_score = class_prob_f32; @@ -167,7 +147,7 @@ int RKYOLOPostprocessor::ProcessInt8(int8_t* input, int* anchor, int grid_h, //printf("limit score: %f\n", limit_score); if (limit_score > conf_threshold_) { float box_x, box_y, box_w, box_h; - if (model_type_ == RKYOLOX) { + if (anchor_per_branch_ == 1) { box_x = DeqntAffineToF32(*in_ptr, zp, scale); box_y = DeqntAffineToF32(in_ptr[grid_len], zp, scale); box_w = DeqntAffineToF32(in_ptr[2 * grid_len], zp, scale); @@ -234,6 +214,6 @@ int RKYOLOPostprocessor::QuickSortIndiceInverse(std::vector& input, return low; } -} // namespace detection -} // namespace vision -} // namespace fastdeploy +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/rknpu2/postprocessor.h b/fastdeploy/vision/detection/contrib/rknpu2/postprocessor.h index 0332b2efd..238c1c465 100755 --- a/fastdeploy/vision/detection/contrib/rknpu2/postprocessor.h +++ b/fastdeploy/vision/detection/contrib/rknpu2/postprocessor.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/detection/contrib/rknpu2/utils.h" @@ -54,9 +55,6 @@ class FASTDEPLOY_DECL RKYOLOPostprocessor { /// Get nms_threshold, default 0.45 float GetNMSThreshold() const { return nms_threshold_; } - // Set model_type - void SetModelType(ModelType model_type); - // Set height and weight void SetHeightAndWeight(int& height, int& width) { height_ = height; @@ -69,10 +67,16 @@ class FASTDEPLOY_DECL RKYOLOPostprocessor { } // Set scale - void SetScale(std::vector scale) { scale_ = scale; } + void SetScale(std::vector scale) { + scale_ = scale; + } + // Set Anchor + void SetAnchor(std::vector anchors,int anchor_per_branch){ + anchors_ = anchors; + anchor_per_branch_ = anchor_per_branch; + }; private: - ModelType model_type_ = ModelType::UNKNOWN; std::vector anchors_ = {10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326}; int strides_[3] = {8, 16, 32}; diff --git a/fastdeploy/vision/detection/contrib/rknpu2/preprocessor.cc b/fastdeploy/vision/detection/contrib/rknpu2/preprocessor.cc index 29480459b..068004346 100755 --- a/fastdeploy/vision/detection/contrib/rknpu2/preprocessor.cc +++ b/fastdeploy/vision/detection/contrib/rknpu2/preprocessor.cc @@ -57,7 +57,7 @@ void RKYOLOPreprocessor::LetterBox(FDMat* mat) { resize_w = size_[0]; } - pad_hw_values_.push_back({pad_h,pad_w}); + pad_hw_values_.push_back({pad_h, pad_w}); if (std::fabs(scale - 1.0f) > 1e-06) { Resize::Run(mat, resize_w, resize_h); @@ -75,17 +75,17 @@ void RKYOLOPreprocessor::LetterBox(FDMat* mat) { bool RKYOLOPreprocessor::Preprocess(FDMat* mat, FDTensor* output) { // process after image load -// float ratio = std::min(size_[1] * 1.0f / static_cast(mat->Height()), -// size_[0] * 1.0f / static_cast(mat->Width())); -// if (std::fabs(ratio - 1.0f) > 1e-06) { -// int interp = cv::INTER_AREA; -// if (ratio > 1.0) { -// interp = cv::INTER_LINEAR; -// } -// int resize_h = int(mat->Height() * ratio); -// int resize_w = int(mat->Width() * ratio); -// Resize::Run(mat, resize_w, resize_h, -1, -1, interp); -// } + // float ratio = std::min(size_[1] * 1.0f / static_cast(mat->Height()), + // size_[0] * 1.0f / static_cast(mat->Width())); + // if (std::fabs(ratio - 1.0f) > 1e-06) { + // int interp = cv::INTER_AREA; + // if (ratio > 1.0) { + // interp = cv::INTER_LINEAR; + // } + // int resize_h = int(mat->Height() * ratio); + // int resize_w = int(mat->Width() * ratio); + // Resize::Run(mat, resize_w, resize_h, -1, -1, interp); + // } // RKYOLO's preprocess steps // 1. letterbox @@ -93,7 +93,7 @@ bool RKYOLOPreprocessor::Preprocess(FDMat* mat, FDTensor* output) { LetterBox(mat); BGR2RGB::Run(mat); mat->ShareWithTensor(output); - output->ExpandDim(0); // reshape to n, h, w, c + output->ExpandDim(0); // reshape to n, h, w, c return true; } @@ -122,6 +122,6 @@ bool RKYOLOPreprocessor::Run(std::vector* images, return true; } -} // namespace detection -} // namespace vision -} // namespace fastdeploy +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/rknpu2/rkyolo.cc b/fastdeploy/vision/detection/contrib/rknpu2/rkyolo.cc index 017cb1be3..524afcef1 100644 --- a/fastdeploy/vision/detection/contrib/rknpu2/rkyolo.cc +++ b/fastdeploy/vision/detection/contrib/rknpu2/rkyolo.cc @@ -1,3 +1,16 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. //NOLINT +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "fastdeploy/vision/detection/contrib/rknpu2/rkyolo.h" namespace fastdeploy { @@ -26,12 +39,11 @@ bool RKYOLO::Initialize() { return false; } auto size = GetPreprocessor().GetSize(); - GetPostprocessor().SetHeightAndWeight(size[0],size[1]); + GetPostprocessor().SetHeightAndWeight(size[0], size[1]); return true; } -bool RKYOLO::Predict(const cv::Mat& im, - DetectionResult* result) { +bool RKYOLO::Predict(const cv::Mat& im, DetectionResult* result) { std::vector results; if (!BatchPredict({im}, &results)) { return false; @@ -50,7 +62,8 @@ bool RKYOLO::BatchPredict(const std::vector& images, } auto pad_hw_values_ = preprocessor_.GetPadHWValues(); postprocessor_.SetPadHWValues(preprocessor_.GetPadHWValues()); - std::cout << "preprocessor_ scale_ = " << preprocessor_.GetScale()[0] << std::endl; + std::cout << "preprocessor_ scale_ = " << preprocessor_.GetScale()[0] + << std::endl; postprocessor_.SetScale(preprocessor_.GetScale()); reused_input_tensors_[0].name = InputInfoOfRuntime(0).name; @@ -59,15 +72,15 @@ bool RKYOLO::BatchPredict(const std::vector& images, return false; } - if (!postprocessor_.Run(reused_output_tensors_, results)) { - FDERROR << "Failed to postprocess the inference results by runtime." << std::endl; + FDERROR << "Failed to postprocess the inference results by runtime." + << std::endl; return false; } return true; } -} // namespace detection -} // namespace vision -} // namespace fastdeploy \ No newline at end of file +} // namespace detection +} // namespace vision +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/detection/contrib/rknpu2/utils.cc b/fastdeploy/vision/detection/contrib/rknpu2/utils.cc index faac26983..4271def4a 100644 --- a/fastdeploy/vision/detection/contrib/rknpu2/utils.cc +++ b/fastdeploy/vision/detection/contrib/rknpu2/utils.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "fastdeploy/vision/detection/contrib/rknpu2/utils.h" -float clamp(float val, int min, int max) { +namespace fastdeploy { +namespace vision { +namespace detection { +float Clamp(float val, int min, int max) { return val > min ? (val < max ? val : max) : min; } @@ -35,59 +38,56 @@ float DeqntAffineToF32(int8_t qnt, int32_t zp, float scale) { return ((float)qnt - (float)zp) * scale; } -static float CalculateOverlap(float xmin0, float ymin0, float xmax0, float ymax0, float xmin1, float ymin1, float xmax1, float ymax1) -{ +static float CalculateOverlap(float xmin0, float ymin0, float xmax0, + float ymax0, float xmin1, float ymin1, + float xmax1, float ymax1) { float w = fmax(0.f, fmin(xmax0, xmax1) - fmax(xmin0, xmin1) + 1.0); float h = fmax(0.f, fmin(ymax0, ymax1) - fmax(ymin0, ymin1) + 1.0); float i = w * h; - float u = (xmax0 - xmin0 + 1.0) * (ymax0 - ymin0 + 1.0) + (xmax1 - xmin1 + 1.0) * (ymax1 - ymin1 + 1.0) - i; + float u = (xmax0 - xmin0 + 1.0) * (ymax0 - ymin0 + 1.0) + + (xmax1 - xmin1 + 1.0) * (ymax1 - ymin1 + 1.0) - i; return u <= 0.f ? 0.f : (i / u); } -int NMS(int validCount, - std::vector &outputLocations, - std::vector &class_id, - std::vector &order, - float threshold, - bool class_agnostic) -{ +int NMS(int valid_count, std::vector& output_locations, + std::vector& class_id, std::vector& order, float threshold, + bool class_agnostic) { // printf("class_agnostic: %d\n", class_agnostic); - for (int i = 0; i < validCount; ++i) - { - if (order[i] == -1) - { + for (int i = 0; i < valid_count; ++i) { + if (order[i] == -1) { continue; } int n = order[i]; - for (int j = i + 1; j < validCount; ++j) - { + for (int j = i + 1; j < valid_count; ++j) { int m = order[j]; - if (m == -1) - { + if (m == -1) { continue; } - if (!class_agnostic && class_id[n] != class_id[m]){ + if (!class_agnostic && class_id[n] != class_id[m]) { continue; } - float xmin0 = outputLocations[n * 4 + 0]; - float ymin0 = outputLocations[n * 4 + 1]; - float xmax0 = outputLocations[n * 4 + 0] + outputLocations[n * 4 + 2]; - float ymax0 = outputLocations[n * 4 + 1] + outputLocations[n * 4 + 3]; + float xmin0 = output_locations[n * 4 + 0]; + float ymin0 = output_locations[n * 4 + 1]; + float xmax0 = output_locations[n * 4 + 0] + output_locations[n * 4 + 2]; + float ymax0 = output_locations[n * 4 + 1] + output_locations[n * 4 + 3]; - float xmin1 = outputLocations[m * 4 + 0]; - float ymin1 = outputLocations[m * 4 + 1]; - float xmax1 = outputLocations[m * 4 + 0] + outputLocations[m * 4 + 2]; - float ymax1 = outputLocations[m * 4 + 1] + outputLocations[m * 4 + 3]; + float xmin1 = output_locations[m * 4 + 0]; + float ymin1 = output_locations[m * 4 + 1]; + float xmax1 = output_locations[m * 4 + 0] + output_locations[m * 4 + 2]; + float ymax1 = output_locations[m * 4 + 1] + output_locations[m * 4 + 3]; - float iou = CalculateOverlap(xmin0, ymin0, xmax0, ymax0, xmin1, ymin1, xmax1, ymax1); + float iou = CalculateOverlap(xmin0, ymin0, xmax0, ymax0, xmin1, ymin1, + xmax1, ymax1); - if (iou > threshold) - { + if (iou > threshold) { order[j] = -1; } } } return 0; -} \ No newline at end of file +} +} // namespace detection +} // namespace vision +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/detection/contrib/rknpu2/utils.h b/fastdeploy/vision/detection/contrib/rknpu2/utils.h index 4414cb8a5..23efa25c8 100644 --- a/fastdeploy/vision/detection/contrib/rknpu2/utils.h +++ b/fastdeploy/vision/detection/contrib/rknpu2/utils.h @@ -14,13 +14,20 @@ #pragma once #include #include -typedef enum { RKYOLOX = 0, RKYOLOV5, RKYOLOV7, UNKNOWN } ModelType; -float clamp(float val, int min, int max); + +namespace fastdeploy { +namespace vision { +namespace detection { +float Clamp(float val, int min, int max); float Sigmoid(float x); float UnSigmoid(float y); inline static int32_t __clip(float val, float min, float max); int8_t QntF32ToAffine(float f32, int32_t zp, float scale); float DeqntAffineToF32(int8_t qnt, int32_t zp, float scale); -int NMS(int validCount, std::vector& outputLocations, +int NMS(int valid_count, std::vector& output_locations, std::vector& class_id, std::vector& order, float threshold, bool class_agnostic); + +} // namespace detection +} // namespace vision +} // namespace fastdeploy \ No newline at end of file