diff --git a/benchmark/cpp/CMakeLists.txt b/benchmark/cpp/CMakeLists.txt index dc0599f4f..2b54ac473 100755 --- a/benchmark/cpp/CMakeLists.txt +++ b/benchmark/cpp/CMakeLists.txt @@ -1,6 +1,5 @@ PROJECT(infer_demo C CXX) CMAKE_MINIMUM_REQUIRED (VERSION 3.10) - # specify the decompress directory of FastDeploy SDK option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") include(${FASTDEPLOY_INSTALL_DIR}/utils/gflags.cmake) @@ -39,6 +38,8 @@ add_executable(benchmark_retinanet ${PROJECT_SOURCE_DIR}/benchmark_retinanet.cc) add_executable(benchmark_tood ${PROJECT_SOURCE_DIR}/benchmark_tood.cc) add_executable(benchmark_ttfnet ${PROJECT_SOURCE_DIR}/benchmark_ttfnet.cc) add_executable(benchmark ${PROJECT_SOURCE_DIR}/benchmark.cc) +add_executable(benchmark_ppdet ${PROJECT_SOURCE_DIR}/benchmark_ppdet.cc) +add_executable(benchmark_dino ${PROJECT_SOURCE_DIR}/benchmark_dino.cc) if(UNIX AND (NOT APPLE) AND (NOT ANDROID)) target_link_libraries(benchmark_yolov5 ${FASTDEPLOY_LIBS} gflags pthread) @@ -72,6 +73,8 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID)) target_link_libraries(benchmark_tood ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ttfnet ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags pthread) + target_link_libraries(benchmark_ppdet ${FASTDEPLOY_LIBS} gflags pthread) + target_link_libraries(benchmark_dino ${FASTDEPLOY_LIBS} gflags pthread) else() target_link_libraries(benchmark_yolov5 ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppyolov5 ${FASTDEPLOY_LIBS} gflags) @@ -104,6 +107,8 @@ else() target_link_libraries(benchmark_tood ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ttfnet ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags) + target_link_libraries(benchmark_ppdet ${FASTDEPLOY_LIBS} gflags) + target_link_libraries(benchmark_dino ${FASTDEPLOY_LIBS} gflags) endif() # only for Android ADB test if(ANDROID) diff --git a/benchmark/cpp/README.md b/benchmark/cpp/README.md index b19702f43..beec3ffef 100755 --- a/benchmark/cpp/README.md +++ b/benchmark/cpp/README.md @@ -186,6 +186,10 @@ benchmark: ./benchmark -[info|diff|check|dump|mem] -model xxx -config_path xxx - ```bash ./benchmark --model ResNet50_vd_infer --config_path config/config.gpu.paddle_trt.fp16.txt --trt_shapes 1,3,224,224:1,3,224,224:1,3,224,224 --names inputs --dtypes FP32 ``` +- TensorRT/Paddle-TRT多输入示例: +```bash +./benchmark --model rtdetr_r50vd_6x_coco --trt_shapes 1,2:1,2:1,2:1,3,640,640:1,3,640,640:1,3,640,640:1,2:1,2:1,2 --names im_shape:image:scale_factor --shapes 1,2:1,3,640,640:1,2 --config_path config/config.gpu.paddle_trt.fp32.txt --dtypes FP32:FP32:FP32 +``` - 支持FD全部后端和全部模型格式:--model_file, --params_file(optional), --model_format ```bash # ONNX模型示例 @@ -206,4 +210,4 @@ benchmark: ./benchmark -[info|diff|check|dump|mem] -model xxx -config_path xxx - - 显示模型的输入信息: --info ```bash ./benchmark --info --model picodet_l_640_coco_lcnet --config_path config/config.arm.lite.fp32.txt -``` \ No newline at end of file +``` diff --git a/benchmark/cpp/benchmark_dino.cc b/benchmark/cpp/benchmark_dino.cc new file mode 100644 index 000000000..f11d08a02 --- /dev/null +++ b/benchmark/cpp/benchmark_dino.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "flags.h" +#include "macros.h" +#include "option.h" + +namespace vision = fastdeploy::vision; +namespace benchmark = fastdeploy::benchmark; + +DEFINE_bool(no_nms, false, "Whether the model contains nms."); + +int main(int argc, char* argv[]) { +#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION) + // Initialization + auto option = fastdeploy::RuntimeOption(); + if (!CreateRuntimeOption(&option, argc, argv, true)) { + return -1; + } + auto im = cv::imread(FLAGS_image); + std::unordered_map config_info; + benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path, + &config_info); + std::string model_name, params_name, config_name; + auto model_format = fastdeploy::ModelFormat::PADDLE; + if (!UpdateModelResourceName(&model_name, ¶ms_name, &config_name, + &model_format, config_info)) { + return -1; + } + auto model_file = FLAGS_model + sep + model_name; + auto params_file = FLAGS_model + sep + params_name; + auto config_file = FLAGS_model + sep + config_name; + if (config_info["backend"] == "paddle_trt") { + option.paddle_infer_option.collect_trt_shape = true; + } + if (config_info["backend"] == "paddle_trt" || + config_info["backend"] == "trt") { + option.trt_option.SetShape("im_shape",{1,2},{1,2},{1,2}); + option.trt_option.SetShape("image", {1, 3, 320,320},{1, 3, 640, 640}, + {1, 3, 1280, 1280}); + option.trt_option.SetShape("scale_factor", {1, 2}, {1, 2}, + {1, 2}); + } + auto model_ppdet = vision::detection::PaddleDetectionModel( + model_file, params_file, config_file, option, model_format); + vision::DetectionResult res; + if (config_info["precision_compare"] == "true") { + // Run once at least + model_ppdet.Predict(im, &res); + // 1. Test result diff + std::cout << "=============== Test result diff =================\n"; + // Save result to -> disk. + std::string det_result_path = "ppdet_result.txt"; + benchmark::ResultManager::SaveDetectionResult(res, det_result_path); + // Load result from <- disk. + vision::DetectionResult res_loaded; + benchmark::ResultManager::LoadDetectionResult(&res_loaded, det_result_path); + // Calculate diff between two results. + auto det_diff = + benchmark::ResultManager::CalculateDiffStatis(res, res_loaded); + std::cout << "Boxes diff: mean=" << det_diff.boxes.mean + << ", max=" << det_diff.boxes.max + << ", min=" << det_diff.boxes.min << std::endl; + std::cout << "Label_ids diff: mean=" << det_diff.labels.mean + << ", max=" << det_diff.labels.max + << ", min=" << det_diff.labels.min << std::endl; + // 2. Test tensor diff + std::cout << "=============== Test tensor diff =================\n"; + std::vector batch_res; + std::vector input_tensors, output_tensors; + std::vector imgs; + imgs.push_back(im); + std::vector fd_images = vision::WrapMat(imgs); + + model_ppdet.GetPreprocessor().Run(&fd_images, &input_tensors); + input_tensors[0].name = "image"; + input_tensors[1].name = "scale_factor"; + input_tensors[2].name = "im_shape"; + input_tensors.pop_back(); + model_ppdet.Infer(input_tensors, &output_tensors); + model_ppdet.GetPostprocessor().Run(output_tensors, &batch_res); + // Save tensor to -> disk. + auto& tensor_dump = output_tensors[0]; + std::string det_tensor_path = "ppdet_tensor.txt"; + benchmark::ResultManager::SaveFDTensor(tensor_dump, det_tensor_path); + // Load tensor from <- disk. + fastdeploy::FDTensor tensor_loaded; + benchmark::ResultManager::LoadFDTensor(&tensor_loaded, det_tensor_path); + // Calculate diff between two tensors. + auto det_tensor_diff = benchmark::ResultManager::CalculateDiffStatis( + tensor_dump, tensor_loaded); + std::cout << "Tensor diff: mean=" << det_tensor_diff.data.mean + << ", max=" << det_tensor_diff.data.max + << ", min=" << det_tensor_diff.data.min << std::endl; + } + // Run profiling + if (FLAGS_no_nms) { + model_ppdet.GetPostprocessor().ApplyNMS(); + } + BENCHMARK_MODEL(model_ppdet, model_ppdet.Predict(im, &res)) + auto vis_im = vision::VisDetection(im, res,0.3); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +#endif + + return 0; +} diff --git a/benchmark/cpp/benchmark_ppdet.cc b/benchmark/cpp/benchmark_ppdet.cc new file mode 100644 index 000000000..124414a2c --- /dev/null +++ b/benchmark/cpp/benchmark_ppdet.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "flags.h" +#include "macros.h" +#include "option.h" + +namespace vision = fastdeploy::vision; +namespace benchmark = fastdeploy::benchmark; + +DEFINE_bool(no_nms, false, "Whether the model contains nms."); + +int main(int argc, char* argv[]) { +#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION) + // Initialization + auto option = fastdeploy::RuntimeOption(); + if (!CreateRuntimeOption(&option, argc, argv, true)) { + return -1; + } + auto im = cv::imread(FLAGS_image); + std::unordered_map config_info; + benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path, + &config_info); + std::string model_name, params_name, config_name; + auto model_format = fastdeploy::ModelFormat::PADDLE; + if (!UpdateModelResourceName(&model_name, ¶ms_name, &config_name, + &model_format, config_info)) { + return -1; + } + auto model_file = FLAGS_model + sep + model_name; + auto params_file = FLAGS_model + sep + params_name; + auto config_file = FLAGS_model + sep + config_name; + if (config_info["backend"] == "paddle_trt") { + option.paddle_infer_option.collect_trt_shape = true; + } + if (config_info["backend"] == "paddle_trt" || + config_info["backend"] == "trt") { + option.trt_option.SetShape("image", {1, 3, 640, 640}, {1, 3, 640, 640}, + {1, 3, 640, 640}); + option.trt_option.SetShape("scale_factor", {1, 2}, {1, 2}, + {1, 2}); + } + auto model_ppdet = vision::detection::PaddleDetectionModel( + model_file, params_file, config_file, option, model_format); + vision::DetectionResult res; + if (config_info["precision_compare"] == "true") { + // Run once at least + model_ppdet.Predict(im, &res); + // 1. Test result diff + std::cout << "=============== Test result diff =================\n"; + // Save result to -> disk. + std::string det_result_path = "ppdet_result.txt"; + benchmark::ResultManager::SaveDetectionResult(res, det_result_path); + // Load result from <- disk. + vision::DetectionResult res_loaded; + benchmark::ResultManager::LoadDetectionResult(&res_loaded, det_result_path); + // Calculate diff between two results. + auto det_diff = + benchmark::ResultManager::CalculateDiffStatis(res, res_loaded); + std::cout << "Boxes diff: mean=" << det_diff.boxes.mean + << ", max=" << det_diff.boxes.max + << ", min=" << det_diff.boxes.min << std::endl; + std::cout << "Label_ids diff: mean=" << det_diff.labels.mean + << ", max=" << det_diff.labels.max + << ", min=" << det_diff.labels.min << std::endl; + // 2. Test tensor diff + std::cout << "=============== Test tensor diff =================\n"; + std::vector batch_res; + std::vector input_tensors, output_tensors; + std::vector imgs; + imgs.push_back(im); + std::vector fd_images = vision::WrapMat(imgs); + + model_ppdet.GetPreprocessor().Run(&fd_images, &input_tensors); + input_tensors[0].name = "image"; + input_tensors[1].name = "scale_factor"; + input_tensors[2].name = "im_shape"; + input_tensors.pop_back(); + model_ppdet.Infer(input_tensors, &output_tensors); + model_ppdet.GetPostprocessor().Run(output_tensors, &batch_res); + // Save tensor to -> disk. + auto& tensor_dump = output_tensors[0]; + std::string det_tensor_path = "ppdet_tensor.txt"; + benchmark::ResultManager::SaveFDTensor(tensor_dump, det_tensor_path); + // Load tensor from <- disk. + fastdeploy::FDTensor tensor_loaded; + benchmark::ResultManager::LoadFDTensor(&tensor_loaded, det_tensor_path); + // Calculate diff between two tensors. + auto det_tensor_diff = benchmark::ResultManager::CalculateDiffStatis( + tensor_dump, tensor_loaded); + std::cout << "Tensor diff: mean=" << det_tensor_diff.data.mean + << ", max=" << det_tensor_diff.data.max + << ", min=" << det_tensor_diff.data.min << std::endl; + } + // Run profiling + if (FLAGS_no_nms) { + model_ppdet.GetPostprocessor().ApplyNMS(); + } + BENCHMARK_MODEL(model_ppdet, model_ppdet.Predict(im, &res)) + auto vis_im = vision::VisDetection(im, res,0.3); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +#endif + + return 0; +} diff --git a/fastdeploy/runtime/option_pybind.cc b/fastdeploy/runtime/option_pybind.cc old mode 100644 new mode 100755 index b5ac9e5b9..425c581e5 --- a/fastdeploy/runtime/option_pybind.cc +++ b/fastdeploy/runtime/option_pybind.cc @@ -43,6 +43,8 @@ void BindOption(pybind11::module& m) { .def("use_sophgo", &RuntimeOption::UseSophgo) .def("use_ascend", &RuntimeOption::UseAscend) .def("use_kunlunxin", &RuntimeOption::UseKunlunXin) + .def("disable_valid_backend_check",&RuntimeOption::DisableValidBackendCheck) + .def("enable_valid_backend_check",&RuntimeOption::EnableValidBackendCheck) .def_readwrite("paddle_lite_option", &RuntimeOption::paddle_lite_option) .def_readwrite("openvino_option", &RuntimeOption::openvino_option) .def_readwrite("ort_option", &RuntimeOption::ort_option) diff --git a/fastdeploy/vision/detection/ppdet/base.cc b/fastdeploy/vision/detection/ppdet/base.cc old mode 100644 new mode 100755 index 9619e056d..c0ba9a711 --- a/fastdeploy/vision/detection/ppdet/base.cc +++ b/fastdeploy/vision/detection/ppdet/base.cc @@ -18,6 +18,7 @@ PPDetBase::PPDetBase(const std::string& model_file, runtime_option.model_format = model_format; runtime_option.model_file = model_file; runtime_option.params_file = params_file; + } std::unique_ptr PPDetBase::Clone() const { @@ -82,6 +83,22 @@ bool PPDetBase::BatchPredict(const std::vector& imgs, return true; } +bool PPDetBase::CheckArch(){ + std::vector archs = {"SOLOv2","YOLO","SSD","RetinaNet","RCNN","Face","GFL","YOLOX","YOLOv5","YOLOv6","YOLOv7","RTMDet","FCOS","TTFNet","TOOD","DETR"}; + auto arch_ = preprocessor_.GetArch(); + for (auto item : archs) { + if (arch_ == item) { + return true; + } + } + FDWARNING << "Please set model arch," + << "support value : SOLOv2, YOLO, SSD, RetinaNet, RCNN, Face , GFL , RTMDet ,"\ + <<"FCOS , TTFNet , TOOD , DETR." << std::endl; + return false; + + +} + } // namespace detection } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/base.h b/fastdeploy/vision/detection/ppdet/base.h old mode 100644 new mode 100755 index 7465d54b9..17f451aeb --- a/fastdeploy/vision/detection/ppdet/base.h +++ b/fastdeploy/vision/detection/ppdet/base.h @@ -77,6 +77,7 @@ class FASTDEPLOY_DECL PPDetBase : public FastDeployModel { virtual bool BatchPredict(const std::vector& imgs, std::vector* results); + PaddleDetPreprocessor& GetPreprocessor() { return preprocessor_; } @@ -84,6 +85,7 @@ class FASTDEPLOY_DECL PPDetBase : public FastDeployModel { PaddleDetPostprocessor& GetPostprocessor() { return postprocessor_; } + virtual bool CheckArch(); protected: virtual bool Initialize(); diff --git a/fastdeploy/vision/detection/ppdet/model.h b/fastdeploy/vision/detection/ppdet/model.h index 34ab1f0ce..6c027e387 100755 --- a/fastdeploy/vision/detection/ppdet/model.h +++ b/fastdeploy/vision/detection/ppdet/model.h @@ -440,6 +440,29 @@ class FASTDEPLOY_DECL GFL : public PPDetBase { virtual std::string ModelName() const { return "PaddleDetection/GFL"; } }; +class FASTDEPLOY_DECL PaddleDetectionModel : public PPDetBase { + public: + PaddleDetectionModel(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) + : PPDetBase(model_file, params_file, config_file, custom_option, + model_format) { + CheckArch(); + valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER, + Backend::LITE}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; + valid_timvx_backends = {Backend::LITE}; + valid_kunlunxin_backends = {Backend::LITE}; + valid_rknpu_backends = {Backend::RKNPU2}; + valid_ascend_backends = {Backend::LITE}; + valid_sophgonpu_backends = {Backend::SOPHGOTPU}; + initialized = Initialize(); + } + + virtual std::string ModelName() const { return "PaddleDetectionModel"; } +}; + class FASTDEPLOY_DECL PPYOLOER : public PPDetBase { public: PPYOLOER(const std::string& model_file, const std::string& params_file, diff --git a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc old mode 100644 new mode 100755 index b25a1ab61..f6cf9fdc9 --- a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc +++ b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc @@ -238,7 +238,12 @@ void BindPPDet(pybind11::module& m) { m, "SOLOv2") .def(pybind11::init()); - + + pybind11::class_( + m, "PaddleDetectionModel") + .def(pybind11::init()); + pybind11::class_( m, "PPYOLOER") .def(pybind11::init