diff --git a/benchmark/cpp/benchmark.cc b/benchmark/cpp/benchmark.cc index 4e582c92b..0f5710224 100644 --- a/benchmark/cpp/benchmark.cc +++ b/benchmark/cpp/benchmark.cc @@ -64,6 +64,12 @@ DEFINE_string(optimized_model_dir, "", DEFINE_bool(collect_trt_shape_by_device, false, "Optional, whether collect trt shape by device. " "default false."); +DEFINE_double(custom_tensor_value, 1.0, + "Optional, set the value for fd tensor, " + "default 1.0"); +DEFINE_bool(collect_trt_shape_by_custom_tensor_value, false, + "Optional, whether collect trt shape by custom tensor value. " + "default false."); #if defined(ENABLE_BENCHMARK) static std::vector GetInt64Shape(const std::vector& shape) { @@ -208,6 +214,23 @@ static void RuntimeProfiling(int argc, char* argv[]) { for (int i = 0; i < input_shapes.size(); ++i) { option.trt_option.SetShape(input_names[i], trt_shapes[i * 3], trt_shapes[i * 3 + 1], trt_shapes[i * 3 + 2]); + // Set custom input data for collect trt shapes + if (FLAGS_collect_trt_shape_by_custom_tensor_value) { + int min_shape_num = std::accumulate(trt_shapes[i * 3].begin(), + trt_shapes[i * 3].end(), 1, + std::multiplies()); + int opt_shape_num = std::accumulate(trt_shapes[i * 3 + 1].begin(), + trt_shapes[i * 3 + 1].end(), 1, + std::multiplies()); + int max_shape_num = std::accumulate(trt_shapes[i * 3 + 2].begin(), + trt_shapes[i * 3 + 2].end(), 1, + std::multiplies()); + std::vector min_input_data(min_shape_num, FLAGS_custom_tensor_value); + std::vector opt_input_data(opt_shape_num, FLAGS_custom_tensor_value); + std::vector max_input_data(max_shape_num, FLAGS_custom_tensor_value); + option.trt_option.SetInputData(input_names[i], min_input_data, + opt_input_data, max_input_data); + } } } @@ -232,8 +255,9 @@ static void RuntimeProfiling(int argc, char* argv[]) { // Feed inputs, all values set as 1. std::vector inputs(runtime.NumInputs()); for (int i = 0; i < inputs.size(); ++i) { - fastdeploy::function::Full(1, GetInt64Shape(input_shapes[i]), &inputs[i], - input_dtypes[i]); + fastdeploy::function::Full( + FLAGS_custom_tensor_value, GetInt64Shape(input_shapes[i]), + &inputs[i], input_dtypes[i]); inputs[i].name = input_names[i]; } diff --git a/benchmark/paddlex/CMakeLists.txt b/benchmark/paddlex/CMakeLists.txt index d4862d622..b1439c546 100755 --- a/benchmark/paddlex/CMakeLists.txt +++ b/benchmark/paddlex/CMakeLists.txt @@ -21,6 +21,7 @@ add_executable(benchmark_structurev2_table ${PROJECT_SOURCE_DIR}/benchmark_struc add_executable(benchmark_structurev2_layout ${PROJECT_SOURCE_DIR}/benchmark_structurev2_layout.cc) add_executable(benchmark_ppshituv2_rec ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_rec.cc) add_executable(benchmark_ppshituv2_det ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_det.cc) +add_executable(benchmark_pp3d_centerpoint ${PROJECT_SOURCE_DIR}/benchmark_pp3d_centerpoint.cc) if(UNIX AND (NOT APPLE) AND (NOT ANDROID)) target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags pthread) @@ -33,6 +34,7 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID)) target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags pthread) + target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags pthread) else() target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppcls ${FASTDEPLOY_LIBS} gflags) @@ -44,6 +46,7 @@ else() target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags) + target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags) endif() # only for Android ADB test if(ANDROID) diff --git a/benchmark/paddlex/benchmark_pp3d_centerpoint.cc b/benchmark/paddlex/benchmark_pp3d_centerpoint.cc new file mode 100644 index 000000000..f7c81c1b0 --- /dev/null +++ b/benchmark/paddlex/benchmark_pp3d_centerpoint.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "flags.h" +#include "macros.h" +#include "option.h" + +namespace vision = fastdeploy::vision; +namespace benchmark = fastdeploy::benchmark; + + +static bool ReadTestPoint(const std::string &file_path, + std::vector &data) { + int with_timelag = 0; + int64_t num_point_dim = 5; + std::ifstream file_in(file_path, std::ios::in | std::ios::binary); + + if (!file_in) { + std::cout << "Failed to read file: " << file_path << std::endl; + return false; + } + + std::streampos file_size; + file_in.seekg(0, std::ios::end); + file_size = file_in.tellg(); + file_in.seekg(0, std::ios::beg); + + data.resize(file_size / sizeof(float)); + + file_in.read(reinterpret_cast(data.data()), file_size); + file_in.close(); + + if (file_size / sizeof(float) % num_point_dim != 0) { + std::cout << "Loaded file size (" << file_size + << ") is not evenly divisible by num_point_dim (" << num_point_dim + << ")\n"; + return false; + } + size_t num_points = file_size / sizeof(float) / num_point_dim; + if (!with_timelag && num_point_dim == 5 || num_point_dim > 5) { + for (int64_t i = 0; i < num_points; ++i) { + data[i * num_point_dim + 4] = 0.; + } + } + return true; +} + +int main(int argc, char* argv[]) { +#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION) + // Initialization + auto option = fastdeploy::RuntimeOption(); + if (!CreateRuntimeOption(&option, argc, argv, true)) { + return -1; + } + std::string point_dir = FLAGS_image; + std::unordered_map config_info; + benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path, + &config_info); + std::string model_name, params_name, config_name; + auto model_format = fastdeploy::ModelFormat::PADDLE; + if (!UpdateModelResourceName(&model_name, ¶ms_name, &config_name, + &model_format, config_info, false)) { + return -1; + } + auto model_file = FLAGS_model + sep + model_name; + auto params_file = FLAGS_model + sep + params_name; + if (config_info["backend"] == "paddle_trt") { + option.paddle_infer_option.collect_trt_shape = true; + option.paddle_infer_option.collect_trt_shape_by_device = true; + } + if (config_info["backend"] == "paddle_trt" || + config_info["backend"] == "trt") { + option.trt_option.SetShape("data", {34752, 5}, {34752, 5}, + {34752, 5}); + std::vector min_input_data; + ReadTestPoint(point_dir, min_input_data); + // use custom data to perform collect shapes. + option.trt_option.SetInputData("data", min_input_data); + } + auto model_centerpoint = vision::perception::Centerpoint( + model_file, params_file, "", option, model_format); + vision::PerceptionResult res; + // Run profiling + BENCHMARK_MODEL(model_centerpoint, model_centerpoint.Predict(point_dir, &res)) + // std::cout << res.Str() << std::endl; +#endif + + return 0; +} diff --git a/benchmark/paddlex/get_models.sh b/benchmark/paddlex/get_models.sh index 8d32a90eb..48f95c1b5 100755 --- a/benchmark/paddlex/get_models.sh +++ b/benchmark/paddlex/get_models.sh @@ -60,17 +60,24 @@ download PP-OCRv4-server-det.tgz # PP-ShiTuV2 download PP-ShiTuv2-rec.tgz +download PP-ShiTuv2-det.tgz # PP-StructureV2 download PP-Structurev2-layout.tgz download PP-Structurev2-SLANet.tgz download PP-Structurev2-vi-layoutxlm.tgz +# Paddle3D +download CADNN_OCRNet-HRNetW18.tgz +download CenterPoint-Pillars-02Voxel.tgz +download PETRv1_v99.tgz +download PETRv2_v99.tgz + # Test resources # PaddleClas wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/ppcls_cls_demo.JPEG -# PaddleDetection +# PaddleDetection & ppshitu-det wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/ppdet_det_img.jpg wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/ppdet_det_img_800x800.jpg @@ -93,3 +100,9 @@ wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/structurev2_layout_v wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/structurev2_vi_layoutxml_zh_val_0.jpg wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/table_structure_dict_ch.txt wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/layout_cdla_dict.txt + +# Paddle3D +wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_cadnn_kitti_000780.png +wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin +wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_petrv1_v99_nuscenes_sample_6.tgz && tar -zxvf paddle3d_petrv1_v99_nuscenes_sample_6.tgz +wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_petrv2_v99_nuscenes_sample_12.tgz && tar -zxvf paddle3d_petrv2_v99_nuscenes_sample_12.tgz diff --git a/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cc b/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cc index e7b041c3d..42d4e8dc2 100644 --- a/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cc +++ b/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cc @@ -25,6 +25,9 @@ #include "paddle/extension.h" #endif +namespace fastdeploy { +namespace paddle_custom_ops { + std::vector postprocess_gpu( const std::vector &hm, const std::vector ®, @@ -97,11 +100,14 @@ std::vector PostProcessInferDtype( return {reg_dtype[0], hm_dtype[0], paddle::DataType::INT64}; } +} // namespace fastdeploy +} // namespace paddle_custom_ops + PD_BUILD_OP(centerpoint_postprocess) .Inputs({paddle::Vec("HM"), paddle::Vec("REG"), paddle::Vec("HEIGHT"), paddle::Vec("DIM"), paddle::Vec("VEL"), paddle::Vec("ROT")}) .Outputs({"BBOXES", "SCORES", "LABELS"}) - .SetKernelFn(PD_KERNEL(centerpoint_postprocess)) + .SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::centerpoint_postprocess)) .Attrs({"voxel_size: std::vector", "point_cloud_range: std::vector", "post_center_range: std::vector", @@ -109,7 +115,7 @@ PD_BUILD_OP(centerpoint_postprocess) "score_threshold: float", "nms_iou_threshold: float", "nms_pre_max_size: int", "nms_post_max_size: int", "with_velocity: bool"}) - .SetInferShapeFn(PD_INFER_SHAPE(PostProcessInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(PostProcessInferDtype)); + .SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::PostProcessInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::PostProcessInferDtype)); #endif // WITH_GPU \ No newline at end of file diff --git a/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cu b/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cu index fccd103f5..86a61769e 100644 --- a/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cu +++ b/fastdeploy/runtime/backends/paddle/ops/centerpoint_postprocess_op.cu @@ -20,6 +20,9 @@ #include "paddle/extension.h" #endif +namespace fastdeploy { +namespace paddle_custom_ops { + #define CHECK_INPUT_CUDA(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") #define CHECK_INPUT_BATCHSIZE(x) \ @@ -284,3 +287,6 @@ std::vector postprocess_gpu( auto out_bboxes = paddle::experimental::concat(bboxes, 0); return {out_bboxes, out_scores, out_labels}; } + +} // namespace fastdeploy +} // namespace paddle_custom_ops \ No newline at end of file diff --git a/fastdeploy/runtime/backends/paddle/ops/iou3d_nms_kernel.cu b/fastdeploy/runtime/backends/paddle/ops/iou3d_nms_kernel.cu index 1d2cd2b44..7a2808198 100644 --- a/fastdeploy/runtime/backends/paddle/ops/iou3d_nms_kernel.cu +++ b/fastdeploy/runtime/backends/paddle/ops/iou3d_nms_kernel.cu @@ -19,6 +19,10 @@ All Rights Reserved 2019-2020. */ #include + +namespace fastdeploy { +namespace paddle_custom_ops { + #define THREADS_PER_BLOCK 16 #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) @@ -315,3 +319,6 @@ void NmsLauncher(const cudaStream_t &stream, const float *bboxes, num_bboxes, num_bboxes_for_nms, nms_overlap_thresh, decode_bboxes_dims, bboxes, index, sorted_index, mask); } + +} // namespace fastdeploy +} // namespace paddle_custom_ops \ No newline at end of file diff --git a/fastdeploy/runtime/backends/paddle/ops/voxelize_op.cc b/fastdeploy/runtime/backends/paddle/ops/voxelize_op.cc index bab67f674..7fad9a3aa 100644 --- a/fastdeploy/runtime/backends/paddle/ops/voxelize_op.cc +++ b/fastdeploy/runtime/backends/paddle/ops/voxelize_op.cc @@ -14,8 +14,6 @@ #include -#if defined(WITH_GPU) - #if defined(PADDLEINFERENCE_API_COMPAT_2_4_x) #include "paddle/include/experimental/ext_all.h" #elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x) @@ -24,6 +22,9 @@ #include "paddle/extension.h" #endif +namespace fastdeploy { +namespace paddle_custom_ops { + template bool hard_voxelize_cpu_kernel( const T *points, const float point_cloud_range_x_min, @@ -147,7 +148,8 @@ std::vector hard_voxelize_cpu( return {voxels, coords, num_points_per_voxel, num_voxels}; } -#ifdef PADDLE_WITH_CUDA + +#if defined(PADDLE_WITH_CUDA) && defined(WITH_GPU) std::vector hard_voxelize_cuda( const paddle::Tensor &points, const std::vector &voxel_size, const std::vector &point_cloud_range, int max_num_points_in_voxel, @@ -161,7 +163,7 @@ std::vector hard_voxelize( if (points.is_cpu()) { return hard_voxelize_cpu(points, voxel_size, point_cloud_range, max_num_points_in_voxel, max_voxels); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) && defined(WITH_GPU) } else if (points.is_gpu() || points.is_gpu_pinned()) { return hard_voxelize_cuda(points, voxel_size, point_cloud_range, max_num_points_in_voxel, max_voxels); @@ -188,14 +190,15 @@ std::vector HardInferDtype(paddle::DataType points_dtype) { paddle::DataType::INT32}; } +} // namespace fastdeploy +} // namespace paddle_custom_ops + PD_BUILD_OP(hard_voxelize) .Inputs({"POINTS"}) .Outputs({"VOXELS", "COORS", "NUM_POINTS_PER_VOXEL", "num_voxels"}) - .SetKernelFn(PD_KERNEL(hard_voxelize)) + .SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::hard_voxelize)) .Attrs({"voxel_size: std::vector", "point_cloud_range: std::vector", "max_num_points_in_voxel: int", "max_voxels: int"}) - .SetInferShapeFn(PD_INFER_SHAPE(HardInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(HardInferDtype)); - -#endif // WITH_GPU \ No newline at end of file + .SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::HardInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::HardInferDtype)); \ No newline at end of file diff --git a/fastdeploy/runtime/backends/paddle/ops/voxelize_op.cu b/fastdeploy/runtime/backends/paddle/ops/voxelize_op.cu index c5c0b7e1b..4b697b9e9 100644 --- a/fastdeploy/runtime/backends/paddle/ops/voxelize_op.cu +++ b/fastdeploy/runtime/backends/paddle/ops/voxelize_op.cu @@ -20,6 +20,9 @@ #include "paddle/extension.h" #endif +namespace fastdeploy { +namespace paddle_custom_ops { + #define CHECK_INPUT_CUDA(x) \ PD_CHECK(x.is_gpu() || x.is_gpu_pinned(), #x " must be a GPU Tensor.") @@ -349,3 +352,6 @@ std::vector hard_voxelize_cuda( return {voxels, coords, num_points_per_voxel, num_voxels}; } + +} // namespace fastdeploy +} // namespace paddle_custom_ops \ No newline at end of file diff --git a/fastdeploy/runtime/backends/paddle/option.h b/fastdeploy/runtime/backends/paddle/option.h index 80eeac4a0..e881946fc 100755 --- a/fastdeploy/runtime/backends/paddle/option.h +++ b/fastdeploy/runtime/backends/paddle/option.h @@ -111,6 +111,8 @@ struct PaddleBackendOption { int gpu_mem_init_size = 100; /// The option to enable fixed size optimization for transformer model bool enable_fixed_size_opt = false; + /// min_subgraph_size for paddle-trt + int trt_min_subgraph_size = 3; /// Disable type of operators run on TensorRT void DisableTrtOps(const std::vector& ops) { diff --git a/fastdeploy/runtime/backends/paddle/option_pybind.cc b/fastdeploy/runtime/backends/paddle/option_pybind.cc index fc107ed4a..ce2306cd9 100644 --- a/fastdeploy/runtime/backends/paddle/option_pybind.cc +++ b/fastdeploy/runtime/backends/paddle/option_pybind.cc @@ -59,6 +59,7 @@ void BindPaddleOption(pybind11::module& m) { &PaddleBackendOption::is_quantize_model) .def_readwrite("inference_precision", &PaddleBackendOption::inference_precision) .def_readwrite("enable_inference_cutlass",&PaddleBackendOption::enable_inference_cutlass) + .def_readwrite("trt_min_subgraph_size",&PaddleBackendOption::trt_min_subgraph_size) .def("disable_trt_ops", &PaddleBackendOption::DisableTrtOps) .def("delete_pass", &PaddleBackendOption::DeletePass) .def("set_ipu_config", &PaddleBackendOption::SetIpuConfig); diff --git a/fastdeploy/runtime/backends/paddle/paddle_backend.cc b/fastdeploy/runtime/backends/paddle/paddle_backend.cc index 71dceecde..1e60b5af3 100644 --- a/fastdeploy/runtime/backends/paddle/paddle_backend.cc +++ b/fastdeploy/runtime/backends/paddle/paddle_backend.cc @@ -88,7 +88,8 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) { config_.SetOptimCacheDir(opt_cache_dir); } config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, - option.trt_option.max_batch_size, 3, + option.trt_option.max_batch_size, + option.trt_min_subgraph_size, precision, use_static); SetTRTDynamicShapeToConfig(option); if (option_.enable_fixed_size_opt) { @@ -225,7 +226,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model, use_static = true; } config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, - option.trt_option.max_batch_size, 3, + option.trt_option.max_batch_size, + option.trt_min_subgraph_size, paddle_infer::PrecisionType::kInt8, use_static, false); SetTRTDynamicShapeToConfig(option);